Skip to main content

plotlars_core/plots/
scattergeo.rs

1use bon::bon;
2
3use polars::frame::DataFrame;
4
5use crate::{
6    components::{Legend, Mode, Rgb, Shape, Text},
7    ir::data::ColumnData,
8    ir::layout::LayoutIR,
9    ir::line::LineIR,
10    ir::marker::MarkerIR,
11    ir::trace::{ScatterGeoIR, TraceIR},
12};
13
14/// A structure representing a geographic scatter plot.
15///
16/// The `ScatterGeo` struct facilitates the creation and customization of geographic scatter plots
17/// with various options for data selection, grouping, layout configuration, and aesthetic adjustments.
18/// It supports plotting data points on a map using latitude and longitude coordinates, with customization
19/// for markers, lines, text labels, and comprehensive layout options.
20///
21/// # Backend Support
22///
23/// | Backend | Supported |
24/// |---------|-----------|
25/// | Plotly  | Yes       |
26/// | Plotters| --        |
27///
28/// # Arguments
29///
30/// * `data` - A reference to the `DataFrame` containing the data to be plotted.
31/// * `lat` - A string slice specifying the column name to be used for latitude coordinates.
32/// * `lon` - A string slice specifying the column name to be used for longitude coordinates.
33/// * `mode` - An optional `Mode` specifying the drawing mode (markers, lines, or both).
34/// * `text` - An optional string slice specifying the column name to be used for text labels.
35/// * `group` - An optional string slice specifying the column name to be used for grouping data points.
36/// * `sort_groups_by` - Optional comparator `fn(&str, &str) -> std::cmp::Ordering` to control group ordering. Groups are sorted lexically by default.
37/// * `opacity` - An optional `f64` value specifying the opacity of the plot elements (range: 0.0 to 1.0).
38/// * `size` - An optional `usize` specifying the size of the markers.
39/// * `color` - An optional `Rgb` value specifying the color of the markers. This is used when `group` is not specified.
40/// * `colors` - An optional vector of `Rgb` values specifying the colors for the markers. This is used when `group` is specified.
41/// * `shape` - An optional `Shape` specifying the shape of the markers. This is used when `group` is not specified.
42/// * `shapes` - An optional vector of `Shape` values specifying multiple shapes for the markers when plotting multiple groups.
43/// * `line_width` - An optional `f64` value specifying the width of the lines (when mode includes lines).
44/// * `line_color` - An optional `Rgb` value specifying the color of the lines.
45/// * `plot_title` - An optional `Text` struct specifying the title of the plot.
46/// * `legend_title` - An optional `Text` struct specifying the title of the legend.
47/// * `legend` - An optional reference to a `Legend` struct for customizing the legend of the plot.
48///
49/// # Example
50///
51/// ```rust
52/// use plotlars::{Plot, Rgb, ScatterGeo, Shape, Text, Mode};
53/// use polars::prelude::*;
54///
55/// let data = LazyCsvReader::new(PlRefPath::new("data/us_cities_regions.csv"))
56///     .finish()
57///     .unwrap()
58///     .collect()
59///     .unwrap();
60///
61/// ScatterGeo::builder()
62///     .data(&data)
63///     .lat("lat")
64///     .lon("lon")
65///     .mode(Mode::Markers)
66///     .text("city")
67///     .group("region")
68///     .size(15)
69///     .colors(vec![
70///         Rgb(255, 0, 0),
71///         Rgb(0, 255, 0),
72///         Rgb(0, 0, 255),
73///         Rgb(255, 165, 0),
74///     ])
75///     .plot_title(
76///         Text::from("Scatter Geo Plot")
77///             .font("Arial")
78///             .size(24)
79///             .x(0.5)
80///     )
81///     .legend_title(
82///         Text::from("Region")
83///             .size(14)
84///     )
85///     .build()
86///     .plot();
87/// ```
88///
89/// ![Example](https://imgur.com/8PCEbhN.png)
90#[derive(Clone)]
91#[allow(dead_code)]
92pub struct ScatterGeo {
93    traces: Vec<TraceIR>,
94    layout: LayoutIR,
95}
96
97#[bon]
98impl ScatterGeo {
99    #[builder(on(String, into), on(Text, into))]
100    pub fn new(
101        data: &DataFrame,
102        lat: &str,
103        lon: &str,
104        mode: Option<Mode>,
105        text: Option<&str>,
106        group: Option<&str>,
107        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
108        opacity: Option<f64>,
109        size: Option<usize>,
110        color: Option<Rgb>,
111        colors: Option<Vec<Rgb>>,
112        shape: Option<Shape>,
113        shapes: Option<Vec<Shape>>,
114        line_width: Option<f64>,
115        line_color: Option<Rgb>,
116        plot_title: Option<Text>,
117        legend_title: Option<Text>,
118        legend: Option<&Legend>,
119    ) -> Self {
120        // Build IR traces
121        let traces = Self::create_ir_traces(
122            data,
123            lat,
124            lon,
125            mode,
126            text,
127            group,
128            sort_groups_by,
129            opacity,
130            size,
131            color,
132            colors,
133            shape,
134            shapes,
135            line_width,
136            line_color,
137        );
138
139        let layout = LayoutIR {
140            title: plot_title.clone(),
141            x_title: None,
142            y_title: None,
143            y2_title: None,
144            z_title: None,
145            legend_title: legend_title.clone(),
146            legend: legend.cloned(),
147            dimensions: None,
148            bar_mode: None,
149            box_mode: None,
150            box_gap: None,
151            margin_bottom: None,
152            axes_2d: None,
153            scene_3d: None,
154            polar: None,
155            mapbox: None,
156            grid: None,
157            annotations: vec![],
158        };
159        Self { traces, layout }
160    }
161}
162
163#[bon]
164impl ScatterGeo {
165    #[builder(
166        start_fn = try_builder,
167        finish_fn = try_build,
168        builder_type = ScatterGeoTryBuilder,
169        on(String, into),
170        on(Text, into),
171    )]
172    pub fn try_new(
173        data: &DataFrame,
174        lat: &str,
175        lon: &str,
176        mode: Option<Mode>,
177        text: Option<&str>,
178        group: Option<&str>,
179        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
180        opacity: Option<f64>,
181        size: Option<usize>,
182        color: Option<Rgb>,
183        colors: Option<Vec<Rgb>>,
184        shape: Option<Shape>,
185        shapes: Option<Vec<Shape>>,
186        line_width: Option<f64>,
187        line_color: Option<Rgb>,
188        plot_title: Option<Text>,
189        legend_title: Option<Text>,
190        legend: Option<&Legend>,
191    ) -> Result<Self, crate::io::PlotlarsError> {
192        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
193            Self::__orig_new(
194                data,
195                lat,
196                lon,
197                mode,
198                text,
199                group,
200                sort_groups_by,
201                opacity,
202                size,
203                color,
204                colors,
205                shape,
206                shapes,
207                line_width,
208                line_color,
209                plot_title,
210                legend_title,
211                legend,
212            )
213        }))
214        .map_err(|panic| {
215            let msg = panic
216                .downcast_ref::<String>()
217                .cloned()
218                .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
219                .unwrap_or_else(|| "unknown error".to_string());
220            crate::io::PlotlarsError::PlotBuild { message: msg }
221        })
222    }
223}
224
225impl ScatterGeo {
226    #[allow(clippy::too_many_arguments)]
227    fn create_ir_traces(
228        data: &DataFrame,
229        lat: &str,
230        lon: &str,
231        mode: Option<Mode>,
232        text: Option<&str>,
233        group: Option<&str>,
234        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
235        opacity: Option<f64>,
236        size: Option<usize>,
237        color: Option<Rgb>,
238        colors: Option<Vec<Rgb>>,
239        shape: Option<Shape>,
240        shapes: Option<Vec<Shape>>,
241        line_width: Option<f64>,
242        line_color: Option<Rgb>,
243    ) -> Vec<TraceIR> {
244        let mut traces = Vec::new();
245
246        let line_ir = if line_width.is_some() || line_color.is_some() {
247            Some(LineIR {
248                width: line_width,
249                style: None,
250                color: line_color,
251            })
252        } else {
253            None
254        };
255
256        match group {
257            Some(group_col) => {
258                let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
259                let groups = groups.iter().map(|s| s.as_str());
260
261                for (i, group_name) in groups.enumerate() {
262                    let subset = crate::data::filter_data_by_group(data, group_col, group_name);
263
264                    let resolved_color = Self::resolve_color(i, color, colors.clone());
265                    let resolved_shape = Self::resolve_shape(i, shape, shapes.clone());
266
267                    let marker_ir = MarkerIR {
268                        opacity,
269                        size,
270                        color: resolved_color,
271                        shape: resolved_shape,
272                    };
273
274                    let lat_data =
275                        ColumnData::Numeric(crate::data::get_numeric_column(&subset, lat));
276                    let lon_data =
277                        ColumnData::Numeric(crate::data::get_numeric_column(&subset, lon));
278
279                    let text_data = text.map(|text_col| {
280                        ColumnData::String(crate::data::get_string_column(&subset, text_col))
281                    });
282
283                    traces.push(TraceIR::ScatterGeo(ScatterGeoIR {
284                        lat: lat_data,
285                        lon: lon_data,
286                        name: Some(group_name.to_string()),
287                        text: text_data,
288                        mode,
289                        marker: Some(marker_ir),
290                        line: line_ir.clone(),
291                        show_legend: None,
292                    }));
293                }
294            }
295            None => {
296                let resolved_color = Self::resolve_color(0, color, colors.clone());
297                let resolved_shape = Self::resolve_shape(0, shape, shapes.clone());
298
299                let marker_ir = MarkerIR {
300                    opacity,
301                    size,
302                    color: resolved_color,
303                    shape: resolved_shape,
304                };
305
306                let lat_data = ColumnData::Numeric(crate::data::get_numeric_column(data, lat));
307                let lon_data = ColumnData::Numeric(crate::data::get_numeric_column(data, lon));
308
309                let text_data = text.map(|text_col| {
310                    ColumnData::String(crate::data::get_string_column(data, text_col))
311                });
312
313                traces.push(TraceIR::ScatterGeo(ScatterGeoIR {
314                    lat: lat_data,
315                    lon: lon_data,
316                    name: None,
317                    text: text_data,
318                    mode,
319                    marker: Some(marker_ir),
320                    line: line_ir,
321                    show_legend: None,
322                }));
323            }
324        }
325
326        traces
327    }
328
329    fn resolve_color(index: usize, color: Option<Rgb>, colors: Option<Vec<Rgb>>) -> Option<Rgb> {
330        if let Some(c) = color {
331            return Some(c);
332        }
333        if let Some(ref cs) = colors {
334            return cs.get(index).copied();
335        }
336        None
337    }
338
339    fn resolve_shape(
340        index: usize,
341        shape: Option<Shape>,
342        shapes: Option<Vec<Shape>>,
343    ) -> Option<Shape> {
344        if let Some(s) = shape {
345            return Some(s);
346        }
347        if let Some(ref ss) = shapes {
348            return ss.get(index).copied();
349        }
350        None
351    }
352}
353
354impl crate::Plot for ScatterGeo {
355    fn ir_traces(&self) -> &[TraceIR] {
356        &self.traces
357    }
358
359    fn ir_layout(&self) -> &LayoutIR {
360        &self.layout
361    }
362}
363
364#[cfg(test)]
365mod tests {
366    use super::*;
367    use crate::Plot;
368    use polars::prelude::*;
369
370    #[test]
371    fn test_basic_one_trace() {
372        let df = df![
373            "lat" => [40.7, 34.0, 41.8],
374            "lon" => [-74.0, -118.2, -87.6]
375        ]
376        .unwrap();
377        let plot = ScatterGeo::builder()
378            .data(&df)
379            .lat("lat")
380            .lon("lon")
381            .build();
382        assert_eq!(plot.ir_traces().len(), 1);
383    }
384
385    #[test]
386    fn test_trace_variant() {
387        let df = df![
388            "lat" => [40.7],
389            "lon" => [-74.0]
390        ]
391        .unwrap();
392        let plot = ScatterGeo::builder()
393            .data(&df)
394            .lat("lat")
395            .lon("lon")
396            .build();
397        assert!(matches!(plot.ir_traces()[0], TraceIR::ScatterGeo(_)));
398    }
399
400    #[test]
401    fn test_with_group() {
402        let df = df![
403            "lat" => [40.7, 34.0, 41.8, 29.7],
404            "lon" => [-74.0, -118.2, -87.6, -95.3],
405            "region" => ["east", "west", "east", "south"]
406        ]
407        .unwrap();
408        let plot = ScatterGeo::builder()
409            .data(&df)
410            .lat("lat")
411            .lon("lon")
412            .group("region")
413            .build();
414        assert_eq!(plot.ir_traces().len(), 3);
415    }
416
417    #[test]
418    fn test_layout_no_cartesian_axes() {
419        let df = df![
420            "lat" => [40.7],
421            "lon" => [-74.0]
422        ]
423        .unwrap();
424        let plot = ScatterGeo::builder()
425            .data(&df)
426            .lat("lat")
427            .lon("lon")
428            .build();
429        assert!(plot.ir_layout().axes_2d.is_none());
430    }
431}