1use super::super::classification::Average;
4use std::fmt;
5
6#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
8pub enum RougeVariant {
9 Rouge1,
11 Rouge2,
13 RougeL,
15}
16
17impl fmt::Display for RougeVariant {
18 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
19 match self {
20 RougeVariant::Rouge1 => write!(f, "ROUGE-1"),
21 RougeVariant::Rouge2 => write!(f, "ROUGE-2"),
22 RougeVariant::RougeL => write!(f, "ROUGE-L"),
23 }
24 }
25}
26
27#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
29pub enum Metric {
30 Accuracy,
33 Precision(Average),
35 Recall(Average),
37 F1(Average),
39 R2,
42 MSE,
44 MAE,
46 RMSE,
48 Silhouette,
51 Inertia,
53 WER,
56 RTFx,
58 BLEU,
61 ROUGE(RougeVariant),
63 Perplexity,
65 MMLUAccuracy,
68 PassAtK(usize),
71 NDCGAtK(usize),
74}
75
76impl Metric {
77 pub fn higher_is_better(&self) -> bool {
79 !matches!(
80 self,
81 Metric::MSE
82 | Metric::MAE
83 | Metric::RMSE
84 | Metric::Inertia
85 | Metric::WER
86 | Metric::Perplexity
87 )
88 }
89
90 pub fn name(&self) -> &'static str {
92 match self {
93 Metric::Accuracy => "Accuracy",
94 Metric::Precision(_) => "Precision",
95 Metric::Recall(_) => "Recall",
96 Metric::F1(_) => "F1",
97 Metric::R2 => "R²",
98 Metric::MSE => "MSE",
99 Metric::MAE => "MAE",
100 Metric::RMSE => "RMSE",
101 Metric::Silhouette => "Silhouette",
102 Metric::Inertia => "Inertia",
103 Metric::WER => "WER",
104 Metric::RTFx => "RTFx",
105 Metric::BLEU => "BLEU",
106 Metric::ROUGE(_) => "ROUGE",
107 Metric::Perplexity => "Perplexity",
108 Metric::MMLUAccuracy => "MMLU",
109 Metric::PassAtK(_) => "pass@k",
110 Metric::NDCGAtK(_) => "NDCG@k",
111 }
112 }
113}
114
115impl fmt::Display for Metric {
116 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
117 match self {
118 Metric::Accuracy
119 | Metric::R2
120 | Metric::MSE
121 | Metric::MAE
122 | Metric::RMSE
123 | Metric::Silhouette
124 | Metric::Inertia
125 | Metric::WER
126 | Metric::RTFx
127 | Metric::BLEU
128 | Metric::Perplexity
129 | Metric::MMLUAccuracy => write!(f, "{}", self.name()),
130 Metric::Precision(avg) => write!(f, "Precision({avg:?})"),
131 Metric::Recall(avg) => write!(f, "Recall({avg:?})"),
132 Metric::F1(avg) => write!(f, "F1({avg:?})"),
133 Metric::ROUGE(variant) => write!(f, "{variant}"),
134 Metric::PassAtK(k) => write!(f, "pass@{k}"),
135 Metric::NDCGAtK(k) => write!(f, "NDCG@{k}"),
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
146 fn test_name_precision_arm() {
147 let m = Metric::Precision(Average::Macro);
148 match m {
149 Metric::Precision(_) => assert_eq!(m.name(), "Precision"),
150 _ => unreachable!(),
151 }
152 }
153
154 #[test]
155 fn test_name_recall_arm() {
156 let m = Metric::Recall(Average::Micro);
157 match m {
158 Metric::Recall(_) => assert_eq!(m.name(), "Recall"),
159 _ => unreachable!(),
160 }
161 }
162
163 #[test]
164 fn test_name_f1_arm() {
165 let m = Metric::F1(Average::Weighted);
166 match m {
167 Metric::F1(_) => assert_eq!(m.name(), "F1"),
168 _ => unreachable!(),
169 }
170 }
171
172 #[test]
173 fn test_name_rouge_arm() {
174 let m = Metric::ROUGE(RougeVariant::RougeL);
175 match m {
176 Metric::ROUGE(_) => assert_eq!(m.name(), "ROUGE"),
177 _ => unreachable!(),
178 }
179 }
180
181 #[test]
182 fn test_name_pass_at_k_arm() {
183 let m = Metric::PassAtK(1);
184 match m {
185 Metric::PassAtK(_) => assert_eq!(m.name(), "pass@k"),
186 _ => unreachable!(),
187 }
188 }
189
190 #[test]
191 fn test_name_ndcg_at_k_arm() {
192 let m = Metric::NDCGAtK(5);
193 match m {
194 Metric::NDCGAtK(_) => assert_eq!(m.name(), "NDCG@k"),
195 _ => unreachable!(),
196 }
197 }
198
199 #[test]
201 fn test_display_precision_avg_arm() {
202 let m = Metric::Precision(Average::Macro);
203 match m {
204 Metric::Precision(avg) => {
205 let _ = avg;
206 assert_eq!(m.to_string(), "Precision(Macro)");
207 }
208 _ => unreachable!(),
209 }
210 }
211
212 #[test]
213 fn test_display_recall_avg_arm() {
214 let m = Metric::Recall(Average::Micro);
215 match m {
216 Metric::Recall(avg) => {
217 let _ = avg;
218 assert_eq!(m.to_string(), "Recall(Micro)");
219 }
220 _ => unreachable!(),
221 }
222 }
223
224 #[test]
225 fn test_display_f1_avg_arm() {
226 let m = Metric::F1(Average::Weighted);
227 match m {
228 Metric::F1(avg) => {
229 let _ = avg;
230 assert_eq!(m.to_string(), "F1(Weighted)");
231 }
232 _ => unreachable!(),
233 }
234 }
235
236 #[test]
237 fn test_display_rouge_variant_arm() {
238 let m = Metric::ROUGE(RougeVariant::Rouge1);
239 match m {
240 Metric::ROUGE(variant) => {
241 let _ = variant;
242 assert_eq!(m.to_string(), "ROUGE-1");
243 }
244 _ => unreachable!(),
245 }
246 }
247
248 #[test]
249 fn test_display_pass_at_k_arm() {
250 let m = Metric::PassAtK(5);
251 match m {
252 Metric::PassAtK(k) => {
253 let _ = k;
254 assert_eq!(m.to_string(), "pass@5");
255 }
256 _ => unreachable!(),
257 }
258 }
259
260 #[test]
261 fn test_display_ndcg_at_k_arm() {
262 let m = Metric::NDCGAtK(10);
263 match m {
264 Metric::NDCGAtK(k) => {
265 let _ = k;
266 assert_eq!(m.to_string(), "NDCG@10");
267 }
268 _ => unreachable!(),
269 }
270 }
271
272 #[test]
273 fn test_display_simple_variants() {
274 assert_eq!(Metric::Accuracy.to_string(), "Accuracy");
275 assert_eq!(Metric::R2.to_string(), "R²");
276 assert_eq!(Metric::MSE.to_string(), "MSE");
277 assert_eq!(Metric::MAE.to_string(), "MAE");
278 assert_eq!(Metric::RMSE.to_string(), "RMSE");
279 assert_eq!(Metric::Silhouette.to_string(), "Silhouette");
280 assert_eq!(Metric::Inertia.to_string(), "Inertia");
281 assert_eq!(Metric::WER.to_string(), "WER");
282 assert_eq!(Metric::RTFx.to_string(), "RTFx");
283 assert_eq!(Metric::BLEU.to_string(), "BLEU");
284 assert_eq!(Metric::Perplexity.to_string(), "Perplexity");
285 assert_eq!(Metric::MMLUAccuracy.to_string(), "MMLU");
286 }
287
288 #[test]
289 fn test_name_simple_variants() {
290 assert_eq!(Metric::Accuracy.name(), "Accuracy");
291 assert_eq!(Metric::R2.name(), "R²");
292 assert_eq!(Metric::MSE.name(), "MSE");
293 assert_eq!(Metric::MAE.name(), "MAE");
294 assert_eq!(Metric::RMSE.name(), "RMSE");
295 assert_eq!(Metric::Silhouette.name(), "Silhouette");
296 assert_eq!(Metric::Inertia.name(), "Inertia");
297 assert_eq!(Metric::WER.name(), "WER");
298 assert_eq!(Metric::RTFx.name(), "RTFx");
299 assert_eq!(Metric::BLEU.name(), "BLEU");
300 assert_eq!(Metric::Perplexity.name(), "Perplexity");
301 assert_eq!(Metric::MMLUAccuracy.name(), "MMLU");
302 }
303
304 #[test]
305 fn test_higher_is_better_all_variants() {
306 assert!(Metric::Accuracy.higher_is_better());
307 assert!(Metric::Precision(Average::Macro).higher_is_better());
308 assert!(Metric::Recall(Average::Micro).higher_is_better());
309 assert!(Metric::F1(Average::Weighted).higher_is_better());
310 assert!(Metric::R2.higher_is_better());
311 assert!(!Metric::MSE.higher_is_better());
312 assert!(!Metric::MAE.higher_is_better());
313 assert!(!Metric::RMSE.higher_is_better());
314 assert!(Metric::Silhouette.higher_is_better());
315 assert!(!Metric::Inertia.higher_is_better());
316 assert!(!Metric::WER.higher_is_better());
317 assert!(Metric::RTFx.higher_is_better());
318 assert!(Metric::BLEU.higher_is_better());
319 assert!(Metric::ROUGE(RougeVariant::Rouge1).higher_is_better());
320 assert!(!Metric::Perplexity.higher_is_better());
321 assert!(Metric::MMLUAccuracy.higher_is_better());
322 assert!(Metric::PassAtK(1).higher_is_better());
323 assert!(Metric::NDCGAtK(5).higher_is_better());
324 }
325
326 #[test]
327 fn test_rouge_variant_display() {
328 assert_eq!(RougeVariant::Rouge1.to_string(), "ROUGE-1");
329 assert_eq!(RougeVariant::Rouge2.to_string(), "ROUGE-2");
330 assert_eq!(RougeVariant::RougeL.to_string(), "ROUGE-L");
331 }
332}