1use super::classify_pipeline::ClassifyPipeline;
6use crate::eval::classification::{ConfusionMatrix, MultiClassMetrics};
7use std::path::Path;
8
9#[derive(Debug, Clone)]
15pub struct ClassifyEvalReport {
16 pub accuracy: f64,
18 pub avg_loss: f32,
20 pub per_class_precision: Vec<f64>,
22 pub per_class_recall: Vec<f64>,
24 pub per_class_f1: Vec<f64>,
26 pub per_class_support: Vec<usize>,
28 pub confusion_matrix: Vec<Vec<usize>>,
30 pub num_classes: usize,
32 pub total_samples: usize,
34 pub eval_time_ms: u64,
36 pub label_names: Vec<String>,
38 pub cohens_kappa: f64,
40 pub mcc: f64,
42 pub top2_accuracy: f64,
44 pub mean_confidence: f64,
46 pub mean_confidence_correct: f64,
48 pub mean_confidence_wrong: f64,
50 pub samples_per_sec: f64,
52 pub calibration_bins: Vec<(f64, f64, usize)>,
54 pub ece: f64,
56 pub brier_score: f64,
58 pub log_loss: f64,
60 pub ci_accuracy: (f64, f64),
62 pub ci_macro_f1: (f64, f64),
63 pub ci_mcc: (f64, f64),
64 pub baseline_random: f64,
66 pub baseline_majority: f64,
67 pub top_confusions: Vec<(usize, usize, usize)>,
69}
70
71impl ClassifyEvalReport {
72 pub(crate) fn from_predictions_with_probs(
77 y_pred: &[usize],
78 y_true: &[usize],
79 all_probs: &[Vec<f32>],
80 total_loss: f32,
81 num_classes: usize,
82 label_names: &[String],
83 eval_time_ms: u64,
84 ) -> Self {
85 let total_samples = y_pred.len();
86 let avg_loss = if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
87
88 let cm = ConfusionMatrix::from_predictions_with_min_classes(y_pred, y_true, num_classes);
89 let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
90 let accuracy = cm.accuracy();
91
92 let cohens_kappa = Self::compute_cohens_kappa(&cm, total_samples);
93 let mcc = Self::compute_mcc(&cm, cm.n_classes(), total_samples);
94 let top2_accuracy = Self::compute_top2_accuracy(all_probs, y_true, total_samples);
95
96 let confidences: Vec<f64> =
97 all_probs.iter().map(|p| f64::from(p.iter().copied().fold(0.0f32, f32::max))).collect();
98
99 let (mean_confidence, mean_confidence_correct, mean_confidence_wrong) =
100 Self::compute_confidence_stats(&confidences, y_pred, y_true);
101
102 let (calibration_bins, ece) =
103 Self::compute_calibration(&confidences, y_pred, y_true, total_samples);
104
105 let samples_per_sec = if eval_time_ms > 0 {
106 total_samples as f64 / (eval_time_ms as f64 / 1000.0)
107 } else {
108 0.0
109 };
110
111 let brier_score = Self::compute_brier_score(all_probs, y_true, num_classes);
112 let log_loss = Self::compute_log_loss(all_probs, y_true);
113
114 let (ci_accuracy, ci_macro_f1, ci_mcc) =
115 Self::compute_bootstrap_cis(y_pred, y_true, num_classes, 1000);
116
117 let (baseline_random, baseline_majority) =
118 Self::compute_baselines(&metrics.support, total_samples, num_classes);
119
120 let top_confusions = Self::compute_top_confusions(cm.matrix(), 5);
121
122 Self {
123 accuracy,
124 avg_loss,
125 per_class_precision: metrics.precision,
126 per_class_recall: metrics.recall,
127 per_class_f1: metrics.f1,
128 per_class_support: metrics.support,
129 confusion_matrix: cm.matrix().clone(),
130 num_classes,
131 total_samples,
132 eval_time_ms,
133 label_names: label_names.to_vec(),
134 cohens_kappa,
135 mcc,
136 top2_accuracy,
137 mean_confidence,
138 mean_confidence_correct,
139 mean_confidence_wrong,
140 samples_per_sec,
141 calibration_bins,
142 ece,
143 brier_score,
144 log_loss,
145 ci_accuracy,
146 ci_macro_f1,
147 ci_mcc,
148 baseline_random,
149 baseline_majority,
150 top_confusions,
151 }
152 }
153
154 pub(crate) fn compute_top2_accuracy(
156 all_probs: &[Vec<f32>],
157 y_true: &[usize],
158 total: usize,
159 ) -> f64 {
160 if total == 0 {
161 return 0.0;
162 }
163 let correct = all_probs
164 .iter()
165 .zip(y_true.iter())
166 .filter(|(probs, &true_label)| {
167 let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
168 indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
169 indexed.len() >= 2 && (indexed[0].0 == true_label || indexed[1].0 == true_label)
170 })
171 .count();
172 correct as f64 / total as f64
173 }
174
175 pub(crate) fn compute_confidence_stats(
177 confidences: &[f64],
178 y_pred: &[usize],
179 y_true: &[usize],
180 ) -> (f64, f64, f64) {
181 let mean = if confidences.is_empty() {
182 0.0
183 } else {
184 confidences.iter().sum::<f64>() / confidences.len() as f64
185 };
186
187 let (mut sum_correct, mut n_correct) = (0.0f64, 0usize);
188 let (mut sum_wrong, mut n_wrong) = (0.0f64, 0usize);
189 for (i, &conf) in confidences.iter().enumerate() {
190 if y_pred[i] == y_true[i] {
191 sum_correct += conf;
192 n_correct += 1;
193 } else {
194 sum_wrong += conf;
195 n_wrong += 1;
196 }
197 }
198
199 let mean_correct = if n_correct > 0 { sum_correct / n_correct as f64 } else { 0.0 };
200 let mean_wrong = if n_wrong > 0 { sum_wrong / n_wrong as f64 } else { 0.0 };
201
202 (mean, mean_correct, mean_wrong)
203 }
204
205 pub(crate) fn compute_calibration(
207 confidences: &[f64],
208 y_pred: &[usize],
209 y_true: &[usize],
210 total_samples: usize,
211 ) -> (Vec<(f64, f64, usize)>, f64) {
212 let num_bins = 10;
213 let mut bin_sum_conf = vec![0.0f64; num_bins];
214 let mut bin_sum_acc = vec![0.0f64; num_bins];
215 let mut bin_count = vec![0usize; num_bins];
216
217 for (i, &conf) in confidences.iter().enumerate() {
218 let bin = ((conf * num_bins as f64) as usize).min(num_bins - 1);
219 bin_sum_conf[bin] += conf;
220 bin_sum_acc[bin] += if y_pred[i] == y_true[i] { 1.0 } else { 0.0 };
221 bin_count[bin] += 1;
222 }
223
224 let bins: Vec<(f64, f64, usize)> = (0..num_bins)
225 .map(|b| {
226 if bin_count[b] > 0 {
227 (
228 bin_sum_conf[b] / bin_count[b] as f64,
229 bin_sum_acc[b] / bin_count[b] as f64,
230 bin_count[b],
231 )
232 } else {
233 (0.0, 0.0, 0)
234 }
235 })
236 .collect();
237
238 let ece: f64 = bins
239 .iter()
240 .map(|&(conf, acc, count)| {
241 if count > 0 {
242 (conf - acc).abs() * count as f64 / total_samples as f64
243 } else {
244 0.0
245 }
246 })
247 .sum();
248
249 (bins, ece)
250 }
251
252 pub(crate) fn compute_cohens_kappa(cm: &ConfusionMatrix, total: usize) -> f64 {
257 if total == 0 {
258 return 0.0;
259 }
260 let mat = cm.matrix();
261 let n = total as f64;
262 let p_o = cm.accuracy();
263
264 let k = mat.len();
266 let mut p_e = 0.0f64;
267 for class in 0..k {
268 let row_sum: f64 = mat[class].iter().sum::<usize>() as f64;
269 let col_sum: f64 = mat.iter().map(|row| row[class]).sum::<usize>() as f64;
270 p_e += (row_sum * col_sum) / (n * n);
271 }
272
273 if (1.0 - p_e).abs() < 1e-10 {
274 return if (p_o - 1.0).abs() < 1e-10 { 1.0 } else { 0.0 };
275 }
276
277 (p_o - p_e) / (1.0 - p_e)
278 }
279
280 pub(crate) fn compute_mcc(cm: &ConfusionMatrix, num_classes: usize, total: usize) -> f64 {
285 if total == 0 {
286 return 0.0;
287 }
288 let mat = cm.matrix();
289 let s = total as f64;
290
291 let c: f64 = (0..num_classes).map(|k| mat[k][k] as f64).sum();
293
294 let p: Vec<f64> =
296 (0..num_classes).map(|k| mat.iter().map(|row| row[k] as f64).sum()).collect();
297
298 let t: Vec<f64> = (0..num_classes).map(|k| mat[k].iter().sum::<usize>() as f64).collect();
300
301 let sum_pk_tk: f64 = p.iter().zip(t.iter()).map(|(pk, tk)| pk * tk).sum();
302 let sum_pk_sq: f64 = p.iter().map(|pk| pk * pk).sum();
303 let sum_tk_sq: f64 = t.iter().map(|tk| tk * tk).sum();
304
305 let numer = c * s - sum_pk_tk;
306 let denom_sq = (s * s - sum_pk_sq) * (s * s - sum_tk_sq);
307
308 if denom_sq <= 0.0 {
309 return 0.0;
310 }
311
312 numer / denom_sq.sqrt()
313 }
314
315 pub(crate) fn compute_brier_score(
317 all_probs: &[Vec<f32>],
318 y_true: &[usize],
319 num_classes: usize,
320 ) -> f64 {
321 if all_probs.is_empty() {
322 return 0.0;
323 }
324 let total: f64 = all_probs
325 .iter()
326 .zip(y_true.iter())
327 .map(|(probs, &true_label)| {
328 (0..num_classes)
329 .map(|k| {
330 let p = f64::from(*probs.get(k).unwrap_or(&0.0));
331 let y = if k == true_label { 1.0 } else { 0.0 };
332 (p - y) * (p - y)
333 })
334 .sum::<f64>()
335 })
336 .sum();
337 total / all_probs.len() as f64
338 }
339
340 pub(crate) fn compute_log_loss(all_probs: &[Vec<f32>], y_true: &[usize]) -> f64 {
342 if all_probs.is_empty() {
343 return 0.0;
344 }
345 let eps = 1e-15_f64;
346 let total: f64 = all_probs
347 .iter()
348 .zip(y_true.iter())
349 .map(|(probs, &true_label)| {
350 let p = f64::from(probs.get(true_label).copied().unwrap_or(0.0));
351 -p.clamp(eps, 1.0 - eps).ln()
352 })
353 .sum();
354 total / all_probs.len() as f64
355 }
356
357 pub(crate) fn compute_bootstrap_cis(
362 y_pred: &[usize],
363 y_true: &[usize],
364 num_classes: usize,
365 n_boot: usize,
366 ) -> ((f64, f64), (f64, f64), (f64, f64)) {
367 let n = y_pred.len();
368 if n == 0 {
369 return ((0.0, 0.0), (0.0, 0.0), (0.0, 0.0));
370 }
371
372 let mut accs = Vec::with_capacity(n_boot);
373 let mut f1s = Vec::with_capacity(n_boot);
374 let mut mccs = Vec::with_capacity(n_boot);
375
376 let mut rng_state: u64 = 42;
378 let lcg_next = |state: &mut u64| -> usize {
379 *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
380 ((*state >> 33) as usize) % n
381 };
382
383 for _ in 0..n_boot {
384 let mut boot_pred = Vec::with_capacity(n);
386 let mut boot_true = Vec::with_capacity(n);
387 for _ in 0..n {
388 let idx = lcg_next(&mut rng_state);
389 boot_pred.push(y_pred[idx]);
390 boot_true.push(y_true[idx]);
391 }
392
393 let cm = ConfusionMatrix::from_predictions_with_min_classes(
394 &boot_pred,
395 &boot_true,
396 num_classes,
397 );
398 let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
399
400 accs.push(cm.accuracy());
401
402 let valid_f1: Vec<f64> = metrics.f1.iter().copied().filter(|v| !v.is_nan()).collect();
404 let macro_f1 = if valid_f1.is_empty() {
405 0.0
406 } else {
407 valid_f1.iter().sum::<f64>() / valid_f1.len() as f64
408 };
409 f1s.push(macro_f1);
410
411 mccs.push(Self::compute_mcc(&cm, cm.n_classes(), n));
412 }
413
414 accs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
415 f1s.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
416 mccs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
417
418 let lo = (0.025 * n_boot as f64) as usize;
419 let hi = (0.975 * n_boot as f64).ceil() as usize;
420 let hi = hi.min(n_boot - 1);
421
422 ((accs[lo], accs[hi]), (f1s[lo], f1s[hi]), (mccs[lo], mccs[hi]))
423 }
424
425 pub(crate) fn compute_baselines(
427 support: &[usize],
428 total: usize,
429 num_classes: usize,
430 ) -> (f64, f64) {
431 let random = if num_classes > 0 { 1.0 / num_classes as f64 } else { 0.0 };
432 let majority = if total > 0 {
433 support.iter().copied().max().unwrap_or(0) as f64 / total as f64
434 } else {
435 0.0
436 };
437 (random, majority)
438 }
439
440 pub(crate) fn compute_top_confusions(
442 matrix: &[Vec<usize>],
443 top_n: usize,
444 ) -> Vec<(usize, usize, usize)> {
445 let mut pairs: Vec<(usize, usize, usize)> = Vec::new();
446 for (i, row) in matrix.iter().enumerate() {
447 for (j, &count) in row.iter().enumerate() {
448 if i != j && count > 0 {
449 pairs.push((i, j, count));
450 }
451 }
452 }
453 pairs.sort_by(|a, b| b.2.cmp(&a.2));
454 pairs.truncate(top_n);
455 pairs
456 }
457
458 #[must_use]
460 pub fn to_report(&self) -> String {
461 use crate::eval::classification::Average;
462
463 let mut out = String::new();
464
465 out.push_str(&format!(
467 "{:>18} {:>10} {:>10} {:>10} {:>10}\n",
468 "", "precision", "recall", "f1-score", "support"
469 ));
470 out.push_str(&format!("{}\n", "-".repeat(62)));
471
472 for i in 0..self.num_classes {
474 let name = self
475 .label_names
476 .get(i)
477 .map_or_else(|| format!("Class {i}"), std::clone::Clone::clone);
478 out.push_str(&format!(
479 "{:>18} {:>10.4} {:>10.4} {:>10.4} {:>10}\n",
480 name,
481 self.per_class_precision[i],
482 self.per_class_recall[i],
483 self.per_class_f1[i],
484 self.per_class_support[i],
485 ));
486 }
487
488 out.push_str(&format!("{}\n", "-".repeat(62)));
489
490 let total_support: usize = self.per_class_support.iter().sum();
491
492 let macro_p = self.avg_metric(&self.per_class_precision, Average::Macro);
494 let macro_r = self.avg_metric(&self.per_class_recall, Average::Macro);
495 let macro_f1 = self.avg_metric(&self.per_class_f1, Average::Macro);
496 out.push_str(&format!(
497 "{:>18} {:>10.4} {:>10.4} {:>10.4} {:>10}\n",
498 "macro avg", macro_p, macro_r, macro_f1, total_support,
499 ));
500
501 let weighted_p = self.avg_metric(&self.per_class_precision, Average::Weighted);
503 let weighted_r = self.avg_metric(&self.per_class_recall, Average::Weighted);
504 let weighted_f1 = self.avg_metric(&self.per_class_f1, Average::Weighted);
505 out.push_str(&format!(
506 "{:>18} {:>10.4} {:>10.4} {:>10.4} {:>10}\n",
507 "weighted avg", weighted_p, weighted_r, weighted_f1, total_support,
508 ));
509
510 self.report_summary(&mut out);
512 self.report_confidence(&mut out);
513 self.report_scoring_rules(&mut out);
514 self.report_calibration(&mut out);
515 self.report_baselines(&mut out);
516 self.report_top_confusions(&mut out);
517 self.report_throughput(&mut out);
518 out
519 }
520
521 pub(crate) fn report_summary(&self, out: &mut String) {
522 out.push_str(&format!(
523 "\nAccuracy: {:.4} ({:.1}%) 95% CI [{:.4}, {:.4}]\n",
524 self.accuracy,
525 self.accuracy * 100.0,
526 self.ci_accuracy.0,
527 self.ci_accuracy.1
528 ));
529 out.push_str(&format!(
530 "Top-2 accuracy: {:.4} ({:.1}%)\n",
531 self.top2_accuracy,
532 self.top2_accuracy * 100.0
533 ));
534 out.push_str(&format!(
535 "Cohen's kappa: {:.4} ({})\n",
536 self.cohens_kappa,
537 Self::kappa_interpretation(self.cohens_kappa)
538 ));
539 out.push_str(&format!(
540 "MCC: {:.4} 95% CI [{:.4}, {:.4}]\n",
541 self.mcc, self.ci_mcc.0, self.ci_mcc.1
542 ));
543 out.push_str(&format!(
544 "Macro F1: {:.4} 95% CI [{:.4}, {:.4}]\n",
545 self.avg_metric(&self.per_class_f1, crate::eval::classification::Average::Macro),
546 self.ci_macro_f1.0,
547 self.ci_macro_f1.1
548 ));
549 out.push_str(&format!("Avg loss: {:.4}\n", self.avg_loss));
550 }
551
552 pub(crate) fn report_confidence(&self, out: &mut String) {
553 out.push_str(&format!("\nConfidence (mean): {:.4}\n", self.mean_confidence));
554 out.push_str(&format!(" correct preds: {:.4}\n", self.mean_confidence_correct));
555 out.push_str(&format!(" wrong preds: {:.4}\n", self.mean_confidence_wrong));
556 let gap = self.mean_confidence_correct - self.mean_confidence_wrong;
557 out.push_str(&format!(" gap (higher=better): {gap:.4}\n"));
558 }
559
560 pub(crate) fn report_scoring_rules(&self, out: &mut String) {
561 out.push_str(&format!(
562 "\nBrier score: {:.4} (perfect=0, random={:.4})\n",
563 self.brier_score,
564 1.0 - 1.0 / self.num_classes as f64
565 ));
566 out.push_str(&format!(
567 "Log loss: {:.4} (random={:.4})\n",
568 self.log_loss,
569 (self.num_classes as f64).ln()
570 ));
571 }
572
573 pub(crate) fn report_calibration(&self, out: &mut String) {
574 out.push_str(&format!("\nECE (Expected Calibration Error): {:.4}\n", self.ece));
575 out.push_str("Calibration:\n");
576 out.push_str(" Bin Confidence Accuracy Count\n");
577 for (i, &(conf, acc, count)) in self.calibration_bins.iter().enumerate() {
578 if count > 0 {
579 let lo = i as f64 * 0.1;
580 let hi = lo + 0.1;
581 let overconf = if conf > acc { "+" } else { "" };
582 out.push_str(&format!(
583 " [{:.1}-{:.1}) {:.4} {:.4} {:>5} {overconf}{:.3}\n",
584 lo,
585 hi,
586 conf,
587 acc,
588 count,
589 conf - acc,
590 ));
591 }
592 }
593 }
594
595 pub(crate) fn report_baselines(&self, out: &mut String) {
596 out.push_str(&format!(
597 "\nBaselines: random={:.1}% majority={:.1}% model={:.1}% lift={:.1}x\n",
598 self.baseline_random * 100.0,
599 self.baseline_majority * 100.0,
600 self.accuracy * 100.0,
601 if self.baseline_majority > 0.0 { self.accuracy / self.baseline_majority } else { 0.0 },
602 ));
603 }
604
605 pub(crate) fn report_top_confusions(&self, out: &mut String) {
606 if self.top_confusions.is_empty() {
607 return;
608 }
609 out.push_str("\nTop confusions (true → predicted, count):\n");
610 for &(true_c, pred_c, count) in &self.top_confusions {
611 let true_name = self.label_names.get(true_c).map_or("?", |n| n.as_str());
612 let pred_name = self.label_names.get(pred_c).map_or("?", |n| n.as_str());
613 out.push_str(&format!(" {true_name} → {pred_name}: {count}\n"));
614 }
615 }
616
617 pub(crate) fn report_throughput(&self, out: &mut String) {
618 out.push_str(&format!("\nSamples: {}\n", self.total_samples));
619 out.push_str(&format!(
620 "Time: {}ms ({:.1} samples/sec)\n",
621 self.eval_time_ms, self.samples_per_sec
622 ));
623 }
624
625 pub(crate) fn kappa_interpretation(kappa: f64) -> &'static str {
627 if kappa < 0.0 {
628 "worse than chance"
629 } else if kappa < 0.20 {
630 "slight"
631 } else if kappa < 0.40 {
632 "fair"
633 } else if kappa < 0.60 {
634 "moderate"
635 } else if kappa < 0.80 {
636 "substantial"
637 } else {
638 "almost perfect"
639 }
640 }
641
642 #[must_use]
646 #[allow(clippy::disallowed_methods)]
647 pub fn to_json(&self) -> String {
648 let per_class: Vec<serde_json::Value> = (0..self.num_classes)
649 .map(|i| {
650 let name = self
651 .label_names
652 .get(i)
653 .map_or_else(|| format!("class_{i}"), std::clone::Clone::clone);
654 serde_json::json!({
655 "label": name,
656 "precision": self.per_class_precision[i],
657 "recall": self.per_class_recall[i],
658 "f1": self.per_class_f1[i],
659 "support": self.per_class_support[i],
660 })
661 })
662 .collect();
663
664 let calibration: Vec<serde_json::Value> = self
665 .calibration_bins
666 .iter()
667 .enumerate()
668 .filter(|(_, &(_, _, count))| count > 0)
669 .map(|(i, &(conf, acc, count))| {
670 serde_json::json!({
671 "bin": format!("[{:.1}-{:.1})", i as f64 * 0.1, (i + 1) as f64 * 0.1),
672 "mean_confidence": conf,
673 "mean_accuracy": acc,
674 "count": count,
675 })
676 })
677 .collect();
678
679 let confusions: Vec<serde_json::Value> = self.top_confusions.iter().map(|&(t, p, c)| {
680 serde_json::json!({
681 "true_class": self.label_names.get(t).cloned().unwrap_or_else(|| format!("class_{t}")),
682 "pred_class": self.label_names.get(p).cloned().unwrap_or_else(|| format!("class_{p}")),
683 "count": c,
684 })
685 }).collect();
686
687 let json = serde_json::json!({
688 "accuracy": self.accuracy,
689 "top2_accuracy": self.top2_accuracy,
690 "cohens_kappa": self.cohens_kappa,
691 "mcc": self.mcc,
692 "avg_loss": self.avg_loss,
693 "brier_score": self.brier_score,
694 "log_loss": self.log_loss,
695 "total_samples": self.total_samples,
696 "num_classes": self.num_classes,
697 "eval_time_ms": self.eval_time_ms,
698 "samples_per_sec": self.samples_per_sec,
699 "confidence_intervals_95": {
700 "accuracy": [self.ci_accuracy.0, self.ci_accuracy.1],
701 "macro_f1": [self.ci_macro_f1.0, self.ci_macro_f1.1],
702 "mcc": [self.ci_mcc.0, self.ci_mcc.1],
703 },
704 "baselines": {
705 "random": self.baseline_random,
706 "majority_class": self.baseline_majority,
707 "lift_over_majority": if self.baseline_majority > 0.0 { self.accuracy / self.baseline_majority } else { 0.0 },
708 },
709 "per_class": per_class,
710 "confusion_matrix": self.confusion_matrix,
711 "top_confusions": confusions,
712 "confidence": {
713 "mean": self.mean_confidence,
714 "mean_correct": self.mean_confidence_correct,
715 "mean_wrong": self.mean_confidence_wrong,
716 "gap": self.mean_confidence_correct - self.mean_confidence_wrong,
717 },
718 "calibration": {
719 "ece": self.ece,
720 "brier_score": self.brier_score,
721 "log_loss": self.log_loss,
722 "bins": calibration,
723 },
724 });
725
726 serde_json::to_string_pretty(&json).unwrap_or_default()
727 }
728
729 #[must_use]
735 pub fn to_model_card(&self, model_name: &str, base_model: Option<&str>) -> String {
736 use crate::eval::classification::Average;
737
738 let macro_f1 = self.avg_metric(&self.per_class_f1, Average::Macro);
739 let weighted_f1 = self.avg_metric(&self.per_class_f1, Average::Weighted);
740
741 let mut out = String::new();
742 self.card_yaml_front_matter(&mut out, model_name, base_model, macro_f1, weighted_f1);
743 self.card_title(&mut out, model_name, base_model);
744 self.card_summary(&mut out, macro_f1, weighted_f1);
745 self.card_labels(&mut out);
746 self.card_per_class_metrics(&mut out);
747 self.card_confusion_matrix(&mut out);
748 self.card_error_analysis(&mut out);
749 self.card_calibration(&mut out);
750 Self::card_intended_use(&mut out);
751 self.card_limitations(&mut out);
752 Self::card_ethical_considerations(&mut out);
753 self.card_training(&mut out, base_model);
754 out.push_str("---\n*Generated by [entrenar](https://github.com/paiml/entrenar)*\n");
755 out
756 }
757
758 pub(crate) fn card_yaml_front_matter(
759 &self,
760 out: &mut String,
761 model_name: &str,
762 base_model: Option<&str>,
763 macro_f1: f64,
764 weighted_f1: f64,
765 ) {
766 out.push_str("---\n");
767 out.push_str("license: apache-2.0\n");
768 out.push_str("language:\n- en\n");
769 out.push_str(
770 "tags:\n- shell-safety\n- code-classification\n- lora\n- entrenar\n- security\n",
771 );
772 if let Some(base) = base_model {
773 out.push_str(&format!("base_model: {base}\n"));
774 }
775 out.push_str("pipeline_tag: text-classification\n");
776 out.push_str("model-index:\n");
777 out.push_str(&format!("- name: {model_name}\n"));
778 out.push_str(" results:\n");
779 out.push_str(" - task:\n");
780 out.push_str(" type: text-classification\n");
781 out.push_str(" name: Shell Safety Classification\n");
782 out.push_str(" metrics:\n");
783 out.push_str(&format!(" - type: accuracy\n value: {:.4}\n", self.accuracy));
784 out.push_str(&format!(
785 " - type: f1\n value: {macro_f1:.4}\n name: Macro F1\n"
786 ));
787 out.push_str(&format!(
788 " - type: f1\n value: {weighted_f1:.4}\n name: Weighted F1\n"
789 ));
790 out.push_str(&format!(" - type: mcc\n value: {:.4}\n", self.mcc));
791 out.push_str(&format!(" - type: cohens_kappa\n value: {:.4}\n", self.cohens_kappa));
792 out.push_str("---\n\n");
793 }
794
795 pub(crate) fn card_title(&self, out: &mut String, model_name: &str, base_model: Option<&str>) {
796 out.push_str(&format!("# {model_name}\n\n"));
797 out.push_str("A shell command safety classifier that categorizes shell commands into safety classes, ");
798 out.push_str("enabling automated triage of commands before execution.\n\n");
799 out.push_str(
800 "Trained with [entrenar](https://github.com/paiml/entrenar) using LoRA fine-tuning",
801 );
802 if let Some(base) = base_model {
803 out.push_str(&format!(" on [`{base}`](https://huggingface.co/{base})"));
804 }
805 out.push_str(".\n\n");
806 }
807
808 pub(crate) fn card_summary(&self, out: &mut String, macro_f1: f64, weighted_f1: f64) {
809 out.push_str("## Summary\n\n");
810 out.push_str("| Metric | Value | 95% CI |\n");
811 out.push_str("|--------|-------|--------|\n");
812 out.push_str(&format!(
813 "| Accuracy | {:.2}% | [{:.2}%, {:.2}%] |\n",
814 self.accuracy * 100.0,
815 self.ci_accuracy.0 * 100.0,
816 self.ci_accuracy.1 * 100.0
817 ));
818 out.push_str(&format!("| Top-2 Accuracy | {:.2}% | — |\n", self.top2_accuracy * 100.0));
819 out.push_str(&format!(
820 "| Macro F1 | {macro_f1:.4} | [{:.4}, {:.4}] |\n",
821 self.ci_macro_f1.0, self.ci_macro_f1.1
822 ));
823 out.push_str(&format!("| Weighted F1 | {weighted_f1:.4} | — |\n"));
824 out.push_str(&format!(
825 "| Cohen's Kappa | {:.4} ({}) | — |\n",
826 self.cohens_kappa,
827 Self::kappa_interpretation(self.cohens_kappa)
828 ));
829 out.push_str(&format!(
830 "| MCC | {:.4} | [{:.4}, {:.4}] |\n",
831 self.mcc, self.ci_mcc.0, self.ci_mcc.1
832 ));
833 out.push_str(&format!("| Brier Score | {:.4} | — |\n", self.brier_score));
834 out.push_str(&format!("| Log Loss | {:.4} | — |\n", self.log_loss));
835 out.push_str(&format!("| ECE | {:.4} | — |\n", self.ece));
836 out.push_str(&format!("| Avg Loss | {:.4} | — |\n", self.avg_loss));
837 out.push_str(&format!("| Eval Samples | {} | — |\n", self.total_samples));
838 out.push_str(&format!("| Throughput | {:.1} samples/sec | — |\n\n", self.samples_per_sec));
839
840 let lift =
841 if self.baseline_majority > 0.0 { self.accuracy / self.baseline_majority } else { 0.0 };
842 out.push_str("**Baselines**: ");
843 out.push_str(&format!(
844 "random={:.1}%, majority={:.1}%, model={:.1}% ({:.1}x lift over majority)\n\n",
845 self.baseline_random * 100.0,
846 self.baseline_majority * 100.0,
847 self.accuracy * 100.0,
848 lift
849 ));
850 }
851
852 pub(crate) fn card_labels(&self, out: &mut String) {
853 out.push_str("## Labels\n\n");
854 out.push_str("| ID | Label | Description |\n");
855 out.push_str("|----|-------|-------------|\n");
856 let descriptions = [
857 "Command is safe to execute as-is",
858 "Command contains unquoted variable expansions (word splitting/globbing risk)",
859 "Command uses non-deterministic sources ($RANDOM, $$, date, etc.)",
860 "Command is not idempotent (unsafe to re-run: mkdir without -p, etc.)",
861 "Command is destructive or has injection risk (rm -rf, eval, etc.)",
862 ];
863 for (i, name) in self.label_names.iter().enumerate() {
864 let desc = descriptions.get(i).unwrap_or(&"");
865 out.push_str(&format!("| {i} | {name} | {desc} |\n"));
866 }
867 out.push('\n');
868 }
869
870 pub(crate) fn card_per_class_metrics(&self, out: &mut String) {
871 out.push_str("## Per-Class Metrics\n\n");
872 out.push_str("| Label | Precision | Recall | F1 | Support |\n");
873 out.push_str("|-------|-----------|--------|----|---------|\n");
874 for i in 0..self.num_classes {
875 let name = self
876 .label_names
877 .get(i)
878 .map_or_else(|| format!("class_{i}"), std::clone::Clone::clone);
879 out.push_str(&format!(
880 "| {} | {:.4} | {:.4} | {:.4} | {} |\n",
881 name,
882 self.per_class_precision[i],
883 self.per_class_recall[i],
884 self.per_class_f1[i],
885 self.per_class_support[i],
886 ));
887 }
888 out.push('\n');
889 }
890
891 fn card_confusion_matrix(&self, out: &mut String) {
892 out.push_str("## Confusion Matrix\n\n");
893 self.card_confusion_raw(out);
894 self.card_confusion_normalized(out);
895 }
896
897 pub(crate) fn card_confusion_raw(&self, out: &mut String) {
898 out.push_str("### Raw Counts\n\n```\n");
899 self.card_confusion_header(out);
900 for (i, row) in self.confusion_matrix.iter().enumerate() {
901 self.card_confusion_row_label(out, i);
902 for val in row {
903 out.push_str(&format!(" {val:>8}"));
904 }
905 out.push('\n');
906 }
907 out.push_str("```\n\n");
908 }
909
910 pub(crate) fn card_confusion_normalized(&self, out: &mut String) {
911 out.push_str("### Normalized (row %)\n\n```\n");
912 self.card_confusion_header(out);
913 for (i, row) in self.confusion_matrix.iter().enumerate() {
914 self.card_confusion_row_label(out, i);
915 let row_sum: usize = row.iter().sum();
916 for val in row {
917 if row_sum > 0 {
918 out.push_str(&format!(" {:>7.1}%", *val as f64 / row_sum as f64 * 100.0));
919 } else {
920 out.push_str(" 0.0%");
921 }
922 }
923 out.push('\n');
924 }
925 out.push_str("```\n\n");
926 }
927
928 pub(crate) fn card_confusion_header(&self, out: &mut String) {
929 out.push_str(&format!("{:>18}", "Predicted →"));
930 for name in &self.label_names {
931 let short = if name.len() > 8 { &name[..8] } else { name.as_str() };
932 out.push_str(&format!(" {short:>8}"));
933 }
934 out.push('\n');
935 }
936
937 pub(crate) fn card_confusion_row_label(&self, out: &mut String, i: usize) {
938 let name =
939 self.label_names.get(i).map_or_else(|| format!("class_{i}"), std::clone::Clone::clone);
940 let short = if name.len() > 18 { &name[..18] } else { name.as_str() };
941 out.push_str(&format!("{short:>18}"));
942 }
943
944 pub(crate) fn card_calibration(&self, out: &mut String) {
945 out.push_str("## Confidence & Calibration\n\n");
946 out.push_str("| Metric | Value |\n");
947 out.push_str("|--------|-------|\n");
948 out.push_str(&format!("| Mean confidence | {:.4} |\n", self.mean_confidence));
949 out.push_str(&format!("| Confidence (correct) | {:.4} |\n", self.mean_confidence_correct));
950 out.push_str(&format!("| Confidence (wrong) | {:.4} |\n", self.mean_confidence_wrong));
951 let gap = self.mean_confidence_correct - self.mean_confidence_wrong;
952 out.push_str(&format!("| Confidence gap | {gap:.4} |\n"));
953 out.push_str(&format!("| ECE | {:.4} |\n\n", self.ece));
954
955 out.push_str("**Calibration curve** (reliability diagram):\n\n");
956 out.push_str("```\n");
957 out.push_str("Bin Conf Acc Count\n");
958 for (i, &(conf, acc, count)) in self.calibration_bins.iter().enumerate() {
959 if count > 0 {
960 let lo = i as f64 * 0.1;
961 let hi = lo + 0.1;
962 out.push_str(&format!("[{lo:.1}-{hi:.1}) {conf:.3} {acc:.3} {count:>5}\n",));
963 }
964 }
965 out.push_str("```\n\n");
966 }
967
968 pub(crate) fn card_error_analysis(&self, out: &mut String) {
969 if self.top_confusions.is_empty() {
970 return;
971 }
972 out.push_str("## Error Analysis\n\n");
973 out.push_str("Most frequent misclassifications:\n\n");
974 out.push_str("| True Class | Predicted As | Count |\n");
975 out.push_str("|------------|-------------|-------|\n");
976 for &(true_c, pred_c, count) in &self.top_confusions {
977 let true_name = self.label_names.get(true_c).map_or("?", |n| n.as_str());
978 let pred_name = self.label_names.get(pred_c).map_or("?", |n| n.as_str());
979 out.push_str(&format!("| {true_name} | {pred_name} | {count} |\n"));
980 }
981 out.push('\n');
982 }
983
984 pub(crate) fn card_intended_use(out: &mut String) {
985 out.push_str("## Intended Use\n\n");
986 out.push_str(
987 "This model is designed for **automated shell command safety triage** in:\n\n",
988 );
989 out.push_str(
990 "- **CI/CD pipelines**: Pre-flight safety check before executing generated scripts\n",
991 );
992 out.push_str("- **Shell purification tools**: Classify commands to determine transformation strategy\n");
993 out.push_str(
994 "- **Code review**: Flag potentially unsafe shell commands in pull requests\n",
995 );
996 out.push_str("- **Interactive shells**: Warn users before executing risky commands\n\n");
997 }
998
999 fn card_limitations(&self, out: &mut String) {
1000 out.push_str("## Limitations\n\n");
1001 out.push_str("- **Not a security oracle**: This model provides *classification hints*, not security guarantees\n");
1002 out.push_str("- **Context-blind**: Cannot assess safety based on the execution environment or user permissions\n");
1003 out.push_str("- **Training distribution**: Trained on synthetic shell scripts; may underperform on novel patterns\n");
1004 out.push_str(
1005 "- **English only**: Command names and variable patterns are English-centric\n",
1006 );
1007 self.card_weak_classes(out);
1008 out.push('\n');
1009 }
1010
1011 pub(crate) fn card_weak_classes(&self, out: &mut String) {
1012 let min_f1_idx = self
1013 .per_class_f1
1014 .iter()
1015 .enumerate()
1016 .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1017 .map(|(i, _)| i);
1018 if let Some(idx) = min_f1_idx {
1019 if self.per_class_f1[idx] < 0.5 {
1020 let name = self.label_names.get(idx).map_or("unknown", |n| n.as_str());
1021 out.push_str(&format!(
1022 "- **Weak class**: `{name}` (F1={:.2}) — consider additional training data\n",
1023 self.per_class_f1[idx]
1024 ));
1025 }
1026 }
1027 }
1028
1029 pub(crate) fn card_ethical_considerations(out: &mut String) {
1030 out.push_str("## Ethical Considerations\n\n");
1031 out.push_str("- **False negatives are dangerous**: An `unsafe` command classified as `safe` could lead to data loss\n");
1032 out.push_str("- **Defense in depth**: Always combine this classifier with other safety mechanisms (sandboxing, dry-run)\n");
1033 out.push_str("- **Not adversarial-robust**: Determined attackers can craft commands to evade classification\n\n");
1034 }
1035
1036 pub(crate) fn card_training(&self, out: &mut String, base_model: Option<&str>) {
1037 out.push_str("## Training\n\n");
1038 out.push_str("| Parameter | Value |\n");
1039 out.push_str("|-----------|-------|\n");
1040 out.push_str("| Framework | [entrenar](https://github.com/paiml/entrenar) (Rust) |\n");
1041 out.push_str("| Method | LoRA (Low-Rank Adaptation) |\n");
1042 if let Some(base) = base_model {
1043 out.push_str(&format!("| Base model | `{base}` |\n"));
1044 }
1045 out.push_str(&format!("| Num classes | {} |\n\n", self.num_classes));
1046 }
1047
1048 pub(crate) fn avg_metric(
1050 &self,
1051 values: &[f64],
1052 average: crate::eval::classification::Average,
1053 ) -> f64 {
1054 match average {
1055 crate::eval::classification::Average::Macro => {
1056 if values.is_empty() {
1057 0.0
1058 } else {
1059 values.iter().sum::<f64>() / values.len() as f64
1060 }
1061 }
1062 crate::eval::classification::Average::Weighted => {
1063 let total: usize = self.per_class_support.iter().sum();
1064 if total == 0 {
1065 return 0.0;
1066 }
1067 values
1068 .iter()
1069 .zip(self.per_class_support.iter())
1070 .map(|(&v, &s)| v * s as f64)
1071 .sum::<f64>()
1072 / total as f64
1073 }
1074 _ => {
1075 if values.is_empty() {
1077 0.0
1078 } else {
1079 values.iter().sum::<f64>() / values.len() as f64
1080 }
1081 }
1082 }
1083 }
1084}
1085
1086pub const SSC_LABELS: [&str; 5] =
1088 ["safe", "needs-quoting", "non-deterministic", "non-idempotent", "unsafe"];
1089
1090pub(crate) fn restore_class_weights_from_metadata(
1102 checkpoint_dir: &std::path::Path,
1103 num_classes: usize,
1104) -> Option<Vec<f32>> {
1105 let meta_str = std::fs::read_to_string(checkpoint_dir.join("metadata.json")).ok()?;
1106 let meta: serde_json::Value = serde_json::from_str(&meta_str).ok()?;
1107 let arr = meta.get("class_weights")?.as_array()?;
1108 let weights: Vec<f32> = arr.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
1109 (weights.len() == num_classes).then_some(weights)
1110}
1111
1112pub fn evaluate_checkpoint(
1123 checkpoint_dir: &Path,
1124 test_data: &Path,
1125 model_config: &crate::transformer::TransformerConfig,
1126 classify_config: super::classify_pipeline::ClassifyConfig,
1127 label_names: &[String],
1128) -> crate::Result<ClassifyEvalReport> {
1129 use super::classification::load_safety_corpus;
1130
1131 let start = std::time::Instant::now();
1132 let num_classes = classify_config.num_classes;
1133
1134 let mut classify_config = classify_config;
1136 if classify_config.class_weights.is_none() {
1137 if let Some(weights) = restore_class_weights_from_metadata(checkpoint_dir, num_classes) {
1138 println!("Restored class_weights from checkpoint: {weights:?}");
1139 classify_config.class_weights = Some(weights);
1140 }
1141 }
1142
1143 let adapter_config_path = checkpoint_dir.join("adapter_config.json");
1146 let mut pipeline = if adapter_config_path.exists() {
1147 let adapter_json = std::fs::read_to_string(&adapter_config_path)
1149 .map_err(|e| crate::Error::Io(format!("Failed to read adapter_config.json: {e}")))?;
1150 let peft_config: crate::lora::PeftAdapterConfig = serde_json::from_str(&adapter_json)
1151 .map_err(|e| {
1152 crate::Error::Serialization(format!("Invalid adapter_config.json: {e}"))
1153 })?;
1154
1155 if peft_config.r > 0 {
1158 classify_config.lora_rank = peft_config.r;
1159 }
1160 if peft_config.lora_alpha > 0.0 {
1161 classify_config.lora_alpha = peft_config.lora_alpha;
1162 }
1163
1164 let mut pipe = match peft_config.base_model_name_or_path.as_deref() {
1168 Some(base_model_path)
1169 if std::path::Path::new(base_model_path).is_dir()
1170 || std::path::Path::new(base_model_path)
1171 .extension()
1172 .is_some_and(|e| e == "safetensors") =>
1173 {
1174 println!("Loading base model from: {base_model_path}");
1175 ClassifyPipeline::from_pretrained(
1176 base_model_path,
1177 model_config,
1178 classify_config.clone(),
1179 )?
1180 }
1181 Some(base_model_path) => {
1182 println!("Base model path is not a SafeTensors directory: {base_model_path}");
1183 println!("Using random-init base model (adapter weights will be restored from checkpoint)");
1184 ClassifyPipeline::new(model_config, classify_config.clone())
1185 }
1186 None => {
1187 println!("No base_model_name_or_path in adapter_config.json");
1188 println!("Using random-init base model (adapter weights will be restored from checkpoint)");
1189 ClassifyPipeline::new(model_config, classify_config.clone())
1190 }
1191 };
1192
1193 let st_path = checkpoint_dir.join("model.safetensors");
1195 let st_data = std::fs::read(&st_path).map_err(|e| {
1196 crate::Error::Io(format!("Failed to read checkpoint model.safetensors: {e}"))
1197 })?;
1198 let tensors = safetensors::SafeTensors::deserialize(&st_data).map_err(|e| {
1199 crate::Error::Serialization(format!("Failed to deserialize checkpoint: {e}"))
1200 })?;
1201
1202 if let Ok(w) = tensors.tensor("classifier.weight") {
1204 let w_data: &[f32] = bytemuck::cast_slice(w.data());
1205 pipe.classifier
1206 .weight
1207 .data_mut()
1208 .as_slice_mut()
1209 .expect("contiguous classifier.weight")
1210 .copy_from_slice(w_data);
1211 }
1212 if let Ok(b) = tensors.tensor("classifier.bias") {
1213 let b_data: &[f32] = bytemuck::cast_slice(b.data());
1214 pipe.classifier
1215 .bias
1216 .data_mut()
1217 .as_slice_mut()
1218 .expect("contiguous classifier.bias")
1219 .copy_from_slice(b_data);
1220 }
1221
1222 for (idx, lora) in pipe.lora_layers.iter_mut().enumerate() {
1224 let layer = idx / 2;
1225 let proj = if idx % 2 == 0 { "q" } else { "v" };
1226
1227 if let Ok(a) = tensors.tensor(&format!("lora.{layer}.{proj}_proj.lora_a")) {
1228 let a_data: &[f32] = bytemuck::cast_slice(a.data());
1229 lora.lora_a_mut()
1230 .data_mut()
1231 .as_slice_mut()
1232 .expect("contiguous lora_a")
1233 .copy_from_slice(a_data);
1234 }
1235 if let Ok(b) = tensors.tensor(&format!("lora.{layer}.{proj}_proj.lora_b")) {
1236 let b_data: &[f32] = bytemuck::cast_slice(b.data());
1237 lora.lora_b_mut()
1238 .data_mut()
1239 .as_slice_mut()
1240 .expect("contiguous lora_b")
1241 .copy_from_slice(b_data);
1242 }
1243 }
1244
1245 let loaded_count = tensors.names().len();
1246 println!("Restored {loaded_count} tensors from checkpoint");
1247 pipe
1248 } else {
1249 ClassifyPipeline::from_pretrained(checkpoint_dir, model_config, classify_config)?
1251 };
1252
1253 let samples = load_safety_corpus(test_data, num_classes)?;
1255
1256 let mut y_true: Vec<usize> = Vec::with_capacity(samples.len());
1258 let mut y_pred: Vec<usize> = Vec::with_capacity(samples.len());
1259 let mut all_probs: Vec<Vec<f32>> = Vec::with_capacity(samples.len());
1260 let mut total_loss = 0.0f32;
1261
1262 for (i, sample) in samples.iter().enumerate() {
1263 let ids = pipeline.tokenize(&sample.input);
1264 let (loss, predicted, probs) = pipeline.forward_only_with_probs(&ids, sample.label);
1265 total_loss += loss;
1266 y_true.push(sample.label);
1267 y_pred.push(predicted);
1268 all_probs.push(probs);
1269
1270 if (i + 1) % 100 == 0 {
1272 println!(" Evaluated {}/{} samples...", i + 1, samples.len());
1273 }
1274 }
1275 println!(" Evaluated {}/{} samples (done)", samples.len(), samples.len());
1276
1277 Ok(ClassifyEvalReport::from_predictions_with_probs(
1278 &y_pred,
1279 &y_true,
1280 &all_probs,
1281 total_loss,
1282 num_classes,
1283 label_names,
1284 start.elapsed().as_millis() as u64,
1285 ))
1286}