Skip to main content

nova_plot/
render.rs

1//! SVG rendering for charts.
2
3use crate::chart::{Chart, ChartType};
4use crate::data::{DataSource, Table};
5use crate::error::{Error, Result};
6use crate::style::Color;
7
8/// Chart renderer.
9pub struct Renderer<'a> {
10    chart: &'a Chart,
11}
12
13impl<'a> Renderer<'a> {
14    /// Create a new renderer.
15    pub fn new(chart: &'a Chart) -> Self {
16        Self { chart }
17    }
18
19    /// Render the chart to SVG.
20    pub fn render(&self, data: &DataSource) -> Result<String> {
21        match self.chart.chart_type {
22            ChartType::Line => self.render_line(data),
23            ChartType::Bar => self.render_bar(data),
24            ChartType::Scatter => self.render_scatter(data),
25            ChartType::Pie => self.render_pie(data),
26            ChartType::Area => self.render_area(data),
27        }
28    }
29
30    /// Render a line chart.
31    fn render_line(&self, data: &DataSource) -> Result<String> {
32        let table = data.as_table().ok_or(Error::InvalidData {
33            message: "line chart requires table data".to_string(),
34        })?;
35
36        let (x_data, y_data) = self.extract_xy(table)?;
37        let (x_min, x_max, y_min, y_max) = self.calculate_bounds(&x_data, &y_data);
38
39        let mut svg = self.svg_header();
40        svg.push_str(&self.render_background());
41
42        if self.chart.style.show_grid {
43            svg.push_str(&self.render_grid(x_min, x_max, y_min, y_max));
44        }
45
46        svg.push_str(&self.render_axes(x_min, x_max, y_min, y_max));
47        svg.push_str(&self.render_line_path(&x_data, &y_data, x_min, x_max, y_min, y_max));
48
49        if let Some(ref title) = self.chart.title {
50            svg.push_str(&self.render_title(title));
51        }
52
53        svg.push_str(&self.render_axis_labels());
54        svg.push_str("</svg>");
55
56        Ok(svg)
57    }
58
59    /// Render a bar chart.
60    fn render_bar(&self, data: &DataSource) -> Result<String> {
61        let table = data.as_table().ok_or(Error::InvalidData {
62            message: "bar chart requires table data".to_string(),
63        })?;
64
65        let (labels, values) = self.extract_labels_values(table)?;
66        let y_max = values.iter().cloned().fold(0.0_f64, f64::max);
67
68        let mut svg = self.svg_header();
69        svg.push_str(&self.render_background());
70        svg.push_str(&self.render_bars(&labels, &values, y_max));
71
72        if let Some(ref title) = self.chart.title {
73            svg.push_str(&self.render_title(title));
74        }
75
76        svg.push_str("</svg>");
77
78        Ok(svg)
79    }
80
81    /// Render a scatter plot.
82    fn render_scatter(&self, data: &DataSource) -> Result<String> {
83        let table = data.as_table().ok_or(Error::InvalidData {
84            message: "scatter chart requires table data".to_string(),
85        })?;
86
87        let (x_data, y_data) = self.extract_xy(table)?;
88        let (x_min, x_max, y_min, y_max) = self.calculate_bounds(&x_data, &y_data);
89
90        let mut svg = self.svg_header();
91        svg.push_str(&self.render_background());
92
93        if self.chart.style.show_grid {
94            svg.push_str(&self.render_grid(x_min, x_max, y_min, y_max));
95        }
96
97        svg.push_str(&self.render_axes(x_min, x_max, y_min, y_max));
98        svg.push_str(&self.render_points(&x_data, &y_data, x_min, x_max, y_min, y_max));
99
100        if let Some(ref title) = self.chart.title {
101            svg.push_str(&self.render_title(title));
102        }
103
104        svg.push_str("</svg>");
105
106        Ok(svg)
107    }
108
109    /// Render a pie chart.
110    fn render_pie(&self, data: &DataSource) -> Result<String> {
111        let table = data.as_table().ok_or(Error::InvalidData {
112            message: "pie chart requires table data".to_string(),
113        })?;
114
115        let (labels, values) = self.extract_labels_values(table)?;
116
117        let mut svg = self.svg_header();
118        svg.push_str(&self.render_background());
119        svg.push_str(&self.render_pie_slices(&labels, &values));
120
121        if let Some(ref title) = self.chart.title {
122            svg.push_str(&self.render_title(title));
123        }
124
125        svg.push_str("</svg>");
126
127        Ok(svg)
128    }
129
130    /// Render an area chart.
131    fn render_area(&self, data: &DataSource) -> Result<String> {
132        // Area chart is similar to line chart but filled
133        let table = data.as_table().ok_or(Error::InvalidData {
134            message: "area chart requires table data".to_string(),
135        })?;
136
137        let (x_data, y_data) = self.extract_xy(table)?;
138        let (x_min, x_max, y_min, y_max) = self.calculate_bounds(&x_data, &y_data);
139
140        let mut svg = self.svg_header();
141        svg.push_str(&self.render_background());
142
143        if self.chart.style.show_grid {
144            svg.push_str(&self.render_grid(x_min, x_max, y_min, y_max));
145        }
146
147        svg.push_str(&self.render_axes(x_min, x_max, y_min, y_max));
148        svg.push_str(&self.render_area_path(&x_data, &y_data, x_min, x_max, y_min, y_max));
149
150        if let Some(ref title) = self.chart.title {
151            svg.push_str(&self.render_title(title));
152        }
153
154        svg.push_str("</svg>");
155
156        Ok(svg)
157    }
158
159    // Helper methods
160
161    fn svg_header(&self) -> String {
162        format!(
163            r#"<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 {} {}" width="{}" height="{}">"#,
164            self.chart.width, self.chart.height, self.chart.width, self.chart.height
165        )
166    }
167
168    fn render_background(&self) -> String {
169        format!(
170            r#"<rect width="100%" height="100%" fill="{}"/>"#,
171            self.chart.style.background.to_hex()
172        )
173    }
174
175    fn render_title(&self, title: &str) -> String {
176        format!(
177            r#"<text x="{}" y="{}" text-anchor="middle" font-family="{}" font-size="{}" fill="{}">{}</text>"#,
178            self.chart.width / 2.0,
179            self.chart.style.padding.top / 2.0 + 5.0,
180            self.chart.style.font_family,
181            self.chart.style.title_font_size,
182            self.chart.style.text_color.to_hex(),
183            html_escape(title)
184        )
185    }
186
187    fn render_axis_labels(&self) -> String {
188        let mut svg = String::new();
189
190        if let Some(ref label) = self.chart.x_label {
191            svg.push_str(&format!(
192                r#"<text x="{}" y="{}" text-anchor="middle" font-family="{}" font-size="{}" fill="{}">{}</text>"#,
193                self.chart.width / 2.0,
194                self.chart.height - 10.0,
195                self.chart.style.font_family,
196                self.chart.style.label_font_size,
197                self.chart.style.text_color.to_hex(),
198                html_escape(label)
199            ));
200        }
201
202        if let Some(ref label) = self.chart.y_label {
203            svg.push_str(&format!(
204                r#"<text x="{}" y="{}" text-anchor="middle" font-family="{}" font-size="{}" fill="{}" transform="rotate(-90, {}, {})">{}</text>"#,
205                15.0,
206                self.chart.height / 2.0,
207                self.chart.style.font_family,
208                self.chart.style.label_font_size,
209                self.chart.style.text_color.to_hex(),
210                15.0,
211                self.chart.height / 2.0,
212                html_escape(label)
213            ));
214        }
215
216        svg
217    }
218
219    fn plot_area(&self) -> (f64, f64, f64, f64) {
220        let padding = &self.chart.style.padding;
221        (
222            padding.left,
223            padding.top,
224            self.chart.width - padding.left - padding.right,
225            self.chart.height - padding.top - padding.bottom,
226        )
227    }
228
229    fn render_grid(&self, _x_min: f64, _x_max: f64, _y_min: f64, _y_max: f64) -> String {
230        let (x, y, w, h) = self.plot_area();
231        let grid_color = self.chart.style.grid_color.to_hex();
232        let mut svg = String::new();
233
234        // Horizontal grid lines
235        for i in 0..=5 {
236            let y_pos = y + h - (h * f64::from(i) / 5.0);
237            svg.push_str(&format!(
238                r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="1"/>"#,
239                x,
240                y_pos,
241                x + w,
242                y_pos,
243                grid_color
244            ));
245        }
246
247        // Vertical grid lines
248        for i in 0..=5 {
249            let x_pos = x + (w * f64::from(i) / 5.0);
250            svg.push_str(&format!(
251                r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="1"/>"#,
252                x_pos,
253                y,
254                x_pos,
255                y + h,
256                grid_color
257            ));
258        }
259
260        svg
261    }
262
263    fn render_axes(&self, x_min: f64, x_max: f64, y_min: f64, y_max: f64) -> String {
264        let (x, y, w, h) = self.plot_area();
265        let axis_color = self.chart.style.axis_color.to_hex();
266        let text_color = self.chart.style.text_color.to_hex();
267        let mut svg = String::new();
268
269        // X axis
270        svg.push_str(&format!(
271            r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="2"/>"#,
272            x,
273            y + h,
274            x + w,
275            y + h,
276            axis_color
277        ));
278
279        // Y axis
280        svg.push_str(&format!(
281            r#"<line x1="{}" y1="{}" x2="{}" y2="{}" stroke="{}" stroke-width="2"/>"#,
282            x,
283            y,
284            x,
285            y + h,
286            axis_color
287        ));
288
289        // X axis ticks and labels
290        for i in 0..=5 {
291            let x_pos = x + (w * f64::from(i) / 5.0);
292            let value = x_min + (x_max - x_min) * f64::from(i) / 5.0;
293            svg.push_str(&format!(
294                r#"<text x="{}" y="{}" text-anchor="middle" font-family="{}" font-size="{}" fill="{}">{:.1}</text>"#,
295                x_pos, y + h + 20.0, self.chart.style.font_family, self.chart.style.axis_font_size, text_color, value
296            ));
297        }
298
299        // Y axis ticks and labels
300        for i in 0..=5 {
301            let y_pos = y + h - (h * f64::from(i) / 5.0);
302            let value = y_min + (y_max - y_min) * f64::from(i) / 5.0;
303            svg.push_str(&format!(
304                r#"<text x="{}" y="{}" text-anchor="end" font-family="{}" font-size="{}" fill="{}">{:.1}</text>"#,
305                x - 5.0, y_pos + 4.0, self.chart.style.font_family, self.chart.style.axis_font_size, text_color, value
306            ));
307        }
308
309        svg
310    }
311
312    fn render_line_path(
313        &self,
314        x_data: &[f64],
315        y_data: &[f64],
316        x_min: f64,
317        x_max: f64,
318        y_min: f64,
319        y_max: f64,
320    ) -> String {
321        let (px, py, pw, ph) = self.plot_area();
322        let color = self.chart.style.primary.to_hex();
323
324        let points: Vec<String> = x_data
325            .iter()
326            .zip(y_data.iter())
327            .map(|(&x, &y)| {
328                let sx = px + (x - x_min) / (x_max - x_min) * pw;
329                let sy = py + ph - (y - y_min) / (y_max - y_min) * ph;
330                format!("{sx},{sy}")
331            })
332            .collect();
333
334        format!(
335            r#"<polyline points="{}" fill="none" stroke="{}" stroke-width="{}"/>"#,
336            points.join(" "),
337            color,
338            self.chart.style.line_width
339        )
340    }
341
342    fn render_area_path(
343        &self,
344        x_data: &[f64],
345        y_data: &[f64],
346        x_min: f64,
347        x_max: f64,
348        y_min: f64,
349        y_max: f64,
350    ) -> String {
351        let (px, py, pw, ph) = self.plot_area();
352        let color = self.chart.style.primary;
353        let fill_color = Color::with_alpha(color.r, color.g, color.b, 100);
354
355        let mut path = String::new();
356
357        // Move to first point
358        if let (Some(&first_x), Some(&first_y)) = (x_data.first(), y_data.first()) {
359            let sx = px + (first_x - x_min) / (x_max - x_min) * pw;
360            let sy = py + ph - (first_y - y_min) / (y_max - y_min) * ph;
361            path.push_str(&format!("M{sx},{} L{sx},{sy}", py + ph));
362        }
363
364        // Line through all points
365        for (&x, &y) in x_data.iter().zip(y_data.iter()) {
366            let sx = px + (x - x_min) / (x_max - x_min) * pw;
367            let sy = py + ph - (y - y_min) / (y_max - y_min) * ph;
368            path.push_str(&format!(" L{sx},{sy}"));
369        }
370
371        // Close path
372        if let Some(&last_x) = x_data.last() {
373            let sx = px + (last_x - x_min) / (x_max - x_min) * pw;
374            path.push_str(&format!(" L{sx},{} Z", py + ph));
375        }
376
377        format!(
378            r#"<path d="{}" fill="{}" stroke="{}" stroke-width="{}"/>"#,
379            path,
380            fill_color.to_rgba(),
381            color.to_hex(),
382            self.chart.style.line_width
383        )
384    }
385
386    fn render_points(
387        &self,
388        x_data: &[f64],
389        y_data: &[f64],
390        x_min: f64,
391        x_max: f64,
392        y_min: f64,
393        y_max: f64,
394    ) -> String {
395        let (px, py, pw, ph) = self.plot_area();
396        let color = self.chart.style.primary.to_hex();
397
398        x_data
399            .iter()
400            .zip(y_data.iter())
401            .map(|(&x, &y)| {
402                let sx = px + (x - x_min) / (x_max - x_min) * pw;
403                let sy = py + ph - (y - y_min) / (y_max - y_min) * ph;
404                format!(
405                    r#"<circle cx="{}" cy="{}" r="{}" fill="{}"/>"#,
406                    sx, sy, self.chart.style.point_radius, color
407                )
408            })
409            .collect()
410    }
411
412    fn render_bars(&self, labels: &[String], values: &[f64], y_max: f64) -> String {
413        let (px, py, pw, ph) = self.plot_area();
414        let bar_count = values.len();
415        let bar_width = pw / bar_count as f64 * 0.8;
416        let gap = pw / bar_count as f64 * 0.2;
417
418        let mut svg = String::new();
419
420        for (i, (label, &value)) in labels.iter().zip(values.iter()).enumerate() {
421            let x = px + (i as f64 * (bar_width + gap)) + gap / 2.0;
422            let height = (value / y_max) * ph;
423            let y = py + ph - height;
424            let color = self.chart.style.series_color(i);
425
426            svg.push_str(&format!(
427                r#"<rect x="{}" y="{}" width="{}" height="{}" fill="{}"/>"#,
428                x,
429                y,
430                bar_width,
431                height,
432                color.to_hex()
433            ));
434
435            // Label
436            svg.push_str(&format!(
437                r#"<text x="{}" y="{}" text-anchor="middle" font-family="{}" font-size="{}" fill="{}">{}</text>"#,
438                x + bar_width / 2.0,
439                py + ph + 20.0,
440                self.chart.style.font_family,
441                self.chart.style.axis_font_size,
442                self.chart.style.text_color.to_hex(),
443                html_escape(label)
444            ));
445        }
446
447        svg
448    }
449
450    fn render_pie_slices(&self, labels: &[String], values: &[f64]) -> String {
451        let cx = self.chart.width / 2.0;
452        let cy = self.chart.height / 2.0;
453        let radius = f64::min(cx, cy) - self.chart.style.padding.top;
454        let total: f64 = values.iter().sum();
455
456        let mut svg = String::new();
457        let mut start_angle = -std::f64::consts::FRAC_PI_2; // Start at top
458
459        for (i, (&value, label)) in values.iter().zip(labels.iter()).enumerate() {
460            let angle = (value / total) * 2.0 * std::f64::consts::PI;
461            let end_angle = start_angle + angle;
462            let color = self.chart.style.series_color(i);
463
464            let x1 = cx + radius * start_angle.cos();
465            let y1 = cy + radius * start_angle.sin();
466            let x2 = cx + radius * end_angle.cos();
467            let y2 = cy + radius * end_angle.sin();
468
469            let large_arc = if angle > std::f64::consts::PI { 1 } else { 0 };
470
471            svg.push_str(&format!(
472                r#"<path d="M{cx},{cy} L{x1},{y1} A{radius},{radius} 0 {large_arc},1 {x2},{y2} Z" fill="{}"/>"#,
473                color.to_hex()
474            ));
475
476            // Label
477            let label_angle = start_angle + angle / 2.0;
478            let label_radius = radius * 0.7;
479            let lx = cx + label_radius * label_angle.cos();
480            let ly = cy + label_radius * label_angle.sin();
481
482            svg.push_str(&format!(
483                r#"<text x="{}" y="{}" text-anchor="middle" font-family="{}" font-size="{}" fill="white">{}</text>"#,
484                lx, ly,
485                self.chart.style.font_family,
486                self.chart.style.axis_font_size,
487                html_escape(label)
488            ));
489
490            start_angle = end_angle;
491        }
492
493        svg
494    }
495
496    fn extract_xy(&self, table: &Table) -> Result<(Vec<f64>, Vec<f64>)> {
497        let x_col = self.chart.x_column.as_deref().unwrap_or("x");
498        let y_col = self.chart.y_column.as_deref().unwrap_or("y");
499
500        // Try to get columns by name, or use first two columns
501        let x_data = table
502            .column_as_f64(x_col)
503            .or_else(|| table.headers.first().and_then(|h| table.column_as_f64(h)))
504            .ok_or_else(|| Error::MissingColumn {
505                column: x_col.to_string(),
506            })?;
507
508        let y_data = table
509            .column_as_f64(y_col)
510            .or_else(|| table.headers.get(1).and_then(|h| table.column_as_f64(h)))
511            .ok_or_else(|| Error::MissingColumn {
512                column: y_col.to_string(),
513            })?;
514
515        Ok((x_data, y_data))
516    }
517
518    fn extract_labels_values(&self, table: &Table) -> Result<(Vec<String>, Vec<f64>)> {
519        if table.headers.len() < 2 {
520            return Err(Error::InvalidData {
521                message: "need at least 2 columns".to_string(),
522            });
523        }
524
525        let labels = table.column_as_str(&table.headers[0]).unwrap_or_default();
526        let values =
527            table
528                .column_as_f64(&table.headers[1])
529                .ok_or_else(|| Error::MissingColumn {
530                    column: table.headers[1].clone(),
531                })?;
532
533        Ok((labels, values))
534    }
535
536    fn calculate_bounds(&self, x_data: &[f64], y_data: &[f64]) -> (f64, f64, f64, f64) {
537        let x_min = x_data.iter().cloned().fold(f64::INFINITY, f64::min);
538        let x_max = x_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
539        let y_min = y_data.iter().cloned().fold(f64::INFINITY, f64::min);
540        let y_max = y_data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
541
542        // Add some padding to bounds
543        let x_padding = (x_max - x_min) * 0.05;
544        let y_padding = (y_max - y_min) * 0.05;
545
546        (
547            x_min - x_padding,
548            x_max + x_padding,
549            (y_min - y_padding).min(0.0),
550            y_max + y_padding,
551        )
552    }
553}
554
555/// Escape HTML entities.
556fn html_escape(s: &str) -> String {
557    s.replace('&', "&amp;")
558        .replace('<', "&lt;")
559        .replace('>', "&gt;")
560        .replace('"', "&quot;")
561        .replace('\'', "&#39;")
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use crate::chart::Chart;
568
569    #[test]
570    fn render_line_chart() {
571        let data =
572            DataSource::from_points(vec![(1.0, 10.0), (2.0, 20.0), (3.0, 15.0), (4.0, 25.0)]);
573
574        let chart = Chart::new(ChartType::Line)
575            .with_title("Test Line Chart")
576            .with_data(data);
577
578        let svg = chart.render().unwrap();
579        assert!(svg.contains("<svg"));
580        assert!(svg.contains("polyline"));
581        assert!(svg.contains("Test Line Chart"));
582    }
583
584    #[test]
585    fn render_bar_chart() {
586        let data = DataSource::from_csv_string("label,value\nA,10\nB,20\nC,15").unwrap();
587
588        let chart = Chart::new(ChartType::Bar)
589            .with_title("Test Bar Chart")
590            .with_data(data);
591
592        let svg = chart.render().unwrap();
593        assert!(svg.contains("<svg"));
594        assert!(svg.contains("rect"));
595    }
596
597    #[test]
598    fn render_scatter_plot() {
599        let data = DataSource::from_points(vec![(1.0, 1.0), (2.0, 4.0), (3.0, 9.0)]);
600
601        let chart = Chart::new(ChartType::Scatter).with_data(data);
602
603        let svg = chart.render().unwrap();
604        assert!(svg.contains("circle"));
605    }
606
607    #[test]
608    fn html_escape_works() {
609        assert_eq!(html_escape("<script>"), "&lt;script&gt;");
610        assert_eq!(html_escape("a & b"), "a &amp; b");
611    }
612}