1use crate::{UtilsError, UtilsResult};
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11
12pub struct ChartData;
14
15impl ChartData {
16 pub fn prepare_scatter_plot(
18 x: &Array1<f64>,
19 y: &Array1<f64>,
20 labels: Option<&Array1<String>>,
21 ) -> UtilsResult<ScatterPlotData> {
22 if x.len() != y.len() {
23 return Err(UtilsError::ShapeMismatch {
24 expected: vec![x.len()],
25 actual: vec![y.len()],
26 });
27 }
28
29 let points: Vec<Point2D> = x
30 .iter()
31 .zip(y.iter())
32 .map(|(&x_val, &y_val)| Point2D { x: x_val, y: y_val })
33 .collect();
34
35 let labels = labels
36 .map(|l| l.to_vec())
37 .unwrap_or_else(|| (0..x.len()).map(|i| format!("Point {i}")).collect());
38
39 Ok(ScatterPlotData { points, labels })
40 }
41
42 pub fn prepare_line_plot(
44 x: &Array1<f64>,
45 y: &Array1<f64>,
46 line_name: Option<String>,
47 ) -> UtilsResult<LinePlotData> {
48 if x.len() != y.len() {
49 return Err(UtilsError::ShapeMismatch {
50 expected: vec![x.len()],
51 actual: vec![y.len()],
52 });
53 }
54
55 let points: Vec<Point2D> = x
56 .iter()
57 .zip(y.iter())
58 .map(|(&x_val, &y_val)| Point2D { x: x_val, y: y_val })
59 .collect();
60
61 Ok(LinePlotData {
62 points,
63 name: line_name.unwrap_or_else(|| "Line".to_string()),
64 })
65 }
66
67 pub fn prepare_histogram(
69 data: &Array1<f64>,
70 bins: Option<usize>,
71 ) -> UtilsResult<HistogramData> {
72 if data.is_empty() {
73 return Err(UtilsError::EmptyInput);
74 }
75
76 let bins = bins.unwrap_or(10);
77 let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
78 let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
79
80 if min_val == max_val {
81 return Err(UtilsError::InvalidParameter(
82 "All values are the same, cannot create histogram".to_string(),
83 ));
84 }
85
86 let bin_width = (max_val - min_val) / bins as f64;
87 let mut bin_counts = vec![0; bins];
88 let mut bin_edges = Vec::with_capacity(bins + 1);
89
90 for i in 0..=bins {
92 bin_edges.push(min_val + i as f64 * bin_width);
93 }
94
95 for &value in data.iter() {
97 let bin_index = ((value - min_val) / bin_width).floor() as usize;
98 let bin_index = bin_index.min(bins - 1); bin_counts[bin_index] += 1;
100 }
101
102 Ok(HistogramData {
103 counts: bin_counts,
104 bin_edges,
105 total_count: data.len(),
106 })
107 }
108
109 pub fn prepare_heatmap(
111 data: &Array2<f64>,
112 row_labels: Option<&[String]>,
113 col_labels: Option<&[String]>,
114 ) -> UtilsResult<HeatmapData> {
115 let (rows, cols) = data.dim();
116
117 if rows == 0 || cols == 0 {
118 return Err(UtilsError::EmptyInput);
119 }
120
121 let values: Vec<Vec<f64>> = data.axis_iter(Axis(0)).map(|row| row.to_vec()).collect();
122
123 let row_labels = row_labels
124 .map(|labels| labels.to_vec())
125 .unwrap_or_else(|| (0..rows).map(|i| format!("Row {i}")).collect());
126
127 let col_labels = col_labels
128 .map(|labels| labels.to_vec())
129 .unwrap_or_else(|| (0..cols).map(|i| format!("Col {i}")).collect());
130
131 let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
133 let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
134
135 Ok(HeatmapData {
136 values,
137 row_labels,
138 col_labels,
139 min_value: min_val,
140 max_value: max_val,
141 })
142 }
143
144 pub fn prepare_box_plot(data: &Array1<f64>, label: Option<String>) -> UtilsResult<BoxPlotData> {
146 if data.is_empty() {
147 return Err(UtilsError::EmptyInput);
148 }
149
150 let mut sorted_data = data.to_vec();
151 sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
152
153 let len = sorted_data.len();
154 let q1 = Self::calculate_quantile(&sorted_data, 0.25);
155 let median = Self::calculate_quantile(&sorted_data, 0.5);
156 let q3 = Self::calculate_quantile(&sorted_data, 0.75);
157
158 let iqr = q3 - q1;
159 let lower_fence = q1 - 1.5 * iqr;
160 let upper_fence = q3 + 1.5 * iqr;
161
162 let outliers: Vec<f64> = sorted_data
163 .iter()
164 .copied()
165 .filter(|&x| x < lower_fence || x > upper_fence)
166 .collect();
167
168 let whisker_low = sorted_data
169 .iter()
170 .find(|&&x| x >= lower_fence)
171 .copied()
172 .unwrap_or(sorted_data[0]);
173
174 let whisker_high = sorted_data
175 .iter()
176 .rev()
177 .find(|&&x| x <= upper_fence)
178 .copied()
179 .unwrap_or(sorted_data[len - 1]);
180
181 Ok(BoxPlotData {
182 q1,
183 median,
184 q3,
185 whisker_low,
186 whisker_high,
187 outliers,
188 label: label.unwrap_or_else(|| "Data".to_string()),
189 })
190 }
191
192 fn calculate_quantile(sorted_data: &[f64], quantile: f64) -> f64 {
193 let index = quantile * (sorted_data.len() - 1) as f64;
194 let lower_index = index.floor() as usize;
195 let upper_index = index.ceil() as usize;
196
197 if lower_index == upper_index {
198 sorted_data[lower_index]
199 } else {
200 let weight = index - index.floor();
201 sorted_data[lower_index] * (1.0 - weight) + sorted_data[upper_index] * weight
202 }
203 }
204}
205
206pub struct PlotUtils;
208
209impl PlotUtils {
210 pub fn create_color_palette(num_colors: usize) -> Vec<Color> {
212 let base_colors = vec![
213 Color::rgb(31, 119, 180), Color::rgb(255, 127, 14), Color::rgb(44, 160, 44), Color::rgb(214, 39, 40), Color::rgb(148, 103, 189), Color::rgb(140, 86, 75), Color::rgb(227, 119, 194), Color::rgb(127, 127, 127), Color::rgb(188, 189, 34), Color::rgb(23, 190, 207), ];
224
225 if num_colors <= base_colors.len() {
226 base_colors.into_iter().take(num_colors).collect()
227 } else {
228 let base_len = base_colors.len();
230 let mut colors = base_colors;
231 for i in base_len..num_colors {
232 let hue = (i as f64 * 360.0 / num_colors as f64) % 360.0;
233 let color = Color::from_hsv(hue, 0.8, 0.8);
234 colors.push(color);
235 }
236 colors
237 }
238 }
239
240 pub fn create_axis_config(
242 label: &str,
243 min_val: Option<f64>,
244 max_val: Option<f64>,
245 tick_count: Option<usize>,
246 ) -> AxisConfig {
247 AxisConfig {
248 label: label.to_string(),
249 min_value: min_val,
250 max_value: max_val,
251 tick_count: tick_count.unwrap_or(10),
252 grid_lines: true,
253 log_scale: false,
254 }
255 }
256
257 pub fn to_json(plot_data: &PlotData) -> UtilsResult<String> {
259 serde_json::to_string_pretty(plot_data)
260 .map_err(|e| UtilsError::InvalidParameter(format!("JSON serialization error: {e}")))
261 }
262
263 pub fn to_csv(scatter_data: &ScatterPlotData) -> UtilsResult<String> {
265 let mut csv = String::new();
266 csv.push_str("x,y,label\n");
267
268 for (point, label) in scatter_data.points.iter().zip(&scatter_data.labels) {
269 csv.push_str(&format!("{},{},{}\n", point.x, point.y, label));
270 }
271
272 Ok(csv)
273 }
274
275 pub fn create_layout(
277 title: &str,
278 x_axis: AxisConfig,
279 y_axis: AxisConfig,
280 width: Option<u32>,
281 height: Option<u32>,
282 ) -> PlotLayout {
283 PlotLayout {
284 title: title.to_string(),
285 x_axis,
286 y_axis,
287 width: width.unwrap_or(800),
288 height: height.unwrap_or(600),
289 background_color: Color::rgb(255, 255, 255),
290 margin: PlotMargin {
291 top: 50,
292 right: 50,
293 bottom: 80,
294 left: 80,
295 },
296 }
297 }
298
299 pub fn generate_plot_summary(plot_data: &PlotData) -> PlotSummary {
301 match plot_data {
302 PlotData::Scatter(data) => PlotSummary {
303 plot_type: "scatter".to_string(),
304 data_points: data.points.len(),
305 summary_stats: Self::calculate_scatter_stats(&data.points),
306 },
307 PlotData::Line(data) => PlotSummary {
308 plot_type: "line".to_string(),
309 data_points: data.points.len(),
310 summary_stats: Self::calculate_scatter_stats(&data.points),
311 },
312 PlotData::Histogram(data) => PlotSummary {
313 plot_type: "histogram".to_string(),
314 data_points: data.total_count,
315 summary_stats: HashMap::from([
316 ("bins".to_string(), data.counts.len() as f64),
317 (
318 "max_count".to_string(),
319 *data.counts.iter().max().unwrap_or(&0) as f64,
320 ),
321 ]),
322 },
323 PlotData::Heatmap(data) => PlotSummary {
324 plot_type: "heatmap".to_string(),
325 data_points: data.values.len() * data.values.first().map_or(0, |row| row.len()),
326 summary_stats: HashMap::from([
327 ("rows".to_string(), data.values.len() as f64),
328 (
329 "cols".to_string(),
330 data.values.first().map_or(0.0, |row| row.len() as f64),
331 ),
332 ("min_value".to_string(), data.min_value),
333 ("max_value".to_string(), data.max_value),
334 ]),
335 },
336 PlotData::BoxPlot(data) => PlotSummary {
337 plot_type: "boxplot".to_string(),
338 data_points: 1, summary_stats: HashMap::from([
340 ("q1".to_string(), data.q1),
341 ("median".to_string(), data.median),
342 ("q3".to_string(), data.q3),
343 ("outliers".to_string(), data.outliers.len() as f64),
344 ]),
345 },
346 }
347 }
348
349 fn calculate_scatter_stats(points: &[Point2D]) -> HashMap<String, f64> {
350 if points.is_empty() {
351 return HashMap::new();
352 }
353
354 let x_values: Vec<f64> = points.iter().map(|p| p.x).collect();
355 let y_values: Vec<f64> = points.iter().map(|p| p.y).collect();
356
357 let x_min = x_values.iter().cloned().fold(f64::INFINITY, f64::min);
358 let x_max = x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
359 let y_min = y_values.iter().cloned().fold(f64::INFINITY, f64::min);
360 let y_max = y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
361
362 HashMap::from([
363 ("x_min".to_string(), x_min),
364 ("x_max".to_string(), x_max),
365 ("y_min".to_string(), y_min),
366 ("y_max".to_string(), y_max),
367 ("x_range".to_string(), x_max - x_min),
368 ("y_range".to_string(), y_max - y_min),
369 ])
370 }
371}
372
373#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct Point2D {
377 pub x: f64,
378 pub y: f64,
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct ScatterPlotData {
383 pub points: Vec<Point2D>,
384 pub labels: Vec<String>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct LinePlotData {
389 pub points: Vec<Point2D>,
390 pub name: String,
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct HistogramData {
395 pub counts: Vec<usize>,
396 pub bin_edges: Vec<f64>,
397 pub total_count: usize,
398}
399
400#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct HeatmapData {
402 pub values: Vec<Vec<f64>>,
403 pub row_labels: Vec<String>,
404 pub col_labels: Vec<String>,
405 pub min_value: f64,
406 pub max_value: f64,
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct BoxPlotData {
411 pub q1: f64,
412 pub median: f64,
413 pub q3: f64,
414 pub whisker_low: f64,
415 pub whisker_high: f64,
416 pub outliers: Vec<f64>,
417 pub label: String,
418}
419
420#[derive(Debug, Clone, Serialize, Deserialize)]
421pub enum PlotData {
422 Scatter(ScatterPlotData),
423 Line(LinePlotData),
424 Histogram(HistogramData),
425 Heatmap(HeatmapData),
426 BoxPlot(BoxPlotData),
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct Color {
431 pub r: u8,
432 pub g: u8,
433 pub b: u8,
434 pub a: f64,
435}
436
437impl Color {
438 pub fn rgb(r: u8, g: u8, b: u8) -> Self {
439 Self { r, g, b, a: 1.0 }
440 }
441
442 pub fn rgba(r: u8, g: u8, b: u8, a: f64) -> Self {
443 Self { r, g, b, a }
444 }
445
446 pub fn from_hsv(h: f64, s: f64, v: f64) -> Self {
447 let c = v * s;
448 let x = c * (1.0 - ((h / 60.0) % 2.0 - 1.0).abs());
449 let m = v - c;
450
451 let (r_prime, g_prime, b_prime) = if h < 60.0 {
452 (c, x, 0.0)
453 } else if h < 120.0 {
454 (x, c, 0.0)
455 } else if h < 180.0 {
456 (0.0, c, x)
457 } else if h < 240.0 {
458 (0.0, x, c)
459 } else if h < 300.0 {
460 (x, 0.0, c)
461 } else {
462 (c, 0.0, x)
463 };
464
465 Self {
466 r: ((r_prime + m) * 255.0) as u8,
467 g: ((g_prime + m) * 255.0) as u8,
468 b: ((b_prime + m) * 255.0) as u8,
469 a: 1.0,
470 }
471 }
472}
473
474impl fmt::Display for Color {
475 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476 if self.a < 1.0 {
477 write!(f, "rgba({}, {}, {}, {:.2})", self.r, self.g, self.b, self.a)
478 } else {
479 write!(f, "rgb({}, {}, {})", self.r, self.g, self.b)
480 }
481 }
482}
483
484#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct AxisConfig {
486 pub label: String,
487 pub min_value: Option<f64>,
488 pub max_value: Option<f64>,
489 pub tick_count: usize,
490 pub grid_lines: bool,
491 pub log_scale: bool,
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize)]
495pub struct PlotMargin {
496 pub top: u32,
497 pub right: u32,
498 pub bottom: u32,
499 pub left: u32,
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize)]
503pub struct PlotLayout {
504 pub title: String,
505 pub x_axis: AxisConfig,
506 pub y_axis: AxisConfig,
507 pub width: u32,
508 pub height: u32,
509 pub background_color: Color,
510 pub margin: PlotMargin,
511}
512
513#[derive(Debug, Clone)]
514pub struct PlotSummary {
515 pub plot_type: String,
516 pub data_points: usize,
517 pub summary_stats: HashMap<String, f64>,
518}
519
520impl fmt::Display for PlotSummary {
521 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522 writeln!(f, "Plot Summary:")?;
523 writeln!(f, " Type: {}", self.plot_type)?;
524 writeln!(f, " Data Points: {}", self.data_points)?;
525 writeln!(f, " Statistics:")?;
526 for (key, value) in &self.summary_stats {
527 writeln!(f, " {key}: {value:.4}")?;
528 }
529 Ok(())
530 }
531}
532
533pub struct MLVisualizationUtils;
535
536impl MLVisualizationUtils {
537 pub fn prepare_confusion_matrix(
539 y_true: &Array1<usize>,
540 y_pred: &Array1<usize>,
541 class_names: Option<&[String]>,
542 ) -> UtilsResult<HeatmapData> {
543 if y_true.len() != y_pred.len() {
544 return Err(UtilsError::ShapeMismatch {
545 expected: vec![y_true.len()],
546 actual: vec![y_pred.len()],
547 });
548 }
549
550 let num_classes = y_true.iter().max().unwrap_or(&0) + 1;
551 let mut matrix = Array2::zeros((num_classes, num_classes));
552
553 for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
554 matrix[(true_label, pred_label)] += 1.0;
555 }
556
557 let labels = class_names
558 .map(|names| names.to_vec())
559 .unwrap_or_else(|| (0..num_classes).map(|i| format!("Class {i}")).collect());
560
561 ChartData::prepare_heatmap(&matrix, Some(&labels), Some(&labels))
562 }
563
564 pub fn prepare_learning_curve(
566 train_sizes: &Array1<usize>,
567 train_scores: &Array1<f64>,
568 val_scores: &Array1<f64>,
569 ) -> UtilsResult<(LinePlotData, LinePlotData)> {
570 if train_sizes.len() != train_scores.len() || train_sizes.len() != val_scores.len() {
571 return Err(UtilsError::ShapeMismatch {
572 expected: vec![train_sizes.len()],
573 actual: vec![train_scores.len(), val_scores.len()],
574 });
575 }
576
577 let x_values: Array1<f64> = train_sizes.mapv(|x| x as f64);
578
579 let train_line = ChartData::prepare_line_plot(
580 &x_values,
581 train_scores,
582 Some("Training Score".to_string()),
583 )?;
584 let val_line = ChartData::prepare_line_plot(
585 &x_values,
586 val_scores,
587 Some("Validation Score".to_string()),
588 )?;
589
590 Ok((train_line, val_line))
591 }
592
593 pub fn prepare_feature_importance(
595 feature_names: &[String],
596 importance_scores: &Array1<f64>,
597 ) -> UtilsResult<ScatterPlotData> {
598 if feature_names.len() != importance_scores.len() {
599 return Err(UtilsError::ShapeMismatch {
600 expected: vec![feature_names.len()],
601 actual: vec![importance_scores.len()],
602 });
603 }
604
605 let x_values: Array1<f64> = (0..feature_names.len()).map(|i| i as f64).collect();
606 ChartData::prepare_scatter_plot(
607 &x_values,
608 importance_scores,
609 Some(&feature_names.to_vec().into()),
610 )
611 }
612
613 pub fn prepare_roc_curve(
615 fpr: &Array1<f64>,
616 tpr: &Array1<f64>,
617 auc: f64,
618 ) -> UtilsResult<LinePlotData> {
619 if fpr.len() != tpr.len() {
620 return Err(UtilsError::ShapeMismatch {
621 expected: vec![fpr.len()],
622 actual: vec![tpr.len()],
623 });
624 }
625
626 ChartData::prepare_line_plot(fpr, tpr, Some(format!("ROC Curve (AUC = {auc:.3})")))
627 }
628}
629
630#[allow(non_snake_case)]
631#[cfg(test)]
632mod tests {
633 use super::*;
634 use approx::assert_abs_diff_eq;
635 use scirs2_core::ndarray::array;
636
637 #[test]
638 fn test_scatter_plot_preparation() {
639 let x = array![1.0, 2.0, 3.0, 4.0];
640 let y = array![2.0, 4.0, 6.0, 8.0];
641 let labels = array![
642 "A".to_string(),
643 "B".to_string(),
644 "C".to_string(),
645 "D".to_string()
646 ];
647
648 let scatter_data = ChartData::prepare_scatter_plot(&x, &y, Some(&labels)).unwrap();
649
650 assert_eq!(scatter_data.points.len(), 4);
651 assert_eq!(scatter_data.labels.len(), 4);
652 assert_eq!(scatter_data.points[0].x, 1.0);
653 assert_eq!(scatter_data.points[0].y, 2.0);
654 assert_eq!(scatter_data.labels[0], "A");
655 }
656
657 #[test]
658 fn test_scatter_plot_shape_mismatch() {
659 let x = array![1.0, 2.0, 3.0];
660 let y = array![2.0, 4.0];
661
662 let result = ChartData::prepare_scatter_plot(&x, &y, None);
663 assert!(result.is_err());
664 }
665
666 #[test]
667 fn test_line_plot_preparation() {
668 let x = array![1.0, 2.0, 3.0];
669 let y = array![1.0, 4.0, 9.0];
670
671 let line_data =
672 ChartData::prepare_line_plot(&x, &y, Some("Quadratic".to_string())).unwrap();
673
674 assert_eq!(line_data.points.len(), 3);
675 assert_eq!(line_data.name, "Quadratic");
676 assert_eq!(line_data.points[1].x, 2.0);
677 assert_eq!(line_data.points[1].y, 4.0);
678 }
679
680 #[test]
681 fn test_histogram_preparation() {
682 let data = array![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 5.0];
683
684 let hist_data = ChartData::prepare_histogram(&data, Some(4)).unwrap();
685
686 assert_eq!(hist_data.counts.len(), 4);
687 assert_eq!(hist_data.bin_edges.len(), 5);
688 assert_eq!(hist_data.total_count, 8);
689 assert!(hist_data.bin_edges[0] <= 1.0);
690 assert!(hist_data.bin_edges[4] >= 5.0);
691 }
692
693 #[test]
694 fn test_histogram_empty_data() {
695 let data = array![];
696 let result = ChartData::prepare_histogram(&data, Some(10));
697 assert!(result.is_err());
698 }
699
700 #[test]
701 fn test_heatmap_preparation() {
702 let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
703 let row_labels = vec!["Row1".to_string(), "Row2".to_string()];
704 let col_labels = vec!["Col1".to_string(), "Col2".to_string(), "Col3".to_string()];
705
706 let heatmap_data =
707 ChartData::prepare_heatmap(&data, Some(&row_labels), Some(&col_labels)).unwrap();
708
709 assert_eq!(heatmap_data.values.len(), 2);
710 assert_eq!(heatmap_data.values[0].len(), 3);
711 assert_eq!(heatmap_data.row_labels.len(), 2);
712 assert_eq!(heatmap_data.col_labels.len(), 3);
713 assert_eq!(heatmap_data.min_value, 1.0);
714 assert_eq!(heatmap_data.max_value, 6.0);
715 }
716
717 #[test]
718 fn test_box_plot_preparation() {
719 let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
720
721 let box_data = ChartData::prepare_box_plot(&data, Some("Test Data".to_string())).unwrap();
722
723 assert_eq!(box_data.label, "Test Data");
724 assert_abs_diff_eq!(box_data.median, 5.5, epsilon = 1e-10);
725 assert_abs_diff_eq!(box_data.q1, 3.25, epsilon = 1e-10);
726 assert_abs_diff_eq!(box_data.q3, 7.75, epsilon = 1e-10);
727 assert!(box_data.outliers.is_empty());
728 }
729
730 #[test]
731 fn test_box_plot_with_outliers() {
732 let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; let box_data = ChartData::prepare_box_plot(&data, None).unwrap();
735
736 assert!(!box_data.outliers.is_empty());
737 assert!(box_data.outliers.contains(&100.0));
738 }
739
740 #[test]
741 fn test_color_palette_generation() {
742 let colors = PlotUtils::create_color_palette(5);
743 assert_eq!(colors.len(), 5);
744
745 let many_colors = PlotUtils::create_color_palette(15);
747 assert_eq!(many_colors.len(), 15);
748 }
749
750 #[test]
751 fn test_color_from_hsv() {
752 let red = Color::from_hsv(0.0, 1.0, 1.0);
753 assert_eq!(red.r, 255);
754 assert_eq!(red.g, 0);
755 assert_eq!(red.b, 0);
756
757 let green = Color::from_hsv(120.0, 1.0, 1.0);
758 assert_eq!(green.r, 0);
759 assert_eq!(green.g, 255);
760 assert_eq!(green.b, 0);
761 }
762
763 #[test]
764 fn test_color_display() {
765 let color_rgb = Color::rgb(255, 128, 64);
766 assert_eq!(color_rgb.to_string(), "rgb(255, 128, 64)");
767
768 let color_rgba = Color::rgba(255, 128, 64, 0.5);
769 assert_eq!(color_rgba.to_string(), "rgba(255, 128, 64, 0.50)");
770 }
771
772 #[test]
773 fn test_axis_config_creation() {
774 let axis = PlotUtils::create_axis_config("X Axis", Some(0.0), Some(10.0), Some(5));
775
776 assert_eq!(axis.label, "X Axis");
777 assert_eq!(axis.min_value, Some(0.0));
778 assert_eq!(axis.max_value, Some(10.0));
779 assert_eq!(axis.tick_count, 5);
780 assert!(axis.grid_lines);
781 assert!(!axis.log_scale);
782 }
783
784 #[test]
785 fn test_plot_layout_creation() {
786 let x_axis = PlotUtils::create_axis_config("X", None, None, None);
787 let y_axis = PlotUtils::create_axis_config("Y", None, None, None);
788
789 let layout = PlotUtils::create_layout("Test Plot", x_axis, y_axis, Some(1000), Some(800));
790
791 assert_eq!(layout.title, "Test Plot");
792 assert_eq!(layout.width, 1000);
793 assert_eq!(layout.height, 800);
794 }
795
796 #[test]
797 fn test_json_export() {
798 let x = array![1.0, 2.0];
799 let y = array![3.0, 4.0];
800 let scatter_data = ChartData::prepare_scatter_plot(&x, &y, None).unwrap();
801 let plot_data = PlotData::Scatter(scatter_data);
802
803 let json_result = PlotUtils::to_json(&plot_data);
804 assert!(json_result.is_ok());
805
806 let json = json_result.unwrap();
807 assert!(json.contains("Scatter"));
808 assert!(json.contains("points"));
809 }
810
811 #[test]
812 fn test_csv_export() {
813 let x = array![1.0, 2.0];
814 let y = array![3.0, 4.0];
815 let scatter_data = ChartData::prepare_scatter_plot(&x, &y, None).unwrap();
816
817 let csv = PlotUtils::to_csv(&scatter_data).unwrap();
818
819 assert!(csv.contains("x,y,label"));
820 assert!(csv.contains("1,3"));
821 assert!(csv.contains("2,4"));
822 }
823
824 #[test]
825 fn test_plot_summary_generation() {
826 let x = array![1.0, 2.0, 3.0];
827 let y = array![2.0, 4.0, 6.0];
828 let scatter_data = ChartData::prepare_scatter_plot(&x, &y, None).unwrap();
829 let plot_data = PlotData::Scatter(scatter_data);
830
831 let summary = PlotUtils::generate_plot_summary(&plot_data);
832
833 assert_eq!(summary.plot_type, "scatter");
834 assert_eq!(summary.data_points, 3);
835 assert!(summary.summary_stats.contains_key("x_min"));
836 assert!(summary.summary_stats.contains_key("x_max"));
837 assert!(summary.summary_stats.contains_key("y_min"));
838 assert!(summary.summary_stats.contains_key("y_max"));
839 }
840
841 #[test]
842 fn test_confusion_matrix_preparation() {
843 let y_true = array![0, 0, 1, 1, 2, 2];
844 let y_pred = array![0, 1, 1, 1, 2, 0];
845 let class_names = vec![
846 "Class A".to_string(),
847 "Class B".to_string(),
848 "Class C".to_string(),
849 ];
850
851 let heatmap =
852 MLVisualizationUtils::prepare_confusion_matrix(&y_true, &y_pred, Some(&class_names))
853 .unwrap();
854
855 assert_eq!(heatmap.values.len(), 3);
856 assert_eq!(heatmap.values[0].len(), 3);
857 assert_eq!(heatmap.row_labels[0], "Class A");
858 assert_eq!(heatmap.col_labels[1], "Class B");
859
860 assert_eq!(heatmap.values[0][0], 1.0); assert_eq!(heatmap.values[0][1], 1.0); assert_eq!(heatmap.values[1][1], 2.0); }
865
866 #[test]
867 fn test_learning_curve_preparation() {
868 let train_sizes = array![100, 200, 300];
869 let train_scores = array![0.8, 0.85, 0.87];
870 let val_scores = array![0.75, 0.82, 0.83];
871
872 let (train_line, val_line) =
873 MLVisualizationUtils::prepare_learning_curve(&train_sizes, &train_scores, &val_scores)
874 .unwrap();
875
876 assert_eq!(train_line.name, "Training Score");
877 assert_eq!(val_line.name, "Validation Score");
878 assert_eq!(train_line.points.len(), 3);
879 assert_eq!(val_line.points.len(), 3);
880
881 assert_eq!(train_line.points[0].x, 100.0);
882 assert_eq!(train_line.points[0].y, 0.8);
883 assert_eq!(val_line.points[1].x, 200.0);
884 assert_eq!(val_line.points[1].y, 0.82);
885 }
886
887 #[test]
888 fn test_feature_importance_preparation() {
889 let features = vec![
890 "Feature1".to_string(),
891 "Feature2".to_string(),
892 "Feature3".to_string(),
893 ];
894 let importance = array![0.5, 0.3, 0.2];
895
896 let scatter_data =
897 MLVisualizationUtils::prepare_feature_importance(&features, &importance).unwrap();
898
899 assert_eq!(scatter_data.points.len(), 3);
900 assert_eq!(scatter_data.labels.len(), 3);
901 assert_eq!(scatter_data.labels[0], "Feature1");
902 assert_eq!(scatter_data.points[0].x, 0.0);
903 assert_eq!(scatter_data.points[0].y, 0.5);
904 }
905
906 #[test]
907 fn test_roc_curve_preparation() {
908 let fpr = array![0.0, 0.2, 0.4, 1.0];
909 let tpr = array![0.0, 0.6, 0.8, 1.0];
910 let auc = 0.85;
911
912 let roc_line = MLVisualizationUtils::prepare_roc_curve(&fpr, &tpr, auc).unwrap();
913
914 assert_eq!(roc_line.points.len(), 4);
915 assert!(roc_line.name.contains("ROC Curve"));
916 assert!(roc_line.name.contains("0.850"));
917 assert_eq!(roc_line.points[0].x, 0.0);
918 assert_eq!(roc_line.points[0].y, 0.0);
919 assert_eq!(roc_line.points[3].x, 1.0);
920 assert_eq!(roc_line.points[3].y, 1.0);
921 }
922
923 #[test]
924 fn test_plot_summary_display() {
925 let x = array![1.0, 2.0];
926 let y = array![3.0, 4.0];
927 let scatter_data = ChartData::prepare_scatter_plot(&x, &y, None).unwrap();
928 let plot_data = PlotData::Scatter(scatter_data);
929
930 let summary = PlotUtils::generate_plot_summary(&plot_data);
931 let display = format!("{summary}");
932
933 assert!(display.contains("Plot Summary:"));
934 assert!(display.contains("Type: scatter"));
935 assert!(display.contains("Data Points: 2"));
936 assert!(display.contains("Statistics:"));
937 }
938}