Skip to main content

plotlars_core/plots/
scattermap.rs

1use bon::bon;
2
3use polars::frame::DataFrame;
4
5use crate::{
6    components::{Legend, Rgb, Shape, Text},
7    ir::data::ColumnData,
8    ir::layout::{LayoutIR, MapboxIR},
9    ir::marker::MarkerIR,
10    ir::trace::{ScatterMapIR, TraceIR},
11};
12
13/// A structure representing a scatter plot on a map.
14///
15/// The `ScatterMap` struct allows for visualizing geographical data points on an interactive map.
16/// Each data point is defined by its latitude and longitude, with additional options for grouping,
17/// coloring, size, opacity, and map configuration such as zoom level and center coordinates.
18/// This struct is ideal for displaying spatial data distributions, such as city locations or geospatial datasets.
19///
20/// # Backend Support
21///
22/// | Backend | Supported |
23/// |---------|-----------|
24/// | Plotly  | Yes       |
25/// | Plotters| --        |
26///
27/// # Arguments
28///
29/// * `data` - A reference to the `DataFrame` containing the data to be plotted.
30/// * `latitude` - A string slice specifying the column name containing latitude values.
31/// * `longitude` - A string slice specifying the column name containing longitude values.
32/// * `center` - An optional array `[f64; 2]` specifying the initial center point of the map ([latitude, longitude]).
33/// * `zoom` - An optional `u8` specifying the initial zoom level of the map.
34/// * `group` - An optional string slice specifying the column name for grouping data points (e.g., by city or category).
35/// * `sort_groups_by` - Optional comparator `fn(&str, &str) -> std::cmp::Ordering` to control group ordering. Groups are sorted lexically by default.
36/// * `opacity` - An optional `f64` value between `0.0` and `1.0` specifying the opacity of the points.
37/// * `size` - An optional `usize` specifying the size of the scatter points.
38/// * `color` - An optional `Rgb` value specifying the color of the points (if no grouping is applied).
39/// * `colors` - An optional vector of `Rgb` values specifying colors for grouped points.
40/// * `shape` - An optional `Shape` enum specifying the marker shape for the points.
41/// * `shapes` - An optional vector of `Shape` enums specifying shapes for grouped points.
42/// * `plot_title` - An optional `Text` struct specifying the title of the plot.
43/// * `legend_title` - An optional `Text` struct specifying the title of the legend.
44/// * `legend` - An optional reference to a `Legend` struct for customizing the legend (e.g., positioning, font, etc.).
45///
46/// # Example
47///
48/// ## Basic Scatter Map Plot
49///
50/// ```rust
51/// use plotlars::{Plot, ScatterMap, Text};
52/// use polars::prelude::*;
53///
54/// let dataset = LazyCsvReader::new(PlRefPath::new("data/cities.csv"))
55///     .finish()
56///     .unwrap()
57///     .collect()
58///     .unwrap();
59///
60/// ScatterMap::builder()
61///     .data(&dataset)
62///     .latitude("latitude")
63///     .longitude("longitude")
64///     .center([48.856613, 2.352222])
65///     .zoom(4)
66///     .group("city")
67///     .opacity(0.5)
68///     .size(12)
69///     .plot_title(
70///         Text::from("Scatter Map")
71///             .font("Arial")
72///             .size(18)
73///     )
74///     .legend_title("cities")
75///     .build()
76///     .plot();
77/// ```
78///
79/// ![Example](https://imgur.com/8MCjVOd.png)
80#[derive(Clone)]
81#[allow(dead_code)]
82pub struct ScatterMap {
83    traces: Vec<TraceIR>,
84    layout: LayoutIR,
85}
86
87#[bon]
88impl ScatterMap {
89    #[builder(on(String, into), on(Text, into))]
90    pub fn new(
91        data: &DataFrame,
92        latitude: &str,
93        longitude: &str,
94        center: Option<[f64; 2]>,
95        zoom: Option<u8>,
96        group: Option<&str>,
97        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
98        opacity: Option<f64>,
99        size: Option<usize>,
100        color: Option<Rgb>,
101        colors: Option<Vec<Rgb>>,
102        shape: Option<Shape>,
103        shapes: Option<Vec<Shape>>,
104        plot_title: Option<Text>,
105        legend_title: Option<Text>,
106        legend: Option<&Legend>,
107    ) -> Self {
108        let traces = Self::create_ir_traces(
109            data,
110            latitude,
111            longitude,
112            group,
113            sort_groups_by,
114            opacity,
115            size,
116            color,
117            colors,
118            shape,
119            shapes,
120        );
121
122        let layout = LayoutIR {
123            title: plot_title,
124            x_title: None,
125            y_title: None,
126            y2_title: None,
127            z_title: None,
128            legend_title,
129            legend: legend.cloned(),
130            dimensions: None,
131            bar_mode: None,
132            box_mode: None,
133            box_gap: None,
134            margin_bottom: Some(0),
135            axes_2d: None,
136            scene_3d: None,
137            polar: None,
138            mapbox: Some(MapboxIR {
139                center: center.map(|c| (c[0], c[1])),
140                zoom: zoom.map(|z| z as f64),
141                style: None,
142            }),
143            grid: None,
144            annotations: vec![],
145        };
146
147        Self { traces, layout }
148    }
149}
150
151#[bon]
152impl ScatterMap {
153    #[builder(
154        start_fn = try_builder,
155        finish_fn = try_build,
156        builder_type = ScatterMapTryBuilder,
157        on(String, into),
158        on(Text, into),
159    )]
160    pub fn try_new(
161        data: &DataFrame,
162        latitude: &str,
163        longitude: &str,
164        center: Option<[f64; 2]>,
165        zoom: Option<u8>,
166        group: Option<&str>,
167        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
168        opacity: Option<f64>,
169        size: Option<usize>,
170        color: Option<Rgb>,
171        colors: Option<Vec<Rgb>>,
172        shape: Option<Shape>,
173        shapes: Option<Vec<Shape>>,
174        plot_title: Option<Text>,
175        legend_title: Option<Text>,
176        legend: Option<&Legend>,
177    ) -> Result<Self, crate::io::PlotlarsError> {
178        std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
179            Self::__orig_new(
180                data,
181                latitude,
182                longitude,
183                center,
184                zoom,
185                group,
186                sort_groups_by,
187                opacity,
188                size,
189                color,
190                colors,
191                shape,
192                shapes,
193                plot_title,
194                legend_title,
195                legend,
196            )
197        }))
198        .map_err(|panic| {
199            let msg = panic
200                .downcast_ref::<String>()
201                .cloned()
202                .or_else(|| panic.downcast_ref::<&str>().map(|s| s.to_string()))
203                .unwrap_or_else(|| "unknown error".to_string());
204            crate::io::PlotlarsError::PlotBuild { message: msg }
205        })
206    }
207}
208
209impl ScatterMap {
210    #[allow(clippy::too_many_arguments)]
211    fn create_ir_traces(
212        data: &DataFrame,
213        latitude: &str,
214        longitude: &str,
215        group: Option<&str>,
216        sort_groups_by: Option<fn(&str, &str) -> std::cmp::Ordering>,
217        opacity: Option<f64>,
218        size: Option<usize>,
219        color: Option<Rgb>,
220        colors: Option<Vec<Rgb>>,
221        shape: Option<Shape>,
222        shapes: Option<Vec<Shape>>,
223    ) -> Vec<TraceIR> {
224        let mut traces = Vec::new();
225
226        match group {
227            Some(group_col) => {
228                let groups = crate::data::get_unique_groups(data, group_col, sort_groups_by);
229                let groups = groups.iter().map(|s| s.as_str());
230
231                for (i, group_name) in groups.enumerate() {
232                    let subset = crate::data::filter_data_by_group(data, group_col, group_name);
233
234                    let resolved_color = Self::resolve_color(i, color, colors.clone());
235                    let resolved_shape = Self::resolve_shape(i, shape, shapes.clone());
236
237                    let marker_ir = MarkerIR {
238                        opacity,
239                        size,
240                        color: resolved_color,
241                        shape: resolved_shape,
242                    };
243
244                    let lat_data =
245                        ColumnData::Numeric(crate::data::get_numeric_column(&subset, latitude));
246                    let lon_data =
247                        ColumnData::Numeric(crate::data::get_numeric_column(&subset, longitude));
248
249                    traces.push(TraceIR::ScatterMap(ScatterMapIR {
250                        lat: lat_data,
251                        lon: lon_data,
252                        name: Some(group_name.to_string()),
253                        marker: Some(marker_ir),
254                        show_legend: None,
255                    }));
256                }
257            }
258            None => {
259                let resolved_color = Self::resolve_color(0, color, colors.clone());
260                let resolved_shape = Self::resolve_shape(0, shape, shapes.clone());
261
262                let marker_ir = MarkerIR {
263                    opacity,
264                    size,
265                    color: resolved_color,
266                    shape: resolved_shape,
267                };
268
269                let lat_data = ColumnData::Numeric(crate::data::get_numeric_column(data, latitude));
270                let lon_data =
271                    ColumnData::Numeric(crate::data::get_numeric_column(data, longitude));
272
273                traces.push(TraceIR::ScatterMap(ScatterMapIR {
274                    lat: lat_data,
275                    lon: lon_data,
276                    name: None,
277                    marker: Some(marker_ir),
278                    show_legend: None,
279                }));
280            }
281        }
282
283        traces
284    }
285
286    fn resolve_color(index: usize, color: Option<Rgb>, colors: Option<Vec<Rgb>>) -> Option<Rgb> {
287        if let Some(c) = color {
288            return Some(c);
289        }
290        if let Some(ref cs) = colors {
291            return cs.get(index).copied();
292        }
293        None
294    }
295
296    fn resolve_shape(
297        index: usize,
298        shape: Option<Shape>,
299        shapes: Option<Vec<Shape>>,
300    ) -> Option<Shape> {
301        if let Some(s) = shape {
302            return Some(s);
303        }
304        if let Some(ref ss) = shapes {
305            return ss.get(index).copied();
306        }
307        None
308    }
309}
310
311impl crate::Plot for ScatterMap {
312    fn ir_traces(&self) -> &[TraceIR] {
313        &self.traces
314    }
315
316    fn ir_layout(&self) -> &LayoutIR {
317        &self.layout
318    }
319}
320
321#[cfg(test)]
322mod tests {
323    use super::*;
324    use crate::Plot;
325    use polars::prelude::*;
326
327    #[test]
328    fn test_basic_one_trace() {
329        let df = df![
330            "latitude" => [48.8, 51.5, 40.7],
331            "longitude" => [2.3, -0.1, -74.0]
332        ]
333        .unwrap();
334        let plot = ScatterMap::builder()
335            .data(&df)
336            .latitude("latitude")
337            .longitude("longitude")
338            .build();
339        assert_eq!(plot.ir_traces().len(), 1);
340    }
341
342    #[test]
343    fn test_trace_variant() {
344        let df = df![
345            "latitude" => [48.8],
346            "longitude" => [2.3]
347        ]
348        .unwrap();
349        let plot = ScatterMap::builder()
350            .data(&df)
351            .latitude("latitude")
352            .longitude("longitude")
353            .build();
354        assert!(matches!(plot.ir_traces()[0], TraceIR::ScatterMap(_)));
355    }
356
357    #[test]
358    fn test_with_group() {
359        let df = df![
360            "latitude" => [48.8, 51.5, 40.7],
361            "longitude" => [2.3, -0.1, -74.0],
362            "city" => ["paris", "london", "nyc"]
363        ]
364        .unwrap();
365        let plot = ScatterMap::builder()
366            .data(&df)
367            .latitude("latitude")
368            .longitude("longitude")
369            .group("city")
370            .build();
371        assert_eq!(plot.ir_traces().len(), 3);
372    }
373
374    #[test]
375    fn test_layout_has_mapbox() {
376        let df = df![
377            "latitude" => [48.8],
378            "longitude" => [2.3]
379        ]
380        .unwrap();
381        let plot = ScatterMap::builder()
382            .data(&df)
383            .latitude("latitude")
384            .longitude("longitude")
385            .build();
386        assert!(plot.ir_layout().mapbox.is_some());
387    }
388}