Skip to main content

presentar_terminal/widgets/
confusion_matrix.rs

1//! Confusion matrix widget for ML classification visualization.
2//!
3//! Displays a confusion matrix with color-coded cells showing
4//! classification performance across classes.
5
6use presentar_core::{
7    Brick, BrickAssertion, BrickBudget, BrickVerification, Canvas, Color, Constraints, Event,
8    LayoutResult, Point, Rect, Size, TextStyle, TypeId, Widget,
9};
10use std::any::Any;
11use std::time::Duration;
12
13/// Normalization mode for confusion matrix.
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
15pub enum Normalization {
16    /// No normalization (raw counts).
17    #[default]
18    None,
19    /// Normalize by row (recall per class).
20    Row,
21    /// Normalize by column (precision per class).
22    Column,
23    /// Normalize by total (overall distribution).
24    Total,
25}
26
27/// Color palette for confusion matrix cells.
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
29pub enum MatrixPalette {
30    /// Blue (low) to red (high).
31    #[default]
32    BlueRed,
33    /// Green for diagonal, red for off-diagonal.
34    DiagonalGreen,
35    /// Grayscale.
36    Grayscale,
37}
38
39impl MatrixPalette {
40    /// Get color for a normalized value (0.0 to 1.0).
41    #[must_use]
42    pub fn color(&self, value: f64, is_diagonal: bool) -> Color {
43        let v = value.clamp(0.0, 1.0) as f32;
44        match self {
45            Self::BlueRed => Color::new(v, 0.2, 1.0 - v, 1.0),
46            Self::DiagonalGreen => {
47                if is_diagonal {
48                    // Green for diagonal (correct predictions)
49                    Color::new(0.2, 0.3 + v * 0.7, 0.2, 1.0)
50                } else {
51                    // Red for off-diagonal (errors)
52                    Color::new(0.3 + v * 0.7, 0.2, 0.2, 1.0)
53                }
54            }
55            Self::Grayscale => {
56                let g = 0.2 + v * 0.6;
57                Color::new(g, g, g, 1.0)
58            }
59        }
60    }
61}
62
63/// Confusion matrix widget for classification visualization.
64#[derive(Debug, Clone)]
65pub struct ConfusionMatrix {
66    /// Matrix data (rows are actual, columns are predicted).
67    matrix: Vec<Vec<u64>>,
68    /// Class labels.
69    labels: Vec<String>,
70    /// Normalization mode.
71    normalization: Normalization,
72    /// Color palette.
73    palette: MatrixPalette,
74    /// Cell width in characters.
75    cell_width: usize,
76    /// Whether to show values in cells.
77    show_values: bool,
78    /// Whether to show percentages instead of counts.
79    show_percentages: bool,
80    /// Title.
81    title: Option<String>,
82    /// Cached bounds.
83    bounds: Rect,
84}
85
86impl Default for ConfusionMatrix {
87    fn default() -> Self {
88        Self::new(vec![vec![0]])
89    }
90}
91
92impl ConfusionMatrix {
93    /// Create a new confusion matrix.
94    #[must_use]
95    pub fn new(matrix: Vec<Vec<u64>>) -> Self {
96        let size = matrix.len();
97        let labels: Vec<String> = (0..size).map(|i| format!("{i}")).collect();
98        Self {
99            matrix,
100            labels,
101            normalization: Normalization::None,
102            palette: MatrixPalette::default(),
103            cell_width: 6,
104            show_values: true,
105            show_percentages: false,
106            title: None,
107            bounds: Rect::default(),
108        }
109    }
110
111    /// Set class labels.
112    #[must_use]
113    pub fn with_labels(mut self, labels: Vec<String>) -> Self {
114        self.labels = labels;
115        self
116    }
117
118    /// Set normalization mode.
119    #[must_use]
120    pub fn with_normalization(mut self, normalization: Normalization) -> Self {
121        self.normalization = normalization;
122        self
123    }
124
125    /// Set color palette.
126    #[must_use]
127    pub fn with_palette(mut self, palette: MatrixPalette) -> Self {
128        self.palette = palette;
129        self
130    }
131
132    /// Set cell width.
133    #[must_use]
134    pub fn with_cell_width(mut self, width: usize) -> Self {
135        self.cell_width = width.max(3);
136        self
137    }
138
139    /// Show or hide values in cells.
140    #[must_use]
141    pub fn with_values(mut self, show: bool) -> Self {
142        self.show_values = show;
143        self
144    }
145
146    /// Show percentages instead of counts.
147    #[must_use]
148    pub fn with_percentages(mut self, show: bool) -> Self {
149        self.show_percentages = show;
150        self
151    }
152
153    /// Set title.
154    #[must_use]
155    pub fn with_title(mut self, title: impl Into<String>) -> Self {
156        self.title = Some(title.into());
157        self
158    }
159
160    /// Update matrix data.
161    pub fn set_matrix(&mut self, matrix: Vec<Vec<u64>>) {
162        self.matrix = matrix;
163    }
164
165    /// Get matrix dimensions.
166    #[must_use]
167    pub fn size(&self) -> usize {
168        self.matrix.len()
169    }
170
171    /// Get total count.
172    #[must_use]
173    pub fn total(&self) -> u64 {
174        self.matrix.iter().flatten().sum()
175    }
176
177    /// Get accuracy (correct / total).
178    #[must_use]
179    pub fn accuracy(&self) -> f64 {
180        let total = self.total();
181        if total == 0 {
182            return 0.0;
183        }
184        let correct: u64 = self
185            .matrix
186            .iter()
187            .enumerate()
188            .map(|(i, row)| row.get(i).copied().unwrap_or(0))
189            .sum();
190        correct as f64 / total as f64
191    }
192
193    /// Get precision for a class.
194    #[must_use]
195    pub fn precision(&self, class: usize) -> f64 {
196        let col_sum: u64 = self
197            .matrix
198            .iter()
199            .map(|row| row.get(class).copied().unwrap_or(0))
200            .sum();
201        if col_sum == 0 {
202            return 0.0;
203        }
204        self.matrix
205            .get(class)
206            .and_then(|row| row.get(class))
207            .copied()
208            .unwrap_or(0) as f64
209            / col_sum as f64
210    }
211
212    /// Get recall for a class.
213    #[must_use]
214    pub fn recall(&self, class: usize) -> f64 {
215        let row_sum: u64 = self.matrix.get(class).map_or(0, |row| row.iter().sum());
216        if row_sum == 0 {
217            return 0.0;
218        }
219        self.matrix
220            .get(class)
221            .and_then(|row| row.get(class))
222            .copied()
223            .unwrap_or(0) as f64
224            / row_sum as f64
225    }
226
227    /// Get F1 score for a class.
228    #[must_use]
229    pub fn f1_score(&self, class: usize) -> f64 {
230        let p = self.precision(class);
231        let r = self.recall(class);
232        if p + r == 0.0 {
233            return 0.0;
234        }
235        2.0 * p * r / (p + r)
236    }
237
238    fn normalize_value(&self, row: usize, col: usize, value: u64) -> f64 {
239        match self.normalization {
240            Normalization::None => {
241                let max_val = self.matrix.iter().flatten().max().copied().unwrap_or(1);
242                if max_val == 0 {
243                    0.0
244                } else {
245                    value as f64 / max_val as f64
246                }
247            }
248            Normalization::Row => {
249                let row_sum: u64 = self.matrix.get(row).map_or(1, |r| r.iter().sum());
250                if row_sum == 0 {
251                    0.0
252                } else {
253                    value as f64 / row_sum as f64
254                }
255            }
256            Normalization::Column => {
257                let col_sum: u64 = self
258                    .matrix
259                    .iter()
260                    .map(|r| r.get(col).copied().unwrap_or(0))
261                    .sum();
262                if col_sum == 0 {
263                    0.0
264                } else {
265                    value as f64 / col_sum as f64
266                }
267            }
268            Normalization::Total => {
269                let total = self.total();
270                if total == 0 {
271                    0.0
272                } else {
273                    value as f64 / total as f64
274                }
275            }
276        }
277    }
278
279    fn format_value(&self, value: u64, normalized: f64) -> String {
280        if self.show_percentages {
281            format!("{:.0}%", normalized * 100.0)
282        } else {
283            value.to_string()
284        }
285    }
286
287    fn label_width(&self) -> usize {
288        self.labels
289            .iter()
290            .map(String::len)
291            .max()
292            .unwrap_or(3)
293            .max(3)
294    }
295}
296
297impl Brick for ConfusionMatrix {
298    fn brick_name(&self) -> &'static str {
299        "confusion_matrix"
300    }
301
302    fn assertions(&self) -> &[BrickAssertion] {
303        static ASSERTIONS: &[BrickAssertion] = &[BrickAssertion::max_latency_ms(16)];
304        ASSERTIONS
305    }
306
307    fn budget(&self) -> BrickBudget {
308        BrickBudget::uniform(16)
309    }
310
311    fn verify(&self) -> BrickVerification {
312        BrickVerification {
313            passed: self.assertions().to_vec(),
314            failed: vec![],
315            verification_time: Duration::from_micros(10),
316        }
317    }
318
319    fn to_html(&self) -> String {
320        String::new()
321    }
322
323    fn to_css(&self) -> String {
324        String::new()
325    }
326}
327
328impl Widget for ConfusionMatrix {
329    fn type_id(&self) -> TypeId {
330        TypeId::of::<Self>()
331    }
332
333    fn measure(&self, constraints: Constraints) -> Size {
334        let label_w = self.label_width();
335        let n = self.size();
336        let title_rows = if self.title.is_some() { 2 } else { 0 };
337
338        // Width: label column + header labels + cells
339        let width = (label_w + 2 + n * (self.cell_width + 1)) as f32;
340        // Height: title + header row + data rows + accuracy row
341        let height = (title_rows + 1 + n + 1) as f32;
342
343        constraints.constrain(Size::new(width.min(constraints.max_width), height))
344    }
345
346    fn layout(&mut self, bounds: Rect) -> LayoutResult {
347        self.bounds = bounds;
348        LayoutResult {
349            size: Size::new(bounds.width, bounds.height),
350        }
351    }
352
353    fn paint(&self, canvas: &mut dyn Canvas) {
354        if self.matrix.is_empty() || self.bounds.width < 1.0 {
355            return;
356        }
357
358        let label_w = self.label_width();
359        let n = self.size();
360        let mut y = self.bounds.y;
361
362        let header_style = TextStyle {
363            color: Color::new(0.9, 0.9, 0.9, 1.0),
364            ..Default::default()
365        };
366
367        let dim_style = TextStyle {
368            color: Color::new(0.6, 0.6, 0.6, 1.0),
369            ..Default::default()
370        };
371
372        // Draw title
373        if let Some(ref title) = self.title {
374            canvas.draw_text(title, Point::new(self.bounds.x, y), &header_style);
375            y += 2.0;
376        }
377
378        // Draw header row (predicted labels)
379        let header_x = self.bounds.x + label_w as f32 + 2.0;
380        canvas.draw_text("Pred→", Point::new(self.bounds.x, y), &dim_style);
381        for (i, label) in self.labels.iter().enumerate().take(n) {
382            let x = header_x + (i * (self.cell_width + 1)) as f32;
383            let truncated = if label.len() > self.cell_width {
384                &label[..self.cell_width]
385            } else {
386                label
387            };
388            canvas.draw_text(truncated, Point::new(x, y), &header_style);
389        }
390        y += 1.0;
391
392        // Draw matrix rows
393        for (row_idx, row) in self.matrix.iter().enumerate().take(n) {
394            // Row label (actual)
395            let label = self.labels.get(row_idx).map_or("?", String::as_str);
396            let truncated = if label.len() > label_w {
397                &label[..label_w]
398            } else {
399                label
400            };
401            canvas.draw_text(truncated, Point::new(self.bounds.x, y), &header_style);
402
403            // Cells
404            for (col_idx, &value) in row.iter().enumerate().take(n) {
405                let x = header_x + (col_idx * (self.cell_width + 1)) as f32;
406                let normalized = self.normalize_value(row_idx, col_idx, value);
407                let is_diagonal = row_idx == col_idx;
408
409                // Draw cell background
410                let bg_color = self.palette.color(normalized, is_diagonal);
411                canvas.fill_rect(Rect::new(x, y, self.cell_width as f32, 1.0), bg_color);
412
413                // Draw value
414                if self.show_values {
415                    let text = self.format_value(value, normalized);
416                    let text_color = if normalized > 0.5 {
417                        Color::new(0.0, 0.0, 0.0, 1.0) // Dark text on light bg
418                    } else {
419                        Color::new(1.0, 1.0, 1.0, 1.0) // Light text on dark bg
420                    };
421                    let value_style = TextStyle {
422                        color: text_color,
423                        ..Default::default()
424                    };
425                    canvas.draw_text(&text, Point::new(x, y), &value_style);
426                }
427            }
428            y += 1.0;
429        }
430
431        // Draw accuracy
432        let accuracy = self.accuracy();
433        let acc_text = format!("Accuracy: {:.1}%", accuracy * 100.0);
434        canvas.draw_text(&acc_text, Point::new(self.bounds.x, y), &header_style);
435    }
436
437    fn event(&mut self, _event: &Event) -> Option<Box<dyn Any + Send>> {
438        None
439    }
440
441    fn children(&self) -> &[Box<dyn Widget>] {
442        &[]
443    }
444
445    fn children_mut(&mut self) -> &mut [Box<dyn Widget>] {
446        &mut []
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    struct MockCanvas {
455        texts: Vec<(String, Point)>,
456        rects: Vec<(Rect, Color)>,
457    }
458
459    impl MockCanvas {
460        fn new() -> Self {
461            Self {
462                texts: vec![],
463                rects: vec![],
464            }
465        }
466    }
467
468    impl Canvas for MockCanvas {
469        fn fill_rect(&mut self, rect: Rect, color: Color) {
470            self.rects.push((rect, color));
471        }
472        fn stroke_rect(&mut self, _rect: Rect, _color: Color, _width: f32) {}
473        fn draw_text(&mut self, text: &str, position: Point, _style: &TextStyle) {
474            self.texts.push((text.to_string(), position));
475        }
476        fn draw_line(&mut self, _from: Point, _to: Point, _color: Color, _width: f32) {}
477        fn fill_circle(&mut self, _center: Point, _radius: f32, _color: Color) {}
478        fn stroke_circle(&mut self, _center: Point, _radius: f32, _color: Color, _width: f32) {}
479        fn fill_arc(&mut self, _c: Point, _r: f32, _s: f32, _e: f32, _color: Color) {}
480        fn draw_path(&mut self, _points: &[Point], _color: Color, _width: f32) {}
481        fn fill_polygon(&mut self, _points: &[Point], _color: Color) {}
482        fn push_clip(&mut self, _rect: Rect) {}
483        fn pop_clip(&mut self) {}
484        fn push_transform(&mut self, _transform: presentar_core::Transform2D) {}
485        fn pop_transform(&mut self) {}
486    }
487
488    #[test]
489    fn test_confusion_matrix_creation() {
490        let cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]);
491        assert_eq!(cm.size(), 2);
492    }
493
494    #[test]
495    fn test_confusion_matrix_total() {
496        let cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]);
497        assert_eq!(cm.total(), 30);
498    }
499
500    #[test]
501    fn test_confusion_matrix_accuracy() {
502        let cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]);
503        // Correct: 10 + 15 = 25, Total: 30
504        let acc = cm.accuracy();
505        assert!((acc - 0.833).abs() < 0.01);
506    }
507
508    #[test]
509    fn test_confusion_matrix_precision() {
510        // Class 0: col sum = 10 + 3 = 13, diagonal = 10
511        let cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]);
512        let prec = cm.precision(0);
513        assert!((prec - 0.769).abs() < 0.01);
514    }
515
516    #[test]
517    fn test_confusion_matrix_recall() {
518        // Class 0: row sum = 10 + 2 = 12, diagonal = 10
519        let cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]);
520        let recall = cm.recall(0);
521        assert!((recall - 0.833).abs() < 0.01);
522    }
523
524    #[test]
525    fn test_confusion_matrix_f1() {
526        let cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]);
527        let f1 = cm.f1_score(0);
528        assert!(f1 > 0.0 && f1 < 1.0);
529    }
530
531    #[test]
532    fn test_confusion_matrix_with_labels() {
533        let cm = ConfusionMatrix::new(vec![vec![5, 1], vec![2, 8]])
534            .with_labels(vec!["Cat".to_string(), "Dog".to_string()]);
535        assert_eq!(cm.labels.len(), 2);
536        assert_eq!(cm.labels[0], "Cat");
537    }
538
539    #[test]
540    fn test_confusion_matrix_with_normalization() {
541        let cm = ConfusionMatrix::new(vec![vec![5]]).with_normalization(Normalization::Row);
542        assert_eq!(cm.normalization, Normalization::Row);
543    }
544
545    #[test]
546    fn test_confusion_matrix_with_palette() {
547        let cm = ConfusionMatrix::new(vec![vec![5]]).with_palette(MatrixPalette::DiagonalGreen);
548        assert_eq!(cm.palette, MatrixPalette::DiagonalGreen);
549    }
550
551    #[test]
552    fn test_confusion_matrix_with_cell_width() {
553        let cm = ConfusionMatrix::new(vec![vec![5]]).with_cell_width(10);
554        assert_eq!(cm.cell_width, 10);
555    }
556
557    #[test]
558    fn test_confusion_matrix_with_cell_width_min() {
559        let cm = ConfusionMatrix::new(vec![vec![5]]).with_cell_width(1);
560        assert_eq!(cm.cell_width, 3); // Minimum is 3
561    }
562
563    #[test]
564    fn test_confusion_matrix_with_values() {
565        let cm = ConfusionMatrix::new(vec![vec![5]]).with_values(false);
566        assert!(!cm.show_values);
567    }
568
569    #[test]
570    fn test_confusion_matrix_with_percentages() {
571        let cm = ConfusionMatrix::new(vec![vec![5]]).with_percentages(true);
572        assert!(cm.show_percentages);
573    }
574
575    #[test]
576    fn test_confusion_matrix_with_title() {
577        let cm = ConfusionMatrix::new(vec![vec![5]]).with_title("Test Matrix");
578        assert_eq!(cm.title, Some("Test Matrix".to_string()));
579    }
580
581    #[test]
582    fn test_confusion_matrix_set_matrix() {
583        let mut cm = ConfusionMatrix::new(vec![vec![1]]);
584        cm.set_matrix(vec![vec![2, 3], vec![4, 5]]);
585        assert_eq!(cm.size(), 2);
586    }
587
588    #[test]
589    fn test_confusion_matrix_paint() {
590        let mut cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]);
591        cm.bounds = Rect::new(0.0, 0.0, 40.0, 10.0);
592
593        let mut canvas = MockCanvas::new();
594        cm.paint(&mut canvas);
595
596        assert!(!canvas.texts.is_empty());
597        assert!(!canvas.rects.is_empty());
598    }
599
600    #[test]
601    fn test_confusion_matrix_paint_empty() {
602        let cm = ConfusionMatrix::new(vec![]);
603        let mut canvas = MockCanvas::new();
604        cm.paint(&mut canvas);
605        assert!(canvas.texts.is_empty());
606    }
607
608    #[test]
609    fn test_confusion_matrix_measure() {
610        let cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]);
611        let size = cm.measure(Constraints::loose(Size::new(100.0, 50.0)));
612        assert!(size.width > 0.0);
613        assert!(size.height > 0.0);
614    }
615
616    #[test]
617    fn test_confusion_matrix_layout() {
618        let mut cm = ConfusionMatrix::new(vec![vec![1]]);
619        let bounds = Rect::new(5.0, 10.0, 30.0, 20.0);
620        let result = cm.layout(bounds);
621        assert_eq!(result.size.width, 30.0);
622        assert_eq!(cm.bounds, bounds);
623    }
624
625    #[test]
626    fn test_confusion_matrix_brick_name() {
627        let cm = ConfusionMatrix::new(vec![vec![1]]);
628        assert_eq!(cm.brick_name(), "confusion_matrix");
629    }
630
631    #[test]
632    fn test_confusion_matrix_assertions() {
633        let cm = ConfusionMatrix::new(vec![vec![1]]);
634        assert!(!cm.assertions().is_empty());
635    }
636
637    #[test]
638    fn test_confusion_matrix_budget() {
639        let cm = ConfusionMatrix::new(vec![vec![1]]);
640        let budget = cm.budget();
641        assert!(budget.paint_ms > 0);
642    }
643
644    #[test]
645    fn test_confusion_matrix_verify() {
646        let cm = ConfusionMatrix::new(vec![vec![1]]);
647        assert!(cm.verify().is_valid());
648    }
649
650    #[test]
651    fn test_confusion_matrix_type_id() {
652        let cm = ConfusionMatrix::new(vec![vec![1]]);
653        assert_eq!(Widget::type_id(&cm), TypeId::of::<ConfusionMatrix>());
654    }
655
656    #[test]
657    fn test_confusion_matrix_children() {
658        let cm = ConfusionMatrix::new(vec![vec![1]]);
659        assert!(cm.children().is_empty());
660    }
661
662    #[test]
663    fn test_confusion_matrix_children_mut() {
664        let mut cm = ConfusionMatrix::new(vec![vec![1]]);
665        assert!(cm.children_mut().is_empty());
666    }
667
668    #[test]
669    fn test_confusion_matrix_event() {
670        let mut cm = ConfusionMatrix::new(vec![vec![1]]);
671        let event = Event::KeyDown {
672            key: presentar_core::Key::Enter,
673        };
674        assert!(cm.event(&event).is_none());
675    }
676
677    #[test]
678    fn test_confusion_matrix_default() {
679        let cm = ConfusionMatrix::default();
680        assert_eq!(cm.size(), 1);
681    }
682
683    #[test]
684    fn test_confusion_matrix_to_html() {
685        let cm = ConfusionMatrix::new(vec![vec![1]]);
686        assert!(cm.to_html().is_empty());
687    }
688
689    #[test]
690    fn test_confusion_matrix_to_css() {
691        let cm = ConfusionMatrix::new(vec![vec![1]]);
692        assert!(cm.to_css().is_empty());
693    }
694
695    #[test]
696    fn test_palette_blue_red() {
697        let palette = MatrixPalette::BlueRed;
698        let low = palette.color(0.0, false);
699        let high = palette.color(1.0, false);
700        assert!(low.b > low.r); // Blue dominant for low
701        assert!(high.r > high.b); // Red dominant for high
702    }
703
704    #[test]
705    fn test_palette_diagonal_green() {
706        let palette = MatrixPalette::DiagonalGreen;
707        let diag = palette.color(0.8, true);
708        let off_diag = palette.color(0.8, false);
709        assert!(diag.g > diag.r); // Green for diagonal
710        assert!(off_diag.r > off_diag.g); // Red for off-diagonal
711    }
712
713    #[test]
714    fn test_palette_grayscale() {
715        let palette = MatrixPalette::Grayscale;
716        let color = palette.color(0.5, false);
717        assert!((color.r - color.g).abs() < 0.01);
718        assert!((color.g - color.b).abs() < 0.01);
719    }
720
721    #[test]
722    fn test_normalization_default() {
723        assert_eq!(Normalization::default(), Normalization::None);
724    }
725
726    #[test]
727    fn test_zero_accuracy() {
728        let cm = ConfusionMatrix::new(vec![vec![0, 0], vec![0, 0]]);
729        assert_eq!(cm.accuracy(), 0.0);
730    }
731
732    #[test]
733    fn test_zero_precision() {
734        let cm = ConfusionMatrix::new(vec![vec![0, 0], vec![0, 0]]);
735        assert_eq!(cm.precision(0), 0.0);
736    }
737
738    #[test]
739    fn test_zero_recall() {
740        let cm = ConfusionMatrix::new(vec![vec![0, 0], vec![0, 0]]);
741        assert_eq!(cm.recall(0), 0.0);
742    }
743
744    #[test]
745    fn test_zero_f1() {
746        let cm = ConfusionMatrix::new(vec![vec![0, 0], vec![0, 0]]);
747        assert_eq!(cm.f1_score(0), 0.0);
748    }
749
750    #[test]
751    fn test_paint_with_title() {
752        let mut cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]])
753            .with_title("Classification Results");
754        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
755
756        let mut canvas = MockCanvas::new();
757        cm.paint(&mut canvas);
758
759        // Title should be in the texts
760        assert!(canvas
761            .texts
762            .iter()
763            .any(|(t, _)| t.contains("Classification")));
764    }
765
766    #[test]
767    fn test_paint_with_row_normalization() {
768        let mut cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]])
769            .with_normalization(Normalization::Row);
770        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
771
772        let mut canvas = MockCanvas::new();
773        cm.paint(&mut canvas);
774
775        // Should have painted cells
776        assert!(!canvas.rects.is_empty());
777    }
778
779    #[test]
780    fn test_paint_with_column_normalization() {
781        let mut cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]])
782            .with_normalization(Normalization::Column);
783        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
784
785        let mut canvas = MockCanvas::new();
786        cm.paint(&mut canvas);
787
788        // Should have painted cells
789        assert!(!canvas.rects.is_empty());
790    }
791
792    #[test]
793    fn test_paint_with_total_normalization() {
794        let mut cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]])
795            .with_normalization(Normalization::Total);
796        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
797
798        let mut canvas = MockCanvas::new();
799        cm.paint(&mut canvas);
800
801        // Should have painted cells
802        assert!(!canvas.rects.is_empty());
803    }
804
805    #[test]
806    fn test_paint_with_percentages() {
807        let mut cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]).with_percentages(true);
808        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
809
810        let mut canvas = MockCanvas::new();
811        cm.paint(&mut canvas);
812
813        // Should show percentage values
814        assert!(canvas.texts.iter().any(|(t, _)| t.contains('%')));
815    }
816
817    #[test]
818    fn test_normalize_value_none_zero_max() {
819        // Matrix with all zeros - max_val == 0
820        let mut cm = ConfusionMatrix::new(vec![vec![0, 0], vec![0, 0]]);
821        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
822
823        let mut canvas = MockCanvas::new();
824        cm.paint(&mut canvas);
825
826        // Should still paint without error
827        assert!(!canvas.rects.is_empty());
828    }
829
830    #[test]
831    fn test_normalize_row_zero_sum() {
832        // Row with zero sum
833        let mut cm = ConfusionMatrix::new(vec![vec![0, 0], vec![3, 15]])
834            .with_normalization(Normalization::Row);
835        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
836
837        let mut canvas = MockCanvas::new();
838        cm.paint(&mut canvas);
839
840        // Should still paint without error
841        assert!(!canvas.rects.is_empty());
842    }
843
844    #[test]
845    fn test_normalize_column_zero_sum() {
846        // Column with zero sum
847        let mut cm = ConfusionMatrix::new(vec![vec![10, 0], vec![3, 0]])
848            .with_normalization(Normalization::Column);
849        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
850
851        let mut canvas = MockCanvas::new();
852        cm.paint(&mut canvas);
853
854        // Should still paint without error
855        assert!(!canvas.rects.is_empty());
856    }
857
858    #[test]
859    fn test_normalize_total_zero() {
860        // All zeros
861        let mut cm = ConfusionMatrix::new(vec![vec![0, 0], vec![0, 0]])
862            .with_normalization(Normalization::Total);
863        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
864
865        let mut canvas = MockCanvas::new();
866        cm.paint(&mut canvas);
867
868        // Should still paint without error
869        assert!(!canvas.rects.is_empty());
870    }
871
872    #[test]
873    fn test_paint_without_values() {
874        let mut cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]).with_values(false);
875        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
876
877        let mut canvas = MockCanvas::new();
878        cm.paint(&mut canvas);
879
880        // Cells painted but no value texts in cells
881        assert!(!canvas.rects.is_empty());
882    }
883
884    #[test]
885    fn test_paint_long_labels() {
886        let mut cm = ConfusionMatrix::new(vec![vec![10, 2], vec![3, 15]]).with_labels(vec![
887            "VeryLongLabelName".to_string(),
888            "AnotherLongLabel".to_string(),
889        ]);
890        cm.bounds = Rect::new(0.0, 0.0, 50.0, 15.0);
891
892        let mut canvas = MockCanvas::new();
893        cm.paint(&mut canvas);
894
895        // Should truncate labels
896        assert!(!canvas.texts.is_empty());
897    }
898}