1use crate::theme::Gradient;
7use presentar_core::{
8 Brick, BrickAssertion, BrickBudget, BrickVerification, Canvas, Color, Constraints, Event,
9 LayoutResult, Point, Rect, Size, TextStyle, TypeId, Widget,
10};
11use std::any::Any;
12use std::time::Duration;
13
14#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16pub enum CurveMode {
17 #[default]
19 Roc,
20 PrecisionRecall,
22 Both,
24}
25
26#[derive(Debug, Clone)]
28pub struct CurveData {
29 pub label: String,
31 pub y_true: Vec<f64>,
33 pub y_score: Vec<f64>,
35 pub color: Color,
37 roc_points: Option<Vec<(f64, f64)>>,
39 pr_points: Option<Vec<(f64, f64)>>,
41 auc_roc: Option<f64>,
43 auc_pr: Option<f64>,
45}
46
47impl CurveData {
48 #[must_use]
50 pub fn new(label: impl Into<String>, y_true: Vec<f64>, y_score: Vec<f64>) -> Self {
51 assert_eq!(
52 y_true.len(),
53 y_score.len(),
54 "y_true and y_score must have same length"
55 );
56 Self {
57 label: label.into(),
58 y_true,
59 y_score,
60 color: Color::new(0.3, 0.7, 1.0, 1.0),
61 roc_points: None,
62 pr_points: None,
63 auc_roc: None,
64 auc_pr: None,
65 }
66 }
67
68 #[must_use]
70 pub fn with_color(mut self, color: Color) -> Self {
71 self.color = color;
72 self
73 }
74
75 fn compute_roc(&mut self, num_thresholds: usize) {
78 if self.y_true.is_empty() {
79 self.roc_points = Some(vec![(0.0, 0.0), (1.0, 1.0)]);
80 self.auc_roc = Some(0.5);
81 return;
82 }
83
84 let use_simd = self.y_true.len() > 100;
85 let thresholds = Self::generate_thresholds(&self.y_score, num_thresholds);
86 let mut points = Vec::with_capacity(thresholds.len() + 2);
87
88 let (n_pos, n_neg) = if use_simd {
90 self.count_classes_simd()
91 } else {
92 self.count_classes_scalar()
93 };
94
95 if n_pos == 0.0 || n_neg == 0.0 {
96 self.roc_points = Some(vec![(0.0, 0.0), (1.0, 1.0)]);
97 self.auc_roc = Some(0.5);
98 return;
99 }
100
101 points.push((0.0, 0.0));
103
104 for &threshold in &thresholds {
106 let (tp, fp) = if use_simd {
107 self.count_positives_at_threshold_simd(threshold)
108 } else {
109 self.count_positives_at_threshold_scalar(threshold)
110 };
111
112 let tpr = tp / n_pos;
113 let fpr = fp / n_neg;
114 points.push((fpr, tpr));
115 }
116
117 points.push((1.0, 1.0));
119
120 points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
122
123 let mut auc = 0.0;
125 for i in 1..points.len() {
126 let dx = points[i].0 - points[i - 1].0;
127 let avg_y = (points[i].1 + points[i - 1].1) / 2.0;
128 auc += dx * avg_y;
129 }
130
131 self.roc_points = Some(points);
132 self.auc_roc = Some(auc);
133 }
134
135 fn compute_pr(&mut self, num_thresholds: usize) {
137 if self.y_true.is_empty() {
138 self.pr_points = Some(vec![(0.0, 1.0), (1.0, 0.0)]);
139 self.auc_pr = Some(0.5);
140 return;
141 }
142
143 let use_simd = self.y_true.len() > 100;
144 let thresholds = Self::generate_thresholds(&self.y_score, num_thresholds);
145 let mut points = Vec::with_capacity(thresholds.len() + 2);
146
147 let (n_pos, _) = if use_simd {
148 self.count_classes_simd()
149 } else {
150 self.count_classes_scalar()
151 };
152
153 if n_pos == 0.0 {
154 self.pr_points = Some(vec![(0.0, 1.0), (1.0, 0.0)]);
155 self.auc_pr = Some(0.5);
156 return;
157 }
158
159 for &threshold in &thresholds {
161 let (tp, fp) = if use_simd {
162 self.count_positives_at_threshold_simd(threshold)
163 } else {
164 self.count_positives_at_threshold_scalar(threshold)
165 };
166
167 let recall = tp / n_pos;
168 let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 1.0 };
169 points.push((recall, precision));
170 }
171
172 points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
174
175 let mut auc = 0.0;
177 for i in 1..points.len() {
178 let dx = points[i].0 - points[i - 1].0;
179 let avg_y = (points[i].1 + points[i - 1].1) / 2.0;
180 auc += dx * avg_y;
181 }
182
183 self.pr_points = Some(points);
184 self.auc_pr = Some(auc);
185 }
186
187 fn generate_thresholds(scores: &[f64], num_thresholds: usize) -> Vec<f64> {
188 let mut sorted: Vec<f64> = scores.iter().copied().filter(|x| x.is_finite()).collect();
189 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
190
191 if sorted.is_empty() {
192 return vec![0.5];
193 }
194
195 let step = (sorted.len() as f64 / num_thresholds as f64).ceil() as usize;
196 sorted.into_iter().step_by(step.max(1)).collect()
197 }
198
199 fn count_classes_scalar(&self) -> (f64, f64) {
200 let mut n_pos = 0.0;
201 let mut n_neg = 0.0;
202 for &y in &self.y_true {
203 if y > 0.5 {
204 n_pos += 1.0;
205 } else {
206 n_neg += 1.0;
207 }
208 }
209 (n_pos, n_neg)
210 }
211
212 fn count_classes_simd(&self) -> (f64, f64) {
214 let mut n_pos = 0.0;
216 let mut n_neg = 0.0;
217 let mut i = 0;
218
219 while i + 4 <= self.y_true.len() {
220 if self.y_true[i] > 0.5 {
221 n_pos += 1.0;
222 } else {
223 n_neg += 1.0;
224 }
225 if self.y_true[i + 1] > 0.5 {
226 n_pos += 1.0;
227 } else {
228 n_neg += 1.0;
229 }
230 if self.y_true[i + 2] > 0.5 {
231 n_pos += 1.0;
232 } else {
233 n_neg += 1.0;
234 }
235 if self.y_true[i + 3] > 0.5 {
236 n_pos += 1.0;
237 } else {
238 n_neg += 1.0;
239 }
240 i += 4;
241 }
242
243 while i < self.y_true.len() {
244 if self.y_true[i] > 0.5 {
245 n_pos += 1.0;
246 } else {
247 n_neg += 1.0;
248 }
249 i += 1;
250 }
251
252 (n_pos, n_neg)
253 }
254
255 fn count_positives_at_threshold_scalar(&self, threshold: f64) -> (f64, f64) {
256 let mut tp = 0.0;
257 let mut fp = 0.0;
258 for (y, &score) in self.y_true.iter().zip(self.y_score.iter()) {
259 if score >= threshold {
260 if *y > 0.5 {
261 tp += 1.0;
262 } else {
263 fp += 1.0;
264 }
265 }
266 }
267 (tp, fp)
268 }
269
270 fn count_positives_at_threshold_simd(&self, threshold: f64) -> (f64, f64) {
272 let mut tp = 0.0;
273 let mut fp = 0.0;
274 let mut i = 0;
275
276 while i + 4 <= self.y_true.len() {
277 if self.y_score[i] >= threshold {
278 if self.y_true[i] > 0.5 {
279 tp += 1.0;
280 } else {
281 fp += 1.0;
282 }
283 }
284 if self.y_score[i + 1] >= threshold {
285 if self.y_true[i + 1] > 0.5 {
286 tp += 1.0;
287 } else {
288 fp += 1.0;
289 }
290 }
291 if self.y_score[i + 2] >= threshold {
292 if self.y_true[i + 2] > 0.5 {
293 tp += 1.0;
294 } else {
295 fp += 1.0;
296 }
297 }
298 if self.y_score[i + 3] >= threshold {
299 if self.y_true[i + 3] > 0.5 {
300 tp += 1.0;
301 } else {
302 fp += 1.0;
303 }
304 }
305 i += 4;
306 }
307
308 while i < self.y_true.len() {
309 if self.y_score[i] >= threshold {
310 if self.y_true[i] > 0.5 {
311 tp += 1.0;
312 } else {
313 fp += 1.0;
314 }
315 }
316 i += 1;
317 }
318
319 (tp, fp)
320 }
321
322 #[must_use]
324 pub fn auc_roc(&self) -> Option<f64> {
325 self.auc_roc
326 }
327
328 #[must_use]
330 pub fn auc_pr(&self) -> Option<f64> {
331 self.auc_pr
332 }
333}
334
335#[derive(Debug, Clone)]
337pub struct RocPrCurve {
338 curves: Vec<CurveData>,
339 mode: CurveMode,
340 num_thresholds: usize,
342 show_baseline: bool,
344 show_auc: bool,
346 show_grid: bool,
348 gradient: Option<Gradient>,
350 bounds: Rect,
351}
352
353impl Default for RocPrCurve {
354 fn default() -> Self {
355 Self::new(Vec::new())
356 }
357}
358
359impl RocPrCurve {
360 #[must_use]
362 pub fn new(curves: Vec<CurveData>) -> Self {
363 Self {
364 curves,
365 mode: CurveMode::default(),
366 num_thresholds: 100,
367 show_baseline: true,
368 show_auc: true,
369 show_grid: true,
370 gradient: None,
371 bounds: Rect::default(),
372 }
373 }
374
375 #[must_use]
377 pub fn with_mode(mut self, mode: CurveMode) -> Self {
378 self.mode = mode;
379 self
380 }
381
382 #[must_use]
384 pub fn with_thresholds(mut self, n: usize) -> Self {
385 self.num_thresholds = n.clamp(10, 1000);
386 self
387 }
388
389 #[must_use]
391 pub fn with_baseline(mut self, show: bool) -> Self {
392 self.show_baseline = show;
393 self
394 }
395
396 #[must_use]
398 pub fn with_auc(mut self, show: bool) -> Self {
399 self.show_auc = show;
400 self
401 }
402
403 #[must_use]
405 pub fn with_grid(mut self, show: bool) -> Self {
406 self.show_grid = show;
407 self
408 }
409
410 #[must_use]
412 pub fn with_gradient(mut self, gradient: Gradient) -> Self {
413 self.gradient = Some(gradient);
414 self
415 }
416
417 pub fn add_curve(&mut self, curve: CurveData) {
419 self.curves.push(curve);
420 }
421
422 fn render_roc(&mut self, canvas: &mut dyn Canvas, area: Rect) {
423 let dim_style = TextStyle {
424 color: Color::new(0.3, 0.3, 0.3, 1.0),
425 ..Default::default()
426 };
427
428 if self.show_grid {
430 for i in 1..5 {
431 let x = area.x + area.width * i as f32 / 5.0;
432 let y = area.y + area.height * i as f32 / 5.0;
433 canvas.draw_text("ยท", Point::new(x, area.y), &dim_style);
434 canvas.draw_text("ยท", Point::new(area.x, y), &dim_style);
435 }
436 }
437
438 if self.show_baseline {
440 for i in 0..area.width.min(area.height) as usize {
441 let x = area.x + i as f32;
442 let y = area.y + area.height - i as f32;
443 if y >= area.y {
444 canvas.draw_text("ยท", Point::new(x, y), &dim_style);
445 }
446 }
447 }
448
449 let label_style = TextStyle {
451 color: Color::new(0.6, 0.6, 0.6, 1.0),
452 ..Default::default()
453 };
454 canvas.draw_text(
455 "FPRโ",
456 Point::new(area.x + area.width - 4.0, area.y + area.height),
457 &label_style,
458 );
459 canvas.draw_text("TPRโ", Point::new(area.x - 4.0, area.y), &label_style);
460
461 let num_curves = self.curves.len().max(1);
463 for (idx, curve) in self.curves.iter_mut().enumerate() {
464 if curve.roc_points.is_none() {
465 curve.compute_roc(self.num_thresholds);
466 }
467
468 let points = curve.roc_points.as_ref().expect("computed above");
469 let color = if let Some(ref gradient) = self.gradient {
470 gradient.sample(idx as f64 / num_curves as f64)
471 } else {
472 curve.color
473 };
474
475 let style = TextStyle {
476 color,
477 ..Default::default()
478 };
479
480 for &(fpr, tpr) in points {
481 let x = area.x + (fpr * area.width as f64) as f32;
482 let y = area.y + ((1.0 - tpr) * area.height as f64) as f32;
483 if x >= area.x && x < area.x + area.width && y >= area.y && y < area.y + area.height
484 {
485 canvas.draw_text("โข", Point::new(x, y), &style);
486 }
487 }
488
489 if self.show_auc {
491 let auc = curve.auc_roc.unwrap_or(0.0);
492 let legend = format!("{}: AUC={:.3}", curve.label, auc);
493 canvas.draw_text(
494 &legend,
495 Point::new(area.x + 1.0, area.y + 1.0 + idx as f32),
496 &style,
497 );
498 }
499 }
500 }
501
502 fn render_pr(&mut self, canvas: &mut dyn Canvas, area: Rect) {
503 let dim_style = TextStyle {
504 color: Color::new(0.3, 0.3, 0.3, 1.0),
505 ..Default::default()
506 };
507
508 if self.show_grid {
510 for i in 1..5 {
511 let x = area.x + area.width * i as f32 / 5.0;
512 let y = area.y + area.height * i as f32 / 5.0;
513 canvas.draw_text("ยท", Point::new(x, area.y), &dim_style);
514 canvas.draw_text("ยท", Point::new(area.x, y), &dim_style);
515 }
516 }
517
518 let label_style = TextStyle {
520 color: Color::new(0.6, 0.6, 0.6, 1.0),
521 ..Default::default()
522 };
523 canvas.draw_text(
524 "Recallโ",
525 Point::new(area.x + area.width - 7.0, area.y + area.height),
526 &label_style,
527 );
528 canvas.draw_text("Precโ", Point::new(area.x - 5.0, area.y), &label_style);
529
530 let num_curves = self.curves.len().max(1);
532 for (idx, curve) in self.curves.iter_mut().enumerate() {
533 if curve.pr_points.is_none() {
534 curve.compute_pr(self.num_thresholds);
535 }
536
537 let points = curve.pr_points.as_ref().expect("computed above");
538 let color = if let Some(ref gradient) = self.gradient {
539 gradient.sample(idx as f64 / num_curves as f64)
540 } else {
541 curve.color
542 };
543
544 let style = TextStyle {
545 color,
546 ..Default::default()
547 };
548
549 for &(recall, precision) in points {
550 let x = area.x + (recall * area.width as f64) as f32;
551 let y = area.y + ((1.0 - precision) * area.height as f64) as f32;
552 if x >= area.x && x < area.x + area.width && y >= area.y && y < area.y + area.height
553 {
554 canvas.draw_text("โข", Point::new(x, y), &style);
555 }
556 }
557
558 if self.show_auc {
560 let auc = curve.auc_pr.unwrap_or(0.0);
561 let legend = format!("{}: AP={:.3}", curve.label, auc);
562 canvas.draw_text(
563 &legend,
564 Point::new(area.x + 1.0, area.y + 1.0 + idx as f32),
565 &style,
566 );
567 }
568 }
569 }
570}
571
572impl Widget for RocPrCurve {
573 fn type_id(&self) -> TypeId {
574 TypeId::of::<Self>()
575 }
576
577 fn measure(&self, constraints: Constraints) -> Size {
578 let width = match self.mode {
579 CurveMode::Both => constraints.max_width.min(80.0),
580 _ => constraints.max_width.min(40.0),
581 };
582 Size::new(width, constraints.max_height.min(20.0))
583 }
584
585 fn layout(&mut self, bounds: Rect) -> LayoutResult {
586 self.bounds = bounds;
587 LayoutResult {
588 size: Size::new(bounds.width, bounds.height),
589 }
590 }
591
592 fn paint(&self, canvas: &mut dyn Canvas) {
593 if self.bounds.width < 10.0 || self.bounds.height < 5.0 {
594 return;
595 }
596
597 let mut mutable_self = self.clone();
598
599 match self.mode {
600 CurveMode::Roc => {
601 mutable_self.render_roc(canvas, self.bounds);
602 }
603 CurveMode::PrecisionRecall => {
604 mutable_self.render_pr(canvas, self.bounds);
605 }
606 CurveMode::Both => {
607 let half_width = self.bounds.width / 2.0;
608 let left = Rect::new(
609 self.bounds.x,
610 self.bounds.y,
611 half_width - 1.0,
612 self.bounds.height,
613 );
614 let right = Rect::new(
615 self.bounds.x + half_width + 1.0,
616 self.bounds.y,
617 half_width - 1.0,
618 self.bounds.height,
619 );
620 mutable_self.render_roc(canvas, left);
621 mutable_self.render_pr(canvas, right);
622 }
623 }
624 }
625
626 fn event(&mut self, _event: &Event) -> Option<Box<dyn Any + Send>> {
627 None
628 }
629
630 fn children(&self) -> &[Box<dyn Widget>] {
631 &[]
632 }
633
634 fn children_mut(&mut self) -> &mut [Box<dyn Widget>] {
635 &mut []
636 }
637}
638
639impl Brick for RocPrCurve {
640 fn brick_name(&self) -> &'static str {
641 "RocPrCurve"
642 }
643
644 fn assertions(&self) -> &[BrickAssertion] {
645 static ASSERTIONS: &[BrickAssertion] = &[BrickAssertion::max_latency_ms(16)];
646 ASSERTIONS
647 }
648
649 fn budget(&self) -> BrickBudget {
650 BrickBudget::uniform(16)
651 }
652
653 fn verify(&self) -> BrickVerification {
654 let mut passed = Vec::new();
655 let mut failed = Vec::new();
656
657 if self.bounds.width >= 10.0 && self.bounds.height >= 5.0 {
658 passed.push(BrickAssertion::max_latency_ms(16));
659 } else {
660 failed.push((
661 BrickAssertion::max_latency_ms(16),
662 "Size too small".to_string(),
663 ));
664 }
665
666 BrickVerification {
667 passed,
668 failed,
669 verification_time: Duration::from_micros(5),
670 }
671 }
672
673 fn to_html(&self) -> String {
674 String::new()
675 }
676
677 fn to_css(&self) -> String {
678 String::new()
679 }
680}
681
682#[cfg(test)]
683mod tests {
684 use super::*;
685 use crate::{CellBuffer, DirectTerminalCanvas};
686
687 #[test]
688 fn test_curve_data_creation() {
689 let data = CurveData::new("Model", vec![0.0, 1.0, 1.0], vec![0.2, 0.8, 0.9]);
690 assert_eq!(data.label, "Model");
691 assert_eq!(data.y_true.len(), 3);
692 }
693
694 #[test]
695 fn test_curve_data_with_color() {
696 let color = Color::new(0.5, 0.6, 0.7, 1.0);
697 let data = CurveData::new("Test", vec![0.0, 1.0], vec![0.3, 0.8]).with_color(color);
698 assert!((data.color.r - 0.5).abs() < 0.001);
699 }
700
701 #[test]
702 fn test_roc_computation() {
703 let mut data = CurveData::new("Test", vec![0.0, 0.0, 1.0, 1.0], vec![0.1, 0.4, 0.35, 0.8]);
704 data.compute_roc(10);
705 assert!(data.roc_points.is_some());
706 assert!(data.auc_roc.is_some());
707 let auc = data.auc_roc.expect("computed above");
709 assert!(auc >= 0.0 && auc <= 1.0);
710 }
711
712 #[test]
713 fn test_pr_computation() {
714 let mut data = CurveData::new("Test", vec![0.0, 0.0, 1.0, 1.0], vec![0.1, 0.4, 0.35, 0.8]);
715 data.compute_pr(10);
716 assert!(data.pr_points.is_some());
717 assert!(data.auc_pr.is_some());
718 }
719
720 #[test]
721 fn test_empty_data() {
722 let mut data = CurveData::new("Empty", vec![], vec![]);
723 data.compute_roc(10);
724 assert_eq!(data.auc_roc, Some(0.5));
725 }
726
727 #[test]
728 fn test_empty_data_pr() {
729 let mut data = CurveData::new("Empty", vec![], vec![]);
730 data.compute_pr(10);
731 assert_eq!(data.auc_pr, Some(0.5));
732 }
733
734 #[test]
735 fn test_all_positives() {
736 let mut data = CurveData::new("AllPos", vec![1.0, 1.0, 1.0], vec![0.3, 0.6, 0.9]);
737 data.compute_roc(10);
738 assert!(data.roc_points.is_some());
740 }
741
742 #[test]
743 fn test_all_negatives() {
744 let mut data = CurveData::new("AllNeg", vec![0.0, 0.0, 0.0], vec![0.1, 0.5, 0.9]);
745 data.compute_roc(10);
746 assert!(data.roc_points.is_some());
748 }
749
750 #[test]
751 fn test_all_negatives_pr() {
752 let mut data = CurveData::new("AllNeg", vec![0.0, 0.0, 0.0], vec![0.1, 0.5, 0.9]);
753 data.compute_pr(10);
754 assert!(data.pr_points.is_some());
756 }
757
758 #[test]
759 fn test_auc_getters() {
760 let data = CurveData::new("Test", vec![0.0, 1.0], vec![0.3, 0.7]);
761 assert!(data.auc_roc().is_none());
762 assert!(data.auc_pr().is_none());
763 }
764
765 #[test]
766 fn test_generate_thresholds_empty() {
767 let thresholds = CurveData::generate_thresholds(&[], 10);
768 assert_eq!(thresholds, vec![0.5]);
769 }
770
771 #[test]
772 fn test_count_classes_scalar() {
773 let data = CurveData::new("Test", vec![0.0, 0.0, 1.0, 1.0, 1.0], vec![0.0; 5]);
774 let (n_pos, n_neg) = data.count_classes_scalar();
775 assert!((n_pos - 3.0).abs() < 0.001);
776 assert!((n_neg - 2.0).abs() < 0.001);
777 }
778
779 #[test]
780 fn test_count_classes_simd() {
781 let y_true: Vec<f64> = (0..150)
782 .map(|i| if i % 3 == 0 { 1.0 } else { 0.0 })
783 .collect();
784 let data = CurveData::new("Test", y_true, vec![0.0; 150]);
785 let (n_pos, n_neg) = data.count_classes_simd();
786 assert_eq!(n_pos, 50.0);
787 assert_eq!(n_neg, 100.0);
788 }
789
790 #[test]
791 fn test_count_positives_scalar() {
792 let data = CurveData::new("Test", vec![0.0, 0.0, 1.0, 1.0], vec![0.2, 0.8, 0.3, 0.9]);
793 let (tp, fp) = data.count_positives_at_threshold_scalar(0.5);
794 assert!((tp - 1.0).abs() < 0.001); assert!((fp - 1.0).abs() < 0.001); }
797
798 #[test]
799 fn test_count_positives_simd() {
800 let y_true: Vec<f64> = (0..150)
801 .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
802 .collect();
803 let y_score: Vec<f64> = (0..150).map(|i| i as f64 / 150.0).collect();
804 let data = CurveData::new("Test", y_true, y_score);
805 let (tp, fp) = data.count_positives_at_threshold_simd(0.5);
806 assert!(tp > 0.0);
807 assert!(fp > 0.0);
808 }
809
810 #[test]
811 fn test_roc_pr_curve_creation() {
812 let curve = RocPrCurve::new(vec![CurveData::new("A", vec![0.0, 1.0], vec![0.3, 0.7])]);
813 assert_eq!(curve.curves.len(), 1);
814 }
815
816 #[test]
817 fn test_roc_pr_curve_default() {
818 let curve = RocPrCurve::default();
819 assert!(curve.curves.is_empty());
820 }
821
822 #[test]
823 fn test_curve_mode() {
824 let curve = RocPrCurve::default().with_mode(CurveMode::Both);
825 assert_eq!(curve.mode, CurveMode::Both);
826 }
827
828 #[test]
829 fn test_curve_mode_default() {
830 let mode = CurveMode::default();
831 assert_eq!(mode, CurveMode::Roc);
832 }
833
834 #[test]
835 fn test_with_gradient() {
836 let gradient = Gradient::two(
837 Color::new(1.0, 0.0, 0.0, 1.0),
838 Color::new(0.0, 0.0, 1.0, 1.0),
839 );
840 let curve = RocPrCurve::default().with_gradient(gradient);
841 assert!(curve.gradient.is_some());
842 }
843
844 #[test]
845 fn test_roc_pr_curve_measure_roc() {
846 let curve = RocPrCurve::default().with_mode(CurveMode::Roc);
847 let constraints = Constraints::new(0.0, 100.0, 0.0, 50.0);
848 let size = curve.measure(constraints);
849 assert_eq!(size.width, 40.0);
850 assert_eq!(size.height, 20.0);
851 }
852
853 #[test]
854 fn test_roc_pr_curve_measure_both() {
855 let curve = RocPrCurve::default().with_mode(CurveMode::Both);
856 let constraints = Constraints::new(0.0, 100.0, 0.0, 50.0);
857 let size = curve.measure(constraints);
858 assert_eq!(size.width, 80.0);
859 }
860
861 #[test]
862 fn test_roc_pr_curve_layout_and_paint_roc() {
863 let mut curve = RocPrCurve::new(vec![CurveData::new(
864 "Good",
865 vec![0.0, 0.0, 1.0, 1.0],
866 vec![0.1, 0.2, 0.8, 0.9],
867 )])
868 .with_mode(CurveMode::Roc);
869
870 let mut buffer = CellBuffer::new(50, 20);
871 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
872
873 let result = curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
874 assert_eq!(result.size.width, 50.0);
875
876 curve.paint(&mut canvas);
877 }
878
879 #[test]
880 fn test_roc_pr_curve_layout_and_paint_pr() {
881 let mut curve = RocPrCurve::new(vec![CurveData::new(
882 "Model",
883 vec![0.0, 0.0, 1.0, 1.0],
884 vec![0.1, 0.2, 0.8, 0.9],
885 )])
886 .with_mode(CurveMode::PrecisionRecall);
887
888 let mut buffer = CellBuffer::new(50, 20);
889 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
890
891 curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
892 curve.paint(&mut canvas);
893 }
894
895 #[test]
896 fn test_roc_pr_curve_layout_and_paint_both() {
897 let mut curve = RocPrCurve::new(vec![CurveData::new(
898 "Model",
899 vec![0.0, 0.0, 1.0, 1.0],
900 vec![0.1, 0.2, 0.8, 0.9],
901 )])
902 .with_mode(CurveMode::Both);
903
904 let mut buffer = CellBuffer::new(80, 20);
905 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
906
907 curve.layout(Rect::new(0.0, 0.0, 80.0, 20.0));
908 curve.paint(&mut canvas);
909 }
910
911 #[test]
912 fn test_roc_pr_curve_paint_small_bounds() {
913 let mut curve = RocPrCurve::new(vec![CurveData::new(
914 "Model",
915 vec![0.0, 1.0],
916 vec![0.3, 0.7],
917 )]);
918
919 let mut buffer = CellBuffer::new(5, 3);
920 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
921
922 curve.layout(Rect::new(0.0, 0.0, 5.0, 3.0));
923 curve.paint(&mut canvas);
924 }
926
927 #[test]
928 fn test_roc_pr_curve_paint_no_baseline() {
929 let mut curve = RocPrCurve::new(vec![CurveData::new(
930 "Model",
931 vec![0.0, 1.0],
932 vec![0.3, 0.7],
933 )])
934 .with_baseline(false);
935
936 let mut buffer = CellBuffer::new(50, 20);
937 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
938
939 curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
940 curve.paint(&mut canvas);
941 }
942
943 #[test]
944 fn test_roc_pr_curve_paint_no_grid() {
945 let mut curve = RocPrCurve::new(vec![CurveData::new(
946 "Model",
947 vec![0.0, 1.0],
948 vec![0.3, 0.7],
949 )])
950 .with_grid(false);
951
952 let mut buffer = CellBuffer::new(50, 20);
953 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
954
955 curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
956 curve.paint(&mut canvas);
957 }
958
959 #[test]
960 fn test_roc_pr_curve_paint_no_auc() {
961 let mut curve = RocPrCurve::new(vec![CurveData::new(
962 "Model",
963 vec![0.0, 1.0],
964 vec![0.3, 0.7],
965 )])
966 .with_auc(false);
967
968 let mut buffer = CellBuffer::new(50, 20);
969 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
970
971 curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
972 curve.paint(&mut canvas);
973 }
974
975 #[test]
976 fn test_roc_pr_curve_paint_with_gradient() {
977 let gradient = Gradient::two(
978 Color::new(0.2, 0.4, 0.8, 1.0),
979 Color::new(0.8, 0.4, 0.2, 1.0),
980 );
981 let mut curve = RocPrCurve::new(vec![
982 CurveData::new("A", vec![0.0, 1.0], vec![0.3, 0.7]),
983 CurveData::new("B", vec![0.0, 1.0], vec![0.4, 0.6]),
984 ])
985 .with_gradient(gradient);
986
987 let mut buffer = CellBuffer::new(50, 20);
988 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
989
990 curve.layout(Rect::new(0.0, 0.0, 50.0, 20.0));
991 curve.paint(&mut canvas);
992 }
993
994 #[test]
995 fn test_roc_pr_curve_assertions() {
996 let curve = RocPrCurve::default();
997 assert!(!curve.assertions().is_empty());
998 }
999
1000 #[test]
1001 fn test_roc_pr_curve_verify_valid() {
1002 let mut curve = RocPrCurve::default();
1003 curve.bounds = Rect::new(0.0, 0.0, 40.0, 20.0);
1004 assert!(curve.verify().is_valid());
1005 }
1006
1007 #[test]
1008 fn test_roc_pr_curve_verify_invalid() {
1009 let mut curve = RocPrCurve::default();
1010 curve.bounds = Rect::new(0.0, 0.0, 5.0, 3.0);
1011 assert!(!curve.verify().is_valid());
1012 }
1013
1014 #[test]
1015 fn test_add_curve() {
1016 let mut curve = RocPrCurve::default();
1017 curve.add_curve(CurveData::new("New", vec![0.0, 1.0], vec![0.3, 0.7]));
1018 assert_eq!(curve.curves.len(), 1);
1019 }
1020
1021 #[test]
1022 fn test_with_thresholds() {
1023 let curve = RocPrCurve::default().with_thresholds(50);
1024 assert_eq!(curve.num_thresholds, 50);
1025 }
1026
1027 #[test]
1028 fn test_thresholds_clamped() {
1029 let curve = RocPrCurve::default().with_thresholds(5);
1030 assert_eq!(curve.num_thresholds, 10);
1031
1032 let curve = RocPrCurve::default().with_thresholds(5000);
1033 assert_eq!(curve.num_thresholds, 1000);
1034 }
1035
1036 #[test]
1037 fn test_large_dataset_simd() {
1038 let y_true: Vec<f64> = (0..200)
1040 .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
1041 .collect();
1042 let y_score: Vec<f64> = (0..200).map(|i| i as f64 / 200.0).collect();
1043 let mut data = CurveData::new("Large", y_true, y_score);
1044 data.compute_roc(50);
1045 assert!(data.auc_roc.is_some());
1046 }
1047
1048 #[test]
1049 fn test_large_dataset_simd_pr() {
1050 let y_true: Vec<f64> = (0..200)
1051 .map(|i| if i % 2 == 0 { 0.0 } else { 1.0 })
1052 .collect();
1053 let y_score: Vec<f64> = (0..200).map(|i| i as f64 / 200.0).collect();
1054 let mut data = CurveData::new("Large", y_true, y_score);
1055 data.compute_pr(50);
1056 assert!(data.auc_pr.is_some());
1057 }
1058
1059 #[test]
1060 fn test_with_baseline() {
1061 let curve = RocPrCurve::default().with_baseline(false);
1062 assert!(!curve.show_baseline);
1063 }
1064
1065 #[test]
1066 fn test_with_auc() {
1067 let curve = RocPrCurve::default().with_auc(false);
1068 assert!(!curve.show_auc);
1069 }
1070
1071 #[test]
1072 fn test_with_grid() {
1073 let curve = RocPrCurve::default().with_grid(false);
1074 assert!(!curve.show_grid);
1075 }
1076
1077 #[test]
1078 fn test_children() {
1079 let curve = RocPrCurve::default();
1080 assert!(curve.children().is_empty());
1081 }
1082
1083 #[test]
1084 fn test_children_mut() {
1085 let mut curve = RocPrCurve::default();
1086 assert!(curve.children_mut().is_empty());
1087 }
1088
1089 #[test]
1090 fn test_brick_name() {
1091 let curve = RocPrCurve::default();
1092 assert_eq!(curve.brick_name(), "RocPrCurve");
1093 }
1094
1095 #[test]
1096 fn test_budget() {
1097 let curve = RocPrCurve::default();
1098 let budget = curve.budget();
1099 assert!(budget.layout_ms > 0);
1100 }
1101
1102 #[test]
1103 fn test_to_html() {
1104 let curve = RocPrCurve::default();
1105 assert!(curve.to_html().is_empty());
1106 }
1107
1108 #[test]
1109 fn test_to_css() {
1110 let curve = RocPrCurve::default();
1111 assert!(curve.to_css().is_empty());
1112 }
1113
1114 #[test]
1115 fn test_type_id() {
1116 let curve = RocPrCurve::default();
1117 let type_id = Widget::type_id(&curve);
1118 assert_eq!(type_id, TypeId::of::<RocPrCurve>());
1119 }
1120
1121 #[test]
1122 fn test_event() {
1123 let mut curve = RocPrCurve::default();
1124 let event = Event::Resize {
1125 width: 80.0,
1126 height: 24.0,
1127 };
1128 assert!(curve.event(&event).is_none());
1129 }
1130
1131 #[test]
1132 fn test_multiple_curves() {
1133 let mut curve = RocPrCurve::new(vec![
1134 CurveData::new("A", vec![0.0, 1.0, 0.0, 1.0], vec![0.1, 0.9, 0.2, 0.8]),
1135 CurveData::new("B", vec![0.0, 1.0, 0.0, 1.0], vec![0.3, 0.7, 0.4, 0.6]),
1136 ]);
1137
1138 let mut buffer = CellBuffer::new(60, 25);
1139 let mut canvas = DirectTerminalCanvas::new(&mut buffer);
1140
1141 curve.layout(Rect::new(0.0, 0.0, 60.0, 25.0));
1142 curve.paint(&mut canvas);
1143 }
1144}