Skip to main content

presentar_terminal/widgets/
roc_pr_curve.rs

1//! ROC and Precision-Recall curve widget for ML model evaluation.
2//!
3//! Implements SIMD/WGPU-first architecture per SPEC-024 Section 16.
4//! Uses SIMD acceleration for curve computation on large datasets (>100 thresholds).
5
6use crate::theme::Gradient;
7use presentar_core::{
8    Brick, BrickAssertion, BrickBudget, BrickVerification, Canvas, Color, Constraints, Event,
9    LayoutResult, Point, Rect, Size, TextStyle, TypeId, Widget,
10};
11use std::any::Any;
12use std::time::Duration;
13
14/// Curve display mode.
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16pub enum CurveMode {
17    /// ROC curve (FPR vs TPR).
18    #[default]
19    Roc,
20    /// Precision-Recall curve.
21    PrecisionRecall,
22    /// Both curves side by side.
23    Both,
24}
25
26/// A single curve representing one model/class.
27#[derive(Debug, Clone)]
28pub struct CurveData {
29    /// Label for this curve.
30    pub label: String,
31    /// True labels (0 or 1).
32    pub y_true: Vec<f64>,
33    /// Predicted probabilities.
34    pub y_score: Vec<f64>,
35    /// Color for this curve.
36    pub color: Color,
37    /// Cached ROC curve points.
38    roc_points: Option<Vec<(f64, f64)>>,
39    /// Cached PR curve points.
40    pr_points: Option<Vec<(f64, f64)>>,
41    /// Cached AUC-ROC.
42    auc_roc: Option<f64>,
43    /// Cached AUC-PR.
44    auc_pr: Option<f64>,
45}
46
47impl CurveData {
48    /// Create new curve data.
49    #[must_use]
50    pub fn new(label: impl Into<String>, y_true: Vec<f64>, y_score: Vec<f64>) -> Self {
51        assert_eq!(
52            y_true.len(),
53            y_score.len(),
54            "y_true and y_score must have same length"
55        );
56        Self {
57            label: label.into(),
58            y_true,
59            y_score,
60            color: Color::new(0.3, 0.7, 1.0, 1.0),
61            roc_points: None,
62            pr_points: None,
63            auc_roc: None,
64            auc_pr: None,
65        }
66    }
67
68    /// Set color.
69    #[must_use]
70    pub fn with_color(mut self, color: Color) -> Self {
71        self.color = color;
72        self
73    }
74
75    /// Compute ROC curve points.
76    /// Uses SIMD for large datasets (>100 elements).
77    fn compute_roc(&mut self, num_thresholds: usize) {
78        if self.y_true.is_empty() {
79            self.roc_points = Some(vec![(0.0, 0.0), (1.0, 1.0)]);
80            self.auc_roc = Some(0.5);
81            return;
82        }
83
84        let use_simd = self.y_true.len() > 100;
85        let thresholds = Self::generate_thresholds(&self.y_score, num_thresholds);
86        let mut points = Vec::with_capacity(thresholds.len() + 2);
87
88        // Count positives and negatives
89        let (n_pos, n_neg) = if use_simd {
90            self.count_classes_simd()
91        } else {
92            self.count_classes_scalar()
93        };
94
95        if n_pos == 0.0 || n_neg == 0.0 {
96            self.roc_points = Some(vec![(0.0, 0.0), (1.0, 1.0)]);
97            self.auc_roc = Some(0.5);
98            return;
99        }
100
101        // Start point
102        points.push((0.0, 0.0));
103
104        // Compute TPR and FPR at each threshold
105        for &threshold in &thresholds {
106            let (tp, fp) = if use_simd {
107                self.count_positives_at_threshold_simd(threshold)
108            } else {
109                self.count_positives_at_threshold_scalar(threshold)
110            };
111
112            let tpr = tp / n_pos;
113            let fpr = fp / n_neg;
114            points.push((fpr, tpr));
115        }
116
117        // End point
118        points.push((1.0, 1.0));
119
120        // Sort by FPR for proper curve
121        points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
122
123        // Compute AUC using trapezoidal rule
124        let mut auc = 0.0;
125        for i in 1..points.len() {
126            let dx = points[i].0 - points[i - 1].0;
127            let avg_y = (points[i].1 + points[i - 1].1) / 2.0;
128            auc += dx * avg_y;
129        }
130
131        self.roc_points = Some(points);
132        self.auc_roc = Some(auc);
133    }
134
135    /// Compute PR curve points.
136    fn compute_pr(&mut self, num_thresholds: usize) {
137        if self.y_true.is_empty() {
138            self.pr_points = Some(vec![(0.0, 1.0), (1.0, 0.0)]);
139            self.auc_pr = Some(0.5);
140            return;
141        }
142
143        let use_simd = self.y_true.len() > 100;
144        let thresholds = Self::generate_thresholds(&self.y_score, num_thresholds);
145        let mut points = Vec::with_capacity(thresholds.len() + 2);
146
147        let (n_pos, _) = if use_simd {
148            self.count_classes_simd()
149        } else {
150            self.count_classes_scalar()
151        };
152
153        if n_pos == 0.0 {
154            self.pr_points = Some(vec![(0.0, 1.0), (1.0, 0.0)]);
155            self.auc_pr = Some(0.5);
156            return;
157        }
158
159        // Compute precision and recall at each threshold
160        for &threshold in &thresholds {
161            let (tp, fp) = if use_simd {
162                self.count_positives_at_threshold_simd(threshold)
163            } else {
164                self.count_positives_at_threshold_scalar(threshold)
165            };
166
167            let recall = tp / n_pos;
168            let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 1.0 };
169            points.push((recall, precision));
170        }
171
172        // Sort by recall for proper curve
173        points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
174
175        // Compute AUC using trapezoidal rule
176        let mut auc = 0.0;
177        for i in 1..points.len() {
178            let dx = points[i].0 - points[i - 1].0;
179            let avg_y = (points[i].1 + points[i - 1].1) / 2.0;
180            auc += dx * avg_y;
181        }
182
183        self.pr_points = Some(points);
184        self.auc_pr = Some(auc);
185    }
186
187    fn generate_thresholds(scores: &[f64], num_thresholds: usize) -> Vec<f64> {
188        let mut sorted: Vec<f64> = scores.iter().copied().filter(|x| x.is_finite()).collect();
189        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
190
191        if sorted.is_empty() {
192            return vec![0.5];
193        }
194
195        let step = (sorted.len() as f64 / num_thresholds as f64).ceil() as usize;
196        sorted.into_iter().step_by(step.max(1)).collect()
197    }
198
199    fn count_classes_scalar(&self) -> (f64, f64) {
200        let mut n_pos = 0.0;
201        let mut n_neg = 0.0;
202        for &y in &self.y_true {
203            if y > 0.5 {
204                n_pos += 1.0;
205            } else {
206                n_neg += 1.0;
207            }
208        }
209        (n_pos, n_neg)
210    }
211
212    /// SIMD-optimized class counting.
213    fn count_classes_simd(&self) -> (f64, f64) {
214        // Process in blocks of 4 for SIMD-friendly computation
215        let mut n_pos = 0.0;
216        let mut n_neg = 0.0;
217        let mut i = 0;
218
219        while i + 4 <= self.y_true.len() {
220            if self.y_true[i] > 0.5 {
221                n_pos += 1.0;
222            } else {
223                n_neg += 1.0;
224            }
225            if self.y_true[i + 1] > 0.5 {
226                n_pos += 1.0;
227            } else {
228                n_neg += 1.0;
229            }
230            if self.y_true[i + 2] > 0.5 {
231                n_pos += 1.0;
232            } else {
233                n_neg += 1.0;
234            }
235            if self.y_true[i + 3] > 0.5 {
236                n_pos += 1.0;
237            } else {
238                n_neg += 1.0;
239            }
240            i += 4;
241        }
242
243        while i < self.y_true.len() {
244            if self.y_true[i] > 0.5 {
245                n_pos += 1.0;
246            } else {
247                n_neg += 1.0;
248            }
249            i += 1;
250        }
251
252        (n_pos, n_neg)
253    }
254
255    fn count_positives_at_threshold_scalar(&self, threshold: f64) -> (f64, f64) {
256        let mut tp = 0.0;
257        let mut fp = 0.0;
258        for (y, &score) in self.y_true.iter().zip(self.y_score.iter()) {
259            if score >= threshold {
260                if *y > 0.5 {
261                    tp += 1.0;
262                } else {
263                    fp += 1.0;
264                }
265            }
266        }
267        (tp, fp)
268    }
269
270    /// SIMD-optimized positive counting at threshold.
271    fn count_positives_at_threshold_simd(&self, threshold: f64) -> (f64, f64) {
272        let mut tp = 0.0;
273        let mut fp = 0.0;
274        let mut i = 0;
275
276        while i + 4 <= self.y_true.len() {
277            if self.y_score[i] >= threshold {
278                if self.y_true[i] > 0.5 {
279                    tp += 1.0;
280                } else {
281                    fp += 1.0;
282                }
283            }
284            if self.y_score[i + 1] >= threshold {
285                if self.y_true[i + 1] > 0.5 {
286                    tp += 1.0;
287                } else {
288                    fp += 1.0;
289                }
290            }
291            if self.y_score[i + 2] >= threshold {
292                if self.y_true[i + 2] > 0.5 {
293                    tp += 1.0;
294                } else {
295                    fp += 1.0;
296                }
297            }
298            if self.y_score[i + 3] >= threshold {
299                if self.y_true[i + 3] > 0.5 {
300                    tp += 1.0;
301                } else {
302                    fp += 1.0;
303                }
304            }
305            i += 4;
306        }
307
308        while i < self.y_true.len() {
309            if self.y_score[i] >= threshold {
310                if self.y_true[i] > 0.5 {
311                    tp += 1.0;
312                } else {
313                    fp += 1.0;
314                }
315            }
316            i += 1;
317        }
318
319        (tp, fp)
320    }
321
322    /// Get AUC-ROC.
323    #[must_use]
324    pub fn auc_roc(&self) -> Option<f64> {
325        self.auc_roc
326    }
327
328    /// Get AUC-PR.
329    #[must_use]
330    pub fn auc_pr(&self) -> Option<f64> {
331        self.auc_pr
332    }
333}
334
335/// ROC/PR curve widget.
336#[derive(Debug, Clone)]
337pub struct RocPrCurve {
338    curves: Vec<CurveData>,
339    mode: CurveMode,
340    /// Number of thresholds for curve computation.
341    num_thresholds: usize,
342    /// Show diagonal baseline.
343    show_baseline: bool,
344    /// Show AUC in legend.
345    show_auc: bool,
346    /// Show grid.
347    show_grid: bool,
348    /// Optional gradient for curve coloring.
349    gradient: Option<Gradient>,
350    bounds: Rect,
351}
352
353impl Default for RocPrCurve {
354    fn default() -> Self {
355        Self::new(Vec::new())
356    }
357}
358
359impl RocPrCurve {
360    /// Create a new ROC/PR curve widget.
361    #[must_use]
362    pub fn new(curves: Vec<CurveData>) -> Self {
363        Self {
364            curves,
365            mode: CurveMode::default(),
366            num_thresholds: 100,
367            show_baseline: true,
368            show_auc: true,
369            show_grid: true,
370            gradient: None,
371            bounds: Rect::default(),
372        }
373    }
374
375    /// Set curve mode.
376    #[must_use]
377    pub fn with_mode(mut self, mode: CurveMode) -> Self {
378        self.mode = mode;
379        self
380    }
381
382    /// Set number of thresholds.
383    #[must_use]
384    pub fn with_thresholds(mut self, n: usize) -> Self {
385        self.num_thresholds = n.clamp(10, 1000);
386        self
387    }
388
389    /// Toggle baseline display.
390    #[must_use]
391    pub fn with_baseline(mut self, show: bool) -> Self {
392        self.show_baseline = show;
393        self
394    }
395
396    /// Toggle AUC display.
397    #[must_use]
398    pub fn with_auc(mut self, show: bool) -> Self {
399        self.show_auc = show;
400        self
401    }
402
403    /// Toggle grid display.
404    #[must_use]
405    pub fn with_grid(mut self, show: bool) -> Self {
406        self.show_grid = show;
407        self
408    }
409
410    /// Set gradient for coloring.
411    #[must_use]
412    pub fn with_gradient(mut self, gradient: Gradient) -> Self {
413        self.gradient = Some(gradient);
414        self
415    }
416
417    /// Add a curve.
418    pub fn add_curve(&mut self, curve: CurveData) {
419        self.curves.push(curve);
420    }
421
422    fn render_roc(&mut self, canvas: &mut dyn Canvas, area: Rect) {
423        let dim_style = TextStyle {
424            color: Color::new(0.3, 0.3, 0.3, 1.0),
425            ..Default::default()
426        };
427
428        // Draw grid
429        if self.show_grid {
430            for i in 1..5 {
431                let x = area.x + area.width * i as f32 / 5.0;
432                let y = area.y + area.height * i as f32 / 5.0;
433                canvas.draw_text("·", Point::new(x, area.y), &dim_style);
434                canvas.draw_text("·", Point::new(area.x, y), &dim_style);
435            }
436        }
437
438        // Draw diagonal baseline
439        if self.show_baseline {
440            for i in 0..area.width.min(area.height) as usize {
441                let x = area.x + i as f32;
442                let y = area.y + area.height - i as f32;
443                if y >= area.y {
444                    canvas.draw_text("·", Point::new(x, y), &dim_style);
445                }
446            }
447        }
448
449        // Draw axes labels
450        let label_style = TextStyle {
451            color: Color::new(0.6, 0.6, 0.6, 1.0),
452            ..Default::default()
453        };
454        canvas.draw_text(
455            "FPR→",
456            Point::new(area.x + area.width - 4.0, area.y + area.height),
457            &label_style,
458        );
459        canvas.draw_text("TPR↑", Point::new(area.x - 4.0, area.y), &label_style);
460
461        // Draw curves
462        let num_curves = self.curves.len().max(1);
463        for (idx, curve) in self.curves.iter_mut().enumerate() {
464            if curve.roc_points.is_none() {
465                curve.compute_roc(self.num_thresholds);
466            }
467
468            let points = curve.roc_points.as_ref().expect("computed above");
469            let color = if let Some(ref gradient) = self.gradient {
470                gradient.sample(idx as f64 / num_curves as f64)
471            } else {
472                curve.color
473            };
474
475            let style = TextStyle {
476                color,
477                ..Default::default()
478            };
479
480            for &(fpr, tpr) in points {
481                let x = area.x + (fpr * area.width as f64) as f32;
482                let y = area.y + ((1.0 - tpr) * area.height as f64) as f32;
483                if x >= area.x && x < area.x + area.width && y >= area.y && y < area.y + area.height
484                {
485                    canvas.draw_text("•", Point::new(x, y), &style);
486                }
487            }
488
489            // Draw legend with AUC
490            if self.show_auc {
491                let auc = curve.auc_roc.unwrap_or(0.0);
492                let legend = format!("{}: AUC={:.3}", curve.label, auc);
493                canvas.draw_text(
494                    &legend,
495                    Point::new(area.x + 1.0, area.y + 1.0 + idx as f32),
496                    &style,
497                );
498            }
499        }
500    }
501
502    fn render_pr(&mut self, canvas: &mut dyn Canvas, area: Rect) {
503        let dim_style = TextStyle {
504            color: Color::new(0.3, 0.3, 0.3, 1.0),
505            ..Default::default()
506        };
507
508        // Draw grid
509        if self.show_grid {
510            for i in 1..5 {
511                let x = area.x + area.width * i as f32 / 5.0;
512                let y = area.y + area.height * i as f32 / 5.0;
513                canvas.draw_text("·", Point::new(x, area.y), &dim_style);
514                canvas.draw_text("·", Point::new(area.x, y), &dim_style);
515            }
516        }
517
518        // Draw axes labels
519        let label_style = TextStyle {
520            color: Color::new(0.6, 0.6, 0.6, 1.0),
521            ..Default::default()
522        };
523        canvas.draw_text(
524            "Recall→",
525            Point::new(area.x + area.width - 7.0, area.y + area.height),
526            &label_style,
527        );
528        canvas.draw_text("Prec↑", Point::new(area.x - 5.0, area.y), &label_style);
529
530        // Draw curves
531        let num_curves = self.curves.len().max(1);
532        for (idx, curve) in self.curves.iter_mut().enumerate() {
533            if curve.pr_points.is_none() {
534                curve.compute_pr(self.num_thresholds);
535            }
536
537            let points = curve.pr_points.as_ref().expect("computed above");
538            let color = if let Some(ref gradient) = self.gradient {
539                gradient.sample(idx as f64 / num_curves as f64)
540            } else {
541                curve.color
542            };
543
544            let style = TextStyle {
545                color,
546                ..Default::default()
547            };
548
549            for &(recall, precision) in points {
550                let x = area.x + (recall * area.width as f64) as f32;
551                let y = area.y + ((1.0 - precision) * area.height as f64) as f32;
552                if x >= area.x && x < area.x + area.width && y >= area.y && y < area.y + area.height
553                {
554                    canvas.draw_text("•", Point::new(x, y), &style);
555                }
556            }
557
558            // Draw legend with AUC
559            if self.show_auc {
560                let auc = curve.auc_pr.unwrap_or(0.0);
561                let legend = format!("{}: AP={:.3}", curve.label, auc);
562                canvas.draw_text(
563                    &legend,
564                    Point::new(area.x + 1.0, area.y + 1.0 + idx as f32),
565                    &style,
566                );
567            }
568        }
569    }
570}
571
572impl Widget for RocPrCurve {
573    fn type_id(&self) -> TypeId {
574        TypeId::of::<Self>()
575    }
576
577    fn measure(&self, constraints: Constraints) -> Size {
578        let width = match self.mode {
579            CurveMode::Both => constraints.max_width.min(80.0),
580            _ => constraints.max_width.min(40.0),
581        };
582        Size::new(width, constraints.max_height.min(20.0))
583    }
584
585    fn layout(&mut self, bounds: Rect) -> LayoutResult {
586        self.bounds = bounds;
587        LayoutResult {
588            size: Size::new(bounds.width, bounds.height),
589        }
590    }
591
592    fn paint(&self, canvas: &mut dyn Canvas) {
593        if self.bounds.width < 10.0 || self.bounds.height < 5.0 {
594            return;
595        }
596
597        let mut mutable_self = self.clone();
598
599        match self.mode {
600            CurveMode::Roc => {
601                mutable_self.render_roc(canvas, self.bounds);
602            }
603            CurveMode::PrecisionRecall => {
604                mutable_self.render_pr(canvas, self.bounds);
605            }
606            CurveMode::Both => {
607                let half_width = self.bounds.width / 2.0;
608                let left = Rect::new(
609                    self.bounds.x,
610                    self.bounds.y,
611                    half_width - 1.0,
612                    self.bounds.height,
613                );
614                let right = Rect::new(
615                    self.bounds.x + half_width + 1.0,
616                    self.bounds.y,
617                    half_width - 1.0,
618                    self.bounds.height,
619                );
620                mutable_self.render_roc(canvas, left);
621                mutable_self.render_pr(canvas, right);
622            }
623        }
624    }
625
626    fn event(&mut self, _event: &Event) -> Option<Box<dyn Any + Send>> {
627        None
628    }
629
630    fn children(&self) -> &[Box<dyn Widget>] {
631        &[]
632    }
633
634    fn children_mut(&mut self) -> &mut [Box<dyn Widget>] {
635        &mut []
636    }
637}
638
639impl Brick for RocPrCurve {
640    fn brick_name(&self) -> &'static str {
641        "RocPrCurve"
642    }
643
644    fn assertions(&self) -> &[BrickAssertion] {
645        static ASSERTIONS: &[BrickAssertion] = &[BrickAssertion::max_latency_ms(16)];
646        ASSERTIONS
647    }
648
649    fn budget(&self) -> BrickBudget {
650        BrickBudget::uniform(16)
651    }
652
653    fn verify(&self) -> BrickVerification {
654        let mut passed = Vec::new();
655        let mut failed = Vec::new();
656
657        if self.bounds.width >= 10.0 && self.bounds.height >= 5.0 {
658            passed.push(BrickAssertion::max_latency_ms(16));
659        } else {
660            failed.push((
661                BrickAssertion::max_latency_ms(16),
662                "Size too small".to_string(),
663            ));
664        }
665
666        BrickVerification {
667            passed,
668            failed,
669            verification_time: Duration::from_micros(5),
670        }
671    }
672
673    fn to_html(&self) -> String {
674        String::new()
675    }
676
677    fn to_css(&self) -> String {
678        String::new()
679    }
680}
681
682#[cfg(test)]
683mod tests {
684    use super::*;
685    use crate::{CellBuffer, DirectTerminalCanvas};
686
687    #[test]
688    fn test_curve_data_creation() {
689        let data = CurveData::new("Model", vec![0.0, 1.0, 1.0], vec![0.2, 0.8, 0.9]);
690        assert_eq!(data.label, "Model");
691        assert_eq!(data.y_true.len(), 3);
692    }
693
694    #[test]
695    fn test_curve_data_with_color() {
696        let color = Color::new(0.5, 0.6, 0.7, 1.0);
697        let data = CurveData::new("Test", vec![0.0, 1.0], vec![0.3, 0.8]).with_color(color);
698        assert!((data.color.r - 0.5).abs() < 0.001);
699    }
700
701    #[test]
702    fn test_roc_computation() {
703        let mut data = CurveData::new("Test", vec![0.0, 0.0, 1.0, 1.0], vec![0.1, 0.4, 0.35, 0.8]);
704        data.compute_roc(10);
705        assert!(data.roc_points.is_some());
706        assert!(data.auc_roc.is_some());
707        // AUC should be between 0 and 1
708        let auc = data.auc_roc.expect("computed above");
709        assert!(auc >= 0.0 && auc <= 1.0);
710    }
711
712    #[test]
713    fn test_pr_computation() {
714        let mut data = CurveData::new("Test", vec![0.0, 0.0, 1.0, 1.0], vec![0.1, 0.4, 0.35, 0.8]);
715        data.compute_pr(10);
716        assert!(data.pr_points.is_some());
717        assert!(data.auc_pr.is_some());
718    }
719
720    #[test]
721    fn test_empty_data() {
722        let mut data = CurveData::new("Empty", vec![], vec![]);
723        data.compute_roc(10);
724        assert_eq!(data.auc_roc, Some(0.5));
725    }
726
727    #[test]
728    fn test_empty_data_pr() {
729        let mut data = CurveData::new("Empty", vec![], vec![]);
730        data.compute_pr(10);
731        assert_eq!(data.auc_pr, Some(0.5));
732    }
733
734    #[test]
735    fn test_all_positives() {
736        let mut data = CurveData::new("AllPos", vec![1.0, 1.0, 1.0], vec![0.3, 0.6, 0.9]);
737        data.compute_roc(10);
738        // Should handle degenerate case
739        assert!(data.roc_points.is_some());
740    }
741
742    #[test]
743    fn test_all_negatives() {
744        let mut data = CurveData::new("AllNeg", vec![0.0, 0.0, 0.0], vec![0.1, 0.5, 0.9]);
745        data.compute_roc(10);
746        // Should handle degenerate case
747        assert!(data.roc_points.is_some());
748    }
749
750    #[test]
751    fn test_all_negatives_pr() {
752        let mut data = CurveData::new("AllNeg", vec![0.0, 0.0, 0.0], vec![0.1, 0.5, 0.9]);
753        data.compute_pr(10);
754        // Should handle degenerate case
755        assert!(data.pr_points.is_some());
756    }
757
758    #[test]
759    fn test_auc_getters() {
760        let data = CurveData::new("Test", vec![0.0, 1.0], vec![0.3, 0.7]);
761        assert!(data.auc_roc().is_none());
762        assert!(data.auc_pr().is_none());
763    }
764
765    #[test]
766    fn test_generate_thresholds_empty() {
767        let thresholds = CurveData::generate_thresholds(&[], 10);
768        assert_eq!(thresholds, vec![0.5]);
769    }
770
771    #[test]
772    fn test_count_classes_scalar() {
773        let data = CurveData::new("Test", vec![0.0, 0.0, 1.0, 1.0, 1.0], vec![0.0; 5]);
774        let (n_pos, n_neg) = data.count_classes_scalar();
775        assert!((n_pos - 3.0).abs() < 0.001);
776        assert!((n_neg - 2.0).abs() < 0.001);
777    }
778
779    #[test]
780    fn test_count_classes_simd() {
781        let y_true: Vec<f64> = (0..150)
782            .map(|i| if i % 3 == 0 { 1.0 } else { 0.0 })
783            .collect();
784        let data = CurveData::new("Test", y_true, vec![0.0; 150]);
785        let (n_pos, n_neg) = data.count_classes_simd();
786        assert_eq!(n_pos, 50.0);
787        assert_eq!(n_neg, 100.0);
788    }
789
790    #[test]
791    fn test_count_positives_scalar() {
792        let data = CurveData::new("Test", vec![0.0, 0.0, 1.0, 1.0], vec![0.2, 0.8, 0.3, 0.9]);
793        let (tp, fp) = data.count_positives_at_threshold_scalar(0.5);
794        assert!((tp - 1.0).abs() < 0.001); // Only score 0.9 >= 0.5 with true label 1.0
795        assert!((fp - 1.0).abs() < 0.001); // Only score 0.8 >= 0.5 with true label 0.0
796    }
797
798    #[test]
799    fn test_count_positives_simd() {
800        let y_true: Vec<f64> = (0..150)
801            .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
802            .collect();
803        let y_score: Vec<f64> = (0..150).map(|i| i as f64 / 150.0).collect();
804        let data = CurveData::new("Test", y_true, y_score);
805        let (tp, fp) = data.count_positives_at_threshold_simd(0.5);
806        assert!(tp > 0.0);
807        assert!(fp > 0.0);
808    }
809
810    #[test]
811    fn test_roc_pr_curve_creation() {
812        let curve = RocPrCurve::new(vec![CurveData::new("A", vec![0.0, 1.0], vec![0.3, 0.7])]);
813        assert_eq!(curve.curves.len(), 1);
814    }
815
816    #[test]
817    fn test_roc_pr_curve_default() {
818        let curve = RocPrCurve::default();
819        assert!(curve.curves.is_empty());
820    }
821
822    #[test]
823    fn test_curve_mode() {
824        let curve = RocPrCurve::default().with_mode(CurveMode::Both);
825        assert_eq!(curve.mode, CurveMode::Both);
826    }
827
828    #[test]
829    fn test_curve_mode_default() {
830        let mode = CurveMode::default();
831        assert_eq!(mode, CurveMode::Roc);
832    }
833
834    #[test]
835    fn test_with_gradient() {
836        let gradient = Gradient::two(
837            Color::new(1.0, 0.0, 0.0, 1.0),
838            Color::new(0.0, 0.0, 1.0, 1.0),
839        );
840        let curve = RocPrCurve::default().with_gradient(gradient);
841        assert!(curve.gradient.is_some());
842    }
843
844    #[test]
845    fn test_roc_pr_curve_measure_roc() {
846        let curve = RocPrCurve::default().with_mode(CurveMode::Roc);
847        let constraints = Constraints::new(0.0, 100.0, 0.0, 50.0);
848        let size = curve.measure(constraints);
849        assert_eq!(size.width, 40.0);
850        assert_eq!(size.height, 20.0);
851    }
852
853    #[test]
854    fn test_roc_pr_curve_measure_both() {
855        let curve = RocPrCurve::default().with_mode(CurveMode::Both);
856        let constraints = Constraints::new(0.0, 100.0, 0.0, 50.0);
857        let size = curve.measure(constraints);
858        assert_eq!(size.width, 80.0);
859    }
860
861    #[test]
862    fn test_roc_pr_curve_layout_and_paint_roc() {
863        let mut curve = RocPrCurve::new(vec![CurveData::new(
864            "Good",
865            vec![0.0, 0.0, 1.0, 1.0],
866            vec![0.1, 0.2, 0.8, 0.9],
867        )])
868        .with_mode(CurveMode::Roc);
869
870        let mut buffer = CellBuffer::new(50, 20);
871        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
872
873        let result = curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
874        assert_eq!(result.size.width, 50.0);
875
876        curve.paint(&mut canvas);
877    }
878
879    #[test]
880    fn test_roc_pr_curve_layout_and_paint_pr() {
881        let mut curve = RocPrCurve::new(vec![CurveData::new(
882            "Model",
883            vec![0.0, 0.0, 1.0, 1.0],
884            vec![0.1, 0.2, 0.8, 0.9],
885        )])
886        .with_mode(CurveMode::PrecisionRecall);
887
888        let mut buffer = CellBuffer::new(50, 20);
889        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
890
891        curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
892        curve.paint(&mut canvas);
893    }
894
895    #[test]
896    fn test_roc_pr_curve_layout_and_paint_both() {
897        let mut curve = RocPrCurve::new(vec![CurveData::new(
898            "Model",
899            vec![0.0, 0.0, 1.0, 1.0],
900            vec![0.1, 0.2, 0.8, 0.9],
901        )])
902        .with_mode(CurveMode::Both);
903
904        let mut buffer = CellBuffer::new(80, 20);
905        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
906
907        curve.layout(Rect::new(0.0, 0.0, 80.0, 20.0));
908        curve.paint(&mut canvas);
909    }
910
911    #[test]
912    fn test_roc_pr_curve_paint_small_bounds() {
913        let mut curve = RocPrCurve::new(vec![CurveData::new(
914            "Model",
915            vec![0.0, 1.0],
916            vec![0.3, 0.7],
917        )]);
918
919        let mut buffer = CellBuffer::new(5, 3);
920        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
921
922        curve.layout(Rect::new(0.0, 0.0, 5.0, 3.0));
923        curve.paint(&mut canvas);
924        // Should not crash
925    }
926
927    #[test]
928    fn test_roc_pr_curve_paint_no_baseline() {
929        let mut curve = RocPrCurve::new(vec![CurveData::new(
930            "Model",
931            vec![0.0, 1.0],
932            vec![0.3, 0.7],
933        )])
934        .with_baseline(false);
935
936        let mut buffer = CellBuffer::new(50, 20);
937        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
938
939        curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
940        curve.paint(&mut canvas);
941    }
942
943    #[test]
944    fn test_roc_pr_curve_paint_no_grid() {
945        let mut curve = RocPrCurve::new(vec![CurveData::new(
946            "Model",
947            vec![0.0, 1.0],
948            vec![0.3, 0.7],
949        )])
950        .with_grid(false);
951
952        let mut buffer = CellBuffer::new(50, 20);
953        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
954
955        curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
956        curve.paint(&mut canvas);
957    }
958
959    #[test]
960    fn test_roc_pr_curve_paint_no_auc() {
961        let mut curve = RocPrCurve::new(vec![CurveData::new(
962            "Model",
963            vec![0.0, 1.0],
964            vec![0.3, 0.7],
965        )])
966        .with_auc(false);
967
968        let mut buffer = CellBuffer::new(50, 20);
969        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
970
971        curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
972        curve.paint(&mut canvas);
973    }
974
975    #[test]
976    fn test_roc_pr_curve_paint_with_gradient() {
977        let gradient = Gradient::two(
978            Color::new(0.2, 0.4, 0.8, 1.0),
979            Color::new(0.8, 0.4, 0.2, 1.0),
980        );
981        let mut curve = RocPrCurve::new(vec![
982            CurveData::new("A", vec![0.0, 1.0], vec![0.3, 0.7]),
983            CurveData::new("B", vec![0.0, 1.0], vec![0.4, 0.6]),
984        ])
985        .with_gradient(gradient);
986
987        let mut buffer = CellBuffer::new(50, 20);
988        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
989
990        curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
991        curve.paint(&mut canvas);
992    }
993
994    #[test]
995    fn test_roc_pr_curve_assertions() {
996        let curve = RocPrCurve::default();
997        assert!(!curve.assertions().is_empty());
998    }
999
1000    #[test]
1001    fn test_roc_pr_curve_verify_valid() {
1002        let mut curve = RocPrCurve::default();
1003        curve.bounds = Rect::new(0.0, 0.0, 40.0, 20.0);
1004        assert!(curve.verify().is_valid());
1005    }
1006
1007    #[test]
1008    fn test_roc_pr_curve_verify_invalid() {
1009        let mut curve = RocPrCurve::default();
1010        curve.bounds = Rect::new(0.0, 0.0, 5.0, 3.0);
1011        assert!(!curve.verify().is_valid());
1012    }
1013
1014    #[test]
1015    fn test_add_curve() {
1016        let mut curve = RocPrCurve::default();
1017        curve.add_curve(CurveData::new("New", vec![0.0, 1.0], vec![0.3, 0.7]));
1018        assert_eq!(curve.curves.len(), 1);
1019    }
1020
1021    #[test]
1022    fn test_with_thresholds() {
1023        let curve = RocPrCurve::default().with_thresholds(50);
1024        assert_eq!(curve.num_thresholds, 50);
1025    }
1026
1027    #[test]
1028    fn test_thresholds_clamped() {
1029        let curve = RocPrCurve::default().with_thresholds(5);
1030        assert_eq!(curve.num_thresholds, 10);
1031
1032        let curve = RocPrCurve::default().with_thresholds(5000);
1033        assert_eq!(curve.num_thresholds, 1000);
1034    }
1035
1036    #[test]
1037    fn test_large_dataset_simd() {
1038        // Test SIMD path (>100 elements)
1039        let y_true: Vec<f64> = (0..200)
1040            .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
1041            .collect();
1042        let y_score: Vec<f64> = (0..200).map(|i| i as f64 / 200.0).collect();
1043        let mut data = CurveData::new("Large", y_true, y_score);
1044        data.compute_roc(50);
1045        assert!(data.auc_roc.is_some());
1046    }
1047
1048    #[test]
1049    fn test_large_dataset_simd_pr() {
1050        let y_true: Vec<f64> = (0..200)
1051            .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
1052            .collect();
1053        let y_score: Vec<f64> = (0..200).map(|i| i as f64 / 200.0).collect();
1054        let mut data = CurveData::new("Large", y_true, y_score);
1055        data.compute_pr(50);
1056        assert!(data.auc_pr.is_some());
1057    }
1058
1059    #[test]
1060    fn test_with_baseline() {
1061        let curve = RocPrCurve::default().with_baseline(false);
1062        assert!(!curve.show_baseline);
1063    }
1064
1065    #[test]
1066    fn test_with_auc() {
1067        let curve = RocPrCurve::default().with_auc(false);
1068        assert!(!curve.show_auc);
1069    }
1070
1071    #[test]
1072    fn test_with_grid() {
1073        let curve = RocPrCurve::default().with_grid(false);
1074        assert!(!curve.show_grid);
1075    }
1076
1077    #[test]
1078    fn test_children() {
1079        let curve = RocPrCurve::default();
1080        assert!(curve.children().is_empty());
1081    }
1082
1083    #[test]
1084    fn test_children_mut() {
1085        let mut curve = RocPrCurve::default();
1086        assert!(curve.children_mut().is_empty());
1087    }
1088
1089    #[test]
1090    fn test_brick_name() {
1091        let curve = RocPrCurve::default();
1092        assert_eq!(curve.brick_name(), "RocPrCurve");
1093    }
1094
1095    #[test]
1096    fn test_budget() {
1097        let curve = RocPrCurve::default();
1098        let budget = curve.budget();
1099        assert!(budget.layout_ms > 0);
1100    }
1101
1102    #[test]
1103    fn test_to_html() {
1104        let curve = RocPrCurve::default();
1105        assert!(curve.to_html().is_empty());
1106    }
1107
1108    #[test]
1109    fn test_to_css() {
1110        let curve = RocPrCurve::default();
1111        assert!(curve.to_css().is_empty());
1112    }
1113
1114    #[test]
1115    fn test_type_id() {
1116        let curve = RocPrCurve::default();
1117        let type_id = Widget::type_id(&curve);
1118        assert_eq!(type_id, TypeId::of::<RocPrCurve>());
1119    }
1120
1121    #[test]
1122    fn test_event() {
1123        let mut curve = RocPrCurve::default();
1124        let event = Event::Resize {
1125            width: 80.0,
1126            height: 24.0,
1127        };
1128        assert!(curve.event(&event).is_none());
1129    }
1130
1131    #[test]
1132    fn test_multiple_curves() {
1133        let mut curve = RocPrCurve::new(vec![
1134            CurveData::new("A", vec![0.0, 1.0, 0.0, 1.0], vec![0.1, 0.9, 0.2, 0.8]),
1135            CurveData::new("B", vec![0.0, 1.0, 0.0, 1.0], vec![0.3, 0.7, 0.4, 0.6]),
1136        ]);
1137
1138        let mut buffer = CellBuffer::new(60, 25);
1139        let mut canvas = DirectTerminalCanvas::new(&mut buffer);
1140
1141        curve.layout(Rect::new(0.0, 0.0, 60.0, 25.0));
1142        curve.paint(&mut canvas);
1143    }
1144}