use chartml_core::plugin::{ChartRenderer, ChartConfig};
use chartml_core::data::DataTable;
use chartml_core::element::*;
use chartml_core::error::ChartError;
use chartml_core::scales::{ScaleLinear, ScaleSqrt};
use chartml_core::spec::{VisualizeSpec, FieldRef, MarkEncoding};
use chartml_core::layout::margins::Margins;
use chartml_core::layout::legend::{LegendMark, LegendConfig, calculate_legend_layout, generate_legend_elements};
pub struct ScatterRenderer;
impl ScatterRenderer {
pub fn new() -> Self {
Self
}
}
impl Default for ScatterRenderer {
fn default() -> Self {
Self::new()
}
}
impl ChartRenderer for ScatterRenderer {
fn render(&self, data: &DataTable, config: &ChartConfig) -> Result<ChartElement, ChartError> {
let x_field = get_field_name(&config.visualize.columns)?;
let y_field = get_field_name(&config.visualize.rows)?;
let color_field = get_color_field(config);
let size_field = get_size_field(config);
let width = config.width;
let height = config.height;
let has_legend = color_field.is_some();
let margins = if has_legend {
Margins::new(30.0, 20.0, 70.0, 60.0)
} else {
Margins::default()
};
let inner_width = margins.inner_width(width);
let inner_height = margins.inner_height(height);
let x_extent = data.extent(&x_field)
.ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", x_field)))?;
let y_extent = data.extent(&y_field)
.ok_or_else(|| ChartError::DataError(format!("No numeric data for field '{}'", y_field)))?;
let x_domain = (x_extent.0, x_extent.1);
let y_domain = (y_extent.0, y_extent.1);
let x_scale = ScaleLinear::new(x_domain, (margins.left, margins.left + inner_width)).nice(5);
let y_scale = ScaleLinear::new(y_domain, (margins.top + inner_height, margins.top)).nice(5);
let size_scale = size_field.as_ref().and_then(|f| {
data.extent(f).map(|ext| ScaleSqrt::new(ext, (3.0, 20.0))) });
let color_categories: Vec<String> = if let Some(ref cf) = color_field {
data.unique_values(cf)
} else {
vec![]
};
let mut point_elements = Vec::new();
for i in 0..data.num_rows() {
let x_val = data.get_f64(i, &x_field);
let y_val = data.get_f64(i, &y_field);
if let (Some(x), Some(y)) = (x_val, y_val) {
let cx = x_scale.map(x);
let cy = y_scale.map(y);
let r = match (&size_field, &size_scale) {
(Some(sf), Some(ss)) => {
data.get_f64(i, sf).map(|v| ss.map(v)).unwrap_or(5.0)
}
_ => 5.0,
};
let color_idx = if let Some(ref cf) = color_field {
data.get_string(i, cf)
.and_then(|v| color_categories.iter().position(|c| c == &v))
.unwrap_or(0)
} else {
0
};
let fill = config.colors.get(color_idx % config.colors.len())
.cloned()
.unwrap_or_else(|| "#2E7D9A".to_string());
let label = data.get_string(i, &x_field).unwrap_or_default();
let value = format!("{}", y);
let el_data = ElementData::new(label, value);
point_elements.push(ChartElement::Circle {
cx,
cy,
r,
fill,
stroke: Some("#fff".to_string()),
class: "chartml-scatter-point".to_string(),
data: Some(el_data),
});
}
}
let mut children = Vec::new();
let x_ticks = x_scale.ticks(((inner_width / 50.0).floor() as usize).clamp(4, 10));
let y_ticks = y_scale.ticks(((inner_height / 50.0).floor() as usize).clamp(4, 10));
let mut axis_elements = Vec::new();
let y_tick_step = compute_tick_step(&y_ticks);
let x_tick_step = compute_tick_step(&x_ticks);
for &val in &y_ticks {
let y = y_scale.map(val);
axis_elements.push(ChartElement::Line {
x1: margins.left, y1: y, x2: margins.left + inner_width, y2: y,
stroke: "#e0e0e0".to_string(), stroke_width: Some(1.0),
stroke_dasharray: None, class: "grid-line".to_string(),
});
axis_elements.push(ChartElement::Line {
x1: margins.left - 5.0, y1: y, x2: margins.left, y2: y,
stroke: "#999".to_string(), stroke_width: Some(1.0),
stroke_dasharray: None, class: "tick".to_string(),
});
let label = format_tick_value(val, y_tick_step);
axis_elements.push(ChartElement::Text {
x: margins.left - 8.0, y,
content: label, anchor: TextAnchor::End,
dominant_baseline: Some("middle".to_string()),
transform: None, font_size: Some("11px".to_string()),
font_weight: None,
fill: Some("#666".to_string()), class: "tick-label".to_string(), data: None,
});
}
let x_axis_y = margins.top + inner_height;
for &val in &x_ticks {
let x = x_scale.map(val);
axis_elements.push(ChartElement::Line {
x1: x, y1: margins.top, x2: x, y2: x_axis_y,
stroke: "#e0e0e0".to_string(), stroke_width: Some(1.0),
stroke_dasharray: None, class: "grid-line".to_string(),
});
axis_elements.push(ChartElement::Line {
x1: x, y1: x_axis_y, x2: x, y2: x_axis_y + 5.0,
stroke: "#999".to_string(), stroke_width: Some(1.0),
stroke_dasharray: None, class: "tick".to_string(),
});
let label = format_tick_value(val, x_tick_step);
axis_elements.push(ChartElement::Text {
x, y: x_axis_y + 18.0,
content: label, anchor: TextAnchor::Middle,
dominant_baseline: None, transform: None,
font_size: Some("11px".to_string()), font_weight: None,
fill: Some("#666".to_string()),
class: "tick-label".to_string(), data: None,
});
}
axis_elements.push(ChartElement::Line {
x1: margins.left, y1: margins.top, x2: margins.left, y2: x_axis_y,
stroke: "#ccc".to_string(), stroke_width: Some(1.0),
stroke_dasharray: None, class: "axis-line".to_string(),
});
axis_elements.push(ChartElement::Line {
x1: margins.left, y1: x_axis_y, x2: margins.left + inner_width, y2: x_axis_y,
stroke: "#ccc".to_string(), stroke_width: Some(1.0),
stroke_dasharray: None, class: "axis-line".to_string(),
});
children.push(ChartElement::Group {
class: "axes".to_string(),
transform: None,
children: axis_elements,
});
children.push(ChartElement::Group {
class: "chartml-scatter-points".to_string(),
transform: None,
children: point_elements,
});
if let Some(ref cf) = color_field {
let series_names = data.unique_values(cf);
if series_names.len() > 1 {
let legend_config = LegendConfig::default();
let legend_layout = calculate_legend_layout(&series_names, &config.colors, width, &legend_config);
let legend_y = height - legend_layout.total_height - 8.0;
let legend_elements = generate_legend_elements(
&series_names,
&config.colors,
width,
legend_y,
LegendMark::Circle,
);
children.push(ChartElement::Group {
class: "legend".to_string(),
transform: None,
children: legend_elements,
});
}
}
Ok(ChartElement::Svg {
viewbox: ViewBox::new(0.0, 0.0, width, height),
width: Some(width),
height: Some(height),
class: "chartml-chart chartml-scatter-chart".to_string(),
children,
})
}
fn default_dimensions(&self, _spec: &VisualizeSpec) -> Option<Dimensions> {
Some(Dimensions::new(400.0))
}
}
fn get_field_name(field_ref: &Option<FieldRef>) -> Result<String, ChartError> {
match field_ref {
Some(FieldRef::Simple(name)) => Ok(name.clone()),
Some(FieldRef::Detailed(spec)) => Ok(spec.field.clone()),
Some(FieldRef::Multiple(items)) => {
match items.first() {
Some(chartml_core::spec::FieldRefItem::Simple(name)) => Ok(name.clone()),
Some(chartml_core::spec::FieldRefItem::Detailed(spec)) => Ok(spec.field.clone()),
None => Err(ChartError::InvalidSpec("Empty field reference list".into())),
}
}
None => Err(ChartError::InvalidSpec("Missing required field reference".into())),
}
}
fn get_color_field(config: &ChartConfig) -> Option<String> {
config.visualize.marks.as_ref().and_then(|marks| {
marks.color.as_ref().map(|enc| match enc {
MarkEncoding::Simple(name) => name.clone(),
MarkEncoding::Detailed(spec) => spec.field.clone(),
})
})
}
fn get_size_field(config: &ChartConfig) -> Option<String> {
config.visualize.marks.as_ref().and_then(|marks| {
marks.size.as_ref().map(|enc| match enc {
MarkEncoding::Simple(name) => name.clone(),
MarkEncoding::Detailed(spec) => spec.field.clone(),
})
})
}
fn compute_tick_step(ticks: &[f64]) -> f64 {
if ticks.len() >= 2 {
(ticks[1] - ticks[0]).abs()
} else {
1.0
}
}
fn format_tick_value(value: f64, tick_step: f64) -> String {
let precision = if tick_step.abs() < 1e-15 {
0usize
} else {
let p = -(tick_step.abs().log10().floor()) as i64;
p.max(0) as usize
};
let formatted = format!("{:.prec$}", value, prec = precision);
let (int_part, dec_part) = if let Some(dot_pos) = formatted.find('.') {
(&formatted[..dot_pos], Some(&formatted[dot_pos..]))
} else {
(formatted.as_str(), None)
};
let (sign, digits) = if let Some(stripped) = int_part.strip_prefix('-') {
("-", stripped)
} else {
("", int_part)
};
let with_commas = insert_commas(digits);
match dec_part {
Some(dec) => format!("{}{}{}", sign, with_commas, dec),
None => format!("{}{}", sign, with_commas),
}
}
fn insert_commas(digits: &str) -> String {
let len = digits.len();
if len <= 3 {
return digits.to_string();
}
let mut result = String::with_capacity(len + len / 3);
for (i, ch) in digits.chars().enumerate() {
if i > 0 && (len - i).is_multiple_of(3) {
result.push(',');
}
result.push(ch);
}
result
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use chartml_core::data::Row;
use chartml_core::spec::{VisualizeSpec, MarksSpec, MarkEncoding};
fn make_row(pairs: &[(&str, serde_json::Value)]) -> Row {
let mut map = HashMap::new();
for (k, v) in pairs {
map.insert(k.to_string(), v.clone());
}
map
}
fn make_scatter_data() -> DataTable {
let rows = vec![
make_row(&[("price", serde_json::json!(10.0)), ("units", serde_json::json!(100.0)), ("category", serde_json::json!("A"))]),
make_row(&[("price", serde_json::json!(20.0)), ("units", serde_json::json!(200.0)), ("category", serde_json::json!("B"))]),
make_row(&[("price", serde_json::json!(30.0)), ("units", serde_json::json!(150.0)), ("category", serde_json::json!("A"))]),
make_row(&[("price", serde_json::json!(40.0)), ("units", serde_json::json!(300.0)), ("category", serde_json::json!("B"))]),
];
DataTable::from_rows(&rows).unwrap()
}
fn make_scatter_config() -> ChartConfig {
ChartConfig {
visualize: VisualizeSpec {
chart_type: "scatter".to_string(),
mode: None,
orientation: None,
columns: Some(FieldRef::Simple("price".to_string())),
rows: Some(FieldRef::Simple("units".to_string())),
marks: Some(MarksSpec {
color: Some(MarkEncoding::Simple("category".to_string())),
size: None,
shape: None,
text: None,
}),
axes: None,
annotations: None,
style: None,
value: None,
label: None,
format: None,
compare_with: None,
invert_trend: None,
data_labels: None,
},
title: Some("Scatter Test".to_string()),
width: 800.0,
height: 400.0,
colors: vec![
"#2E7D9A".to_string(),
"#E8533E".to_string(),
"#4CAF50".to_string(),
],
}
}
fn make_bubble_data() -> DataTable {
let rows = vec![
make_row(&[("x", serde_json::json!(5.0)), ("y", serde_json::json!(10.0)), ("size", serde_json::json!(100.0))]),
make_row(&[("x", serde_json::json!(15.0)), ("y", serde_json::json!(20.0)), ("size", serde_json::json!(400.0))]),
make_row(&[("x", serde_json::json!(25.0)), ("y", serde_json::json!(15.0)), ("size", serde_json::json!(200.0))]),
];
DataTable::from_rows(&rows).unwrap()
}
fn make_bubble_config() -> ChartConfig {
ChartConfig {
visualize: VisualizeSpec {
chart_type: "scatter".to_string(),
mode: None,
orientation: None,
columns: Some(FieldRef::Simple("x".to_string())),
rows: Some(FieldRef::Simple("y".to_string())),
marks: Some(MarksSpec {
color: None,
size: Some(MarkEncoding::Simple("size".to_string())),
shape: None,
text: None,
}),
axes: None,
annotations: None,
style: None,
value: None,
label: None,
format: None,
compare_with: None,
invert_trend: None,
data_labels: None,
},
title: None,
width: 600.0,
height: 400.0,
colors: vec!["#2E7D9A".to_string()],
}
}
#[test]
fn scatter_chart_renders() {
let renderer = ScatterRenderer::new();
let result = renderer.render(&make_scatter_data(), &make_scatter_config());
assert!(result.is_ok(), "render failed: {:?}", result.err());
let element = result.unwrap();
let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
assert_eq!(circle_count, 6); }
#[test]
fn scatter_with_size_encoding() {
let renderer = ScatterRenderer::new();
let result = renderer.render(&make_bubble_data(), &make_bubble_config());
assert!(result.is_ok(), "render failed: {:?}", result.err());
let element = result.unwrap();
let circle_count = count_elements(&element, &|e| matches!(e, ChartElement::Circle { .. }));
assert!(circle_count > 0);
}
#[test]
fn scatter_empty_data_errors() {
let renderer = ScatterRenderer::new();
let data = DataTable::from_rows(&Vec::<Row>::new()).unwrap();
let result = renderer.render(&data, &make_scatter_config());
assert!(result.is_err());
}
}