Skip to main content

entrenar/finetune/
linear_probe.rs

1//! Linear probe classifier pipeline (SSC v11 Section 5)
2//!
3//! Implements the CodeBERT linear probe:
4//! 1. Extract [CLS] embeddings from frozen encoder (CLF-001)
5//! 2. Train Linear(hidden_size, num_classes) on cached embeddings (CLF-002)
6//! 3. Evaluate with MCC, accuracy, recall, bootstrap CI (CLF-003)
7//! 4. Cache confidence scores for conversation generation (CLF-007)
8//!
9//! # Architecture
10//!
11//! ```text
12//! token_ids → EncoderModel.cls_embedding() → [hidden_size]
13//!           → Linear(hidden_size, 2) → softmax → [p_safe, p_unsafe]
14//! ```
15//!
16//! # Contract (linear-probe-classifier-v1.yaml)
17//! - Frozen encoder: weights unchanged after training
18//! - Probability simplex: softmax sums to 1.0
19//! - Embedding determinism: bit-identical on repeated calls
20
21use crate::autograd::{matmul, Tensor};
22
23/// Classification metrics for binary or multi-class evaluation (CLF-003).
24#[derive(Debug, Clone)]
25pub struct ClassificationMetrics {
26    /// Matthews Correlation Coefficient (-1 to 1)
27    pub mcc: f32,
28    /// Overall accuracy (0 to 1)
29    pub accuracy: f32,
30    /// Per-class recall (sensitivity)
31    pub recall: Vec<f32>,
32    /// Per-class precision
33    pub precision: Vec<f32>,
34    /// Number of samples evaluated
35    pub num_samples: usize,
36    /// Confusion matrix [predicted][actual] — row=predicted, col=actual
37    pub confusion_matrix: Vec<Vec<usize>>,
38}
39
40/// Bootstrap confidence interval.
41#[derive(Debug, Clone, Copy)]
42pub struct BootstrapCI {
43    /// Point estimate
44    pub estimate: f32,
45    /// Lower bound (2.5th percentile)
46    pub lower: f32,
47    /// Upper bound (97.5th percentile)
48    pub upper: f32,
49    /// Number of bootstrap iterations
50    pub n_bootstrap: usize,
51}
52
53/// Linear probe: frozen embeddings + trainable linear head (CLF-002).
54///
55/// Trains on pre-extracted embeddings (not raw token IDs), making training
56/// complete in seconds rather than minutes.
57pub struct LinearProbe {
58    /// Linear weight [hidden_size, num_classes] flattened row-major
59    pub weight: Tensor,
60    /// Bias [num_classes]
61    pub bias: Tensor,
62    /// Input dimension
63    hidden_size: usize,
64    /// Number of output classes
65    num_classes: usize,
66}
67
68impl LinearProbe {
69    /// Create with Xavier initialization.
70    pub fn new(hidden_size: usize, num_classes: usize) -> Self {
71        assert!(hidden_size > 0, "hidden_size must be > 0");
72        assert!(num_classes >= 2, "num_classes must be >= 2");
73
74        let scale = (6.0 / (hidden_size + num_classes) as f32).sqrt();
75        let mut rng: u64 = 42;
76        let weight_data: Vec<f32> = (0..hidden_size * num_classes)
77            .map(|_| {
78                rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
79                let u = (rng >> 33) as f32 / (1u64 << 31) as f32;
80                (2.0 * u - 1.0) * scale
81            })
82            .collect();
83
84        Self {
85            weight: Tensor::from_vec(weight_data, true),
86            bias: Tensor::zeros(num_classes, true),
87            hidden_size,
88            num_classes,
89        }
90    }
91
92    /// Forward pass: embedding → logits.
93    ///
94    /// # Arguments
95    /// * `embedding` - Pre-extracted [CLS] embedding [hidden_size]
96    ///
97    /// # Returns
98    /// Logits tensor [num_classes]
99    pub fn forward(&self, embedding: &Tensor) -> Tensor {
100        let logits = matmul(embedding, &self.weight, 1, self.hidden_size, self.num_classes);
101        let logits_data = logits.data();
102        let logits_slice = logits_data.as_slice().expect("contiguous logits");
103        let bias_data = self.bias.data();
104        let bias_slice = bias_data.as_slice().expect("contiguous bias");
105
106        let output: Vec<f32> =
107            logits_slice.iter().zip(bias_slice.iter()).map(|(&l, &b)| l + b).collect();
108        Tensor::from_vec(output, logits.requires_grad())
109    }
110
111    /// Predict class probabilities via softmax.
112    pub fn predict_probs(&self, embedding: &Tensor) -> Vec<f32> {
113        let logits = self.forward(embedding);
114        softmax_vec(&logits)
115    }
116
117    /// Predict class index (argmax of logits).
118    pub fn predict(&self, embedding: &Tensor) -> usize {
119        contract_pre_predict!();
120        let logits = self.forward(embedding);
121        let data = logits.data();
122        let slice = data.as_slice().expect("contiguous");
123        slice
124            .iter()
125            .enumerate()
126            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
127            .map_or(0, |(i, _)| i)
128    }
129
130    /// Train on pre-extracted embeddings using SGD.
131    ///
132    /// # Arguments
133    /// * `embeddings` - List of pre-extracted [CLS] embeddings (each len=hidden_size)
134    /// * `labels` - Corresponding class labels
135    /// * `epochs` - Number of training epochs
136    /// * `learning_rate` - SGD learning rate
137    /// * `class_weights` - Optional per-class loss weights for imbalance
138    ///
139    /// # Returns
140    /// Final training loss
141    pub fn train(
142        &mut self,
143        embeddings: &[Vec<f32>],
144        labels: &[usize],
145        epochs: usize,
146        learning_rate: f32,
147        class_weights: Option<&[f32]>,
148    ) -> f32 {
149        assert_eq!(embeddings.len(), labels.len());
150        let n = embeddings.len();
151        let mut final_loss = 0.0;
152
153        for epoch in 0..epochs {
154            let mut epoch_loss = 0.0;
155
156            for (emb, &label) in embeddings.iter().zip(labels.iter()) {
157                assert_eq!(emb.len(), self.hidden_size);
158                assert!(label < self.num_classes);
159
160                // Forward
161                let emb_tensor = Tensor::from_vec(emb.clone(), false);
162                let logits = self.forward(&emb_tensor);
163
164                // Cross-entropy loss with optional class weights
165                let probs = softmax_vec(&logits);
166                let loss_weight = class_weights.map_or(1.0, |w| w[label]);
167                let loss = -probs[label].max(1e-10).ln() * loss_weight;
168                epoch_loss += loss;
169
170                // Gradient of cross-entropy w.r.t. logits: probs - one_hot(label)
171                let mut grad_logits = probs;
172                grad_logits[label] -= 1.0;
173                if let Some(w) = class_weights {
174                    for (i, g) in grad_logits.iter_mut().enumerate() {
175                        *g *= w[i];
176                    }
177                }
178
179                // Update weight: grad_W = emb^T @ grad_logits
180                let w_data = self.weight.data();
181                let mut w_slice = w_data.as_slice().expect("contiguous").to_vec();
182                for i in 0..self.hidden_size {
183                    for j in 0..self.num_classes {
184                        w_slice[i * self.num_classes + j] -=
185                            learning_rate * emb[i] * grad_logits[j];
186                    }
187                }
188                self.weight = Tensor::from_vec(w_slice, true);
189
190                // Update bias: grad_b = grad_logits
191                let b_data = self.bias.data();
192                let mut b_slice = b_data.as_slice().expect("contiguous").to_vec();
193                for j in 0..self.num_classes {
194                    b_slice[j] -= learning_rate * grad_logits[j];
195                }
196                self.bias = Tensor::from_vec(b_slice, true);
197            }
198
199            final_loss = epoch_loss / n as f32;
200            if epoch == 0 || (epoch + 1) % 5 == 0 || epoch == epochs - 1 {
201                eprintln!("  Epoch {}/{epochs}: loss={final_loss:.4}", epoch + 1);
202            }
203        }
204
205        final_loss
206    }
207
208    /// Get total trainable parameter count (CLF-002: 1,538 for binary CodeBERT).
209    pub fn num_parameters(&self) -> usize {
210        self.hidden_size * self.num_classes + self.num_classes
211    }
212
213    /// Get number of classes.
214    pub fn num_classes(&self) -> usize {
215        self.num_classes
216    }
217}
218
219/// MLP probe: frozen embeddings + trainable 2-layer MLP head (Level 0.5).
220///
221/// Adds a hidden layer with ReLU between embeddings and classification head.
222/// This allows non-linear decision boundaries, which can capture patterns
223/// that a linear probe cannot (e.g., shell safety from CodeBERT embeddings).
224///
225/// Architecture: embedding → Linear(hidden_size, mlp_hidden) → ReLU → Linear(mlp_hidden, num_classes)
226pub struct MlpProbe {
227    /// First layer weights [hidden_size × mlp_hidden] flattened row-major
228    pub w1: Vec<f32>,
229    /// First layer bias [mlp_hidden]
230    pub b1: Vec<f32>,
231    /// Second layer weights [mlp_hidden × num_classes] flattened row-major
232    pub w2: Vec<f32>,
233    /// Second layer bias [num_classes]
234    pub b2: Vec<f32>,
235    /// Input dimension
236    pub hidden_size: usize,
237    /// Hidden layer dimension
238    pub mlp_hidden: usize,
239    /// Number of output classes
240    pub num_classes: usize,
241}
242
243impl MlpProbe {
244    /// Create with Xavier initialization.
245    pub fn new(hidden_size: usize, mlp_hidden: usize, num_classes: usize) -> Self {
246        assert!(hidden_size > 0 && mlp_hidden > 0 && num_classes >= 2);
247
248        let mut rng: u64 = 42;
249        let mut xavier = |fan_in: usize, fan_out: usize, n: usize| -> Vec<f32> {
250            let scale = (6.0 / (fan_in + fan_out) as f32).sqrt();
251            (0..n)
252                .map(|_| {
253                    rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
254                    let u = (rng >> 33) as f32 / (1u64 << 31) as f32;
255                    (2.0 * u - 1.0) * scale
256                })
257                .collect()
258        };
259
260        Self {
261            w1: xavier(hidden_size, mlp_hidden, hidden_size * mlp_hidden),
262            b1: vec![0.0; mlp_hidden],
263            w2: xavier(mlp_hidden, num_classes, mlp_hidden * num_classes),
264            b2: vec![0.0; num_classes],
265            hidden_size,
266            mlp_hidden,
267            num_classes,
268        }
269    }
270
271    /// Forward pass: embedding → hidden (ReLU) → logits.
272    pub fn forward(&self, emb: &[f32]) -> (Vec<f32>, Vec<f32>) {
273        // Layer 1: h = ReLU(W1 @ emb + b1)
274        let mut h = vec![0.0_f32; self.mlp_hidden];
275        for j in 0..self.mlp_hidden {
276            let mut sum = self.b1[j];
277            for i in 0..self.hidden_size {
278                sum += self.w1[i * self.mlp_hidden + j] * emb[i];
279            }
280            h[j] = sum.max(0.0); // ReLU
281        }
282
283        // Layer 2: logits = W2 @ h + b2
284        let mut logits = vec![0.0_f32; self.num_classes];
285        for j in 0..self.num_classes {
286            let mut sum = self.b2[j];
287            for i in 0..self.mlp_hidden {
288                sum += self.w2[i * self.num_classes + j] * h[i];
289            }
290            logits[j] = sum;
291        }
292
293        (h, logits)
294    }
295
296    /// Predict class index.
297    pub fn predict(&self, emb: &[f32]) -> usize {
298        contract_pre_predict!();
299        let (_, logits) = self.forward(emb);
300        logits
301            .iter()
302            .enumerate()
303            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
304            .map_or(0, |(i, _)| i)
305    }
306
307    /// Predict class probabilities via softmax.
308    pub fn predict_probs(&self, emb: &[f32]) -> Vec<f32> {
309        let (_, logits) = self.forward(emb);
310        softmax_slice(&logits)
311    }
312
313    /// Forward pass returning pre-ReLU activations, post-ReLU hidden, and logits.
314    fn forward_train(&self, emb: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
315        let mut h_pre = vec![0.0_f32; self.mlp_hidden];
316        let mut h = vec![0.0_f32; self.mlp_hidden];
317        for j in 0..self.mlp_hidden {
318            let mut sum = self.b1[j];
319            for i in 0..self.hidden_size {
320                sum += self.w1[i * self.mlp_hidden + j] * emb[i];
321            }
322            h_pre[j] = sum;
323            h[j] = sum.max(0.0);
324        }
325
326        let mut logits = vec![0.0_f32; self.num_classes];
327        for j in 0..self.num_classes {
328            let mut sum = self.b2[j];
329            for i in 0..self.mlp_hidden {
330                sum += self.w2[i * self.num_classes + j] * h[i];
331            }
332            logits[j] = sum;
333        }
334        (h_pre, h, logits)
335    }
336
337    /// Backward pass: update W1, b1, W2, b2 given gradients.
338    fn backward_step(
339        &mut self,
340        emb: &[f32],
341        h_pre: &[f32],
342        h: &[f32],
343        grad_logits: &[f32],
344        lr: f32,
345        wd: f32,
346    ) {
347        // Update W2 and b2
348        for i in 0..self.mlp_hidden {
349            for j in 0..self.num_classes {
350                let idx = i * self.num_classes + j;
351                self.w2[idx] -= lr * (h[i] * grad_logits[j] + wd * self.w2[idx]);
352            }
353        }
354        for j in 0..self.num_classes {
355            self.b2[j] -= lr * grad_logits[j];
356        }
357
358        // Compute grad_h (with ReLU mask)
359        let mut grad_h = vec![0.0_f32; self.mlp_hidden];
360        for i in 0..self.mlp_hidden {
361            if h_pre[i] > 0.0 {
362                for j in 0..self.num_classes {
363                    grad_h[i] += self.w2[i * self.num_classes + j] * grad_logits[j];
364                }
365            }
366        }
367
368        // Update W1 and b1
369        for i in 0..self.hidden_size {
370            for j in 0..self.mlp_hidden {
371                let idx = i * self.mlp_hidden + j;
372                self.w1[idx] -= lr * (emb[i] * grad_h[j] + wd * self.w1[idx]);
373            }
374        }
375        for j in 0..self.mlp_hidden {
376            self.b1[j] -= lr * grad_h[j];
377        }
378    }
379
380    /// Train with online SGD + class weights + L2 regularization.
381    #[allow(clippy::too_many_arguments)]
382    pub fn train(
383        &mut self,
384        embeddings: &[Vec<f32>],
385        labels: &[usize],
386        epochs: usize,
387        learning_rate: f32,
388        class_weights: Option<&[f32]>,
389        weight_decay: f32,
390    ) -> f32 {
391        assert_eq!(embeddings.len(), labels.len());
392        let n = embeddings.len();
393        let mut final_loss = 0.0;
394
395        for epoch in 0..epochs {
396            let mut epoch_loss = 0.0;
397
398            for (emb, &label) in embeddings.iter().zip(labels.iter()) {
399                let (h_pre, h, logits) = self.forward_train(emb);
400                let probs = softmax_slice(&logits);
401                let loss_weight = class_weights.map_or(1.0, |w| w[label]);
402                epoch_loss += -probs[label].max(1e-10).ln() * loss_weight;
403
404                let mut grad_logits = probs;
405                grad_logits[label] -= 1.0;
406                if let Some(w) = class_weights {
407                    for (i, g) in grad_logits.iter_mut().enumerate() {
408                        *g *= w[i];
409                    }
410                }
411
412                self.backward_step(emb, &h_pre, &h, &grad_logits, learning_rate, weight_decay);
413            }
414
415            final_loss = epoch_loss / n as f32;
416            if epoch == 0 || (epoch + 1) % 10 == 0 || epoch == epochs - 1 {
417                eprintln!("  Epoch {}/{epochs}: loss={final_loss:.4}", epoch + 1);
418            }
419        }
420
421        final_loss
422    }
423
424    /// Total trainable parameters.
425    pub fn num_parameters(&self) -> usize {
426        self.hidden_size * self.mlp_hidden + self.mlp_hidden  // W1 + b1
427        + self.mlp_hidden * self.num_classes + self.num_classes // W2 + b2
428    }
429}
430
431/// Compute softmax from a slice.
432fn softmax_slice(logits: &[f32]) -> Vec<f32> {
433    let max_val = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
434    let exp_vals: Vec<f32> = logits.iter().map(|&x| (x - max_val).exp()).collect();
435    let sum: f32 = exp_vals.iter().sum();
436    exp_vals.iter().map(|&v| v / sum).collect()
437}
438
439/// Compute softmax probabilities from a tensor.
440fn softmax_vec(logits: &Tensor) -> Vec<f32> {
441    let data = logits.data();
442    let slice = data.as_slice().expect("contiguous logits");
443    let max_val = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
444    let exp_vals: Vec<f32> = slice.iter().map(|&x| (x - max_val).exp()).collect();
445    let sum: f32 = exp_vals.iter().sum();
446    exp_vals.iter().map(|&v| v / sum).collect()
447}
448
449/// Compute binary MCC from confusion matrix values.
450///
451/// MCC = (TP*TN - FP*FN) / sqrt((TP+FP)(TP+FN)(TN+FP)(TN+FN))
452pub fn binary_mcc(tp: usize, tn: usize, fp: usize, fn_count: usize) -> f32 {
453    let numerator = (tp * tn) as f64 - (fp * fn_count) as f64;
454    let denom =
455        ((tp + fp) as f64 * (tp + fn_count) as f64 * (tn + fp) as f64 * (tn + fn_count) as f64)
456            .sqrt();
457    if denom < 1e-10 {
458        0.0
459    } else {
460        (numerator / denom) as f32
461    }
462}
463
464/// Evaluate predictions against ground truth (CLF-003).
465pub fn evaluate(
466    predictions: &[usize],
467    labels: &[usize],
468    num_classes: usize,
469) -> ClassificationMetrics {
470    assert_eq!(predictions.len(), labels.len());
471    let n = predictions.len();
472
473    // Build confusion matrix
474    let mut cm = vec![vec![0usize; num_classes]; num_classes];
475    for (&pred, &label) in predictions.iter().zip(labels.iter()) {
476        if pred < num_classes && label < num_classes {
477            cm[pred][label] += 1;
478        }
479    }
480
481    // Accuracy
482    let correct: usize = (0..num_classes).map(|c| cm[c][c]).sum();
483    let accuracy = correct as f32 / n.max(1) as f32;
484
485    // Per-class precision and recall
486    let mut precision = vec![0.0_f32; num_classes];
487    let mut recall = vec![0.0_f32; num_classes];
488    for c in 0..num_classes {
489        let pred_count: usize = cm[c].iter().sum();
490        let actual_count: usize = (0..num_classes).map(|p| cm[p][c]).sum();
491        precision[c] = if pred_count > 0 { cm[c][c] as f32 / pred_count as f32 } else { 0.0 };
492        recall[c] = if actual_count > 0 { cm[c][c] as f32 / actual_count as f32 } else { 0.0 };
493    }
494
495    // MCC (binary for 2-class, multiclass generalization otherwise)
496    let mcc = if num_classes == 2 {
497        let tp = cm[1][1];
498        let tn = cm[0][0];
499        let fp = cm[1][0];
500        let fn_count = cm[0][1];
501        binary_mcc(tp, tn, fp, fn_count)
502    } else {
503        multiclass_mcc(&cm, num_classes)
504    };
505
506    ClassificationMetrics { mcc, accuracy, recall, precision, num_samples: n, confusion_matrix: cm }
507}
508
509/// Multiclass MCC using the general formula.
510fn multiclass_mcc(cm: &[Vec<usize>], k: usize) -> f32 {
511    let n: f64 = cm.iter().flat_map(|row| row.iter()).sum::<usize>() as f64;
512    let c: f64 = (0..k).map(|i| cm[i][i] as f64).sum();
513
514    let mut s = 0.0_f64; // sum of outer products
515    let mut p = 0.0_f64; // sum of row sums squared
516    let mut t = 0.0_f64; // sum of col sums squared
517
518    for i in 0..k {
519        let row_sum: f64 = cm[i].iter().sum::<usize>() as f64;
520        let col_sum: f64 = (0..k).map(|j| cm[j][i] as f64).sum();
521        p += row_sum * row_sum;
522        t += col_sum * col_sum;
523        for j in 0..k {
524            s += (cm[i].iter().sum::<usize>() as f64) * (cm[j][i] as f64);
525        }
526    }
527
528    let numerator = c * n - s;
529    let denom = ((n * n - p) * (n * n - t)).sqrt();
530    if denom < 1e-10 {
531        0.0
532    } else {
533        (numerator / denom) as f32
534    }
535}
536
537/// Compute bootstrap confidence interval for MCC (CLF-003).
538///
539/// Resamples predictions/labels with replacement `n_bootstrap` times,
540/// computes MCC for each, and returns 2.5th/97.5th percentiles.
541pub fn bootstrap_mcc_ci(
542    predictions: &[usize],
543    labels: &[usize],
544    num_classes: usize,
545    n_bootstrap: usize,
546) -> BootstrapCI {
547    let n = predictions.len();
548    let point_estimate = evaluate(predictions, labels, num_classes).mcc;
549
550    let mut mcc_samples = Vec::with_capacity(n_bootstrap);
551    let mut rng: u64 = 12345;
552
553    for _ in 0..n_bootstrap {
554        let mut boot_preds = Vec::with_capacity(n);
555        let mut boot_labels = Vec::with_capacity(n);
556
557        for _ in 0..n {
558            rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1442695040888963407);
559            let idx = (rng >> 33) as usize % n;
560            boot_preds.push(predictions[idx]);
561            boot_labels.push(labels[idx]);
562        }
563
564        let metrics = evaluate(&boot_preds, &boot_labels, num_classes);
565        mcc_samples.push(metrics.mcc);
566    }
567
568    mcc_samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
569
570    let lower_idx = (n_bootstrap as f32 * 0.025) as usize;
571    let upper_idx = ((n_bootstrap as f32 * 0.975) as usize).min(n_bootstrap - 1);
572
573    BootstrapCI {
574        estimate: point_estimate,
575        lower: mcc_samples[lower_idx],
576        upper: mcc_samples[upper_idx],
577        n_bootstrap,
578    }
579}
580
581/// Confidence score for a single sample (CLF-007).
582#[derive(Debug, Clone)]
583pub struct ConfidenceScore {
584    /// Predicted class (argmax)
585    pub predicted_class: usize,
586    /// Probability of predicted class
587    pub confidence: f32,
588    /// Full probability distribution
589    pub probabilities: Vec<f32>,
590}
591
592/// Cache confidence scores for all samples (CLF-007).
593pub fn compute_confidence_scores(
594    probe: &LinearProbe,
595    embeddings: &[Vec<f32>],
596) -> Vec<ConfidenceScore> {
597    embeddings
598        .iter()
599        .map(|emb| {
600            let emb_tensor = Tensor::from_vec(emb.clone(), false);
601            let probs = probe.predict_probs(&emb_tensor);
602            let (predicted_class, &confidence) = probs
603                .iter()
604                .enumerate()
605                .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
606                .expect("non-empty probabilities");
607            ConfidenceScore { predicted_class, confidence, probabilities: probs }
608        })
609        .collect()
610}
611
612// =============================================================================
613// CLF-004: ESCALATION LADDER
614// =============================================================================
615
616/// Escalation level for classifier training (SSC v11 Section 5.4).
617#[derive(Debug, Clone, Copy, PartialEq, Eq)]
618pub enum EscalationLevel {
619    /// Level 0: Linear probe on frozen encoder (1,538 params)
620    LinearProbe,
621    /// Level 1: Fine-tune top-2 encoder layers + head (~15M params)
622    TopLayers,
623    /// Level 2: Full fine-tune all encoder layers (125M params)
624    FullFinetune,
625    /// Level 3: Continue-pretrain on shell + fine-tune (125M params)
626    ContinuePretrain,
627}
628
629impl std::fmt::Display for EscalationLevel {
630    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631        match self {
632            Self::LinearProbe => write!(f, "Level 0: Linear probe"),
633            Self::TopLayers => write!(f, "Level 1: Top-2 layers + head"),
634            Self::FullFinetune => write!(f, "Level 2: Full fine-tune"),
635            Self::ContinuePretrain => write!(f, "Level 3: Continue-pretrain + fine-tune"),
636        }
637    }
638}
639
640/// Decide whether to escalate based on MCC CI (CLF-004).
641///
642/// Returns `Some(next_level)` if escalation needed, `None` if ship gate met.
643pub fn should_escalate(
644    current_level: EscalationLevel,
645    mcc_ci: &BootstrapCI,
646    accuracy: f32,
647) -> Option<EscalationLevel> {
648    match current_level {
649        EscalationLevel::LinearProbe => {
650            if mcc_ci.lower < 0.2 || accuracy <= 0.935 {
651                Some(EscalationLevel::TopLayers)
652            } else {
653                None // Ship gate C-CLF-001 met
654            }
655        }
656        EscalationLevel::TopLayers | EscalationLevel::FullFinetune => {
657            if mcc_ci.lower < 0.3 {
658                match current_level {
659                    EscalationLevel::TopLayers => Some(EscalationLevel::FullFinetune),
660                    _ => Some(EscalationLevel::ContinuePretrain),
661                }
662            } else {
663                None
664            }
665        }
666        EscalationLevel::ContinuePretrain => {
667            // Terminal level — if this fails, classifier adds no value
668            None
669        }
670    }
671}
672
673// =============================================================================
674// CLF-005: BASELINES COMPARISON
675// =============================================================================
676
677/// Baseline comparison result (CLF-005).
678#[derive(Debug, Clone)]
679pub struct BaselineComparison {
680    /// Name of the baseline
681    pub name: String,
682    /// Baseline MCC
683    pub baseline_mcc: f32,
684    /// Model MCC
685    pub model_mcc: f32,
686    /// Whether model beats this baseline
687    pub beats_baseline: bool,
688}
689
690/// Compare model against baselines (CLF-005).
691///
692/// Baselines from SSC v11 Section 5.5:
693/// - Majority class: MCC = 0.0
694/// - Keyword regex: MCC ~0.3-0.5
695/// - bashrs linter: MCC ~0.4-0.6
696pub fn compare_baselines(model_mcc: f32, baseline_mccs: &[(&str, f32)]) -> Vec<BaselineComparison> {
697    baseline_mccs
698        .iter()
699        .map(|&(name, baseline_mcc)| BaselineComparison {
700            name: name.to_string(),
701            baseline_mcc,
702            model_mcc,
703            beats_baseline: model_mcc > baseline_mcc,
704        })
705        .collect()
706}
707
708// =============================================================================
709// CLF-006: GENERALIZATION TEST
710// =============================================================================
711
712/// Generalization test result (CLF-006).
713#[derive(Debug, Clone)]
714pub struct GeneralizationResult {
715    /// Number of novel unsafe scripts tested
716    pub total: usize,
717    /// Number correctly classified as unsafe
718    pub detected: usize,
719    /// Detection rate (detected / total)
720    pub detection_rate: f32,
721    /// Meets threshold (>= 50%)
722    pub passes: bool,
723}
724
725/// Run generalization test on novel unsafe scripts (CLF-006).
726///
727/// Tests the classifier on out-of-distribution scripts that have
728/// no lexical overlap with training data.
729pub fn generalization_test(
730    probe: &LinearProbe,
731    novel_embeddings: &[Vec<f32>],
732    unsafe_class: usize,
733) -> GeneralizationResult {
734    let total = novel_embeddings.len();
735    let detected = novel_embeddings
736        .iter()
737        .filter(|emb| {
738            let emb_tensor = Tensor::from_vec((*emb).clone(), false);
739            probe.predict(&emb_tensor) == unsafe_class
740        })
741        .count();
742
743    let detection_rate = if total > 0 { detected as f32 / total as f32 } else { 0.0 };
744
745    GeneralizationResult { total, detected, detection_rate, passes: detection_rate >= 0.5 }
746}
747
748// =============================================================================
749// SHIP GATE (C-CLF-001)
750// =============================================================================
751
752/// Ship gate check result (SSC v11 Section 5.7).
753#[allow(clippy::struct_excessive_bools)]
754#[derive(Debug, Clone)]
755pub struct ShipGateResult {
756    /// MCC CI lower bound > 0.2
757    pub mcc_passes: bool,
758    /// Accuracy > 0.935
759    pub accuracy_passes: bool,
760    /// Generalization >= 50%
761    pub generalization_passes: bool,
762    /// All criteria met
763    pub ship_ready: bool,
764    /// Escalation level that achieved these results
765    pub level: EscalationLevel,
766}
767
768/// Check ship gate C-CLF-001 (SSC v11 Section 5.7).
769pub fn check_ship_gate(
770    mcc_ci: &BootstrapCI,
771    accuracy: f32,
772    generalization: &GeneralizationResult,
773    level: EscalationLevel,
774) -> ShipGateResult {
775    let mcc_passes = mcc_ci.lower > 0.2;
776    let accuracy_passes = accuracy > 0.935;
777    let generalization_passes = generalization.passes;
778
779    ShipGateResult {
780        mcc_passes,
781        accuracy_passes,
782        generalization_passes,
783        ship_ready: mcc_passes && accuracy_passes && generalization_passes,
784        level,
785    }
786}
787
788#[cfg(test)]
789#[allow(clippy::unwrap_used)]
790mod tests {
791    use super::*;
792
793    #[test]
794    fn clf_002_linear_probe_forward_shape() {
795        let probe = LinearProbe::new(768, 2);
796        let emb = Tensor::from_vec(vec![0.1; 768], false);
797        let logits = probe.forward(&emb);
798        assert_eq!(logits.len(), 2);
799    }
800
801    #[test]
802    fn clf_002_linear_probe_predict_probs_sum_to_one() {
803        let probe = LinearProbe::new(64, 3);
804        let emb = Tensor::from_vec(vec![0.5; 64], false);
805        let probs = probe.predict_probs(&emb);
806        assert_eq!(probs.len(), 3);
807        let sum: f32 = probs.iter().sum();
808        assert!((sum - 1.0).abs() < 1e-5, "probabilities must sum to 1.0, got {sum}");
809        assert!(probs.iter().all(|&p| p > 0.0), "all probabilities must be positive");
810    }
811
812    #[test]
813    fn clf_002_linear_probe_num_parameters() {
814        let probe = LinearProbe::new(768, 2);
815        assert_eq!(probe.num_parameters(), 768 * 2 + 2); // 1538
816    }
817
818    #[test]
819    fn clf_002_linear_probe_train_reduces_loss() {
820        let mut probe = LinearProbe::new(8, 2);
821        // Simple linearly separable data
822        let embeddings: Vec<Vec<f32>> = (0..20)
823            .map(|i| {
824                if i < 10 {
825                    vec![1.0; 8] // class 0
826                } else {
827                    vec![-1.0; 8] // class 1
828                }
829            })
830            .collect();
831        let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
832
833        let loss_before = {
834            let mut temp = LinearProbe::new(8, 2);
835            temp.train(&embeddings, &labels, 1, 0.01, None)
836        };
837        let loss_after = probe.train(&embeddings, &labels, 10, 0.01, None);
838
839        // After 10 epochs, loss should be lower
840        assert!(loss_after < loss_before + 0.5, "training should reduce loss");
841    }
842
843    #[test]
844    fn clf_003_binary_mcc_perfect() {
845        // Perfect predictions
846        assert!((binary_mcc(50, 50, 0, 0) - 1.0).abs() < 1e-5);
847    }
848
849    #[test]
850    fn clf_003_binary_mcc_random() {
851        // Random predictions: MCC ≈ 0
852        assert!(binary_mcc(25, 25, 25, 25).abs() < 1e-5);
853    }
854
855    #[test]
856    fn clf_003_evaluate_perfect() {
857        let preds = vec![0, 0, 1, 1, 1];
858        let labels = vec![0, 0, 1, 1, 1];
859        let metrics = evaluate(&preds, &labels, 2);
860        assert!((metrics.accuracy - 1.0).abs() < 1e-5);
861        assert!((metrics.mcc - 1.0).abs() < 1e-5);
862    }
863
864    #[test]
865    fn clf_003_evaluate_majority_baseline() {
866        // All predict class 0
867        let preds = vec![0; 100];
868        let labels: Vec<usize> = (0..100).map(|i| usize::from(i >= 93)).collect();
869        let metrics = evaluate(&preds, &labels, 2);
870        assert!((metrics.accuracy - 0.93).abs() < 0.01);
871        assert_eq!(metrics.recall[1], 0.0); // unsafe recall is 0
872    }
873
874    #[test]
875    fn clf_003_bootstrap_ci_contains_estimate() {
876        let preds = vec![0, 0, 1, 1, 0, 1, 0, 0, 1, 1];
877        let labels = vec![0, 0, 1, 1, 0, 0, 0, 1, 1, 1];
878        let ci = bootstrap_mcc_ci(&preds, &labels, 2, 100);
879        assert!(ci.lower <= ci.estimate, "CI lower must be <= estimate");
880        assert!(ci.upper >= ci.estimate, "CI upper must be >= estimate");
881    }
882
883    #[test]
884    fn clf_007_confidence_scores() {
885        let probe = LinearProbe::new(8, 2);
886        let embeddings = vec![vec![0.5; 8], vec![-0.5; 8]];
887        let scores = compute_confidence_scores(&probe, &embeddings);
888        assert_eq!(scores.len(), 2);
889        for score in &scores {
890            assert!(score.confidence > 0.0);
891            assert!(score.confidence <= 1.0);
892            assert_eq!(score.probabilities.len(), 2);
893            let sum: f32 = score.probabilities.iter().sum();
894            assert!((sum - 1.0).abs() < 1e-5);
895        }
896    }
897
898    // =========================================================================
899    // CLF-004: ESCALATION TESTS
900    // =========================================================================
901
902    #[test]
903    fn clf_004_escalate_from_linear_probe_low_mcc() {
904        let ci = BootstrapCI { estimate: 0.15, lower: 0.10, upper: 0.20, n_bootstrap: 100 };
905        let result = should_escalate(EscalationLevel::LinearProbe, &ci, 0.94);
906        assert_eq!(result, Some(EscalationLevel::TopLayers));
907    }
908
909    #[test]
910    fn clf_004_no_escalate_when_ship_gate_met() {
911        let ci = BootstrapCI { estimate: 0.45, lower: 0.30, upper: 0.60, n_bootstrap: 100 };
912        let result = should_escalate(EscalationLevel::LinearProbe, &ci, 0.96);
913        assert_eq!(result, None);
914    }
915
916    #[test]
917    fn clf_004_escalate_from_top_layers_to_full() {
918        let ci = BootstrapCI { estimate: 0.25, lower: 0.15, upper: 0.35, n_bootstrap: 100 };
919        let result = should_escalate(EscalationLevel::TopLayers, &ci, 0.95);
920        assert_eq!(result, Some(EscalationLevel::FullFinetune));
921    }
922
923    #[test]
924    fn clf_004_terminal_level_no_escalation() {
925        let ci = BootstrapCI { estimate: 0.1, lower: 0.05, upper: 0.15, n_bootstrap: 100 };
926        let result = should_escalate(EscalationLevel::ContinuePretrain, &ci, 0.90);
927        assert_eq!(result, None); // Terminal — can't escalate further
928    }
929
930    #[test]
931    fn clf_004_escalate_on_low_accuracy() {
932        // MCC CI OK but accuracy below threshold
933        let ci = BootstrapCI { estimate: 0.45, lower: 0.30, upper: 0.60, n_bootstrap: 100 };
934        let result = should_escalate(EscalationLevel::LinearProbe, &ci, 0.93);
935        assert_eq!(result, Some(EscalationLevel::TopLayers));
936    }
937
938    // =========================================================================
939    // CLF-005: BASELINES COMPARISON TESTS
940    // =========================================================================
941
942    #[test]
943    fn clf_005_compare_baselines_beats_majority() {
944        let baselines = vec![("majority", 0.0), ("keyword", 0.4), ("linter", 0.5)];
945        let comparisons = compare_baselines(0.35, &baselines);
946        assert!(comparisons[0].beats_baseline); // beats majority (0.35 > 0.0)
947        assert!(!comparisons[1].beats_baseline); // loses to keyword (0.35 < 0.4)
948        assert!(!comparisons[2].beats_baseline); // loses to linter (0.35 < 0.5)
949    }
950
951    #[test]
952    fn clf_005_compare_baselines_beats_all() {
953        let baselines = vec![("majority", 0.0), ("keyword", 0.4), ("linter", 0.5)];
954        let comparisons = compare_baselines(0.65, &baselines);
955        assert!(comparisons.iter().all(|c| c.beats_baseline));
956    }
957
958    // =========================================================================
959    // CLF-006: GENERALIZATION TEST
960    // =========================================================================
961
962    #[test]
963    fn clf_006_generalization_all_detected() {
964        let mut probe = LinearProbe::new(4, 2);
965        // Train probe to always predict unsafe (class 1) for negative embeddings
966        let embeddings: Vec<Vec<f32>> =
967            (0..20).map(|i| if i < 10 { vec![1.0; 4] } else { vec![-1.0; 4] }).collect();
968        let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
969        probe.train(&embeddings, &labels, 30, 0.1, None);
970
971        let novel = vec![vec![-1.0; 4]; 10]; // all "unsafe" pattern
972        let result = generalization_test(&probe, &novel, 1);
973        assert_eq!(result.total, 10);
974        assert!(result.passes, "trained probe should detect unsafe-pattern embeddings");
975    }
976
977    #[test]
978    fn clf_006_generalization_empty() {
979        let probe = LinearProbe::new(4, 2);
980        let result = generalization_test(&probe, &[], 1);
981        assert_eq!(result.total, 0);
982        assert_eq!(result.detection_rate, 0.0);
983    }
984
985    // =========================================================================
986    // SHIP GATE TESTS
987    // =========================================================================
988
989    #[test]
990    fn clf_ship_gate_passes() {
991        let ci = BootstrapCI { estimate: 0.4, lower: 0.25, upper: 0.55, n_bootstrap: 100 };
992        let gen =
993            GeneralizationResult { total: 50, detected: 30, detection_rate: 0.6, passes: true };
994        let result = check_ship_gate(&ci, 0.96, &gen, EscalationLevel::LinearProbe);
995        assert!(result.ship_ready);
996        assert!(result.mcc_passes);
997        assert!(result.accuracy_passes);
998        assert!(result.generalization_passes);
999    }
1000
1001    #[test]
1002    fn clf_ship_gate_fails_mcc() {
1003        let ci = BootstrapCI { estimate: 0.15, lower: 0.10, upper: 0.20, n_bootstrap: 100 };
1004        let gen =
1005            GeneralizationResult { total: 50, detected: 30, detection_rate: 0.6, passes: true };
1006        let result = check_ship_gate(&ci, 0.96, &gen, EscalationLevel::LinearProbe);
1007        assert!(!result.ship_ready);
1008        assert!(!result.mcc_passes);
1009    }
1010
1011    #[test]
1012    fn clf_ship_gate_fails_generalization() {
1013        let ci = BootstrapCI { estimate: 0.4, lower: 0.25, upper: 0.55, n_bootstrap: 100 };
1014        let gen =
1015            GeneralizationResult { total: 50, detected: 20, detection_rate: 0.4, passes: false };
1016        let result = check_ship_gate(&ci, 0.96, &gen, EscalationLevel::LinearProbe);
1017        assert!(!result.ship_ready);
1018        assert!(!result.generalization_passes);
1019    }
1020
1021    // =========================================================================
1022    // MLP PROBE TESTS (Level 0.5)
1023    // =========================================================================
1024
1025    #[test]
1026    fn mlp_probe_forward_shape() {
1027        let probe = MlpProbe::new(768, 128, 2);
1028        let emb = vec![0.1; 768];
1029        let (h, logits) = probe.forward(&emb);
1030        assert_eq!(h.len(), 128);
1031        assert_eq!(logits.len(), 2);
1032    }
1033
1034    #[test]
1035    fn mlp_probe_predict_probs_sum_to_one() {
1036        let probe = MlpProbe::new(64, 32, 3);
1037        let emb = vec![0.5; 64];
1038        let probs = probe.predict_probs(&emb);
1039        assert_eq!(probs.len(), 3);
1040        let sum: f32 = probs.iter().sum();
1041        assert!((sum - 1.0).abs() < 1e-5, "probabilities must sum to 1.0, got {sum}");
1042    }
1043
1044    #[test]
1045    fn mlp_probe_num_parameters() {
1046        let probe = MlpProbe::new(768, 128, 2);
1047        // W1: 768*128 + b1: 128 + W2: 128*2 + b2: 2 = 98,434 + 128 + 256 + 2 = 98,818 + 2 = 98,690
1048        assert_eq!(probe.num_parameters(), 768 * 128 + 128 + 128 * 2 + 2);
1049    }
1050
1051    #[test]
1052    fn mlp_probe_relu_zeros_negative() {
1053        let probe = MlpProbe::new(4, 4, 2);
1054        let emb = vec![-10.0; 4]; // all negative
1055        let (h, _) = probe.forward(&emb);
1056        // After ReLU, some hidden units may be zero (depends on init)
1057        // At least verify h values are non-negative
1058        assert!(h.iter().all(|&v| v >= 0.0), "ReLU output must be non-negative");
1059    }
1060
1061    #[test]
1062    fn mlp_probe_train_learns_xor() {
1063        // XOR is not linearly separable — MLP should learn it, linear probe can't
1064        let embeddings = vec![vec![0.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0]];
1065        let labels = vec![0, 1, 1, 0]; // XOR pattern
1066
1067        // Repeat data for more training signal
1068        let embeddings: Vec<Vec<f32>> = embeddings.iter().cycle().take(40).cloned().collect();
1069        let labels: Vec<usize> = labels.iter().cycle().take(40).copied().collect();
1070
1071        let mut mlp = MlpProbe::new(2, 8, 2);
1072        mlp.train(&embeddings, &labels, 200, 0.1, None, 0.0);
1073
1074        // Test XOR predictions
1075        let pred_00 = mlp.predict(&[0.0, 0.0]);
1076        let pred_01 = mlp.predict(&[0.0, 1.0]);
1077        let pred_10 = mlp.predict(&[1.0, 0.0]);
1078        let pred_11 = mlp.predict(&[1.0, 1.0]);
1079
1080        // MLP should learn XOR (at least partially)
1081        let correct = u8::from(pred_00 == 0)
1082            + u8::from(pred_01 == 1)
1083            + u8::from(pred_10 == 1)
1084            + u8::from(pred_11 == 0);
1085        assert!(correct >= 3, "MLP should learn XOR (got {correct}/4 correct)");
1086    }
1087
1088    #[test]
1089    fn mlp_probe_train_reduces_loss() {
1090        let mut probe = MlpProbe::new(8, 16, 2);
1091        let embeddings: Vec<Vec<f32>> =
1092            (0..20).map(|i| if i < 10 { vec![1.0; 8] } else { vec![-1.0; 8] }).collect();
1093        let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
1094
1095        let loss_1 = probe.train(&embeddings, &labels, 1, 0.01, None, 0.0);
1096        let loss_10 = probe.train(&embeddings, &labels, 10, 0.01, None, 0.0);
1097        assert!(loss_10 < loss_1 + 0.5, "training should reduce loss");
1098    }
1099
1100    // ── test_cov4 additional coverage tests ────────────────────────
1101
1102    #[test]
1103    fn test_cov4_multiclass_mcc_perfect_3class() {
1104        // Perfect 3-class predictions
1105        let preds = vec![0, 0, 1, 1, 2, 2];
1106        let labels = vec![0, 0, 1, 1, 2, 2];
1107        let metrics = evaluate(&preds, &labels, 3);
1108        assert!((metrics.accuracy - 1.0).abs() < 1e-5);
1109        assert!(metrics.mcc > 0.9, "Perfect 3-class should have high MCC, got {}", metrics.mcc);
1110    }
1111
1112    #[test]
1113    fn test_cov4_multiclass_mcc_random_3class() {
1114        // Completely wrong predictions
1115        let preds = vec![1, 2, 0, 2, 0, 1];
1116        let labels = vec![0, 0, 1, 1, 2, 2];
1117        let metrics = evaluate(&preds, &labels, 3);
1118        assert!(metrics.mcc < 0.1, "Random 3-class MCC should be near 0, got {}", metrics.mcc);
1119    }
1120
1121    #[test]
1122    fn test_cov4_multiclass_mcc_4class() {
1123        let preds = vec![0, 1, 2, 3, 0, 1, 2, 3];
1124        let labels = vec![0, 1, 2, 3, 0, 1, 2, 3];
1125        let metrics = evaluate(&preds, &labels, 4);
1126        assert!((metrics.mcc - 1.0).abs() < 1e-5);
1127        assert_eq!(metrics.num_samples, 8);
1128    }
1129
1130    #[test]
1131    fn test_cov4_binary_mcc_all_tp() {
1132        // All positive, all predicted positive
1133        assert_eq!(binary_mcc(100, 0, 0, 0), 0.0); // denom = 0
1134    }
1135
1136    #[test]
1137    fn test_cov4_binary_mcc_all_tn() {
1138        // All negative, all predicted negative
1139        assert_eq!(binary_mcc(0, 100, 0, 0), 0.0); // denom = 0
1140    }
1141
1142    #[test]
1143    fn test_cov4_binary_mcc_worst() {
1144        // All predictions wrong
1145        assert!((binary_mcc(0, 0, 50, 50) - (-1.0)).abs() < 1e-5);
1146    }
1147
1148    #[test]
1149    fn test_cov4_binary_mcc_asymmetric() {
1150        // TP=80, TN=10, FP=5, FN=5
1151        let mcc = binary_mcc(80, 10, 5, 5);
1152        assert!(mcc > 0.0 && mcc < 1.0, "Asymmetric MCC should be between 0 and 1, got {mcc}");
1153    }
1154
1155    #[test]
1156    fn test_cov4_evaluate_all_same_prediction() {
1157        // All predict class 0, mixed true labels
1158        let preds = vec![0, 0, 0, 0, 0];
1159        let labels = vec![0, 0, 1, 1, 1];
1160        let metrics = evaluate(&preds, &labels, 2);
1161        assert!((metrics.accuracy - 0.4).abs() < 1e-5);
1162        assert_eq!(metrics.recall[0], 1.0); // all class-0 detected
1163        assert_eq!(metrics.recall[1], 0.0); // no class-1 detected
1164    }
1165
1166    #[test]
1167    fn test_cov4_evaluate_empty() {
1168        let metrics = evaluate(&[], &[], 2);
1169        assert_eq!(metrics.num_samples, 0);
1170        assert!((metrics.accuracy - 0.0).abs() < 1e-5);
1171    }
1172
1173    #[test]
1174    fn test_cov4_evaluate_precision() {
1175        let preds = vec![0, 0, 1, 1, 1];
1176        let labels = vec![0, 1, 1, 1, 0];
1177        let metrics = evaluate(&preds, &labels, 2);
1178        // Predicted class 0: [0]=true, [1]=false → precision[0]=1/2=0.5
1179        assert!((metrics.precision[0] - 0.5).abs() < 1e-5);
1180        // Predicted class 1: [1]=true(2), [0]=false(1) → precision[1]=2/3
1181        assert!((metrics.precision[1] - 2.0 / 3.0).abs() < 1e-5);
1182    }
1183
1184    #[test]
1185    fn test_cov4_evaluate_confusion_matrix() {
1186        let preds = vec![0, 1, 0, 1];
1187        let labels = vec![0, 1, 1, 0];
1188        let metrics = evaluate(&preds, &labels, 2);
1189        // cm[pred][actual]
1190        assert_eq!(metrics.confusion_matrix[0][0], 1); // TP for class 0
1191        assert_eq!(metrics.confusion_matrix[0][1], 1); // FN for class 1 (predicted 0)
1192        assert_eq!(metrics.confusion_matrix[1][0], 1); // FP for class 0 (predicted 1)
1193        assert_eq!(metrics.confusion_matrix[1][1], 1); // TP for class 1
1194    }
1195
1196    #[test]
1197    fn test_cov4_evaluate_out_of_bounds_ignored() {
1198        // Labels or predictions >= num_classes should be silently skipped in cm
1199        let preds = vec![0, 1, 5]; // 5 is out of bounds for num_classes=2
1200        let labels = vec![0, 1, 0];
1201        let metrics = evaluate(&preds, &labels, 2);
1202        assert_eq!(metrics.num_samples, 3);
1203        // cm should only count valid entries
1204        assert_eq!(metrics.confusion_matrix[0][0], 1);
1205        assert_eq!(metrics.confusion_matrix[1][1], 1);
1206    }
1207
1208    #[test]
1209    fn test_cov4_bootstrap_ci_deterministic() {
1210        let preds = vec![0, 0, 1, 1, 0, 1, 0, 0, 1, 1];
1211        let labels = vec![0, 0, 1, 1, 0, 0, 0, 1, 1, 1];
1212        let ci1 = bootstrap_mcc_ci(&preds, &labels, 2, 50);
1213        let ci2 = bootstrap_mcc_ci(&preds, &labels, 2, 50);
1214        // Same seed should give same result
1215        assert!((ci1.lower - ci2.lower).abs() < 1e-5);
1216        assert!((ci1.upper - ci2.upper).abs() < 1e-5);
1217    }
1218
1219    #[test]
1220    fn test_cov4_bootstrap_ci_bounds() {
1221        let preds = vec![0, 0, 1, 1, 0, 1];
1222        let labels = vec![0, 0, 1, 1, 0, 1];
1223        let ci = bootstrap_mcc_ci(&preds, &labels, 2, 200);
1224        assert!(ci.lower <= ci.upper);
1225        assert!(ci.lower >= -1.0);
1226        assert!(ci.upper <= 1.0);
1227        assert_eq!(ci.n_bootstrap, 200);
1228    }
1229
1230    #[test]
1231    fn test_cov4_confidence_scores_deterministic() {
1232        let probe = LinearProbe::new(8, 2);
1233        let embs = vec![vec![0.5; 8], vec![-0.5; 8]];
1234        let scores1 = compute_confidence_scores(&probe, &embs);
1235        let scores2 = compute_confidence_scores(&probe, &embs);
1236        for (s1, s2) in scores1.iter().zip(scores2.iter()) {
1237            assert_eq!(s1.predicted_class, s2.predicted_class);
1238            assert!((s1.confidence - s2.confidence).abs() < 1e-6);
1239        }
1240    }
1241
1242    #[test]
1243    fn test_cov4_confidence_scores_empty() {
1244        let probe = LinearProbe::new(8, 2);
1245        let scores = compute_confidence_scores(&probe, &[]);
1246        assert!(scores.is_empty());
1247    }
1248
1249    #[test]
1250    fn test_cov4_escalation_display() {
1251        assert_eq!(format!("{}", EscalationLevel::LinearProbe), "Level 0: Linear probe");
1252        assert_eq!(format!("{}", EscalationLevel::TopLayers), "Level 1: Top-2 layers + head");
1253        assert_eq!(format!("{}", EscalationLevel::FullFinetune), "Level 2: Full fine-tune");
1254        assert_eq!(
1255            format!("{}", EscalationLevel::ContinuePretrain),
1256            "Level 3: Continue-pretrain + fine-tune"
1257        );
1258    }
1259
1260    #[test]
1261    fn test_cov4_escalation_debug_clone() {
1262        let level = EscalationLevel::TopLayers;
1263        let cloned = level;
1264        assert_eq!(level, cloned);
1265        assert!(format!("{level:?}").contains("TopLayers"));
1266    }
1267
1268    #[test]
1269    fn test_cov4_escalate_full_to_continue() {
1270        let ci = BootstrapCI { estimate: 0.2, lower: 0.1, upper: 0.3, n_bootstrap: 100 };
1271        let result = should_escalate(EscalationLevel::FullFinetune, &ci, 0.95);
1272        assert_eq!(result, Some(EscalationLevel::ContinuePretrain));
1273    }
1274
1275    #[test]
1276    fn test_cov4_escalate_full_no_escalate() {
1277        let ci = BootstrapCI { estimate: 0.5, lower: 0.4, upper: 0.6, n_bootstrap: 100 };
1278        let result = should_escalate(EscalationLevel::FullFinetune, &ci, 0.96);
1279        assert_eq!(result, None);
1280    }
1281
1282    #[test]
1283    fn test_cov4_escalate_top_layers_no_escalate() {
1284        let ci = BootstrapCI { estimate: 0.5, lower: 0.35, upper: 0.65, n_bootstrap: 100 };
1285        let result = should_escalate(EscalationLevel::TopLayers, &ci, 0.96);
1286        assert_eq!(result, None);
1287    }
1288
1289    #[test]
1290    fn test_cov4_compare_baselines_details() {
1291        let comps = compare_baselines(0.5, &[("majority", 0.0), ("keyword", 0.5), ("linter", 0.6)]);
1292        assert_eq!(comps[0].name, "majority");
1293        assert!(comps[0].beats_baseline);
1294        assert!(!comps[1].beats_baseline); // 0.5 > 0.5 is false
1295        assert!(!comps[2].beats_baseline);
1296        assert!((comps[0].model_mcc - 0.5).abs() < 1e-5);
1297        assert!((comps[0].baseline_mcc - 0.0).abs() < 1e-5);
1298    }
1299
1300    #[test]
1301    fn test_cov4_compare_baselines_empty() {
1302        let comps = compare_baselines(0.5, &[]);
1303        assert!(comps.is_empty());
1304    }
1305
1306    #[test]
1307    fn test_cov4_generalization_result_fields() {
1308        let probe = LinearProbe::new(4, 2);
1309        let embs: Vec<Vec<f32>> = (0..5).map(|_| vec![0.0; 4]).collect();
1310        let result = generalization_test(&probe, &embs, 1);
1311        assert_eq!(result.total, 5);
1312        assert!(result.detected <= 5);
1313        assert!((result.detection_rate - result.detected as f32 / 5.0).abs() < 1e-5);
1314    }
1315
1316    #[test]
1317    fn test_cov4_ship_gate_all_fail() {
1318        let ci = BootstrapCI { estimate: 0.1, lower: 0.05, upper: 0.15, n_bootstrap: 100 };
1319        let gen =
1320            GeneralizationResult { total: 50, detected: 10, detection_rate: 0.2, passes: false };
1321        let result = check_ship_gate(&ci, 0.90, &gen, EscalationLevel::LinearProbe);
1322        assert!(!result.ship_ready);
1323        assert!(!result.mcc_passes);
1324        assert!(!result.accuracy_passes);
1325        assert!(!result.generalization_passes);
1326        assert_eq!(result.level, EscalationLevel::LinearProbe);
1327    }
1328
1329    #[test]
1330    fn test_cov4_ship_gate_fails_accuracy() {
1331        let ci = BootstrapCI { estimate: 0.4, lower: 0.25, upper: 0.55, n_bootstrap: 100 };
1332        let gen =
1333            GeneralizationResult { total: 50, detected: 30, detection_rate: 0.6, passes: true };
1334        let result = check_ship_gate(&ci, 0.90, &gen, EscalationLevel::TopLayers);
1335        assert!(!result.ship_ready);
1336        assert!(result.mcc_passes);
1337        assert!(!result.accuracy_passes);
1338        assert!(result.generalization_passes);
1339        assert_eq!(result.level, EscalationLevel::TopLayers);
1340    }
1341
1342    #[test]
1343    fn test_cov4_linear_probe_predict() {
1344        let probe = LinearProbe::new(8, 3);
1345        let emb = Tensor::from_vec(vec![0.5; 8], false);
1346        let predicted = probe.predict(&emb);
1347        assert!(predicted < 3);
1348    }
1349
1350    #[test]
1351    fn test_cov4_linear_probe_num_classes() {
1352        let probe = LinearProbe::new(64, 5);
1353        assert_eq!(probe.num_classes(), 5);
1354    }
1355
1356    #[test]
1357    fn test_cov4_linear_probe_train_with_class_weights() {
1358        let mut probe = LinearProbe::new(4, 2);
1359        let embeddings =
1360            vec![vec![1.0; 4]; 10].into_iter().chain(vec![vec![-1.0; 4]; 10]).collect::<Vec<_>>();
1361        let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
1362        let weights = vec![1.0, 5.0]; // upweight class 1
1363
1364        let loss = probe.train(&embeddings, &labels, 5, 0.01, Some(&weights));
1365        assert!(loss.is_finite());
1366    }
1367
1368    #[test]
1369    fn test_cov4_mlp_probe_predict() {
1370        let probe = MlpProbe::new(8, 16, 3);
1371        let emb = vec![0.1; 8];
1372        let predicted = probe.predict(&emb);
1373        assert!(predicted < 3);
1374    }
1375
1376    #[test]
1377    fn test_cov4_mlp_probe_predict_probs_all_positive() {
1378        let probe = MlpProbe::new(4, 8, 2);
1379        let probs = probe.predict_probs(&[0.5, -0.5, 1.0, -1.0]);
1380        assert!(probs.iter().all(|&p| p > 0.0));
1381        assert!(probs.iter().all(|&p| p <= 1.0));
1382    }
1383
1384    #[test]
1385    fn test_cov4_mlp_probe_num_parameters() {
1386        let probe = MlpProbe::new(16, 8, 3);
1387        // W1: 16*8 + b1: 8 + W2: 8*3 + b2: 3 = 128 + 8 + 24 + 3 = 163
1388        assert_eq!(probe.num_parameters(), 16 * 8 + 8 + 8 * 3 + 3);
1389    }
1390
1391    #[test]
1392    fn test_cov4_mlp_probe_train_with_class_weights() {
1393        let mut probe = MlpProbe::new(4, 8, 2);
1394        let embeddings: Vec<Vec<f32>> =
1395            (0..20).map(|i| if i < 10 { vec![1.0; 4] } else { vec![-1.0; 4] }).collect();
1396        let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
1397        let weights = vec![1.0, 5.0];
1398
1399        let loss = probe.train(&embeddings, &labels, 5, 0.01, Some(&weights), 0.0);
1400        assert!(loss.is_finite());
1401    }
1402
1403    #[test]
1404    fn test_cov4_mlp_probe_train_with_weight_decay() {
1405        let mut probe = MlpProbe::new(4, 8, 2);
1406        let embeddings: Vec<Vec<f32>> =
1407            (0..10).map(|i| if i < 5 { vec![1.0; 4] } else { vec![-1.0; 4] }).collect();
1408        let labels: Vec<usize> = (0..10).map(|i| usize::from(i >= 5)).collect();
1409
1410        let loss = probe.train(&embeddings, &labels, 5, 0.01, None, 0.01);
1411        assert!(loss.is_finite());
1412    }
1413
1414    #[test]
1415    fn test_cov4_softmax_slice_single() {
1416        let result = softmax_slice(&[0.0]);
1417        assert_eq!(result.len(), 1);
1418        assert!((result[0] - 1.0).abs() < 1e-5);
1419    }
1420
1421    #[test]
1422    fn test_cov4_softmax_slice_large_values() {
1423        // Should not overflow due to max subtraction
1424        let result = softmax_slice(&[1000.0, 1001.0]);
1425        assert_eq!(result.len(), 2);
1426        let sum: f32 = result.iter().sum();
1427        assert!((sum - 1.0).abs() < 1e-5);
1428        assert!(result[1] > result[0]); // higher logit → higher prob
1429    }
1430
1431    #[test]
1432    fn test_cov4_softmax_slice_equal() {
1433        let result = softmax_slice(&[1.0, 1.0, 1.0]);
1434        for &p in &result {
1435            assert!((p - 1.0 / 3.0).abs() < 1e-5);
1436        }
1437    }
1438
1439    #[test]
1440    fn test_cov4_classification_metrics_clone() {
1441        let m = ClassificationMetrics {
1442            mcc: 0.5,
1443            accuracy: 0.9,
1444            recall: vec![0.8, 0.7],
1445            precision: vec![0.85, 0.75],
1446            num_samples: 100,
1447            confusion_matrix: vec![vec![40, 10], vec![5, 45]],
1448        };
1449        let m2 = m.clone();
1450        assert!((m2.mcc - 0.5).abs() < 1e-5);
1451        assert_eq!(m2.num_samples, 100);
1452        assert!(format!("{m2:?}").contains("ClassificationMetrics"));
1453    }
1454
1455    #[test]
1456    fn test_cov4_bootstrap_ci_clone() {
1457        let ci = BootstrapCI { estimate: 0.5, lower: 0.3, upper: 0.7, n_bootstrap: 1000 };
1458        let ci2 = ci;
1459        assert!((ci2.estimate - 0.5).abs() < 1e-5);
1460        assert!(format!("{ci:?}").contains("BootstrapCI"));
1461    }
1462
1463    #[test]
1464    fn test_cov4_confidence_score_clone() {
1465        let s =
1466            ConfidenceScore { predicted_class: 1, confidence: 0.8, probabilities: vec![0.2, 0.8] };
1467        let s2 = s.clone();
1468        assert_eq!(s2.predicted_class, 1);
1469        assert!((s2.confidence - 0.8).abs() < 1e-5);
1470        assert!(format!("{s2:?}").contains("ConfidenceScore"));
1471    }
1472
1473    #[test]
1474    fn test_cov4_generalization_result_clone() {
1475        let r =
1476            GeneralizationResult { total: 20, detected: 15, detection_rate: 0.75, passes: true };
1477        let r2 = r.clone();
1478        assert!(r2.passes);
1479        assert_eq!(r2.total, 20);
1480        assert!(format!("{r2:?}").contains("GeneralizationResult"));
1481    }
1482
1483    #[test]
1484    fn test_cov4_baseline_comparison_clone() {
1485        let b = BaselineComparison {
1486            name: "test".to_string(),
1487            baseline_mcc: 0.3,
1488            model_mcc: 0.5,
1489            beats_baseline: true,
1490        };
1491        let b2 = b.clone();
1492        assert!(b2.beats_baseline);
1493        assert!(format!("{b2:?}").contains("BaselineComparison"));
1494    }
1495
1496    #[test]
1497    fn test_cov4_ship_gate_result_clone() {
1498        let r = ShipGateResult {
1499            mcc_passes: true,
1500            accuracy_passes: true,
1501            generalization_passes: false,
1502            ship_ready: false,
1503            level: EscalationLevel::LinearProbe,
1504        };
1505        let r2 = r.clone();
1506        assert!(!r2.ship_ready);
1507        assert!(format!("{r2:?}").contains("ShipGateResult"));
1508    }
1509
1510    #[test]
1511    fn test_cov4_multiclass_mcc_single_class() {
1512        // Edge case: all predictions and labels same class
1513        let preds = vec![0, 0, 0, 0];
1514        let labels = vec![0, 0, 0, 0];
1515        let metrics = evaluate(&preds, &labels, 3);
1516        assert!((metrics.accuracy - 1.0).abs() < 1e-5);
1517        // MCC is 0 when only one class is present (denom = 0)
1518        assert!(metrics.mcc.abs() < 1e-5 || metrics.mcc.is_finite());
1519    }
1520}