Skip to main content

entrenar/eval/evaluator/
metric.rs

1//! Evaluation metric definitions
2
3use super::super::classification::Average;
4use std::fmt;
5
6/// ROUGE variant for text generation evaluation
7#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
8pub enum RougeVariant {
9    /// Unigram overlap
10    Rouge1,
11    /// Bigram overlap
12    Rouge2,
13    /// Longest common subsequence
14    RougeL,
15}
16
17impl fmt::Display for RougeVariant {
18    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19        match self {
20            RougeVariant::Rouge1 => write!(f, "ROUGE-1"),
21            RougeVariant::Rouge2 => write!(f, "ROUGE-2"),
22            RougeVariant::RougeL => write!(f, "ROUGE-L"),
23        }
24    }
25}
26
27/// Available evaluation metrics
28#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
29pub enum Metric {
30    // Classification
31    /// Classification accuracy
32    Accuracy,
33    /// Precision with averaging strategy
34    Precision(Average),
35    /// Recall with averaging strategy
36    Recall(Average),
37    /// F1 score with averaging strategy
38    F1(Average),
39    // Regression
40    /// R² coefficient of determination
41    R2,
42    /// Mean Squared Error
43    MSE,
44    /// Mean Absolute Error
45    MAE,
46    /// Root Mean Squared Error
47    RMSE,
48    // Clustering
49    /// Silhouette score
50    Silhouette,
51    /// Inertia
52    Inertia,
53    // ASR (Automatic Speech Recognition)
54    /// Word Error Rate (lower is better)
55    WER,
56    /// Inverse Real-Time Factor (higher is better: RTFx=100 means 100x real-time)
57    RTFx,
58    // Text Generation
59    /// BLEU score (higher is better)
60    BLEU,
61    /// ROUGE score with variant (higher is better)
62    ROUGE(RougeVariant),
63    /// Perplexity (lower is better)
64    Perplexity,
65    // LLM Benchmarks
66    /// MMLU accuracy (higher is better, covers MMLU-PRO, BBH, etc.)
67    MMLUAccuracy,
68    // Code
69    /// pass@k — unbiased estimator, parameterized by k (higher is better)
70    PassAtK(usize),
71    // Retrieval
72    /// NDCG@k — normalized discounted cumulative gain (higher is better)
73    NDCGAtK(usize),
74}
75
76impl Metric {
77    /// Whether higher values are better for this metric
78    pub fn higher_is_better(&self) -> bool {
79        !matches!(
80            self,
81            Metric::MSE
82                | Metric::MAE
83                | Metric::RMSE
84                | Metric::Inertia
85                | Metric::WER
86                | Metric::Perplexity
87        )
88    }
89
90    /// Get metric name as string
91    pub fn name(&self) -> &'static str {
92        match self {
93            Metric::Accuracy => "Accuracy",
94            Metric::Precision(_) => "Precision",
95            Metric::Recall(_) => "Recall",
96            Metric::F1(_) => "F1",
97            Metric::R2 => "R²",
98            Metric::MSE => "MSE",
99            Metric::MAE => "MAE",
100            Metric::RMSE => "RMSE",
101            Metric::Silhouette => "Silhouette",
102            Metric::Inertia => "Inertia",
103            Metric::WER => "WER",
104            Metric::RTFx => "RTFx",
105            Metric::BLEU => "BLEU",
106            Metric::ROUGE(_) => "ROUGE",
107            Metric::Perplexity => "Perplexity",
108            Metric::MMLUAccuracy => "MMLU",
109            Metric::PassAtK(_) => "pass@k",
110            Metric::NDCGAtK(_) => "NDCG@k",
111        }
112    }
113}
114
115impl fmt::Display for Metric {
116    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117        match self {
118            Metric::Accuracy
119            | Metric::R2
120            | Metric::MSE
121            | Metric::MAE
122            | Metric::RMSE
123            | Metric::Silhouette
124            | Metric::Inertia
125            | Metric::WER
126            | Metric::RTFx
127            | Metric::BLEU
128            | Metric::Perplexity
129            | Metric::MMLUAccuracy => write!(f, "{}", self.name()),
130            Metric::Precision(avg) => write!(f, "Precision({avg:?})"),
131            Metric::Recall(avg) => write!(f, "Recall({avg:?})"),
132            Metric::F1(avg) => write!(f, "F1({avg:?})"),
133            Metric::ROUGE(variant) => write!(f, "{variant}"),
134            Metric::PassAtK(k) => write!(f, "pass@{k}"),
135            Metric::NDCGAtK(k) => write!(f, "NDCG@{k}"),
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    // name() arm tests — each exercises its exact match arm with `_` wildcard
145    #[test]
146    fn test_name_precision_arm() {
147        let m = Metric::Precision(Average::Macro);
148        match m {
149            Metric::Precision(_) => assert_eq!(m.name(), "Precision"),
150            _ => unreachable!(),
151        }
152    }
153
154    #[test]
155    fn test_name_recall_arm() {
156        let m = Metric::Recall(Average::Micro);
157        match m {
158            Metric::Recall(_) => assert_eq!(m.name(), "Recall"),
159            _ => unreachable!(),
160        }
161    }
162
163    #[test]
164    fn test_name_f1_arm() {
165        let m = Metric::F1(Average::Weighted);
166        match m {
167            Metric::F1(_) => assert_eq!(m.name(), "F1"),
168            _ => unreachable!(),
169        }
170    }
171
172    #[test]
173    fn test_name_rouge_arm() {
174        let m = Metric::ROUGE(RougeVariant::RougeL);
175        match m {
176            Metric::ROUGE(_) => assert_eq!(m.name(), "ROUGE"),
177            _ => unreachable!(),
178        }
179    }
180
181    #[test]
182    fn test_name_pass_at_k_arm() {
183        let m = Metric::PassAtK(1);
184        match m {
185            Metric::PassAtK(_) => assert_eq!(m.name(), "pass@k"),
186            _ => unreachable!(),
187        }
188    }
189
190    #[test]
191    fn test_name_ndcg_at_k_arm() {
192        let m = Metric::NDCGAtK(5);
193        match m {
194            Metric::NDCGAtK(_) => assert_eq!(m.name(), "NDCG@k"),
195            _ => unreachable!(),
196        }
197    }
198
199    // Display arm tests — each exercises its exact match arm with named binding
200    #[test]
201    fn test_display_precision_avg_arm() {
202        let m = Metric::Precision(Average::Macro);
203        match m {
204            Metric::Precision(avg) => {
205                let _ = avg;
206                assert_eq!(m.to_string(), "Precision(Macro)");
207            }
208            _ => unreachable!(),
209        }
210    }
211
212    #[test]
213    fn test_display_recall_avg_arm() {
214        let m = Metric::Recall(Average::Micro);
215        match m {
216            Metric::Recall(avg) => {
217                let _ = avg;
218                assert_eq!(m.to_string(), "Recall(Micro)");
219            }
220            _ => unreachable!(),
221        }
222    }
223
224    #[test]
225    fn test_display_f1_avg_arm() {
226        let m = Metric::F1(Average::Weighted);
227        match m {
228            Metric::F1(avg) => {
229                let _ = avg;
230                assert_eq!(m.to_string(), "F1(Weighted)");
231            }
232            _ => unreachable!(),
233        }
234    }
235
236    #[test]
237    fn test_display_rouge_variant_arm() {
238        let m = Metric::ROUGE(RougeVariant::Rouge1);
239        match m {
240            Metric::ROUGE(variant) => {
241                let _ = variant;
242                assert_eq!(m.to_string(), "ROUGE-1");
243            }
244            _ => unreachable!(),
245        }
246    }
247
248    #[test]
249    fn test_display_pass_at_k_arm() {
250        let m = Metric::PassAtK(5);
251        match m {
252            Metric::PassAtK(k) => {
253                let _ = k;
254                assert_eq!(m.to_string(), "pass@5");
255            }
256            _ => unreachable!(),
257        }
258    }
259
260    #[test]
261    fn test_display_ndcg_at_k_arm() {
262        let m = Metric::NDCGAtK(10);
263        match m {
264            Metric::NDCGAtK(k) => {
265                let _ = k;
266                assert_eq!(m.to_string(), "NDCG@10");
267            }
268            _ => unreachable!(),
269        }
270    }
271
272    #[test]
273    fn test_display_simple_variants() {
274        assert_eq!(Metric::Accuracy.to_string(), "Accuracy");
275        assert_eq!(Metric::R2.to_string(), "R²");
276        assert_eq!(Metric::MSE.to_string(), "MSE");
277        assert_eq!(Metric::MAE.to_string(), "MAE");
278        assert_eq!(Metric::RMSE.to_string(), "RMSE");
279        assert_eq!(Metric::Silhouette.to_string(), "Silhouette");
280        assert_eq!(Metric::Inertia.to_string(), "Inertia");
281        assert_eq!(Metric::WER.to_string(), "WER");
282        assert_eq!(Metric::RTFx.to_string(), "RTFx");
283        assert_eq!(Metric::BLEU.to_string(), "BLEU");
284        assert_eq!(Metric::Perplexity.to_string(), "Perplexity");
285        assert_eq!(Metric::MMLUAccuracy.to_string(), "MMLU");
286    }
287
288    #[test]
289    fn test_name_simple_variants() {
290        assert_eq!(Metric::Accuracy.name(), "Accuracy");
291        assert_eq!(Metric::R2.name(), "R²");
292        assert_eq!(Metric::MSE.name(), "MSE");
293        assert_eq!(Metric::MAE.name(), "MAE");
294        assert_eq!(Metric::RMSE.name(), "RMSE");
295        assert_eq!(Metric::Silhouette.name(), "Silhouette");
296        assert_eq!(Metric::Inertia.name(), "Inertia");
297        assert_eq!(Metric::WER.name(), "WER");
298        assert_eq!(Metric::RTFx.name(), "RTFx");
299        assert_eq!(Metric::BLEU.name(), "BLEU");
300        assert_eq!(Metric::Perplexity.name(), "Perplexity");
301        assert_eq!(Metric::MMLUAccuracy.name(), "MMLU");
302    }
303
304    #[test]
305    fn test_higher_is_better_all_variants() {
306        assert!(Metric::Accuracy.higher_is_better());
307        assert!(Metric::Precision(Average::Macro).higher_is_better());
308        assert!(Metric::Recall(Average::Micro).higher_is_better());
309        assert!(Metric::F1(Average::Weighted).higher_is_better());
310        assert!(Metric::R2.higher_is_better());
311        assert!(!Metric::MSE.higher_is_better());
312        assert!(!Metric::MAE.higher_is_better());
313        assert!(!Metric::RMSE.higher_is_better());
314        assert!(Metric::Silhouette.higher_is_better());
315        assert!(!Metric::Inertia.higher_is_better());
316        assert!(!Metric::WER.higher_is_better());
317        assert!(Metric::RTFx.higher_is_better());
318        assert!(Metric::BLEU.higher_is_better());
319        assert!(Metric::ROUGE(RougeVariant::Rouge1).higher_is_better());
320        assert!(!Metric::Perplexity.higher_is_better());
321        assert!(Metric::MMLUAccuracy.higher_is_better());
322        assert!(Metric::PassAtK(1).higher_is_better());
323        assert!(Metric::NDCGAtK(5).higher_is_better());
324    }
325
326    #[test]
327    fn test_rouge_variant_display() {
328        assert_eq!(RougeVariant::Rouge1.to_string(), "ROUGE-1");
329        assert_eq!(RougeVariant::Rouge2.to_string(), "ROUGE-2");
330        assert_eq!(RougeVariant::RougeL.to_string(), "ROUGE-L");
331    }
332}