sklears_utils/
visualization.rs

1//! Visualization utilities for machine learning data preparation
2//!
3//! This module provides utilities for preparing data for visualization,
4//! chart data formatting, and plotting helpers for ML workflows.
5
6use crate::{UtilsError, UtilsResult};
7use scirs2_core::ndarray::{Array1, Array2, Axis};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt;
11
12/// Chart data preparation utilities
13pub struct ChartData;
14
15impl ChartData {
16    /// Prepare data for scatter plot visualization
17    pub fn prepare_scatter_plot(
18        x: &Array1<f64>,
19        y: &Array1<f64>,
20        labels: Option<&Array1<String>>,
21    ) -> UtilsResult<ScatterPlotData> {
22        if x.len() != y.len() {
23            return Err(UtilsError::ShapeMismatch {
24                expected: vec![x.len()],
25                actual: vec![y.len()],
26            });
27        }
28
29        let points: Vec<Point2D> = x
30            .iter()
31            .zip(y.iter())
32            .map(|(&x_val, &y_val)| Point2D { x: x_val, y: y_val })
33            .collect();
34
35        let labels = labels
36            .map(|l| l.to_vec())
37            .unwrap_or_else(|| (0..x.len()).map(|i| format!("Point {i}")).collect());
38
39        Ok(ScatterPlotData { points, labels })
40    }
41
42    /// Prepare data for line plot visualization
43    pub fn prepare_line_plot(
44        x: &Array1<f64>,
45        y: &Array1<f64>,
46        line_name: Option<String>,
47    ) -> UtilsResult<LinePlotData> {
48        if x.len() != y.len() {
49            return Err(UtilsError::ShapeMismatch {
50                expected: vec![x.len()],
51                actual: vec![y.len()],
52            });
53        }
54
55        let points: Vec<Point2D> = x
56            .iter()
57            .zip(y.iter())
58            .map(|(&x_val, &y_val)| Point2D { x: x_val, y: y_val })
59            .collect();
60
61        Ok(LinePlotData {
62            points,
63            name: line_name.unwrap_or_else(|| "Line".to_string()),
64        })
65    }
66
67    /// Prepare data for histogram visualization
68    pub fn prepare_histogram(
69        data: &Array1<f64>,
70        bins: Option<usize>,
71    ) -> UtilsResult<HistogramData> {
72        if data.is_empty() {
73            return Err(UtilsError::EmptyInput);
74        }
75
76        let bins = bins.unwrap_or(10);
77        let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
78        let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
79
80        if min_val == max_val {
81            return Err(UtilsError::InvalidParameter(
82                "All values are the same, cannot create histogram".to_string(),
83            ));
84        }
85
86        let bin_width = (max_val - min_val) / bins as f64;
87        let mut bin_counts = vec![0; bins];
88        let mut bin_edges = Vec::with_capacity(bins + 1);
89
90        // Create bin edges
91        for i in 0..=bins {
92            bin_edges.push(min_val + i as f64 * bin_width);
93        }
94
95        // Count values in each bin
96        for &value in data.iter() {
97            let bin_index = ((value - min_val) / bin_width).floor() as usize;
98            let bin_index = bin_index.min(bins - 1); // Handle edge case for max value
99            bin_counts[bin_index] += 1;
100        }
101
102        Ok(HistogramData {
103            counts: bin_counts,
104            bin_edges,
105            total_count: data.len(),
106        })
107    }
108
109    /// Prepare data for heatmap visualization
110    pub fn prepare_heatmap(
111        data: &Array2<f64>,
112        row_labels: Option<&[String]>,
113        col_labels: Option<&[String]>,
114    ) -> UtilsResult<HeatmapData> {
115        let (rows, cols) = data.dim();
116
117        if rows == 0 || cols == 0 {
118            return Err(UtilsError::EmptyInput);
119        }
120
121        let values: Vec<Vec<f64>> = data.axis_iter(Axis(0)).map(|row| row.to_vec()).collect();
122
123        let row_labels = row_labels
124            .map(|labels| labels.to_vec())
125            .unwrap_or_else(|| (0..rows).map(|i| format!("Row {i}")).collect());
126
127        let col_labels = col_labels
128            .map(|labels| labels.to_vec())
129            .unwrap_or_else(|| (0..cols).map(|i| format!("Col {i}")).collect());
130
131        // Calculate min/max for color scaling
132        let min_val = data.iter().cloned().fold(f64::INFINITY, f64::min);
133        let max_val = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
134
135        Ok(HeatmapData {
136            values,
137            row_labels,
138            col_labels,
139            min_value: min_val,
140            max_value: max_val,
141        })
142    }
143
144    /// Prepare data for box plot visualization
145    pub fn prepare_box_plot(data: &Array1<f64>, label: Option<String>) -> UtilsResult<BoxPlotData> {
146        if data.is_empty() {
147            return Err(UtilsError::EmptyInput);
148        }
149
150        let mut sorted_data = data.to_vec();
151        sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap());
152
153        let len = sorted_data.len();
154        let q1 = Self::calculate_quantile(&sorted_data, 0.25);
155        let median = Self::calculate_quantile(&sorted_data, 0.5);
156        let q3 = Self::calculate_quantile(&sorted_data, 0.75);
157
158        let iqr = q3 - q1;
159        let lower_fence = q1 - 1.5 * iqr;
160        let upper_fence = q3 + 1.5 * iqr;
161
162        let outliers: Vec<f64> = sorted_data
163            .iter()
164            .copied()
165            .filter(|&x| x < lower_fence || x > upper_fence)
166            .collect();
167
168        let whisker_low = sorted_data
169            .iter()
170            .find(|&&x| x >= lower_fence)
171            .copied()
172            .unwrap_or(sorted_data[0]);
173
174        let whisker_high = sorted_data
175            .iter()
176            .rev()
177            .find(|&&x| x <= upper_fence)
178            .copied()
179            .unwrap_or(sorted_data[len - 1]);
180
181        Ok(BoxPlotData {
182            q1,
183            median,
184            q3,
185            whisker_low,
186            whisker_high,
187            outliers,
188            label: label.unwrap_or_else(|| "Data".to_string()),
189        })
190    }
191
192    fn calculate_quantile(sorted_data: &[f64], quantile: f64) -> f64 {
193        let index = quantile * (sorted_data.len() - 1) as f64;
194        let lower_index = index.floor() as usize;
195        let upper_index = index.ceil() as usize;
196
197        if lower_index == upper_index {
198            sorted_data[lower_index]
199        } else {
200            let weight = index - index.floor();
201            sorted_data[lower_index] * (1.0 - weight) + sorted_data[upper_index] * weight
202        }
203    }
204}
205
206/// Plotting utilities for ML visualization
207pub struct PlotUtils;
208
209impl PlotUtils {
210    /// Create color palette for categorical data
211    pub fn create_color_palette(num_colors: usize) -> Vec<Color> {
212        let base_colors = vec![
213            Color::rgb(31, 119, 180),  // Blue
214            Color::rgb(255, 127, 14),  // Orange
215            Color::rgb(44, 160, 44),   // Green
216            Color::rgb(214, 39, 40),   // Red
217            Color::rgb(148, 103, 189), // Purple
218            Color::rgb(140, 86, 75),   // Brown
219            Color::rgb(227, 119, 194), // Pink
220            Color::rgb(127, 127, 127), // Gray
221            Color::rgb(188, 189, 34),  // Olive
222            Color::rgb(23, 190, 207),  // Cyan
223        ];
224
225        if num_colors <= base_colors.len() {
226            base_colors.into_iter().take(num_colors).collect()
227        } else {
228            // Generate additional colors using HSV color space
229            let base_len = base_colors.len();
230            let mut colors = base_colors;
231            for i in base_len..num_colors {
232                let hue = (i as f64 * 360.0 / num_colors as f64) % 360.0;
233                let color = Color::from_hsv(hue, 0.8, 0.8);
234                colors.push(color);
235            }
236            colors
237        }
238    }
239
240    /// Create axis configuration for plots
241    pub fn create_axis_config(
242        label: &str,
243        min_val: Option<f64>,
244        max_val: Option<f64>,
245        tick_count: Option<usize>,
246    ) -> AxisConfig {
247        AxisConfig {
248            label: label.to_string(),
249            min_value: min_val,
250            max_value: max_val,
251            tick_count: tick_count.unwrap_or(10),
252            grid_lines: true,
253            log_scale: false,
254        }
255    }
256
257    /// Format data for JSON export
258    pub fn to_json(plot_data: &PlotData) -> UtilsResult<String> {
259        serde_json::to_string_pretty(plot_data)
260            .map_err(|e| UtilsError::InvalidParameter(format!("JSON serialization error: {e}")))
261    }
262
263    /// Format data for CSV export
264    pub fn to_csv(scatter_data: &ScatterPlotData) -> UtilsResult<String> {
265        let mut csv = String::new();
266        csv.push_str("x,y,label\n");
267
268        for (point, label) in scatter_data.points.iter().zip(&scatter_data.labels) {
269            csv.push_str(&format!("{},{},{}\n", point.x, point.y, label));
270        }
271
272        Ok(csv)
273    }
274
275    /// Create plot layout configuration
276    pub fn create_layout(
277        title: &str,
278        x_axis: AxisConfig,
279        y_axis: AxisConfig,
280        width: Option<u32>,
281        height: Option<u32>,
282    ) -> PlotLayout {
283        PlotLayout {
284            title: title.to_string(),
285            x_axis,
286            y_axis,
287            width: width.unwrap_or(800),
288            height: height.unwrap_or(600),
289            background_color: Color::rgb(255, 255, 255),
290            margin: PlotMargin {
291                top: 50,
292                right: 50,
293                bottom: 80,
294                left: 80,
295            },
296        }
297    }
298
299    /// Generate plot summary statistics
300    pub fn generate_plot_summary(plot_data: &PlotData) -> PlotSummary {
301        match plot_data {
302            PlotData::Scatter(data) => PlotSummary {
303                plot_type: "scatter".to_string(),
304                data_points: data.points.len(),
305                summary_stats: Self::calculate_scatter_stats(&data.points),
306            },
307            PlotData::Line(data) => PlotSummary {
308                plot_type: "line".to_string(),
309                data_points: data.points.len(),
310                summary_stats: Self::calculate_scatter_stats(&data.points),
311            },
312            PlotData::Histogram(data) => PlotSummary {
313                plot_type: "histogram".to_string(),
314                data_points: data.total_count,
315                summary_stats: HashMap::from([
316                    ("bins".to_string(), data.counts.len() as f64),
317                    (
318                        "max_count".to_string(),
319                        *data.counts.iter().max().unwrap_or(&0) as f64,
320                    ),
321                ]),
322            },
323            PlotData::Heatmap(data) => PlotSummary {
324                plot_type: "heatmap".to_string(),
325                data_points: data.values.len() * data.values.first().map_or(0, |row| row.len()),
326                summary_stats: HashMap::from([
327                    ("rows".to_string(), data.values.len() as f64),
328                    (
329                        "cols".to_string(),
330                        data.values.first().map_or(0.0, |row| row.len() as f64),
331                    ),
332                    ("min_value".to_string(), data.min_value),
333                    ("max_value".to_string(), data.max_value),
334                ]),
335            },
336            PlotData::BoxPlot(data) => PlotSummary {
337                plot_type: "boxplot".to_string(),
338                data_points: 1, // One box
339                summary_stats: HashMap::from([
340                    ("q1".to_string(), data.q1),
341                    ("median".to_string(), data.median),
342                    ("q3".to_string(), data.q3),
343                    ("outliers".to_string(), data.outliers.len() as f64),
344                ]),
345            },
346        }
347    }
348
349    fn calculate_scatter_stats(points: &[Point2D]) -> HashMap<String, f64> {
350        if points.is_empty() {
351            return HashMap::new();
352        }
353
354        let x_values: Vec<f64> = points.iter().map(|p| p.x).collect();
355        let y_values: Vec<f64> = points.iter().map(|p| p.y).collect();
356
357        let x_min = x_values.iter().cloned().fold(f64::INFINITY, f64::min);
358        let x_max = x_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
359        let y_min = y_values.iter().cloned().fold(f64::INFINITY, f64::min);
360        let y_max = y_values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
361
362        HashMap::from([
363            ("x_min".to_string(), x_min),
364            ("x_max".to_string(), x_max),
365            ("y_min".to_string(), y_min),
366            ("y_max".to_string(), y_max),
367            ("x_range".to_string(), x_max - x_min),
368            ("y_range".to_string(), y_max - y_min),
369        ])
370    }
371}
372
373/// Data structures for different plot types
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct Point2D {
377    pub x: f64,
378    pub y: f64,
379}
380
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct ScatterPlotData {
383    pub points: Vec<Point2D>,
384    pub labels: Vec<String>,
385}
386
387#[derive(Debug, Clone, Serialize, Deserialize)]
388pub struct LinePlotData {
389    pub points: Vec<Point2D>,
390    pub name: String,
391}
392
393#[derive(Debug, Clone, Serialize, Deserialize)]
394pub struct HistogramData {
395    pub counts: Vec<usize>,
396    pub bin_edges: Vec<f64>,
397    pub total_count: usize,
398}
399
400#[derive(Debug, Clone, Serialize, Deserialize)]
401pub struct HeatmapData {
402    pub values: Vec<Vec<f64>>,
403    pub row_labels: Vec<String>,
404    pub col_labels: Vec<String>,
405    pub min_value: f64,
406    pub max_value: f64,
407}
408
409#[derive(Debug, Clone, Serialize, Deserialize)]
410pub struct BoxPlotData {
411    pub q1: f64,
412    pub median: f64,
413    pub q3: f64,
414    pub whisker_low: f64,
415    pub whisker_high: f64,
416    pub outliers: Vec<f64>,
417    pub label: String,
418}
419
420#[derive(Debug, Clone, Serialize, Deserialize)]
421pub enum PlotData {
422    Scatter(ScatterPlotData),
423    Line(LinePlotData),
424    Histogram(HistogramData),
425    Heatmap(HeatmapData),
426    BoxPlot(BoxPlotData),
427}
428
429#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct Color {
431    pub r: u8,
432    pub g: u8,
433    pub b: u8,
434    pub a: f64,
435}
436
437impl Color {
438    pub fn rgb(r: u8, g: u8, b: u8) -> Self {
439        Self { r, g, b, a: 1.0 }
440    }
441
442    pub fn rgba(r: u8, g: u8, b: u8, a: f64) -> Self {
443        Self { r, g, b, a }
444    }
445
446    pub fn from_hsv(h: f64, s: f64, v: f64) -> Self {
447        let c = v * s;
448        let x = c * (1.0 - ((h / 60.0) % 2.0 - 1.0).abs());
449        let m = v - c;
450
451        let (r_prime, g_prime, b_prime) = if h < 60.0 {
452            (c, x, 0.0)
453        } else if h < 120.0 {
454            (x, c, 0.0)
455        } else if h < 180.0 {
456            (0.0, c, x)
457        } else if h < 240.0 {
458            (0.0, x, c)
459        } else if h < 300.0 {
460            (x, 0.0, c)
461        } else {
462            (c, 0.0, x)
463        };
464
465        Self {
466            r: ((r_prime + m) * 255.0) as u8,
467            g: ((g_prime + m) * 255.0) as u8,
468            b: ((b_prime + m) * 255.0) as u8,
469            a: 1.0,
470        }
471    }
472}
473
474impl fmt::Display for Color {
475    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
476        if self.a < 1.0 {
477            write!(f, "rgba({}, {}, {}, {:.2})", self.r, self.g, self.b, self.a)
478        } else {
479            write!(f, "rgb({}, {}, {})", self.r, self.g, self.b)
480        }
481    }
482}
483
484#[derive(Debug, Clone, Serialize, Deserialize)]
485pub struct AxisConfig {
486    pub label: String,
487    pub min_value: Option<f64>,
488    pub max_value: Option<f64>,
489    pub tick_count: usize,
490    pub grid_lines: bool,
491    pub log_scale: bool,
492}
493
494#[derive(Debug, Clone, Serialize, Deserialize)]
495pub struct PlotMargin {
496    pub top: u32,
497    pub right: u32,
498    pub bottom: u32,
499    pub left: u32,
500}
501
502#[derive(Debug, Clone, Serialize, Deserialize)]
503pub struct PlotLayout {
504    pub title: String,
505    pub x_axis: AxisConfig,
506    pub y_axis: AxisConfig,
507    pub width: u32,
508    pub height: u32,
509    pub background_color: Color,
510    pub margin: PlotMargin,
511}
512
513#[derive(Debug, Clone)]
514pub struct PlotSummary {
515    pub plot_type: String,
516    pub data_points: usize,
517    pub summary_stats: HashMap<String, f64>,
518}
519
520impl fmt::Display for PlotSummary {
521    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
522        writeln!(f, "Plot Summary:")?;
523        writeln!(f, "  Type: {}", self.plot_type)?;
524        writeln!(f, "  Data Points: {}", self.data_points)?;
525        writeln!(f, "  Statistics:")?;
526        for (key, value) in &self.summary_stats {
527            writeln!(f, "    {key}: {value:.4}")?;
528        }
529        Ok(())
530    }
531}
532
533/// ML visualization specific utilities
534pub struct MLVisualizationUtils;
535
536impl MLVisualizationUtils {
537    /// Prepare confusion matrix for visualization
538    pub fn prepare_confusion_matrix(
539        y_true: &Array1<usize>,
540        y_pred: &Array1<usize>,
541        class_names: Option<&[String]>,
542    ) -> UtilsResult<HeatmapData> {
543        if y_true.len() != y_pred.len() {
544            return Err(UtilsError::ShapeMismatch {
545                expected: vec![y_true.len()],
546                actual: vec![y_pred.len()],
547            });
548        }
549
550        let num_classes = y_true.iter().max().unwrap_or(&0) + 1;
551        let mut matrix = Array2::zeros((num_classes, num_classes));
552
553        for (&true_label, &pred_label) in y_true.iter().zip(y_pred.iter()) {
554            matrix[(true_label, pred_label)] += 1.0;
555        }
556
557        let labels = class_names
558            .map(|names| names.to_vec())
559            .unwrap_or_else(|| (0..num_classes).map(|i| format!("Class {i}")).collect());
560
561        ChartData::prepare_heatmap(&matrix, Some(&labels), Some(&labels))
562    }
563
564    /// Prepare learning curve data for visualization
565    pub fn prepare_learning_curve(
566        train_sizes: &Array1<usize>,
567        train_scores: &Array1<f64>,
568        val_scores: &Array1<f64>,
569    ) -> UtilsResult<(LinePlotData, LinePlotData)> {
570        if train_sizes.len() != train_scores.len() || train_sizes.len() != val_scores.len() {
571            return Err(UtilsError::ShapeMismatch {
572                expected: vec![train_sizes.len()],
573                actual: vec![train_scores.len(), val_scores.len()],
574            });
575        }
576
577        let x_values: Array1<f64> = train_sizes.mapv(|x| x as f64);
578
579        let train_line = ChartData::prepare_line_plot(
580            &x_values,
581            train_scores,
582            Some("Training Score".to_string()),
583        )?;
584        let val_line = ChartData::prepare_line_plot(
585            &x_values,
586            val_scores,
587            Some("Validation Score".to_string()),
588        )?;
589
590        Ok((train_line, val_line))
591    }
592
593    /// Prepare feature importance visualization
594    pub fn prepare_feature_importance(
595        feature_names: &[String],
596        importance_scores: &Array1<f64>,
597    ) -> UtilsResult<ScatterPlotData> {
598        if feature_names.len() != importance_scores.len() {
599            return Err(UtilsError::ShapeMismatch {
600                expected: vec![feature_names.len()],
601                actual: vec![importance_scores.len()],
602            });
603        }
604
605        let x_values: Array1<f64> = (0..feature_names.len()).map(|i| i as f64).collect();
606        ChartData::prepare_scatter_plot(
607            &x_values,
608            importance_scores,
609            Some(&feature_names.to_vec().into()),
610        )
611    }
612
613    /// Prepare ROC curve data
614    pub fn prepare_roc_curve(
615        fpr: &Array1<f64>,
616        tpr: &Array1<f64>,
617        auc: f64,
618    ) -> UtilsResult<LinePlotData> {
619        if fpr.len() != tpr.len() {
620            return Err(UtilsError::ShapeMismatch {
621                expected: vec![fpr.len()],
622                actual: vec![tpr.len()],
623            });
624        }
625
626        ChartData::prepare_line_plot(fpr, tpr, Some(format!("ROC Curve (AUC = {auc:.3})")))
627    }
628}
629
630#[allow(non_snake_case)]
631#[cfg(test)]
632mod tests {
633    use super::*;
634    use approx::assert_abs_diff_eq;
635    use scirs2_core::ndarray::array;
636
637    #[test]
638    fn test_scatter_plot_preparation() {
639        let x = array![1.0, 2.0, 3.0, 4.0];
640        let y = array![2.0, 4.0, 6.0, 8.0];
641        let labels = array![
642            "A".to_string(),
643            "B".to_string(),
644            "C".to_string(),
645            "D".to_string()
646        ];
647
648        let scatter_data = ChartData::prepare_scatter_plot(&x, &y, Some(&labels)).unwrap();
649
650        assert_eq!(scatter_data.points.len(), 4);
651        assert_eq!(scatter_data.labels.len(), 4);
652        assert_eq!(scatter_data.points[0].x, 1.0);
653        assert_eq!(scatter_data.points[0].y, 2.0);
654        assert_eq!(scatter_data.labels[0], "A");
655    }
656
657    #[test]
658    fn test_scatter_plot_shape_mismatch() {
659        let x = array![1.0, 2.0, 3.0];
660        let y = array![2.0, 4.0];
661
662        let result = ChartData::prepare_scatter_plot(&x, &y, None);
663        assert!(result.is_err());
664    }
665
666    #[test]
667    fn test_line_plot_preparation() {
668        let x = array![1.0, 2.0, 3.0];
669        let y = array![1.0, 4.0, 9.0];
670
671        let line_data =
672            ChartData::prepare_line_plot(&x, &y, Some("Quadratic".to_string())).unwrap();
673
674        assert_eq!(line_data.points.len(), 3);
675        assert_eq!(line_data.name, "Quadratic");
676        assert_eq!(line_data.points[1].x, 2.0);
677        assert_eq!(line_data.points[1].y, 4.0);
678    }
679
680    #[test]
681    fn test_histogram_preparation() {
682        let data = array![1.0, 2.0, 2.0, 3.0, 3.0, 3.0, 4.0, 5.0];
683
684        let hist_data = ChartData::prepare_histogram(&data, Some(4)).unwrap();
685
686        assert_eq!(hist_data.counts.len(), 4);
687        assert_eq!(hist_data.bin_edges.len(), 5);
688        assert_eq!(hist_data.total_count, 8);
689        assert!(hist_data.bin_edges[0] <= 1.0);
690        assert!(hist_data.bin_edges[4] >= 5.0);
691    }
692
693    #[test]
694    fn test_histogram_empty_data() {
695        let data = array![];
696        let result = ChartData::prepare_histogram(&data, Some(10));
697        assert!(result.is_err());
698    }
699
700    #[test]
701    fn test_heatmap_preparation() {
702        let data = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]];
703        let row_labels = vec!["Row1".to_string(), "Row2".to_string()];
704        let col_labels = vec!["Col1".to_string(), "Col2".to_string(), "Col3".to_string()];
705
706        let heatmap_data =
707            ChartData::prepare_heatmap(&data, Some(&row_labels), Some(&col_labels)).unwrap();
708
709        assert_eq!(heatmap_data.values.len(), 2);
710        assert_eq!(heatmap_data.values[0].len(), 3);
711        assert_eq!(heatmap_data.row_labels.len(), 2);
712        assert_eq!(heatmap_data.col_labels.len(), 3);
713        assert_eq!(heatmap_data.min_value, 1.0);
714        assert_eq!(heatmap_data.max_value, 6.0);
715    }
716
717    #[test]
718    fn test_box_plot_preparation() {
719        let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
720
721        let box_data = ChartData::prepare_box_plot(&data, Some("Test Data".to_string())).unwrap();
722
723        assert_eq!(box_data.label, "Test Data");
724        assert_abs_diff_eq!(box_data.median, 5.5, epsilon = 1e-10);
725        assert_abs_diff_eq!(box_data.q1, 3.25, epsilon = 1e-10);
726        assert_abs_diff_eq!(box_data.q3, 7.75, epsilon = 1e-10);
727        assert!(box_data.outliers.is_empty());
728    }
729
730    #[test]
731    fn test_box_plot_with_outliers() {
732        let data = array![1.0, 2.0, 3.0, 4.0, 5.0, 100.0]; // 100 is an outlier
733
734        let box_data = ChartData::prepare_box_plot(&data, None).unwrap();
735
736        assert!(!box_data.outliers.is_empty());
737        assert!(box_data.outliers.contains(&100.0));
738    }
739
740    #[test]
741    fn test_color_palette_generation() {
742        let colors = PlotUtils::create_color_palette(5);
743        assert_eq!(colors.len(), 5);
744
745        // Test that we get more colors than base palette
746        let many_colors = PlotUtils::create_color_palette(15);
747        assert_eq!(many_colors.len(), 15);
748    }
749
750    #[test]
751    fn test_color_from_hsv() {
752        let red = Color::from_hsv(0.0, 1.0, 1.0);
753        assert_eq!(red.r, 255);
754        assert_eq!(red.g, 0);
755        assert_eq!(red.b, 0);
756
757        let green = Color::from_hsv(120.0, 1.0, 1.0);
758        assert_eq!(green.r, 0);
759        assert_eq!(green.g, 255);
760        assert_eq!(green.b, 0);
761    }
762
763    #[test]
764    fn test_color_display() {
765        let color_rgb = Color::rgb(255, 128, 64);
766        assert_eq!(color_rgb.to_string(), "rgb(255, 128, 64)");
767
768        let color_rgba = Color::rgba(255, 128, 64, 0.5);
769        assert_eq!(color_rgba.to_string(), "rgba(255, 128, 64, 0.50)");
770    }
771
772    #[test]
773    fn test_axis_config_creation() {
774        let axis = PlotUtils::create_axis_config("X Axis", Some(0.0), Some(10.0), Some(5));
775
776        assert_eq!(axis.label, "X Axis");
777        assert_eq!(axis.min_value, Some(0.0));
778        assert_eq!(axis.max_value, Some(10.0));
779        assert_eq!(axis.tick_count, 5);
780        assert!(axis.grid_lines);
781        assert!(!axis.log_scale);
782    }
783
784    #[test]
785    fn test_plot_layout_creation() {
786        let x_axis = PlotUtils::create_axis_config("X", None, None, None);
787        let y_axis = PlotUtils::create_axis_config("Y", None, None, None);
788
789        let layout = PlotUtils::create_layout("Test Plot", x_axis, y_axis, Some(1000), Some(800));
790
791        assert_eq!(layout.title, "Test Plot");
792        assert_eq!(layout.width, 1000);
793        assert_eq!(layout.height, 800);
794    }
795
796    #[test]
797    fn test_json_export() {
798        let x = array![1.0, 2.0];
799        let y = array![3.0, 4.0];
800        let scatter_data = ChartData::prepare_scatter_plot(&x, &y, None).unwrap();
801        let plot_data = PlotData::Scatter(scatter_data);
802
803        let json_result = PlotUtils::to_json(&plot_data);
804        assert!(json_result.is_ok());
805
806        let json = json_result.unwrap();
807        assert!(json.contains("Scatter"));
808        assert!(json.contains("points"));
809    }
810
811    #[test]
812    fn test_csv_export() {
813        let x = array![1.0, 2.0];
814        let y = array![3.0, 4.0];
815        let scatter_data = ChartData::prepare_scatter_plot(&x, &y, None).unwrap();
816
817        let csv = PlotUtils::to_csv(&scatter_data).unwrap();
818
819        assert!(csv.contains("x,y,label"));
820        assert!(csv.contains("1,3"));
821        assert!(csv.contains("2,4"));
822    }
823
824    #[test]
825    fn test_plot_summary_generation() {
826        let x = array![1.0, 2.0, 3.0];
827        let y = array![2.0, 4.0, 6.0];
828        let scatter_data = ChartData::prepare_scatter_plot(&x, &y, None).unwrap();
829        let plot_data = PlotData::Scatter(scatter_data);
830
831        let summary = PlotUtils::generate_plot_summary(&plot_data);
832
833        assert_eq!(summary.plot_type, "scatter");
834        assert_eq!(summary.data_points, 3);
835        assert!(summary.summary_stats.contains_key("x_min"));
836        assert!(summary.summary_stats.contains_key("x_max"));
837        assert!(summary.summary_stats.contains_key("y_min"));
838        assert!(summary.summary_stats.contains_key("y_max"));
839    }
840
841    #[test]
842    fn test_confusion_matrix_preparation() {
843        let y_true = array![0, 0, 1, 1, 2, 2];
844        let y_pred = array![0, 1, 1, 1, 2, 0];
845        let class_names = vec![
846            "Class A".to_string(),
847            "Class B".to_string(),
848            "Class C".to_string(),
849        ];
850
851        let heatmap =
852            MLVisualizationUtils::prepare_confusion_matrix(&y_true, &y_pred, Some(&class_names))
853                .unwrap();
854
855        assert_eq!(heatmap.values.len(), 3);
856        assert_eq!(heatmap.values[0].len(), 3);
857        assert_eq!(heatmap.row_labels[0], "Class A");
858        assert_eq!(heatmap.col_labels[1], "Class B");
859
860        // Check some values in confusion matrix
861        assert_eq!(heatmap.values[0][0], 1.0); // True A, Pred A
862        assert_eq!(heatmap.values[0][1], 1.0); // True A, Pred B
863        assert_eq!(heatmap.values[1][1], 2.0); // True B, Pred B
864    }
865
866    #[test]
867    fn test_learning_curve_preparation() {
868        let train_sizes = array![100, 200, 300];
869        let train_scores = array![0.8, 0.85, 0.87];
870        let val_scores = array![0.75, 0.82, 0.83];
871
872        let (train_line, val_line) =
873            MLVisualizationUtils::prepare_learning_curve(&train_sizes, &train_scores, &val_scores)
874                .unwrap();
875
876        assert_eq!(train_line.name, "Training Score");
877        assert_eq!(val_line.name, "Validation Score");
878        assert_eq!(train_line.points.len(), 3);
879        assert_eq!(val_line.points.len(), 3);
880
881        assert_eq!(train_line.points[0].x, 100.0);
882        assert_eq!(train_line.points[0].y, 0.8);
883        assert_eq!(val_line.points[1].x, 200.0);
884        assert_eq!(val_line.points[1].y, 0.82);
885    }
886
887    #[test]
888    fn test_feature_importance_preparation() {
889        let features = vec![
890            "Feature1".to_string(),
891            "Feature2".to_string(),
892            "Feature3".to_string(),
893        ];
894        let importance = array![0.5, 0.3, 0.2];
895
896        let scatter_data =
897            MLVisualizationUtils::prepare_feature_importance(&features, &importance).unwrap();
898
899        assert_eq!(scatter_data.points.len(), 3);
900        assert_eq!(scatter_data.labels.len(), 3);
901        assert_eq!(scatter_data.labels[0], "Feature1");
902        assert_eq!(scatter_data.points[0].x, 0.0);
903        assert_eq!(scatter_data.points[0].y, 0.5);
904    }
905
906    #[test]
907    fn test_roc_curve_preparation() {
908        let fpr = array![0.0, 0.2, 0.4, 1.0];
909        let tpr = array![0.0, 0.6, 0.8, 1.0];
910        let auc = 0.85;
911
912        let roc_line = MLVisualizationUtils::prepare_roc_curve(&fpr, &tpr, auc).unwrap();
913
914        assert_eq!(roc_line.points.len(), 4);
915        assert!(roc_line.name.contains("ROC Curve"));
916        assert!(roc_line.name.contains("0.850"));
917        assert_eq!(roc_line.points[0].x, 0.0);
918        assert_eq!(roc_line.points[0].y, 0.0);
919        assert_eq!(roc_line.points[3].x, 1.0);
920        assert_eq!(roc_line.points[3].y, 1.0);
921    }
922
923    #[test]
924    fn test_plot_summary_display() {
925        let x = array![1.0, 2.0];
926        let y = array![3.0, 4.0];
927        let scatter_data = ChartData::prepare_scatter_plot(&x, &y, None).unwrap();
928        let plot_data = PlotData::Scatter(scatter_data);
929
930        let summary = PlotUtils::generate_plot_summary(&plot_data);
931        let display = format!("{summary}");
932
933        assert!(display.contains("Plot Summary:"));
934        assert!(display.contains("Type: scatter"));
935        assert!(display.contains("Data Points: 2"));
936        assert!(display.contains("Statistics:"));
937    }
938}