1use plotters::backend::SVGBackend;
2use plotters::drawing::IntoDrawingArea;
3use plotters::element::DashedPathElement;
4use plotters::prelude::*;
5
6use crate::abnormal::{AbnormalSample, abnormal_smaples_series};
7use crate::chart_data::ChannelChartData;
8use crate::config::SnapshotConfig;
9use crate::util::{
10 INPUT_CHANNEL_COLORS, OUTPUT_CHANNEL_COLORS, get_contrasting_color, num_x_labels,
11 parse_hex_color, time_formatter,
12};
13
14#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
18pub enum Layout {
19 #[default]
21 SeparateChannels,
22 CombinedPerChannelType,
26 Combined,
28}
29
30pub(crate) fn generate_svg(
31 input_data: &[Vec<f32>],
32 output_data: &[Vec<f32>],
33 abnormalities: &[Vec<(usize, AbnormalSample)>],
34 config: &SnapshotConfig,
35) -> String {
36 let height_per_channel = config.svg_height_per_channel;
37 let num_channels = output_data.len()
38 + if config.with_inputs {
39 input_data.len()
40 } else {
41 0
42 };
43 let num_samples = output_data.first().map(|c| c.len()).unwrap_or(0);
44
45 if num_samples == 0 || num_channels == 0 {
46 return "<svg xmlns=\"http://www.w3.org/2000/svg\" viewBox=\"0 0 100 100\"><text>Empty</text></svg>".to_string();
47 }
48
49 let svg_width = config.svg_width.unwrap_or(config.num_samples * 2) as u32;
50 let total_height = (height_per_channel * num_channels) as u32;
51
52 let mut svg_buffer = String::new();
54 {
55 let root =
56 SVGBackend::with_string(&mut svg_buffer, (svg_width, total_height)).into_drawing_area();
57
58 let bg_color = parse_hex_color(&config.background_color);
60 root.fill(&bg_color).unwrap();
61
62 let current_area = if let Some(ref title) = config.chart_title {
64 let title_color = get_contrasting_color(&bg_color);
65 let text_style = TextStyle::from(("sans-serif", 20)).color(&title_color);
66 root.titled(title, text_style).unwrap()
67 } else {
68 root
69 };
70
71 let input_charts: Vec<ChannelChartData> = if config.with_inputs {
72 input_data
73 .iter()
74 .enumerate()
75 .map(|(i, data)| ChannelChartData::from_input_data(data, i, config))
76 .collect()
77 } else {
78 vec![]
79 };
80
81 let output_charts: Vec<ChannelChartData> = output_data
82 .iter()
83 .zip(abnormalities)
84 .enumerate()
85 .map(|(i, (data, abnormalities))| {
86 ChannelChartData::from_output_data(data, abnormalities, i, config)
87 })
88 .collect();
89
90 let start_sample = config.warm_up.num_samples(config.sample_rate);
91
92 let output_axis_color = parse_hex_color(OUTPUT_CHANNEL_COLORS[0]);
93 let input_axis_color = parse_hex_color(INPUT_CHANNEL_COLORS[0]);
94
95 match config.chart_layout {
96 Layout::SeparateChannels => {
97 let areas = current_area.split_evenly((num_channels, 1));
99 for (chart, area) in input_charts
100 .into_iter()
101 .chain(output_charts.into_iter())
102 .zip(areas)
103 {
104 one_channel_chart(chart, config, start_sample, &area);
105 }
106 }
107 Layout::CombinedPerChannelType => {
108 if config.with_inputs {
109 let areas = current_area.split_evenly((2, 1));
110
111 multi_channel_chart(
112 input_charts,
113 config,
114 true,
115 start_sample,
116 input_axis_color,
117 &areas[0],
118 );
119 multi_channel_chart(
120 output_charts,
121 config,
122 true,
123 start_sample,
124 output_axis_color,
125 &areas[1],
126 );
127 } else {
128 multi_channel_chart(
129 output_charts,
130 config,
131 true,
132 start_sample,
133 output_axis_color,
134 ¤t_area,
135 );
136 }
137 }
138 Layout::Combined => {
139 let charts = output_charts.into_iter().chain(input_charts).collect();
140 multi_channel_chart(
141 charts,
142 config,
143 false,
144 start_sample,
145 output_axis_color,
146 ¤t_area,
147 );
148 }
149 }
150
151 current_area.present().unwrap();
152 }
153
154 svg_buffer
155}
156
157fn multi_channel_chart(
158 charts_data: Vec<ChannelChartData>,
159 config: &SnapshotConfig,
160 solid_input: bool,
161 start_from: usize,
162 axis_color: RGBColor,
163 area: &DrawingArea<SVGBackend<'_>, plotters::coord::Shift>,
164) {
165 let num_samples = charts_data
166 .iter()
167 .map(|chart| chart.data.len())
168 .max()
169 .unwrap_or_default();
170 let min_val = charts_data
171 .iter()
172 .flat_map(|c| c.data.iter())
173 .cloned()
174 .fold(f32::INFINITY, f32::min);
175 let max_val = charts_data
176 .iter()
177 .flat_map(|c| c.data.iter())
178 .cloned()
179 .fold(f32::NEG_INFINITY, f32::max);
180
181 let range = (max_val - min_val).max(f32::EPSILON);
182 let y_min = (min_val - range * 0.1) as f64;
183 let y_max = (max_val + range * 0.1) as f64;
184
185 let mut chart = ChartBuilder::on(area)
187 .margin(5)
188 .x_label_area_size(35)
189 .y_label_area_size(50)
190 .build_cartesian_2d(
191 start_from as f64..(num_samples + start_from) as f64,
192 y_min..y_max,
193 )
194 .unwrap();
195
196 let mut mesh = chart.configure_mesh();
197
198 mesh.axis_style(axis_color.mix(0.3));
199
200 if !config.show_grid {
201 mesh.disable_mesh();
202 } else {
203 mesh.light_line_style(axis_color.mix(0.1))
204 .bold_line_style(axis_color.mix(0.2));
205 }
206
207 if config.show_labels {
208 let x_labels = num_x_labels(num_samples, config.sample_rate);
209 mesh.x_labels(
210 config
211 .max_labels_x_axis
212 .map(|mx| x_labels.min(mx))
213 .unwrap_or(x_labels),
214 )
215 .y_labels(3)
216 .label_style(("sans-serif", 10, &axis_color));
217 }
218
219 let formatter = |v: &f64| time_formatter(*v as usize, config.sample_rate);
220 if config.format_x_axis_labels_as_time {
221 mesh.x_label_formatter(&formatter);
222 }
223
224 mesh.draw().unwrap();
225
226 let mut has_legend = false;
227
228 for entry in charts_data.iter().filter(|d| !d.is_input || solid_input) {
231 let ChannelChartData {
232 data: channel_data,
233 color,
234 label,
235 ..
236 } = entry;
237
238 let line_style = ShapeStyle {
239 color: color.to_rgba(),
240 filled: false,
241 stroke_width: config.line_width as u32,
242 };
243
244 let series = chart
245 .draw_series(std::iter::once(PathElement::new(
246 channel_data
247 .iter()
248 .enumerate()
249 .map(|(i, &sample)| ((i + start_from) as f64, sample as f64))
250 .collect::<Vec<(f64, f64)>>(),
251 line_style,
252 )))
253 .unwrap();
254
255 if let Some(label) = label {
256 series
257 .label(label)
258 .legend(|(x, y)| PathElement::new(vec![(x, y), (x + 20, y)], entry.color));
259 has_legend = true;
260 }
261 }
262
263 if !solid_input && charts_data.iter().any(|d| d.is_input) {
265 for entry in charts_data.iter().filter(|d| d.is_input) {
266 let ChannelChartData {
267 data: channel_data,
268 color,
269 label,
270 ..
271 } = entry;
272
273 let line_style = ShapeStyle {
274 color: color.to_rgba(),
275 filled: false,
276 stroke_width: config.line_width as u32,
277 };
278
279 let dashed = DashedPathElement::new(
280 channel_data
281 .iter()
282 .enumerate()
283 .map(|(i, &sample)| ((i + start_from) as f64, sample as f64))
284 .collect::<Vec<(f64, f64)>>(),
285 2,
286 3,
287 line_style,
288 );
289
290 let series = chart.draw_series(std::iter::once(dashed)).unwrap();
291
292 if let Some(label) = label {
293 series.label(label).legend(|(x, y)| {
294 DashedPathElement::new(vec![(x, y), (x + 20, y)], 2, 3, entry.color)
295 });
296 has_legend = true;
297 }
298 }
299 }
300
301 abnormal_smaples_series(&charts_data, &mut chart, y_min, y_max);
302
303 if has_legend {
304 let background = parse_hex_color(&config.background_color);
305 let contrasting = get_contrasting_color(&background);
306
307 chart
308 .configure_series_labels()
309 .border_style(contrasting)
310 .background_style(background)
311 .label_font(TextStyle::from(("sans-serif", 10)).color(&contrasting))
312 .draw()
313 .unwrap();
314 }
315}
316
317fn one_channel_chart(
318 chart_data: ChannelChartData,
319 config: &SnapshotConfig,
320 start_from: usize,
321 area: &DrawingArea<SVGBackend<'_>, plotters::coord::Shift>,
322) {
323 let ChannelChartData {
324 data: channel_data,
325 color,
326 label,
327 ..
328 } = &chart_data;
329
330 let num_samples = channel_data.len();
331
332 let min_val = channel_data.iter().cloned().fold(f32::INFINITY, f32::min);
334 let max_val = channel_data
335 .iter()
336 .cloned()
337 .fold(f32::NEG_INFINITY, f32::max);
338 let range = (max_val - min_val).max(f32::EPSILON);
339 let y_min = (min_val - range * 0.1) as f64;
340 let y_max = (max_val + range * 0.1) as f64;
341
342 let mut chart = ChartBuilder::on(area)
344 .margin(5)
345 .x_label_area_size(if label.is_some() { 35 } else { 0 })
346 .y_label_area_size(if label.is_some() { 50 } else { 0 })
347 .build_cartesian_2d(
348 start_from as f64..(num_samples + start_from) as f64,
349 y_min..y_max,
350 )
351 .unwrap();
352
353 let mut mesh = chart.configure_mesh();
354
355 mesh.axis_style(color.mix(0.3));
356
357 if !config.show_grid {
358 mesh.disable_mesh();
359 } else {
360 mesh.light_line_style(color.mix(0.1))
361 .bold_line_style(color.mix(0.2));
362 }
363
364 if let Some(label) = label {
365 let x_labels = num_x_labels(num_samples, config.sample_rate);
366 mesh.x_labels(
367 config
368 .max_labels_x_axis
369 .map(|mx| x_labels.min(mx))
370 .unwrap_or(x_labels),
371 )
372 .y_labels(3)
373 .x_desc(label)
374 .label_style(("sans-serif", 10, &color));
375 }
376
377 let formatter = |v: &f64| time_formatter(*v as usize, config.sample_rate);
378 if config.format_x_axis_labels_as_time {
379 mesh.x_label_formatter(&formatter);
380 }
381
382 mesh.draw().unwrap();
383
384 let line_style = ShapeStyle {
386 color: color.to_rgba(),
387 filled: false,
388 stroke_width: config.line_width as u32,
389 };
390
391 chart
392 .draw_series(std::iter::once(PathElement::new(
393 channel_data
394 .iter()
395 .enumerate()
396 .map(|(i, &sample)| ((i + start_from) as f64, sample as f64))
397 .collect::<Vec<(f64, f64)>>(),
398 line_style,
399 )))
400 .unwrap();
401
402 abnormal_smaples_series(&[chart_data], &mut chart, y_min, y_max);
403}