use plotly::{
layout::{Annotation, Axis as AxisPlotly, GridPattern, Layout as LayoutPlotly, LayoutGrid},
Trace,
};
use serde_json::Value;
use crate::converters::components as conv;
use plotlars_core::components::{Dimensions, Text};
use plotlars_core::Plot;
use super::custom_legend::CustomLegend;
use super::shared::{
adjust_domain_for_type, calculate_spanning_domain, detect_plot_type, determine_bar_mode,
determine_box_mode, extract_axis_title_from_annotations, inject_non_cartesian_domains,
AxisConfig, JsonTrace, NonCartesianLayout, PlotType,
};
use super::SubplotGrid;
struct GridConfig {
rows: usize,
cols: usize,
h_gap: f64,
v_gap: f64,
}
fn build_axis_from_config(config: &AxisConfig) -> Option<AxisPlotly> {
let axis_obj = config.axis_json.as_object()?;
let mut axis = AxisPlotly::new();
if let Some(show_line) = axis_obj.get("showline").and_then(|v| v.as_bool()) {
axis = axis.show_line(show_line);
}
if let Some(show_grid) = axis_obj.get("showgrid").and_then(|v| v.as_bool()) {
axis = axis.show_grid(show_grid);
}
if let Some(zero_line) = axis_obj.get("zeroline").and_then(|v| v.as_bool()) {
axis = axis.zero_line(zero_line);
}
if let Some(range) = axis_obj.get("range").and_then(|v| v.as_array()) {
if range.len() == 2 {
if let (Some(min), Some(max)) = (range[0].as_f64(), range[1].as_f64()) {
axis = axis.range(vec![min, max]);
}
}
}
if let Some(separators) = axis_obj.get("separatethousands").and_then(|v| v.as_bool()) {
axis = axis.separate_thousands(separators);
}
if let Some(tick_direction) = axis_obj.get("ticks").and_then(|v| v.as_str()) {
use plotly::layout::TicksDirection;
let dir = match tick_direction {
"outside" => TicksDirection::Outside,
"inside" => TicksDirection::Inside,
_ => TicksDirection::Outside,
};
axis = axis.ticks(dir);
}
Some(axis)
}
fn validate_regular_grid(n_plots: usize, rows: usize, cols: usize) {
if n_plots == 0 {
panic!(
"SubplotGrid validation error: plots vector cannot be empty.\n\
\n\
Problem: You provided an empty plots vector.\n\
Solution: Create at least one plot and add it to the plots vector.\n\
\n\
Example:\n\
let plot1 = ScatterPlot::builder().data(&df).x(\"x\").y(\"y\").build();\n\
SubplotGrid::regular().plots(vec![&plot1])\n\
.build();"
);
}
if rows == 0 {
panic!(
"SubplotGrid validation error: rows must be greater than 0.\n\
\n\
Problem: You specified rows = 0, but rows must be at least 1.\n\
Solution: Set rows to a positive integer (e.g., 1, 2, or 3).\n\
\n\
Example:\n\
SubplotGrid::regular()\n\
.plots(vec![&plot1])\n\
.rows(2) // Use positive integer\n\
.cols(2)\n\
.build();"
);
}
if cols == 0 {
panic!(
"SubplotGrid validation error: cols must be greater than 0.\n\
\n\
Problem: You specified cols = 0, but cols must be at least 1.\n\
Solution: Set cols to a positive integer (e.g., 1, 2, or 3).\n\
\n\
Example:\n\
SubplotGrid::regular()\n\
.plots(vec![&plot1])\n\
.rows(2)\n\
.cols(2) // Use positive integer\n\
.build();"
);
}
let grid_capacity = rows * cols;
if n_plots > grid_capacity {
panic!(
"SubplotGrid validation error: too many plots for grid size.\n\
\n\
Problem: You provided {} plot(s) but the grid only has {} cells ({}x{} = {}).\n\
Solution: Either reduce the number of plots or increase the grid size.\n\
\n\
Option 1 - Reduce plots:\n\
Use {} plots instead of {}\n\
\n\
Option 2 - Increase grid size:\n\
Example calculations:\n\
- For {} plots: {}x{} grid works\n\
- For {} plots: {}x{} grid works",
n_plots,
grid_capacity,
rows,
cols,
grid_capacity,
grid_capacity,
n_plots,
n_plots,
(n_plots as f64).sqrt().ceil() as usize,
((n_plots as f64) / 2.0).ceil() as usize,
n_plots,
((n_plots + 1) as f64).sqrt().ceil() as usize,
((n_plots + 1) as f64 / 2.0).ceil() as usize
);
}
}
fn convert_plot(plot: &dyn Plot) -> (Vec<Box<dyn Trace + 'static>>, Value) {
let traces: Vec<Box<dyn Trace + 'static>> = plot
.ir_traces()
.iter()
.map(crate::converters::trace::convert)
.collect();
let (layout, _overrides) = crate::converters::layout::convert_layout_ir(plot.ir_layout());
let layout_json = serde_json::to_value(&layout).unwrap_or(Value::Null);
(traces, layout_json)
}
#[allow(clippy::too_many_arguments)]
pub(super) fn build_regular(
plots: Vec<&dyn Plot>,
rows: Option<usize>,
cols: Option<usize>,
title: Option<Text>,
h_gap: Option<f64>,
v_gap: Option<f64>,
legends: Option<Vec<Option<&CustomLegend>>>,
dimensions: Option<&Dimensions>,
) -> SubplotGrid {
let rows = rows.unwrap_or(1);
let cols = cols.unwrap_or(1);
let h_gap = h_gap.unwrap_or(0.1);
let v_gap = v_gap.unwrap_or(0.1);
validate_regular_grid(plots.len(), rows, cols);
let mut all_traces: Vec<Box<dyn Trace + 'static>> = Vec::new();
let mut plot_titles: Vec<Option<Text>> = Vec::new();
let mut axis_configs: Vec<(AxisConfig, AxisConfig)> = Vec::new();
let mut subplot_info: Vec<NonCartesianLayout> = Vec::new();
let mut legend_sources: Vec<Vec<JsonTrace>> = Vec::new();
let mut per_plot_traces: Vec<Vec<Box<dyn Trace + 'static>>> = Vec::new();
let mut scene_count = 0;
let mut polar_count = 0;
let mut mapbox_count = 0;
let mut geo_count = 0;
for (plot_idx, plot) in plots.iter().enumerate() {
let (traces, layout_json) = convert_plot(*plot);
let plot_type = detect_plot_type(traces[0].as_ref());
plot_titles.push(plot.ir_layout().title.clone());
let x_axis_json = layout_json.get("xaxis").cloned().unwrap_or(Value::Null);
let y_axis_json = layout_json.get("yaxis").cloned().unwrap_or(Value::Null);
let layout_fragment = match plot_type {
PlotType::Cartesian3D => layout_json.get("scene").cloned(),
PlotType::Polar => layout_json.get("polar").cloned(),
PlotType::Mapbox => layout_json.get("mapbox").cloned(),
PlotType::Geo => layout_json.get("geo").cloned(),
_ => None,
};
let x_title = plot
.ir_layout()
.x_title
.clone()
.or_else(|| extract_axis_title_from_annotations(&layout_json, true));
let y_title = plot
.ir_layout()
.y_title
.clone()
.or_else(|| extract_axis_title_from_annotations(&layout_json, false));
let x_config = AxisConfig {
title: x_title,
axis_json: x_axis_json,
};
let y_config = AxisConfig {
title: y_title,
axis_json: y_axis_json,
};
axis_configs.push((x_config, y_config));
let x_axis = if plot_idx == 0 {
"x".to_string()
} else {
format!("x{}", plot_idx + 1)
};
let y_axis = if plot_idx == 0 {
"y".to_string()
} else {
format!("y{}", plot_idx + 1)
};
let (row, col) = (plot_idx / cols, plot_idx % cols);
let (x_start, x_end, y_start, y_end) =
calculate_spanning_domain(row, col, 1, 1, rows, cols, h_gap, v_gap);
let (domain_x, domain_y) =
adjust_domain_for_type(plot_type.clone(), x_start, x_end, y_start, y_end);
let subplot_ref = match plot_type {
PlotType::Cartesian3D => {
let name = if scene_count == 0 {
"scene".to_string()
} else {
format!("scene{}", scene_count + 1)
};
scene_count += 1;
name
}
PlotType::Polar => {
let name = if polar_count == 0 {
"polar".to_string()
} else {
format!("polar{}", polar_count + 1)
};
polar_count += 1;
name
}
PlotType::Mapbox => {
let name = if mapbox_count == 0 {
"mapbox".to_string()
} else {
format!("mapbox{}", mapbox_count + 1)
};
mapbox_count += 1;
name
}
PlotType::Geo => {
let name = if geo_count == 0 {
"geo".to_string()
} else {
format!("geo{}", geo_count + 1)
};
geo_count += 1;
name
}
PlotType::Cartesian2D | PlotType::Domain => String::new(),
};
subplot_info.push(NonCartesianLayout {
plot_type: plot_type.clone(),
domain_x,
domain_y,
layout_fragment,
subplot_ref: subplot_ref.clone(),
});
let mut legend_traces: Vec<JsonTrace> = Vec::new();
for (trace_idx, trace) in traces.iter().enumerate() {
let mut json_trace = JsonTrace::new(trace.clone());
match plot_type {
PlotType::Cartesian2D => json_trace.set_axis_references(&x_axis, &y_axis),
PlotType::Cartesian3D => json_trace.set_scene_reference(&subplot_ref),
PlotType::Polar => json_trace.set_subplot_reference(&subplot_ref),
PlotType::Domain => json_trace.set_domain(domain_x, domain_y),
PlotType::Mapbox => json_trace.set_subplot_reference(&subplot_ref),
PlotType::Geo => json_trace.set_subplot_reference(&subplot_ref),
}
json_trace.ensure_color(trace_idx);
legend_traces.push(json_trace.clone());
all_traces.push(Box::new(json_trace));
}
legend_sources.push(legend_traces);
per_plot_traces.push(traces);
}
let grid_config = GridConfig {
rows,
cols,
h_gap,
v_gap,
};
let owned_legends: Option<Vec<Option<CustomLegend>>> =
legends.map(|vec| vec.iter().map(|opt| opt.cloned()).collect());
let auto_legends_owned: Option<Vec<Option<CustomLegend>>> =
owned_legends.clone().or_else(|| {
let generated: Vec<Option<CustomLegend>> = plots
.iter()
.map(|plot| CustomLegend::from_ir(plot.ir_traces(), plot.ir_layout()))
.collect();
Some(generated)
});
let (layout, layout_json) = create_regular_layout(
&grid_config,
title,
&plot_titles,
&axis_configs,
auto_legends_owned,
&per_plot_traces,
&subplot_info,
dimensions,
);
scale_colorbars_for_regular_grid(&mut all_traces, &per_plot_traces, &grid_config);
SubplotGrid {
traces: all_traces,
layout,
layout_json: Some(layout_json),
}
}
#[allow(clippy::too_many_arguments)]
fn create_regular_layout(
grid_config: &GridConfig,
plot_title: Option<Text>,
subplot_titles: &[Option<Text>],
axis_configs: &[(AxisConfig, AxisConfig)],
legends: Option<Vec<Option<CustomLegend>>>,
per_plot_traces: &[Vec<Box<dyn Trace + 'static>>],
subplot_info: &[NonCartesianLayout],
dimensions: Option<&Dimensions>,
) -> (LayoutPlotly, Value) {
let mut layout = LayoutPlotly::new().show_legend(false);
if let Some(bar_mode) = determine_bar_mode(per_plot_traces) {
layout = layout.bar_mode(bar_mode);
}
if let Some(box_mode) = determine_box_mode(per_plot_traces) {
layout = layout.box_mode(box_mode);
}
if let Some(title) = plot_title {
layout = layout.title(conv::convert_text_to_title(
&title.with_plot_title_defaults(),
));
}
let grid = LayoutGrid::new()
.rows(grid_config.rows)
.columns(grid_config.cols)
.pattern(GridPattern::Independent)
.x_gap(grid_config.h_gap)
.y_gap(grid_config.v_gap);
layout = layout.grid(grid);
for (idx, (x_config, y_config)) in axis_configs.iter().enumerate() {
let is_cartesian = subplot_info
.get(idx)
.map(|info| matches!(info.plot_type, PlotType::Cartesian2D))
.unwrap_or(true);
if !is_cartesian {
continue;
}
if let Some(x_axis) = build_axis_from_config(x_config) {
layout = match idx {
0 => layout.x_axis(x_axis),
1 => layout.x_axis2(x_axis),
2 => layout.x_axis3(x_axis),
3 => layout.x_axis4(x_axis),
4 => layout.x_axis5(x_axis),
5 => layout.x_axis6(x_axis),
6 => layout.x_axis7(x_axis),
7 => layout.x_axis8(x_axis),
_ => layout,
};
}
if let Some(y_axis) = build_axis_from_config(y_config) {
layout = match idx {
0 => layout.y_axis(y_axis),
1 => layout.y_axis2(y_axis),
2 => layout.y_axis3(y_axis),
3 => layout.y_axis4(y_axis),
4 => layout.y_axis5(y_axis),
5 => layout.y_axis6(y_axis),
6 => layout.y_axis7(y_axis),
7 => layout.y_axis8(y_axis),
_ => layout,
};
}
}
let mut annotations = Vec::new();
for (idx, (x_config, y_config)) in axis_configs.iter().enumerate() {
let is_cartesian = subplot_info
.get(idx)
.map(|info| matches!(info.plot_type, PlotType::Cartesian2D))
.unwrap_or(true);
if !is_cartesian {
continue;
}
if let Some(ref x_title) = x_config.title {
let axis_ref = if idx == 0 {
"x".to_string()
} else {
format!("x{}", idx + 1)
};
let x_title_with_defaults = x_title.clone().with_x_title_defaults();
let annotation = conv::convert_text_to_axis_annotation(
&x_title_with_defaults,
true,
&axis_ref,
true,
);
annotations.push(annotation);
}
if let Some(ref y_title) = y_config.title {
let axis_ref = if idx == 0 {
"y".to_string()
} else {
format!("y{}", idx + 1)
};
let y_title_with_defaults = y_title.clone().with_y_title_defaults();
let annotation = conv::convert_text_to_axis_annotation(
&y_title_with_defaults,
false,
&axis_ref,
true,
);
annotations.push(annotation);
}
}
for (idx, title_opt) in subplot_titles.iter().enumerate() {
if let Some(title_text) = title_opt {
let title = title_text.clone().with_subplot_title_defaults();
if let Some(info) = subplot_info.get(idx) {
if matches!(info.plot_type, PlotType::Cartesian2D) {
let x_ref = if idx == 0 {
"x domain".to_string()
} else {
format!("x{} domain", idx + 1)
};
let y_ref = if idx == 0 {
"y domain".to_string()
} else {
format!("y{} domain", idx + 1)
};
let ann = Annotation::new()
.text(&title.content)
.font(conv::convert_text_to_font(&title))
.x_ref(&x_ref)
.y_ref(&y_ref)
.x(title.x)
.y(title.y)
.show_arrow(false);
annotations.push(ann);
} else {
let width = info.domain_x[1] - info.domain_x[0];
let height = info.domain_y[1] - info.domain_y[0];
let x_pos = info.domain_x[0] + width * title.x;
let y_pos = if matches!(info.plot_type, PlotType::Polar)
&& !title_text.has_custom_position()
{
info.domain_y[1] + height * 0.20
} else {
info.domain_y[0] + height * title.y
};
annotations.push(
Annotation::new()
.text(&title.content)
.font(conv::convert_text_to_font(&title))
.x_ref("paper")
.y_ref("paper")
.x(x_pos)
.y(y_pos)
.show_arrow(false),
);
}
}
}
}
let get_domain = |idx: usize| {
subplot_info.get(idx).and_then(|info| {
if matches!(info.plot_type, PlotType::Cartesian2D) {
None
} else {
Some((info.domain_x, info.domain_y))
}
})
};
if let Some(legend_configs) = legends {
for (subplot_idx, legend_opt) in legend_configs.iter().enumerate() {
if let Some(legend) = legend_opt {
if let Some(legend_annotation) =
legend.to_annotation(subplot_idx, get_domain(subplot_idx))
{
annotations.push(legend_annotation);
}
}
}
}
if !annotations.is_empty() {
layout = layout.annotations(annotations);
}
if let Some(dims) = dimensions {
if let Some(width) = dims.width {
layout = layout.width(width);
}
if let Some(height) = dims.height {
layout = layout.height(height);
}
if let Some(auto_size) = dims.auto_size {
layout = layout.auto_size(auto_size);
}
}
let mut layout_json = serde_json::to_value(&layout).unwrap();
inject_non_cartesian_domains(&mut layout_json, subplot_info);
(layout, layout_json)
}
fn calculate_subplot_domain(
plot_idx: usize,
rows: usize,
cols: usize,
h_gap: f64,
v_gap: f64,
) -> (f64, f64, f64, f64) {
let row = plot_idx / cols;
let col = plot_idx % cols;
let col_width = (1.0 - h_gap * (cols - 1) as f64) / cols as f64;
let row_height = (1.0 - v_gap * (rows - 1) as f64) / rows as f64;
let x_start = col as f64 * (col_width + h_gap);
let x_end = x_start + col_width;
let y_from_top_start = row as f64 * (row_height + v_gap);
let y_from_top_end = y_from_top_start + row_height;
let y_start = 1.0 - y_from_top_end;
let y_end = 1.0 - y_from_top_start;
(x_start, x_end, y_start, y_end)
}
fn scale_colorbars_for_regular_grid(
all_traces: &mut [Box<dyn Trace + 'static>],
per_plot_traces: &[Vec<Box<dyn Trace + 'static>>],
grid_config: &GridConfig,
) {
let mut trace_idx = 0;
for (plot_idx, traces) in per_plot_traces.iter().enumerate() {
let col = plot_idx % grid_config.cols;
let (x_start, x_end, y_start, y_end) = calculate_subplot_domain(
plot_idx,
grid_config.rows,
grid_config.cols,
grid_config.h_gap,
grid_config.v_gap,
);
let domain_width = x_end - x_start;
let domain_height = y_end - y_start;
let num_traces = traces.len();
for _ in 0..num_traces {
if trace_idx >= all_traces.len() {
break;
}
let trace_json = serde_json::to_value(&all_traces[trace_idx]).ok();
if let Some(mut trace_value) = trace_json {
let has_colorbar = trace_value.get("colorbar").is_some();
let shows_scale = trace_value
.get("showscale")
.and_then(|v| v.as_bool())
.unwrap_or(true);
if !has_colorbar && shows_scale {
if let Some(trace_type) = trace_value.get("type").and_then(|v| v.as_str()) {
if matches!(trace_type, "heatmap" | "contour" | "surface") {
trace_value["colorbar"] = serde_json::json!({});
}
}
}
if let Some(colorbar) = trace_value.get_mut("colorbar") {
let current_len = colorbar.get("len").and_then(|v| v.as_f64());
match current_len {
Some(len) => {
if let Some(lenmode) = colorbar.get("lenmode").and_then(|v| v.as_str())
{
if lenmode == "fraction" && len > domain_height {
let scaled_len = len * domain_height;
colorbar["len"] = serde_json::json!(scaled_len);
}
}
}
None => {
colorbar["len"] = serde_json::json!(domain_height);
colorbar["lenmode"] = serde_json::json!("fraction");
}
}
let user_y_domain = colorbar.get("y").and_then(|v| v.as_f64()).unwrap_or(0.5);
if colorbar.get("yanchor").is_none() {
let yanchor = if user_y_domain >= 0.8 {
"top"
} else if user_y_domain <= 0.2 {
"bottom"
} else {
"middle"
};
colorbar["yanchor"] = serde_json::json!(yanchor);
}
if colorbar.get("yref").is_none() {
colorbar["yref"] = serde_json::json!("paper");
}
let paper_y = y_start + user_y_domain * domain_height;
colorbar["y"] = serde_json::json!(paper_y);
if colorbar.get("xref").is_none() {
colorbar["xref"] = serde_json::json!("paper");
}
if let Some(user_x) = colorbar.get("x").and_then(|v| v.as_f64()) {
let paper_x = x_start + user_x * domain_width;
colorbar["x"] = serde_json::json!(paper_x);
} else {
let is_rightmost_col = col == grid_config.cols - 1;
if is_rightmost_col {
if colorbar.get("xanchor").is_none() {
colorbar["xanchor"] = serde_json::json!("left");
}
let paper_x = x_end + 0.01;
colorbar["x"] = serde_json::json!(paper_x);
} else {
if colorbar.get("xanchor").is_none() {
colorbar["xanchor"] = serde_json::json!("center");
}
let gap_center = x_end + (grid_config.h_gap / 2.0);
colorbar["x"] = serde_json::json!(gap_center);
}
}
let scaled_trace = JsonTrace::from_value(trace_value);
all_traces[trace_idx] = Box::new(scaled_trace);
}
}
trace_idx += 1;
}
}
}