1use serde::{Deserialize, Serialize};
56use std::collections::HashMap;
57
58#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct Prediction {
61 pub confidence: f32,
63 pub correct: bool,
65 pub category: Option<String>,
67 pub timestamp: Option<u64>,
69 pub metadata: HashMap<String, String>,
71}
72
73impl Prediction {
74 pub fn new(confidence: f32, correct: bool) -> Self {
75 Self {
76 confidence: confidence.clamp(0.0, 1.0),
77 correct,
78 category: None,
79 timestamp: None,
80 metadata: HashMap::new(),
81 }
82 }
83
84 pub fn with_category(mut self, category: impl Into<String>) -> Self {
85 self.category = Some(category.into());
86 self
87 }
88
89 pub fn with_timestamp(mut self, timestamp: u64) -> Self {
90 self.timestamp = Some(timestamp);
91 self
92 }
93}
94
95#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ConfidenceBin {
98 pub lower: f32,
100 pub upper: f32,
102 pub count: usize,
104 pub avg_confidence: f32,
106 pub accuracy: f32,
108 pub calibration_error: f32,
110}
111
112#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct CalibrationReport {
115 pub total_predictions: usize,
117 pub overall_accuracy: f32,
119 pub avg_confidence: f32,
121 pub brier_score: f32,
123 pub ece: f32,
125 pub mce: f32,
127 pub bins: Vec<ConfidenceBin>,
129 pub diagnosis: CalibrationDiagnosis,
131 pub recommendations: Vec<String>,
133 pub category_stats: HashMap<String, CategoryCalibration>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct CategoryCalibration {
139 pub count: usize,
140 pub accuracy: f32,
141 pub avg_confidence: f32,
142 pub brier_score: f32,
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146pub enum CalibrationDiagnosis {
147 WellCalibrated,
149 SlightlyOverconfident,
151 Overconfident,
153 SeverelyOverconfident,
155 Underconfident,
157 Mixed,
159 InsufficientData,
161}
162
163impl CalibrationDiagnosis {
164 pub fn from_metrics(ece: f32, avg_confidence: f32, accuracy: f32) -> Self {
165 if avg_confidence > accuracy + 0.15 {
166 if ece >= 0.20 {
167 Self::SeverelyOverconfident
168 } else if ece >= 0.10 {
169 Self::Overconfident
170 } else {
171 Self::SlightlyOverconfident
172 }
173 } else if avg_confidence < accuracy - 0.15 {
174 Self::Underconfident
175 } else if ece < 0.05 {
176 Self::WellCalibrated
177 } else if ece < 0.10 {
178 Self::SlightlyOverconfident
179 } else {
180 Self::Mixed
181 }
182 }
183
184 pub fn description(&self) -> &'static str {
185 match self {
186 Self::WellCalibrated => "Confidence matches accuracy well",
187 Self::SlightlyOverconfident => "Slightly too confident in predictions",
188 Self::Overconfident => "Significantly overconfident - reduce certainty",
189 Self::SeverelyOverconfident => "Severely overconfident - major recalibration needed",
190 Self::Underconfident => "Too cautious - can trust predictions more",
191 Self::Mixed => "Calibration varies by confidence level",
192 Self::InsufficientData => "Not enough data to assess calibration",
193 }
194 }
195}
196
197#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct CalibrationConfig {
200 pub num_bins: usize,
202 pub min_predictions: usize,
204 pub well_calibrated_threshold: f32,
206 pub track_categories: bool,
208}
209
210impl Default for CalibrationConfig {
211 fn default() -> Self {
212 Self {
213 num_bins: 10,
214 min_predictions: 20,
215 well_calibrated_threshold: 0.05,
216 track_categories: true,
217 }
218 }
219}
220
221pub struct CalibrationTracker {
223 pub config: CalibrationConfig,
224 predictions: Vec<Prediction>,
225}
226
227impl CalibrationTracker {
228 pub fn new() -> Self {
229 Self {
230 config: CalibrationConfig::default(),
231 predictions: Vec::new(),
232 }
233 }
234
235 pub fn with_config(config: CalibrationConfig) -> Self {
236 Self {
237 config,
238 predictions: Vec::new(),
239 }
240 }
241
242 pub fn record(&mut self, prediction: Prediction) {
244 self.predictions.push(prediction);
245 }
246
247 pub fn record_batch(&mut self, predictions: Vec<Prediction>) {
249 self.predictions.extend(predictions);
250 }
251
252 pub fn count(&self) -> usize {
254 self.predictions.len()
255 }
256
257 pub fn clear(&mut self) {
259 self.predictions.clear();
260 }
261
262 pub fn brier_score(&self) -> f32 {
264 if self.predictions.is_empty() {
265 return 0.0;
266 }
267
268 self.predictions
269 .iter()
270 .map(|p| {
271 let outcome = if p.correct { 1.0 } else { 0.0 };
272 (p.confidence - outcome).powi(2)
273 })
274 .sum::<f32>()
275 / self.predictions.len() as f32
276 }
277
278 fn compute_bins(&self) -> Vec<ConfidenceBin> {
280 let num_bins = self.config.num_bins;
281 let bin_width = 1.0 / num_bins as f32;
282
283 (0..num_bins)
284 .map(|i| {
285 let lower = i as f32 * bin_width;
286 let upper = (i + 1) as f32 * bin_width;
287
288 let in_bin: Vec<_> = self
289 .predictions
290 .iter()
291 .filter(|p| p.confidence >= lower && p.confidence < upper.min(1.001))
292 .collect();
293
294 let count = in_bin.len();
295
296 if count == 0 {
297 return ConfidenceBin {
298 lower,
299 upper,
300 count: 0,
301 avg_confidence: (lower + upper) / 2.0,
302 accuracy: 0.0,
303 calibration_error: 0.0,
304 };
305 }
306
307 let avg_confidence =
308 in_bin.iter().map(|p| p.confidence).sum::<f32>() / count as f32;
309 let accuracy = in_bin.iter().filter(|p| p.correct).count() as f32 / count as f32;
310 let calibration_error = (avg_confidence - accuracy).abs();
311
312 ConfidenceBin {
313 lower,
314 upper,
315 count,
316 avg_confidence,
317 accuracy,
318 calibration_error,
319 }
320 })
321 .collect()
322 }
323
324 pub fn ece(&self) -> f32 {
326 if self.predictions.is_empty() {
327 return 0.0;
328 }
329
330 let bins = self.compute_bins();
331 let total = self.predictions.len() as f32;
332
333 bins.iter()
334 .map(|bin| (bin.count as f32 / total) * bin.calibration_error)
335 .sum()
336 }
337
338 pub fn mce(&self) -> f32 {
340 self.compute_bins()
341 .iter()
342 .filter(|bin| bin.count > 0)
343 .map(|bin| bin.calibration_error)
344 .max_by(|a, b| a.partial_cmp(b).unwrap())
345 .unwrap_or(0.0)
346 }
347
348 pub fn accuracy(&self) -> f32 {
350 if self.predictions.is_empty() {
351 return 0.0;
352 }
353
354 self.predictions.iter().filter(|p| p.correct).count() as f32 / self.predictions.len() as f32
355 }
356
357 pub fn avg_confidence(&self) -> f32 {
359 if self.predictions.is_empty() {
360 return 0.0;
361 }
362
363 self.predictions.iter().map(|p| p.confidence).sum::<f32>() / self.predictions.len() as f32
364 }
365
366 fn compute_category_stats(&self) -> HashMap<String, CategoryCalibration> {
368 let mut categories: HashMap<String, Vec<&Prediction>> = HashMap::new();
369
370 for pred in &self.predictions {
371 if let Some(ref cat) = pred.category {
372 categories.entry(cat.clone()).or_default().push(pred);
373 }
374 }
375
376 categories
377 .into_iter()
378 .map(|(cat, preds)| {
379 let count = preds.len();
380 let accuracy = preds.iter().filter(|p| p.correct).count() as f32 / count as f32;
381 let avg_confidence = preds.iter().map(|p| p.confidence).sum::<f32>() / count as f32;
382 let brier_score = preds
383 .iter()
384 .map(|p| {
385 let outcome = if p.correct { 1.0 } else { 0.0 };
386 (p.confidence - outcome).powi(2)
387 })
388 .sum::<f32>()
389 / count as f32;
390
391 (
392 cat,
393 CategoryCalibration {
394 count,
395 accuracy,
396 avg_confidence,
397 brier_score,
398 },
399 )
400 })
401 .collect()
402 }
403
404 fn generate_recommendations(
406 &self,
407 diagnosis: CalibrationDiagnosis,
408 bins: &[ConfidenceBin],
409 ) -> Vec<String> {
410 let mut recs = Vec::new();
411
412 match diagnosis {
413 CalibrationDiagnosis::SeverelyOverconfident => {
414 recs.push("Reduce confidence by 20-30% across all predictions".into());
415 recs.push("Add explicit uncertainty language (\"possibly\", \"likely\")".into());
416 recs.push("Consider using --paranoid profile for verification".into());
417 }
418 CalibrationDiagnosis::Overconfident => {
419 recs.push("Reduce confidence by 10-20%".into());
420 recs.push("Add qualifiers to high-confidence claims".into());
421 }
422 CalibrationDiagnosis::SlightlyOverconfident => {
423 recs.push("Minor confidence adjustment recommended".into());
424 recs.push("Focus on claims in 80-100% confidence range".into());
425 }
426 CalibrationDiagnosis::Underconfident => {
427 recs.push("Can trust predictions more".into());
428 recs.push("Consider increasing confidence by 10-15%".into());
429 }
430 CalibrationDiagnosis::Mixed => {
431 for bin in bins {
433 if bin.count >= 5
434 && bin.calibration_error > 0.15
435 && bin.avg_confidence > bin.accuracy
436 {
437 recs.push(format!(
438 "For {:.0}%-{:.0}% confidence: reduce by {:.0}%",
439 bin.lower * 100.0,
440 bin.upper * 100.0,
441 bin.calibration_error * 100.0
442 ));
443 }
444 }
445 }
446 CalibrationDiagnosis::WellCalibrated => {
447 recs.push("Calibration is good - maintain current approach".into());
448 }
449 CalibrationDiagnosis::InsufficientData => {
450 recs.push("Need more predictions to assess calibration".into());
451 }
452 }
453
454 recs
455 }
456
457 pub fn generate_report(&self) -> CalibrationReport {
459 let bins = self.compute_bins();
460 let brier_score = self.brier_score();
461 let ece = self.ece();
462 let mce = self.mce();
463 let overall_accuracy = self.accuracy();
464 let avg_confidence = self.avg_confidence();
465
466 let diagnosis = if self.predictions.len() < self.config.min_predictions {
467 CalibrationDiagnosis::InsufficientData
468 } else {
469 CalibrationDiagnosis::from_metrics(ece, avg_confidence, overall_accuracy)
470 };
471
472 let recommendations = self.generate_recommendations(diagnosis, &bins);
473
474 let category_stats = if self.config.track_categories {
475 self.compute_category_stats()
476 } else {
477 HashMap::new()
478 };
479
480 CalibrationReport {
481 total_predictions: self.predictions.len(),
482 overall_accuracy,
483 avg_confidence,
484 brier_score,
485 ece,
486 mce,
487 bins,
488 diagnosis,
489 recommendations,
490 category_stats,
491 }
492 }
493}
494
495impl Default for CalibrationTracker {
496 fn default() -> Self {
497 Self::new()
498 }
499}
500
501impl CalibrationReport {
502 pub fn format(&self) -> String {
504 let mut output = String::new();
505
506 output
507 .push_str("┌─────────────────────────────────────────────────────────────────────┐\n");
508 output
509 .push_str("│ CALIBRATION REPORT │\n");
510 output
511 .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
512
513 output.push_str(&format!(
514 "│ Total Predictions: {:<50}│\n",
515 self.total_predictions
516 ));
517 output.push_str(&format!(
518 "│ Overall Accuracy: {:.1}%{:>45}│\n",
519 self.overall_accuracy * 100.0,
520 ""
521 ));
522 output.push_str(&format!(
523 "│ Avg Confidence: {:.1}%{:>45}│\n",
524 self.avg_confidence * 100.0,
525 ""
526 ));
527
528 output
529 .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
530 output
531 .push_str("│ CALIBRATION METRICS │\n");
532 output.push_str(&format!(
533 "│ Brier Score: {:.3} (0=perfect, <0.25 good){:>21}│\n",
534 self.brier_score, ""
535 ));
536 output.push_str(&format!(
537 "│ ECE: {:.3} (<0.05 well-calibrated){:>21}│\n",
538 self.ece, ""
539 ));
540 output.push_str(&format!(
541 "│ MCE: {:.3} (worst bin){:>33}│\n",
542 self.mce, ""
543 ));
544
545 output
546 .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
547 output.push_str(&format!("│ DIAGNOSIS: {:?} {:>42}│\n", self.diagnosis, ""));
548 output.push_str(&format!(
549 "│ {}{:>52}│\n",
550 self.diagnosis.description(),
551 ""
552 ));
553
554 output
556 .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
557 output
558 .push_str("│ CALIBRATION CURVE │\n");
559 output
560 .push_str("│ Confidence → Accuracy │\n");
561
562 for bin in &self.bins {
563 if bin.count > 0 {
564 let bar_len = (bin.accuracy * 30.0) as usize;
565 let bar = "█".repeat(bar_len);
566 let gap = " ".repeat(30 - bar_len);
567
568 let indicator = if bin.calibration_error > 0.15 {
569 "⚠"
570 } else if bin.calibration_error > 0.05 {
571 "○"
572 } else {
573 "✓"
574 };
575
576 output.push_str(&format!(
577 "│ {:.0}-{:.0}%: {} |{}{}| {:.0}% (n={}){}│\n",
578 bin.lower * 100.0,
579 bin.upper * 100.0,
580 indicator,
581 bar,
582 gap,
583 bin.accuracy * 100.0,
584 bin.count,
585 " ".repeat(10)
586 ));
587 }
588 }
589
590 if !self.recommendations.is_empty() {
592 output.push_str(
593 "├─────────────────────────────────────────────────────────────────────┤\n",
594 );
595 output.push_str(
596 "│ RECOMMENDATIONS │\n",
597 );
598 for rec in &self.recommendations {
599 output.push_str(&format!("│ • {:<62}│\n", rec));
600 }
601 }
602
603 output
604 .push_str("└─────────────────────────────────────────────────────────────────────┘\n");
605
606 output
607 }
608}
609
610pub fn platt_scale(confidence: f32, a: f32, b: f32) -> f32 {
612 1.0 / (1.0 + (-a * confidence + b).exp())
613}
614
615pub fn temperature_scale(logit: f32, temperature: f32) -> f32 {
617 1.0 / (1.0 + (-logit / temperature).exp())
618}
619
620pub struct ConfidenceAdjuster;
622
623impl ConfidenceAdjuster {
624 pub fn adjust(raw_confidence: f32, diagnosis: CalibrationDiagnosis) -> f32 {
626 match diagnosis {
627 CalibrationDiagnosis::SeverelyOverconfident => {
628 raw_confidence * 0.75
630 }
631 CalibrationDiagnosis::Overconfident => {
632 raw_confidence * 0.85
634 }
635 CalibrationDiagnosis::SlightlyOverconfident => {
636 raw_confidence * 0.95
638 }
639 CalibrationDiagnosis::Underconfident => {
640 (raw_confidence * 1.1).min(0.95)
642 }
643 _ => raw_confidence,
644 }
645 }
646
647 pub fn confidence_to_qualifier(confidence: f32) -> &'static str {
649 if confidence >= 0.95 {
650 "certainly"
651 } else if confidence >= 0.85 {
652 "very likely"
653 } else if confidence >= 0.70 {
654 "probably"
655 } else if confidence >= 0.50 {
656 "possibly"
657 } else if confidence >= 0.30 {
658 "unlikely"
659 } else {
660 "very unlikely"
661 }
662 }
663}
664
665#[cfg(test)]
666mod tests {
667 use super::*;
668
669 #[test]
670 fn test_perfect_calibration() {
671 let mut tracker = CalibrationTracker::new();
672
673 for _ in 0..9 {
675 tracker.record(Prediction::new(0.9, true));
676 }
677 tracker.record(Prediction::new(0.9, false));
678
679 for _ in 0..5 {
681 tracker.record(Prediction::new(0.5, true));
682 }
683 for _ in 0..5 {
684 tracker.record(Prediction::new(0.5, false));
685 }
686
687 let report = tracker.generate_report();
688 assert!(report.ece < 0.15); }
690
691 #[test]
692 fn test_overconfident() {
693 let mut tracker = CalibrationTracker::new();
694
695 for _ in 0..25 {
697 tracker.record(Prediction::new(0.9, true));
698 tracker.record(Prediction::new(0.9, false));
699 }
700
701 let report = tracker.generate_report();
702 assert!(matches!(
703 report.diagnosis,
704 CalibrationDiagnosis::Overconfident | CalibrationDiagnosis::SeverelyOverconfident
705 ));
706 }
707
708 #[test]
709 fn test_brier_score() {
710 let mut tracker = CalibrationTracker::new();
711
712 tracker.record(Prediction::new(1.0, true));
714 tracker.record(Prediction::new(0.0, false));
715
716 let brier = tracker.brier_score();
717 assert!((brier - 0.0).abs() < 0.01);
718 }
719
720 #[test]
721 fn test_category_tracking() {
722 let mut tracker = CalibrationTracker::with_config(CalibrationConfig {
723 track_categories: true,
724 ..Default::default()
725 });
726
727 tracker.record(Prediction::new(0.8, true).with_category("math"));
728 tracker.record(Prediction::new(0.7, true).with_category("math"));
729 tracker.record(Prediction::new(0.9, false).with_category("logic"));
730
731 let report = tracker.generate_report();
732 assert!(report.category_stats.contains_key("math"));
733 assert_eq!(report.category_stats["math"].count, 2);
734 }
735
736 #[test]
737 fn test_confidence_adjuster() {
738 let adjusted = ConfidenceAdjuster::adjust(0.9, CalibrationDiagnosis::SeverelyOverconfident);
739 assert!((adjusted - 0.675).abs() < 0.01);
740
741 let qualifier = ConfidenceAdjuster::confidence_to_qualifier(0.85);
742 assert_eq!(qualifier, "very likely");
743 }
744}