Skip to main content

chartml_core/
element.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5/// The output of any ChartRenderer. Framework adapters walk this tree
6/// and produce framework-specific DOM/view output.
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(tag = "type", rename_all = "camelCase")]
9pub enum ChartElement {
10    Svg {
11        viewbox: ViewBox,
12        width: Option<f64>,
13        height: Option<f64>,
14        class: String,
15        children: Vec<ChartElement>,
16    },
17    Group {
18        class: String,
19        transform: Option<Transform>,
20        children: Vec<ChartElement>,
21    },
22    Rect {
23        x: f64,
24        y: f64,
25        width: f64,
26        height: f64,
27        fill: String,
28        stroke: Option<String>,
29        class: String,
30        data: Option<ElementData>,
31    },
32    #[serde(rename_all = "camelCase")]
33    Path {
34        d: String,
35        fill: Option<String>,
36        stroke: Option<String>,
37        stroke_width: Option<f64>,
38        stroke_dasharray: Option<String>,
39        opacity: Option<f64>,
40        class: String,
41        data: Option<ElementData>,
42    },
43    Circle {
44        cx: f64,
45        cy: f64,
46        r: f64,
47        fill: String,
48        stroke: Option<String>,
49        class: String,
50        data: Option<ElementData>,
51    },
52    #[serde(rename_all = "camelCase")]
53    Line {
54        x1: f64,
55        y1: f64,
56        x2: f64,
57        y2: f64,
58        stroke: String,
59        stroke_width: Option<f64>,
60        stroke_dasharray: Option<String>,
61        class: String,
62    },
63    #[serde(rename_all = "camelCase")]
64    Text {
65        x: f64,
66        y: f64,
67        content: String,
68        anchor: TextAnchor,
69        dominant_baseline: Option<String>,
70        transform: Option<Transform>,
71        font_size: Option<String>,
72        font_weight: Option<String>,
73        fill: Option<String>,
74        class: String,
75        data: Option<ElementData>,
76    },
77    /// Non-SVG container (e.g., metric card uses div-based layout)
78    Div {
79        class: String,
80        style: HashMap<String, String>,
81        children: Vec<ChartElement>,
82    },
83    /// Raw text node (for metric values, labels in div-based charts)
84    Span {
85        class: String,
86        style: HashMap<String, String>,
87        content: String,
88    },
89}
90
91/// Data attached to interactive elements for tooltips.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ElementData {
94    pub label: String,
95    pub value: String,
96    pub series: Option<String>,
97    pub raw: HashMap<String, serde_json::Value>,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
101pub struct ViewBox {
102    pub x: f64,
103    pub y: f64,
104    pub width: f64,
105    pub height: f64,
106}
107
108#[derive(Debug, Clone, Serialize, Deserialize)]
109pub enum Transform {
110    Translate(f64, f64),
111    Rotate(f64, f64, f64),
112    Multiple(Vec<Transform>),
113}
114
115#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
116pub enum TextAnchor {
117    Start,
118    Middle,
119    End,
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct Dimensions {
124    pub width: Option<f64>,
125    pub height: f64,
126}
127
128impl ViewBox {
129    pub fn new(x: f64, y: f64, width: f64, height: f64) -> Self {
130        Self { x, y, width, height }
131    }
132
133    /// Format as SVG viewBox attribute string: "x y width height"
134    pub fn to_svg_string(&self) -> String {
135        format!("{} {} {} {}", self.x, self.y, self.width, self.height)
136    }
137}
138
139impl std::fmt::Display for ViewBox {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        write!(f, "{} {} {} {}", self.x, self.y, self.width, self.height)
142    }
143}
144
145impl Transform {
146    /// Format as SVG transform attribute string.
147    pub fn to_svg_string(&self) -> String {
148        match self {
149            Transform::Translate(x, y) => format!("translate({},{})", x, y),
150            Transform::Rotate(angle, cx, cy) => format!("rotate({},{},{})", angle, cx, cy),
151            Transform::Multiple(transforms) => {
152                transforms.iter().map(|t| t.to_svg_string()).collect::<Vec<_>>().join(" ")
153            }
154        }
155    }
156}
157
158impl std::fmt::Display for Transform {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        write!(f, "{}", self.to_svg_string())
161    }
162}
163
164impl std::fmt::Display for TextAnchor {
165    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166        match self {
167            TextAnchor::Start => write!(f, "start"),
168            TextAnchor::Middle => write!(f, "middle"),
169            TextAnchor::End => write!(f, "end"),
170        }
171    }
172}
173
174impl ElementData {
175    pub fn new(label: impl Into<String>, value: impl Into<String>) -> Self {
176        Self {
177            label: label.into(),
178            value: value.into(),
179            series: None,
180            raw: HashMap::new(),
181        }
182    }
183
184    pub fn with_series(mut self, series: impl Into<String>) -> Self {
185        self.series = Some(series.into());
186        self
187    }
188}
189
190impl Dimensions {
191    pub fn new(height: f64) -> Self {
192        Self { width: None, height }
193    }
194
195    pub fn with_width(mut self, width: f64) -> Self {
196        self.width = Some(width);
197        self
198    }
199}
200
201/// Count elements in the tree matching a predicate.
202pub fn count_elements<F>(element: &ChartElement, predicate: &F) -> usize
203where
204    F: Fn(&ChartElement) -> bool,
205{
206    let mut count = if predicate(element) { 1 } else { 0 };
207    match element {
208        ChartElement::Svg { children, .. }
209        | ChartElement::Group { children, .. }
210        | ChartElement::Div { children, .. } => {
211            for child in children {
212                count += count_elements(child, predicate);
213            }
214        }
215        _ => {}
216    }
217    count
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223
224    #[test]
225    fn viewbox_display() {
226        let vb = ViewBox::new(0.0, 0.0, 800.0, 400.0);
227        assert_eq!(vb.to_string(), "0 0 800 400");
228    }
229
230    #[test]
231    fn transform_translate_display() {
232        let t = Transform::Translate(10.0, 20.0);
233        assert_eq!(t.to_string(), "translate(10,20)");
234    }
235
236    #[test]
237    fn transform_rotate_display() {
238        let t = Transform::Rotate(45.0, 100.0, 200.0);
239        assert_eq!(t.to_string(), "rotate(45,100,200)");
240    }
241
242    #[test]
243    fn transform_multiple_display() {
244        let t = Transform::Multiple(vec![
245            Transform::Translate(10.0, 20.0),
246            Transform::Rotate(45.0, 0.0, 0.0),
247        ]);
248        assert_eq!(t.to_string(), "translate(10,20) rotate(45,0,0)");
249    }
250
251    #[test]
252    fn text_anchor_display() {
253        assert_eq!(TextAnchor::Start.to_string(), "start");
254        assert_eq!(TextAnchor::Middle.to_string(), "middle");
255        assert_eq!(TextAnchor::End.to_string(), "end");
256    }
257
258    #[test]
259    fn element_data_builder() {
260        let data = ElementData::new("Jan", "1234")
261            .with_series("Revenue");
262        assert_eq!(data.label, "Jan");
263        assert_eq!(data.value, "1234");
264        assert_eq!(data.series, Some("Revenue".to_string()));
265    }
266
267    #[test]
268    fn count_rects_in_tree() {
269        let tree = ChartElement::Svg {
270            viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
271            width: Some(800.0),
272            height: Some(400.0),
273            class: "chart".to_string(),
274            children: vec![
275                ChartElement::Group {
276                    class: "bars".to_string(),
277                    transform: None,
278                    children: vec![
279                        ChartElement::Rect {
280                            x: 0.0, y: 0.0, width: 50.0, height: 100.0,
281                            fill: "red".to_string(), stroke: None,
282                            class: "bar".to_string(), data: None,
283                        },
284                        ChartElement::Rect {
285                            x: 60.0, y: 0.0, width: 50.0, height: 150.0,
286                            fill: "blue".to_string(), stroke: None,
287                            class: "bar".to_string(), data: None,
288                        },
289                    ],
290                },
291                ChartElement::Text {
292                    x: 400.0, y: 20.0, content: "Title".to_string(),
293                    anchor: TextAnchor::Middle, dominant_baseline: None,
294                    transform: None, font_size: None, font_weight: None, fill: None,
295                    class: "title".to_string(),
296                    data: None,
297                },
298            ],
299        };
300        let rect_count = count_elements(&tree, &|e| matches!(e, ChartElement::Rect { .. }));
301        assert_eq!(rect_count, 2);
302    }
303
304    #[test]
305    fn dimensions_builder() {
306        let dims = Dimensions::new(400.0).with_width(800.0);
307        assert_eq!(dims.height, 400.0);
308        assert_eq!(dims.width, Some(800.0));
309    }
310
311    #[test]
312    fn serde_round_trip_chart_element_tree() {
313        let tree = ChartElement::Svg {
314            viewbox: ViewBox::new(0.0, 0.0, 800.0, 400.0),
315            width: Some(800.0),
316            height: Some(400.0),
317            class: "chart".to_string(),
318            children: vec![
319                ChartElement::Group {
320                    class: "bars".to_string(),
321                    transform: Some(Transform::Translate(50.0, 10.0)),
322                    children: vec![
323                        ChartElement::Rect {
324                            x: 0.0,
325                            y: 0.0,
326                            width: 50.0,
327                            height: 100.0,
328                            fill: "red".to_string(),
329                            stroke: None,
330                            class: "bar".to_string(),
331                            data: Some(
332                                ElementData::new("Jan", "1234").with_series("Revenue"),
333                            ),
334                        },
335                        ChartElement::Path {
336                            d: "M0,0 L10,10".to_string(),
337                            fill: None,
338                            stroke: Some("blue".to_string()),
339                            stroke_width: Some(2.0),
340                            stroke_dasharray: Some("4,2".to_string()),
341                            opacity: Some(0.8),
342                            class: "line".to_string(),
343                            data: None,
344                        },
345                    ],
346                },
347                ChartElement::Line {
348                    x1: 0.0,
349                    y1: 0.0,
350                    x2: 100.0,
351                    y2: 100.0,
352                    stroke: "black".to_string(),
353                    stroke_width: Some(1.0),
354                    stroke_dasharray: None,
355                    class: "axis".to_string(),
356                },
357                ChartElement::Text {
358                    x: 400.0,
359                    y: 20.0,
360                    content: "Title".to_string(),
361                    anchor: TextAnchor::Middle,
362                    dominant_baseline: Some("central".to_string()),
363                    transform: Some(Transform::Rotate(45.0, 400.0, 20.0)),
364                    font_size: Some("14px".to_string()),
365                    font_weight: Some("bold".to_string()),
366                    fill: Some("black".to_string()),
367                    class: "title".to_string(),
368                    data: None,
369                },
370                ChartElement::Circle {
371                    cx: 50.0,
372                    cy: 50.0,
373                    r: 5.0,
374                    fill: "green".to_string(),
375                    stroke: None,
376                    class: "dot".to_string(),
377                    data: None,
378                },
379                ChartElement::Div {
380                    class: "metric-card".to_string(),
381                    style: HashMap::from([
382                        ("display".to_string(), "flex".to_string()),
383                    ]),
384                    children: vec![ChartElement::Span {
385                        class: "value".to_string(),
386                        style: HashMap::from([
387                            ("font-size".to_string(), "24px".to_string()),
388                        ]),
389                        content: "$1,234".to_string(),
390                    }],
391                },
392            ],
393        };
394
395        let json = serde_json::to_string(&tree).expect("serialize");
396        let deserialized: ChartElement =
397            serde_json::from_str(&json).expect("deserialize");
398
399        // Re-serialize to confirm structural equality
400        let json2 = serde_json::to_string(&deserialized).expect("re-serialize");
401        assert_eq!(json, json2);
402
403        // Verify the tag format is correct
404        let value: serde_json::Value =
405            serde_json::from_str(&json).expect("parse as Value");
406        assert_eq!(value["type"], "svg");
407        assert_eq!(value["children"][0]["type"], "group");
408        assert_eq!(value["children"][0]["children"][1]["type"], "path");
409        assert_eq!(
410            value["children"][0]["children"][1]["strokeWidth"],
411            serde_json::json!(2.0)
412        );
413        assert_eq!(
414            value["children"][2]["dominantBaseline"],
415            serde_json::json!("central")
416        );
417    }
418}