subplot_grid/
subplot_grid.rs

1use plotlars::{
2    Axis, BarPlot, BoxPlot, CandlestickPlot, ColorBar, Direction, HeatMap, Histogram, Legend, Line,
3    Orientation, Palette, Plot, Rgb, ScatterPlot, Shape, SubplotGrid, Text, TickDirection,
4    TimeSeriesPlot, ValueExponent,
5};
6use polars::prelude::*;
7
8fn main() {
9    regular_grid_example();
10    irregular_grid_example();
11}
12
13fn regular_grid_example() {
14    let dataset1 = LazyCsvReader::new(PlPath::new("data/animal_statistics.csv"))
15        .finish()
16        .unwrap()
17        .collect()
18        .unwrap();
19
20    let plot1 = BarPlot::builder()
21        .data(&dataset1)
22        .labels("animal")
23        .values("value")
24        .orientation(Orientation::Vertical)
25        .group("gender")
26        .sort_groups_by(|a, b| a.len().cmp(&b.len()))
27        .error("error")
28        .colors(vec![Rgb(255, 127, 80), Rgb(64, 224, 208)])
29        .plot_title(Text::from("Bar Plot").x(-0.05).y(1.35).size(14))
30        .y_title(Text::from("value").x(-0.055).y(0.76))
31        .x_title(Text::from("animal").x(0.97).y(-0.2))
32        .legend(
33            &Legend::new()
34                .orientation(Orientation::Horizontal)
35                .x(0.4)
36                .y(1.2),
37        )
38        .build();
39
40    let dataset2 = LazyCsvReader::new(PlPath::new("data/penguins.csv"))
41        .finish()
42        .unwrap()
43        .select([
44            col("species"),
45            col("sex").alias("gender"),
46            col("flipper_length_mm").cast(DataType::Int16),
47            col("body_mass_g").cast(DataType::Int16),
48        ])
49        .collect()
50        .unwrap();
51
52    let axis = Axis::new()
53        .show_line(true)
54        .tick_direction(TickDirection::OutSide)
55        .value_thousands(true);
56
57    let plot2 = ScatterPlot::builder()
58        .data(&dataset2)
59        .x("body_mass_g")
60        .y("flipper_length_mm")
61        .group("species")
62        .sort_groups_by(|a, b| {
63            if a.len() == b.len() {
64                a.cmp(b)
65            } else {
66                a.len().cmp(&b.len())
67            }
68        })
69        .opacity(0.5)
70        .size(12)
71        .colors(vec![Rgb(178, 34, 34), Rgb(65, 105, 225), Rgb(255, 140, 0)])
72        .shapes(vec![Shape::Circle, Shape::Square, Shape::Diamond])
73        .plot_title(Text::from("Scatter Plot").x(-0.075).y(1.35).size(14))
74        .x_title(Text::from("body mass (g)").y(-0.4))
75        .y_title(Text::from("flipper length (mm)").x(-0.078).y(0.5))
76        .legend_title("species")
77        .x_axis(&axis.clone().value_range(vec![2500.0, 6500.0]))
78        .y_axis(&axis.clone().value_range(vec![170.0, 240.0]))
79        .legend(&Legend::new().x(0.98).y(0.95))
80        .build();
81
82    let dataset3 = LazyCsvReader::new(PlPath::new("data/debilt_2023_temps.csv"))
83        .with_has_header(true)
84        .with_try_parse_dates(true)
85        .finish()
86        .unwrap()
87        .with_columns(vec![
88            (col("tavg") / lit(10)).alias("avg"),
89            (col("tmin") / lit(10)).alias("min"),
90            (col("tmax") / lit(10)).alias("max"),
91        ])
92        .collect()
93        .unwrap();
94
95    let plot3 = TimeSeriesPlot::builder()
96        .data(&dataset3)
97        .x("date")
98        .y("avg")
99        .additional_series(vec!["min", "max"])
100        .colors(vec![Rgb(128, 128, 128), Rgb(0, 122, 255), Rgb(255, 128, 0)])
101        .lines(vec![Line::Solid, Line::Dot, Line::Dot])
102        .plot_title(Text::from("Time Series Plot").x(-0.05).y(1.35).size(14))
103        .y_title(Text::from("temperature (ÂșC)").x(-0.055).y(0.6))
104        .legend(&Legend::new().x(0.9).y(1.25))
105        .build();
106
107    let plot4 = BoxPlot::builder()
108        .data(&dataset2)
109        .labels("species")
110        .values("body_mass_g")
111        .orientation(Orientation::Vertical)
112        .group("gender")
113        .box_points(true)
114        .point_offset(-1.5)
115        .jitter(0.01)
116        .opacity(0.1)
117        .colors(vec![Rgb(0, 191, 255), Rgb(57, 255, 20), Rgb(255, 105, 180)])
118        .plot_title(Text::from("Box Plot").x(-0.075).y(1.35).size(14))
119        .x_title(Text::from("species").y(-0.3))
120        .y_title(Text::from("body mass (g)").x(-0.08).y(0.5))
121        .legend_title(Text::from("gender").size(12))
122        .y_axis(&Axis::new().value_thousands(true))
123        .legend(&Legend::new().x(1.0))
124        .build();
125
126    SubplotGrid::regular()
127        .plots(vec![&plot1, &plot2, &plot3, &plot4])
128        .rows(2)
129        .cols(2)
130        .v_gap(0.4)
131        .title(
132            Text::from("Regular Subplot Grid")
133                .size(16)
134                .font("Arial bold")
135                .y(0.95),
136        )
137        .build()
138        .plot();
139}
140
141fn irregular_grid_example() {
142    let dataset1 = LazyCsvReader::new(PlPath::new("data/penguins.csv"))
143        .finish()
144        .unwrap()
145        .select([
146            col("species"),
147            col("sex").alias("gender"),
148            col("flipper_length_mm").cast(DataType::Int16),
149            col("body_mass_g").cast(DataType::Int16),
150        ])
151        .collect()
152        .unwrap();
153
154    let axis = Axis::new()
155        .show_line(true)
156        .show_grid(true)
157        .value_thousands(true)
158        .tick_direction(TickDirection::OutSide);
159
160    let plot1 = Histogram::builder()
161        .data(&dataset1)
162        .x("body_mass_g")
163        .group("species")
164        .opacity(0.5)
165        .colors(vec![Rgb(255, 165, 0), Rgb(147, 112, 219), Rgb(46, 139, 87)])
166        .plot_title(Text::from("Histogram").x(0.0).y(1.35).size(14))
167        .x_title(Text::from("body mass (g)").x(0.94).y(-0.35))
168        .y_title(Text::from("count").x(-0.062).y(0.83))
169        .x_axis(&axis)
170        .y_axis(&axis)
171        .legend_title(Text::from("species"))
172        .legend(&Legend::new().x(0.87).y(1.2))
173        .build();
174
175    let dataset2 = LazyCsvReader::new(PlPath::new("data/stock_prices.csv"))
176        .finish()
177        .unwrap()
178        .collect()
179        .unwrap();
180
181    let increasing = Direction::new()
182        .line_color(Rgb(0, 200, 100))
183        .line_width(0.5);
184
185    let decreasing = Direction::new()
186        .line_color(Rgb(200, 50, 50))
187        .line_width(0.5);
188
189    let plot2 = CandlestickPlot::builder()
190        .data(&dataset2)
191        .dates("date")
192        .open("open")
193        .high("high")
194        .low("low")
195        .close("close")
196        .increasing(&increasing)
197        .decreasing(&decreasing)
198        .whisker_width(0.1)
199        .plot_title(Text::from("Candlestick").x(0.0).y(1.35).size(14))
200        .y_title(Text::from("price ($)").x(-0.06).y(0.76))
201        .y_axis(&Axis::new().show_axis(true).show_grid(true))
202        .build();
203
204    let dataset3 = LazyCsvReader::new(PlPath::new("data/heatmap.csv"))
205        .finish()
206        .unwrap()
207        .collect()
208        .unwrap();
209
210    let plot3 = HeatMap::builder()
211        .data(&dataset3)
212        .x("x")
213        .y("y")
214        .z("z")
215        .color_bar(
216            &ColorBar::new()
217                .value_exponent(ValueExponent::None)
218                .separate_thousands(true)
219                .tick_length(5)
220                .tick_step(5000.0),
221        )
222        .plot_title(Text::from("Heat Map").x(0.0).y(1.35).size(14))
223        .color_scale(Palette::Viridis)
224        .build();
225
226    SubplotGrid::irregular()
227        .plots(vec![
228            (&plot1, 0, 0, 1, 1),
229            (&plot2, 0, 1, 1, 1),
230            (&plot3, 1, 0, 1, 2),
231        ])
232        .rows(2)
233        .cols(2)
234        .v_gap(0.35)
235        .h_gap(0.05)
236        .title(
237            Text::from("Irregular Subplot Grid")
238                .size(16)
239                .font("Arial bold")
240                .y(0.95),
241        )
242        .build()
243        .plot();
244}