Skip to main content

plotlars/common/
plot.rs

1use std::env;
2use std::fs;
3use std::process::Command;
4
5use plotly::{Layout, Plot as Plotly, Trace};
6use serde_json::Value;
7
8use crate::components::{Rgb, Text};
9
10use serde::Serialize;
11
12/// A trait representing a generic plot that can be displayed or rendered.
13pub trait Plot {
14    fn plot(&self);
15
16    fn write_html(&self, path: impl Into<String>);
17
18    fn to_json(&self) -> Result<String, serde_json::Error>;
19
20    fn to_html(&self) -> String;
21
22    fn to_inline_html(&self, plot_div_id: Option<&str>) -> String; // We need it?
23
24    /// Exports the plot to a static image file.
25    ///
26    /// This method requires one of the export features to be enabled:
27    /// - `export-chrome` (uses ChromeDriver)
28    /// - `export-firefox` (uses GeckoDriver)
29    /// - `export-default` (uses any available driver)
30    ///
31    /// # Arguments
32    ///
33    /// * `path` - Output file path including extension (`.png`, `.jpg`, `.jpeg`, `.webp`, or `.svg`)
34    /// * `width` - Image width in pixels
35    /// * `height` - Image height in pixels
36    /// * `scale` - Scaling factor for resolution (use `1.0` for standard displays, `2.0` for high-DPI)
37    ///
38    /// # Returns
39    ///
40    /// Returns `Ok(())` on success, or an error if:
41    /// - The export feature is not enabled
42    /// - The WebDriver is not installed or accessible
43    /// - The file format is unsupported
44    /// - The file cannot be written
45    ///
46    /// # Examples
47    ///
48    /// ```no_run
49    /// use plotlars::{ScatterPlot, Plot};
50    /// use polars::prelude::*;
51    ///
52    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
53    /// let dataset = df! {
54    ///     "x" => &[1, 2, 3, 4, 5],
55    ///     "y" => &[2, 4, 6, 8, 10],
56    /// }?;
57    ///
58    /// let plot = ScatterPlot::builder()
59    ///     .data(&dataset)
60    ///     .x("x")
61    ///     .y("y")
62    ///     .build();
63    ///
64    /// // Export as high-resolution PNG
65    /// plot.write_image("scatter.png", 1920, 1080, 2.0)?;
66    /// # Ok(())
67    /// # }
68    /// ```
69    ///
70    /// # See Also
71    ///
72    /// - [`write_html`](Plot::write_html) - Export as interactive HTML
73    /// - [`to_json`](Plot::to_json) - Export plot data as JSON
74    #[cfg(any(
75        feature = "export-chrome",
76        feature = "export-firefox",
77        feature = "export-default"
78    ))]
79    fn write_image(
80        &self,
81        path: impl Into<String>,
82        width: usize,
83        height: usize,
84        scale: f64,
85    ) -> Result<(), std::boxed::Box<dyn std::error::Error + 'static>>;
86}
87
88/// Helper trait for internal use by the `Plot` trait implementation.
89/// Can be used to get the underlying layout and traces of a plot (for example, to create a subplot).
90pub trait PlotHelper {
91    #[doc(hidden)]
92    fn get_layout(&self) -> &Layout;
93    #[doc(hidden)]
94    fn get_traces(&self) -> &Vec<Box<dyn Trace + 'static>>;
95
96    #[doc(hidden)]
97    fn get_layout_override(&self) -> Option<&Value> {
98        None
99    }
100
101    #[doc(hidden)]
102    fn get_serialized_traces(&self) -> Option<Vec<Value>> {
103        None
104    }
105
106    #[doc(hidden)]
107    fn get_main_title(&self) -> Option<String> {
108        let layout_json = serde_json::to_value(self.get_layout()).ok()?;
109        layout_json
110            .get("title")
111            .and_then(|t| t.get("text"))
112            .and_then(|t| t.as_str())
113            .map(|s| s.to_string())
114    }
115
116    #[doc(hidden)]
117    fn get_x_title(&self) -> Option<String> {
118        let layout_json = serde_json::to_value(self.get_layout()).ok()?;
119        layout_json
120            .get("xaxis")
121            .and_then(|axis| axis.get("title"))
122            .and_then(|title| title.get("text"))
123            .and_then(|text| text.as_str())
124            .map(|s| s.to_string())
125    }
126
127    #[doc(hidden)]
128    fn get_y_title(&self) -> Option<String> {
129        let layout_json = serde_json::to_value(self.get_layout()).ok()?;
130        layout_json
131            .get("yaxis")
132            .and_then(|axis| axis.get("title"))
133            .and_then(|title| title.get("text"))
134            .and_then(|text| text.as_str())
135            .map(|s| s.to_string())
136    }
137
138    #[doc(hidden)]
139    fn get_main_title_text(&self) -> Option<Text> {
140        let layout_json = serde_json::to_value(self.get_layout()).ok()?;
141        let title_obj = layout_json.get("title")?;
142
143        let content = title_obj
144            .get("text")
145            .and_then(|t| t.as_str())
146            .map(|s| s.to_string())?;
147
148        let mut text = Text::from(content);
149
150        if let Some(font_obj) = title_obj.get("font") {
151            if let Some(family) = font_obj.get("family").and_then(|f| f.as_str()) {
152                if !family.is_empty() {
153                    text = text.font(family);
154                }
155            }
156
157            if let Some(size) = font_obj.get("size").and_then(|s| s.as_u64()) {
158                if size > 0 {
159                    text = text.size(size as usize);
160                }
161            }
162
163            if let Some(color) = font_obj.get("color").and_then(|c| c.as_str()) {
164                if let Some(rgb) = parse_color(color) {
165                    text = text.color(rgb);
166                }
167            }
168        }
169
170        if let Some(x) = title_obj.get("x").and_then(|v| v.as_f64()) {
171            text = text.x(x);
172        }
173
174        if let Some(y) = title_obj.get("y").and_then(|v| v.as_f64()) {
175            text = text.y(y);
176        }
177
178        Some(text)
179    }
180
181    #[doc(hidden)]
182    fn get_x_title_text(&self) -> Option<Text> {
183        let layout_json = serde_json::to_value(self.get_layout()).ok()?;
184        let title_obj = layout_json
185            .get("xaxis")
186            .and_then(|axis| axis.get("title"))?;
187
188        let content = title_obj
189            .get("text")
190            .and_then(|t| t.as_str())
191            .map(|s| s.to_string())?;
192
193        let mut text = Text::from(content);
194
195        if let Some(font_obj) = title_obj.get("font") {
196            if let Some(family) = font_obj.get("family").and_then(|f| f.as_str()) {
197                if !family.is_empty() {
198                    text = text.font(family);
199                }
200            }
201
202            if let Some(size) = font_obj.get("size").and_then(|s| s.as_u64()) {
203                if size > 0 {
204                    text = text.size(size as usize);
205                }
206            }
207
208            if let Some(color) = font_obj.get("color").and_then(|c| c.as_str()) {
209                if let Some(rgb) = parse_color(color) {
210                    text = text.color(rgb);
211                }
212            }
213        }
214
215        if let Some(x) = title_obj.get("x").and_then(|v| v.as_f64()) {
216            text = text.x(x);
217        }
218
219        if let Some(y) = title_obj.get("y").and_then(|v| v.as_f64()) {
220            text = text.y(y);
221        }
222
223        Some(text)
224    }
225
226    #[doc(hidden)]
227    fn get_y_title_text(&self) -> Option<Text> {
228        let layout_json = serde_json::to_value(self.get_layout()).ok()?;
229        let title_obj = layout_json
230            .get("yaxis")
231            .and_then(|axis| axis.get("title"))?;
232
233        let content = title_obj
234            .get("text")
235            .and_then(|t| t.as_str())
236            .map(|s| s.to_string())?;
237
238        let mut text = Text::from(content);
239
240        if let Some(font_obj) = title_obj.get("font") {
241            if let Some(family) = font_obj.get("family").and_then(|f| f.as_str()) {
242                if !family.is_empty() {
243                    text = text.font(family);
244                }
245            }
246
247            if let Some(size) = font_obj.get("size").and_then(|s| s.as_u64()) {
248                if size > 0 {
249                    text = text.size(size as usize);
250                }
251            }
252
253            if let Some(color) = font_obj.get("color").and_then(|c| c.as_str()) {
254                if let Some(rgb) = parse_color(color) {
255                    text = text.color(rgb);
256                }
257            }
258        }
259
260        if let Some(x) = title_obj.get("x").and_then(|v| v.as_f64()) {
261            text = text.x(x);
262        }
263
264        if let Some(y) = title_obj.get("y").and_then(|v| v.as_f64()) {
265            text = text.y(y);
266        }
267
268        Some(text)
269    }
270
271    #[doc(hidden)]
272    #[cfg(any(
273        feature = "export-chrome",
274        feature = "export-firefox",
275        feature = "export-default"
276    ))]
277    fn get_image_format(
278        &self,
279        extension: &str,
280    ) -> Result<plotly::ImageFormat, std::boxed::Box<dyn std::error::Error + 'static>> {
281        match extension {
282            "png" => Ok(plotly::ImageFormat::PNG),
283            "jpg" => Ok(plotly::ImageFormat::JPEG),
284            "jpeg" => Ok(plotly::ImageFormat::JPEG),
285            "webp" => Ok(plotly::ImageFormat::WEBP),
286            "svg" => Ok(plotly::ImageFormat::SVG),
287            _ => Err(format!("Unsupported image format: {extension}").into()),
288        }
289    }
290}
291
292fn parse_color(color_str: &str) -> Option<Rgb> {
293    if color_str.starts_with("rgb(") || color_str.starts_with("rgba(") {
294        let start = color_str.find('(')?;
295        let end = color_str.find(')')?;
296        let values = &color_str[start + 1..end];
297        let parts: Vec<&str> = values.split(',').map(|s| s.trim()).collect();
298
299        if parts.len() >= 3 {
300            let r = parts[0].parse::<u8>().ok()?;
301            let g = parts[1].parse::<u8>().ok()?;
302            let b = parts[2].parse::<u8>().ok()?;
303            return Some(Rgb(r, g, b));
304        }
305    }
306
307    if let Some(hex) = color_str.strip_prefix('#') {
308        if hex.len() == 6 {
309            let r = u8::from_str_radix(&hex[0..2], 16).ok()?;
310            let g = u8::from_str_radix(&hex[2..4], 16).ok()?;
311            let b = u8::from_str_radix(&hex[4..6], 16).ok()?;
312            return Some(Rgb(r, g, b));
313        } else if hex.len() == 3 {
314            let r = u8::from_str_radix(&hex[0..1], 16).ok()? * 17;
315            let g = u8::from_str_radix(&hex[1..2], 16).ok()? * 17;
316            let b = u8::from_str_radix(&hex[2..3], 16).ok()? * 17;
317            return Some(Rgb(r, g, b));
318        }
319    }
320
321    match color_str.to_lowercase().as_str() {
322        "black" => Some(Rgb(0, 0, 0)),
323        "white" => Some(Rgb(255, 255, 255)),
324        "red" => Some(Rgb(255, 0, 0)),
325        "green" => Some(Rgb(0, 128, 0)),
326        "blue" => Some(Rgb(0, 0, 255)),
327        "yellow" => Some(Rgb(255, 255, 0)),
328        "cyan" => Some(Rgb(0, 255, 255)),
329        "magenta" => Some(Rgb(255, 0, 255)),
330        "gray" | "grey" => Some(Rgb(128, 128, 128)),
331        "orange" => Some(Rgb(255, 165, 0)),
332        "purple" => Some(Rgb(128, 0, 128)),
333        "pink" => Some(Rgb(255, 192, 203)),
334        "brown" => Some(Rgb(165, 42, 42)),
335        "lime" => Some(Rgb(0, 255, 0)),
336        "navy" => Some(Rgb(0, 0, 128)),
337        "teal" => Some(Rgb(0, 128, 128)),
338        "silver" => Some(Rgb(192, 192, 192)),
339        "maroon" => Some(Rgb(128, 0, 0)),
340        "olive" => Some(Rgb(128, 128, 0)),
341        _ => None,
342    }
343}
344
345// Implement the public trait `Plot` for any type that implements `PlotHelper`.
346impl<T> Plot for T
347where
348    T: PlotHelper + Serialize + Clone,
349{
350    fn plot(&self) {
351        if self.get_layout_override().is_some() {
352            let html = self.to_html();
353
354            match env::var("EVCXR_IS_RUNTIME") {
355                Ok(_) => {
356                    // For Jupyter/evcxr, print the HTML directly
357                    println!("HTML");
358                    println!("{}", html);
359                }
360                _ => {
361                    // Write HTML to temp file and open in browser
362                    let temp_dir = env::temp_dir();
363                    let timestamp = std::time::SystemTime::now()
364                        .duration_since(std::time::UNIX_EPOCH)
365                        .unwrap()
366                        .as_nanos();
367                    let temp_file = temp_dir.join(format!(
368                        "plotlars_{}_{}.html",
369                        std::process::id(),
370                        timestamp
371                    ));
372                    fs::write(&temp_file, html).expect("Failed to write HTML file");
373
374                    // Open the file in default browser
375                    open_html_file(&temp_file);
376                }
377            }
378        } else {
379            let mut plot = Plotly::new();
380            plot.set_layout(self.get_layout().to_owned());
381            plot.add_traces(self.get_traces().to_owned());
382
383            match env::var("EVCXR_IS_RUNTIME") {
384                Ok(_) => plot.evcxr_display(),
385                _ => plot.show(),
386            }
387        }
388    }
389
390    fn write_html(&self, path: impl Into<String>) {
391        if self.get_layout_override().is_some() {
392            let html = self.to_html();
393            fs::write(path.into(), html).expect("Failed to write HTML file");
394        } else {
395            let mut plot = Plotly::new();
396            plot.set_layout(self.get_layout().to_owned());
397            plot.add_traces(self.get_traces().to_owned());
398            plot.write_html(path.into());
399        }
400    }
401
402    fn to_json(&self) -> Result<String, serde_json::Error> {
403        serde_json::to_string(self)
404    }
405
406    fn to_html(&self) -> String {
407        if self.get_layout_override().is_some() {
408            let plot_json = serde_json::to_string(self).unwrap();
409            let escaped_json = plot_json
410                .replace('\\', "\\\\")
411                .replace('\'', "\\'")
412                .replace('\n', "\\n")
413                .replace('\r', "\\r");
414
415            format!(
416                r#"<!DOCTYPE html>
417<html>
418<head>
419    <meta charset="utf-8" />
420    <script src="https://cdn.plot.ly/plotly-2.18.0.min.js"></script>
421</head>
422<body>
423    <div id="plotly-div" style="width:100%;height:100%;"></div>
424    <script type="text/javascript">
425        var plotData = JSON.parse('{}');
426        Plotly.newPlot('plotly-div', plotData.traces, plotData.layout, {{responsive: true}});
427    </script>
428</body>
429</html>"#,
430                escaped_json
431            )
432        } else {
433            let mut plot = Plotly::new();
434            plot.set_layout(self.get_layout().to_owned());
435            plot.add_traces(self.get_traces().to_owned());
436            plot.to_html()
437        }
438    }
439
440    fn to_inline_html(&self, plot_div_id: Option<&str>) -> String {
441        let div_id = plot_div_id.unwrap_or("plotly-div");
442
443        if self.get_layout_override().is_some() {
444            let plot_json = serde_json::to_string(self).unwrap();
445            let escaped_json = plot_json
446                .replace('\\', "\\\\")
447                .replace('\'', "\\'")
448                .replace('\n', "\\n")
449                .replace('\r', "\\r");
450
451            format!(
452                r#"<div id="{}" style="width:100%;height:100%;"></div>
453<script type="text/javascript">
454    var plotData = JSON.parse('{}');
455    Plotly.newPlot('{}', plotData.traces, plotData.layout, {{responsive: true}});
456</script>"#,
457                div_id, escaped_json, div_id
458            )
459        } else {
460            let mut plot = Plotly::new();
461            plot.set_layout(self.get_layout().to_owned());
462            plot.add_traces(self.get_traces().to_owned());
463            plot.to_inline_html(plot_div_id)
464        }
465    }
466
467    #[cfg(any(
468        feature = "export-chrome",
469        feature = "export-firefox",
470        feature = "export-default"
471    ))]
472    fn write_image(
473        &self,
474        path: impl Into<String>,
475        width: usize,
476        height: usize,
477        scale: f64,
478    ) -> Result<(), std::boxed::Box<dyn std::error::Error + 'static>> {
479        let path_string = path.into();
480
481        let mut plot = Plotly::new();
482        plot.set_layout(self.get_layout().to_owned());
483        plot.add_traces(self.get_traces().to_owned());
484
485        if let Some((filename, extension)) = path_string.rsplit_once('.') {
486            let format = self.get_image_format(extension)?;
487            plot.write_image(filename, format, width, height, scale)?;
488        } else {
489            return Err("No extension provided for image.".into());
490        }
491
492        Ok(())
493    }
494}
495
496/// Helper function to open an HTML file in the default browser
497fn open_html_file(path: &std::path::Path) {
498    #[cfg(target_os = "macos")]
499    {
500        let _ = Command::new("open").arg(path).spawn().map(|mut child| {
501            // Spawn browser process and detach - we don't want to wait for it
502            let _ = std::thread::spawn(move || {
503                let _ = child.wait();
504            });
505        });
506    }
507
508    #[cfg(target_os = "linux")]
509    {
510        let _ = Command::new("xdg-open").arg(path).spawn().map(|mut child| {
511            // Spawn browser process and detach - we don't want to wait for it
512            let _ = std::thread::spawn(move || {
513                let _ = child.wait();
514            });
515        });
516    }
517
518    #[cfg(target_os = "windows")]
519    {
520        let _ = Command::new("cmd")
521            .args(&["/C", "start", "", path.to_str().unwrap()])
522            .spawn()
523            .map(|mut child| {
524                // Spawn browser process and detach - we don't want to wait for it
525                let _ = std::thread::spawn(move || {
526                    let _ = child.wait();
527                });
528            });
529    }
530
531    #[cfg(not(any(target_os = "macos", target_os = "linux", target_os = "windows")))]
532    {
533        eprintln!("Cannot automatically open browser on this platform. Please open the file manually: {:?}", path);
534    }
535}