Skip to main content

chartml_chart_scatter/
lib.rs

1use chartml_core::plugin::{ChartRenderer, ChartConfig};
2use chartml_core::data::DataTable;
3use chartml_core::element::*;
4use chartml_core::error::ChartError;
5use chartml_core::scales::{ScaleLinear, ScaleSqrt};
6use chartml_core::spec::{VisualizeSpec, FieldRef, MarkEncoding};
7use chartml_core::layout::margins::Margins;
8use chartml_core::layout::legend::{LegendMark, LegendConfig, calculate_legend_layout, generate_legend_elements};
9
10pub struct ScatterRenderer;
11
12impl ScatterRenderer {
13    pub fn new() -> Self {
14        Self
15    }
16}
17
18impl Default for ScatterRenderer {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl ChartRenderer for ScatterRenderer {
25    fn render(&self, data: &DataTable, config: &ChartConfig) -> Result<ChartElement, ChartError> {
26        let x_field = get_field_name(&config.visualize.columns)?;
27        let y_field = get_field_name(&config.visualize.rows)?;
28        let color_field = get_color_field(config);
29        let size_field = get_size_field(config);
30
31        let width = config.width;
32        let height = config.height;
33
34        let has_legend = color_field.is_some();
35        let margins = if has_legend {
36            // Add 30px bottom margin for the legend row
37            Margins::new(30.0, 20.0, 70.0, 60.0)
38        } else {
39            Margins::default()
40        };
41        let inner_width = margins.inner_width(width);
42        let inner_height = margins.inner_height(height);
43
44        // Compute domains
45        let x_extent = data.extent(&x_field)
46            .ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", x_field)))?;
47        let y_extent = data.extent(&y_field)
48            .ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", y_field)))?;
49
50        // Use actual data extent for axis domains so tightly-clustered data
51        // fills the plot area instead of being crammed near a forced zero origin.
52        let x_domain = (x_extent.0, x_extent.1);
53        let y_domain = (y_extent.0, y_extent.1);
54        let x_scale = ScaleLinear::new(x_domain, (margins.left, margins.left + inner_width)).nice(5);
55        let y_scale = ScaleLinear::new(y_domain, (margins.top + inner_height, margins.top)).nice(5); // inverted for SVG
56
57        // Size scale (if marks.size present)
58        let size_scale = size_field.as_ref().and_then(|f| {
59            data.extent(f).map(|ext| ScaleSqrt::new(ext, (3.0, 20.0))) // radius 3-20px
60        });
61
62        // Color mapping
63        let color_categories: Vec<String> = if let Some(ref cf) = color_field {
64            data.unique_values(cf)
65        } else {
66            vec![]
67        };
68
69        // Generate scatter points
70        let mut point_elements = Vec::new();
71        for i in 0..data.num_rows() {
72            let x_val = data.get_f64(i, &x_field);
73            let y_val = data.get_f64(i, &y_field);
74
75            if let (Some(x), Some(y)) = (x_val, y_val) {
76                let cx = x_scale.map(x);
77                let cy = y_scale.map(y);
78
79                let r = match (&size_field, &size_scale) {
80                    (Some(sf), Some(ss)) => {
81                        data.get_f64(i, sf).map(|v| ss.map(v)).unwrap_or(5.0)
82                    }
83                    _ => 5.0,
84                };
85
86                let color_idx = if let Some(ref cf) = color_field {
87                    data.get_string(i, cf)
88                        .and_then(|v| color_categories.iter().position(|c| c == &v))
89                        .unwrap_or(0)
90                } else {
91                    0
92                };
93                let fill = config.colors.get(color_idx % config.colors.len())
94                    .cloned()
95                    .unwrap_or_else(|| "#2E7D9A".to_string());
96
97                let label = data.get_string(i, &x_field).unwrap_or_default();
98                let value = format!("{}", y);
99                let el_data = ElementData::new(label, value);
100
101                point_elements.push(ChartElement::Circle {
102                    cx,
103                    cy,
104                    r,
105                    fill,
106                    stroke: Some("#fff".to_string()),
107                    class: "chartml-scatter-point".to_string(),
108                    data: Some(el_data),
109                });
110            }
111        }
112
113        // Build SVG
114        let mut children = Vec::new();
115
116        // Grid lines + axes
117        let x_ticks = x_scale.ticks(((inner_width / 50.0).floor() as usize).clamp(4, 10));
118        let y_ticks = y_scale.ticks(((inner_height / 50.0).floor() as usize).clamp(4, 10));
119        let mut axis_elements = Vec::new();
120
121        // Compute tick steps for formatting
122        let y_tick_step = compute_tick_step(&y_ticks);
123        let x_tick_step = compute_tick_step(&x_ticks);
124
125        // Horizontal grid lines + y-axis ticks
126        for &val in &y_ticks {
127            let y = y_scale.map(val);
128            // Grid line
129            axis_elements.push(ChartElement::Line {
130                x1: margins.left, y1: y, x2: margins.left + inner_width, y2: y,
131                stroke: "#e0e0e0".to_string(), stroke_width: Some(1.0),
132                stroke_dasharray: None, class: "grid-line".to_string(),
133            });
134            // Tick
135            axis_elements.push(ChartElement::Line {
136                x1: margins.left - 5.0, y1: y, x2: margins.left, y2: y,
137                stroke: "#999".to_string(), stroke_width: Some(1.0),
138                stroke_dasharray: None, class: "tick".to_string(),
139            });
140            // Label
141            let label = format_tick_value(val, y_tick_step);
142            axis_elements.push(ChartElement::Text {
143                x: margins.left - 8.0, y,
144                content: label, anchor: TextAnchor::End,
145                dominant_baseline: Some("middle".to_string()),
146                transform: None, font_size: Some("11px".to_string()),
147                font_weight: None,
148                fill: Some("#666".to_string()), class: "tick-label".to_string(), data: None,
149            });
150        }
151
152        // Vertical grid lines + x-axis ticks
153        let x_axis_y = margins.top + inner_height;
154        for &val in &x_ticks {
155            let x = x_scale.map(val);
156            // Grid line
157            axis_elements.push(ChartElement::Line {
158                x1: x, y1: margins.top, x2: x, y2: x_axis_y,
159                stroke: "#e0e0e0".to_string(), stroke_width: Some(1.0),
160                stroke_dasharray: None, class: "grid-line".to_string(),
161            });
162            // Tick
163            axis_elements.push(ChartElement::Line {
164                x1: x, y1: x_axis_y, x2: x, y2: x_axis_y + 5.0,
165                stroke: "#999".to_string(), stroke_width: Some(1.0),
166                stroke_dasharray: None, class: "tick".to_string(),
167            });
168            // Label
169            let label = format_tick_value(val, x_tick_step);
170            axis_elements.push(ChartElement::Text {
171                x, y: x_axis_y + 18.0,
172                content: label, anchor: TextAnchor::Middle,
173                dominant_baseline: None, transform: None,
174                font_size: Some("11px".to_string()), font_weight: None,
175                fill: Some("#666".to_string()),
176                class: "tick-label".to_string(), data: None,
177            });
178        }
179
180        // Axis lines
181        axis_elements.push(ChartElement::Line {
182            x1: margins.left, y1: margins.top, x2: margins.left, y2: x_axis_y,
183            stroke: "#ccc".to_string(), stroke_width: Some(1.0),
184            stroke_dasharray: None, class: "axis-line".to_string(),
185        });
186        axis_elements.push(ChartElement::Line {
187            x1: margins.left, y1: x_axis_y, x2: margins.left + inner_width, y2: x_axis_y,
188            stroke: "#ccc".to_string(), stroke_width: Some(1.0),
189            stroke_dasharray: None, class: "axis-line".to_string(),
190        });
191
192        children.push(ChartElement::Group {
193            class: "axes".to_string(),
194            transform: None,
195            children: axis_elements,
196        });
197
198        // Title is rendered as HTML outside the SVG — not added here.
199
200        // Points group
201        children.push(ChartElement::Group {
202            class: "chartml-scatter-points".to_string(),
203            transform: None,
204            children: point_elements,
205        });
206
207        // Legend
208        if let Some(ref cf) = color_field {
209            let series_names = data.unique_values(cf);
210            if series_names.len() > 1 {
211                let legend_config = LegendConfig::default();
212                let legend_layout = calculate_legend_layout(&series_names, &config.colors, width, &legend_config);
213                let legend_y = height - legend_layout.total_height - 8.0;
214                let legend_elements = generate_legend_elements(
215                    &series_names,
216                    &config.colors,
217                    width,
218                    legend_y,
219                    LegendMark::Circle,
220                );
221                children.push(ChartElement::Group {
222                    class: "legend".to_string(),
223                    transform: None,
224                    children: legend_elements,
225                });
226            }
227        }
228
229        Ok(ChartElement::Svg {
230            viewbox: ViewBox::new(0.0, 0.0, width, height),
231            width: Some(width),
232            height: Some(height),
233            class: "chartml-chart chartml-scatter-chart".to_string(),
234            children,
235        })
236    }
237
238    fn default_dimensions(&self, _spec: &VisualizeSpec) -> Option<Dimensions> {
239        Some(Dimensions::new(400.0))
240    }
241}
242
243/// Extract the field name from an optional FieldRef.
244fn get_field_name(field_ref: &Option<FieldRef>) -> Result<String, ChartError> {
245    match field_ref {
246        Some(FieldRef::Simple(name)) => Ok(name.clone()),
247        Some(FieldRef::Detailed(spec)) => Ok(spec.field.clone()),
248        Some(FieldRef::Multiple(items)) => {
249            // Use the first item
250            match items.first() {
251                Some(chartml_core::spec::FieldRefItem::Simple(name)) => Ok(name.clone()),
252                Some(chartml_core::spec::FieldRefItem::Detailed(spec)) => Ok(spec.field.clone()),
253                None => Err(ChartError::InvalidSpec("Empty field reference list".into())),
254            }
255        }
256        None => Err(ChartError::InvalidSpec("Missing required field reference".into())),
257    }
258}
259
260/// Extract the color field name from marks.color encoding, if present.
261fn get_color_field(config: &ChartConfig) -> Option<String> {
262    config.visualize.marks.as_ref().and_then(|marks| {
263        marks.color.as_ref().map(|enc| match enc {
264            MarkEncoding::Simple(name) => name.clone(),
265            MarkEncoding::Detailed(spec) => spec.field.clone(),
266        })
267    })
268}
269
270/// Extract the size field name from marks.size encoding, if present.
271fn get_size_field(config: &ChartConfig) -> Option<String> {
272    config.visualize.marks.as_ref().and_then(|marks| {
273        marks.size.as_ref().map(|enc| match enc {
274            MarkEncoding::Simple(name) => name.clone(),
275            MarkEncoding::Detailed(spec) => spec.field.clone(),
276        })
277    })
278}
279
280/// Compute the tick step from a slice of ticks.
281fn compute_tick_step(ticks: &[f64]) -> f64 {
282    if ticks.len() >= 2 {
283        (ticks[1] - ticks[0]).abs()
284    } else {
285        1.0
286    }
287}
288
289/// Format a numeric value for use as an axis tick label, with comma separators.
290///
291/// Mirrors the D3-style `format_tick_value` from the cartesian helpers: computes
292/// decimal precision from the tick step and inserts commas into the integer part.
293fn format_tick_value(value: f64, tick_step: f64) -> String {
294    // D3's precisionFixed(step): max(0, -floor(log10(abs(step))))
295    let precision = if tick_step.abs() < 1e-15 {
296        0usize
297    } else {
298        let p = -(tick_step.abs().log10().floor()) as i64;
299        p.max(0) as usize
300    };
301
302    let formatted = format!("{:.prec$}", value, prec = precision);
303
304    // Split on decimal point
305    let (int_part, dec_part) = if let Some(dot_pos) = formatted.find('.') {
306        (&formatted[..dot_pos], Some(&formatted[dot_pos..]))
307    } else {
308        (formatted.as_str(), None)
309    };
310
311    // Handle negative sign
312    let (sign, digits) = if let Some(stripped) = int_part.strip_prefix('-') {
313        ("-", stripped)
314    } else {
315        ("", int_part)
316    };
317
318    let with_commas = insert_commas(digits);
319
320    match dec_part {
321        Some(dec) => format!("{}{}{}", sign, with_commas, dec),
322        None => format!("{}{}", sign, with_commas),
323    }
324}
325
326/// Insert comma separators into a string of digits.
327fn insert_commas(digits: &str) -> String {
328    let len = digits.len();
329    if len <= 3 {
330        return digits.to_string();
331    }
332    let mut result = String::with_capacity(len + len / 3);
333    for (i, ch) in digits.chars().enumerate() {
334        if i > 0 && (len - i).is_multiple_of(3) {
335            result.push(',');
336        }
337        result.push(ch);
338    }
339    result
340}
341
342#[cfg(test)]
343mod tests {
344    use super::*;
345    use std::collections::HashMap;
346    use chartml_core::data::Row;
347    use chartml_core::spec::{VisualizeSpec, MarksSpec, MarkEncoding};
348
349    fn make_row(pairs: &[(&str, serde_json::Value)]) -> Row {
350        let mut map = HashMap::new();
351        for (k, v) in pairs {
352            map.insert(k.to_string(), v.clone());
353        }
354        map
355    }
356
357    fn make_scatter_data() -> DataTable {
358        let rows = vec![
359            make_row(&[("price", serde_json::json!(10.0)), ("units", serde_json::json!(100.0)), ("category", serde_json::json!("A"))]),
360            make_row(&[("price", serde_json::json!(20.0)), ("units", serde_json::json!(200.0)), ("category", serde_json::json!("B"))]),
361            make_row(&[("price", serde_json::json!(30.0)), ("units", serde_json::json!(150.0)), ("category", serde_json::json!("A"))]),
362            make_row(&[("price", serde_json::json!(40.0)), ("units", serde_json::json!(300.0)), ("category", serde_json::json!("B"))]),
363        ];
364        DataTable::from_rows(&rows).unwrap()
365    }
366
367    fn make_scatter_config() -> ChartConfig {
368        ChartConfig {
369            visualize: VisualizeSpec {
370                chart_type: "scatter".to_string(),
371                mode: None,
372                orientation: None,
373                columns: Some(FieldRef::Simple("price".to_string())),
374                rows: Some(FieldRef::Simple("units".to_string())),
375                marks: Some(MarksSpec {
376                    color: Some(MarkEncoding::Simple("category".to_string())),
377                    size: None,
378                    shape: None,
379                    text: None,
380                }),
381                axes: None,
382                annotations: None,
383                style: None,
384                value: None,
385                label: None,
386                format: None,
387                compare_with: None,
388                invert_trend: None,
389                data_labels: None,
390            },
391            title: Some("Scatter Test".to_string()),
392            width: 800.0,
393            height: 400.0,
394            colors: vec![
395                "#2E7D9A".to_string(),
396                "#E8533E".to_string(),
397                "#4CAF50".to_string(),
398            ],
399        }
400    }
401
402    fn make_bubble_data() -> DataTable {
403        let rows = vec![
404            make_row(&[("x", serde_json::json!(5.0)), ("y", serde_json::json!(10.0)), ("size", serde_json::json!(100.0))]),
405            make_row(&[("x", serde_json::json!(15.0)), ("y", serde_json::json!(20.0)), ("size", serde_json::json!(400.0))]),
406            make_row(&[("x", serde_json::json!(25.0)), ("y", serde_json::json!(15.0)), ("size", serde_json::json!(200.0))]),
407        ];
408        DataTable::from_rows(&rows).unwrap()
409    }
410
411    fn make_bubble_config() -> ChartConfig {
412        ChartConfig {
413            visualize: VisualizeSpec {
414                chart_type: "scatter".to_string(),
415                mode: None,
416                orientation: None,
417                columns: Some(FieldRef::Simple("x".to_string())),
418                rows: Some(FieldRef::Simple("y".to_string())),
419                marks: Some(MarksSpec {
420                    color: None,
421                    size: Some(MarkEncoding::Simple("size".to_string())),
422                    shape: None,
423                    text: None,
424                }),
425                axes: None,
426                annotations: None,
427                style: None,
428                value: None,
429                label: None,
430                format: None,
431                compare_with: None,
432                invert_trend: None,
433                data_labels: None,
434            },
435            title: None,
436            width: 600.0,
437            height: 400.0,
438            colors: vec!["#2E7D9A".to_string()],
439        }
440    }
441
442    #[test]
443    fn scatter_chart_renders() {
444        let renderer = ScatterRenderer::new();
445        let result = renderer.render(&make_scatter_data(), &make_scatter_config());
446        assert!(result.is_ok(), "render failed: {:?}", result.err());
447        let element = result.unwrap();
448        let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
449        assert_eq!(circle_count, 6); // 4 data points + 2 legend circles (categories A, B)
450    }
451
452    #[test]
453    fn scatter_with_size_encoding() {
454        let renderer = ScatterRenderer::new();
455        let result = renderer.render(&make_bubble_data(), &make_bubble_config());
456        assert!(result.is_ok(), "render failed: {:?}", result.err());
457        let element = result.unwrap();
458        let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
459        assert!(circle_count > 0);
460    }
461
462    #[test]
463    fn scatter_empty_data_errors() {
464        let renderer = ScatterRenderer::new();
465        let data = DataTable::from_rows(&Vec::<Row>::new()).unwrap();
466        let result = renderer.render(&data, &make_scatter_config());
467        assert!(result.is_err());
468    }
469}