Skip to main content

plotlars/plots/subplot_grid/
mod.rs

1use bon::bon;
2use plotly::{layout::Layout as LayoutPlotly, Trace};
3use serde::ser::{Serialize, SerializeStruct, Serializer};
4use serde_json::Value;
5
6use crate::common::{Layout, PlotHelper, Polar};
7use crate::components::{Dimensions, Text};
8
9mod custom_legend;
10mod irregular;
11mod regular;
12mod shared;
13
14/// A structure representing a subplot grid layout.
15///
16/// The `SubplotGrid` struct facilitates the creation of multi-plot layouts arranged in a grid configuration.
17/// Plots are automatically arranged in rows and columns in row-major order (left-to-right, top-to-bottom).
18/// Each subplot retains its own title, axis labels, and legend, providing flexibility for displaying
19/// multiple related visualizations in a single figure.
20///
21/// # Features
22///
23/// - Automatic grid layout with configurable rows and columns
24/// - Individual subplot titles (extracted from plot titles)
25/// - Independent axis labels for each subplot
26/// - Configurable horizontal and vertical spacing
27/// - Overall figure title
28/// - Sparse grid support (fewer plots than grid capacity)
29///
30#[derive(Clone)]
31pub struct SubplotGrid {
32    traces: Vec<Box<dyn Trace + 'static>>,
33    layout: LayoutPlotly,
34    layout_json: Option<Value>,
35}
36
37impl Serialize for SubplotGrid {
38    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
39    where
40        S: Serializer,
41    {
42        let mut state = serializer.serialize_struct("SubplotGrid", 2)?;
43        state.serialize_field("traces", &self.traces)?;
44
45        if let Some(ref layout_json) = self.layout_json {
46            state.serialize_field("layout", layout_json)?;
47        } else {
48            state.serialize_field("layout", &self.layout)?;
49        }
50
51        state.end()
52    }
53}
54
55#[bon]
56impl SubplotGrid {
57    /// Creates a subplot grid layout.
58    ///
59    /// Arranges plots in a row × column grid with automatic positioning. Plots are placed
60    /// in row-major order (left-to-right, top-to-bottom). Each subplot retains its individual title
61    /// (from the plot's `plot_title`), axis labels, and legend.
62    ///
63    /// # Arguments
64    ///
65    /// * `plots` - Vector of plot references to arrange in the grid. Plots are positioned in row-major order.
66    /// * `rows` - Number of rows in the grid (default: 1).
67    /// * `cols` - Number of columns in the grid (default: 1).
68    /// * `title` - Overall title for the entire subplot figure (optional).
69    /// * `h_gap` - Horizontal spacing between columns (range: 0.0 to 1.0, default: 0.1).
70    /// * `v_gap` - Vertical spacing between rows (range: 0.0 to 1.0, default: 0.1).
71    ///
72    /// # Example
73    ///
74    /// ```rust
75    /// use plotlars::{
76    ///     Axis, BarPlot, BoxPlot, Legend, Line, Orientation, Plot, Rgb, ScatterPlot, Shape,
77    ///     SubplotGrid, Text, TickDirection, TimeSeriesPlot,
78    /// };
79    /// use polars::prelude::*;
80    ///
81    /// let dataset1 = LazyCsvReader::new(PlRefPath::new("data/animal_statistics.csv"))
82    ///     .finish()
83    ///     .unwrap()
84    ///     .collect()
85    ///     .unwrap();
86    ///
87    /// let plot1 = BarPlot::builder()
88    ///     .data(&dataset1)
89    ///     .labels("animal")
90    ///     .values("value")
91    ///     .orientation(Orientation::Vertical)
92    ///     .group("gender")
93    ///     .sort_groups_by(|a, b| a.len().cmp(&b.len()))
94    ///     .error("error")
95    ///     .colors(vec![Rgb(255, 127, 80), Rgb(64, 224, 208)])
96    ///     .plot_title(Text::from("Bar Plot").x(-0.05).y(1.35).size(14))
97    ///     .y_title(Text::from("value").x(-0.055).y(0.76))
98    ///     .x_title(Text::from("animal").x(0.97).y(-0.2))
99    ///     .legend(
100    ///         &Legend::new()
101    ///             .orientation(Orientation::Horizontal)
102    ///             .x(0.4)
103    ///             .y(1.2),
104    ///     )
105    ///     .build();
106    ///
107    /// let dataset2 = LazyCsvReader::new(PlRefPath::new("data/penguins.csv"))
108    ///     .finish()
109    ///     .unwrap()
110    ///     .select([
111    ///         col("species"),
112    ///         col("sex").alias("gender"),
113    ///         col("flipper_length_mm").cast(DataType::Int16),
114    ///         col("body_mass_g").cast(DataType::Int16),
115    ///     ])
116    ///     .collect()
117    ///     .unwrap();
118    ///
119    /// let axis = Axis::new()
120    ///     .show_line(true)
121    ///     .tick_direction(TickDirection::OutSide)
122    ///     .value_thousands(true);
123    ///
124    /// let plot2 = ScatterPlot::builder()
125    ///     .data(&dataset2)
126    ///     .x("body_mass_g")
127    ///     .y("flipper_length_mm")
128    ///     .group("species")
129    ///     .sort_groups_by(|a, b| {
130    ///         if a.len() == b.len() {
131    ///             a.cmp(b)
132    ///         } else {
133    ///             a.len().cmp(&b.len())
134    ///         }
135    ///     })
136    ///     .opacity(0.5)
137    ///     .size(12)
138    ///     .colors(vec![Rgb(178, 34, 34), Rgb(65, 105, 225), Rgb(255, 140, 0)])
139    ///     .shapes(vec![Shape::Circle, Shape::Square, Shape::Diamond])
140    ///     .plot_title(Text::from("Scatter Plot").x(-0.075).y(1.35).size(14))
141    ///     .x_title(Text::from("body mass (g)").y(-0.4))
142    ///     .y_title(Text::from("flipper length (mm)").x(-0.078).y(0.5))
143    ///     .legend_title("species")
144    ///     .x_axis(&axis.clone().value_range(vec![2500.0, 6500.0]))
145    ///     .y_axis(&axis.clone().value_range(vec![170.0, 240.0]))
146    ///     .legend(&Legend::new().x(0.98).y(0.95))
147    ///     .build();
148    ///
149    /// let dataset3 = LazyCsvReader::new(PlRefPath::new("data/debilt_2023_temps.csv"))
150    ///     .with_has_header(true)
151    ///     .with_try_parse_dates(true)
152    ///     .finish()
153    ///     .unwrap()
154    ///     .with_columns(vec![
155    ///         (col("tavg") / lit(10)).alias("avg"),
156    ///         (col("tmin") / lit(10)).alias("min"),
157    ///         (col("tmax") / lit(10)).alias("max"),
158    ///     ])
159    ///     .collect()
160    ///     .unwrap();
161    ///
162    /// let plot3 = TimeSeriesPlot::builder()
163    ///     .data(&dataset3)
164    ///     .x("date")
165    ///     .y("avg")
166    ///     .additional_series(vec!["min", "max"])
167    ///     .colors(vec![Rgb(128, 128, 128), Rgb(0, 122, 255), Rgb(255, 128, 0)])
168    ///     .lines(vec![Line::Solid, Line::Dot, Line::Dot])
169    ///     .plot_title(Text::from("Time Series Plot").x(-0.05).y(1.35).size(14))
170    ///     .y_title(Text::from("temperature (ºC)").x(-0.055).y(0.6))
171    ///     .legend(&Legend::new().x(0.9).y(1.25))
172    ///     .build();
173    ///
174    /// let plot4 = BoxPlot::builder()
175    ///     .data(&dataset2)
176    ///     .labels("species")
177    ///     .values("body_mass_g")
178    ///     .orientation(Orientation::Vertical)
179    ///     .group("gender")
180    ///     .box_points(true)
181    ///     .point_offset(-1.5)
182    ///     .jitter(0.01)
183    ///     .opacity(0.1)
184    ///     .colors(vec![Rgb(0, 191, 255), Rgb(57, 255, 20), Rgb(255, 105, 180)])
185    ///     .plot_title(Text::from("Box Plot").x(-0.075).y(1.35).size(14))
186    ///     .x_title(Text::from("species").y(-0.3))
187    ///     .y_title(Text::from("body mass (g)").x(-0.08).y(0.5))
188    ///     .legend_title(Text::from("gender").size(12))
189    ///     .y_axis(&Axis::new().value_thousands(true))
190    ///     .legend(&Legend::new().x(1.0))
191    ///     .build();
192    ///
193    /// SubplotGrid::regular()
194    ///     .plots(vec![&plot1, &plot2, &plot3, &plot4])
195    ///     .rows(2)
196    ///     .cols(2)
197    ///     .v_gap(0.4)
198    ///     .title(
199    ///         Text::from("Regular Subplot Grid")
200    ///             .size(16)
201    ///             .font("Arial bold")
202    ///             .y(0.95),
203    ///     )
204    ///     .build()
205    ///     .plot();
206    /// ```
207    ///
208    /// ![Example](https://imgur.com/q0K7cyP.png)
209    #[builder(on(String, into), on(Text, into), finish_fn = build)]
210    pub fn regular(
211        plots: Vec<&dyn PlotHelper>,
212        rows: Option<usize>,
213        cols: Option<usize>,
214        title: Option<Text>,
215        h_gap: Option<f64>,
216        v_gap: Option<f64>,
217        dimensions: Option<&Dimensions>,
218    ) -> Self {
219        regular::build_regular(plots, rows, cols, title, h_gap, v_gap, None, dimensions)
220    }
221
222    /// Creates an irregular grid subplot layout with custom row/column spanning.
223    ///
224    /// Allows plots to span multiple rows and/or columns, enabling dashboard-style
225    /// layouts and asymmetric grid arrangements. Each plot explicitly specifies its
226    /// position and span.
227    ///
228    /// # Arguments
229    ///
230    /// * `plots` - Vector of tuples `(plot, row, col, rowspan, colspan)` where:
231    ///   - `plot`: Reference to the plot
232    ///   - `row`: Starting row (0-indexed)
233    ///   - `col`: Starting column (0-indexed)
234    ///   - `rowspan`: Number of rows to span (minimum 1)
235    ///   - `colspan`: Number of columns to span (minimum 1)
236    /// * `rows` - Total number of rows in the grid (default: 1).
237    /// * `cols` - Total number of columns in the grid (default: 1).
238    /// * `title` - Overall title for the subplot figure (optional).
239    /// * `h_gap` - Horizontal spacing between columns (range: 0.0 to 1.0, default: 0.1).
240    /// * `v_gap` - Vertical spacing between rows (range: 0.0 to 1.0, default: 0.1).
241    ///
242    /// # Example
243    ///
244    /// ```rust
245    /// use plotlars::{
246    ///     Axis, CandlestickPlot, ColorBar, Direction, HeatMap, Histogram, Legend, Palette, Plot,
247    ///     Rgb, SubplotGrid, Text, TickDirection, ValueExponent,
248    /// };
249    /// use polars::prelude::*;
250    ///
251    /// let dataset1 = LazyCsvReader::new(PlRefPath::new("data/penguins.csv"))
252    ///     .finish()
253    ///     .unwrap()
254    ///     .select([
255    ///         col("species"),
256    ///         col("sex").alias("gender"),
257    ///         col("flipper_length_mm").cast(DataType::Int16),
258    ///         col("body_mass_g").cast(DataType::Int16),
259    ///     ])
260    ///     .collect()
261    ///     .unwrap();
262    ///
263    /// let axis = Axis::new()
264    ///     .show_line(true)
265    ///     .show_grid(true)
266    ///     .value_thousands(true)
267    ///     .tick_direction(TickDirection::OutSide);
268    ///
269    /// let plot1 = Histogram::builder()
270    ///     .data(&dataset1)
271    ///     .x("body_mass_g")
272    ///     .group("species")
273    ///     .opacity(0.5)
274    ///     .colors(vec![Rgb(255, 165, 0), Rgb(147, 112, 219), Rgb(46, 139, 87)])
275    ///     .plot_title(Text::from("Histogram").x(0.0).y(1.35).size(14))
276    ///     .x_title(Text::from("body mass (g)").x(0.94).y(-0.35))
277    ///     .y_title(Text::from("count").x(-0.062).y(0.83))
278    ///     .x_axis(&axis)
279    ///     .y_axis(&axis)
280    ///     .legend_title(Text::from("species"))
281    ///     .legend(&Legend::new().x(0.87).y(1.2))
282    ///     .build();
283    ///
284    /// let dataset2 = LazyCsvReader::new(PlRefPath::new("data/stock_prices.csv"))
285    ///     .finish()
286    ///     .unwrap()
287    ///     .collect()
288    ///     .unwrap();
289    ///
290    /// let increasing = Direction::new()
291    ///     .line_color(Rgb(0, 200, 100))
292    ///     .line_width(0.5);
293    ///
294    /// let decreasing = Direction::new()
295    ///     .line_color(Rgb(200, 50, 50))
296    ///     .line_width(0.5);
297    ///
298    /// let plot2 = CandlestickPlot::builder()
299    ///     .data(&dataset2)
300    ///     .dates("date")
301    ///     .open("open")
302    ///     .high("high")
303    ///     .low("low")
304    ///     .close("close")
305    ///     .increasing(&increasing)
306    ///     .decreasing(&decreasing)
307    ///     .whisker_width(0.1)
308    ///     .plot_title(Text::from("Candlestick").x(0.0).y(1.35).size(14))
309    ///     .y_title(Text::from("Price ($)").x(-0.06).y(0.76))
310    ///     .y_axis(&Axis::new().show_axis(true).show_grid(true))
311    ///     .build();
312    ///
313    /// let dataset3 = LazyCsvReader::new(PlRefPath::new("data/heatmap.csv"))
314    ///     .finish()
315    ///     .unwrap()
316    ///     .collect()
317    ///     .unwrap();
318    ///
319    /// let plot3 = HeatMap::builder()
320    ///     .data(&dataset3)
321    ///     .x("x")
322    ///     .y("y")
323    ///     .z("z")
324    ///     .color_bar(
325    ///         &ColorBar::new()
326    ///             .value_exponent(ValueExponent::None)
327    ///             .separate_thousands(true)
328    ///             .tick_length(5)
329    ///             .tick_step(5000.0),
330    ///     )
331    ///     .plot_title(Text::from("Heat Map").x(0.0).y(1.35).size(14))
332    ///     .color_scale(Palette::Viridis)
333    ///     .build();
334    ///
335    /// SubplotGrid::irregular()
336    ///     .plots(vec![
337    ///         (&plot1, 0, 0, 1, 1),
338    ///         (&plot2, 0, 1, 1, 1),
339    ///         (&plot3, 1, 0, 1, 2),
340    ///     ])
341    ///     .rows(2)
342    ///     .cols(2)
343    ///     .v_gap(0.35)
344    ///     .h_gap(0.05)
345    ///     .title(
346    ///         Text::from("Irregular Subplot Grid")
347    ///             .size(16)
348    ///             .font("Arial bold")
349    ///             .y(0.95),
350    ///     )
351    ///     .build()
352    ///     .plot();
353    /// ```
354    ///
355    /// ![Example](https://imgur.com/RvZwv3O.png)
356    #[builder(on(String, into), on(Text, into), finish_fn = build)]
357    pub fn irregular(
358        plots: Vec<(&dyn PlotHelper, usize, usize, usize, usize)>,
359        rows: Option<usize>,
360        cols: Option<usize>,
361        title: Option<Text>,
362        h_gap: Option<f64>,
363        v_gap: Option<f64>,
364        dimensions: Option<&Dimensions>,
365    ) -> Self {
366        irregular::build_irregular(plots, rows, cols, title, h_gap, v_gap, dimensions)
367    }
368}
369
370impl Layout for SubplotGrid {}
371impl Polar for SubplotGrid {}
372
373#[doc(hidden)]
374impl PlotHelper for SubplotGrid {
375    fn get_layout(&self) -> &LayoutPlotly {
376        &self.layout
377    }
378
379    fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>> {
380        &self.traces
381    }
382
383    fn get_layout_override(&self) -> Option<&Value> {
384        self.layout_json.as_ref()
385    }
386}