Skip to main content

chartml_chart_pie/
lib.rs

1use chartml_core::plugin::{ChartRenderer, ChartConfig};
2use chartml_core::data::DataTable;
3use chartml_core::element::*;
4use chartml_core::error::ChartError;
5use chartml_core::shapes::{ArcGenerator, PieLayout};
6use chartml_core::spec::{VisualizeSpec, FieldRef};
7use chartml_core::layout::{calculate_legend_layout, LegendConfig, LegendAlignment, TextMetrics};
8
9#[derive(Default)]
10pub struct PieRenderer;
11
12impl PieRenderer {
13    pub fn new() -> Self { Self }
14}
15
16impl ChartRenderer for PieRenderer {
17    fn render(&self, data: &DataTable, config: &ChartConfig) -> Result<ChartElement, ChartError> {
18        let chart_type = &config.visualize.chart_type;
19        let is_doughnut = chart_type == "doughnut";
20
21        // Get fields
22        let col_field = get_field_name(&config.visualize.columns)?;
23        let row_field = get_field_name(&config.visualize.rows)?;
24
25        // Extract data
26        let mut labels = Vec::new();
27        let mut values = Vec::new();
28        for i in 0..data.num_rows() {
29            if let (Some(label), Some(value)) = (data.get_string(i, &col_field), data.get_f64(i, &row_field)) {
30                labels.push(label);
31                values.push(value);
32            }
33        }
34
35        if values.is_empty() {
36            return Err(ChartError::DataError("No data for pie chart".into()));
37        }
38
39        let width = config.width;
40        let height = config.height;
41
42        // Reserve space at the bottom for the legend + 5% bottom margin.
43        // The 50px is a conservative upper-bound for single-row legend height
44        // (30px gap + ~20px legend row); multi-row legends are accommodated by
45        // the actual legend_layout.total_height used for positioning below.
46        let bottom_margin = height * 0.05;
47        let legend_reserved = 50.0 + bottom_margin;
48        let radius = (width.min(height - legend_reserved) / 2.0) - 40.0;
49        let inner_radius = if is_doughnut { radius * 0.5 } else { 0.0 };
50        let cx = width / 2.0;
51        // Shift pie center up slightly to make room for legend below
52        let cy = (height - legend_reserved) / 2.0;
53
54        // Compute pie layout
55        let pie = PieLayout::new();
56        let slices = pie.layout(&values);
57
58        // Generate arc paths
59        let arc = ArcGenerator::new(inner_radius, radius);
60        let mut slice_elements = Vec::new();
61
62        for (i, slice) in slices.iter().enumerate() {
63            let path_d = arc.generate(slice.start_angle, slice.end_angle);
64            let color = config.colors.get(i % config.colors.len())
65                .cloned()
66                .unwrap_or_else(|| "#999".to_string());
67
68            let data = ElementData::new(&labels[slice.index], chartml_core::format::format_value(values[slice.index], None))
69                .with_series(&labels[slice.index]);
70
71            slice_elements.push(ChartElement::Path {
72                d: path_d,
73                fill: Some(color),
74                stroke: Some(config.theme.bg.clone()),
75                stroke_width: Some(2.0),
76                stroke_dasharray: None,
77                stroke_dashoffset: None,
78                opacity: None,
79                class: "chartml-pie-slice".to_string(),
80                data: Some(data),
81                animation_origin: None,
82            });
83        }
84
85        // Title is rendered as HTML outside the SVG — not added here.
86        let mut children = Vec::new();
87
88        // Pie group (centered)
89        children.push(ChartElement::Group {
90            class: "chartml-pie".to_string(),
91            transform: Some(Transform::Translate(cx, cy)),
92            children: slice_elements,
93        });
94
95        // Legend — rendered below the pie, horizontally centered
96        let legend_config = LegendConfig {
97            alignment: LegendAlignment::Center,
98            text_metrics: TextMetrics::from_theme_legend(&config.theme),
99            ..LegendConfig::default()
100        };
101        // Build ordered labels and colors (original data order matches color palette order)
102        let legend_colors: Vec<String> = (0..labels.len())
103            .map(|i| config.colors.get(i % config.colors.len()).cloned().unwrap_or_else(|| "#999".to_string()))
104            .collect();
105        let legend_layout = calculate_legend_layout(&labels, &legend_colors, width, &legend_config);
106        // Position legend so the last text baseline sits at least 5% from the bottom edge
107        let legend_y = height - legend_layout.total_height - bottom_margin;
108        for item in legend_layout.items.iter().filter(|i| i.visible) {
109            // Colored swatch rect
110            children.push(ChartElement::Rect {
111                x: item.x,
112                y: legend_y + item.y,
113                width: legend_config.symbol_size,
114                height: legend_config.symbol_size,
115                fill: item.color.clone(),
116                stroke: None,
117                rx: None,
118                ry: None,
119                class: "legend-symbol".to_string(),
120                data: None,
121                animation_origin: None,
122            });
123            // Label text
124            let ts = TextStyle::for_role(&config.theme, TextRole::LegendLabel);
125            children.push(ChartElement::Text {
126                x: item.x + legend_config.symbol_size + legend_config.symbol_text_gap,
127                y: legend_y + item.y + 10.0, // vertical center of 20px row_height
128                content: item.label.clone(),
129                anchor: TextAnchor::Start,
130                dominant_baseline: None,
131                transform: None,
132                font_family: ts.font_family,
133                font_size: ts.font_size,
134                font_weight: ts.font_weight,
135                letter_spacing: ts.letter_spacing,
136                text_transform: ts.text_transform,
137                fill: Some(config.theme.text_secondary.clone()),
138                class: "legend-label".to_string(),
139                data: None,
140            });
141        }
142
143        Ok(ChartElement::Svg {
144            viewbox: ViewBox::new(0.0, 0.0, width, height),
145            width: Some(width),
146            height: Some(height),
147            class: "chartml-chart chartml-pie-chart".to_string(),
148            children,
149        })
150    }
151
152    fn default_dimensions(&self, _spec: &VisualizeSpec) -> Option<Dimensions> {
153        Some(Dimensions::new(400.0))
154    }
155}
156
157fn get_field_name(field_ref: &Option<FieldRef>) -> Result<String, ChartError> {
158    fn field_or_err(spec: &chartml_core::spec::FieldSpec) -> Result<String, ChartError> {
159        spec.field
160            .clone()
161            .ok_or_else(|| ChartError::MissingField("field (range-mark specs are not supported for pie charts)".into()))
162    }
163    match field_ref {
164        Some(FieldRef::Simple(name)) => Ok(name.clone()),
165        Some(FieldRef::Detailed(spec)) => field_or_err(spec),
166        Some(FieldRef::Multiple(items)) => {
167            match items.first() {
168                Some(chartml_core::spec::FieldRefItem::Simple(s)) => Ok(s.clone()),
169                Some(chartml_core::spec::FieldRefItem::Detailed(spec)) => field_or_err(spec),
170                None => Err(ChartError::MissingField("field".into())),
171            }
172        }
173        None => Err(ChartError::MissingField("columns/rows field".into())),
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    #![allow(clippy::unwrap_used)]
180    use super::*;
181    use chartml_core::data::Row;
182    use chartml_core::element::count_elements;
183    use serde_json::json;
184
185    fn make_pie_data() -> DataTable {
186        let rows: Vec<Row> = vec![
187            [("region".to_string(), json!("North")), ("revenue".to_string(), json!(100))].into_iter().collect(),
188            [("region".to_string(), json!("South")), ("revenue".to_string(), json!(200))].into_iter().collect(),
189            [("region".to_string(), json!("East")), ("revenue".to_string(), json!(150))].into_iter().collect(),
190        ];
191        DataTable::from_rows(&rows).unwrap()
192    }
193
194    fn make_pie_config(chart_type: &str) -> ChartConfig {
195        let viz: chartml_core::spec::VisualizeSpec = serde_yaml::from_str(&format!(r#"
196            type: {}
197            columns: region
198            rows: revenue
199        "#, chart_type)).unwrap();
200        ChartConfig {
201            visualize: viz,
202            title: Some("Test Pie".to_string()),
203            width: 400.0,
204            height: 400.0,
205            colors: vec!["#2E7D9A".to_string(), "#D4A445".to_string(), "#4A7C59".to_string()],
206            theme: chartml_core::theme::Theme::default(),
207        }
208    }
209
210    #[test]
211    fn pie_chart_renders() {
212        let renderer = PieRenderer::new();
213        let result = renderer.render(&make_pie_data(), &make_pie_config("pie"));
214        assert!(result.is_ok(), "Pie render failed: {:?}", result.err());
215        let element = result.unwrap();
216        let path_count = count_elements(&element, &|e| matches!(e, ChartElement::Path { .. }));
217        assert_eq!(path_count, 3, "Should have 3 slices");
218    }
219
220    #[test]
221    fn doughnut_chart_renders() {
222        let renderer = PieRenderer::new();
223        let result = renderer.render(&make_pie_data(), &make_pie_config("doughnut"));
224        assert!(result.is_ok());
225        let element = result.unwrap();
226        let path_count = count_elements(&element, &|e| matches!(e, ChartElement::Path { .. }));
227        assert_eq!(path_count, 3);
228    }
229
230    #[test]
231    fn pie_has_no_title_in_svg() {
232        // Title is rendered as HTML outside the SVG.
233        let renderer = PieRenderer::new();
234        let element = renderer.render(&make_pie_data(), &make_pie_config("pie")).unwrap();
235        let text_count = count_elements(&element, &|e| matches!(e, ChartElement::Text { class, .. } if class == "chart-title"));
236        assert_eq!(text_count, 0);
237    }
238
239    #[test]
240    fn pie_has_legend() {
241        let renderer = PieRenderer::new();
242        let element = renderer.render(&make_pie_data(), &make_pie_config("pie")).unwrap();
243        // 3 slices = 3 legend swatches (Rect) + 3 legend labels (Text with class "legend-label")
244        let swatch_count = count_elements(&element, &|e| matches!(e, ChartElement::Rect { class, .. } if class == "legend-symbol"));
245        assert_eq!(swatch_count, 3, "Should have 3 legend swatches (one per slice)");
246        let label_count = count_elements(&element, &|e| matches!(e, ChartElement::Text { class, .. } if class == "legend-label"));
247        assert_eq!(label_count, 3, "Should have 3 legend labels (one per slice)");
248    }
249
250    #[test]
251    fn pie_legend_respects_5pct_bottom_margin() {
252        let renderer = PieRenderer::new();
253        let config = make_pie_config("pie");
254        let height = config.height;
255        let element = renderer.render(&make_pie_data(), &config).unwrap();
256        // Find the maximum y coordinate of any legend-label text element
257        let mut max_text_y: f64 = 0.0;
258        fn collect_max_y(el: &ChartElement, max_y: &mut f64) {
259            match el {
260                ChartElement::Text { y, class, .. } if class == "legend-label" => {
261                    if *y > *max_y { *max_y = *y; }
262                }
263                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
264                    for child in children { collect_max_y(child, max_y); }
265                }
266                _ => {}
267            }
268        }
269        collect_max_y(&element, &mut max_text_y);
270        let bottom_gap = height - max_text_y;
271        assert!(
272            bottom_gap >= height * 0.05,
273            "Bottom gap {:.1}px ({:.1}%) is below 5% threshold on {:.0}px chart",
274            bottom_gap, bottom_gap / height * 100.0, height
275        );
276    }
277
278    #[test]
279    fn pie_legend_colors_match_slices() {
280        let renderer = PieRenderer::new();
281        let config = make_pie_config("pie");
282        let element = renderer.render(&make_pie_data(), &config).unwrap();
283        // Collect legend swatch fills in document order
284        let mut fills = Vec::new();
285        fn collect_fills(el: &ChartElement, fills: &mut Vec<String>) {
286            match el {
287                ChartElement::Rect { fill, class, .. } if class == "legend-symbol" => {
288                    fills.push(fill.clone());
289                }
290                ChartElement::Svg { children, .. } | ChartElement::Group { children, .. } => {
291                    for child in children { collect_fills(child, fills); }
292                }
293                _ => {}
294            }
295        }
296        collect_fills(&element, &mut fills);
297        assert_eq!(fills.len(), 3, "Expected 3 legend swatches");
298        // Colors must match the configured palette in order
299        assert_eq!(fills[0], config.colors[0]);
300        assert_eq!(fills[1], config.colors[1]);
301        assert_eq!(fills[2], config.colors[2]);
302    }
303}