reasonkit/thinktool/
quality.rs

1//! # Quality Metrics Collection System
2//!
3//! Comprehensive quality tracking across all ThinkTools modules.
4//! Provides dashboards, trends, and improvement recommendations.
5//!
6//! ## Metrics Collected
7//!
8//! | Category | Metrics |
9//! |----------|---------|
10//! | Accuracy | GSM8K, MATH, ARC-C benchmark scores |
11//! | Calibration | Brier score, ECE, overconfidence ratio |
12//! | Reasoning | PRM scores, ToT success rate, step validity |
13//! | Verification | Triangulation score, fact-check accuracy |
14//! | Debate | Win rate, argument strength, verdict confidence |
15//!
16//! ## Usage
17//!
18//! ```rust,ignore
19//! use reasonkit::thinktool::quality::{QualityDashboard, QualityMetric};
20//!
21//! let mut dashboard = QualityDashboard::new();
22//! dashboard.record_metric(QualityMetric::Accuracy { benchmark: "GSM8K", score: 0.859 });
23//! dashboard.record_metric(QualityMetric::Calibration { brier: 0.15, ece: 0.08 });
24//!
25//! let report = dashboard.generate_report();
26//! println!("{}", report.format());
27//! ```
28
29use serde::{Deserialize, Serialize};
30use std::collections::HashMap;
31use std::time::{SystemTime, UNIX_EPOCH};
32
33/// Individual quality metric
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub enum QualityMetric {
36    /// Benchmark accuracy scores
37    Accuracy {
38        benchmark: String,
39        score: f32,
40        samples: usize,
41    },
42    /// Calibration metrics
43    Calibration {
44        brier_score: f32,
45        ece: f32,
46        overconfidence_ratio: f32,
47    },
48    /// Process Reward Model metrics
49    PrmScore {
50        avg_step_correctness: f32,
51        critical_issues: usize,
52        sound_chains: f32,
53    },
54    /// Tree-of-Thoughts metrics
55    TotMetrics {
56        success_rate: f32,
57        avg_depth: f32,
58        nodes_explored: usize,
59        pruning_rate: f32,
60    },
61    /// Triangulation metrics
62    Triangulation {
63        verification_rate: f32,
64        avg_sources: f32,
65        contradiction_rate: f32,
66    },
67    /// Debate metrics
68    Debate {
69        advocate_win_rate: f32,
70        avg_argument_strength: f32,
71        consensus_rate: f32,
72    },
73    /// Toulmin argument quality
74    Argumentation {
75        soundness_rate: f32,
76        avg_grounds_score: f32,
77        avg_warrant_score: f32,
78    },
79    /// Latency metrics
80    Latency {
81        avg_ms: f64,
82        p95_ms: f64,
83        p99_ms: f64,
84    },
85    /// Token usage
86    TokenUsage {
87        avg_tokens: usize,
88        total_tokens: usize,
89        efficiency: f32,
90    },
91    /// Custom metric
92    Custom {
93        name: String,
94        value: f32,
95        unit: Option<String>,
96    },
97}
98
99impl QualityMetric {
100    pub fn category(&self) -> &'static str {
101        match self {
102            QualityMetric::Accuracy { .. } => "accuracy",
103            QualityMetric::Calibration { .. } => "calibration",
104            QualityMetric::PrmScore { .. } => "reasoning",
105            QualityMetric::TotMetrics { .. } => "exploration",
106            QualityMetric::Triangulation { .. } => "verification",
107            QualityMetric::Debate { .. } => "debate",
108            QualityMetric::Argumentation { .. } => "argumentation",
109            QualityMetric::Latency { .. } => "performance",
110            QualityMetric::TokenUsage { .. } => "efficiency",
111            QualityMetric::Custom { .. } => "custom",
112        }
113    }
114
115    pub fn primary_value(&self) -> f32 {
116        match self {
117            QualityMetric::Accuracy { score, .. } => *score,
118            QualityMetric::Calibration { brier_score, .. } => 1.0 - *brier_score, // Invert (lower is better)
119            QualityMetric::PrmScore {
120                avg_step_correctness,
121                ..
122            } => *avg_step_correctness,
123            QualityMetric::TotMetrics { success_rate, .. } => *success_rate,
124            QualityMetric::Triangulation {
125                verification_rate, ..
126            } => *verification_rate,
127            QualityMetric::Debate {
128                avg_argument_strength,
129                ..
130            } => *avg_argument_strength,
131            QualityMetric::Argumentation { soundness_rate, .. } => *soundness_rate,
132            QualityMetric::Latency { avg_ms, .. } => (1000.0 / *avg_ms as f32).min(1.0), // Faster = better
133            QualityMetric::TokenUsage { efficiency, .. } => *efficiency,
134            QualityMetric::Custom { value, .. } => *value,
135        }
136    }
137}
138
139/// Timestamped metric record
140#[derive(Debug, Clone, Serialize, Deserialize)]
141pub struct MetricRecord {
142    pub metric: QualityMetric,
143    pub timestamp: u64,
144    pub profile: Option<String>,
145    pub session_id: Option<String>,
146}
147
148impl MetricRecord {
149    pub fn new(metric: QualityMetric) -> Self {
150        let timestamp = SystemTime::now()
151            .duration_since(UNIX_EPOCH)
152            .map(|d| d.as_secs())
153            .unwrap_or(0);
154
155        Self {
156            metric,
157            timestamp,
158            profile: None,
159            session_id: None,
160        }
161    }
162
163    pub fn with_profile(mut self, profile: impl Into<String>) -> Self {
164        self.profile = Some(profile.into());
165        self
166    }
167}
168
169/// Quality targets for comparison
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct QualityTargets {
172    /// Accuracy targets per benchmark
173    pub accuracy: HashMap<String, f32>,
174    /// Maximum acceptable Brier score
175    pub max_brier_score: f32,
176    /// Maximum acceptable ECE
177    pub max_ece: f32,
178    /// Minimum PRM step correctness
179    pub min_prm_correctness: f32,
180    /// Minimum ToT success rate
181    pub min_tot_success: f32,
182    /// Minimum triangulation verification rate
183    pub min_triangulation: f32,
184    /// Maximum acceptable latency (ms)
185    pub max_latency_ms: f64,
186}
187
188impl Default for QualityTargets {
189    fn default() -> Self {
190        let mut accuracy = HashMap::new();
191        accuracy.insert("GSM8K".into(), 0.859);
192        accuracy.insert("MATH".into(), 0.365);
193        accuracy.insert("ARC-C".into(), 0.90);
194        accuracy.insert("TruthfulQA".into(), 0.72);
195
196        Self {
197            accuracy,
198            max_brier_score: 0.20,
199            max_ece: 0.10,
200            min_prm_correctness: 0.80,
201            min_tot_success: 0.60,
202            min_triangulation: 0.70,
203            max_latency_ms: 5000.0,
204        }
205    }
206}
207
208/// Quality score aggregation
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct QualityScore {
211    /// Overall quality score (0-100)
212    pub overall: f32,
213    /// Per-category scores
214    pub categories: HashMap<String, f32>,
215    /// Grade (A-F)
216    pub grade: QualityGrade,
217    /// Trend (improving/declining/stable)
218    pub trend: Trend,
219    /// Areas needing improvement
220    pub improvement_areas: Vec<String>,
221}
222
223#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
224pub enum QualityGrade {
225    A, // 90-100
226    B, // 80-89
227    C, // 70-79
228    D, // 60-69
229    F, // < 60
230}
231
232impl QualityGrade {
233    pub fn from_score(score: f32) -> Self {
234        match (score * 100.0) as u32 {
235            90..=100 => QualityGrade::A,
236            80..=89 => QualityGrade::B,
237            70..=79 => QualityGrade::C,
238            60..=69 => QualityGrade::D,
239            _ => QualityGrade::F,
240        }
241    }
242
243    pub fn label(&self) -> &'static str {
244        match self {
245            QualityGrade::A => "Excellent",
246            QualityGrade::B => "Good",
247            QualityGrade::C => "Acceptable",
248            QualityGrade::D => "Needs Improvement",
249            QualityGrade::F => "Failing",
250        }
251    }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
255pub enum Trend {
256    Improving,
257    Stable,
258    Declining,
259    Unknown,
260}
261
262/// Quality dashboard for tracking and reporting
263pub struct QualityDashboard {
264    pub targets: QualityTargets,
265    records: Vec<MetricRecord>,
266    /// Category weights for overall score
267    weights: HashMap<String, f32>,
268}
269
270impl QualityDashboard {
271    pub fn new() -> Self {
272        let mut weights = HashMap::new();
273        weights.insert("accuracy".into(), 0.25);
274        weights.insert("calibration".into(), 0.15);
275        weights.insert("reasoning".into(), 0.15);
276        weights.insert("verification".into(), 0.15);
277        weights.insert("argumentation".into(), 0.10);
278        weights.insert("exploration".into(), 0.10);
279        weights.insert("performance".into(), 0.05);
280        weights.insert("efficiency".into(), 0.05);
281
282        Self {
283            targets: QualityTargets::default(),
284            records: Vec::new(),
285            weights,
286        }
287    }
288
289    pub fn with_targets(mut self, targets: QualityTargets) -> Self {
290        self.targets = targets;
291        self
292    }
293
294    /// Record a metric
295    pub fn record_metric(&mut self, metric: QualityMetric) {
296        self.records.push(MetricRecord::new(metric));
297    }
298
299    /// Record a metric with profile
300    pub fn record_with_profile(&mut self, metric: QualityMetric, profile: &str) {
301        self.records
302            .push(MetricRecord::new(metric).with_profile(profile));
303    }
304
305    /// Get records by category
306    pub fn get_by_category(&self, category: &str) -> Vec<&MetricRecord> {
307        self.records
308            .iter()
309            .filter(|r| r.metric.category() == category)
310            .collect()
311    }
312
313    /// Get latest record for each category
314    pub fn get_latest_by_category(&self) -> HashMap<String, &MetricRecord> {
315        let mut latest: HashMap<String, &MetricRecord> = HashMap::new();
316
317        for record in &self.records {
318            let cat = record.metric.category().to_string();
319            match latest.get(&cat) {
320                None => {
321                    latest.insert(cat, record);
322                }
323                Some(existing) if record.timestamp > existing.timestamp => {
324                    latest.insert(cat, record);
325                }
326                _ => {}
327            }
328        }
329
330        latest
331    }
332
333    /// Compute category score
334    fn compute_category_score(&self, category: &str) -> Option<f32> {
335        let records: Vec<_> = self.get_by_category(category);
336        if records.is_empty() {
337            return None;
338        }
339
340        // Use latest N records for averaging
341        let recent: Vec<_> = records.into_iter().rev().take(10).collect();
342        let avg =
343            recent.iter().map(|r| r.metric.primary_value()).sum::<f32>() / recent.len() as f32;
344
345        Some(avg)
346    }
347
348    /// Compute overall quality score
349    pub fn compute_score(&self) -> QualityScore {
350        let mut categories = HashMap::new();
351        let mut weighted_sum = 0.0f32;
352        let mut weight_sum = 0.0f32;
353
354        for (cat, weight) in &self.weights {
355            if let Some(score) = self.compute_category_score(cat) {
356                categories.insert(cat.clone(), score);
357                weighted_sum += score * weight;
358                weight_sum += weight;
359            }
360        }
361
362        let overall = if weight_sum > 0.0 {
363            weighted_sum / weight_sum
364        } else {
365            0.0
366        };
367
368        let grade = QualityGrade::from_score(overall);
369
370        // Find improvement areas
371        let mut improvement_areas = Vec::new();
372        for (cat, score) in &categories {
373            if *score < 0.7 {
374                improvement_areas.push(format!("{} ({:.0}%)", cat, score * 100.0));
375            }
376        }
377
378        // Compute trend (compare last 10 vs previous 10)
379        let trend = self.compute_trend();
380
381        QualityScore {
382            overall,
383            categories,
384            grade,
385            trend,
386            improvement_areas,
387        }
388    }
389
390    fn compute_trend(&self) -> Trend {
391        if self.records.len() < 20 {
392            return Trend::Unknown;
393        }
394
395        let mid = self.records.len() / 2;
396        let first_half_avg = self.records[..mid]
397            .iter()
398            .map(|r| r.metric.primary_value())
399            .sum::<f32>()
400            / mid as f32;
401
402        let second_half_avg = self.records[mid..]
403            .iter()
404            .map(|r| r.metric.primary_value())
405            .sum::<f32>()
406            / (self.records.len() - mid) as f32;
407
408        let diff = second_half_avg - first_half_avg;
409
410        if diff > 0.05 {
411            Trend::Improving
412        } else if diff < -0.05 {
413            Trend::Declining
414        } else {
415            Trend::Stable
416        }
417    }
418
419    /// Check metrics against targets
420    pub fn check_targets(&self) -> Vec<TargetViolation> {
421        let mut violations = Vec::new();
422
423        for record in self.get_latest_by_category().values() {
424            match &record.metric {
425                QualityMetric::Accuracy {
426                    benchmark, score, ..
427                } => {
428                    if let Some(&target) = self.targets.accuracy.get(benchmark) {
429                        if *score < target {
430                            violations.push(TargetViolation {
431                                metric: format!("{} accuracy", benchmark),
432                                target,
433                                actual: *score,
434                                gap: target - score,
435                            });
436                        }
437                    }
438                }
439                QualityMetric::Calibration {
440                    brier_score, ece, ..
441                } => {
442                    if *brier_score > self.targets.max_brier_score {
443                        violations.push(TargetViolation {
444                            metric: "Brier score".into(),
445                            target: self.targets.max_brier_score,
446                            actual: *brier_score,
447                            gap: *brier_score - self.targets.max_brier_score,
448                        });
449                    }
450                    if *ece > self.targets.max_ece {
451                        violations.push(TargetViolation {
452                            metric: "ECE".into(),
453                            target: self.targets.max_ece,
454                            actual: *ece,
455                            gap: *ece - self.targets.max_ece,
456                        });
457                    }
458                }
459                QualityMetric::PrmScore {
460                    avg_step_correctness,
461                    ..
462                } => {
463                    if *avg_step_correctness < self.targets.min_prm_correctness {
464                        violations.push(TargetViolation {
465                            metric: "PRM step correctness".into(),
466                            target: self.targets.min_prm_correctness,
467                            actual: *avg_step_correctness,
468                            gap: self.targets.min_prm_correctness - avg_step_correctness,
469                        });
470                    }
471                }
472                QualityMetric::Latency { avg_ms, .. } => {
473                    if *avg_ms > self.targets.max_latency_ms {
474                        violations.push(TargetViolation {
475                            metric: "Latency".into(),
476                            target: self.targets.max_latency_ms as f32,
477                            actual: *avg_ms as f32,
478                            gap: (*avg_ms - self.targets.max_latency_ms) as f32,
479                        });
480                    }
481                }
482                _ => {}
483            }
484        }
485
486        violations
487    }
488
489    /// Generate quality report
490    pub fn generate_report(&self) -> QualityReport {
491        let score = self.compute_score();
492        let violations = self.check_targets();
493
494        let recommendations = self.generate_recommendations(&score, &violations);
495
496        QualityReport {
497            score,
498            violations,
499            total_records: self.records.len(),
500            recommendations,
501            timestamp: SystemTime::now()
502                .duration_since(UNIX_EPOCH)
503                .map(|d| d.as_secs())
504                .unwrap_or(0),
505        }
506    }
507
508    fn generate_recommendations(
509        &self,
510        score: &QualityScore,
511        violations: &[TargetViolation],
512    ) -> Vec<String> {
513        let mut recs = Vec::new();
514
515        // Based on grade
516        match score.grade {
517            QualityGrade::F | QualityGrade::D => {
518                recs.push("Use --paranoid profile for maximum verification".into());
519                recs.push("Enable PRM for step-by-step validation".into());
520            }
521            QualityGrade::C => {
522                recs.push("Consider using --deep profile for thorough analysis".into());
523            }
524            _ => {}
525        }
526
527        // Based on violations
528        for violation in violations {
529            if violation.metric.contains("accuracy") {
530                recs.push(format!(
531                    "Improve {} - currently {:.1}% below target",
532                    violation.metric,
533                    violation.gap * 100.0
534                ));
535            }
536            if violation.metric.contains("Brier") || violation.metric.contains("ECE") {
537                recs.push("Recalibrate confidence levels - currently overconfident".into());
538            }
539            if violation.metric.contains("Latency") {
540                recs.push("Consider using lighter models or caching".into());
541            }
542        }
543
544        // Based on trend
545        if score.trend == Trend::Declining {
546            recs.push("Quality is declining - review recent changes".into());
547        }
548
549        recs
550    }
551
552    /// Clear all records
553    pub fn clear(&mut self) {
554        self.records.clear();
555    }
556
557    /// Export records as JSON
558    pub fn export_json(&self) -> String {
559        serde_json::to_string_pretty(&self.records).unwrap_or_default()
560    }
561}
562
563impl Default for QualityDashboard {
564    fn default() -> Self {
565        Self::new()
566    }
567}
568
569#[derive(Debug, Clone, Serialize, Deserialize)]
570pub struct TargetViolation {
571    pub metric: String,
572    pub target: f32,
573    pub actual: f32,
574    pub gap: f32,
575}
576
577#[derive(Debug, Clone, Serialize, Deserialize)]
578pub struct QualityReport {
579    pub score: QualityScore,
580    pub violations: Vec<TargetViolation>,
581    pub total_records: usize,
582    pub recommendations: Vec<String>,
583    pub timestamp: u64,
584}
585
586impl QualityReport {
587    pub fn format(&self) -> String {
588        let mut output = String::new();
589
590        output
591            .push_str("┌─────────────────────────────────────────────────────────────────────┐\n");
592        output
593            .push_str("│                    QUALITY METRICS REPORT                           │\n");
594        output
595            .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
596
597        // Overall score
598        let grade_icon = match self.score.grade {
599            QualityGrade::A => "⭐",
600            QualityGrade::B => "✓",
601            QualityGrade::C => "○",
602            QualityGrade::D => "⚠",
603            QualityGrade::F => "✗",
604        };
605
606        output.push_str(&format!(
607            "│ OVERALL SCORE: {:.0}/100 {} {:?} ({})            \n",
608            self.score.overall * 100.0,
609            grade_icon,
610            self.score.grade,
611            self.score.grade.label()
612        ));
613
614        let trend_icon = match self.score.trend {
615            Trend::Improving => "📈",
616            Trend::Stable => "➡️",
617            Trend::Declining => "📉",
618            Trend::Unknown => "❓",
619        };
620        output.push_str(&format!(
621            "│ TREND: {:?} {}                                              \n",
622            self.score.trend, trend_icon
623        ));
624
625        // Category breakdown
626        output
627            .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
628        output
629            .push_str("│ CATEGORY SCORES:                                                    │\n");
630
631        let mut cats: Vec<_> = self.score.categories.iter().collect();
632        cats.sort_by(|a, b| b.1.partial_cmp(a.1).unwrap());
633
634        for (cat, score) in cats {
635            let bar_len = (*score * 30.0) as usize;
636            let bar = "█".repeat(bar_len);
637            let gap = " ".repeat(30 - bar_len);
638            let icon = if *score >= 0.8 {
639                "✓"
640            } else if *score >= 0.6 {
641                "○"
642            } else {
643                "✗"
644            };
645            output.push_str(&format!(
646                "│   {:<15} {} |{}{}| {:.0}%\n",
647                cat,
648                icon,
649                bar,
650                gap,
651                score * 100.0
652            ));
653        }
654
655        // Violations
656        if !self.violations.is_empty() {
657            output.push_str(
658                "├─────────────────────────────────────────────────────────────────────┤\n",
659            );
660            output.push_str(
661                "│ TARGET VIOLATIONS:                                                  │\n",
662            );
663            for v in &self.violations {
664                output.push_str(&format!(
665                    "│   ⚠ {}: {:.1} (target: {:.1}, gap: {:.1})\n",
666                    v.metric, v.actual, v.target, v.gap
667                ));
668            }
669        }
670
671        // Improvement areas
672        if !self.score.improvement_areas.is_empty() {
673            output.push_str(
674                "├─────────────────────────────────────────────────────────────────────┤\n",
675            );
676            output.push_str(
677                "│ NEEDS IMPROVEMENT:                                                  │\n",
678            );
679            for area in &self.score.improvement_areas {
680                output.push_str(&format!("│   • {}\n", area));
681            }
682        }
683
684        // Recommendations
685        if !self.recommendations.is_empty() {
686            output.push_str(
687                "├─────────────────────────────────────────────────────────────────────┤\n",
688            );
689            output.push_str(
690                "│ RECOMMENDATIONS:                                                    │\n",
691            );
692            for rec in &self.recommendations {
693                output.push_str(&format!("│   → {}\n", rec));
694            }
695        }
696
697        output
698            .push_str("├─────────────────────────────────────────────────────────────────────┤\n");
699        output.push_str(&format!(
700            "│ Total metrics recorded: {}                                          \n",
701            self.total_records
702        ));
703        output
704            .push_str("└─────────────────────────────────────────────────────────────────────┘\n");
705
706        output
707    }
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713
714    #[test]
715    fn test_quality_dashboard() {
716        let mut dashboard = QualityDashboard::new();
717
718        dashboard.record_metric(QualityMetric::Accuracy {
719            benchmark: "GSM8K".into(),
720            score: 0.85,
721            samples: 100,
722        });
723
724        dashboard.record_metric(QualityMetric::Calibration {
725            brier_score: 0.15,
726            ece: 0.08,
727            overconfidence_ratio: 0.2,
728        });
729
730        let score = dashboard.compute_score();
731        assert!(score.overall > 0.0);
732    }
733
734    #[test]
735    fn test_grade_from_score() {
736        assert_eq!(QualityGrade::from_score(0.95), QualityGrade::A);
737        assert_eq!(QualityGrade::from_score(0.85), QualityGrade::B);
738        assert_eq!(QualityGrade::from_score(0.75), QualityGrade::C);
739        assert_eq!(QualityGrade::from_score(0.65), QualityGrade::D);
740        assert_eq!(QualityGrade::from_score(0.50), QualityGrade::F);
741    }
742
743    #[test]
744    fn test_target_violations() {
745        let mut dashboard = QualityDashboard::new();
746
747        // Record below-target accuracy
748        dashboard.record_metric(QualityMetric::Accuracy {
749            benchmark: "GSM8K".into(),
750            score: 0.70, // Target is 0.859
751            samples: 100,
752        });
753
754        let violations = dashboard.check_targets();
755        assert!(!violations.is_empty());
756        assert!(violations[0].metric.contains("GSM8K"));
757    }
758
759    #[test]
760    fn test_metric_categories() {
761        assert_eq!(
762            QualityMetric::Accuracy {
763                benchmark: "test".into(),
764                score: 0.9,
765                samples: 10
766            }
767            .category(),
768            "accuracy"
769        );
770
771        assert_eq!(
772            QualityMetric::PrmScore {
773                avg_step_correctness: 0.8,
774                critical_issues: 0,
775                sound_chains: 0.9
776            }
777            .category(),
778            "reasoning"
779        );
780    }
781
782    #[test]
783    fn test_report_generation() {
784        let mut dashboard = QualityDashboard::new();
785
786        // Add various metrics
787        dashboard.record_metric(QualityMetric::Accuracy {
788            benchmark: "GSM8K".into(),
789            score: 0.88,
790            samples: 100,
791        });
792        dashboard.record_metric(QualityMetric::PrmScore {
793            avg_step_correctness: 0.85,
794            critical_issues: 2,
795            sound_chains: 0.90,
796        });
797        dashboard.record_metric(QualityMetric::Triangulation {
798            verification_rate: 0.75,
799            avg_sources: 3.2,
800            contradiction_rate: 0.05,
801        });
802
803        let report = dashboard.generate_report();
804        assert!(report.score.overall > 0.0);
805        assert!(!report.score.categories.is_empty());
806    }
807}