1pub mod metrics;
41pub mod reasoning;
42
43pub use metrics::{
44 average_precision, mean_average_precision, mean_reciprocal_rank, ndcg_at_k, precision_at_k,
45 recall_at_k, EvaluationResult, QueryResult, RetrievalMetrics,
46};
47
48pub use reasoning::{
49 BedRockMetrics, BenchmarkResult, BrutalHonestyMetrics, CalibrationMetrics, ConsistencyMetrics,
50 GigaThinkMetrics, LaserLogicMetrics, Profile, ProofGuardMetrics, ReasoningMetrics,
51 ThinkToolMetrics,
52};
53
54#[derive(Debug, Clone)]
56pub struct ReasoningEvalConfig {
57 pub consistency_runs: usize,
59 pub profile: Profile,
61 pub benchmarks: Vec<String>,
63 pub measure_calibration: bool,
65 pub measure_thinktool_effectiveness: bool,
67}
68
69impl Default for ReasoningEvalConfig {
70 fn default() -> Self {
71 Self {
72 consistency_runs: 5,
73 profile: Profile::Balanced,
74 benchmarks: vec!["gsm8k".into(), "arc_challenge".into()],
75 measure_calibration: true,
76 measure_thinktool_effectiveness: true,
77 }
78 }
79}
80
81#[derive(Debug, Clone)]
83pub struct RetrievalEvalConfig {
84 pub k_values: Vec<usize>,
86 pub compute_mrr: bool,
88 pub compute_map: bool,
90}
91
92impl Default for RetrievalEvalConfig {
93 fn default() -> Self {
94 Self {
95 k_values: vec![5, 10, 20],
96 compute_mrr: true,
97 compute_map: true,
98 }
99 }
100}
101
102pub fn evaluate_reasoning(
104 results: &[BenchmarkResult],
105 config: &ReasoningEvalConfig,
106) -> ReasoningEvalSummary {
107 let mut summary = ReasoningEvalSummary::new(&config.benchmarks);
108
109 for result in results {
110 summary.add_result(result);
111 }
112
113 summary.finalize()
114}
115
116#[derive(Debug, Clone)]
118pub struct ReasoningEvalSummary {
119 pub num_benchmarks: usize,
120 pub accuracy: std::collections::HashMap<String, f64>,
122 pub improvement: std::collections::HashMap<String, f64>,
124 pub self_consistency: f64,
126 pub calibration_ece: f64,
128 pub thinktool_scores: std::collections::HashMap<String, f64>,
130
131 accuracy_sums: std::collections::HashMap<String, (f64, usize)>,
133}
134
135impl ReasoningEvalSummary {
136 pub fn new(benchmarks: &[String]) -> Self {
137 let mut accuracy = std::collections::HashMap::new();
138 let mut improvement = std::collections::HashMap::new();
139 let mut accuracy_sums = std::collections::HashMap::new();
140
141 for b in benchmarks {
142 accuracy.insert(b.clone(), 0.0);
143 improvement.insert(b.clone(), 0.0);
144 accuracy_sums.insert(b.clone(), (0.0, 0));
145 }
146
147 Self {
148 num_benchmarks: benchmarks.len(),
149 accuracy,
150 improvement,
151 self_consistency: 0.0,
152 calibration_ece: 0.0,
153 thinktool_scores: std::collections::HashMap::new(),
154 accuracy_sums,
155 }
156 }
157
158 fn add_result(&mut self, result: &BenchmarkResult) {
159 if let Some((sum, count)) = self.accuracy_sums.get_mut(&result.benchmark) {
160 *sum += result.accuracy;
161 *count += 1;
162 }
163 }
164
165 fn finalize(mut self) -> Self {
166 for (benchmark, (sum, count)) in &self.accuracy_sums {
167 if *count > 0 {
168 self.accuracy
169 .insert(benchmark.clone(), sum / (*count as f64));
170 }
171 }
172 self
173 }
174
175 pub fn check_targets(&self, targets: &ReasoningTargets) -> TargetResult {
177 let mut passed = true;
178 let mut failures = Vec::new();
179
180 if let Some(&target) = targets.gsm8k_improvement.as_ref() {
182 if let Some(&actual) = self.improvement.get("gsm8k") {
183 if actual < target {
184 passed = false;
185 failures.push(format!(
186 "GSM8K improvement: {:.1}% < {:.1}%",
187 actual * 100.0,
188 target * 100.0
189 ));
190 }
191 }
192 }
193
194 if let Some(target) = targets.self_consistency {
196 if self.self_consistency < target {
197 passed = false;
198 failures.push(format!(
199 "Self-consistency: {:.1}% < {:.1}%",
200 self.self_consistency * 100.0,
201 target * 100.0
202 ));
203 }
204 }
205
206 if let Some(target) = targets.calibration_ece_max {
208 if self.calibration_ece > target {
209 passed = false;
210 failures.push(format!(
211 "Calibration ECE: {:.3} > {:.3}",
212 self.calibration_ece, target
213 ));
214 }
215 }
216
217 TargetResult { passed, failures }
218 }
219}
220
221#[derive(Debug, Clone, Default)]
223pub struct ReasoningTargets {
224 pub gsm8k_improvement: Option<f64>,
225 pub arc_c_improvement: Option<f64>,
226 pub logiqa_improvement: Option<f64>,
227 pub self_consistency: Option<f64>,
228 pub calibration_ece_max: Option<f64>,
229}
230
231impl ReasoningTargets {
232 pub fn v1_targets() -> Self {
234 Self {
235 gsm8k_improvement: Some(0.15), arc_c_improvement: Some(0.08), logiqa_improvement: None, self_consistency: Some(0.85), calibration_ece_max: Some(0.10), }
241 }
242
243 pub fn v1_5_targets() -> Self {
245 Self {
246 gsm8k_improvement: Some(0.20),
247 arc_c_improvement: Some(0.10),
248 logiqa_improvement: Some(0.20),
249 self_consistency: Some(0.90),
250 calibration_ece_max: Some(0.08),
251 }
252 }
253}
254
255#[derive(Debug, Clone)]
257pub struct TargetResult {
258 pub passed: bool,
259 pub failures: Vec<String>,
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265
266 #[test]
267 fn test_reasoning_config_default() {
268 let config = ReasoningEvalConfig::default();
269 assert_eq!(config.consistency_runs, 5);
270 assert!(config.measure_calibration);
271 }
272
273 #[test]
274 fn test_v1_targets() {
275 let targets = ReasoningTargets::v1_targets();
276 assert_eq!(targets.gsm8k_improvement, Some(0.15));
277 assert_eq!(targets.self_consistency, Some(0.85));
278 }
279
280 #[test]
281 fn test_target_check_pass() {
282 let mut summary = ReasoningEvalSummary::new(&["gsm8k".into()]);
283 summary.improvement.insert("gsm8k".into(), 0.20);
284 summary.self_consistency = 0.90;
285 summary.calibration_ece = 0.05;
286
287 let targets = ReasoningTargets::v1_targets();
288 let result = summary.check_targets(&targets);
289
290 assert!(result.passed);
291 assert!(result.failures.is_empty());
292 }
293
294 #[test]
295 fn test_target_check_fail() {
296 let mut summary = ReasoningEvalSummary::new(&["gsm8k".into()]);
297 summary.improvement.insert("gsm8k".into(), 0.10); summary.self_consistency = 0.75; summary.calibration_ece = 0.15; let targets = ReasoningTargets::v1_targets();
302 let result = summary.check_targets(&targets);
303
304 assert!(!result.passed);
305 assert_eq!(result.failures.len(), 3);
306 }
307}