insta_fun/
chart.rs

1use plotters::backend::SVGBackend;
2use plotters::drawing::IntoDrawingArea;
3use plotters::element::DashedPathElement;
4use plotters::prelude::*;
5
6use crate::abnormal::{AbnormalSample, abnormal_smaples_series};
7use crate::chart_data::ChannelChartData;
8use crate::config::SvgChartConfig;
9use crate::util::{
10    INPUT_CHANNEL_COLORS, OUTPUT_CHANNEL_COLORS, get_contrasting_color, num_x_labels,
11    parse_hex_color, time_formatter,
12};
13
14/// Chart layout
15///
16/// Whether to plot channels on separate charts or combined charts.
17#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
18pub enum Layout {
19    /// Each channel plots on its own chart
20    #[default]
21    SeparateChannels,
22    /// All input channels plot on one chart, all output channels plot on another chart
23    ///
24    /// Same as `Combined` when `config.with_inputs` is `false`
25    CombinedPerChannelType,
26    /// All channels plot on one chart
27    Combined,
28}
29
30pub(crate) fn generate_svg(
31    input_data: &[Vec<f32>],
32    output_data: &[Vec<f32>],
33    abnormalities: &[Vec<(usize, AbnormalSample)>],
34    config: &SvgChartConfig,
35    sample_rate: f64,
36    num_samples: usize,
37    start_sample: usize,
38) -> String {
39    let height_per_channel = config.svg_height_per_channel;
40    let num_channels = output_data.len()
41        + if config.with_inputs {
42            input_data.len()
43        } else {
44            0
45        };
46
47    if num_samples == 0 || num_channels == 0 {
48        return "<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 100 100\"><text>Empty</text></svg>".to_string();
49    }
50
51    let svg_width = config.svg_width.unwrap_or(num_samples * 2) as u32;
52    let total_height = (height_per_channel * num_channels) as u32;
53
54    // Create SVG backend with buffer
55    let mut svg_buffer = String::new();
56    {
57        let root =
58            SVGBackend::with_string(&mut svg_buffer, (svg_width, total_height)).into_drawing_area();
59
60        // Fill background
61        let bg_color = parse_hex_color(&config.background_color);
62        root.fill(&bg_color).unwrap();
63
64        // Add optional title with contrasting color
65        let current_area = if let Some(ref title) = config.chart_title {
66            let title_color = get_contrasting_color(&bg_color);
67            let text_style = TextStyle::from(("sans-serif", 20)).color(&title_color);
68            root.titled(title, text_style).unwrap()
69        } else {
70            root
71        };
72
73        let input_charts: Vec<ChannelChartData> = if config.with_inputs {
74            input_data
75                .iter()
76                .enumerate()
77                .map(|(i, data)| ChannelChartData::from_input_data(data, i, config))
78                .collect()
79        } else {
80            vec![]
81        };
82
83        let output_charts: Vec<ChannelChartData> = output_data
84            .iter()
85            .zip(abnormalities)
86            .enumerate()
87            .map(|(i, (data, abnormalities))| {
88                ChannelChartData::from_output_data(data, abnormalities, i, config)
89            })
90            .collect();
91
92        let output_axis_color = parse_hex_color(OUTPUT_CHANNEL_COLORS[0]);
93        let input_axis_color = parse_hex_color(INPUT_CHANNEL_COLORS[0]);
94
95        match config.chart_layout {
96            Layout::SeparateChannels => {
97                // Split area for each channel
98                let areas = current_area.split_evenly((num_channels, 1));
99                for (chart, area) in input_charts
100                    .into_iter()
101                    .chain(output_charts.into_iter())
102                    .zip(areas)
103                {
104                    one_channel_chart(chart, config, start_sample, &area, sample_rate);
105                }
106            }
107            Layout::CombinedPerChannelType => {
108                if config.with_inputs {
109                    let areas = current_area.split_evenly((2, 1));
110
111                    multi_channel_chart(
112                        input_charts,
113                        config,
114                        true,
115                        start_sample,
116                        input_axis_color,
117                        &areas[0],
118                        sample_rate,
119                    );
120                    multi_channel_chart(
121                        output_charts,
122                        config,
123                        true,
124                        start_sample,
125                        output_axis_color,
126                        &areas[1],
127                        sample_rate,
128                    );
129                } else {
130                    multi_channel_chart(
131                        output_charts,
132                        config,
133                        true,
134                        start_sample,
135                        output_axis_color,
136                        &current_area,
137                        sample_rate,
138                    );
139                }
140            }
141            Layout::Combined => {
142                let charts = output_charts.into_iter().chain(input_charts).collect();
143                multi_channel_chart(
144                    charts,
145                    config,
146                    false,
147                    start_sample,
148                    output_axis_color,
149                    &current_area,
150                    sample_rate,
151                );
152            }
153        }
154
155        current_area.present().unwrap();
156    }
157
158    svg_buffer
159}
160
161fn multi_channel_chart(
162    charts_data: Vec<ChannelChartData>,
163    config: &SvgChartConfig,
164    solid_input: bool,
165    start_from: usize,
166    axis_color: RGBColor,
167    area: &DrawingArea<SVGBackend<'_>, plotters::coord::Shift>,
168    sample_rate: f64,
169) {
170    let num_samples = charts_data
171        .iter()
172        .map(|chart| chart.data.len())
173        .max()
174        .unwrap_or_default();
175    let min_val = charts_data
176        .iter()
177        .flat_map(|c| c.data.iter())
178        .cloned()
179        .fold(f32::INFINITY, f32::min);
180    let max_val = charts_data
181        .iter()
182        .flat_map(|c| c.data.iter())
183        .cloned()
184        .fold(f32::NEG_INFINITY, f32::max);
185
186    let range = (max_val - min_val).max(f32::EPSILON);
187    let y_min = (min_val - range * 0.1) as f64;
188    let y_max = (max_val + range * 0.1) as f64;
189
190    // Build chart
191    let mut chart = ChartBuilder::on(area)
192        .margin(5)
193        .x_label_area_size(35)
194        .y_label_area_size(50)
195        .build_cartesian_2d(
196            start_from as f64..(num_samples + start_from) as f64,
197            y_min..y_max,
198        )
199        .unwrap();
200
201    let mut mesh = chart.configure_mesh();
202
203    mesh.axis_style(axis_color.mix(0.3));
204
205    if !config.show_grid {
206        mesh.disable_mesh();
207    } else {
208        mesh.light_line_style(axis_color.mix(0.1))
209            .bold_line_style(axis_color.mix(0.2));
210    }
211
212    if config.show_labels {
213        let x_labels = num_x_labels(num_samples, sample_rate);
214        mesh.x_labels(
215            config
216                .max_labels_x_axis
217                .map(|mx| x_labels.min(mx))
218                .unwrap_or(x_labels),
219        )
220        .y_labels(3)
221        .label_style(("sans-serif", 10, &axis_color));
222    }
223
224    let formatter = |v: &f64| time_formatter(*v as usize, sample_rate);
225    if config.format_x_axis_labels_as_time {
226        mesh.x_label_formatter(&formatter);
227    }
228
229    mesh.draw().unwrap();
230
231    let mut has_legend = false;
232
233    // Draw outputs (or inputs as solid when `solid_input` is true) one by one,
234    // registering a legend entry per series.
235    for entry in charts_data.iter().filter(|d| !d.is_input || solid_input) {
236        let ChannelChartData {
237            data: channel_data,
238            color,
239            label,
240            ..
241        } = entry;
242
243        let line_style = ShapeStyle {
244            color: color.to_rgba(),
245            filled: false,
246            stroke_width: config.line_width as u32,
247        };
248
249        let series = chart
250            .draw_series(std::iter::once(PathElement::new(
251                channel_data
252                    .iter()
253                    .enumerate()
254                    .map(|(i, &sample)| ((i + start_from) as f64, sample as f64))
255                    .collect::<Vec<(f64, f64)>>(),
256                line_style,
257            )))
258            .unwrap();
259
260        if let Some(label) = label {
261            series
262                .label(label)
263                .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], entry.color));
264            has_legend = true;
265        }
266    }
267
268    // Dashed inputs when not solid
269    if !solid_input && charts_data.iter().any(|d| d.is_input) {
270        for entry in charts_data.iter().filter(|d| d.is_input) {
271            let ChannelChartData {
272                data: channel_data,
273                color,
274                label,
275                ..
276            } = entry;
277
278            let line_style = ShapeStyle {
279                color: color.to_rgba(),
280                filled: false,
281                stroke_width: config.line_width as u32,
282            };
283
284            let dashed = DashedPathElement::new(
285                channel_data
286                    .iter()
287                    .enumerate()
288                    .map(|(i, &sample)| ((i + start_from) as f64, sample as f64))
289                    .collect::<Vec<(f64, f64)>>(),
290                2,
291                3,
292                line_style,
293            );
294
295            let series = chart.draw_series(std::iter::once(dashed)).unwrap();
296
297            if let Some(label) = label {
298                series.label(label).legend(|(x, y)| {
299                    DashedPathElement::new(vec![(x, y), (x + 20, y)], 2, 3, entry.color)
300                });
301                has_legend = true;
302            }
303        }
304    }
305
306    abnormal_smaples_series(&charts_data, &mut chart, y_min, y_max);
307
308    if has_legend {
309        let background = parse_hex_color(&config.background_color);
310        let contrasting = get_contrasting_color(&background);
311
312        chart
313            .configure_series_labels()
314            .border_style(contrasting)
315            .background_style(background)
316            .label_font(TextStyle::from(("sans-serif", 10)).color(&contrasting))
317            .draw()
318            .unwrap();
319    }
320}
321
322fn one_channel_chart(
323    chart_data: ChannelChartData,
324    config: &SvgChartConfig,
325    start_from: usize,
326    area: &DrawingArea<SVGBackend<'_>, plotters::coord::Shift>,
327    sample_rate: f64,
328) {
329    let ChannelChartData {
330        data: channel_data,
331        color,
332        label,
333        ..
334    } = &chart_data;
335
336    let num_samples = channel_data.len();
337
338    // Calculate data range
339    let min_val = channel_data.iter().cloned().fold(f32::INFINITY, f32::min);
340    let max_val = channel_data
341        .iter()
342        .cloned()
343        .fold(f32::NEG_INFINITY, f32::max);
344    let range = (max_val - min_val).max(f32::EPSILON);
345    let y_min = (min_val - range * 0.1) as f64;
346    let y_max = (max_val + range * 0.1) as f64;
347
348    // Build chart
349    let mut chart = ChartBuilder::on(area)
350        .margin(5)
351        .x_label_area_size(if label.is_some() { 35 } else { 0 })
352        .y_label_area_size(if label.is_some() { 50 } else { 0 })
353        .build_cartesian_2d(
354            start_from as f64..(num_samples + start_from) as f64,
355            y_min..y_max,
356        )
357        .unwrap();
358
359    let mut mesh = chart.configure_mesh();
360
361    mesh.axis_style(color.mix(0.3));
362
363    if !config.show_grid {
364        mesh.disable_mesh();
365    } else {
366        mesh.light_line_style(color.mix(0.1))
367            .bold_line_style(color.mix(0.2));
368    }
369
370    if let Some(label) = label {
371        let x_labels = num_x_labels(num_samples, sample_rate);
372        mesh.x_labels(
373            config
374                .max_labels_x_axis
375                .map(|mx| x_labels.min(mx))
376                .unwrap_or(x_labels),
377        )
378        .y_labels(3)
379        .x_desc(label)
380        .label_style(("sans-serif", 10, &color));
381    }
382
383    let formatter = |v: &f64| time_formatter(*v as usize, sample_rate);
384    if config.format_x_axis_labels_as_time {
385        mesh.x_label_formatter(&formatter);
386    }
387
388    mesh.draw().unwrap();
389
390    // Draw waveform
391    let line_style = ShapeStyle {
392        color: color.to_rgba(),
393        filled: false,
394        stroke_width: config.line_width as u32,
395    };
396
397    chart
398        .draw_series(std::iter::once(PathElement::new(
399            channel_data
400                .iter()
401                .enumerate()
402                .map(|(i, &sample)| ((i + start_from) as f64, sample as f64))
403                .collect::<Vec<(f64, f64)>>(),
404            line_style,
405        )))
406        .unwrap();
407
408    abnormal_smaples_series(&[chart_data], &mut chart, y_min, y_max);
409}