Skip to main content

plotlars_plotters/
ext.rs

1use plotlars_core::Plot;
2
3/// Plotters rendering extension trait. Provides static image output methods.
4///
5/// Available on all types implementing the core `Plot` trait via blanket impl.
6pub trait PlottersExt: Plot {
7    /// Render and display the plot.
8    ///
9    /// In Jupyter/evcxr: displays inline PNG.
10    /// Otherwise: writes a temp PNG and opens the OS image viewer.
11    fn plot(&self);
12
13    /// Save the plot to a file. Format inferred from extension:
14    /// - `.png` -> BitMapBackend
15    /// - `.svg` -> SVGBackend
16    fn save(&self, path: &str);
17
18    /// Render the plot to an SVG string (in-memory, no file I/O).
19    fn to_svg(&self) -> String;
20}
21
22impl<T: Plot> PlottersExt for T {
23    fn plot(&self) {
24        crate::render::plot_interactive(self);
25    }
26
27    fn save(&self, path: &str) {
28        crate::render::save_to_file(self, path);
29    }
30
31    fn to_svg(&self) -> String {
32        crate::render::render_to_svg_string(self)
33    }
34}
35
36#[cfg(test)]
37mod tests {
38    use super::*;
39    use plotlars_core::components::Rgb;
40    use plotlars_core::plots::barplot::BarPlot;
41    use plotlars_core::plots::boxplot::BoxPlot;
42    use plotlars_core::plots::candlestick::CandlestickPlot;
43    use plotlars_core::plots::heatmap::HeatMap;
44    use plotlars_core::plots::histogram::Histogram;
45    use plotlars_core::plots::lineplot::LinePlot;
46    use plotlars_core::plots::scatterplot::ScatterPlot;
47    use plotlars_core::plots::timeseriesplot::TimeSeriesPlot;
48    use polars::prelude::*;
49
50    #[test]
51    fn scatter_plot_renders_to_svg() {
52        let df = df![
53            "x" => [1.0, 2.0, 3.0, 4.0],
54            "y" => [10.0, 20.0, 15.0, 25.0]
55        ]
56        .unwrap();
57        let plot = ScatterPlot::builder().data(&df).x("x").y("y").build();
58        let svg = plot.to_svg();
59        assert!(!svg.is_empty());
60        assert!(svg.contains("<svg"));
61    }
62
63    #[test]
64    fn scatter_plot_grouped_renders() {
65        let df = df![
66            "x" => [1.0, 2.0, 3.0, 4.0],
67            "y" => [10.0, 20.0, 15.0, 25.0],
68            "g" => ["a", "a", "b", "b"]
69        ]
70        .unwrap();
71        let plot = ScatterPlot::builder()
72            .data(&df)
73            .x("x")
74            .y("y")
75            .group("g")
76            .build();
77        let svg = plot.to_svg();
78        assert!(!svg.is_empty());
79        assert!(svg.contains("<svg"));
80    }
81
82    #[test]
83    fn scatter_plot_styled_renders() {
84        let df = df![
85            "x" => [1.0, 2.0, 3.0],
86            "y" => [4.0, 5.0, 6.0]
87        ]
88        .unwrap();
89        let plot = ScatterPlot::builder()
90            .data(&df)
91            .x("x")
92            .y("y")
93            .color(Rgb(255, 0, 0))
94            .opacity(0.7)
95            .size(10)
96            .build();
97        let svg = plot.to_svg();
98        assert!(!svg.is_empty());
99    }
100
101    #[test]
102    fn scatter_plot_with_title_renders() {
103        let df = df![
104            "x" => [1.0, 2.0, 3.0],
105            "y" => [4.0, 5.0, 6.0]
106        ]
107        .unwrap();
108        let plot = ScatterPlot::builder()
109            .data(&df)
110            .x("x")
111            .y("y")
112            .plot_title("My Plot")
113            .x_title("X Axis")
114            .y_title("Y Axis")
115            .build();
116        let svg = plot.to_svg();
117        assert!(svg.contains("My Plot"));
118    }
119
120    #[test]
121    fn horizontal_legend_border_debug() {
122        use plotlars_core::components::{Legend, Orientation};
123        let df = df![
124            "x" => [1.0, 2.0, 3.0],
125            "y" => [4.0, 5.0, 6.0],
126            "g" => ["a", "b", "c"]
127        ]
128        .unwrap();
129        let plot = ScatterPlot::builder()
130            .data(&df)
131            .x("x")
132            .y("y")
133            .group("g")
134            .legend_title("test")
135            .legend(
136                &Legend::new()
137                    .orientation(Orientation::Horizontal)
138                    .border_width(50),
139            )
140            .build();
141        let svg = plot.to_svg();
142        assert!(
143            svg.contains("stroke-width=\"50\""),
144            "SVG should contain stroke-width=50"
145        );
146    }
147
148    #[test]
149    fn line_plot_renders_to_svg() {
150        let df = df![
151            "x" => [1.0, 2.0, 3.0, 4.0],
152            "y" => [10.0, 20.0, 15.0, 25.0]
153        ]
154        .unwrap();
155        let plot = LinePlot::builder().data(&df).x("x").y("y").build();
156        let svg = plot.to_svg();
157        assert!(!svg.is_empty());
158        assert!(svg.contains("<svg"));
159    }
160
161    #[test]
162    fn line_plot_additional_lines_renders() {
163        let df = df![
164            "x" => [1.0, 2.0, 3.0],
165            "y1" => [4.0, 5.0, 6.0],
166            "y2" => [7.0, 8.0, 9.0]
167        ]
168        .unwrap();
169        let plot = LinePlot::builder()
170            .data(&df)
171            .x("x")
172            .y("y1")
173            .additional_lines(vec!["y2"])
174            .build();
175        let svg = plot.to_svg();
176        assert!(!svg.is_empty());
177    }
178
179    #[test]
180    fn bar_plot_renders_to_svg() {
181        let df = df![
182            "labels" => ["a", "b", "c"],
183            "values" => [10.0, 20.0, 30.0]
184        ]
185        .unwrap();
186        let plot = BarPlot::builder()
187            .data(&df)
188            .labels("labels")
189            .values("values")
190            .build();
191        let svg = plot.to_svg();
192        assert!(!svg.is_empty());
193        assert!(svg.contains("<svg"));
194    }
195
196    #[test]
197    fn bar_plot_grouped_renders() {
198        let df = df![
199            "labels" => ["a", "b", "a", "b"],
200            "values" => [10.0, 20.0, 30.0, 40.0],
201            "g" => ["x", "x", "y", "y"]
202        ]
203        .unwrap();
204        let plot = BarPlot::builder()
205            .data(&df)
206            .labels("labels")
207            .values("values")
208            .group("g")
209            .build();
210        let svg = plot.to_svg();
211        assert!(!svg.is_empty());
212    }
213
214    #[test]
215    fn bar_plot_horizontal_renders() {
216        let df = df![
217            "labels" => ["a", "b", "c"],
218            "values" => [10.0, 20.0, 30.0]
219        ]
220        .unwrap();
221        let plot = BarPlot::builder()
222            .data(&df)
223            .labels("labels")
224            .values("values")
225            .orientation(plotlars_core::components::Orientation::Horizontal)
226            .build();
227        let svg = plot.to_svg();
228        assert!(!svg.is_empty());
229    }
230
231    #[test]
232    fn histogram_renders_to_svg() {
233        let df = df!["x" => [1.0, 2.0, 2.0, 3.0, 3.0, 3.0]].unwrap();
234        let plot = Histogram::builder().data(&df).x("x").build();
235        let svg = plot.to_svg();
236        assert!(!svg.is_empty());
237        assert!(svg.contains("<svg"));
238    }
239
240    #[test]
241    fn boxplot_renders_to_svg() {
242        let df = df![
243            "species" => ["a", "a", "a", "a", "a", "b", "b", "b", "b", "b"],
244            "value" => [1.0, 2.0, 3.0, 4.0, 5.0, 2.0, 3.0, 4.0, 5.0, 6.0]
245        ]
246        .unwrap();
247        let plot = BoxPlot::builder()
248            .data(&df)
249            .labels("species")
250            .values("value")
251            .build();
252        let svg = plot.to_svg();
253        assert!(!svg.is_empty());
254        assert!(svg.contains("<svg"));
255    }
256
257    #[test]
258    fn boxplot_grouped_renders() {
259        let df = df![
260            "species" => ["a", "a", "a", "a", "b", "b", "b", "b"],
261            "value" => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
262            "g" => ["x", "x", "y", "y", "x", "x", "y", "y"]
263        ]
264        .unwrap();
265        let plot = BoxPlot::builder()
266            .data(&df)
267            .labels("species")
268            .values("value")
269            .group("g")
270            .build();
271        let svg = plot.to_svg();
272        assert!(!svg.is_empty());
273        assert!(svg.contains("<svg"));
274    }
275
276    #[test]
277    fn heatmap_renders_to_svg() {
278        let df = df![
279            "x" => ["a", "b", "c", "a", "b", "c"],
280            "y" => ["p", "p", "p", "q", "q", "q"],
281            "z" => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
282        ]
283        .unwrap();
284        let plot = HeatMap::builder().data(&df).x("x").y("y").z("z").build();
285        let svg = plot.to_svg();
286        assert!(!svg.is_empty());
287        assert!(svg.contains("<svg"));
288    }
289
290    #[test]
291    fn heatmap_with_palette_renders() {
292        let df = df![
293            "x" => ["a", "b", "a", "b"],
294            "y" => ["p", "p", "q", "q"],
295            "z" => [10.0, 20.0, 30.0, 40.0]
296        ]
297        .unwrap();
298        let plot = HeatMap::builder()
299            .data(&df)
300            .x("x")
301            .y("y")
302            .z("z")
303            .color_scale(plotlars_core::components::Palette::Hot)
304            .build();
305        let svg = plot.to_svg();
306        assert!(!svg.is_empty());
307    }
308
309    #[test]
310    fn candlestick_renders_to_svg() {
311        let df = df![
312            "date" => ["2024-01-01", "2024-01-02", "2024-01-03"],
313            "open" => [100.0, 102.5, 101.0],
314            "high" => [103.0, 104.0, 103.5],
315            "low" => [99.0, 101.5, 100.0],
316            "close" => [102.5, 101.0, 103.5]
317        ]
318        .unwrap();
319        let plot = CandlestickPlot::builder()
320            .data(&df)
321            .dates("date")
322            .open("open")
323            .high("high")
324            .low("low")
325            .close("close")
326            .build();
327        let svg = plot.to_svg();
328        assert!(!svg.is_empty());
329        assert!(svg.contains("<svg"));
330    }
331
332    #[test]
333    fn candlestick_with_colors_renders() {
334        use plotlars_core::components::Direction;
335        let df = df![
336            "date" => ["2024-01-01", "2024-01-02", "2024-01-03"],
337            "open" => [100.0, 102.5, 101.0],
338            "high" => [103.0, 104.0, 103.5],
339            "low" => [99.0, 101.5, 100.0],
340            "close" => [102.5, 101.0, 103.5]
341        ]
342        .unwrap();
343        let inc = Direction::new().line_color(Rgb(0, 150, 255));
344        let dec = Direction::new().line_color(Rgb(200, 0, 100));
345        let plot = CandlestickPlot::builder()
346            .data(&df)
347            .dates("date")
348            .open("open")
349            .high("high")
350            .low("low")
351            .close("close")
352            .increasing(&inc)
353            .decreasing(&dec)
354            .build();
355        let svg = plot.to_svg();
356        assert!(!svg.is_empty());
357    }
358
359    #[test]
360    fn timeseries_renders_to_svg() {
361        let df = df![
362            "date" => ["2024-01", "2024-02", "2024-03", "2024-04"],
363            "y" => [10.0, 20.0, 15.0, 25.0]
364        ]
365        .unwrap();
366        let plot = TimeSeriesPlot::builder().data(&df).x("date").y("y").build();
367        let svg = plot.to_svg();
368        assert!(!svg.is_empty());
369        assert!(svg.contains("<svg"));
370    }
371
372    #[test]
373    fn timeseries_additional_series_renders() {
374        let df = df![
375            "date" => ["2024-01", "2024-02", "2024-03", "2024-04"],
376            "y1" => [10.0, 20.0, 15.0, 25.0],
377            "y2" => [5.0, 15.0, 10.0, 20.0],
378            "y3" => [8.0, 18.0, 12.0, 22.0]
379        ]
380        .unwrap();
381        let plot = TimeSeriesPlot::builder()
382            .data(&df)
383            .x("date")
384            .y("y1")
385            .additional_series(vec!["y2", "y3"])
386            .colors(vec![Rgb(128, 128, 128), Rgb(0, 122, 255), Rgb(255, 128, 0)])
387            .build();
388        let svg = plot.to_svg();
389        assert!(!svg.is_empty());
390        assert!(svg.contains("<svg"));
391        // Should contain line elements for all 3 series
392        assert!(svg.contains("<polyline"));
393    }
394
395    #[test]
396    fn timeseries_dual_y_axis_renders() {
397        use plotlars_core::components::axis::AxisSide;
398        use plotlars_core::components::Axis;
399        let df = df![
400            "date" => ["2024-01", "2024-02", "2024-03", "2024-04"],
401            "revenue" => [1000.0, 2000.0, 3000.0, 4000.0],
402            "cost" => [100.0, 200.0, 150.0, 250.0]
403        ]
404        .unwrap();
405        let plot = TimeSeriesPlot::builder()
406            .data(&df)
407            .x("date")
408            .y("revenue")
409            .additional_series(vec!["cost"])
410            .colors(vec![Rgb(0, 0, 255), Rgb(255, 0, 0)])
411            .y_title("revenue")
412            .y2_title("cost")
413            .y_axis(&Axis::new().value_color(Rgb(0, 0, 255)))
414            .y2_axis(
415                &Axis::new()
416                    .axis_side(AxisSide::Right)
417                    .value_color(Rgb(255, 0, 0)),
418            )
419            .build();
420        let svg = plot.to_svg();
421        assert!(!svg.is_empty());
422        assert!(svg.contains("<svg"));
423        // Should not contain the unsupported warning text
424        // y2 axis labels should appear on right side
425    }
426
427    #[test]
428    fn timeseries_365_points_with_dashed_lines() {
429        use plotlars_core::components::Line as LineStyle;
430        // Simulate debilt 2023 temps: 365 days, 3 series with dashed lines
431        let dates: Vec<String> = (0..365)
432            .map(|i| format!("2023-{:02}-{:02}", i / 30 + 1, i % 30 + 1))
433            .collect();
434        let tavg: Vec<f64> = (0..365)
435            .map(|i| 10.0 + 10.0 * (i as f64 * 0.017).sin())
436            .collect();
437        let tmin: Vec<f64> = tavg.iter().map(|t| t - 5.0).collect();
438        let tmax: Vec<f64> = tavg.iter().map(|t| t + 5.0).collect();
439
440        let df = df![
441            "date" => dates,
442            "tavg" => tavg,
443            "tmin" => tmin,
444            "tmax" => tmax
445        ]
446        .unwrap();
447
448        let start = std::time::Instant::now();
449        let plot = TimeSeriesPlot::builder()
450            .data(&df)
451            .x("date")
452            .y("tavg")
453            .additional_series(vec!["tmin", "tmax"])
454            .colors(vec![Rgb(128, 128, 128), Rgb(0, 122, 255), Rgb(255, 128, 0)])
455            .lines(vec![LineStyle::Solid, LineStyle::Dot, LineStyle::Dot])
456            .build();
457        let svg = plot.to_svg();
458        let elapsed = start.elapsed();
459
460        assert!(!svg.is_empty());
461        assert!(svg.contains("<svg"));
462        assert!(
463            svg.contains("stroke-dasharray"),
464            "Dashed lines should have stroke-dasharray in SVG"
465        );
466        assert!(
467            elapsed.as_secs() < 5,
468            "Rendering took too long: {:?}",
469            elapsed
470        );
471    }
472
473    #[test]
474    fn histogram_grouped_renders() {
475        let df = df![
476            "x" => [1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
477            "g" => ["a", "a", "a", "b", "b", "b"]
478        ]
479        .unwrap();
480        let plot = Histogram::builder().data(&df).x("x").group("g").build();
481        let svg = plot.to_svg();
482        assert!(!svg.is_empty());
483    }
484}