Skip to main content

entrenar/finetune/
classify_eval_report.rs

1//! Classification evaluation report and checkpoint evaluation.
2//!
3//! Extracted from `classify_trainer.rs` to reduce file size.
4
5use super::classify_pipeline::ClassifyPipeline;
6use crate::eval::classification::{ConfusionMatrix, MultiClassMetrics};
7use std::path::Path;
8
9/// Evaluation report from the classification pipeline.
10///
11/// Contains per-class precision/recall/F1, confusion matrix, aggregate metrics,
12/// and advanced diagnostics (Cohen's kappa, MCC, calibration, confidence distribution).
13/// Produced by [`ClassifyTrainer::evaluate`] or [`evaluate_checkpoint`].
14#[derive(Debug, Clone)]
15pub struct ClassifyEvalReport {
16    /// Overall accuracy (0.0-1.0)
17    pub accuracy: f64,
18    /// Average cross-entropy loss
19    pub avg_loss: f32,
20    /// Per-class precision (0.0-1.0)
21    pub per_class_precision: Vec<f64>,
22    /// Per-class recall (0.0-1.0)
23    pub per_class_recall: Vec<f64>,
24    /// Per-class F1 score (0.0-1.0)
25    pub per_class_f1: Vec<f64>,
26    /// Per-class support (sample count)
27    pub per_class_support: Vec<usize>,
28    /// Confusion matrix: `confusion_matrix[true][predicted]`
29    pub confusion_matrix: Vec<Vec<usize>>,
30    /// Number of classes
31    pub num_classes: usize,
32    /// Total samples evaluated
33    pub total_samples: usize,
34    /// Evaluation wall-clock time in milliseconds
35    pub eval_time_ms: u64,
36    /// Human-readable class names
37    pub label_names: Vec<String>,
38    /// Cohen's kappa (chance-corrected agreement, -1 to 1)
39    pub cohens_kappa: f64,
40    /// Matthews Correlation Coefficient (-1 to 1, robust to class imbalance)
41    pub mcc: f64,
42    /// Top-2 accuracy (correct class in top 2 predictions)
43    pub top2_accuracy: f64,
44    /// Mean prediction confidence (max softmax probability)
45    pub mean_confidence: f64,
46    /// Mean confidence when prediction is correct
47    pub mean_confidence_correct: f64,
48    /// Mean confidence when prediction is wrong
49    pub mean_confidence_wrong: f64,
50    /// Samples per second throughput
51    pub samples_per_sec: f64,
52    /// Calibration bins: (mean_confidence, mean_accuracy, count) for 10 bins
53    pub calibration_bins: Vec<(f64, f64, usize)>,
54    /// Expected Calibration Error (lower is better, 0 = perfectly calibrated)
55    pub ece: f64,
56    /// Brier score (multi-class mean squared error of probabilities, lower is better)
57    pub brier_score: f64,
58    /// Log loss (negative log-likelihood of true class, lower is better)
59    pub log_loss: f64,
60    /// Bootstrap 95% confidence intervals: (lower, upper) for (accuracy, macro_f1, mcc)
61    pub ci_accuracy: (f64, f64),
62    pub ci_macro_f1: (f64, f64),
63    pub ci_mcc: (f64, f64),
64    /// Baseline accuracies: (random, majority_class, stratified_random)
65    pub baseline_random: f64,
66    pub baseline_majority: f64,
67    /// Most confused class pairs: (true_class, pred_class, count)
68    pub top_confusions: Vec<(usize, usize, usize)>,
69}
70
71impl ClassifyEvalReport {
72    /// Build a report from raw predictions with full probability distributions.
73    ///
74    /// Computes all metrics including Cohen's kappa, MCC, calibration ECE,
75    /// top-2 accuracy, and confidence analysis.
76    pub(crate) fn from_predictions_with_probs(
77        y_pred: &[usize],
78        y_true: &[usize],
79        all_probs: &[Vec<f32>],
80        total_loss: f32,
81        num_classes: usize,
82        label_names: &[String],
83        eval_time_ms: u64,
84    ) -> Self {
85        let total_samples = y_pred.len();
86        let avg_loss = if total_samples > 0 { total_loss / total_samples as f32 } else { 0.0 };
87
88        let cm = ConfusionMatrix::from_predictions_with_min_classes(y_pred, y_true, num_classes);
89        let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
90        let accuracy = cm.accuracy();
91
92        let cohens_kappa = Self::compute_cohens_kappa(&cm, total_samples);
93        let mcc = Self::compute_mcc(&cm, cm.n_classes(), total_samples);
94        let top2_accuracy = Self::compute_top2_accuracy(all_probs, y_true, total_samples);
95
96        let confidences: Vec<f64> =
97            all_probs.iter().map(|p| f64::from(p.iter().copied().fold(0.0f32, f32::max))).collect();
98
99        let (mean_confidence, mean_confidence_correct, mean_confidence_wrong) =
100            Self::compute_confidence_stats(&confidences, y_pred, y_true);
101
102        let (calibration_bins, ece) =
103            Self::compute_calibration(&confidences, y_pred, y_true, total_samples);
104
105        let samples_per_sec = if eval_time_ms > 0 {
106            total_samples as f64 / (eval_time_ms as f64 / 1000.0)
107        } else {
108            0.0
109        };
110
111        let brier_score = Self::compute_brier_score(all_probs, y_true, num_classes);
112        let log_loss = Self::compute_log_loss(all_probs, y_true);
113
114        let (ci_accuracy, ci_macro_f1, ci_mcc) =
115            Self::compute_bootstrap_cis(y_pred, y_true, num_classes, 1000);
116
117        let (baseline_random, baseline_majority) =
118            Self::compute_baselines(&metrics.support, total_samples, num_classes);
119
120        let top_confusions = Self::compute_top_confusions(cm.matrix(), 5);
121
122        Self {
123            accuracy,
124            avg_loss,
125            per_class_precision: metrics.precision,
126            per_class_recall: metrics.recall,
127            per_class_f1: metrics.f1,
128            per_class_support: metrics.support,
129            confusion_matrix: cm.matrix().clone(),
130            num_classes,
131            total_samples,
132            eval_time_ms,
133            label_names: label_names.to_vec(),
134            cohens_kappa,
135            mcc,
136            top2_accuracy,
137            mean_confidence,
138            mean_confidence_correct,
139            mean_confidence_wrong,
140            samples_per_sec,
141            calibration_bins,
142            ece,
143            brier_score,
144            log_loss,
145            ci_accuracy,
146            ci_macro_f1,
147            ci_mcc,
148            baseline_random,
149            baseline_majority,
150            top_confusions,
151        }
152    }
153
154    /// Compute top-2 accuracy: fraction of samples where the true label is in the top 2 predictions.
155    pub(crate) fn compute_top2_accuracy(
156        all_probs: &[Vec<f32>],
157        y_true: &[usize],
158        total: usize,
159    ) -> f64 {
160        if total == 0 {
161            return 0.0;
162        }
163        let correct = all_probs
164            .iter()
165            .zip(y_true.iter())
166            .filter(|(probs, &true_label)| {
167                let mut indexed: Vec<(usize, f32)> = probs.iter().copied().enumerate().collect();
168                indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
169                indexed.len() >= 2 && (indexed[0].0 == true_label || indexed[1].0 == true_label)
170            })
171            .count();
172        correct as f64 / total as f64
173    }
174
175    /// Compute mean confidence overall, for correct predictions, and for wrong predictions.
176    pub(crate) fn compute_confidence_stats(
177        confidences: &[f64],
178        y_pred: &[usize],
179        y_true: &[usize],
180    ) -> (f64, f64, f64) {
181        let mean = if confidences.is_empty() {
182            0.0
183        } else {
184            confidences.iter().sum::<f64>() / confidences.len() as f64
185        };
186
187        let (mut sum_correct, mut n_correct) = (0.0f64, 0usize);
188        let (mut sum_wrong, mut n_wrong) = (0.0f64, 0usize);
189        for (i, &conf) in confidences.iter().enumerate() {
190            if y_pred[i] == y_true[i] {
191                sum_correct += conf;
192                n_correct += 1;
193            } else {
194                sum_wrong += conf;
195                n_wrong += 1;
196            }
197        }
198
199        let mean_correct = if n_correct > 0 { sum_correct / n_correct as f64 } else { 0.0 };
200        let mean_wrong = if n_wrong > 0 { sum_wrong / n_wrong as f64 } else { 0.0 };
201
202        (mean, mean_correct, mean_wrong)
203    }
204
205    /// Compute calibration bins (10 equal-width bins) and Expected Calibration Error.
206    pub(crate) fn compute_calibration(
207        confidences: &[f64],
208        y_pred: &[usize],
209        y_true: &[usize],
210        total_samples: usize,
211    ) -> (Vec<(f64, f64, usize)>, f64) {
212        let num_bins = 10;
213        let mut bin_sum_conf = vec![0.0f64; num_bins];
214        let mut bin_sum_acc = vec![0.0f64; num_bins];
215        let mut bin_count = vec![0usize; num_bins];
216
217        for (i, &conf) in confidences.iter().enumerate() {
218            let bin = ((conf * num_bins as f64) as usize).min(num_bins - 1);
219            bin_sum_conf[bin] += conf;
220            bin_sum_acc[bin] += if y_pred[i] == y_true[i] { 1.0 } else { 0.0 };
221            bin_count[bin] += 1;
222        }
223
224        let bins: Vec<(f64, f64, usize)> = (0..num_bins)
225            .map(|b| {
226                if bin_count[b] > 0 {
227                    (
228                        bin_sum_conf[b] / bin_count[b] as f64,
229                        bin_sum_acc[b] / bin_count[b] as f64,
230                        bin_count[b],
231                    )
232                } else {
233                    (0.0, 0.0, 0)
234                }
235            })
236            .collect();
237
238        let ece: f64 = bins
239            .iter()
240            .map(|&(conf, acc, count)| {
241                if count > 0 {
242                    (conf - acc).abs() * count as f64 / total_samples as f64
243                } else {
244                    0.0
245                }
246            })
247            .sum();
248
249        (bins, ece)
250    }
251
252    /// Cohen's kappa: chance-corrected agreement.
253    ///
254    /// kappa = (p_o - p_e) / (1 - p_e)
255    /// where p_o = observed agreement (accuracy), p_e = expected agreement by chance
256    pub(crate) fn compute_cohens_kappa(cm: &ConfusionMatrix, total: usize) -> f64 {
257        if total == 0 {
258            return 0.0;
259        }
260        let mat = cm.matrix();
261        let n = total as f64;
262        let p_o = cm.accuracy();
263
264        // p_e = sum_k (row_k_total * col_k_total) / n^2
265        let k = mat.len();
266        let mut p_e = 0.0f64;
267        for class in 0..k {
268            let row_sum: f64 = mat[class].iter().sum::<usize>() as f64;
269            let col_sum: f64 = mat.iter().map(|row| row[class]).sum::<usize>() as f64;
270            p_e += (row_sum * col_sum) / (n * n);
271        }
272
273        if (1.0 - p_e).abs() < 1e-10 {
274            return if (p_o - 1.0).abs() < 1e-10 { 1.0 } else { 0.0 };
275        }
276
277        (p_o - p_e) / (1.0 - p_e)
278    }
279
280    /// Matthews Correlation Coefficient for multiclass.
281    ///
282    /// MCC = (c * s - sum_k(p_k * t_k)) / sqrt((s^2 - sum_k(p_k^2)) * (s^2 - sum_k(t_k^2)))
283    /// where c = correct predictions, s = total samples, p_k = predicted counts, t_k = true counts
284    pub(crate) fn compute_mcc(cm: &ConfusionMatrix, num_classes: usize, total: usize) -> f64 {
285        if total == 0 {
286            return 0.0;
287        }
288        let mat = cm.matrix();
289        let s = total as f64;
290
291        // c = sum of diagonal (correct predictions)
292        let c: f64 = (0..num_classes).map(|k| mat[k][k] as f64).sum();
293
294        // p_k = column sums (predicted counts per class)
295        let p: Vec<f64> =
296            (0..num_classes).map(|k| mat.iter().map(|row| row[k] as f64).sum()).collect();
297
298        // t_k = row sums (true counts per class)
299        let t: Vec<f64> = (0..num_classes).map(|k| mat[k].iter().sum::<usize>() as f64).collect();
300
301        let sum_pk_tk: f64 = p.iter().zip(t.iter()).map(|(pk, tk)| pk * tk).sum();
302        let sum_pk_sq: f64 = p.iter().map(|pk| pk * pk).sum();
303        let sum_tk_sq: f64 = t.iter().map(|tk| tk * tk).sum();
304
305        let numer = c * s - sum_pk_tk;
306        let denom_sq = (s * s - sum_pk_sq) * (s * s - sum_tk_sq);
307
308        if denom_sq <= 0.0 {
309            return 0.0;
310        }
311
312        numer / denom_sq.sqrt()
313    }
314
315    /// Multi-class Brier score: mean of sum_k (p_k - y_k)^2 across samples.
316    pub(crate) fn compute_brier_score(
317        all_probs: &[Vec<f32>],
318        y_true: &[usize],
319        num_classes: usize,
320    ) -> f64 {
321        if all_probs.is_empty() {
322            return 0.0;
323        }
324        let total: f64 = all_probs
325            .iter()
326            .zip(y_true.iter())
327            .map(|(probs, &true_label)| {
328                (0..num_classes)
329                    .map(|k| {
330                        let p = f64::from(*probs.get(k).unwrap_or(&0.0));
331                        let y = if k == true_label { 1.0 } else { 0.0 };
332                        (p - y) * (p - y)
333                    })
334                    .sum::<f64>()
335            })
336            .sum();
337        total / all_probs.len() as f64
338    }
339
340    /// Log loss: -mean(log(p_true_class)).
341    pub(crate) fn compute_log_loss(all_probs: &[Vec<f32>], y_true: &[usize]) -> f64 {
342        if all_probs.is_empty() {
343            return 0.0;
344        }
345        let eps = 1e-15_f64;
346        let total: f64 = all_probs
347            .iter()
348            .zip(y_true.iter())
349            .map(|(probs, &true_label)| {
350                let p = f64::from(probs.get(true_label).copied().unwrap_or(0.0));
351                -p.clamp(eps, 1.0 - eps).ln()
352            })
353            .sum();
354        total / all_probs.len() as f64
355    }
356
357    /// Bootstrap 95% confidence intervals for accuracy, macro F1, and MCC.
358    ///
359    /// Uses percentile method with `n_boot` resamples. Deterministic seed
360    /// for reproducibility.
361    pub(crate) fn compute_bootstrap_cis(
362        y_pred: &[usize],
363        y_true: &[usize],
364        num_classes: usize,
365        n_boot: usize,
366    ) -> ((f64, f64), (f64, f64), (f64, f64)) {
367        let n = y_pred.len();
368        if n == 0 {
369            return ((0.0, 0.0), (0.0, 0.0), (0.0, 0.0));
370        }
371
372        let mut accs = Vec::with_capacity(n_boot);
373        let mut f1s = Vec::with_capacity(n_boot);
374        let mut mccs = Vec::with_capacity(n_boot);
375
376        // Simple LCG PRNG (deterministic, no dependency needed)
377        let mut rng_state: u64 = 42;
378        let lcg_next = |state: &mut u64| -> usize {
379            *state = state.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
380            ((*state >> 33) as usize) % n
381        };
382
383        for _ in 0..n_boot {
384            // Resample with replacement
385            let mut boot_pred = Vec::with_capacity(n);
386            let mut boot_true = Vec::with_capacity(n);
387            for _ in 0..n {
388                let idx = lcg_next(&mut rng_state);
389                boot_pred.push(y_pred[idx]);
390                boot_true.push(y_true[idx]);
391            }
392
393            let cm = ConfusionMatrix::from_predictions_with_min_classes(
394                &boot_pred,
395                &boot_true,
396                num_classes,
397            );
398            let metrics = MultiClassMetrics::from_confusion_matrix(&cm);
399
400            accs.push(cm.accuracy());
401
402            // Macro F1
403            let valid_f1: Vec<f64> = metrics.f1.iter().copied().filter(|v| !v.is_nan()).collect();
404            let macro_f1 = if valid_f1.is_empty() {
405                0.0
406            } else {
407                valid_f1.iter().sum::<f64>() / valid_f1.len() as f64
408            };
409            f1s.push(macro_f1);
410
411            mccs.push(Self::compute_mcc(&cm, cm.n_classes(), n));
412        }
413
414        accs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
415        f1s.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
416        mccs.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
417
418        let lo = (0.025 * n_boot as f64) as usize;
419        let hi = (0.975 * n_boot as f64).ceil() as usize;
420        let hi = hi.min(n_boot - 1);
421
422        ((accs[lo], accs[hi]), (f1s[lo], f1s[hi]), (mccs[lo], mccs[hi]))
423    }
424
425    /// Compute baseline accuracies: random and majority-class.
426    pub(crate) fn compute_baselines(
427        support: &[usize],
428        total: usize,
429        num_classes: usize,
430    ) -> (f64, f64) {
431        let random = if num_classes > 0 { 1.0 / num_classes as f64 } else { 0.0 };
432        let majority = if total > 0 {
433            support.iter().copied().max().unwrap_or(0) as f64 / total as f64
434        } else {
435            0.0
436        };
437        (random, majority)
438    }
439
440    /// Extract top-N most confused class pairs from the confusion matrix (off-diagonal).
441    pub(crate) fn compute_top_confusions(
442        matrix: &[Vec<usize>],
443        top_n: usize,
444    ) -> Vec<(usize, usize, usize)> {
445        let mut pairs: Vec<(usize, usize, usize)> = Vec::new();
446        for (i, row) in matrix.iter().enumerate() {
447            for (j, &count) in row.iter().enumerate() {
448                if i != j && count > 0 {
449                    pairs.push((i, j, count));
450                }
451            }
452        }
453        pairs.sort_by(|a, b| b.2.cmp(&a.2));
454        pairs.truncate(top_n);
455        pairs
456    }
457
458    /// Format as a human-readable sklearn-style classification report.
459    #[must_use]
460    pub fn to_report(&self) -> String {
461        use crate::eval::classification::Average;
462
463        let mut out = String::new();
464
465        // Header
466        out.push_str(&format!(
467            "{:>18} {:>10} {:>10} {:>10} {:>10}\n",
468            "", "precision", "recall", "f1-score", "support"
469        ));
470        out.push_str(&format!("{}\n", "-".repeat(62)));
471
472        // Per-class rows
473        for i in 0..self.num_classes {
474            let name = self
475                .label_names
476                .get(i)
477                .map_or_else(|| format!("Class {i}"), std::clone::Clone::clone);
478            out.push_str(&format!(
479                "{:>18} {:>10.4} {:>10.4} {:>10.4} {:>10}\n",
480                name,
481                self.per_class_precision[i],
482                self.per_class_recall[i],
483                self.per_class_f1[i],
484                self.per_class_support[i],
485            ));
486        }
487
488        out.push_str(&format!("{}\n", "-".repeat(62)));
489
490        let total_support: usize = self.per_class_support.iter().sum();
491
492        // Macro average
493        let macro_p = self.avg_metric(&self.per_class_precision, Average::Macro);
494        let macro_r = self.avg_metric(&self.per_class_recall, Average::Macro);
495        let macro_f1 = self.avg_metric(&self.per_class_f1, Average::Macro);
496        out.push_str(&format!(
497            "{:>18} {:>10.4} {:>10.4} {:>10.4} {:>10}\n",
498            "macro avg", macro_p, macro_r, macro_f1, total_support,
499        ));
500
501        // Weighted average
502        let weighted_p = self.avg_metric(&self.per_class_precision, Average::Weighted);
503        let weighted_r = self.avg_metric(&self.per_class_recall, Average::Weighted);
504        let weighted_f1 = self.avg_metric(&self.per_class_f1, Average::Weighted);
505        out.push_str(&format!(
506            "{:>18} {:>10.4} {:>10.4} {:>10.4} {:>10}\n",
507            "weighted avg", weighted_p, weighted_r, weighted_f1, total_support,
508        ));
509
510        // ── Summary metrics ───────────────────────────────────────
511        self.report_summary(&mut out);
512        self.report_confidence(&mut out);
513        self.report_scoring_rules(&mut out);
514        self.report_calibration(&mut out);
515        self.report_baselines(&mut out);
516        self.report_top_confusions(&mut out);
517        self.report_throughput(&mut out);
518        out
519    }
520
521    pub(crate) fn report_summary(&self, out: &mut String) {
522        out.push_str(&format!(
523            "\nAccuracy:       {:.4}  ({:.1}%)  95% CI [{:.4}, {:.4}]\n",
524            self.accuracy,
525            self.accuracy * 100.0,
526            self.ci_accuracy.0,
527            self.ci_accuracy.1
528        ));
529        out.push_str(&format!(
530            "Top-2 accuracy: {:.4}  ({:.1}%)\n",
531            self.top2_accuracy,
532            self.top2_accuracy * 100.0
533        ));
534        out.push_str(&format!(
535            "Cohen's kappa:  {:.4}  ({})\n",
536            self.cohens_kappa,
537            Self::kappa_interpretation(self.cohens_kappa)
538        ));
539        out.push_str(&format!(
540            "MCC:            {:.4}  95% CI [{:.4}, {:.4}]\n",
541            self.mcc, self.ci_mcc.0, self.ci_mcc.1
542        ));
543        out.push_str(&format!(
544            "Macro F1:       {:.4}  95% CI [{:.4}, {:.4}]\n",
545            self.avg_metric(&self.per_class_f1, crate::eval::classification::Average::Macro),
546            self.ci_macro_f1.0,
547            self.ci_macro_f1.1
548        ));
549        out.push_str(&format!("Avg loss:       {:.4}\n", self.avg_loss));
550    }
551
552    pub(crate) fn report_confidence(&self, out: &mut String) {
553        out.push_str(&format!("\nConfidence (mean): {:.4}\n", self.mean_confidence));
554        out.push_str(&format!("  correct preds:   {:.4}\n", self.mean_confidence_correct));
555        out.push_str(&format!("  wrong preds:     {:.4}\n", self.mean_confidence_wrong));
556        let gap = self.mean_confidence_correct - self.mean_confidence_wrong;
557        out.push_str(&format!("  gap (higher=better): {gap:.4}\n"));
558    }
559
560    pub(crate) fn report_scoring_rules(&self, out: &mut String) {
561        out.push_str(&format!(
562            "\nBrier score:    {:.4}  (perfect=0, random={:.4})\n",
563            self.brier_score,
564            1.0 - 1.0 / self.num_classes as f64
565        ));
566        out.push_str(&format!(
567            "Log loss:       {:.4}  (random={:.4})\n",
568            self.log_loss,
569            (self.num_classes as f64).ln()
570        ));
571    }
572
573    pub(crate) fn report_calibration(&self, out: &mut String) {
574        out.push_str(&format!("\nECE (Expected Calibration Error): {:.4}\n", self.ece));
575        out.push_str("Calibration:\n");
576        out.push_str("  Bin       Confidence  Accuracy    Count\n");
577        for (i, &(conf, acc, count)) in self.calibration_bins.iter().enumerate() {
578            if count > 0 {
579                let lo = i as f64 * 0.1;
580                let hi = lo + 0.1;
581                let overconf = if conf > acc { "+" } else { "" };
582                out.push_str(&format!(
583                    "  [{:.1}-{:.1})  {:.4}      {:.4}      {:>5}  {overconf}{:.3}\n",
584                    lo,
585                    hi,
586                    conf,
587                    acc,
588                    count,
589                    conf - acc,
590                ));
591            }
592        }
593    }
594
595    pub(crate) fn report_baselines(&self, out: &mut String) {
596        out.push_str(&format!(
597            "\nBaselines:  random={:.1}%  majority={:.1}%  model={:.1}%  lift={:.1}x\n",
598            self.baseline_random * 100.0,
599            self.baseline_majority * 100.0,
600            self.accuracy * 100.0,
601            if self.baseline_majority > 0.0 { self.accuracy / self.baseline_majority } else { 0.0 },
602        ));
603    }
604
605    pub(crate) fn report_top_confusions(&self, out: &mut String) {
606        if self.top_confusions.is_empty() {
607            return;
608        }
609        out.push_str("\nTop confusions (true → predicted, count):\n");
610        for &(true_c, pred_c, count) in &self.top_confusions {
611            let true_name = self.label_names.get(true_c).map_or("?", |n| n.as_str());
612            let pred_name = self.label_names.get(pred_c).map_or("?", |n| n.as_str());
613            out.push_str(&format!("  {true_name} → {pred_name}: {count}\n"));
614        }
615    }
616
617    pub(crate) fn report_throughput(&self, out: &mut String) {
618        out.push_str(&format!("\nSamples:   {}\n", self.total_samples));
619        out.push_str(&format!(
620            "Time:      {}ms ({:.1} samples/sec)\n",
621            self.eval_time_ms, self.samples_per_sec
622        ));
623    }
624
625    /// Interpret Cohen's kappa value.
626    pub(crate) fn kappa_interpretation(kappa: f64) -> &'static str {
627        if kappa < 0.0 {
628            "worse than chance"
629        } else if kappa < 0.20 {
630            "slight"
631        } else if kappa < 0.40 {
632            "fair"
633        } else if kappa < 0.60 {
634            "moderate"
635        } else if kappa < 0.80 {
636            "substantial"
637        } else {
638            "almost perfect"
639        }
640    }
641
642    /// Format as JSON string.
643    ///
644    /// Uses `serde_json::json!` internally — infallible.
645    #[must_use]
646    #[allow(clippy::disallowed_methods)]
647    pub fn to_json(&self) -> String {
648        let per_class: Vec<serde_json::Value> = (0..self.num_classes)
649            .map(|i| {
650                let name = self
651                    .label_names
652                    .get(i)
653                    .map_or_else(|| format!("class_{i}"), std::clone::Clone::clone);
654                serde_json::json!({
655                    "label": name,
656                    "precision": self.per_class_precision[i],
657                    "recall": self.per_class_recall[i],
658                    "f1": self.per_class_f1[i],
659                    "support": self.per_class_support[i],
660                })
661            })
662            .collect();
663
664        let calibration: Vec<serde_json::Value> = self
665            .calibration_bins
666            .iter()
667            .enumerate()
668            .filter(|(_, &(_, _, count))| count > 0)
669            .map(|(i, &(conf, acc, count))| {
670                serde_json::json!({
671                    "bin": format!("[{:.1}-{:.1})", i as f64 * 0.1, (i + 1) as f64 * 0.1),
672                    "mean_confidence": conf,
673                    "mean_accuracy": acc,
674                    "count": count,
675                })
676            })
677            .collect();
678
679        let confusions: Vec<serde_json::Value> = self.top_confusions.iter().map(|&(t, p, c)| {
680            serde_json::json!({
681                "true_class": self.label_names.get(t).cloned().unwrap_or_else(|| format!("class_{t}")),
682                "pred_class": self.label_names.get(p).cloned().unwrap_or_else(|| format!("class_{p}")),
683                "count": c,
684            })
685        }).collect();
686
687        let json = serde_json::json!({
688            "accuracy": self.accuracy,
689            "top2_accuracy": self.top2_accuracy,
690            "cohens_kappa": self.cohens_kappa,
691            "mcc": self.mcc,
692            "avg_loss": self.avg_loss,
693            "brier_score": self.brier_score,
694            "log_loss": self.log_loss,
695            "total_samples": self.total_samples,
696            "num_classes": self.num_classes,
697            "eval_time_ms": self.eval_time_ms,
698            "samples_per_sec": self.samples_per_sec,
699            "confidence_intervals_95": {
700                "accuracy": [self.ci_accuracy.0, self.ci_accuracy.1],
701                "macro_f1": [self.ci_macro_f1.0, self.ci_macro_f1.1],
702                "mcc": [self.ci_mcc.0, self.ci_mcc.1],
703            },
704            "baselines": {
705                "random": self.baseline_random,
706                "majority_class": self.baseline_majority,
707                "lift_over_majority": if self.baseline_majority > 0.0 { self.accuracy / self.baseline_majority } else { 0.0 },
708            },
709            "per_class": per_class,
710            "confusion_matrix": self.confusion_matrix,
711            "top_confusions": confusions,
712            "confidence": {
713                "mean": self.mean_confidence,
714                "mean_correct": self.mean_confidence_correct,
715                "mean_wrong": self.mean_confidence_wrong,
716                "gap": self.mean_confidence_correct - self.mean_confidence_wrong,
717            },
718            "calibration": {
719                "ece": self.ece,
720                "brier_score": self.brier_score,
721                "log_loss": self.log_loss,
722                "bins": calibration,
723            },
724        });
725
726        serde_json::to_string_pretty(&json).unwrap_or_default()
727    }
728
729    /// Generate a HuggingFace-compatible model card (README.md) from evaluation results.
730    ///
731    /// Produces a publication-quality model card with YAML front matter, summary metrics,
732    /// per-class breakdown, confusion matrix (raw + normalized), calibration analysis,
733    /// intended use, limitations, and ethical considerations.
734    #[must_use]
735    pub fn to_model_card(&self, model_name: &str, base_model: Option<&str>) -> String {
736        use crate::eval::classification::Average;
737
738        let macro_f1 = self.avg_metric(&self.per_class_f1, Average::Macro);
739        let weighted_f1 = self.avg_metric(&self.per_class_f1, Average::Weighted);
740
741        let mut out = String::new();
742        self.card_yaml_front_matter(&mut out, model_name, base_model, macro_f1, weighted_f1);
743        self.card_title(&mut out, model_name, base_model);
744        self.card_summary(&mut out, macro_f1, weighted_f1);
745        self.card_labels(&mut out);
746        self.card_per_class_metrics(&mut out);
747        self.card_confusion_matrix(&mut out);
748        self.card_error_analysis(&mut out);
749        self.card_calibration(&mut out);
750        Self::card_intended_use(&mut out);
751        self.card_limitations(&mut out);
752        Self::card_ethical_considerations(&mut out);
753        self.card_training(&mut out, base_model);
754        out.push_str("---\n*Generated by [entrenar](https://github.com/paiml/entrenar)*\n");
755        out
756    }
757
758    pub(crate) fn card_yaml_front_matter(
759        &self,
760        out: &mut String,
761        model_name: &str,
762        base_model: Option<&str>,
763        macro_f1: f64,
764        weighted_f1: f64,
765    ) {
766        out.push_str("---\n");
767        out.push_str("license: apache-2.0\n");
768        out.push_str("language:\n- en\n");
769        out.push_str(
770            "tags:\n- shell-safety\n- code-classification\n- lora\n- entrenar\n- security\n",
771        );
772        if let Some(base) = base_model {
773            out.push_str(&format!("base_model: {base}\n"));
774        }
775        out.push_str("pipeline_tag: text-classification\n");
776        out.push_str("model-index:\n");
777        out.push_str(&format!("- name: {model_name}\n"));
778        out.push_str("  results:\n");
779        out.push_str("  - task:\n");
780        out.push_str("      type: text-classification\n");
781        out.push_str("      name: Shell Safety Classification\n");
782        out.push_str("    metrics:\n");
783        out.push_str(&format!("    - type: accuracy\n      value: {:.4}\n", self.accuracy));
784        out.push_str(&format!(
785            "    - type: f1\n      value: {macro_f1:.4}\n      name: Macro F1\n"
786        ));
787        out.push_str(&format!(
788            "    - type: f1\n      value: {weighted_f1:.4}\n      name: Weighted F1\n"
789        ));
790        out.push_str(&format!("    - type: mcc\n      value: {:.4}\n", self.mcc));
791        out.push_str(&format!("    - type: cohens_kappa\n      value: {:.4}\n", self.cohens_kappa));
792        out.push_str("---\n\n");
793    }
794
795    pub(crate) fn card_title(&self, out: &mut String, model_name: &str, base_model: Option<&str>) {
796        out.push_str(&format!("# {model_name}\n\n"));
797        out.push_str("A shell command safety classifier that categorizes shell commands into safety classes, ");
798        out.push_str("enabling automated triage of commands before execution.\n\n");
799        out.push_str(
800            "Trained with [entrenar](https://github.com/paiml/entrenar) using LoRA fine-tuning",
801        );
802        if let Some(base) = base_model {
803            out.push_str(&format!(" on [`{base}`](https://huggingface.co/{base})"));
804        }
805        out.push_str(".\n\n");
806    }
807
808    pub(crate) fn card_summary(&self, out: &mut String, macro_f1: f64, weighted_f1: f64) {
809        out.push_str("## Summary\n\n");
810        out.push_str("| Metric | Value | 95% CI |\n");
811        out.push_str("|--------|-------|--------|\n");
812        out.push_str(&format!(
813            "| Accuracy | {:.2}% | [{:.2}%, {:.2}%] |\n",
814            self.accuracy * 100.0,
815            self.ci_accuracy.0 * 100.0,
816            self.ci_accuracy.1 * 100.0
817        ));
818        out.push_str(&format!("| Top-2 Accuracy | {:.2}% | — |\n", self.top2_accuracy * 100.0));
819        out.push_str(&format!(
820            "| Macro F1 | {macro_f1:.4} | [{:.4}, {:.4}] |\n",
821            self.ci_macro_f1.0, self.ci_macro_f1.1
822        ));
823        out.push_str(&format!("| Weighted F1 | {weighted_f1:.4} | — |\n"));
824        out.push_str(&format!(
825            "| Cohen's Kappa | {:.4} ({}) | — |\n",
826            self.cohens_kappa,
827            Self::kappa_interpretation(self.cohens_kappa)
828        ));
829        out.push_str(&format!(
830            "| MCC | {:.4} | [{:.4}, {:.4}] |\n",
831            self.mcc, self.ci_mcc.0, self.ci_mcc.1
832        ));
833        out.push_str(&format!("| Brier Score | {:.4} | — |\n", self.brier_score));
834        out.push_str(&format!("| Log Loss | {:.4} | — |\n", self.log_loss));
835        out.push_str(&format!("| ECE | {:.4} | — |\n", self.ece));
836        out.push_str(&format!("| Avg Loss | {:.4} | — |\n", self.avg_loss));
837        out.push_str(&format!("| Eval Samples | {} | — |\n", self.total_samples));
838        out.push_str(&format!("| Throughput | {:.1} samples/sec | — |\n\n", self.samples_per_sec));
839
840        let lift =
841            if self.baseline_majority > 0.0 { self.accuracy / self.baseline_majority } else { 0.0 };
842        out.push_str("**Baselines**: ");
843        out.push_str(&format!(
844            "random={:.1}%, majority={:.1}%, model={:.1}% ({:.1}x lift over majority)\n\n",
845            self.baseline_random * 100.0,
846            self.baseline_majority * 100.0,
847            self.accuracy * 100.0,
848            lift
849        ));
850    }
851
852    pub(crate) fn card_labels(&self, out: &mut String) {
853        out.push_str("## Labels\n\n");
854        out.push_str("| ID | Label | Description |\n");
855        out.push_str("|----|-------|-------------|\n");
856        let descriptions = [
857            "Command is safe to execute as-is",
858            "Command contains unquoted variable expansions (word splitting/globbing risk)",
859            "Command uses non-deterministic sources ($RANDOM, $$, date, etc.)",
860            "Command is not idempotent (unsafe to re-run: mkdir without -p, etc.)",
861            "Command is destructive or has injection risk (rm -rf, eval, etc.)",
862        ];
863        for (i, name) in self.label_names.iter().enumerate() {
864            let desc = descriptions.get(i).unwrap_or(&"");
865            out.push_str(&format!("| {i} | {name} | {desc} |\n"));
866        }
867        out.push('\n');
868    }
869
870    pub(crate) fn card_per_class_metrics(&self, out: &mut String) {
871        out.push_str("## Per-Class Metrics\n\n");
872        out.push_str("| Label | Precision | Recall | F1 | Support |\n");
873        out.push_str("|-------|-----------|--------|----|---------|\n");
874        for i in 0..self.num_classes {
875            let name = self
876                .label_names
877                .get(i)
878                .map_or_else(|| format!("class_{i}"), std::clone::Clone::clone);
879            out.push_str(&format!(
880                "| {} | {:.4} | {:.4} | {:.4} | {} |\n",
881                name,
882                self.per_class_precision[i],
883                self.per_class_recall[i],
884                self.per_class_f1[i],
885                self.per_class_support[i],
886            ));
887        }
888        out.push('\n');
889    }
890
891    fn card_confusion_matrix(&self, out: &mut String) {
892        out.push_str("## Confusion Matrix\n\n");
893        self.card_confusion_raw(out);
894        self.card_confusion_normalized(out);
895    }
896
897    pub(crate) fn card_confusion_raw(&self, out: &mut String) {
898        out.push_str("### Raw Counts\n\n```\n");
899        self.card_confusion_header(out);
900        for (i, row) in self.confusion_matrix.iter().enumerate() {
901            self.card_confusion_row_label(out, i);
902            for val in row {
903                out.push_str(&format!(" {val:>8}"));
904            }
905            out.push('\n');
906        }
907        out.push_str("```\n\n");
908    }
909
910    pub(crate) fn card_confusion_normalized(&self, out: &mut String) {
911        out.push_str("### Normalized (row %)\n\n```\n");
912        self.card_confusion_header(out);
913        for (i, row) in self.confusion_matrix.iter().enumerate() {
914            self.card_confusion_row_label(out, i);
915            let row_sum: usize = row.iter().sum();
916            for val in row {
917                if row_sum > 0 {
918                    out.push_str(&format!(" {:>7.1}%", *val as f64 / row_sum as f64 * 100.0));
919                } else {
920                    out.push_str("     0.0%");
921                }
922            }
923            out.push('\n');
924        }
925        out.push_str("```\n\n");
926    }
927
928    pub(crate) fn card_confusion_header(&self, out: &mut String) {
929        out.push_str(&format!("{:>18}", "Predicted →"));
930        for name in &self.label_names {
931            let short = if name.len() > 8 { &name[..8] } else { name.as_str() };
932            out.push_str(&format!(" {short:>8}"));
933        }
934        out.push('\n');
935    }
936
937    pub(crate) fn card_confusion_row_label(&self, out: &mut String, i: usize) {
938        let name =
939            self.label_names.get(i).map_or_else(|| format!("class_{i}"), std::clone::Clone::clone);
940        let short = if name.len() > 18 { &name[..18] } else { name.as_str() };
941        out.push_str(&format!("{short:>18}"));
942    }
943
944    pub(crate) fn card_calibration(&self, out: &mut String) {
945        out.push_str("## Confidence & Calibration\n\n");
946        out.push_str("| Metric | Value |\n");
947        out.push_str("|--------|-------|\n");
948        out.push_str(&format!("| Mean confidence | {:.4} |\n", self.mean_confidence));
949        out.push_str(&format!("| Confidence (correct) | {:.4} |\n", self.mean_confidence_correct));
950        out.push_str(&format!("| Confidence (wrong) | {:.4} |\n", self.mean_confidence_wrong));
951        let gap = self.mean_confidence_correct - self.mean_confidence_wrong;
952        out.push_str(&format!("| Confidence gap | {gap:.4} |\n"));
953        out.push_str(&format!("| ECE | {:.4} |\n\n", self.ece));
954
955        out.push_str("**Calibration curve** (reliability diagram):\n\n");
956        out.push_str("```\n");
957        out.push_str("Bin         Conf    Acc     Count\n");
958        for (i, &(conf, acc, count)) in self.calibration_bins.iter().enumerate() {
959            if count > 0 {
960                let lo = i as f64 * 0.1;
961                let hi = lo + 0.1;
962                out.push_str(&format!("[{lo:.1}-{hi:.1})   {conf:.3}   {acc:.3}   {count:>5}\n",));
963            }
964        }
965        out.push_str("```\n\n");
966    }
967
968    pub(crate) fn card_error_analysis(&self, out: &mut String) {
969        if self.top_confusions.is_empty() {
970            return;
971        }
972        out.push_str("## Error Analysis\n\n");
973        out.push_str("Most frequent misclassifications:\n\n");
974        out.push_str("| True Class | Predicted As | Count |\n");
975        out.push_str("|------------|-------------|-------|\n");
976        for &(true_c, pred_c, count) in &self.top_confusions {
977            let true_name = self.label_names.get(true_c).map_or("?", |n| n.as_str());
978            let pred_name = self.label_names.get(pred_c).map_or("?", |n| n.as_str());
979            out.push_str(&format!("| {true_name} | {pred_name} | {count} |\n"));
980        }
981        out.push('\n');
982    }
983
984    pub(crate) fn card_intended_use(out: &mut String) {
985        out.push_str("## Intended Use\n\n");
986        out.push_str(
987            "This model is designed for **automated shell command safety triage** in:\n\n",
988        );
989        out.push_str(
990            "- **CI/CD pipelines**: Pre-flight safety check before executing generated scripts\n",
991        );
992        out.push_str("- **Shell purification tools**: Classify commands to determine transformation strategy\n");
993        out.push_str(
994            "- **Code review**: Flag potentially unsafe shell commands in pull requests\n",
995        );
996        out.push_str("- **Interactive shells**: Warn users before executing risky commands\n\n");
997    }
998
999    fn card_limitations(&self, out: &mut String) {
1000        out.push_str("## Limitations\n\n");
1001        out.push_str("- **Not a security oracle**: This model provides *classification hints*, not security guarantees\n");
1002        out.push_str("- **Context-blind**: Cannot assess safety based on the execution environment or user permissions\n");
1003        out.push_str("- **Training distribution**: Trained on synthetic shell scripts; may underperform on novel patterns\n");
1004        out.push_str(
1005            "- **English only**: Command names and variable patterns are English-centric\n",
1006        );
1007        self.card_weak_classes(out);
1008        out.push('\n');
1009    }
1010
1011    pub(crate) fn card_weak_classes(&self, out: &mut String) {
1012        let min_f1_idx = self
1013            .per_class_f1
1014            .iter()
1015            .enumerate()
1016            .min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
1017            .map(|(i, _)| i);
1018        if let Some(idx) = min_f1_idx {
1019            if self.per_class_f1[idx] < 0.5 {
1020                let name = self.label_names.get(idx).map_or("unknown", |n| n.as_str());
1021                out.push_str(&format!(
1022                    "- **Weak class**: `{name}` (F1={:.2}) — consider additional training data\n",
1023                    self.per_class_f1[idx]
1024                ));
1025            }
1026        }
1027    }
1028
1029    pub(crate) fn card_ethical_considerations(out: &mut String) {
1030        out.push_str("## Ethical Considerations\n\n");
1031        out.push_str("- **False negatives are dangerous**: An `unsafe` command classified as `safe` could lead to data loss\n");
1032        out.push_str("- **Defense in depth**: Always combine this classifier with other safety mechanisms (sandboxing, dry-run)\n");
1033        out.push_str("- **Not adversarial-robust**: Determined attackers can craft commands to evade classification\n\n");
1034    }
1035
1036    pub(crate) fn card_training(&self, out: &mut String, base_model: Option<&str>) {
1037        out.push_str("## Training\n\n");
1038        out.push_str("| Parameter | Value |\n");
1039        out.push_str("|-----------|-------|\n");
1040        out.push_str("| Framework | [entrenar](https://github.com/paiml/entrenar) (Rust) |\n");
1041        out.push_str("| Method | LoRA (Low-Rank Adaptation) |\n");
1042        if let Some(base) = base_model {
1043            out.push_str(&format!("| Base model | `{base}` |\n"));
1044        }
1045        out.push_str(&format!("| Num classes | {} |\n\n", self.num_classes));
1046    }
1047
1048    /// Average a metric vector using the given strategy.
1049    pub(crate) fn avg_metric(
1050        &self,
1051        values: &[f64],
1052        average: crate::eval::classification::Average,
1053    ) -> f64 {
1054        match average {
1055            crate::eval::classification::Average::Macro => {
1056                if values.is_empty() {
1057                    0.0
1058                } else {
1059                    values.iter().sum::<f64>() / values.len() as f64
1060                }
1061            }
1062            crate::eval::classification::Average::Weighted => {
1063                let total: usize = self.per_class_support.iter().sum();
1064                if total == 0 {
1065                    return 0.0;
1066                }
1067                values
1068                    .iter()
1069                    .zip(self.per_class_support.iter())
1070                    .map(|(&v, &s)| v * s as f64)
1071                    .sum::<f64>()
1072                    / total as f64
1073            }
1074            _ => {
1075                // Fallback to macro
1076                if values.is_empty() {
1077                    0.0
1078                } else {
1079                    values.iter().sum::<f64>() / values.len() as f64
1080                }
1081            }
1082        }
1083    }
1084}
1085
1086/// SSC label names used across the shell safety classifier.
1087pub const SSC_LABELS: [&str; 5] =
1088    ["safe", "needs-quoting", "non-deterministic", "non-idempotent", "unsafe"];
1089
1090/// Evaluate a saved checkpoint against a test JSONL dataset.
1091///
1092/// Standalone function that loads a checkpoint, builds a pipeline, and runs
1093/// evaluation without needing a full `ClassifyTrainer` setup.
1094///
1095/// Handles LoRA adapter checkpoints: reads `adapter_config.json` to find the
1096/// base model path, loads the full transformer from that path, then restores
1097/// trained LoRA + classifier head weights from the checkpoint's `model.safetensors`.
1098/// Restore class_weights from checkpoint metadata.json if present.
1099/// Training runs that use `auto_balance_classes()` save weights to metadata;
1100/// without this, evaluation would use uniform weights while training used weighted loss.
1101pub(crate) fn restore_class_weights_from_metadata(
1102    checkpoint_dir: &std::path::Path,
1103    num_classes: usize,
1104) -> Option<Vec<f32>> {
1105    let meta_str = std::fs::read_to_string(checkpoint_dir.join("metadata.json")).ok()?;
1106    let meta: serde_json::Value = serde_json::from_str(&meta_str).ok()?;
1107    let arr = meta.get("class_weights")?.as_array()?;
1108    let weights: Vec<f32> = arr.iter().filter_map(|v| v.as_f64().map(|f| f as f32)).collect();
1109    (weights.len() == num_classes).then_some(weights)
1110}
1111
1112///
1113/// # Arguments
1114/// * `checkpoint_dir` - Directory containing `model.safetensors` + `adapter_config.json`
1115/// * `test_data` - JSONL file with `{"input": "...", "label": N}` entries
1116/// * `model_config` - Transformer architecture config (must match checkpoint)
1117/// * `classify_config` - Classification config (num_classes, etc.)
1118/// * `label_names` - Human-readable class names
1119///
1120/// # Errors
1121/// Returns error if checkpoint or test data cannot be loaded.
1122pub fn evaluate_checkpoint(
1123    checkpoint_dir: &Path,
1124    test_data: &Path,
1125    model_config: &crate::transformer::TransformerConfig,
1126    classify_config: super::classify_pipeline::ClassifyConfig,
1127    label_names: &[String],
1128) -> crate::Result<ClassifyEvalReport> {
1129    use super::classification::load_safety_corpus;
1130
1131    let start = std::time::Instant::now();
1132    let num_classes = classify_config.num_classes;
1133
1134    // Restore class_weights from checkpoint metadata if not provided by caller.
1135    let mut classify_config = classify_config;
1136    if classify_config.class_weights.is_none() {
1137        if let Some(weights) = restore_class_weights_from_metadata(checkpoint_dir, num_classes) {
1138            println!("Restored class_weights from checkpoint: {weights:?}");
1139            classify_config.class_weights = Some(weights);
1140        }
1141    }
1142
1143    // Resolve the base model directory from adapter_config.json (LoRA checkpoint)
1144    // or fall back to loading directly from checkpoint_dir (full model checkpoint)
1145    let adapter_config_path = checkpoint_dir.join("adapter_config.json");
1146    let mut pipeline = if adapter_config_path.exists() {
1147        // LoRA adapter checkpoint: load base model, then restore adapter weights
1148        let adapter_json = std::fs::read_to_string(&adapter_config_path)
1149            .map_err(|e| crate::Error::Io(format!("Failed to read adapter_config.json: {e}")))?;
1150        let peft_config: crate::lora::PeftAdapterConfig = serde_json::from_str(&adapter_json)
1151            .map_err(|e| {
1152                crate::Error::Serialization(format!("Invalid adapter_config.json: {e}"))
1153            })?;
1154
1155        // Update classify_config with LoRA rank/alpha from the checkpoint's
1156        // adapter_config.json so LoRA layers are created with matching dimensions.
1157        if peft_config.r > 0 {
1158            classify_config.lora_rank = peft_config.r;
1159        }
1160        if peft_config.lora_alpha > 0.0 {
1161            classify_config.lora_alpha = peft_config.lora_alpha;
1162        }
1163
1164        // Try to load base model from pretrained weights.  Fall back to random
1165        // init when the path is missing, points to an .apr file (not a
1166        // SafeTensors directory), or otherwise fails to load.
1167        let mut pipe = match peft_config.base_model_name_or_path.as_deref() {
1168            Some(base_model_path)
1169                if std::path::Path::new(base_model_path).is_dir()
1170                    || std::path::Path::new(base_model_path)
1171                        .extension()
1172                        .is_some_and(|e| e == "safetensors") =>
1173            {
1174                println!("Loading base model from: {base_model_path}");
1175                ClassifyPipeline::from_pretrained(
1176                    base_model_path,
1177                    model_config,
1178                    classify_config.clone(),
1179                )?
1180            }
1181            Some(base_model_path) => {
1182                println!("Base model path is not a SafeTensors directory: {base_model_path}");
1183                println!("Using random-init base model (adapter weights will be restored from checkpoint)");
1184                ClassifyPipeline::new(model_config, classify_config.clone())
1185            }
1186            None => {
1187                println!("No base_model_name_or_path in adapter_config.json");
1188                println!("Using random-init base model (adapter weights will be restored from checkpoint)");
1189                ClassifyPipeline::new(model_config, classify_config.clone())
1190            }
1191        };
1192
1193        // Load trained LoRA + classifier weights from checkpoint
1194        let st_path = checkpoint_dir.join("model.safetensors");
1195        let st_data = std::fs::read(&st_path).map_err(|e| {
1196            crate::Error::Io(format!("Failed to read checkpoint model.safetensors: {e}"))
1197        })?;
1198        let tensors = safetensors::SafeTensors::deserialize(&st_data).map_err(|e| {
1199            crate::Error::Serialization(format!("Failed to deserialize checkpoint: {e}"))
1200        })?;
1201
1202        // Restore classifier head weights
1203        if let Ok(w) = tensors.tensor("classifier.weight") {
1204            let w_data: &[f32] = bytemuck::cast_slice(w.data());
1205            pipe.classifier
1206                .weight
1207                .data_mut()
1208                .as_slice_mut()
1209                .expect("contiguous classifier.weight")
1210                .copy_from_slice(w_data);
1211        }
1212        if let Ok(b) = tensors.tensor("classifier.bias") {
1213            let b_data: &[f32] = bytemuck::cast_slice(b.data());
1214            pipe.classifier
1215                .bias
1216                .data_mut()
1217                .as_slice_mut()
1218                .expect("contiguous classifier.bias")
1219                .copy_from_slice(b_data);
1220        }
1221
1222        // Restore LoRA adapter weights (convention: 2 per layer, Q=even V=odd)
1223        for (idx, lora) in pipe.lora_layers.iter_mut().enumerate() {
1224            let layer = idx / 2;
1225            let proj = if idx % 2 == 0 { "q" } else { "v" };
1226
1227            if let Ok(a) = tensors.tensor(&format!("lora.{layer}.{proj}_proj.lora_a")) {
1228                let a_data: &[f32] = bytemuck::cast_slice(a.data());
1229                lora.lora_a_mut()
1230                    .data_mut()
1231                    .as_slice_mut()
1232                    .expect("contiguous lora_a")
1233                    .copy_from_slice(a_data);
1234            }
1235            if let Ok(b) = tensors.tensor(&format!("lora.{layer}.{proj}_proj.lora_b")) {
1236                let b_data: &[f32] = bytemuck::cast_slice(b.data());
1237                lora.lora_b_mut()
1238                    .data_mut()
1239                    .as_slice_mut()
1240                    .expect("contiguous lora_b")
1241                    .copy_from_slice(b_data);
1242            }
1243        }
1244
1245        let loaded_count = tensors.names().len();
1246        println!("Restored {loaded_count} tensors from checkpoint");
1247        pipe
1248    } else {
1249        // Full model checkpoint: load directly
1250        ClassifyPipeline::from_pretrained(checkpoint_dir, model_config, classify_config)?
1251    };
1252
1253    // Load test corpus
1254    let samples = load_safety_corpus(test_data, num_classes)?;
1255
1256    // Run forward-only on all samples, collecting full probability distributions
1257    let mut y_true: Vec<usize> = Vec::with_capacity(samples.len());
1258    let mut y_pred: Vec<usize> = Vec::with_capacity(samples.len());
1259    let mut all_probs: Vec<Vec<f32>> = Vec::with_capacity(samples.len());
1260    let mut total_loss = 0.0f32;
1261
1262    for (i, sample) in samples.iter().enumerate() {
1263        let ids = pipeline.tokenize(&sample.input);
1264        let (loss, predicted, probs) = pipeline.forward_only_with_probs(&ids, sample.label);
1265        total_loss += loss;
1266        y_true.push(sample.label);
1267        y_pred.push(predicted);
1268        all_probs.push(probs);
1269
1270        // Progress indicator every 100 samples
1271        if (i + 1) % 100 == 0 {
1272            println!("  Evaluated {}/{} samples...", i + 1, samples.len());
1273        }
1274    }
1275    println!("  Evaluated {}/{} samples (done)", samples.len(), samples.len());
1276
1277    Ok(ClassifyEvalReport::from_predictions_with_probs(
1278        &y_pred,
1279        &y_true,
1280        &all_probs,
1281        total_loss,
1282        num_classes,
1283        label_names,
1284        start.elapsed().as_millis() as u64,
1285    ))
1286}