reasonkit/thinktool/
calibration.rs

1//! # Metacognitive Calibration System
2//!
3//! Implements confidence calibration tracking and improvement.
4//! Ensures AI confidence matches actual accuracy.
5//!
6//! ## Scientific Foundation
7//!
8//! Based on:
9//! - Brier score for probabilistic forecasting
10//! - Expected Calibration Error (ECE)
11//! - Metacognitive sensitivity research
12//!
13//! ## Core Concept
14//!
15//! ```text
16//! ┌─────────────────────────────────────────────────────────────────────┐
17//! │                 CALIBRATION = CONFIDENCE ≈ ACCURACY                 │
18//! ├─────────────────────────────────────────────────────────────────────┤
19//! │                                                                     │
20//! │   Perfect Calibration:                                              │
21//! │   • 90% confident claims are correct 90% of the time               │
22//! │   • 70% confident claims are correct 70% of the time               │
23//! │   • 50% confident claims are correct 50% of the time               │
24//! │                                                                     │
25//! │   Overconfidence (common):                                          │
26//! │   • 90% confident but only 60% accurate → needs recalibration      │
27//! │                                                                     │
28//! │   Underconfidence (rare):                                           │
29//! │   • 50% confident but 80% accurate → can trust more                │
30//! │                                                                     │
31//! └─────────────────────────────────────────────────────────────────────┘
32//! ```
33//!
34//! ## Key Metrics
35//!
36//! - **Brier Score**: Mean squared error of probabilistic predictions (0 = perfect)
37//! - **ECE**: Expected calibration error across confidence bins
38//! - **MCE**: Maximum calibration error (worst bin)
39//! - **meta-d'**: Metacognitive sensitivity measure
40//!
41//! ## Usage
42//!
43//! ```rust,ignore
44//! use reasonkit::thinktool::calibration::{CalibrationTracker, Prediction};
45//!
46//! let mut tracker = CalibrationTracker::new();
47//!
48//! tracker.record(Prediction::new(0.9, true)); // 90% confident, was correct
49//! tracker.record(Prediction::new(0.8, false)); // 80% confident, was wrong
50//!
51//! let report = tracker.generate_report();
52//! println!("Brier Score: {:.3}", report.brier_score);
53//! ```
54
55use serde::{Deserialize, Serialize};
56use std::collections::HashMap;
57
58/// A single prediction with confidence and outcome
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct Prediction {
61    /// Confidence level (0.0 - 1.0)
62    pub confidence: f32,
63    /// Was the prediction correct?
64    pub correct: bool,
65    /// Category/domain of prediction
66    pub category: Option<String>,
67    /// Timestamp if tracking over time
68    pub timestamp: Option<u64>,
69    /// Additional metadata
70    pub metadata: HashMap<String, String>,
71}
72
73impl Prediction {
74    pub fn new(confidence: f32, correct: bool) -> Self {
75        Self {
76            confidence: confidence.clamp(0.0, 1.0),
77            correct,
78            category: None,
79            timestamp: None,
80            metadata: HashMap::new(),
81        }
82    }
83
84    pub fn with_category(mut self, category: impl Into<String>) -> Self {
85        self.category = Some(category.into());
86        self
87    }
88
89    pub fn with_timestamp(mut self, timestamp: u64) -> Self {
90        self.timestamp = Some(timestamp);
91        self
92    }
93}
94
95/// Confidence bin for calibration analysis
96#[derive(Debug, Clone, Serialize, Deserialize)]
97pub struct ConfidenceBin {
98    /// Bin lower bound
99    pub lower: f32,
100    /// Bin upper bound
101    pub upper: f32,
102    /// Number of predictions in bin
103    pub count: usize,
104    /// Average confidence in bin
105    pub avg_confidence: f32,
106    /// Actual accuracy in bin
107    pub accuracy: f32,
108    /// Calibration error for this bin (|confidence - accuracy|)
109    pub calibration_error: f32,
110}
111
112/// Calibration report
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct CalibrationReport {
115    /// Total predictions analyzed
116    pub total_predictions: usize,
117    /// Overall accuracy
118    pub overall_accuracy: f32,
119    /// Average confidence
120    pub avg_confidence: f32,
121    /// Brier score (lower is better, 0 = perfect)
122    pub brier_score: f32,
123    /// Expected Calibration Error
124    pub ece: f32,
125    /// Maximum Calibration Error
126    pub mce: f32,
127    /// Confidence bins
128    pub bins: Vec<ConfidenceBin>,
129    /// Calibration diagnosis
130    pub diagnosis: CalibrationDiagnosis,
131    /// Recommendations
132    pub recommendations: Vec<String>,
133    /// Per-category stats
134    pub category_stats: HashMap<String, CategoryCalibration>,
135}
136
137#[derive(Debug, Clone, Serialize, Deserialize)]
138pub struct CategoryCalibration {
139    pub count: usize,
140    pub accuracy: f32,
141    pub avg_confidence: f32,
142    pub brier_score: f32,
143}
144
145#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
146pub enum CalibrationDiagnosis {
147    /// Well calibrated (ECE < 0.05)
148    WellCalibrated,
149    /// Slightly overconfident (ECE < 0.10)
150    SlightlyOverconfident,
151    /// Significantly overconfident (ECE < 0.20)
152    Overconfident,
153    /// Severely overconfident (ECE >= 0.20)
154    SeverelyOverconfident,
155    /// Underconfident
156    Underconfident,
157    /// Mixed calibration issues
158    Mixed,
159    /// Not enough data
160    InsufficientData,
161}
162
163impl CalibrationDiagnosis {
164    pub fn from_metrics(ece: f32, avg_confidence: f32, accuracy: f32) -> Self {
165        if avg_confidence > accuracy + 0.15 {
166            if ece >= 0.20 {
167                Self::SeverelyOverconfident
168            } else if ece >= 0.10 {
169                Self::Overconfident
170            } else {
171                Self::SlightlyOverconfident
172            }
173        } else if avg_confidence < accuracy - 0.15 {
174            Self::Underconfident
175        } else if ece < 0.05 {
176            Self::WellCalibrated
177        } else if ece < 0.10 {
178            Self::SlightlyOverconfident
179        } else {
180            Self::Mixed
181        }
182    }
183
184    pub fn description(&self) -> &'static str {
185        match self {
186            Self::WellCalibrated => "Confidence matches accuracy well",
187            Self::SlightlyOverconfident => "Slightly too confident in predictions",
188            Self::Overconfident => "Significantly overconfident - reduce certainty",
189            Self::SeverelyOverconfident => "Severely overconfident - major recalibration needed",
190            Self::Underconfident => "Too cautious - can trust predictions more",
191            Self::Mixed => "Calibration varies by confidence level",
192            Self::InsufficientData => "Not enough data to assess calibration",
193        }
194    }
195}
196
197/// Configuration for calibration tracking
198#[derive(Debug, Clone, Serialize, Deserialize)]
199pub struct CalibrationConfig {
200    /// Number of confidence bins
201    pub num_bins: usize,
202    /// Minimum predictions for valid analysis
203    pub min_predictions: usize,
204    /// ECE threshold for "well calibrated"
205    pub well_calibrated_threshold: f32,
206    /// Track per-category stats
207    pub track_categories: bool,
208}
209
210impl Default for CalibrationConfig {
211    fn default() -> Self {
212        Self {
213            num_bins: 10,
214            min_predictions: 20,
215            well_calibrated_threshold: 0.05,
216            track_categories: true,
217        }
218    }
219}
220
221/// Calibration tracker
222pub struct CalibrationTracker {
223    pub config: CalibrationConfig,
224    predictions: Vec<Prediction>,
225}
226
227impl CalibrationTracker {
228    pub fn new() -> Self {
229        Self {
230            config: CalibrationConfig::default(),
231            predictions: Vec::new(),
232        }
233    }
234
235    pub fn with_config(config: CalibrationConfig) -> Self {
236        Self {
237            config,
238            predictions: Vec::new(),
239        }
240    }
241
242    /// Record a prediction
243    pub fn record(&mut self, prediction: Prediction) {
244        self.predictions.push(prediction);
245    }
246
247    /// Record multiple predictions
248    pub fn record_batch(&mut self, predictions: Vec<Prediction>) {
249        self.predictions.extend(predictions);
250    }
251
252    /// Get number of predictions
253    pub fn count(&self) -> usize {
254        self.predictions.len()
255    }
256
257    /// Clear all predictions
258    pub fn clear(&mut self) {
259        self.predictions.clear();
260    }
261
262    /// Compute Brier score
263    pub fn brier_score(&self) -> f32 {
264        if self.predictions.is_empty() {
265            return 0.0;
266        }
267
268        self.predictions
269            .iter()
270            .map(|p| {
271                let outcome = if p.correct { 1.0 } else { 0.0 };
272                (p.confidence - outcome).powi(2)
273            })
274            .sum::<f32>()
275            / self.predictions.len() as f32
276    }
277
278    /// Compute binned calibration
279    fn compute_bins(&self) -> Vec<ConfidenceBin> {
280        let num_bins = self.config.num_bins;
281        let bin_width = 1.0 / num_bins as f32;
282
283        (0..num_bins)
284            .map(|i| {
285                let lower = i as f32 * bin_width;
286                let upper = (i + 1) as f32 * bin_width;
287
288                let in_bin: Vec<_> = self
289                    .predictions
290                    .iter()
291                    .filter(|p| p.confidence >= lower && p.confidence < upper.min(1.001))
292                    .collect();
293
294                let count = in_bin.len();
295
296                if count == 0 {
297                    return ConfidenceBin {
298                        lower,
299                        upper,
300                        count: 0,
301                        avg_confidence: (lower + upper) / 2.0,
302                        accuracy: 0.0,
303                        calibration_error: 0.0,
304                    };
305                }
306
307                let avg_confidence =
308                    in_bin.iter().map(|p| p.confidence).sum::<f32>() / count as f32;
309                let accuracy = in_bin.iter().filter(|p| p.correct).count() as f32 / count as f32;
310                let calibration_error = (avg_confidence - accuracy).abs();
311
312                ConfidenceBin {
313                    lower,
314                    upper,
315                    count,
316                    avg_confidence,
317                    accuracy,
318                    calibration_error,
319                }
320            })
321            .collect()
322    }
323
324    /// Compute Expected Calibration Error
325    pub fn ece(&self) -> f32 {
326        if self.predictions.is_empty() {
327            return 0.0;
328        }
329
330        let bins = self.compute_bins();
331        let total = self.predictions.len() as f32;
332
333        bins.iter()
334            .map(|bin| (bin.count as f32 / total) * bin.calibration_error)
335            .sum()
336    }
337
338    /// Compute Maximum Calibration Error
339    pub fn mce(&self) -> f32 {
340        self.compute_bins()
341            .iter()
342            .filter(|bin| bin.count > 0)
343            .map(|bin| bin.calibration_error)
344            .max_by(|a, b| a.partial_cmp(b).unwrap())
345            .unwrap_or(0.0)
346    }
347
348    /// Overall accuracy
349    pub fn accuracy(&self) -> f32 {
350        if self.predictions.is_empty() {
351            return 0.0;
352        }
353
354        self.predictions.iter().filter(|p| p.correct).count() as f32 / self.predictions.len() as f32
355    }
356
357    /// Average confidence
358    pub fn avg_confidence(&self) -> f32 {
359        if self.predictions.is_empty() {
360            return 0.0;
361        }
362
363        self.predictions.iter().map(|p| p.confidence).sum::<f32>() / self.predictions.len() as f32
364    }
365
366    /// Compute per-category stats
367    fn compute_category_stats(&self) -> HashMap<String, CategoryCalibration> {
368        let mut categories: HashMap<String, Vec<&Prediction>> = HashMap::new();
369
370        for pred in &self.predictions {
371            if let Some(ref cat) = pred.category {
372                categories.entry(cat.clone()).or_default().push(pred);
373            }
374        }
375
376        categories
377            .into_iter()
378            .map(|(cat, preds)| {
379                let count = preds.len();
380                let accuracy = preds.iter().filter(|p| p.correct).count() as f32 / count as f32;
381                let avg_confidence = preds.iter().map(|p| p.confidence).sum::<f32>() / count as f32;
382                let brier_score = preds
383                    .iter()
384                    .map(|p| {
385                        let outcome = if p.correct { 1.0 } else { 0.0 };
386                        (p.confidence - outcome).powi(2)
387                    })
388                    .sum::<f32>()
389                    / count as f32;
390
391                (
392                    cat,
393                    CategoryCalibration {
394                        count,
395                        accuracy,
396                        avg_confidence,
397                        brier_score,
398                    },
399                )
400            })
401            .collect()
402    }
403
404    /// Generate recommendations
405    fn generate_recommendations(
406        &self,
407        diagnosis: CalibrationDiagnosis,
408        bins: &[ConfidenceBin],
409    ) -> Vec<String> {
410        let mut recs = Vec::new();
411
412        match diagnosis {
413            CalibrationDiagnosis::SeverelyOverconfident => {
414                recs.push("Reduce confidence by 20-30% across all predictions".into());
415                recs.push("Add explicit uncertainty language (\"possibly\", \"likely\")".into());
416                recs.push("Consider using --paranoid profile for verification".into());
417            }
418            CalibrationDiagnosis::Overconfident => {
419                recs.push("Reduce confidence by 10-20%".into());
420                recs.push("Add qualifiers to high-confidence claims".into());
421            }
422            CalibrationDiagnosis::SlightlyOverconfident => {
423                recs.push("Minor confidence adjustment recommended".into());
424                recs.push("Focus on claims in 80-100% confidence range".into());
425            }
426            CalibrationDiagnosis::Underconfident => {
427                recs.push("Can trust predictions more".into());
428                recs.push("Consider increasing confidence by 10-15%".into());
429            }
430            CalibrationDiagnosis::Mixed => {
431                // Find problematic bins
432                for bin in bins {
433                    if bin.count >= 5
434                        && bin.calibration_error > 0.15
435                        && bin.avg_confidence > bin.accuracy
436                    {
437                        recs.push(format!(
438                            "For {:.0}%-{:.0}% confidence: reduce by {:.0}%",
439                            bin.lower * 100.0,
440                            bin.upper * 100.0,
441                            bin.calibration_error * 100.0
442                        ));
443                    }
444                }
445            }
446            CalibrationDiagnosis::WellCalibrated => {
447                recs.push("Calibration is good - maintain current approach".into());
448            }
449            CalibrationDiagnosis::InsufficientData => {
450                recs.push("Need more predictions to assess calibration".into());
451            }
452        }
453
454        recs
455    }
456
457    /// Generate full calibration report
458    pub fn generate_report(&self) -> CalibrationReport {
459        let bins = self.compute_bins();
460        let brier_score = self.brier_score();
461        let ece = self.ece();
462        let mce = self.mce();
463        let overall_accuracy = self.accuracy();
464        let avg_confidence = self.avg_confidence();
465
466        let diagnosis = if self.predictions.len() < self.config.min_predictions {
467            CalibrationDiagnosis::InsufficientData
468        } else {
469            CalibrationDiagnosis::from_metrics(ece, avg_confidence, overall_accuracy)
470        };
471
472        let recommendations = self.generate_recommendations(diagnosis, &bins);
473
474        let category_stats = if self.config.track_categories {
475            self.compute_category_stats()
476        } else {
477            HashMap::new()
478        };
479
480        CalibrationReport {
481            total_predictions: self.predictions.len(),
482            overall_accuracy,
483            avg_confidence,
484            brier_score,
485            ece,
486            mce,
487            bins,
488            diagnosis,
489            recommendations,
490            category_stats,
491        }
492    }
493}
494
495impl Default for CalibrationTracker {
496    fn default() -> Self {
497        Self::new()
498    }
499}
500
501impl CalibrationReport {
502    /// Format as a readable report
503    pub fn format(&self) -> String {
504        let mut output = String::new();
505
506        output
507            .push_str("┌─────────────────────────────────────────────────────────────────────┐\n");
508        output
509            .push_str("│                    CALIBRATION REPORT                               │\n");
510        output
511            .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
512
513        output.push_str(&format!(
514            "│ Total Predictions: {:<50}│\n",
515            self.total_predictions
516        ));
517        output.push_str(&format!(
518            "│ Overall Accuracy:  {:.1}%{:>45}│\n",
519            self.overall_accuracy * 100.0,
520            ""
521        ));
522        output.push_str(&format!(
523            "│ Avg Confidence:    {:.1}%{:>45}│\n",
524            self.avg_confidence * 100.0,
525            ""
526        ));
527
528        output
529            .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
530        output
531            .push_str("│ CALIBRATION METRICS                                                 │\n");
532        output.push_str(&format!(
533            "│   Brier Score: {:.3} (0=perfect, <0.25 good){:>21}│\n",
534            self.brier_score, ""
535        ));
536        output.push_str(&format!(
537            "│   ECE:         {:.3} (<0.05 well-calibrated){:>21}│\n",
538            self.ece, ""
539        ));
540        output.push_str(&format!(
541            "│   MCE:         {:.3} (worst bin){:>33}│\n",
542            self.mce, ""
543        ));
544
545        output
546            .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
547        output.push_str(&format!("│ DIAGNOSIS: {:?} {:>42}│\n", self.diagnosis, ""));
548        output.push_str(&format!(
549            "│   {}{:>52}│\n",
550            self.diagnosis.description(),
551            ""
552        ));
553
554        // Confidence bins visualization
555        output
556            .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
557        output
558            .push_str("│ CALIBRATION CURVE                                                   │\n");
559        output
560            .push_str("│   Confidence → Accuracy                                             │\n");
561
562        for bin in &self.bins {
563            if bin.count > 0 {
564                let bar_len = (bin.accuracy * 30.0) as usize;
565                let bar = "█".repeat(bar_len);
566                let gap = " ".repeat(30 - bar_len);
567
568                let indicator = if bin.calibration_error > 0.15 {
569                    "⚠"
570                } else if bin.calibration_error > 0.05 {
571                    "○"
572                } else {
573                    "✓"
574                };
575
576                output.push_str(&format!(
577                    "│   {:.0}-{:.0}%: {} |{}{}| {:.0}% (n={}){}│\n",
578                    bin.lower * 100.0,
579                    bin.upper * 100.0,
580                    indicator,
581                    bar,
582                    gap,
583                    bin.accuracy * 100.0,
584                    bin.count,
585                    " ".repeat(10)
586                ));
587            }
588        }
589
590        // Recommendations
591        if !self.recommendations.is_empty() {
592            output.push_str(
593                "├─────────────────────────────────────────────────────────────────────┤\n",
594            );
595            output.push_str(
596                "│ RECOMMENDATIONS                                                     │\n",
597            );
598            for rec in &self.recommendations {
599                output.push_str(&format!("│   • {:<62}│\n", rec));
600            }
601        }
602
603        output
604            .push_str("└─────────────────────────────────────────────────────────────────────┘\n");
605
606        output
607    }
608}
609
610/// Recalibration function using Platt scaling
611pub fn platt_scale(confidence: f32, a: f32, b: f32) -> f32 {
612    1.0 / (1.0 + (-a * confidence + b).exp())
613}
614
615/// Temperature scaling for recalibration
616pub fn temperature_scale(logit: f32, temperature: f32) -> f32 {
617    1.0 / (1.0 + (-logit / temperature).exp())
618}
619
620/// Confidence adjustment recommendations
621pub struct ConfidenceAdjuster;
622
623impl ConfidenceAdjuster {
624    /// Adjust confidence based on calibration data
625    pub fn adjust(raw_confidence: f32, diagnosis: CalibrationDiagnosis) -> f32 {
626        match diagnosis {
627            CalibrationDiagnosis::SeverelyOverconfident => {
628                // Reduce by 25%
629                raw_confidence * 0.75
630            }
631            CalibrationDiagnosis::Overconfident => {
632                // Reduce by 15%
633                raw_confidence * 0.85
634            }
635            CalibrationDiagnosis::SlightlyOverconfident => {
636                // Reduce by 5%
637                raw_confidence * 0.95
638            }
639            CalibrationDiagnosis::Underconfident => {
640                // Increase by 10% (but cap at 0.95)
641                (raw_confidence * 1.1).min(0.95)
642            }
643            _ => raw_confidence,
644        }
645    }
646
647    /// Convert confidence to appropriate qualifier
648    pub fn confidence_to_qualifier(confidence: f32) -> &'static str {
649        if confidence >= 0.95 {
650            "certainly"
651        } else if confidence >= 0.85 {
652            "very likely"
653        } else if confidence >= 0.70 {
654            "probably"
655        } else if confidence >= 0.50 {
656            "possibly"
657        } else if confidence >= 0.30 {
658            "unlikely"
659        } else {
660            "very unlikely"
661        }
662    }
663}
664
665#[cfg(test)]
666mod tests {
667    use super::*;
668
669    #[test]
670    fn test_perfect_calibration() {
671        let mut tracker = CalibrationTracker::new();
672
673        // Perfectly calibrated: 90% confident, 90% correct
674        for _ in 0..9 {
675            tracker.record(Prediction::new(0.9, true));
676        }
677        tracker.record(Prediction::new(0.9, false));
678
679        // 50% confident, 50% correct
680        for _ in 0..5 {
681            tracker.record(Prediction::new(0.5, true));
682        }
683        for _ in 0..5 {
684            tracker.record(Prediction::new(0.5, false));
685        }
686
687        let report = tracker.generate_report();
688        assert!(report.ece < 0.15); // Should be reasonably calibrated
689    }
690
691    #[test]
692    fn test_overconfident() {
693        let mut tracker = CalibrationTracker::new();
694
695        // 90% confident but only 50% correct
696        for _ in 0..25 {
697            tracker.record(Prediction::new(0.9, true));
698            tracker.record(Prediction::new(0.9, false));
699        }
700
701        let report = tracker.generate_report();
702        assert!(matches!(
703            report.diagnosis,
704            CalibrationDiagnosis::Overconfident | CalibrationDiagnosis::SeverelyOverconfident
705        ));
706    }
707
708    #[test]
709    fn test_brier_score() {
710        let mut tracker = CalibrationTracker::new();
711
712        // Perfect predictions
713        tracker.record(Prediction::new(1.0, true));
714        tracker.record(Prediction::new(0.0, false));
715
716        let brier = tracker.brier_score();
717        assert!((brier - 0.0).abs() < 0.01);
718    }
719
720    #[test]
721    fn test_category_tracking() {
722        let mut tracker = CalibrationTracker::with_config(CalibrationConfig {
723            track_categories: true,
724            ..Default::default()
725        });
726
727        tracker.record(Prediction::new(0.8, true).with_category("math"));
728        tracker.record(Prediction::new(0.7, true).with_category("math"));
729        tracker.record(Prediction::new(0.9, false).with_category("logic"));
730
731        let report = tracker.generate_report();
732        assert!(report.category_stats.contains_key("math"));
733        assert_eq!(report.category_stats["math"].count, 2);
734    }
735
736    #[test]
737    fn test_confidence_adjuster() {
738        let adjusted = ConfidenceAdjuster::adjust(0.9, CalibrationDiagnosis::SeverelyOverconfident);
739        assert!((adjusted - 0.675).abs() < 0.01);
740
741        let qualifier = ConfidenceAdjuster::confidence_to_qualifier(0.85);
742        assert_eq!(qualifier, "very likely");
743    }
744}