1use chartml_core::plugin::{ChartRenderer, ChartConfig};
2use chartml_core::data::DataTable;
3use chartml_core::element::*;
4use chartml_core::error::ChartError;
5use chartml_core::scales::{ScaleLinear, ScaleSqrt};
6use chartml_core::spec::{VisualizeSpec, FieldRef, MarkEncoding};
7use chartml_core::layout::margins::Margins;
8use chartml_core::layout::legend::{LegendMark, LegendConfig, calculate_legend_layout, generate_legend_elements};
9
10pub struct ScatterRenderer;
11
12impl ScatterRenderer {
13 pub fn new() -> Self {
14 Self
15 }
16}
17
18impl Default for ScatterRenderer {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl ChartRenderer for ScatterRenderer {
25 fn render(&self, data: &DataTable, config: &ChartConfig) -> Result<ChartElement, ChartError> {
26 let x_field = get_field_name(&config.visualize.columns)?;
27 let y_field = get_field_name(&config.visualize.rows)?;
28 let color_field = get_color_field(config);
29 let size_field = get_size_field(config);
30
31 let width = config.width;
32 let height = config.height;
33
34 let has_legend = color_field.is_some();
35 let margins = if has_legend {
36 Margins::new(30.0, 20.0, 70.0, 60.0)
38 } else {
39 Margins::default()
40 };
41 let inner_width = margins.inner_width(width);
42 let inner_height = margins.inner_height(height);
43
44 let x_extent = data.extent(&x_field)
46 .ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", x_field)))?;
47 let y_extent = data.extent(&y_field)
48 .ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", y_field)))?;
49
50 let x_domain = (x_extent.0, x_extent.1);
53 let y_domain = (y_extent.0, y_extent.1);
54 let x_scale = ScaleLinear::new(x_domain, (margins.left, margins.left + inner_width)).nice(5);
55 let y_scale = ScaleLinear::new(y_domain, (margins.top + inner_height, margins.top)).nice(5); let size_scale = size_field.as_ref().and_then(|f| {
59 data.extent(f).map(|ext| ScaleSqrt::new(ext, (3.0, 20.0))) });
61
62 let color_categories: Vec<String> = if let Some(ref cf) = color_field {
64 data.unique_values(cf)
65 } else {
66 vec![]
67 };
68
69 let mut point_elements = Vec::new();
71 for i in 0..data.num_rows() {
72 let x_val = data.get_f64(i, &x_field);
73 let y_val = data.get_f64(i, &y_field);
74
75 if let (Some(x), Some(y)) = (x_val, y_val) {
76 let cx = x_scale.map(x);
77 let cy = y_scale.map(y);
78
79 let r = match (&size_field, &size_scale) {
80 (Some(sf), Some(ss)) => {
81 data.get_f64(i, sf).map(|v| ss.map(v)).unwrap_or(5.0)
82 }
83 _ => 5.0,
84 };
85
86 let color_idx = if let Some(ref cf) = color_field {
87 data.get_string(i, cf)
88 .and_then(|v| color_categories.iter().position(|c| c == &v))
89 .unwrap_or(0)
90 } else {
91 0
92 };
93 let fill = config.colors.get(color_idx % config.colors.len())
94 .cloned()
95 .unwrap_or_else(|| "#2E7D9A".to_string());
96
97 let label = data.get_string(i, &x_field).unwrap_or_default();
98 let value = format!("{}", y);
99 let el_data = ElementData::new(label, value);
100
101 point_elements.push(ChartElement::Circle {
102 cx,
103 cy,
104 r,
105 fill,
106 stroke: Some("#fff".to_string()),
107 class: "chartml-scatter-point".to_string(),
108 data: Some(el_data),
109 });
110 }
111 }
112
113 let mut children = Vec::new();
115
116 let x_ticks = x_scale.ticks(((inner_width / 50.0).floor() as usize).clamp(4, 10));
118 let y_ticks = y_scale.ticks(((inner_height / 50.0).floor() as usize).clamp(4, 10));
119 let mut axis_elements = Vec::new();
120
121 let y_tick_step = compute_tick_step(&y_ticks);
123 let x_tick_step = compute_tick_step(&x_ticks);
124
125 for &val in &y_ticks {
127 let y = y_scale.map(val);
128 axis_elements.push(ChartElement::Line {
130 x1: margins.left, y1: y, x2: margins.left + inner_width, y2: y,
131 stroke: "#e0e0e0".to_string(), stroke_width: Some(1.0),
132 stroke_dasharray: None, class: "grid-line".to_string(),
133 });
134 axis_elements.push(ChartElement::Line {
136 x1: margins.left - 5.0, y1: y, x2: margins.left, y2: y,
137 stroke: "#999".to_string(), stroke_width: Some(1.0),
138 stroke_dasharray: None, class: "tick".to_string(),
139 });
140 let label = format_tick_value(val, y_tick_step);
142 axis_elements.push(ChartElement::Text {
143 x: margins.left - 8.0, y,
144 content: label, anchor: TextAnchor::End,
145 dominant_baseline: Some("middle".to_string()),
146 transform: None, font_size: Some("11px".to_string()),
147 font_weight: None,
148 fill: Some("#666".to_string()), class: "tick-label".to_string(), data: None,
149 });
150 }
151
152 let x_axis_y = margins.top + inner_height;
154 for &val in &x_ticks {
155 let x = x_scale.map(val);
156 axis_elements.push(ChartElement::Line {
158 x1: x, y1: margins.top, x2: x, y2: x_axis_y,
159 stroke: "#e0e0e0".to_string(), stroke_width: Some(1.0),
160 stroke_dasharray: None, class: "grid-line".to_string(),
161 });
162 axis_elements.push(ChartElement::Line {
164 x1: x, y1: x_axis_y, x2: x, y2: x_axis_y + 5.0,
165 stroke: "#999".to_string(), stroke_width: Some(1.0),
166 stroke_dasharray: None, class: "tick".to_string(),
167 });
168 let label = format_tick_value(val, x_tick_step);
170 axis_elements.push(ChartElement::Text {
171 x, y: x_axis_y + 18.0,
172 content: label, anchor: TextAnchor::Middle,
173 dominant_baseline: None, transform: None,
174 font_size: Some("11px".to_string()), font_weight: None,
175 fill: Some("#666".to_string()),
176 class: "tick-label".to_string(), data: None,
177 });
178 }
179
180 axis_elements.push(ChartElement::Line {
182 x1: margins.left, y1: margins.top, x2: margins.left, y2: x_axis_y,
183 stroke: "#ccc".to_string(), stroke_width: Some(1.0),
184 stroke_dasharray: None, class: "axis-line".to_string(),
185 });
186 axis_elements.push(ChartElement::Line {
187 x1: margins.left, y1: x_axis_y, x2: margins.left + inner_width, y2: x_axis_y,
188 stroke: "#ccc".to_string(), stroke_width: Some(1.0),
189 stroke_dasharray: None, class: "axis-line".to_string(),
190 });
191
192 children.push(ChartElement::Group {
193 class: "axes".to_string(),
194 transform: None,
195 children: axis_elements,
196 });
197
198 children.push(ChartElement::Group {
202 class: "chartml-scatter-points".to_string(),
203 transform: None,
204 children: point_elements,
205 });
206
207 if let Some(ref cf) = color_field {
209 let series_names = data.unique_values(cf);
210 if series_names.len() > 1 {
211 let legend_config = LegendConfig::default();
212 let legend_layout = calculate_legend_layout(&series_names, &config.colors, width, &legend_config);
213 let legend_y = height - legend_layout.total_height - 8.0;
214 let legend_elements = generate_legend_elements(
215 &series_names,
216 &config.colors,
217 width,
218 legend_y,
219 LegendMark::Circle,
220 );
221 children.push(ChartElement::Group {
222 class: "legend".to_string(),
223 transform: None,
224 children: legend_elements,
225 });
226 }
227 }
228
229 Ok(ChartElement::Svg {
230 viewbox: ViewBox::new(0.0, 0.0, width, height),
231 width: Some(width),
232 height: Some(height),
233 class: "chartml-chart chartml-scatter-chart".to_string(),
234 children,
235 })
236 }
237
238 fn default_dimensions(&self, _spec: &VisualizeSpec) -> Option<Dimensions> {
239 Some(Dimensions::new(400.0))
240 }
241}
242
243fn get_field_name(field_ref: &Option<FieldRef>) -> Result<String, ChartError> {
245 match field_ref {
246 Some(FieldRef::Simple(name)) => Ok(name.clone()),
247 Some(FieldRef::Detailed(spec)) => Ok(spec.field.clone()),
248 Some(FieldRef::Multiple(items)) => {
249 match items.first() {
251 Some(chartml_core::spec::FieldRefItem::Simple(name)) => Ok(name.clone()),
252 Some(chartml_core::spec::FieldRefItem::Detailed(spec)) => Ok(spec.field.clone()),
253 None => Err(ChartError::InvalidSpec("Empty field reference list".into())),
254 }
255 }
256 None => Err(ChartError::InvalidSpec("Missing required field reference".into())),
257 }
258}
259
260fn get_color_field(config: &ChartConfig) -> Option<String> {
262 config.visualize.marks.as_ref().and_then(|marks| {
263 marks.color.as_ref().map(|enc| match enc {
264 MarkEncoding::Simple(name) => name.clone(),
265 MarkEncoding::Detailed(spec) => spec.field.clone(),
266 })
267 })
268}
269
270fn get_size_field(config: &ChartConfig) -> Option<String> {
272 config.visualize.marks.as_ref().and_then(|marks| {
273 marks.size.as_ref().map(|enc| match enc {
274 MarkEncoding::Simple(name) => name.clone(),
275 MarkEncoding::Detailed(spec) => spec.field.clone(),
276 })
277 })
278}
279
280fn compute_tick_step(ticks: &[f64]) -> f64 {
282 if ticks.len() >= 2 {
283 (ticks[1] - ticks[0]).abs()
284 } else {
285 1.0
286 }
287}
288
289fn format_tick_value(value: f64, tick_step: f64) -> String {
294 let precision = if tick_step.abs() < 1e-15 {
296 0usize
297 } else {
298 let p = -(tick_step.abs().log10().floor()) as i64;
299 p.max(0) as usize
300 };
301
302 let formatted = format!("{:.prec$}", value, prec = precision);
303
304 let (int_part, dec_part) = if let Some(dot_pos) = formatted.find('.') {
306 (&formatted[..dot_pos], Some(&formatted[dot_pos..]))
307 } else {
308 (formatted.as_str(), None)
309 };
310
311 let (sign, digits) = if let Some(stripped) = int_part.strip_prefix('-') {
313 ("-", stripped)
314 } else {
315 ("", int_part)
316 };
317
318 let with_commas = insert_commas(digits);
319
320 match dec_part {
321 Some(dec) => format!("{}{}{}", sign, with_commas, dec),
322 None => format!("{}{}", sign, with_commas),
323 }
324}
325
326fn insert_commas(digits: &str) -> String {
328 let len = digits.len();
329 if len <= 3 {
330 return digits.to_string();
331 }
332 let mut result = String::with_capacity(len + len / 3);
333 for (i, ch) in digits.chars().enumerate() {
334 if i > 0 && (len - i).is_multiple_of(3) {
335 result.push(',');
336 }
337 result.push(ch);
338 }
339 result
340}
341
342#[cfg(test)]
343mod tests {
344 use super::*;
345 use std::collections::HashMap;
346 use chartml_core::data::Row;
347 use chartml_core::spec::{VisualizeSpec, MarksSpec, MarkEncoding};
348
349 fn make_row(pairs: &[(&str, serde_json::Value)]) -> Row {
350 let mut map = HashMap::new();
351 for (k, v) in pairs {
352 map.insert(k.to_string(), v.clone());
353 }
354 map
355 }
356
357 fn make_scatter_data() -> DataTable {
358 let rows = vec![
359 make_row(&[("price", serde_json::json!(10.0)), ("units", serde_json::json!(100.0)), ("category", serde_json::json!("A"))]),
360 make_row(&[("price", serde_json::json!(20.0)), ("units", serde_json::json!(200.0)), ("category", serde_json::json!("B"))]),
361 make_row(&[("price", serde_json::json!(30.0)), ("units", serde_json::json!(150.0)), ("category", serde_json::json!("A"))]),
362 make_row(&[("price", serde_json::json!(40.0)), ("units", serde_json::json!(300.0)), ("category", serde_json::json!("B"))]),
363 ];
364 DataTable::from_rows(&rows).unwrap()
365 }
366
367 fn make_scatter_config() -> ChartConfig {
368 ChartConfig {
369 visualize: VisualizeSpec {
370 chart_type: "scatter".to_string(),
371 mode: None,
372 orientation: None,
373 columns: Some(FieldRef::Simple("price".to_string())),
374 rows: Some(FieldRef::Simple("units".to_string())),
375 marks: Some(MarksSpec {
376 color: Some(MarkEncoding::Simple("category".to_string())),
377 size: None,
378 shape: None,
379 text: None,
380 }),
381 axes: None,
382 annotations: None,
383 style: None,
384 value: None,
385 label: None,
386 format: None,
387 compare_with: None,
388 invert_trend: None,
389 data_labels: None,
390 },
391 title: Some("Scatter Test".to_string()),
392 width: 800.0,
393 height: 400.0,
394 colors: vec![
395 "#2E7D9A".to_string(),
396 "#E8533E".to_string(),
397 "#4CAF50".to_string(),
398 ],
399 }
400 }
401
402 fn make_bubble_data() -> DataTable {
403 let rows = vec![
404 make_row(&[("x", serde_json::json!(5.0)), ("y", serde_json::json!(10.0)), ("size", serde_json::json!(100.0))]),
405 make_row(&[("x", serde_json::json!(15.0)), ("y", serde_json::json!(20.0)), ("size", serde_json::json!(400.0))]),
406 make_row(&[("x", serde_json::json!(25.0)), ("y", serde_json::json!(15.0)), ("size", serde_json::json!(200.0))]),
407 ];
408 DataTable::from_rows(&rows).unwrap()
409 }
410
411 fn make_bubble_config() -> ChartConfig {
412 ChartConfig {
413 visualize: VisualizeSpec {
414 chart_type: "scatter".to_string(),
415 mode: None,
416 orientation: None,
417 columns: Some(FieldRef::Simple("x".to_string())),
418 rows: Some(FieldRef::Simple("y".to_string())),
419 marks: Some(MarksSpec {
420 color: None,
421 size: Some(MarkEncoding::Simple("size".to_string())),
422 shape: None,
423 text: None,
424 }),
425 axes: None,
426 annotations: None,
427 style: None,
428 value: None,
429 label: None,
430 format: None,
431 compare_with: None,
432 invert_trend: None,
433 data_labels: None,
434 },
435 title: None,
436 width: 600.0,
437 height: 400.0,
438 colors: vec!["#2E7D9A".to_string()],
439 }
440 }
441
442 #[test]
443 fn scatter_chart_renders() {
444 let renderer = ScatterRenderer::new();
445 let result = renderer.render(&make_scatter_data(), &make_scatter_config());
446 assert!(result.is_ok(), "render failed: {:?}", result.err());
447 let element = result.unwrap();
448 let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
449 assert_eq!(circle_count, 6); }
451
452 #[test]
453 fn scatter_with_size_encoding() {
454 let renderer = ScatterRenderer::new();
455 let result = renderer.render(&make_bubble_data(), &make_bubble_config());
456 assert!(result.is_ok(), "render failed: {:?}", result.err());
457 let element = result.unwrap();
458 let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
459 assert!(circle_count > 0);
460 }
461
462 #[test]
463 fn scatter_empty_data_errors() {
464 let renderer = ScatterRenderer::new();
465 let data = DataTable::from_rows(&Vec::<Row>::new()).unwrap();
466 let result = renderer.render(&data, &make_scatter_config());
467 assert!(result.is_err());
468 }
469}