1use crate::error::{SslError, SslResult};
21use crate::handle::LcgRng;
22
23#[derive(Debug, Clone)]
27pub struct LinearProbeConfig {
28 pub n_classes: usize,
30 pub n_folds: usize,
32 pub max_iter: usize,
34 pub tol: f64,
36 pub l2_reg: f64,
38 pub seed: u64,
40}
41
42impl Default for LinearProbeConfig {
43 fn default() -> Self {
44 Self {
45 n_classes: 2,
46 n_folds: 5,
47 max_iter: 200,
48 tol: 1e-5,
49 l2_reg: 1e-3,
50 seed: 42,
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
59pub struct LinearProbeResult {
60 pub mean_accuracy: f64,
62 pub std_accuracy: f64,
64 pub per_fold_accuracy: Vec<f64>,
66 pub macro_f1: f64,
68 pub per_class_f1: Vec<f64>,
70}
71
72#[derive(Debug, Clone)]
74pub struct FittedLinearProbe {
75 pub weights: Vec<f64>,
78 pub in_dim: usize,
80 pub n_classes: usize,
82 pub n_iter: Vec<usize>,
84 pub converged: Vec<bool>,
86}
87
88#[inline]
92fn sigmoid(x: f64) -> f64 {
93 if x >= 0.0 {
94 1.0 / (1.0 + (-x).exp())
95 } else {
96 let ex = x.exp();
97 ex / (1.0 + ex)
98 }
99}
100
101fn cholesky_solve(a: &[f64], b: &[f64], n: usize) -> SslResult<Vec<f64>> {
106 debug_assert_eq!(a.len(), n * n);
107 debug_assert_eq!(b.len(), n);
108
109 let mut l = vec![0.0_f64; n * n];
111 for i in 0..n {
112 for j in 0..=i {
113 let mut s = a[i * n + j];
114 for k in 0..j {
115 s -= l[i * n + k] * l[j * n + k];
116 }
117 if i == j {
118 if s <= 0.0 {
119 return Err(SslError::Internal(
120 "cholesky_solve: matrix not positive-definite".into(),
121 ));
122 }
123 l[i * n + j] = s.sqrt();
124 } else {
125 l[i * n + j] = s / l[j * n + j];
126 }
127 }
128 }
129
130 let mut y = vec![0.0_f64; n];
132 for i in 0..n {
133 let mut s = b[i];
134 for k in 0..i {
135 s -= l[i * n + k] * y[k];
136 }
137 y[i] = s / l[i * n + i];
138 }
139
140 let mut x = vec![0.0_f64; n];
142 for i in (0..n).rev() {
143 let mut s = y[i];
144 for k in (i + 1)..n {
145 s -= l[k * n + i] * x[k];
146 }
147 x[i] = s / l[i * n + i];
148 }
149
150 Ok(x)
151}
152
153fn accuracy(predicted: &[usize], truth: &[usize]) -> f64 {
155 if predicted.is_empty() {
156 return 0.0;
157 }
158 let correct = predicted
159 .iter()
160 .zip(truth.iter())
161 .filter(|&(&p, &t)| p == t)
162 .count();
163 correct as f64 / predicted.len() as f64
164}
165
166fn f1_per_class(predicted: &[usize], truth: &[usize], n_classes: usize) -> Vec<f64> {
168 let mut tp = vec![0usize; n_classes];
169 let mut fp = vec![0usize; n_classes];
170 let mut fn_ = vec![0usize; n_classes];
171
172 for (&p, &t) in predicted.iter().zip(truth.iter()) {
173 if p < n_classes && t < n_classes {
174 if p == t {
175 tp[p] += 1;
176 } else {
177 fp[p] += 1;
178 fn_[t] += 1;
179 }
180 }
181 }
182
183 (0..n_classes)
184 .map(|k| {
185 let denom = tp[k] as f64 + 0.5 * (fp[k] + fn_[k]) as f64;
186 if denom < 1e-12 {
187 0.0
188 } else {
189 tp[k] as f64 / denom
190 }
191 })
192 .collect()
193}
194
195fn fisher_yates_shuffle(indices: &mut [usize], rng: &mut LcgRng) {
197 rng.shuffle(indices);
198}
199
200fn irls_binary(
210 x_aug: &[f64],
211 y_bin: &[f64],
212 n: usize,
213 d_aug: usize,
214 config: &LinearProbeConfig,
215) -> SslResult<(Vec<f64>, usize, bool)> {
216 const EPS: f64 = 1e-7;
217
218 let mut w = vec![0.0_f64; d_aug];
219 let mut iters_done = 0usize;
220 let mut converged = false;
221
222 for iter in 0..config.max_iter {
223 let mut p_vec = vec![0.0_f64; n];
225 for i in 0..n {
226 let row = &x_aug[i * d_aug..(i + 1) * d_aug];
227 let eta_i: f64 = row.iter().zip(w.iter()).map(|(&xi, &wi)| xi * wi).sum();
228 p_vec[i] = sigmoid(eta_i).clamp(EPS, 1.0 - EPS);
229 }
230
231 let mut eta_vec = vec![0.0_f64; n];
235 for i in 0..n {
236 let row = &x_aug[i * d_aug..(i + 1) * d_aug];
237 eta_vec[i] = row.iter().zip(w.iter()).map(|(&xi, &wi)| xi * wi).sum();
238 }
239
240 let mut xtwx = vec![0.0_f64; d_aug * d_aug];
243 let mut xtwz = vec![0.0_f64; d_aug];
244
245 for i in 0..n {
246 let p_i = p_vec[i];
247 let w_i = p_i * (1.0 - p_i); let z_i = eta_vec[i] + (y_bin[i] - p_i) / w_i;
249 let row = &x_aug[i * d_aug..(i + 1) * d_aug];
250
251 for r in 0..d_aug {
253 let val_r = w_i * row[r];
254 for c in 0..d_aug {
255 xtwx[r * d_aug + c] += val_r * row[c];
256 }
257 xtwz[r] += val_r * z_i;
258 }
259 }
260
261 for j in 0..d_aug {
263 xtwx[j * d_aug + j] += config.l2_reg;
264 }
265
266 let w_new = cholesky_solve(&xtwx, &xtwz, d_aug)?;
268
269 let delta_norm: f64 = w_new
271 .iter()
272 .zip(w.iter())
273 .map(|(&a, &b)| (a - b) * (a - b))
274 .sum::<f64>()
275 .sqrt();
276 let w_norm: f64 = w.iter().map(|&v| v * v).sum::<f64>().sqrt();
277 let rel = delta_norm / w_norm.max(1.0);
278
279 w = w_new;
280 iters_done = iter + 1;
281
282 if rel < config.tol {
283 converged = true;
284 break;
285 }
286 }
287
288 for &v in &w {
290 if v.is_nan() {
291 return Err(SslError::NanEncountered {
292 location: "irls_binary weight",
293 });
294 }
295 }
296
297 Ok((w, iters_done, converged))
298}
299
300pub fn linear_probe_fit(
312 features: &[f64],
313 labels: &[usize],
314 n_samples: usize,
315 in_dim: usize,
316 config: &LinearProbeConfig,
317) -> SslResult<FittedLinearProbe> {
318 if n_samples == 0 {
320 return Err(SslError::EmptyInput);
321 }
322 if in_dim == 0 {
323 return Err(SslError::InvalidParameter {
324 name: "in_dim".into(),
325 reason: "feature dimension must be > 0".into(),
326 });
327 }
328 if config.n_classes < 2 {
329 return Err(SslError::InvalidParameter {
330 name: "n_classes".into(),
331 reason: "must be >= 2".into(),
332 });
333 }
334 if config.l2_reg < 0.0 || !config.l2_reg.is_finite() {
335 return Err(SslError::InvalidParameter {
336 name: "l2_reg".into(),
337 reason: "must be non-negative and finite".into(),
338 });
339 }
340 if features.len() != n_samples * in_dim {
341 return Err(SslError::DimensionMismatch {
342 expected: n_samples * in_dim,
343 got: features.len(),
344 });
345 }
346 if labels.len() != n_samples {
347 return Err(SslError::DimensionMismatch {
348 expected: n_samples,
349 got: labels.len(),
350 });
351 }
352 for (i, &lbl) in labels.iter().enumerate() {
353 if lbl >= config.n_classes {
354 return Err(SslError::InvalidParameter {
355 name: "labels".into(),
356 reason: format!(
357 "label {} at index {} is out of range [0, {})",
358 lbl, i, config.n_classes
359 ),
360 });
361 }
362 }
363
364 let d_aug = in_dim + 1;
366 let mut x_aug = vec![0.0_f64; n_samples * d_aug];
367 for i in 0..n_samples {
368 let src = &features[i * in_dim..(i + 1) * in_dim];
369 let dst = &mut x_aug[i * d_aug..(i + 1) * d_aug];
370 dst[..in_dim].copy_from_slice(src);
371 dst[in_dim] = 1.0; }
373
374 for (j, &v) in x_aug.iter().enumerate() {
376 if !v.is_finite() {
377 let sample = j / d_aug;
378 let _ = sample; return Err(SslError::NanEncountered {
380 location: "features (augmented)",
381 });
382 }
383 }
384
385 let mut all_weights = vec![0.0_f64; config.n_classes * d_aug];
387 let mut n_iter_per_class = vec![0usize; config.n_classes];
388 let mut converged_per_class = vec![false; config.n_classes];
389
390 for k in 0..config.n_classes {
391 let y_bin: Vec<f64> = labels
392 .iter()
393 .map(|&lbl| if lbl == k { 1.0 } else { 0.0 })
394 .collect();
395
396 let (w_k, iters, conv) = irls_binary(&x_aug, &y_bin, n_samples, d_aug, config)?;
397
398 all_weights[k * d_aug..(k + 1) * d_aug].copy_from_slice(&w_k);
399 n_iter_per_class[k] = iters;
400 converged_per_class[k] = conv;
401 }
402
403 Ok(FittedLinearProbe {
404 weights: all_weights,
405 in_dim,
406 n_classes: config.n_classes,
407 n_iter: n_iter_per_class,
408 converged: converged_per_class,
409 })
410}
411
412pub fn linear_probe_predict(
419 probe: &FittedLinearProbe,
420 features: &[f64],
421 n_samples: usize,
422) -> SslResult<Vec<usize>> {
423 let d_aug = probe.in_dim + 1;
424
425 if features.len() != n_samples * probe.in_dim {
426 return Err(SslError::DimensionMismatch {
427 expected: n_samples * probe.in_dim,
428 got: features.len(),
429 });
430 }
431
432 let mut predictions = vec![0usize; n_samples];
433 for i in 0..n_samples {
434 let src = &features[i * probe.in_dim..(i + 1) * probe.in_dim];
435
436 let mut x_aug = vec![0.0_f64; d_aug];
438 x_aug[..probe.in_dim].copy_from_slice(src);
439 x_aug[probe.in_dim] = 1.0;
440
441 let mut best_class = 0usize;
443 let mut best_score = f64::NEG_INFINITY;
444 for k in 0..probe.n_classes {
445 let w_k = &probe.weights[k * d_aug..(k + 1) * d_aug];
446 let eta: f64 = w_k.iter().zip(x_aug.iter()).map(|(&w, &x)| w * x).sum();
447 let score = sigmoid(eta);
448 if score > best_score {
449 best_score = score;
450 best_class = k;
451 }
452 }
453 predictions[i] = best_class;
454 }
455
456 Ok(predictions)
457}
458
459pub fn linear_probe_eval(
468 features: &[f64],
469 labels: &[usize],
470 n_samples: usize,
471 in_dim: usize,
472 config: &LinearProbeConfig,
473) -> SslResult<LinearProbeResult> {
474 if n_samples == 0 {
475 return Err(SslError::EmptyInput);
476 }
477 if config.n_folds < 2 {
478 return Err(SslError::InvalidParameter {
479 name: "n_folds".into(),
480 reason: "must be >= 2".into(),
481 });
482 }
483 if n_samples < config.n_folds {
484 return Err(SslError::BatchTooSmall);
485 }
486
487 let mut indices: Vec<usize> = (0..n_samples).collect();
489 let mut rng = LcgRng::new(config.seed);
490 fisher_yates_shuffle(&mut indices, &mut rng);
491
492 let fold_size = n_samples / config.n_folds;
495 let mut fold_starts = Vec::with_capacity(config.n_folds + 1);
496 for f in 0..config.n_folds {
497 fold_starts.push(f * fold_size);
498 }
499 fold_starts.push(n_samples); let mut per_fold_accuracy = Vec::with_capacity(config.n_folds);
503 let mut per_class_f1_sum = vec![0.0_f64; config.n_classes];
505
506 for fold_idx in 0..config.n_folds {
507 let val_start = fold_starts[fold_idx];
508 let val_end = fold_starts[fold_idx + 1];
509
510 let val_indices: Vec<usize> = indices[val_start..val_end].to_vec();
512 let train_indices: Vec<usize> = indices[..val_start]
513 .iter()
514 .chain(&indices[val_end..])
515 .copied()
516 .collect();
517
518 let n_train = train_indices.len();
519 let n_val = val_indices.len();
520
521 if n_train == 0 || n_val == 0 {
522 return Err(SslError::BatchTooSmall);
523 }
524
525 let mut train_feat = vec![0.0_f64; n_train * in_dim];
527 let mut train_lbl = vec![0usize; n_train];
528 for (out_i, &src_i) in train_indices.iter().enumerate() {
529 train_feat[out_i * in_dim..(out_i + 1) * in_dim]
530 .copy_from_slice(&features[src_i * in_dim..(src_i + 1) * in_dim]);
531 train_lbl[out_i] = labels[src_i];
532 }
533
534 let mut val_feat = vec![0.0_f64; n_val * in_dim];
535 let mut val_lbl = vec![0usize; n_val];
536 for (out_i, &src_i) in val_indices.iter().enumerate() {
537 val_feat[out_i * in_dim..(out_i + 1) * in_dim]
538 .copy_from_slice(&features[src_i * in_dim..(src_i + 1) * in_dim]);
539 val_lbl[out_i] = labels[src_i];
540 }
541
542 let probe = linear_probe_fit(&train_feat, &train_lbl, n_train, in_dim, config)?;
544 let preds = linear_probe_predict(&probe, &val_feat, n_val)?;
545
546 let fold_acc = accuracy(&preds, &val_lbl);
548 per_fold_accuracy.push(fold_acc);
549
550 let f1s = f1_per_class(&preds, &val_lbl, config.n_classes);
551 for (k, &f1_k) in f1s.iter().enumerate() {
552 per_class_f1_sum[k] += f1_k;
553 }
554 }
555
556 let mean_accuracy = per_fold_accuracy.iter().sum::<f64>() / config.n_folds as f64;
558
559 let variance = per_fold_accuracy
560 .iter()
561 .map(|&a| {
562 let d = a - mean_accuracy;
563 d * d
564 })
565 .sum::<f64>()
566 / config.n_folds as f64;
567 let std_accuracy = variance.sqrt();
568
569 let per_class_f1: Vec<f64> = per_class_f1_sum
570 .iter()
571 .map(|&s| s / config.n_folds as f64)
572 .collect();
573
574 let macro_f1 = per_class_f1.iter().sum::<f64>() / config.n_classes as f64;
575
576 Ok(LinearProbeResult {
577 mean_accuracy,
578 std_accuracy,
579 per_fold_accuracy,
580 macro_f1,
581 per_class_f1,
582 })
583}
584
585#[cfg(test)]
588mod tests {
589 use super::*;
590
591 fn make_binary_separable(n: usize, dim: usize, offset: f64) -> (Vec<f64>, Vec<usize>) {
596 let half = n / 2;
597 let mut feats = vec![0.0_f64; n * dim];
598 let mut lbls = vec![0usize; n];
599 for i in half..n {
600 feats[i * dim] = offset;
601 lbls[i] = 1;
602 }
603 (feats, lbls)
604 }
605
606 fn make_multiclass_separable(n_per_class: usize, dim: usize) -> (Vec<f64>, Vec<usize>) {
608 let n = n_per_class * 3;
609 let mut feats = vec![0.0_f64; n * dim];
610 let mut lbls = vec![0usize; n];
611 for k in 0..3usize {
612 for i in 0..n_per_class {
613 let row = k * n_per_class + i;
614 feats[row * dim + k.min(dim - 1)] = (k + 1) as f64 * 20.0;
616 lbls[row] = k;
617 }
618 }
619 (feats, lbls)
620 }
621
622 #[test]
625 fn config_defaults() {
626 let cfg = LinearProbeConfig::default();
627 assert_eq!(cfg.n_folds, 5);
628 assert_eq!(cfg.max_iter, 200);
629 assert!((cfg.l2_reg - 1e-3).abs() < 1e-15);
630 assert_eq!(cfg.n_classes, 2);
631 assert!((cfg.tol - 1e-5).abs() < 1e-18);
632 assert_eq!(cfg.seed, 42);
633 }
634
635 #[test]
638 fn sigmoid_stable() {
639 assert!((sigmoid(0.0) - 0.5).abs() < 1e-15);
640 assert!((sigmoid(100.0) - 1.0).abs() < 1e-6);
641 assert!(sigmoid(-100.0) < 1e-6);
642 assert!(sigmoid(f64::MAX / 2.0).is_finite());
644 assert!(sigmoid(f64::MIN / 2.0).is_finite());
645 }
646
647 #[test]
650 fn fit_empty_error() {
651 let cfg = LinearProbeConfig::default();
652 let result = linear_probe_fit(&[], &[], 0, 4, &cfg);
653 assert!(matches!(result, Err(SslError::EmptyInput)));
654 }
655
656 #[test]
659 fn fit_single_class_error() {
660 let cfg = LinearProbeConfig {
661 n_classes: 1,
662 ..Default::default()
663 };
664 let feats = vec![0.0_f64; 10 * 4];
665 let lbls = vec![0usize; 10];
666 let result = linear_probe_fit(&feats, &lbls, 10, 4, &cfg);
667 assert!(matches!(
668 result,
669 Err(SslError::InvalidParameter { name: _, reason: _ })
670 ));
671 }
672
673 #[test]
676 fn fit_binary_linearly_separable() {
677 let cfg = LinearProbeConfig {
678 n_classes: 2,
679 max_iter: 200,
680 l2_reg: 1e-4,
681 ..Default::default()
682 };
683 let (feats, lbls) = make_binary_separable(20, 2, 10.0);
684 let probe =
685 linear_probe_fit(&feats, &lbls, 20, 2, &cfg).expect("linear_probe_fit should succeed");
686 let preds =
687 linear_probe_predict(&probe, &feats, 20).expect("linear_probe_predict should succeed");
688 let acc = accuracy(&preds, &lbls);
689 assert!(
690 acc >= 0.9,
691 "expected accuracy >= 0.9 on separable data, got {acc:.4}"
692 );
693 }
694
695 #[test]
698 fn predict_shape() {
699 let cfg = LinearProbeConfig::default();
700 let (feats, lbls) = make_binary_separable(20, 4, 5.0);
701 let probe =
702 linear_probe_fit(&feats, &lbls, 20, 4, &cfg).expect("linear_probe_fit should succeed");
703 let preds =
704 linear_probe_predict(&probe, &feats, 20).expect("linear_probe_predict should succeed");
705 assert_eq!(preds.len(), 20);
706 }
707
708 #[test]
711 fn fit_multiclass() {
712 let cfg = LinearProbeConfig {
713 n_classes: 3,
714 max_iter: 300,
715 l2_reg: 1e-4,
716 ..Default::default()
717 };
718 let (feats, lbls) = make_multiclass_separable(10, 4);
719 let probe =
720 linear_probe_fit(&feats, &lbls, 30, 4, &cfg).expect("linear_probe_fit should succeed");
721 let preds =
722 linear_probe_predict(&probe, &feats, 30).expect("linear_probe_predict should succeed");
723 let acc = accuracy(&preds, &lbls);
724 assert!(
725 (acc - 1.0).abs() < 1e-9,
726 "expected perfect accuracy, got {acc:.4}"
727 );
728 }
729
730 #[test]
733 fn fit_returns_n_class_rows() {
734 let cfg = LinearProbeConfig {
735 n_classes: 3,
736 ..Default::default()
737 };
738 let in_dim = 5;
739 let (feats, lbls) = make_multiclass_separable(5, in_dim);
740 let probe = linear_probe_fit(&feats, &lbls, 15, in_dim, &cfg)
741 .expect("linear_probe_fit should succeed");
742 assert_eq!(probe.weights.len(), cfg.n_classes * (in_dim + 1));
743 assert_eq!(probe.in_dim, in_dim);
744 assert_eq!(probe.n_classes, cfg.n_classes);
745 }
746
747 #[test]
750 fn eval_cv_mean_accuracy_positive() {
751 let cfg = LinearProbeConfig {
752 n_classes: 2,
753 n_folds: 5,
754 max_iter: 200,
755 l2_reg: 1e-4,
756 ..Default::default()
757 };
758 let (feats, lbls) = make_binary_separable(50, 4, 10.0);
760 let result = linear_probe_eval(&feats, &lbls, 50, 4, &cfg)
761 .expect("linear_probe_eval should succeed");
762 assert!(
763 result.mean_accuracy > 0.8,
764 "expected mean_accuracy > 0.8, got {:.4}",
765 result.mean_accuracy
766 );
767 }
768
769 #[test]
772 fn eval_std_accuracy_finite() {
773 let cfg = LinearProbeConfig {
774 n_classes: 2,
775 n_folds: 5,
776 l2_reg: 1e-3,
777 ..Default::default()
778 };
779 let (feats, lbls) = make_binary_separable(50, 4, 10.0);
780 let result = linear_probe_eval(&feats, &lbls, 50, 4, &cfg)
781 .expect("linear_probe_eval should succeed");
782 assert!(result.std_accuracy.is_finite());
783 assert!(result.std_accuracy >= 0.0);
784 }
785
786 #[test]
789 fn eval_macro_f1_range() {
790 let cfg = LinearProbeConfig {
791 n_classes: 2,
792 n_folds: 5,
793 l2_reg: 1e-3,
794 ..Default::default()
795 };
796 let (feats, lbls) = make_binary_separable(50, 4, 10.0);
797 let result = linear_probe_eval(&feats, &lbls, 50, 4, &cfg)
798 .expect("linear_probe_eval should succeed");
799 assert!(
800 result.macro_f1 >= 0.0 && result.macro_f1 <= 1.0,
801 "macro_f1 = {:.4} out of [0, 1]",
802 result.macro_f1
803 );
804 }
805
806 #[test]
809 fn per_class_f1_length() {
810 let cfg = LinearProbeConfig {
811 n_classes: 3,
812 n_folds: 3,
813 l2_reg: 1e-3,
814 ..Default::default()
815 };
816 let (feats, lbls) = make_multiclass_separable(15, 4);
817 let result = linear_probe_eval(&feats, &lbls, 45, 4, &cfg)
818 .expect("linear_probe_eval should succeed");
819 assert_eq!(result.per_class_f1.len(), 3);
820 }
821
822 #[test]
825 fn cholesky_solve_identity() {
826 let n = 4;
827 let mut a = vec![0.0_f64; n * n];
828 for i in 0..n {
829 a[i * n + i] = 1.0;
830 }
831 let b = vec![1.0, -2.0, std::f64::consts::PI, 0.0];
832 let x = cholesky_solve(&a, &b, n).expect("cholesky_solve should succeed");
833 for (xi, bi) in x.iter().zip(b.iter()) {
834 assert!((xi - bi).abs() < 1e-12, "expected x={bi}, got {xi}");
835 }
836 }
837
838 #[test]
841 fn cholesky_solve_spd_3x3() {
842 let a = vec![4.0, 2.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 6.0];
844 let b = vec![1.0, 2.0, 3.0];
845 let x = cholesky_solve(&a, &b, 3).expect("cholesky_solve should succeed");
846 let ax0 = 4.0 * x[0] + 2.0 * x[1] + 1.0 * x[2];
848 let ax1 = 2.0 * x[0] + 5.0 * x[1] + 3.0 * x[2];
849 let ax2 = 1.0 * x[0] + 3.0 * x[1] + 6.0 * x[2];
850 assert!((ax0 - 1.0).abs() < 1e-10);
851 assert!((ax1 - 2.0).abs() < 1e-10);
852 assert!((ax2 - 3.0).abs() < 1e-10);
853 }
854}