Skip to main content

chartml_chart_scatter/
lib.rs

1use chartml_core::plugin::{ChartRenderer, ChartConfig};
2use chartml_core::data::DataTable;
3use chartml_core::element::*;
4use chartml_core::error::ChartError;
5use chartml_core::scales::{ScaleLinear, ScaleSqrt};
6use chartml_core::spec::{VisualizeSpec, FieldRef, MarkEncoding};
7use chartml_core::layout::margins::{MarginConfig, calculate_margins};
8use chartml_core::layout::labels::{TextMetrics, approximate_text_width_at, format_tick_value_si, measure_text};
9use chartml_core::layout::legend::{LegendMark, LegendConfig, calculate_legend_layout, generate_legend_elements};
10use chartml_core::theme::GridStyle;
11
12/// Whether horizontal (constant-y) gridlines should be drawn.
13#[inline]
14fn should_draw_horizontal_grid(style: &GridStyle) -> bool {
15    matches!(style, GridStyle::Both | GridStyle::HorizontalOnly)
16}
17
18/// Whether vertical (constant-x) gridlines should be drawn.
19#[inline]
20fn should_draw_vertical_grid(style: &GridStyle) -> bool {
21    matches!(style, GridStyle::Both | GridStyle::VerticalOnly)
22}
23
24pub struct ScatterRenderer;
25
26impl ScatterRenderer {
27    pub fn new() -> Self {
28        Self
29    }
30}
31
32impl Default for ScatterRenderer {
33    fn default() -> Self {
34        Self::new()
35    }
36}
37
38impl ChartRenderer for ScatterRenderer {
39    fn render(&self, data: &DataTable, config: &ChartConfig) -> Result<ChartElement, ChartError> {
40        let x_field = get_field_name(&config.visualize.columns)?;
41        let y_field = get_field_name(&config.visualize.rows)?;
42        let color_field = get_color_field(config);
43        let size_field = get_size_field(config);
44
45        let width = config.width;
46        let height = config.height;
47
48        // Color mapping — compute early so has_legend is accurate
49        let color_categories: Vec<String> = if let Some(ref cf) = color_field {
50            data.unique_values(cf)
51        } else {
52            vec![]
53        };
54
55        let has_legend = color_categories.len() > 1;
56        let legend_height = if has_legend {
57            let legend_config = LegendConfig {
58                text_metrics: TextMetrics::from_theme_legend(&config.theme),
59                ..LegendConfig::default()
60            };
61            calculate_legend_layout(&color_categories, &config.colors, width, &legend_config).total_height
62        } else {
63            0.0
64        };
65        let margin_config = MarginConfig {
66            legend_height,
67            chart_height: height,
68            tick_value_metrics: TextMetrics::from_theme_tick_value(&config.theme),
69            axis_label_metrics: TextMetrics::from_theme_axis_label(&config.theme),
70            ..Default::default()
71        };
72        let margins = calculate_margins(&margin_config);
73        let inner_width = margins.inner_width(width);
74        let inner_height = margins.inner_height(height);
75
76        // Compute domains
77        let x_extent = data.extent(&x_field)
78            .ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", x_field)))?;
79        let y_extent = data.extent(&y_field)
80            .ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", y_field)))?;
81
82        // Use actual data extent for axis domains so tightly-clustered data
83        // fills the plot area instead of being crammed near a forced zero origin.
84        let x_domain = (x_extent.0, x_extent.1);
85        let y_domain = (y_extent.0, y_extent.1);
86        // Size scale (if marks.size present)
87        let size_scale = size_field.as_ref().and_then(|f| {
88            data.extent(f).map(|ext| ScaleSqrt::new(ext, (3.0, 20.0))) // radius 3-20px
89        });
90
91        // Compute maximum point radius to inset scale ranges, ensuring circles
92        // don't extend past the SVG edges. Use at least 5% of inner dimension.
93        let max_radius = match (&size_field, &size_scale) {
94            (Some(sf), Some(ss)) => {
95                let mut mr = 5.0_f64;
96                for i in 0..data.num_rows() {
97                    if let Some(v) = data.get_f64(i, sf) {
98                        mr = mr.max(ss.map(v));
99                    }
100                }
101                mr
102            }
103            _ => 5.0,
104        };
105        let x_inset = max_radius.max(inner_width * 0.05);
106        let y_inset = max_radius.max(inner_height * 0.05);
107
108        let x_scale = ScaleLinear::new(x_domain, (margins.left + x_inset, margins.left + inner_width - x_inset)).nice(5);
109        let y_scale = ScaleLinear::new(y_domain, (margins.top + inner_height - y_inset, margins.top + y_inset)).nice(5); // inverted for SVG
110
111        // Generate scatter points
112        let mut point_elements = Vec::new();
113        for i in 0..data.num_rows() {
114            let x_val = data.get_f64(i, &x_field);
115            let y_val = data.get_f64(i, &y_field);
116
117            if let (Some(x), Some(y)) = (x_val, y_val) {
118                let cx = x_scale.map(x);
119                let cy = y_scale.map(y);
120
121                let r = match (&size_field, &size_scale) {
122                    (Some(sf), Some(ss)) => {
123                        // Spec-driven sized case — the `size` field overrides
124                        // `theme.dot_radius` per-point. The inner `unwrap_or(5.0)`
125                        // is the fallback for rows with a missing size value and
126                        // is intentionally unchanged by Phase 5 theme wiring.
127                        data.get_f64(i, sf).map(|v| ss.map(v)).unwrap_or(5.0)
128                    }
129                    _ => config.theme.dot_radius as f64,
130                };
131
132                let color_idx = if let Some(ref cf) = color_field {
133                    data.get_string(i, cf)
134                        .and_then(|v| color_categories.iter().position(|c| c == &v))
135                        .unwrap_or(0)
136                } else {
137                    0
138                };
139                let fill = config.colors.get(color_idx % config.colors.len())
140                    .cloned()
141                    .unwrap_or_else(|| "#2E7D9A".to_string());
142
143                let label = data.get_string(i, &x_field).unwrap_or_default();
144                let value = format!("{}", y);
145                let mut el_data = ElementData::new(label, value);
146                if let Some(ref cf) = color_field {
147                    if let Some(series_name) = data.get_string(i, cf) {
148                        el_data = el_data.with_series(series_name);
149                    }
150                }
151
152                if let Some(halo) = emit_dot_halo_if_enabled(&config.theme, cx, cy, r) {
153                    point_elements.push(halo);
154                }
155                point_elements.push(ChartElement::Circle {
156                    cx,
157                    cy,
158                    r,
159                    fill,
160                    stroke: Some(config.theme.bg.clone()),
161                    class: "chartml-scatter-point dot-marker".to_string(),
162                    data: Some(el_data),
163                });
164            }
165        }
166
167        // Build SVG
168        let mut children = Vec::new();
169
170        // Grid lines + axes
171        let x_ticks = x_scale.ticks(((inner_width / 50.0).floor() as usize).clamp(4, 10));
172        let y_ticks = y_scale.ticks(((inner_height / 50.0).floor() as usize).clamp(4, 10));
173        let mut axis_elements = Vec::new();
174
175        // Compute tick steps for formatting
176        let y_tick_step = compute_tick_step(&y_ticks);
177        let x_tick_step = compute_tick_step(&x_ticks);
178
179        // Y-axis label skip factor: skip labels that would overlap vertically
180        let y_label_height = 18.0; // 14px font + 4px spacing
181        let y_skip = if y_ticks.len() > 1 {
182            let px_per_tick = inner_height / (y_ticks.len() - 1) as f64;
183            (y_label_height / px_per_tick).ceil() as usize
184        } else {
185            1
186        }.max(1);
187
188        // Horizontal grid lines + y-axis ticks
189        let draw_h_grid = should_draw_horizontal_grid(&config.theme.grid_style);
190        let draw_v_grid = should_draw_vertical_grid(&config.theme.grid_style);
191        for (i, &val) in y_ticks.iter().enumerate() {
192            let y = y_scale.map(val);
193            // Grid line — gated on grid_style (horizontal: constant y)
194            if draw_h_grid {
195                axis_elements.push(ChartElement::Line {
196                    x1: margins.left, y1: y, x2: margins.left + inner_width, y2: y,
197                    stroke: config.theme.grid.clone(), stroke_width: Some(config.theme.grid_line_weight as f64),
198                    stroke_dasharray: None, class: "grid-line".to_string(),
199                });
200            }
201            // Tick mark — always rendered
202            axis_elements.push(ChartElement::Line {
203                x1: margins.left - 5.0, y1: y, x2: margins.left, y2: y,
204                stroke: config.theme.tick.clone(), stroke_width: Some(config.theme.axis_line_weight as f64),
205                stroke_dasharray: None, class: "tick".to_string(),
206            });
207            // Label — only if not skipped
208            if i % y_skip == 0 {
209                let label = format_tick_value_si(val, y_tick_step);
210                let ts = TextStyle::for_role(&config.theme, TextRole::TickValue);
211                axis_elements.push(ChartElement::Text {
212                    x: margins.left - 8.0, y,
213                    content: label, anchor: TextAnchor::End,
214                    dominant_baseline: Some("middle".to_string()),
215                    transform: None,
216                    font_family: ts.font_family,
217                    font_size: ts.font_size,
218                    font_weight: ts.font_weight,
219                    letter_spacing: ts.letter_spacing,
220                    text_transform: ts.text_transform,
221                    fill: Some(config.theme.text_secondary.clone()), class: "tick-label tick-value".to_string(), data: None,
222                });
223            }
224        }
225
226        // X-axis label skip factor: skip labels that would overlap horizontally
227        // Scatter historically measured X tick labels at 11px (one px below
228        // the 12px default) — preserve that exactly when the theme is the
229        // legacy default to keep golden snapshots byte-identical. Any theme
230        // override switches to the full theme-aware measurement.
231        let scatter_tick_metrics = TextMetrics::from_theme_tick_value(&config.theme);
232        let x_label_widths: Vec<f64> = x_ticks.iter()
233            .map(|&v| {
234                let label = format_tick_value_si(v, x_tick_step);
235                if scatter_tick_metrics.is_legacy_default() {
236                    approximate_text_width_at(&label, 11.0)
237                } else {
238                    measure_text(&label, &scatter_tick_metrics)
239                }
240            })
241            .collect();
242        let x_widest = x_label_widths.iter().cloned().fold(0.0_f64, f64::max);
243        let x_skip = if x_ticks.len() > 1 {
244            let px_per_tick = inner_width / (x_ticks.len() - 1) as f64;
245            let needed = x_widest + 8.0; // label width + small gap
246            (needed / px_per_tick).ceil() as usize
247        } else {
248            1
249        }.max(1);
250
251        // Vertical grid lines + x-axis ticks
252        let x_axis_y = margins.top + inner_height;
253        for (i, &val) in x_ticks.iter().enumerate() {
254            let x = x_scale.map(val);
255            // Grid line — gated on grid_style (vertical: constant x)
256            if draw_v_grid {
257                axis_elements.push(ChartElement::Line {
258                    x1: x, y1: margins.top, x2: x, y2: x_axis_y,
259                    stroke: config.theme.grid.clone(), stroke_width: Some(config.theme.grid_line_weight as f64),
260                    stroke_dasharray: None, class: "grid-line".to_string(),
261                });
262            }
263            // Tick mark — always rendered
264            axis_elements.push(ChartElement::Line {
265                x1: x, y1: x_axis_y, x2: x, y2: x_axis_y + 5.0,
266                stroke: config.theme.tick.clone(), stroke_width: Some(config.theme.axis_line_weight as f64),
267                stroke_dasharray: None, class: "tick".to_string(),
268            });
269            // Label — only if not skipped
270            if i % x_skip == 0 {
271                let label = format_tick_value_si(val, x_tick_step);
272                let ts = TextStyle::for_role(&config.theme, TextRole::TickValue);
273                axis_elements.push(ChartElement::Text {
274                    x, y: x_axis_y + 18.0,
275                    content: label, anchor: TextAnchor::Middle,
276                    dominant_baseline: None, transform: None,
277                    font_family: ts.font_family,
278                    font_size: ts.font_size,
279                    font_weight: ts.font_weight,
280                    letter_spacing: ts.letter_spacing,
281                    text_transform: ts.text_transform,
282                    fill: Some(config.theme.text_secondary.clone()),
283                    class: "tick-label tick-value".to_string(), data: None,
284                });
285            }
286        }
287
288        // Axis lines
289        axis_elements.push(ChartElement::Line {
290            x1: margins.left, y1: margins.top, x2: margins.left, y2: x_axis_y,
291            stroke: config.theme.axis_line.clone(), stroke_width: Some(config.theme.axis_line_weight as f64),
292            stroke_dasharray: None, class: "axis-line".to_string(),
293        });
294        axis_elements.push(ChartElement::Line {
295            x1: margins.left, y1: x_axis_y, x2: margins.left + inner_width, y2: x_axis_y,
296            stroke: config.theme.axis_line.clone(), stroke_width: Some(config.theme.axis_line_weight as f64),
297            stroke_dasharray: None, class: "axis-line".to_string(),
298        });
299
300        children.push(ChartElement::Group {
301            class: "axes".to_string(),
302            transform: None,
303            children: axis_elements,
304        });
305
306        // Title is rendered as HTML outside the SVG — not added here.
307
308        // Points group
309        children.push(ChartElement::Group {
310            class: "chartml-scatter-points".to_string(),
311            transform: None,
312            children: point_elements,
313        });
314
315        // Legend — reuse color_categories computed earlier
316        if color_categories.len() > 1 {
317            let legend_config = LegendConfig {
318                text_metrics: TextMetrics::from_theme_legend(&config.theme),
319                ..LegendConfig::default()
320            };
321            let legend_layout = calculate_legend_layout(&color_categories, &config.colors, width, &legend_config);
322            let legend_y = height - legend_layout.total_height - 8.0;
323            let legend_elements = generate_legend_elements(
324                &color_categories,
325                &config.colors,
326                width,
327                legend_y,
328                LegendMark::Circle,
329                &config.theme,
330            );
331            children.push(ChartElement::Group {
332                class: "legend".to_string(),
333                transform: None,
334                children: legend_elements,
335            });
336        }
337
338        Ok(ChartElement::Svg {
339            viewbox: ViewBox::new(0.0, 0.0, width, height),
340            width: Some(width),
341            height: Some(height),
342            class: "chartml-chart chartml-scatter-chart".to_string(),
343            children,
344        })
345    }
346
347    fn default_dimensions(&self, _spec: &VisualizeSpec) -> Option<Dimensions> {
348        Some(Dimensions::new(400.0))
349    }
350}
351
352/// Extract the field name from an optional FieldRef.
353fn get_field_name(field_ref: &Option<FieldRef>) -> Result<String, ChartError> {
354    match field_ref {
355        Some(FieldRef::Simple(name)) => Ok(name.clone()),
356        Some(FieldRef::Detailed(spec)) => Ok(spec.field.clone()),
357        Some(FieldRef::Multiple(items)) => {
358            // Use the first item
359            match items.first() {
360                Some(chartml_core::spec::FieldRefItem::Simple(name)) => Ok(name.clone()),
361                Some(chartml_core::spec::FieldRefItem::Detailed(spec)) => Ok(spec.field.clone()),
362                None => Err(ChartError::InvalidSpec("Empty field reference list".into())),
363            }
364        }
365        None => Err(ChartError::InvalidSpec("Missing required field reference".into())),
366    }
367}
368
369/// Extract the color field name from marks.color encoding, if present.
370fn get_color_field(config: &ChartConfig) -> Option<String> {
371    config.visualize.marks.as_ref().and_then(|marks| {
372        marks.color.as_ref().map(|enc| match enc {
373            MarkEncoding::Simple(name) => name.clone(),
374            MarkEncoding::Detailed(spec) => spec.field.clone(),
375        })
376    })
377}
378
379/// Extract the size field name from marks.size encoding, if present.
380fn get_size_field(config: &ChartConfig) -> Option<String> {
381    config.visualize.marks.as_ref().and_then(|marks| {
382        marks.size.as_ref().map(|enc| match enc {
383            MarkEncoding::Simple(name) => name.clone(),
384            MarkEncoding::Detailed(spec) => spec.field.clone(),
385        })
386    })
387}
388
389/// Compute the tick step from a slice of ticks.
390fn compute_tick_step(ticks: &[f64]) -> f64 {
391    if ticks.len() >= 2 {
392        (ticks[1] - ticks[0]).abs()
393    } else {
394        1.0
395    }
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use std::collections::HashMap;
402    use chartml_core::data::Row;
403    use chartml_core::spec::{VisualizeSpec, MarksSpec, MarkEncoding};
404
405    fn make_row(pairs: &[(&str, serde_json::Value)]) -> Row {
406        let mut map = HashMap::new();
407        for (k, v) in pairs {
408            map.insert(k.to_string(), v.clone());
409        }
410        map
411    }
412
413    fn make_scatter_data() -> DataTable {
414        let rows = vec![
415            make_row(&[("price", serde_json::json!(10.0)), ("units", serde_json::json!(100.0)), ("category", serde_json::json!("A"))]),
416            make_row(&[("price", serde_json::json!(20.0)), ("units", serde_json::json!(200.0)), ("category", serde_json::json!("B"))]),
417            make_row(&[("price", serde_json::json!(30.0)), ("units", serde_json::json!(150.0)), ("category", serde_json::json!("A"))]),
418            make_row(&[("price", serde_json::json!(40.0)), ("units", serde_json::json!(300.0)), ("category", serde_json::json!("B"))]),
419        ];
420        DataTable::from_rows(&rows).unwrap()
421    }
422
423    fn make_scatter_config() -> ChartConfig {
424        ChartConfig {
425            visualize: VisualizeSpec {
426                chart_type: "scatter".to_string(),
427                mode: None,
428                orientation: None,
429                columns: Some(FieldRef::Simple("price".to_string())),
430                rows: Some(FieldRef::Simple("units".to_string())),
431                marks: Some(MarksSpec {
432                    color: Some(MarkEncoding::Simple("category".to_string())),
433                    size: None,
434                    shape: None,
435                    text: None,
436                }),
437                axes: None,
438                annotations: None,
439                style: None,
440                value: None,
441                label: None,
442                format: None,
443                compare_with: None,
444                invert_trend: None,
445                data_labels: None,
446            },
447            title: Some("Scatter Test".to_string()),
448            width: 800.0,
449            height: 400.0,
450            colors: vec![
451                "#2E7D9A".to_string(),
452                "#E8533E".to_string(),
453                "#4CAF50".to_string(),
454            ],
455            theme: chartml_core::theme::Theme::default(),
456        }
457    }
458
459    fn make_bubble_data() -> DataTable {
460        let rows = vec![
461            make_row(&[("x", serde_json::json!(5.0)), ("y", serde_json::json!(10.0)), ("size", serde_json::json!(100.0))]),
462            make_row(&[("x", serde_json::json!(15.0)), ("y", serde_json::json!(20.0)), ("size", serde_json::json!(400.0))]),
463            make_row(&[("x", serde_json::json!(25.0)), ("y", serde_json::json!(15.0)), ("size", serde_json::json!(200.0))]),
464        ];
465        DataTable::from_rows(&rows).unwrap()
466    }
467
468    fn make_bubble_config() -> ChartConfig {
469        ChartConfig {
470            visualize: VisualizeSpec {
471                chart_type: "scatter".to_string(),
472                mode: None,
473                orientation: None,
474                columns: Some(FieldRef::Simple("x".to_string())),
475                rows: Some(FieldRef::Simple("y".to_string())),
476                marks: Some(MarksSpec {
477                    color: None,
478                    size: Some(MarkEncoding::Simple("size".to_string())),
479                    shape: None,
480                    text: None,
481                }),
482                axes: None,
483                annotations: None,
484                style: None,
485                value: None,
486                label: None,
487                format: None,
488                compare_with: None,
489                invert_trend: None,
490                data_labels: None,
491            },
492            title: None,
493            width: 600.0,
494            height: 400.0,
495            colors: vec!["#2E7D9A".to_string()],
496            theme: chartml_core::theme::Theme::default(),
497        }
498    }
499
500    #[test]
501    fn scatter_chart_renders() {
502        let renderer = ScatterRenderer::new();
503        let result = renderer.render(&make_scatter_data(), &make_scatter_config());
504        assert!(result.is_ok(), "render failed: {:?}", result.err());
505        let element = result.unwrap();
506        let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
507        assert_eq!(circle_count, 6); // 4 data points + 2 legend circles (categories A, B)
508    }
509
510    #[test]
511    fn scatter_with_size_encoding() {
512        let renderer = ScatterRenderer::new();
513        let result = renderer.render(&make_bubble_data(), &make_bubble_config());
514        assert!(result.is_ok(), "render failed: {:?}", result.err());
515        let element = result.unwrap();
516        let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
517        assert!(circle_count > 0);
518    }
519
520    #[test]
521    fn scatter_data_series_populated_with_color_encoding() {
522        let renderer = ScatterRenderer::new();
523        let element = renderer.render(&make_scatter_data(), &make_scatter_config()).unwrap();
524        // Collect series values from data circles (not legend circles)
525        let mut series_values = Vec::new();
526        fn collect_series(el: &ChartElement, out: &mut Vec<Option<String>>) {
527            match el {
528                ChartElement::Circle { data: Some(d), class, .. } if !class.contains("legend") => {
529                    out.push(d.series.clone());
530                }
531                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
532                    for child in children { collect_series(child, out); }
533                }
534                _ => {}
535            }
536        }
537        collect_series(&element, &mut series_values);
538        assert_eq!(series_values.len(), 4, "Expected 4 data circles");
539        for (i, series) in series_values.iter().enumerate() {
540            assert!(series.is_some(), "Circle {} has null data.series", i);
541        }
542        // Verify actual category values are present
543        let series_strs: Vec<&str> = series_values.iter().map(|s| s.as_deref().unwrap()).collect();
544        assert!(series_strs.contains(&"A"));
545        assert!(series_strs.contains(&"B"));
546    }
547
548    #[test]
549    fn scatter_empty_data_errors() {
550        let renderer = ScatterRenderer::new();
551        let data = DataTable::from_rows(&Vec::<Row>::new()).unwrap();
552        let result = renderer.render(&data, &make_scatter_config());
553        assert!(result.is_err());
554    }
555
556    // ----- Phase 6: theme.grid_style gating -----
557
558    fn count_scatter_grid_lines(el: &ChartElement) -> usize {
559        let mut n = 0usize;
560        fn visit(el: &ChartElement, n: &mut usize) {
561            match el {
562                ChartElement::Line { class, .. } => {
563                    if class.split_whitespace().any(|c| c == "grid-line") {
564                        *n += 1;
565                    }
566                }
567                ChartElement::Svg { children, .. }
568                | ChartElement::Group { children, .. } => {
569                    for c in children {
570                        visit(c, n);
571                    }
572                }
573                _ => {}
574            }
575        }
576        visit(el, &mut n);
577        n
578    }
579
580    // ----- Phase 8: dot_halo wiring -----
581
582    /// Collect (index, class) of every Circle and dot-halo Path in a flat
583    /// traversal order. Used by Phase 8 tests to assert halo-before-dot
584    /// ordering and counts.
585    fn collect_dot_and_halo_order(el: &ChartElement) -> Vec<(usize, String)> {
586        let mut out = Vec::new();
587        fn visit(el: &ChartElement, out: &mut Vec<(usize, String)>) {
588            match el {
589                ChartElement::Circle { class, .. } => {
590                    let idx = out.len();
591                    out.push((idx, class.clone()));
592                }
593                ChartElement::Path { class, .. } if class == "dot-halo" => {
594                    let idx = out.len();
595                    out.push((idx, class.clone()));
596                }
597                ChartElement::Svg { children, .. }
598                | ChartElement::Group { children, .. } => {
599                    for c in children {
600                        visit(c, out);
601                    }
602                }
603                _ => {}
604            }
605        }
606        visit(el, &mut out);
607        out
608    }
609
610    fn count_halos(el: &ChartElement) -> usize {
611        count_elements(el, &|e| matches!(e, ChartElement::Path { class, .. } if class == "dot-halo"))
612    }
613
614    #[test]
615    fn phase8_scatter_default_theme_emits_no_halo() {
616        use chartml_core::theme::Theme;
617        let renderer = ScatterRenderer::new();
618        let mut config = make_scatter_config();
619        config.theme = Theme::default();
620        let element = renderer.render(&make_scatter_data(), &config).unwrap();
621        assert_eq!(count_halos(&element), 0, "default theme must emit zero dot-halo elements");
622    }
623
624    #[test]
625    fn phase8_scatter_halo_color_emits_one_halo_per_point() {
626        use chartml_core::theme::Theme;
627        let renderer = ScatterRenderer::new();
628        let mut config = make_scatter_config();
629        let mut t = Theme::default();
630        t.dot_halo_color = Some("#ffffff".to_string());
631        t.dot_halo_width = 1.5;
632        config.theme = t;
633        let element = renderer.render(&make_scatter_data(), &config).unwrap();
634
635        // 4 data points → 4 halos (legend circles don't get halos).
636        assert_eq!(count_halos(&element), 4);
637
638        // Data-point circles should also still number 4.
639        let data_circles = count_elements(&element, &|e| {
640            matches!(e, ChartElement::Circle { class, .. } if class.contains("chartml-scatter-point"))
641        });
642        assert_eq!(data_circles, 4);
643
644        // Verify halo stroke + width on at least one emitted halo.
645        fn find_halo(el: &ChartElement) -> Option<(String, f64)> {
646            match el {
647                ChartElement::Path { class, stroke, stroke_width, .. } if class == "dot-halo" => {
648                    Some((stroke.clone().unwrap_or_default(), stroke_width.unwrap_or(-1.0)))
649                }
650                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
651                    children.iter().find_map(find_halo)
652                }
653                _ => None,
654            }
655        }
656        let (stroke, width) = find_halo(&element).expect("at least one halo");
657        assert_eq!(stroke, "#ffffff");
658        assert!((width - 1.5).abs() < 1e-9, "halo stroke-width {} != 1.5", width);
659    }
660
661    #[test]
662    fn phase8_scatter_halo_precedes_dot_in_order() {
663        use chartml_core::theme::Theme;
664        let renderer = ScatterRenderer::new();
665        let mut config = make_scatter_config();
666        let mut t = Theme::default();
667        t.dot_halo_color = Some("#ffffff".to_string());
668        t.dot_halo_width = 1.5;
669        config.theme = t;
670        let element = renderer.render(&make_scatter_data(), &config).unwrap();
671
672        // Walk the points group and assert every dot-marker circle is
673        // preceded immediately by a dot-halo.
674        fn find_points_group(el: &ChartElement) -> Option<&Vec<ChartElement>> {
675            match el {
676                ChartElement::Group { class, children, .. }
677                    if class == "chartml-scatter-points" => Some(children),
678                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
679                    children.iter().find_map(find_points_group)
680                }
681                _ => None,
682            }
683        }
684        let points = find_points_group(&element).expect("points group");
685        let mut pair_count = 0;
686        let mut iter = points.iter().peekable();
687        while let Some(el) = iter.next() {
688            if let ChartElement::Path { class, .. } = el {
689                if class == "dot-halo" {
690                    let next = iter.peek().expect("halo must be followed by dot");
691                    match next {
692                        ChartElement::Circle { class, .. } => {
693                            assert!(class.contains("dot-marker"));
694                            pair_count += 1;
695                        }
696                        _ => panic!("halo not immediately followed by a Circle"),
697                    }
698                }
699            }
700        }
701        assert_eq!(pair_count, 4);
702    }
703
704    #[test]
705    fn phase8_bubble_halo_radius_tracks_per_point_size() {
706        // For a bubble chart each point has a distinct radius from the size
707        // scale. The halo's path must encode that same radius, not a static
708        // theme.dot_radius default.
709        use chartml_core::theme::Theme;
710        let renderer = ScatterRenderer::new();
711        let mut config = make_bubble_config();
712        let mut t = Theme::default();
713        t.dot_halo_color = Some("#000000".to_string());
714        t.dot_halo_width = 1.0;
715        config.theme = t;
716        let element = renderer.render(&make_bubble_data(), &config).unwrap();
717
718        // Collect halo path d strings and dot radii in traversal order.
719        fn collect(el: &ChartElement, halos: &mut Vec<String>, dots: &mut Vec<f64>) {
720            match el {
721                ChartElement::Path { class, d, .. } if class == "dot-halo" => halos.push(d.clone()),
722                ChartElement::Circle { class, r, .. } if class.contains("chartml-scatter-point") => {
723                    dots.push(*r);
724                }
725                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
726                    for c in children { collect(c, halos, dots); }
727                }
728                _ => {}
729            }
730        }
731        let mut halos = Vec::new();
732        let mut dots = Vec::new();
733        collect(&element, &mut halos, &mut dots);
734        assert_eq!(halos.len(), 3);
735        assert_eq!(dots.len(), 3);
736        // Distinct radii → distinct halo d strings.
737        let unique_d: std::collections::HashSet<&String> = halos.iter().collect();
738        assert_eq!(unique_d.len(), 3, "bubble halos should have 3 distinct path d values, one per radius");
739        // Each halo d must contain the per-point radius.
740        for (d, r) in halos.iter().zip(dots.iter()) {
741            assert!(
742                d.contains(&format!("{},{}", r, r)) || d.contains(&format!(" {},", r)),
743                "halo d {:?} should encode per-point radius {}",
744                d,
745                r
746            );
747        }
748    }
749
750    #[test]
751    fn phase8_scatter_traversal_order_sanity() {
752        // Basic smoke: with halo enabled, the recorded sequence should
753        // alternate halo/dot for each data point.
754        use chartml_core::theme::Theme;
755        let renderer = ScatterRenderer::new();
756        let mut config = make_scatter_config();
757        let mut t = Theme::default();
758        t.dot_halo_color = Some("#ffffff".to_string());
759        t.dot_halo_width = 1.0;
760        config.theme = t;
761        let element = renderer.render(&make_scatter_data(), &config).unwrap();
762        let order = collect_dot_and_halo_order(&element);
763        // Find the first "dot-halo" entry and ensure it's followed by a
764        // scatter-point dot.
765        let first_halo = order.iter().position(|(_, c)| c == "dot-halo");
766        assert!(first_halo.is_some());
767    }
768
769    #[test]
770    fn phase6_scatter_grid_style_none_skips_all_gridlines() {
771        use chartml_core::theme::{GridStyle, Theme};
772        let renderer = ScatterRenderer::new();
773        let data = make_scatter_data();
774        let mut config = make_scatter_config();
775
776        // Sanity: Both (default) produces > 0 gridlines.
777        let baseline = renderer.render(&data, &config).unwrap();
778        let baseline_count = count_scatter_grid_lines(&baseline);
779        assert!(
780            baseline_count > 0,
781            "scatter default (GridStyle::Both) should emit gridlines, got {}",
782            baseline_count
783        );
784
785        // None: zero gridlines of either orientation.
786        let mut t = Theme::default();
787        t.grid_style = GridStyle::None;
788        config.theme = t;
789        let element = renderer.render(&data, &config).unwrap();
790        let n = count_scatter_grid_lines(&element);
791        assert_eq!(n, 0, "GridStyle::None: expected 0 scatter gridlines, got {}", n);
792    }
793}