Skip to main content

chartml_chart_scatter/
lib.rs

1use chartml_core::plugin::{ChartRenderer, ChartConfig};
2use chartml_core::data::DataTable;
3use chartml_core::element::*;
4use chartml_core::error::ChartError;
5use chartml_core::scales::{ScaleLinear, ScaleSqrt};
6use chartml_core::spec::{VisualizeSpec, FieldRef, MarkEncoding};
7use chartml_core::layout::margins::{MarginConfig, calculate_margins};
8use chartml_core::layout::labels::{TextMetrics, approximate_text_width_at, format_tick_value_si, measure_text};
9use chartml_core::layout::legend::{LegendMark, LegendConfig, calculate_legend_layout, generate_legend_elements};
10use chartml_core::theme::GridStyle;
11
12/// Whether horizontal (constant-y) gridlines should be drawn.
13#[inline]
14fn should_draw_horizontal_grid(style: &GridStyle) -> bool {
15    matches!(style, GridStyle::Both | GridStyle::HorizontalOnly)
16}
17
18/// Whether vertical (constant-x) gridlines should be drawn.
19#[inline]
20fn should_draw_vertical_grid(style: &GridStyle) -> bool {
21    matches!(style, GridStyle::Both | GridStyle::VerticalOnly)
22}
23
24pub struct ScatterRenderer;
25
26impl ScatterRenderer {
27    pub fn new() -> Self {
28        Self
29    }
30}
31
32impl Default for ScatterRenderer {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl ChartRenderer for ScatterRenderer {
39    fn render(&self, data: &DataTable, config: &ChartConfig) -> Result<ChartElement, ChartError> {
40        let x_field = get_field_name(&config.visualize.columns)?;
41        let y_field = get_field_name(&config.visualize.rows)?;
42        let color_field = get_color_field(config);
43        let size_field = get_size_field(config);
44
45        let width = config.width;
46        let height = config.height;
47
48        // Color mapping — compute early so has_legend is accurate
49        let color_categories: Vec<String> = if let Some(ref cf) = color_field {
50            data.unique_values(cf)
51        } else {
52            vec![]
53        };
54
55        let has_legend = color_categories.len() > 1;
56        let legend_height = if has_legend {
57            let legend_config = LegendConfig {
58                text_metrics: TextMetrics::from_theme_legend(&config.theme),
59                ..LegendConfig::default()
60            };
61            calculate_legend_layout(&color_categories, &config.colors, width, &legend_config).total_height
62        } else {
63            0.0
64        };
65        let margin_config = MarginConfig {
66            legend_height,
67            chart_height: height,
68            tick_value_metrics: TextMetrics::from_theme_tick_value(&config.theme),
69            axis_label_metrics: TextMetrics::from_theme_axis_label(&config.theme),
70            ..Default::default()
71        };
72        let margins = calculate_margins(&margin_config);
73        let inner_width = margins.inner_width(width);
74        let inner_height = margins.inner_height(height);
75
76        // Compute domains
77        let x_extent = data.extent(&x_field)
78            .ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", x_field)))?;
79        let y_extent = data.extent(&y_field)
80            .ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", y_field)))?;
81
82        // Use actual data extent for axis domains so tightly-clustered data
83        // fills the plot area instead of being crammed near a forced zero origin.
84        let x_domain = (x_extent.0, x_extent.1);
85        let y_domain = (y_extent.0, y_extent.1);
86        // Size scale (if marks.size present)
87        let size_scale = size_field.as_ref().and_then(|f| {
88            data.extent(f).map(|ext| ScaleSqrt::new(ext, (3.0, 20.0))) // radius 3-20px
89        });
90
91        // Compute maximum point radius to inset scale ranges, ensuring circles
92        // don't extend past the SVG edges. Use at least 5% of inner dimension.
93        let max_radius = match (&size_field, &size_scale) {
94            (Some(sf), Some(ss)) => {
95                let mut mr = 5.0_f64;
96                for i in 0..data.num_rows() {
97                    if let Some(v) = data.get_f64(i, sf) {
98                        mr = mr.max(ss.map(v));
99                    }
100                }
101                mr
102            }
103            _ => 5.0,
104        };
105        let x_inset = max_radius.max(inner_width * 0.05);
106        let y_inset = max_radius.max(inner_height * 0.05);
107
108        let x_scale = ScaleLinear::new(x_domain, (margins.left + x_inset, margins.left + inner_width - x_inset)).nice(5);
109        let y_scale = ScaleLinear::new(y_domain, (margins.top + inner_height - y_inset, margins.top + y_inset)).nice(5); // inverted for SVG
110
111        // Generate scatter points
112        let mut point_elements = Vec::new();
113        for i in 0..data.num_rows() {
114            let x_val = data.get_f64(i, &x_field);
115            let y_val = data.get_f64(i, &y_field);
116
117            if let (Some(x), Some(y)) = (x_val, y_val) {
118                let cx = x_scale.map(x);
119                let cy = y_scale.map(y);
120
121                let r = match (&size_field, &size_scale) {
122                    (Some(sf), Some(ss)) => {
123                        // Spec-driven sized case — the `size` field overrides
124                        // `theme.dot_radius` per-point. The inner `unwrap_or(5.0)`
125                        // is the fallback for rows with a missing size value and
126                        // is intentionally unchanged by Phase 5 theme wiring.
127                        data.get_f64(i, sf).map(|v| ss.map(v)).unwrap_or(5.0)
128                    }
129                    _ => config.theme.dot_radius as f64,
130                };
131
132                let color_idx = if let Some(ref cf) = color_field {
133                    data.get_string(i, cf)
134                        .and_then(|v| color_categories.iter().position(|c| c == &v))
135                        .unwrap_or(0)
136                } else {
137                    0
138                };
139                let fill = config.colors.get(color_idx % config.colors.len())
140                    .cloned()
141                    .unwrap_or_else(|| "#2E7D9A".to_string());
142
143                let label = data.get_string(i, &x_field).unwrap_or_default();
144                let value = format!("{}", y);
145                let mut el_data = ElementData::new(label, value);
146                if let Some(ref cf) = color_field {
147                    if let Some(series_name) = data.get_string(i, cf) {
148                        el_data = el_data.with_series(series_name);
149                    }
150                }
151
152                if let Some(halo) = emit_dot_halo_if_enabled(&config.theme, cx, cy, r) {
153                    point_elements.push(halo);
154                }
155                point_elements.push(ChartElement::Circle {
156                    cx,
157                    cy,
158                    r,
159                    fill,
160                    stroke: Some(config.theme.bg.clone()),
161                    class: "chartml-scatter-point dot-marker".to_string(),
162                    data: Some(el_data),
163                });
164            }
165        }
166
167        // Build SVG
168        let mut children = Vec::new();
169
170        // Grid lines + axes
171        let x_ticks = x_scale.ticks(((inner_width / 50.0).floor() as usize).clamp(4, 10));
172        let y_ticks = y_scale.ticks(((inner_height / 50.0).floor() as usize).clamp(4, 10));
173        let mut axis_elements = Vec::new();
174
175        // Compute tick steps for formatting
176        let y_tick_step = compute_tick_step(&y_ticks);
177        let x_tick_step = compute_tick_step(&x_ticks);
178
179        // Y-axis label skip factor: skip labels that would overlap vertically
180        let y_label_height = 18.0; // 14px font + 4px spacing
181        let y_skip = if y_ticks.len() > 1 {
182            let px_per_tick = inner_height / (y_ticks.len() - 1) as f64;
183            (y_label_height / px_per_tick).ceil() as usize
184        } else {
185            1
186        }.max(1);
187
188        // Horizontal grid lines + y-axis ticks
189        let draw_h_grid = should_draw_horizontal_grid(&config.theme.grid_style);
190        let draw_v_grid = should_draw_vertical_grid(&config.theme.grid_style);
191        for (i, &val) in y_ticks.iter().enumerate() {
192            let y = y_scale.map(val);
193            // Grid line — gated on grid_style (horizontal: constant y)
194            if draw_h_grid {
195                axis_elements.push(ChartElement::Line {
196                    x1: margins.left, y1: y, x2: margins.left + inner_width, y2: y,
197                    stroke: config.theme.grid.clone(), stroke_width: Some(config.theme.grid_line_weight as f64),
198                    stroke_dasharray: None, class: "grid-line".to_string(),
199                });
200            }
201            // Tick mark — always rendered
202            axis_elements.push(ChartElement::Line {
203                x1: margins.left - 5.0, y1: y, x2: margins.left, y2: y,
204                stroke: config.theme.tick.clone(), stroke_width: Some(config.theme.axis_line_weight as f64),
205                stroke_dasharray: None, class: "tick".to_string(),
206            });
207            // Label — only if not skipped
208            if i % y_skip == 0 {
209                let label = format_tick_value_si(val, y_tick_step);
210                let ts = TextStyle::for_role(&config.theme, TextRole::TickValue);
211                axis_elements.push(ChartElement::Text {
212                    x: margins.left - 8.0, y,
213                    content: label, anchor: TextAnchor::End,
214                    dominant_baseline: Some("middle".to_string()),
215                    transform: None,
216                    font_family: ts.font_family,
217                    font_size: ts.font_size,
218                    font_weight: ts.font_weight,
219                    letter_spacing: ts.letter_spacing,
220                    text_transform: ts.text_transform,
221                    fill: Some(config.theme.text_secondary.clone()), class: "tick-label tick-value".to_string(), data: None,
222                });
223            }
224        }
225
226        // X-axis label skip factor: skip labels that would overlap horizontally
227        // Scatter historically measured X tick labels at 11px (one px below
228        // the 12px default) — preserve that exactly when the theme is the
229        // legacy default to keep golden snapshots byte-identical. Any theme
230        // override switches to the full theme-aware measurement.
231        let scatter_tick_metrics = TextMetrics::from_theme_tick_value(&config.theme);
232        let x_label_widths: Vec<f64> = x_ticks.iter()
233            .map(|&v| {
234                let label = format_tick_value_si(v, x_tick_step);
235                if scatter_tick_metrics.is_legacy_default() {
236                    approximate_text_width_at(&label, 11.0)
237                } else {
238                    measure_text(&label, &scatter_tick_metrics)
239                }
240            })
241            .collect();
242        let x_widest = x_label_widths.iter().cloned().fold(0.0_f64, f64::max);
243        let x_skip = if x_ticks.len() > 1 {
244            let px_per_tick = inner_width / (x_ticks.len() - 1) as f64;
245            let needed = x_widest + 8.0; // label width + small gap
246            (needed / px_per_tick).ceil() as usize
247        } else {
248            1
249        }.max(1);
250
251        // Vertical grid lines + x-axis ticks
252        let x_axis_y = margins.top + inner_height;
253        for (i, &val) in x_ticks.iter().enumerate() {
254            let x = x_scale.map(val);
255            // Grid line — gated on grid_style (vertical: constant x)
256            if draw_v_grid {
257                axis_elements.push(ChartElement::Line {
258                    x1: x, y1: margins.top, x2: x, y2: x_axis_y,
259                    stroke: config.theme.grid.clone(), stroke_width: Some(config.theme.grid_line_weight as f64),
260                    stroke_dasharray: None, class: "grid-line".to_string(),
261                });
262            }
263            // Tick mark — always rendered
264            axis_elements.push(ChartElement::Line {
265                x1: x, y1: x_axis_y, x2: x, y2: x_axis_y + 5.0,
266                stroke: config.theme.tick.clone(), stroke_width: Some(config.theme.axis_line_weight as f64),
267                stroke_dasharray: None, class: "tick".to_string(),
268            });
269            // Label — only if not skipped
270            if i % x_skip == 0 {
271                let label = format_tick_value_si(val, x_tick_step);
272                let ts = TextStyle::for_role(&config.theme, TextRole::TickValue);
273                axis_elements.push(ChartElement::Text {
274                    x, y: x_axis_y + 18.0,
275                    content: label, anchor: TextAnchor::Middle,
276                    dominant_baseline: None, transform: None,
277                    font_family: ts.font_family,
278                    font_size: ts.font_size,
279                    font_weight: ts.font_weight,
280                    letter_spacing: ts.letter_spacing,
281                    text_transform: ts.text_transform,
282                    fill: Some(config.theme.text_secondary.clone()),
283                    class: "tick-label tick-value".to_string(), data: None,
284                });
285            }
286        }
287
288        // Axis lines
289        axis_elements.push(ChartElement::Line {
290            x1: margins.left, y1: margins.top, x2: margins.left, y2: x_axis_y,
291            stroke: config.theme.axis_line.clone(), stroke_width: Some(config.theme.axis_line_weight as f64),
292            stroke_dasharray: None, class: "axis-line".to_string(),
293        });
294        axis_elements.push(ChartElement::Line {
295            x1: margins.left, y1: x_axis_y, x2: margins.left + inner_width, y2: x_axis_y,
296            stroke: config.theme.axis_line.clone(), stroke_width: Some(config.theme.axis_line_weight as f64),
297            stroke_dasharray: None, class: "axis-line".to_string(),
298        });
299
300        children.push(ChartElement::Group {
301            class: "axes".to_string(),
302            transform: None,
303            children: axis_elements,
304        });
305
306        // Title is rendered as HTML outside the SVG — not added here.
307
308        // Points group
309        children.push(ChartElement::Group {
310            class: "chartml-scatter-points".to_string(),
311            transform: None,
312            children: point_elements,
313        });
314
315        // Legend — reuse color_categories computed earlier
316        if color_categories.len() > 1 {
317            let legend_config = LegendConfig {
318                text_metrics: TextMetrics::from_theme_legend(&config.theme),
319                ..LegendConfig::default()
320            };
321            let legend_layout = calculate_legend_layout(&color_categories, &config.colors, width, &legend_config);
322            let legend_y = height - legend_layout.total_height - 8.0;
323            let legend_elements = generate_legend_elements(
324                &color_categories,
325                &config.colors,
326                width,
327                legend_y,
328                LegendMark::Circle,
329                &config.theme,
330            );
331            children.push(ChartElement::Group {
332                class: "legend".to_string(),
333                transform: None,
334                children: legend_elements,
335            });
336        }
337
338        Ok(ChartElement::Svg {
339            viewbox: ViewBox::new(0.0, 0.0, width, height),
340            width: Some(width),
341            height: Some(height),
342            class: "chartml-chart chartml-scatter-chart".to_string(),
343            children,
344        })
345    }
346
347    fn default_dimensions(&self, _spec: &VisualizeSpec) -> Option<Dimensions> {
348        Some(Dimensions::new(400.0))
349    }
350}
351
352/// Extract the field name from an optional FieldRef.
353fn get_field_name(field_ref: &Option<FieldRef>) -> Result<String, ChartError> {
354    fn field_or_err(spec: &chartml_core::spec::FieldSpec) -> Result<String, ChartError> {
355        spec.field
356            .clone()
357            .ok_or_else(|| ChartError::InvalidSpec("Field spec has no `field` (range-mark specs are not supported for scatter axes)".into()))
358    }
359    match field_ref {
360        Some(FieldRef::Simple(name)) => Ok(name.clone()),
361        Some(FieldRef::Detailed(spec)) => field_or_err(spec),
362        Some(FieldRef::Multiple(items)) => {
363            // Use the first item
364            match items.first() {
365                Some(chartml_core::spec::FieldRefItem::Simple(name)) => Ok(name.clone()),
366                Some(chartml_core::spec::FieldRefItem::Detailed(spec)) => field_or_err(spec),
367                None => Err(ChartError::InvalidSpec("Empty field reference list".into())),
368            }
369        }
370        None => Err(ChartError::InvalidSpec("Missing required field reference".into())),
371    }
372}
373
374/// Extract the color field name from marks.color encoding, if present.
375fn get_color_field(config: &ChartConfig) -> Option<String> {
376    config.visualize.marks.as_ref().and_then(|marks| {
377        marks.color.as_ref().map(|enc| match enc {
378            MarkEncoding::Simple(name) => name.clone(),
379            MarkEncoding::Detailed(spec) => spec.field.clone(),
380        })
381    })
382}
383
384/// Extract the size field name from marks.size encoding, if present.
385fn get_size_field(config: &ChartConfig) -> Option<String> {
386    config.visualize.marks.as_ref().and_then(|marks| {
387        marks.size.as_ref().map(|enc| match enc {
388            MarkEncoding::Simple(name) => name.clone(),
389            MarkEncoding::Detailed(spec) => spec.field.clone(),
390        })
391    })
392}
393
394/// Compute the tick step from a slice of ticks.
395fn compute_tick_step(ticks: &[f64]) -> f64 {
396    if ticks.len() >= 2 {
397        (ticks[1] - ticks[0]).abs()
398    } else {
399        1.0
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use std::collections::HashMap;
407    use chartml_core::data::Row;
408    use chartml_core::spec::{VisualizeSpec, MarksSpec, MarkEncoding};
409
410    fn make_row(pairs: &[(&str, serde_json::Value)]) -> Row {
411        let mut map = HashMap::new();
412        for (k, v) in pairs {
413            map.insert(k.to_string(), v.clone());
414        }
415        map
416    }
417
418    fn make_scatter_data() -> DataTable {
419        let rows = vec![
420            make_row(&[("price", serde_json::json!(10.0)), ("units", serde_json::json!(100.0)), ("category", serde_json::json!("A"))]),
421            make_row(&[("price", serde_json::json!(20.0)), ("units", serde_json::json!(200.0)), ("category", serde_json::json!("B"))]),
422            make_row(&[("price", serde_json::json!(30.0)), ("units", serde_json::json!(150.0)), ("category", serde_json::json!("A"))]),
423            make_row(&[("price", serde_json::json!(40.0)), ("units", serde_json::json!(300.0)), ("category", serde_json::json!("B"))]),
424        ];
425        DataTable::from_rows(&rows).unwrap()
426    }
427
428    fn make_scatter_config() -> ChartConfig {
429        ChartConfig {
430            visualize: VisualizeSpec {
431                chart_type: "scatter".to_string(),
432                mode: None,
433                orientation: None,
434                columns: Some(FieldRef::Simple("price".to_string())),
435                rows: Some(FieldRef::Simple("units".to_string())),
436                marks: Some(MarksSpec {
437                    color: Some(MarkEncoding::Simple("category".to_string())),
438                    size: None,
439                    shape: None,
440                    text: None,
441                }),
442                axes: None,
443                annotations: None,
444                style: None,
445                value: None,
446                label: None,
447                format: None,
448                compare_with: None,
449                invert_trend: None,
450                data_labels: None,
451            },
452            title: Some("Scatter Test".to_string()),
453            width: 800.0,
454            height: 400.0,
455            colors: vec![
456                "#2E7D9A".to_string(),
457                "#E8533E".to_string(),
458                "#4CAF50".to_string(),
459            ],
460            theme: chartml_core::theme::Theme::default(),
461        }
462    }
463
464    fn make_bubble_data() -> DataTable {
465        let rows = vec![
466            make_row(&[("x", serde_json::json!(5.0)), ("y", serde_json::json!(10.0)), ("size", serde_json::json!(100.0))]),
467            make_row(&[("x", serde_json::json!(15.0)), ("y", serde_json::json!(20.0)), ("size", serde_json::json!(400.0))]),
468            make_row(&[("x", serde_json::json!(25.0)), ("y", serde_json::json!(15.0)), ("size", serde_json::json!(200.0))]),
469        ];
470        DataTable::from_rows(&rows).unwrap()
471    }
472
473    fn make_bubble_config() -> ChartConfig {
474        ChartConfig {
475            visualize: VisualizeSpec {
476                chart_type: "scatter".to_string(),
477                mode: None,
478                orientation: None,
479                columns: Some(FieldRef::Simple("x".to_string())),
480                rows: Some(FieldRef::Simple("y".to_string())),
481                marks: Some(MarksSpec {
482                    color: None,
483                    size: Some(MarkEncoding::Simple("size".to_string())),
484                    shape: None,
485                    text: None,
486                }),
487                axes: None,
488                annotations: None,
489                style: None,
490                value: None,
491                label: None,
492                format: None,
493                compare_with: None,
494                invert_trend: None,
495                data_labels: None,
496            },
497            title: None,
498            width: 600.0,
499            height: 400.0,
500            colors: vec!["#2E7D9A".to_string()],
501            theme: chartml_core::theme::Theme::default(),
502        }
503    }
504
505    #[test]
506    fn scatter_chart_renders() {
507        let renderer = ScatterRenderer::new();
508        let result = renderer.render(&make_scatter_data(), &make_scatter_config());
509        assert!(result.is_ok(), "render failed: {:?}", result.err());
510        let element = result.unwrap();
511        let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
512        assert_eq!(circle_count, 6); // 4 data points + 2 legend circles (categories A, B)
513    }
514
515    #[test]
516    fn scatter_with_size_encoding() {
517        let renderer = ScatterRenderer::new();
518        let result = renderer.render(&make_bubble_data(), &make_bubble_config());
519        assert!(result.is_ok(), "render failed: {:?}", result.err());
520        let element = result.unwrap();
521        let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
522        assert!(circle_count > 0);
523    }
524
525    #[test]
526    fn scatter_data_series_populated_with_color_encoding() {
527        let renderer = ScatterRenderer::new();
528        let element = renderer.render(&make_scatter_data(), &make_scatter_config()).unwrap();
529        // Collect series values from data circles (not legend circles)
530        let mut series_values = Vec::new();
531        fn collect_series(el: &ChartElement, out: &mut Vec<Option<String>>) {
532            match el {
533                ChartElement::Circle { data: Some(d), class, .. } if !class.contains("legend") => {
534                    out.push(d.series.clone());
535                }
536                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
537                    for child in children { collect_series(child, out); }
538                }
539                _ => {}
540            }
541        }
542        collect_series(&element, &mut series_values);
543        assert_eq!(series_values.len(), 4, "Expected 4 data circles");
544        for (i, series) in series_values.iter().enumerate() {
545            assert!(series.is_some(), "Circle {} has null data.series", i);
546        }
547        // Verify actual category values are present
548        let series_strs: Vec<&str> = series_values.iter().map(|s| s.as_deref().unwrap()).collect();
549        assert!(series_strs.contains(&"A"));
550        assert!(series_strs.contains(&"B"));
551    }
552
553    #[test]
554    fn scatter_empty_data_errors() {
555        let renderer = ScatterRenderer::new();
556        let data = DataTable::from_rows(&Vec::<Row>::new()).unwrap();
557        let result = renderer.render(&data, &make_scatter_config());
558        assert!(result.is_err());
559    }
560
561    // ----- Phase 6: theme.grid_style gating -----
562
563    fn count_scatter_grid_lines(el: &ChartElement) -> usize {
564        let mut n = 0usize;
565        fn visit(el: &ChartElement, n: &mut usize) {
566            match el {
567                ChartElement::Line { class, .. } => {
568                    if class.split_whitespace().any(|c| c == "grid-line") {
569                        *n += 1;
570                    }
571                }
572                ChartElement::Svg { children, .. }
573                | ChartElement::Group { children, .. } => {
574                    for c in children {
575                        visit(c, n);
576                    }
577                }
578                _ => {}
579            }
580        }
581        visit(el, &mut n);
582        n
583    }
584
585    // ----- Phase 8: dot_halo wiring -----
586
587    /// Collect (index, class) of every Circle and dot-halo Path in a flat
588    /// traversal order. Used by Phase 8 tests to assert halo-before-dot
589    /// ordering and counts.
590    fn collect_dot_and_halo_order(el: &ChartElement) -> Vec<(usize, String)> {
591        let mut out = Vec::new();
592        fn visit(el: &ChartElement, out: &mut Vec<(usize, String)>) {
593            match el {
594                ChartElement::Circle { class, .. } => {
595                    let idx = out.len();
596                    out.push((idx, class.clone()));
597                }
598                ChartElement::Path { class, .. } if class == "dot-halo" => {
599                    let idx = out.len();
600                    out.push((idx, class.clone()));
601                }
602                ChartElement::Svg { children, .. }
603                | ChartElement::Group { children, .. } => {
604                    for c in children {
605                        visit(c, out);
606                    }
607                }
608                _ => {}
609            }
610        }
611        visit(el, &mut out);
612        out
613    }
614
615    fn count_halos(el: &ChartElement) -> usize {
616        count_elements(el, &|e| matches!(e, ChartElement::Path { class, .. } if class == "dot-halo"))
617    }
618
619    #[test]
620    fn phase8_scatter_default_theme_emits_no_halo() {
621        use chartml_core::theme::Theme;
622        let renderer = ScatterRenderer::new();
623        let mut config = make_scatter_config();
624        config.theme = Theme::default();
625        let element = renderer.render(&make_scatter_data(), &config).unwrap();
626        assert_eq!(count_halos(&element), 0, "default theme must emit zero dot-halo elements");
627    }
628
629    #[test]
630    fn phase8_scatter_halo_color_emits_one_halo_per_point() {
631        use chartml_core::theme::Theme;
632        let renderer = ScatterRenderer::new();
633        let mut config = make_scatter_config();
634        let mut t = Theme::default();
635        t.dot_halo_color = Some("#ffffff".to_string());
636        t.dot_halo_width = 1.5;
637        config.theme = t;
638        let element = renderer.render(&make_scatter_data(), &config).unwrap();
639
640        // 4 data points → 4 halos (legend circles don't get halos).
641        assert_eq!(count_halos(&element), 4);
642
643        // Data-point circles should also still number 4.
644        let data_circles = count_elements(&element, &|e| {
645            matches!(e, ChartElement::Circle { class, .. } if class.contains("chartml-scatter-point"))
646        });
647        assert_eq!(data_circles, 4);
648
649        // Verify halo stroke + width on at least one emitted halo.
650        fn find_halo(el: &ChartElement) -> Option<(String, f64)> {
651            match el {
652                ChartElement::Path { class, stroke, stroke_width, .. } if class == "dot-halo" => {
653                    Some((stroke.clone().unwrap_or_default(), stroke_width.unwrap_or(-1.0)))
654                }
655                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
656                    children.iter().find_map(find_halo)
657                }
658                _ => None,
659            }
660        }
661        let (stroke, width) = find_halo(&element).expect("at least one halo");
662        assert_eq!(stroke, "#ffffff");
663        assert!((width - 1.5).abs() < 1e-9, "halo stroke-width {} != 1.5", width);
664    }
665
666    #[test]
667    fn phase8_scatter_halo_precedes_dot_in_order() {
668        use chartml_core::theme::Theme;
669        let renderer = ScatterRenderer::new();
670        let mut config = make_scatter_config();
671        let mut t = Theme::default();
672        t.dot_halo_color = Some("#ffffff".to_string());
673        t.dot_halo_width = 1.5;
674        config.theme = t;
675        let element = renderer.render(&make_scatter_data(), &config).unwrap();
676
677        // Walk the points group and assert every dot-marker circle is
678        // preceded immediately by a dot-halo.
679        fn find_points_group(el: &ChartElement) -> Option<&Vec<ChartElement>> {
680            match el {
681                ChartElement::Group { class, children, .. }
682                    if class == "chartml-scatter-points" => Some(children),
683                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
684                    children.iter().find_map(find_points_group)
685                }
686                _ => None,
687            }
688        }
689        let points = find_points_group(&element).expect("points group");
690        let mut pair_count = 0;
691        let mut iter = points.iter().peekable();
692        while let Some(el) = iter.next() {
693            if let ChartElement::Path { class, .. } = el {
694                if class == "dot-halo" {
695                    let next = iter.peek().expect("halo must be followed by dot");
696                    match next {
697                        ChartElement::Circle { class, .. } => {
698                            assert!(class.contains("dot-marker"));
699                            pair_count += 1;
700                        }
701                        _ => panic!("halo not immediately followed by a Circle"),
702                    }
703                }
704            }
705        }
706        assert_eq!(pair_count, 4);
707    }
708
709    #[test]
710    fn phase8_bubble_halo_radius_tracks_per_point_size() {
711        // For a bubble chart each point has a distinct radius from the size
712        // scale. The halo's path must encode that same radius, not a static
713        // theme.dot_radius default.
714        use chartml_core::theme::Theme;
715        let renderer = ScatterRenderer::new();
716        let mut config = make_bubble_config();
717        let mut t = Theme::default();
718        t.dot_halo_color = Some("#000000".to_string());
719        t.dot_halo_width = 1.0;
720        config.theme = t;
721        let element = renderer.render(&make_bubble_data(), &config).unwrap();
722
723        // Collect halo path d strings and dot radii in traversal order.
724        fn collect(el: &ChartElement, halos: &mut Vec<String>, dots: &mut Vec<f64>) {
725            match el {
726                ChartElement::Path { class, d, .. } if class == "dot-halo" => halos.push(d.clone()),
727                ChartElement::Circle { class, r, .. } if class.contains("chartml-scatter-point") => {
728                    dots.push(*r);
729                }
730                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
731                    for c in children { collect(c, halos, dots); }
732                }
733                _ => {}
734            }
735        }
736        let mut halos = Vec::new();
737        let mut dots = Vec::new();
738        collect(&element, &mut halos, &mut dots);
739        assert_eq!(halos.len(), 3);
740        assert_eq!(dots.len(), 3);
741        // Distinct radii → distinct halo d strings.
742        let unique_d: std::collections::HashSet<&String> = halos.iter().collect();
743        assert_eq!(unique_d.len(), 3, "bubble halos should have 3 distinct path d values, one per radius");
744        // Each halo d must contain the per-point radius.
745        for (d, r) in halos.iter().zip(dots.iter()) {
746            assert!(
747                d.contains(&format!("{},{}", r, r)) || d.contains(&format!(" {},", r)),
748                "halo d {:?} should encode per-point radius {}",
749                d,
750                r
751            );
752        }
753    }
754
755    #[test]
756    fn phase8_scatter_traversal_order_sanity() {
757        // Basic smoke: with halo enabled, the recorded sequence should
758        // alternate halo/dot for each data point.
759        use chartml_core::theme::Theme;
760        let renderer = ScatterRenderer::new();
761        let mut config = make_scatter_config();
762        let mut t = Theme::default();
763        t.dot_halo_color = Some("#ffffff".to_string());
764        t.dot_halo_width = 1.0;
765        config.theme = t;
766        let element = renderer.render(&make_scatter_data(), &config).unwrap();
767        let order = collect_dot_and_halo_order(&element);
768        // Find the first "dot-halo" entry and ensure it's followed by a
769        // scatter-point dot.
770        let first_halo = order.iter().position(|(_, c)| c == "dot-halo");
771        assert!(first_halo.is_some());
772    }
773
774    #[test]
775    fn phase6_scatter_grid_style_none_skips_all_gridlines() {
776        use chartml_core::theme::{GridStyle, Theme};
777        let renderer = ScatterRenderer::new();
778        let data = make_scatter_data();
779        let mut config = make_scatter_config();
780
781        // Sanity: Both (default) produces > 0 gridlines.
782        let baseline = renderer.render(&data, &config).unwrap();
783        let baseline_count = count_scatter_grid_lines(&baseline);
784        assert!(
785            baseline_count > 0,
786            "scatter default (GridStyle::Both) should emit gridlines, got {}",
787            baseline_count
788        );
789
790        // None: zero gridlines of either orientation.
791        let mut t = Theme::default();
792        t.grid_style = GridStyle::None;
793        config.theme = t;
794        let element = renderer.render(&data, &config).unwrap();
795        let n = count_scatter_grid_lines(&element);
796        assert_eq!(n, 0, "GridStyle::None: expected 0 scatter gridlines, got {}", n);
797    }
798}