1use crate::autograd::{matmul, Tensor};
22
23#[derive(Debug, Clone)]
25pub struct ClassificationMetrics {
26 pub mcc: f32,
28 pub accuracy: f32,
30 pub recall: Vec<f32>,
32 pub precision: Vec<f32>,
34 pub num_samples: usize,
36 pub confusion_matrix: Vec<Vec<usize>>,
38}
39
40#[derive(Debug, Clone, Copy)]
42pub struct BootstrapCI {
43 pub estimate: f32,
45 pub lower: f32,
47 pub upper: f32,
49 pub n_bootstrap: usize,
51}
52
53pub struct LinearProbe {
58 pub weight: Tensor,
60 pub bias: Tensor,
62 hidden_size: usize,
64 num_classes: usize,
66}
67
68impl LinearProbe {
69 pub fn new(hidden_size: usize, num_classes: usize) -> Self {
71 assert!(hidden_size > 0, "hidden_size must be > 0");
72 assert!(num_classes >= 2, "num_classes must be >= 2");
73
74 let scale = (6.0 / (hidden_size + num_classes) as f32).sqrt();
75 let mut rng: u64 = 42;
76 let weight_data: Vec<f32> = (0..hidden_size * num_classes)
77 .map(|_| {
78 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
79 let u = (rng >> 33) as f32 / (1u64 << 31) as f32;
80 (2.0 * u - 1.0) * scale
81 })
82 .collect();
83
84 Self {
85 weight: Tensor::from_vec(weight_data, true),
86 bias: Tensor::zeros(num_classes, true),
87 hidden_size,
88 num_classes,
89 }
90 }
91
92 pub fn forward(&self, embedding: &Tensor) -> Tensor {
100 let logits = matmul(embedding, &self.weight, 1, self.hidden_size, self.num_classes);
101 let logits_data = logits.data();
102 let logits_slice = logits_data.as_slice().expect("contiguous logits");
103 let bias_data = self.bias.data();
104 let bias_slice = bias_data.as_slice().expect("contiguous bias");
105
106 let output: Vec<f32> =
107 logits_slice.iter().zip(bias_slice.iter()).map(|(&l, &b)| l + b).collect();
108 Tensor::from_vec(output, logits.requires_grad())
109 }
110
111 pub fn predict_probs(&self, embedding: &Tensor) -> Vec<f32> {
113 let logits = self.forward(embedding);
114 softmax_vec(&logits)
115 }
116
117 pub fn predict(&self, embedding: &Tensor) -> usize {
119 contract_pre_predict!();
120 let logits = self.forward(embedding);
121 let data = logits.data();
122 let slice = data.as_slice().expect("contiguous");
123 slice
124 .iter()
125 .enumerate()
126 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
127 .map_or(0, |(i, _)| i)
128 }
129
130 pub fn train(
142 &mut self,
143 embeddings: &[Vec<f32>],
144 labels: &[usize],
145 epochs: usize,
146 learning_rate: f32,
147 class_weights: Option<&[f32]>,
148 ) -> f32 {
149 assert_eq!(embeddings.len(), labels.len());
150 let n = embeddings.len();
151 let mut final_loss = 0.0;
152
153 for epoch in 0..epochs {
154 let mut epoch_loss = 0.0;
155
156 for (emb, &label) in embeddings.iter().zip(labels.iter()) {
157 assert_eq!(emb.len(), self.hidden_size);
158 assert!(label < self.num_classes);
159
160 let emb_tensor = Tensor::from_vec(emb.clone(), false);
162 let logits = self.forward(&emb_tensor);
163
164 let probs = softmax_vec(&logits);
166 let loss_weight = class_weights.map_or(1.0, |w| w[label]);
167 let loss = -probs[label].max(1e-10).ln() * loss_weight;
168 epoch_loss += loss;
169
170 let mut grad_logits = probs;
172 grad_logits[label] -= 1.0;
173 if let Some(w) = class_weights {
174 for (i, g) in grad_logits.iter_mut().enumerate() {
175 *g *= w[i];
176 }
177 }
178
179 let w_data = self.weight.data();
181 let mut w_slice = w_data.as_slice().expect("contiguous").to_vec();
182 for i in 0..self.hidden_size {
183 for j in 0..self.num_classes {
184 w_slice[i * self.num_classes + j] -=
185 learning_rate * emb[i] * grad_logits[j];
186 }
187 }
188 self.weight = Tensor::from_vec(w_slice, true);
189
190 let b_data = self.bias.data();
192 let mut b_slice = b_data.as_slice().expect("contiguous").to_vec();
193 for j in 0..self.num_classes {
194 b_slice[j] -= learning_rate * grad_logits[j];
195 }
196 self.bias = Tensor::from_vec(b_slice, true);
197 }
198
199 final_loss = epoch_loss / n as f32;
200 if epoch == 0 || (epoch + 1) % 5 == 0 || epoch == epochs - 1 {
201 eprintln!(" Epoch {}/{epochs}: loss={final_loss:.4}", epoch + 1);
202 }
203 }
204
205 final_loss
206 }
207
208 pub fn num_parameters(&self) -> usize {
210 self.hidden_size * self.num_classes + self.num_classes
211 }
212
213 pub fn num_classes(&self) -> usize {
215 self.num_classes
216 }
217}
218
219pub struct MlpProbe {
227 pub w1: Vec<f32>,
229 pub b1: Vec<f32>,
231 pub w2: Vec<f32>,
233 pub b2: Vec<f32>,
235 pub hidden_size: usize,
237 pub mlp_hidden: usize,
239 pub num_classes: usize,
241}
242
243impl MlpProbe {
244 pub fn new(hidden_size: usize, mlp_hidden: usize, num_classes: usize) -> Self {
246 assert!(hidden_size > 0 && mlp_hidden > 0 && num_classes >= 2);
247
248 let mut rng: u64 = 42;
249 let mut xavier = |fan_in: usize, fan_out: usize, n: usize| -> Vec<f32> {
250 let scale = (6.0 / (fan_in + fan_out) as f32).sqrt();
251 (0..n)
252 .map(|_| {
253 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
254 let u = (rng >> 33) as f32 / (1u64 << 31) as f32;
255 (2.0 * u - 1.0) * scale
256 })
257 .collect()
258 };
259
260 Self {
261 w1: xavier(hidden_size, mlp_hidden, hidden_size * mlp_hidden),
262 b1: vec![0.0; mlp_hidden],
263 w2: xavier(mlp_hidden, num_classes, mlp_hidden * num_classes),
264 b2: vec![0.0; num_classes],
265 hidden_size,
266 mlp_hidden,
267 num_classes,
268 }
269 }
270
271 pub fn forward(&self, emb: &[f32]) -> (Vec<f32>, Vec<f32>) {
273 let mut h = vec![0.0_f32; self.mlp_hidden];
275 for j in 0..self.mlp_hidden {
276 let mut sum = self.b1[j];
277 for i in 0..self.hidden_size {
278 sum += self.w1[i * self.mlp_hidden + j] * emb[i];
279 }
280 h[j] = sum.max(0.0); }
282
283 let mut logits = vec![0.0_f32; self.num_classes];
285 for j in 0..self.num_classes {
286 let mut sum = self.b2[j];
287 for i in 0..self.mlp_hidden {
288 sum += self.w2[i * self.num_classes + j] * h[i];
289 }
290 logits[j] = sum;
291 }
292
293 (h, logits)
294 }
295
296 pub fn predict(&self, emb: &[f32]) -> usize {
298 contract_pre_predict!();
299 let (_, logits) = self.forward(emb);
300 logits
301 .iter()
302 .enumerate()
303 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
304 .map_or(0, |(i, _)| i)
305 }
306
307 pub fn predict_probs(&self, emb: &[f32]) -> Vec<f32> {
309 let (_, logits) = self.forward(emb);
310 softmax_slice(&logits)
311 }
312
313 fn forward_train(&self, emb: &[f32]) -> (Vec<f32>, Vec<f32>, Vec<f32>) {
315 let mut h_pre = vec![0.0_f32; self.mlp_hidden];
316 let mut h = vec![0.0_f32; self.mlp_hidden];
317 for j in 0..self.mlp_hidden {
318 let mut sum = self.b1[j];
319 for i in 0..self.hidden_size {
320 sum += self.w1[i * self.mlp_hidden + j] * emb[i];
321 }
322 h_pre[j] = sum;
323 h[j] = sum.max(0.0);
324 }
325
326 let mut logits = vec![0.0_f32; self.num_classes];
327 for j in 0..self.num_classes {
328 let mut sum = self.b2[j];
329 for i in 0..self.mlp_hidden {
330 sum += self.w2[i * self.num_classes + j] * h[i];
331 }
332 logits[j] = sum;
333 }
334 (h_pre, h, logits)
335 }
336
337 fn backward_step(
339 &mut self,
340 emb: &[f32],
341 h_pre: &[f32],
342 h: &[f32],
343 grad_logits: &[f32],
344 lr: f32,
345 wd: f32,
346 ) {
347 for i in 0..self.mlp_hidden {
349 for j in 0..self.num_classes {
350 let idx = i * self.num_classes + j;
351 self.w2[idx] -= lr * (h[i] * grad_logits[j] + wd * self.w2[idx]);
352 }
353 }
354 for j in 0..self.num_classes {
355 self.b2[j] -= lr * grad_logits[j];
356 }
357
358 let mut grad_h = vec![0.0_f32; self.mlp_hidden];
360 for i in 0..self.mlp_hidden {
361 if h_pre[i] > 0.0 {
362 for j in 0..self.num_classes {
363 grad_h[i] += self.w2[i * self.num_classes + j] * grad_logits[j];
364 }
365 }
366 }
367
368 for i in 0..self.hidden_size {
370 for j in 0..self.mlp_hidden {
371 let idx = i * self.mlp_hidden + j;
372 self.w1[idx] -= lr * (emb[i] * grad_h[j] + wd * self.w1[idx]);
373 }
374 }
375 for j in 0..self.mlp_hidden {
376 self.b1[j] -= lr * grad_h[j];
377 }
378 }
379
380 #[allow(clippy::too_many_arguments)]
382 pub fn train(
383 &mut self,
384 embeddings: &[Vec<f32>],
385 labels: &[usize],
386 epochs: usize,
387 learning_rate: f32,
388 class_weights: Option<&[f32]>,
389 weight_decay: f32,
390 ) -> f32 {
391 assert_eq!(embeddings.len(), labels.len());
392 let n = embeddings.len();
393 let mut final_loss = 0.0;
394
395 for epoch in 0..epochs {
396 let mut epoch_loss = 0.0;
397
398 for (emb, &label) in embeddings.iter().zip(labels.iter()) {
399 let (h_pre, h, logits) = self.forward_train(emb);
400 let probs = softmax_slice(&logits);
401 let loss_weight = class_weights.map_or(1.0, |w| w[label]);
402 epoch_loss += -probs[label].max(1e-10).ln() * loss_weight;
403
404 let mut grad_logits = probs;
405 grad_logits[label] -= 1.0;
406 if let Some(w) = class_weights {
407 for (i, g) in grad_logits.iter_mut().enumerate() {
408 *g *= w[i];
409 }
410 }
411
412 self.backward_step(emb, &h_pre, &h, &grad_logits, learning_rate, weight_decay);
413 }
414
415 final_loss = epoch_loss / n as f32;
416 if epoch == 0 || (epoch + 1) % 10 == 0 || epoch == epochs - 1 {
417 eprintln!(" Epoch {}/{epochs}: loss={final_loss:.4}", epoch + 1);
418 }
419 }
420
421 final_loss
422 }
423
424 pub fn num_parameters(&self) -> usize {
426 self.hidden_size * self.mlp_hidden + self.mlp_hidden + self.mlp_hidden * self.num_classes + self.num_classes }
429}
430
431fn softmax_slice(logits: &[f32]) -> Vec<f32> {
433 let max_val = logits.iter().copied().fold(f32::NEG_INFINITY, f32::max);
434 let exp_vals: Vec<f32> = logits.iter().map(|&x| (x - max_val).exp()).collect();
435 let sum: f32 = exp_vals.iter().sum();
436 exp_vals.iter().map(|&v| v / sum).collect()
437}
438
439fn softmax_vec(logits: &Tensor) -> Vec<f32> {
441 let data = logits.data();
442 let slice = data.as_slice().expect("contiguous logits");
443 let max_val = slice.iter().copied().fold(f32::NEG_INFINITY, f32::max);
444 let exp_vals: Vec<f32> = slice.iter().map(|&x| (x - max_val).exp()).collect();
445 let sum: f32 = exp_vals.iter().sum();
446 exp_vals.iter().map(|&v| v / sum).collect()
447}
448
449pub fn binary_mcc(tp: usize, tn: usize, fp: usize, fn_count: usize) -> f32 {
453 let numerator = (tp * tn) as f64 - (fp * fn_count) as f64;
454 let denom =
455 ((tp + fp) as f64 * (tp + fn_count) as f64 * (tn + fp) as f64 * (tn + fn_count) as f64)
456 .sqrt();
457 if denom < 1e-10 {
458 0.0
459 } else {
460 (numerator / denom) as f32
461 }
462}
463
464pub fn evaluate(
466 predictions: &[usize],
467 labels: &[usize],
468 num_classes: usize,
469) -> ClassificationMetrics {
470 assert_eq!(predictions.len(), labels.len());
471 let n = predictions.len();
472
473 let mut cm = vec![vec![0usize; num_classes]; num_classes];
475 for (&pred, &label) in predictions.iter().zip(labels.iter()) {
476 if pred < num_classes && label < num_classes {
477 cm[pred][label] += 1;
478 }
479 }
480
481 let correct: usize = (0..num_classes).map(|c| cm[c][c]).sum();
483 let accuracy = correct as f32 / n.max(1) as f32;
484
485 let mut precision = vec![0.0_f32; num_classes];
487 let mut recall = vec![0.0_f32; num_classes];
488 for c in 0..num_classes {
489 let pred_count: usize = cm[c].iter().sum();
490 let actual_count: usize = (0..num_classes).map(|p| cm[p][c]).sum();
491 precision[c] = if pred_count > 0 { cm[c][c] as f32 / pred_count as f32 } else { 0.0 };
492 recall[c] = if actual_count > 0 { cm[c][c] as f32 / actual_count as f32 } else { 0.0 };
493 }
494
495 let mcc = if num_classes == 2 {
497 let tp = cm[1][1];
498 let tn = cm[0][0];
499 let fp = cm[1][0];
500 let fn_count = cm[0][1];
501 binary_mcc(tp, tn, fp, fn_count)
502 } else {
503 multiclass_mcc(&cm, num_classes)
504 };
505
506 ClassificationMetrics { mcc, accuracy, recall, precision, num_samples: n, confusion_matrix: cm }
507}
508
509fn multiclass_mcc(cm: &[Vec<usize>], k: usize) -> f32 {
511 let n: f64 = cm.iter().flat_map(|row| row.iter()).sum::<usize>() as f64;
512 let c: f64 = (0..k).map(|i| cm[i][i] as f64).sum();
513
514 let mut s = 0.0_f64; let mut p = 0.0_f64; let mut t = 0.0_f64; for i in 0..k {
519 let row_sum: f64 = cm[i].iter().sum::<usize>() as f64;
520 let col_sum: f64 = (0..k).map(|j| cm[j][i] as f64).sum();
521 p += row_sum * row_sum;
522 t += col_sum * col_sum;
523 for j in 0..k {
524 s += (cm[i].iter().sum::<usize>() as f64) * (cm[j][i] as f64);
525 }
526 }
527
528 let numerator = c * n - s;
529 let denom = ((n * n - p) * (n * n - t)).sqrt();
530 if denom < 1e-10 {
531 0.0
532 } else {
533 (numerator / denom) as f32
534 }
535}
536
537pub fn bootstrap_mcc_ci(
542 predictions: &[usize],
543 labels: &[usize],
544 num_classes: usize,
545 n_bootstrap: usize,
546) -> BootstrapCI {
547 let n = predictions.len();
548 let point_estimate = evaluate(predictions, labels, num_classes).mcc;
549
550 let mut mcc_samples = Vec::with_capacity(n_bootstrap);
551 let mut rng: u64 = 12345;
552
553 for _ in 0..n_bootstrap {
554 let mut boot_preds = Vec::with_capacity(n);
555 let mut boot_labels = Vec::with_capacity(n);
556
557 for _ in 0..n {
558 rng = rng.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1442695040888963407);
559 let idx = (rng >> 33) as usize % n;
560 boot_preds.push(predictions[idx]);
561 boot_labels.push(labels[idx]);
562 }
563
564 let metrics = evaluate(&boot_preds, &boot_labels, num_classes);
565 mcc_samples.push(metrics.mcc);
566 }
567
568 mcc_samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
569
570 let lower_idx = (n_bootstrap as f32 * 0.025) as usize;
571 let upper_idx = ((n_bootstrap as f32 * 0.975) as usize).min(n_bootstrap - 1);
572
573 BootstrapCI {
574 estimate: point_estimate,
575 lower: mcc_samples[lower_idx],
576 upper: mcc_samples[upper_idx],
577 n_bootstrap,
578 }
579}
580
581#[derive(Debug, Clone)]
583pub struct ConfidenceScore {
584 pub predicted_class: usize,
586 pub confidence: f32,
588 pub probabilities: Vec<f32>,
590}
591
592pub fn compute_confidence_scores(
594 probe: &LinearProbe,
595 embeddings: &[Vec<f32>],
596) -> Vec<ConfidenceScore> {
597 embeddings
598 .iter()
599 .map(|emb| {
600 let emb_tensor = Tensor::from_vec(emb.clone(), false);
601 let probs = probe.predict_probs(&emb_tensor);
602 let (predicted_class, &confidence) = probs
603 .iter()
604 .enumerate()
605 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
606 .expect("non-empty probabilities");
607 ConfidenceScore { predicted_class, confidence, probabilities: probs }
608 })
609 .collect()
610}
611
612#[derive(Debug, Clone, Copy, PartialEq, Eq)]
618pub enum EscalationLevel {
619 LinearProbe,
621 TopLayers,
623 FullFinetune,
625 ContinuePretrain,
627}
628
629impl std::fmt::Display for EscalationLevel {
630 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
631 match self {
632 Self::LinearProbe => write!(f, "Level 0: Linear probe"),
633 Self::TopLayers => write!(f, "Level 1: Top-2 layers + head"),
634 Self::FullFinetune => write!(f, "Level 2: Full fine-tune"),
635 Self::ContinuePretrain => write!(f, "Level 3: Continue-pretrain + fine-tune"),
636 }
637 }
638}
639
640pub fn should_escalate(
644 current_level: EscalationLevel,
645 mcc_ci: &BootstrapCI,
646 accuracy: f32,
647) -> Option<EscalationLevel> {
648 match current_level {
649 EscalationLevel::LinearProbe => {
650 if mcc_ci.lower < 0.2 || accuracy <= 0.935 {
651 Some(EscalationLevel::TopLayers)
652 } else {
653 None }
655 }
656 EscalationLevel::TopLayers | EscalationLevel::FullFinetune => {
657 if mcc_ci.lower < 0.3 {
658 match current_level {
659 EscalationLevel::TopLayers => Some(EscalationLevel::FullFinetune),
660 _ => Some(EscalationLevel::ContinuePretrain),
661 }
662 } else {
663 None
664 }
665 }
666 EscalationLevel::ContinuePretrain => {
667 None
669 }
670 }
671}
672
673#[derive(Debug, Clone)]
679pub struct BaselineComparison {
680 pub name: String,
682 pub baseline_mcc: f32,
684 pub model_mcc: f32,
686 pub beats_baseline: bool,
688}
689
690pub fn compare_baselines(model_mcc: f32, baseline_mccs: &[(&str, f32)]) -> Vec<BaselineComparison> {
697 baseline_mccs
698 .iter()
699 .map(|&(name, baseline_mcc)| BaselineComparison {
700 name: name.to_string(),
701 baseline_mcc,
702 model_mcc,
703 beats_baseline: model_mcc > baseline_mcc,
704 })
705 .collect()
706}
707
708#[derive(Debug, Clone)]
714pub struct GeneralizationResult {
715 pub total: usize,
717 pub detected: usize,
719 pub detection_rate: f32,
721 pub passes: bool,
723}
724
725pub fn generalization_test(
730 probe: &LinearProbe,
731 novel_embeddings: &[Vec<f32>],
732 unsafe_class: usize,
733) -> GeneralizationResult {
734 let total = novel_embeddings.len();
735 let detected = novel_embeddings
736 .iter()
737 .filter(|emb| {
738 let emb_tensor = Tensor::from_vec((*emb).clone(), false);
739 probe.predict(&emb_tensor) == unsafe_class
740 })
741 .count();
742
743 let detection_rate = if total > 0 { detected as f32 / total as f32 } else { 0.0 };
744
745 GeneralizationResult { total, detected, detection_rate, passes: detection_rate >= 0.5 }
746}
747
748#[allow(clippy::struct_excessive_bools)]
754#[derive(Debug, Clone)]
755pub struct ShipGateResult {
756 pub mcc_passes: bool,
758 pub accuracy_passes: bool,
760 pub generalization_passes: bool,
762 pub ship_ready: bool,
764 pub level: EscalationLevel,
766}
767
768pub fn check_ship_gate(
770 mcc_ci: &BootstrapCI,
771 accuracy: f32,
772 generalization: &GeneralizationResult,
773 level: EscalationLevel,
774) -> ShipGateResult {
775 let mcc_passes = mcc_ci.lower > 0.2;
776 let accuracy_passes = accuracy > 0.935;
777 let generalization_passes = generalization.passes;
778
779 ShipGateResult {
780 mcc_passes,
781 accuracy_passes,
782 generalization_passes,
783 ship_ready: mcc_passes && accuracy_passes && generalization_passes,
784 level,
785 }
786}
787
788#[cfg(test)]
789#[allow(clippy::unwrap_used)]
790mod tests {
791 use super::*;
792
793 #[test]
794 fn clf_002_linear_probe_forward_shape() {
795 let probe = LinearProbe::new(768, 2);
796 let emb = Tensor::from_vec(vec![0.1; 768], false);
797 let logits = probe.forward(&emb);
798 assert_eq!(logits.len(), 2);
799 }
800
801 #[test]
802 fn clf_002_linear_probe_predict_probs_sum_to_one() {
803 let probe = LinearProbe::new(64, 3);
804 let emb = Tensor::from_vec(vec![0.5; 64], false);
805 let probs = probe.predict_probs(&emb);
806 assert_eq!(probs.len(), 3);
807 let sum: f32 = probs.iter().sum();
808 assert!((sum - 1.0).abs() < 1e-5, "probabilities must sum to 1.0, got {sum}");
809 assert!(probs.iter().all(|&p| p > 0.0), "all probabilities must be positive");
810 }
811
812 #[test]
813 fn clf_002_linear_probe_num_parameters() {
814 let probe = LinearProbe::new(768, 2);
815 assert_eq!(probe.num_parameters(), 768 * 2 + 2); }
817
818 #[test]
819 fn clf_002_linear_probe_train_reduces_loss() {
820 let mut probe = LinearProbe::new(8, 2);
821 let embeddings: Vec<Vec<f32>> = (0..20)
823 .map(|i| {
824 if i < 10 {
825 vec![1.0; 8] } else {
827 vec![-1.0; 8] }
829 })
830 .collect();
831 let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
832
833 let loss_before = {
834 let mut temp = LinearProbe::new(8, 2);
835 temp.train(&embeddings, &labels, 1, 0.01, None)
836 };
837 let loss_after = probe.train(&embeddings, &labels, 10, 0.01, None);
838
839 assert!(loss_after < loss_before + 0.5, "training should reduce loss");
841 }
842
843 #[test]
844 fn clf_003_binary_mcc_perfect() {
845 assert!((binary_mcc(50, 50, 0, 0) - 1.0).abs() < 1e-5);
847 }
848
849 #[test]
850 fn clf_003_binary_mcc_random() {
851 assert!(binary_mcc(25, 25, 25, 25).abs() < 1e-5);
853 }
854
855 #[test]
856 fn clf_003_evaluate_perfect() {
857 let preds = vec![0, 0, 1, 1, 1];
858 let labels = vec![0, 0, 1, 1, 1];
859 let metrics = evaluate(&preds, &labels, 2);
860 assert!((metrics.accuracy - 1.0).abs() < 1e-5);
861 assert!((metrics.mcc - 1.0).abs() < 1e-5);
862 }
863
864 #[test]
865 fn clf_003_evaluate_majority_baseline() {
866 let preds = vec![0; 100];
868 let labels: Vec<usize> = (0..100).map(|i| usize::from(i >= 93)).collect();
869 let metrics = evaluate(&preds, &labels, 2);
870 assert!((metrics.accuracy - 0.93).abs() < 0.01);
871 assert_eq!(metrics.recall[1], 0.0); }
873
874 #[test]
875 fn clf_003_bootstrap_ci_contains_estimate() {
876 let preds = vec![0, 0, 1, 1, 0, 1, 0, 0, 1, 1];
877 let labels = vec![0, 0, 1, 1, 0, 0, 0, 1, 1, 1];
878 let ci = bootstrap_mcc_ci(&preds, &labels, 2, 100);
879 assert!(ci.lower <= ci.estimate, "CI lower must be <= estimate");
880 assert!(ci.upper >= ci.estimate, "CI upper must be >= estimate");
881 }
882
883 #[test]
884 fn clf_007_confidence_scores() {
885 let probe = LinearProbe::new(8, 2);
886 let embeddings = vec![vec![0.5; 8], vec![-0.5; 8]];
887 let scores = compute_confidence_scores(&probe, &embeddings);
888 assert_eq!(scores.len(), 2);
889 for score in &scores {
890 assert!(score.confidence > 0.0);
891 assert!(score.confidence <= 1.0);
892 assert_eq!(score.probabilities.len(), 2);
893 let sum: f32 = score.probabilities.iter().sum();
894 assert!((sum - 1.0).abs() < 1e-5);
895 }
896 }
897
898 #[test]
903 fn clf_004_escalate_from_linear_probe_low_mcc() {
904 let ci = BootstrapCI { estimate: 0.15, lower: 0.10, upper: 0.20, n_bootstrap: 100 };
905 let result = should_escalate(EscalationLevel::LinearProbe, &ci, 0.94);
906 assert_eq!(result, Some(EscalationLevel::TopLayers));
907 }
908
909 #[test]
910 fn clf_004_no_escalate_when_ship_gate_met() {
911 let ci = BootstrapCI { estimate: 0.45, lower: 0.30, upper: 0.60, n_bootstrap: 100 };
912 let result = should_escalate(EscalationLevel::LinearProbe, &ci, 0.96);
913 assert_eq!(result, None);
914 }
915
916 #[test]
917 fn clf_004_escalate_from_top_layers_to_full() {
918 let ci = BootstrapCI { estimate: 0.25, lower: 0.15, upper: 0.35, n_bootstrap: 100 };
919 let result = should_escalate(EscalationLevel::TopLayers, &ci, 0.95);
920 assert_eq!(result, Some(EscalationLevel::FullFinetune));
921 }
922
923 #[test]
924 fn clf_004_terminal_level_no_escalation() {
925 let ci = BootstrapCI { estimate: 0.1, lower: 0.05, upper: 0.15, n_bootstrap: 100 };
926 let result = should_escalate(EscalationLevel::ContinuePretrain, &ci, 0.90);
927 assert_eq!(result, None); }
929
930 #[test]
931 fn clf_004_escalate_on_low_accuracy() {
932 let ci = BootstrapCI { estimate: 0.45, lower: 0.30, upper: 0.60, n_bootstrap: 100 };
934 let result = should_escalate(EscalationLevel::LinearProbe, &ci, 0.93);
935 assert_eq!(result, Some(EscalationLevel::TopLayers));
936 }
937
938 #[test]
943 fn clf_005_compare_baselines_beats_majority() {
944 let baselines = vec![("majority", 0.0), ("keyword", 0.4), ("linter", 0.5)];
945 let comparisons = compare_baselines(0.35, &baselines);
946 assert!(comparisons[0].beats_baseline); assert!(!comparisons[1].beats_baseline); assert!(!comparisons[2].beats_baseline); }
950
951 #[test]
952 fn clf_005_compare_baselines_beats_all() {
953 let baselines = vec![("majority", 0.0), ("keyword", 0.4), ("linter", 0.5)];
954 let comparisons = compare_baselines(0.65, &baselines);
955 assert!(comparisons.iter().all(|c| c.beats_baseline));
956 }
957
958 #[test]
963 fn clf_006_generalization_all_detected() {
964 let mut probe = LinearProbe::new(4, 2);
965 let embeddings: Vec<Vec<f32>> =
967 (0..20).map(|i| if i < 10 { vec![1.0; 4] } else { vec![-1.0; 4] }).collect();
968 let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
969 probe.train(&embeddings, &labels, 30, 0.1, None);
970
971 let novel = vec![vec![-1.0; 4]; 10]; let result = generalization_test(&probe, &novel, 1);
973 assert_eq!(result.total, 10);
974 assert!(result.passes, "trained probe should detect unsafe-pattern embeddings");
975 }
976
977 #[test]
978 fn clf_006_generalization_empty() {
979 let probe = LinearProbe::new(4, 2);
980 let result = generalization_test(&probe, &[], 1);
981 assert_eq!(result.total, 0);
982 assert_eq!(result.detection_rate, 0.0);
983 }
984
985 #[test]
990 fn clf_ship_gate_passes() {
991 let ci = BootstrapCI { estimate: 0.4, lower: 0.25, upper: 0.55, n_bootstrap: 100 };
992 let gen =
993 GeneralizationResult { total: 50, detected: 30, detection_rate: 0.6, passes: true };
994 let result = check_ship_gate(&ci, 0.96, &gen, EscalationLevel::LinearProbe);
995 assert!(result.ship_ready);
996 assert!(result.mcc_passes);
997 assert!(result.accuracy_passes);
998 assert!(result.generalization_passes);
999 }
1000
1001 #[test]
1002 fn clf_ship_gate_fails_mcc() {
1003 let ci = BootstrapCI { estimate: 0.15, lower: 0.10, upper: 0.20, n_bootstrap: 100 };
1004 let gen =
1005 GeneralizationResult { total: 50, detected: 30, detection_rate: 0.6, passes: true };
1006 let result = check_ship_gate(&ci, 0.96, &gen, EscalationLevel::LinearProbe);
1007 assert!(!result.ship_ready);
1008 assert!(!result.mcc_passes);
1009 }
1010
1011 #[test]
1012 fn clf_ship_gate_fails_generalization() {
1013 let ci = BootstrapCI { estimate: 0.4, lower: 0.25, upper: 0.55, n_bootstrap: 100 };
1014 let gen =
1015 GeneralizationResult { total: 50, detected: 20, detection_rate: 0.4, passes: false };
1016 let result = check_ship_gate(&ci, 0.96, &gen, EscalationLevel::LinearProbe);
1017 assert!(!result.ship_ready);
1018 assert!(!result.generalization_passes);
1019 }
1020
1021 #[test]
1026 fn mlp_probe_forward_shape() {
1027 let probe = MlpProbe::new(768, 128, 2);
1028 let emb = vec![0.1; 768];
1029 let (h, logits) = probe.forward(&emb);
1030 assert_eq!(h.len(), 128);
1031 assert_eq!(logits.len(), 2);
1032 }
1033
1034 #[test]
1035 fn mlp_probe_predict_probs_sum_to_one() {
1036 let probe = MlpProbe::new(64, 32, 3);
1037 let emb = vec![0.5; 64];
1038 let probs = probe.predict_probs(&emb);
1039 assert_eq!(probs.len(), 3);
1040 let sum: f32 = probs.iter().sum();
1041 assert!((sum - 1.0).abs() < 1e-5, "probabilities must sum to 1.0, got {sum}");
1042 }
1043
1044 #[test]
1045 fn mlp_probe_num_parameters() {
1046 let probe = MlpProbe::new(768, 128, 2);
1047 assert_eq!(probe.num_parameters(), 768 * 128 + 128 + 128 * 2 + 2);
1049 }
1050
1051 #[test]
1052 fn mlp_probe_relu_zeros_negative() {
1053 let probe = MlpProbe::new(4, 4, 2);
1054 let emb = vec![-10.0; 4]; let (h, _) = probe.forward(&emb);
1056 assert!(h.iter().all(|&v| v >= 0.0), "ReLU output must be non-negative");
1059 }
1060
1061 #[test]
1062 fn mlp_probe_train_learns_xor() {
1063 let embeddings = vec![vec![0.0, 0.0], vec![0.0, 1.0], vec![1.0, 0.0], vec![1.0, 1.0]];
1065 let labels = vec![0, 1, 1, 0]; let embeddings: Vec<Vec<f32>> = embeddings.iter().cycle().take(40).cloned().collect();
1069 let labels: Vec<usize> = labels.iter().cycle().take(40).copied().collect();
1070
1071 let mut mlp = MlpProbe::new(2, 8, 2);
1072 mlp.train(&embeddings, &labels, 200, 0.1, None, 0.0);
1073
1074 let pred_00 = mlp.predict(&[0.0, 0.0]);
1076 let pred_01 = mlp.predict(&[0.0, 1.0]);
1077 let pred_10 = mlp.predict(&[1.0, 0.0]);
1078 let pred_11 = mlp.predict(&[1.0, 1.0]);
1079
1080 let correct = u8::from(pred_00 == 0)
1082 + u8::from(pred_01 == 1)
1083 + u8::from(pred_10 == 1)
1084 + u8::from(pred_11 == 0);
1085 assert!(correct >= 3, "MLP should learn XOR (got {correct}/4 correct)");
1086 }
1087
1088 #[test]
1089 fn mlp_probe_train_reduces_loss() {
1090 let mut probe = MlpProbe::new(8, 16, 2);
1091 let embeddings: Vec<Vec<f32>> =
1092 (0..20).map(|i| if i < 10 { vec![1.0; 8] } else { vec![-1.0; 8] }).collect();
1093 let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
1094
1095 let loss_1 = probe.train(&embeddings, &labels, 1, 0.01, None, 0.0);
1096 let loss_10 = probe.train(&embeddings, &labels, 10, 0.01, None, 0.0);
1097 assert!(loss_10 < loss_1 + 0.5, "training should reduce loss");
1098 }
1099
1100 #[test]
1103 fn test_cov4_multiclass_mcc_perfect_3class() {
1104 let preds = vec![0, 0, 1, 1, 2, 2];
1106 let labels = vec![0, 0, 1, 1, 2, 2];
1107 let metrics = evaluate(&preds, &labels, 3);
1108 assert!((metrics.accuracy - 1.0).abs() < 1e-5);
1109 assert!(metrics.mcc > 0.9, "Perfect 3-class should have high MCC, got {}", metrics.mcc);
1110 }
1111
1112 #[test]
1113 fn test_cov4_multiclass_mcc_random_3class() {
1114 let preds = vec![1, 2, 0, 2, 0, 1];
1116 let labels = vec![0, 0, 1, 1, 2, 2];
1117 let metrics = evaluate(&preds, &labels, 3);
1118 assert!(metrics.mcc < 0.1, "Random 3-class MCC should be near 0, got {}", metrics.mcc);
1119 }
1120
1121 #[test]
1122 fn test_cov4_multiclass_mcc_4class() {
1123 let preds = vec![0, 1, 2, 3, 0, 1, 2, 3];
1124 let labels = vec![0, 1, 2, 3, 0, 1, 2, 3];
1125 let metrics = evaluate(&preds, &labels, 4);
1126 assert!((metrics.mcc - 1.0).abs() < 1e-5);
1127 assert_eq!(metrics.num_samples, 8);
1128 }
1129
1130 #[test]
1131 fn test_cov4_binary_mcc_all_tp() {
1132 assert_eq!(binary_mcc(100, 0, 0, 0), 0.0); }
1135
1136 #[test]
1137 fn test_cov4_binary_mcc_all_tn() {
1138 assert_eq!(binary_mcc(0, 100, 0, 0), 0.0); }
1141
1142 #[test]
1143 fn test_cov4_binary_mcc_worst() {
1144 assert!((binary_mcc(0, 0, 50, 50) - (-1.0)).abs() < 1e-5);
1146 }
1147
1148 #[test]
1149 fn test_cov4_binary_mcc_asymmetric() {
1150 let mcc = binary_mcc(80, 10, 5, 5);
1152 assert!(mcc > 0.0 && mcc < 1.0, "Asymmetric MCC should be between 0 and 1, got {mcc}");
1153 }
1154
1155 #[test]
1156 fn test_cov4_evaluate_all_same_prediction() {
1157 let preds = vec![0, 0, 0, 0, 0];
1159 let labels = vec![0, 0, 1, 1, 1];
1160 let metrics = evaluate(&preds, &labels, 2);
1161 assert!((metrics.accuracy - 0.4).abs() < 1e-5);
1162 assert_eq!(metrics.recall[0], 1.0); assert_eq!(metrics.recall[1], 0.0); }
1165
1166 #[test]
1167 fn test_cov4_evaluate_empty() {
1168 let metrics = evaluate(&[], &[], 2);
1169 assert_eq!(metrics.num_samples, 0);
1170 assert!((metrics.accuracy - 0.0).abs() < 1e-5);
1171 }
1172
1173 #[test]
1174 fn test_cov4_evaluate_precision() {
1175 let preds = vec![0, 0, 1, 1, 1];
1176 let labels = vec![0, 1, 1, 1, 0];
1177 let metrics = evaluate(&preds, &labels, 2);
1178 assert!((metrics.precision[0] - 0.5).abs() < 1e-5);
1180 assert!((metrics.precision[1] - 2.0 / 3.0).abs() < 1e-5);
1182 }
1183
1184 #[test]
1185 fn test_cov4_evaluate_confusion_matrix() {
1186 let preds = vec![0, 1, 0, 1];
1187 let labels = vec![0, 1, 1, 0];
1188 let metrics = evaluate(&preds, &labels, 2);
1189 assert_eq!(metrics.confusion_matrix[0][0], 1); assert_eq!(metrics.confusion_matrix[0][1], 1); assert_eq!(metrics.confusion_matrix[1][0], 1); assert_eq!(metrics.confusion_matrix[1][1], 1); }
1195
1196 #[test]
1197 fn test_cov4_evaluate_out_of_bounds_ignored() {
1198 let preds = vec![0, 1, 5]; let labels = vec![0, 1, 0];
1201 let metrics = evaluate(&preds, &labels, 2);
1202 assert_eq!(metrics.num_samples, 3);
1203 assert_eq!(metrics.confusion_matrix[0][0], 1);
1205 assert_eq!(metrics.confusion_matrix[1][1], 1);
1206 }
1207
1208 #[test]
1209 fn test_cov4_bootstrap_ci_deterministic() {
1210 let preds = vec![0, 0, 1, 1, 0, 1, 0, 0, 1, 1];
1211 let labels = vec![0, 0, 1, 1, 0, 0, 0, 1, 1, 1];
1212 let ci1 = bootstrap_mcc_ci(&preds, &labels, 2, 50);
1213 let ci2 = bootstrap_mcc_ci(&preds, &labels, 2, 50);
1214 assert!((ci1.lower - ci2.lower).abs() < 1e-5);
1216 assert!((ci1.upper - ci2.upper).abs() < 1e-5);
1217 }
1218
1219 #[test]
1220 fn test_cov4_bootstrap_ci_bounds() {
1221 let preds = vec![0, 0, 1, 1, 0, 1];
1222 let labels = vec![0, 0, 1, 1, 0, 1];
1223 let ci = bootstrap_mcc_ci(&preds, &labels, 2, 200);
1224 assert!(ci.lower <= ci.upper);
1225 assert!(ci.lower >= -1.0);
1226 assert!(ci.upper <= 1.0);
1227 assert_eq!(ci.n_bootstrap, 200);
1228 }
1229
1230 #[test]
1231 fn test_cov4_confidence_scores_deterministic() {
1232 let probe = LinearProbe::new(8, 2);
1233 let embs = vec![vec![0.5; 8], vec![-0.5; 8]];
1234 let scores1 = compute_confidence_scores(&probe, &embs);
1235 let scores2 = compute_confidence_scores(&probe, &embs);
1236 for (s1, s2) in scores1.iter().zip(scores2.iter()) {
1237 assert_eq!(s1.predicted_class, s2.predicted_class);
1238 assert!((s1.confidence - s2.confidence).abs() < 1e-6);
1239 }
1240 }
1241
1242 #[test]
1243 fn test_cov4_confidence_scores_empty() {
1244 let probe = LinearProbe::new(8, 2);
1245 let scores = compute_confidence_scores(&probe, &[]);
1246 assert!(scores.is_empty());
1247 }
1248
1249 #[test]
1250 fn test_cov4_escalation_display() {
1251 assert_eq!(format!("{}", EscalationLevel::LinearProbe), "Level 0: Linear probe");
1252 assert_eq!(format!("{}", EscalationLevel::TopLayers), "Level 1: Top-2 layers + head");
1253 assert_eq!(format!("{}", EscalationLevel::FullFinetune), "Level 2: Full fine-tune");
1254 assert_eq!(
1255 format!("{}", EscalationLevel::ContinuePretrain),
1256 "Level 3: Continue-pretrain + fine-tune"
1257 );
1258 }
1259
1260 #[test]
1261 fn test_cov4_escalation_debug_clone() {
1262 let level = EscalationLevel::TopLayers;
1263 let cloned = level;
1264 assert_eq!(level, cloned);
1265 assert!(format!("{level:?}").contains("TopLayers"));
1266 }
1267
1268 #[test]
1269 fn test_cov4_escalate_full_to_continue() {
1270 let ci = BootstrapCI { estimate: 0.2, lower: 0.1, upper: 0.3, n_bootstrap: 100 };
1271 let result = should_escalate(EscalationLevel::FullFinetune, &ci, 0.95);
1272 assert_eq!(result, Some(EscalationLevel::ContinuePretrain));
1273 }
1274
1275 #[test]
1276 fn test_cov4_escalate_full_no_escalate() {
1277 let ci = BootstrapCI { estimate: 0.5, lower: 0.4, upper: 0.6, n_bootstrap: 100 };
1278 let result = should_escalate(EscalationLevel::FullFinetune, &ci, 0.96);
1279 assert_eq!(result, None);
1280 }
1281
1282 #[test]
1283 fn test_cov4_escalate_top_layers_no_escalate() {
1284 let ci = BootstrapCI { estimate: 0.5, lower: 0.35, upper: 0.65, n_bootstrap: 100 };
1285 let result = should_escalate(EscalationLevel::TopLayers, &ci, 0.96);
1286 assert_eq!(result, None);
1287 }
1288
1289 #[test]
1290 fn test_cov4_compare_baselines_details() {
1291 let comps = compare_baselines(0.5, &[("majority", 0.0), ("keyword", 0.5), ("linter", 0.6)]);
1292 assert_eq!(comps[0].name, "majority");
1293 assert!(comps[0].beats_baseline);
1294 assert!(!comps[1].beats_baseline); assert!(!comps[2].beats_baseline);
1296 assert!((comps[0].model_mcc - 0.5).abs() < 1e-5);
1297 assert!((comps[0].baseline_mcc - 0.0).abs() < 1e-5);
1298 }
1299
1300 #[test]
1301 fn test_cov4_compare_baselines_empty() {
1302 let comps = compare_baselines(0.5, &[]);
1303 assert!(comps.is_empty());
1304 }
1305
1306 #[test]
1307 fn test_cov4_generalization_result_fields() {
1308 let probe = LinearProbe::new(4, 2);
1309 let embs: Vec<Vec<f32>> = (0..5).map(|_| vec![0.0; 4]).collect();
1310 let result = generalization_test(&probe, &embs, 1);
1311 assert_eq!(result.total, 5);
1312 assert!(result.detected <= 5);
1313 assert!((result.detection_rate - result.detected as f32 / 5.0).abs() < 1e-5);
1314 }
1315
1316 #[test]
1317 fn test_cov4_ship_gate_all_fail() {
1318 let ci = BootstrapCI { estimate: 0.1, lower: 0.05, upper: 0.15, n_bootstrap: 100 };
1319 let gen =
1320 GeneralizationResult { total: 50, detected: 10, detection_rate: 0.2, passes: false };
1321 let result = check_ship_gate(&ci, 0.90, &gen, EscalationLevel::LinearProbe);
1322 assert!(!result.ship_ready);
1323 assert!(!result.mcc_passes);
1324 assert!(!result.accuracy_passes);
1325 assert!(!result.generalization_passes);
1326 assert_eq!(result.level, EscalationLevel::LinearProbe);
1327 }
1328
1329 #[test]
1330 fn test_cov4_ship_gate_fails_accuracy() {
1331 let ci = BootstrapCI { estimate: 0.4, lower: 0.25, upper: 0.55, n_bootstrap: 100 };
1332 let gen =
1333 GeneralizationResult { total: 50, detected: 30, detection_rate: 0.6, passes: true };
1334 let result = check_ship_gate(&ci, 0.90, &gen, EscalationLevel::TopLayers);
1335 assert!(!result.ship_ready);
1336 assert!(result.mcc_passes);
1337 assert!(!result.accuracy_passes);
1338 assert!(result.generalization_passes);
1339 assert_eq!(result.level, EscalationLevel::TopLayers);
1340 }
1341
1342 #[test]
1343 fn test_cov4_linear_probe_predict() {
1344 let probe = LinearProbe::new(8, 3);
1345 let emb = Tensor::from_vec(vec![0.5; 8], false);
1346 let predicted = probe.predict(&emb);
1347 assert!(predicted < 3);
1348 }
1349
1350 #[test]
1351 fn test_cov4_linear_probe_num_classes() {
1352 let probe = LinearProbe::new(64, 5);
1353 assert_eq!(probe.num_classes(), 5);
1354 }
1355
1356 #[test]
1357 fn test_cov4_linear_probe_train_with_class_weights() {
1358 let mut probe = LinearProbe::new(4, 2);
1359 let embeddings =
1360 vec![vec![1.0; 4]; 10].into_iter().chain(vec![vec![-1.0; 4]; 10]).collect::<Vec<_>>();
1361 let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
1362 let weights = vec![1.0, 5.0]; let loss = probe.train(&embeddings, &labels, 5, 0.01, Some(&weights));
1365 assert!(loss.is_finite());
1366 }
1367
1368 #[test]
1369 fn test_cov4_mlp_probe_predict() {
1370 let probe = MlpProbe::new(8, 16, 3);
1371 let emb = vec![0.1; 8];
1372 let predicted = probe.predict(&emb);
1373 assert!(predicted < 3);
1374 }
1375
1376 #[test]
1377 fn test_cov4_mlp_probe_predict_probs_all_positive() {
1378 let probe = MlpProbe::new(4, 8, 2);
1379 let probs = probe.predict_probs(&[0.5, -0.5, 1.0, -1.0]);
1380 assert!(probs.iter().all(|&p| p > 0.0));
1381 assert!(probs.iter().all(|&p| p <= 1.0));
1382 }
1383
1384 #[test]
1385 fn test_cov4_mlp_probe_num_parameters() {
1386 let probe = MlpProbe::new(16, 8, 3);
1387 assert_eq!(probe.num_parameters(), 16 * 8 + 8 + 8 * 3 + 3);
1389 }
1390
1391 #[test]
1392 fn test_cov4_mlp_probe_train_with_class_weights() {
1393 let mut probe = MlpProbe::new(4, 8, 2);
1394 let embeddings: Vec<Vec<f32>> =
1395 (0..20).map(|i| if i < 10 { vec![1.0; 4] } else { vec![-1.0; 4] }).collect();
1396 let labels: Vec<usize> = (0..20).map(|i| usize::from(i >= 10)).collect();
1397 let weights = vec![1.0, 5.0];
1398
1399 let loss = probe.train(&embeddings, &labels, 5, 0.01, Some(&weights), 0.0);
1400 assert!(loss.is_finite());
1401 }
1402
1403 #[test]
1404 fn test_cov4_mlp_probe_train_with_weight_decay() {
1405 let mut probe = MlpProbe::new(4, 8, 2);
1406 let embeddings: Vec<Vec<f32>> =
1407 (0..10).map(|i| if i < 5 { vec![1.0; 4] } else { vec![-1.0; 4] }).collect();
1408 let labels: Vec<usize> = (0..10).map(|i| usize::from(i >= 5)).collect();
1409
1410 let loss = probe.train(&embeddings, &labels, 5, 0.01, None, 0.01);
1411 assert!(loss.is_finite());
1412 }
1413
1414 #[test]
1415 fn test_cov4_softmax_slice_single() {
1416 let result = softmax_slice(&[0.0]);
1417 assert_eq!(result.len(), 1);
1418 assert!((result[0] - 1.0).abs() < 1e-5);
1419 }
1420
1421 #[test]
1422 fn test_cov4_softmax_slice_large_values() {
1423 let result = softmax_slice(&[1000.0, 1001.0]);
1425 assert_eq!(result.len(), 2);
1426 let sum: f32 = result.iter().sum();
1427 assert!((sum - 1.0).abs() < 1e-5);
1428 assert!(result[1] > result[0]); }
1430
1431 #[test]
1432 fn test_cov4_softmax_slice_equal() {
1433 let result = softmax_slice(&[1.0, 1.0, 1.0]);
1434 for &p in &result {
1435 assert!((p - 1.0 / 3.0).abs() < 1e-5);
1436 }
1437 }
1438
1439 #[test]
1440 fn test_cov4_classification_metrics_clone() {
1441 let m = ClassificationMetrics {
1442 mcc: 0.5,
1443 accuracy: 0.9,
1444 recall: vec![0.8, 0.7],
1445 precision: vec![0.85, 0.75],
1446 num_samples: 100,
1447 confusion_matrix: vec![vec![40, 10], vec![5, 45]],
1448 };
1449 let m2 = m.clone();
1450 assert!((m2.mcc - 0.5).abs() < 1e-5);
1451 assert_eq!(m2.num_samples, 100);
1452 assert!(format!("{m2:?}").contains("ClassificationMetrics"));
1453 }
1454
1455 #[test]
1456 fn test_cov4_bootstrap_ci_clone() {
1457 let ci = BootstrapCI { estimate: 0.5, lower: 0.3, upper: 0.7, n_bootstrap: 1000 };
1458 let ci2 = ci;
1459 assert!((ci2.estimate - 0.5).abs() < 1e-5);
1460 assert!(format!("{ci:?}").contains("BootstrapCI"));
1461 }
1462
1463 #[test]
1464 fn test_cov4_confidence_score_clone() {
1465 let s =
1466 ConfidenceScore { predicted_class: 1, confidence: 0.8, probabilities: vec![0.2, 0.8] };
1467 let s2 = s.clone();
1468 assert_eq!(s2.predicted_class, 1);
1469 assert!((s2.confidence - 0.8).abs() < 1e-5);
1470 assert!(format!("{s2:?}").contains("ConfidenceScore"));
1471 }
1472
1473 #[test]
1474 fn test_cov4_generalization_result_clone() {
1475 let r =
1476 GeneralizationResult { total: 20, detected: 15, detection_rate: 0.75, passes: true };
1477 let r2 = r.clone();
1478 assert!(r2.passes);
1479 assert_eq!(r2.total, 20);
1480 assert!(format!("{r2:?}").contains("GeneralizationResult"));
1481 }
1482
1483 #[test]
1484 fn test_cov4_baseline_comparison_clone() {
1485 let b = BaselineComparison {
1486 name: "test".to_string(),
1487 baseline_mcc: 0.3,
1488 model_mcc: 0.5,
1489 beats_baseline: true,
1490 };
1491 let b2 = b.clone();
1492 assert!(b2.beats_baseline);
1493 assert!(format!("{b2:?}").contains("BaselineComparison"));
1494 }
1495
1496 #[test]
1497 fn test_cov4_ship_gate_result_clone() {
1498 let r = ShipGateResult {
1499 mcc_passes: true,
1500 accuracy_passes: true,
1501 generalization_passes: false,
1502 ship_ready: false,
1503 level: EscalationLevel::LinearProbe,
1504 };
1505 let r2 = r.clone();
1506 assert!(!r2.ship_ready);
1507 assert!(format!("{r2:?}").contains("ShipGateResult"));
1508 }
1509
1510 #[test]
1511 fn test_cov4_multiclass_mcc_single_class() {
1512 let preds = vec![0, 0, 0, 0];
1514 let labels = vec![0, 0, 0, 0];
1515 let metrics = evaluate(&preds, &labels, 3);
1516 assert!((metrics.accuracy - 1.0).abs() < 1e-5);
1517 assert!(metrics.mcc.abs() < 1e-5 || metrics.mcc.is_finite());
1519 }
1520}