Skip to main content

oxicuda_ssl/head/
linear_probe.rs

1//! Linear probing evaluation helper for SSL representations.
2//!
3//! Implements a full one-vs-all (OVA) multiclass logistic regression via
4//! Iteratively Reweighted Least Squares (IRLS), with k-fold cross-validation.
5//!
6//! # Protocol
7//! Freeze the backbone, extract features once, then fit a linear classifier
8//! (logistic regression) on the features.  This is the standard SSL evaluation
9//! protocol: a higher accuracy indicates a richer representation.
10//!
11//! # Algorithm
12//! * **OVA-IRLS** — one binary IRLS logistic regression per class.
13//! * **Augmented features** — a 1 is appended to each feature vector so that
14//!   the bias term is absorbed into the weight vector.
15//! * **Regularisation** — isotropic L2 (λ·I) added to the Gram matrix.
16//! * **Cholesky WLS** — each IRLS step solves a (D+1)×(D+1) system via
17//!   Cholesky decomposition instead of a full matrix inversion.
18//! * **k-Fold CV** — Fisher-Yates shuffle then split into contiguous folds.
19
20use crate::error::{SslError, SslResult};
21use crate::handle::LcgRng;
22
23// ─── Config ──────────────────────────────────────────────────────────────────
24
25/// Configuration for the linear probing evaluator.
26#[derive(Debug, Clone)]
27pub struct LinearProbeConfig {
28    /// Number of target classes.
29    pub n_classes: usize,
30    /// Number of CV folds.
31    pub n_folds: usize,
32    /// Maximum IRLS iterations per binary sub-problem.
33    pub max_iter: usize,
34    /// Convergence tolerance on the relative weight update.
35    pub tol: f64,
36    /// L2 regularisation strength λ.
37    pub l2_reg: f64,
38    /// Seed for the fold-shuffling RNG.
39    pub seed: u64,
40}
41
42impl Default for LinearProbeConfig {
43    fn default() -> Self {
44        Self {
45            n_classes: 2,
46            n_folds: 5,
47            max_iter: 200,
48            tol: 1e-5,
49            l2_reg: 1e-3,
50            seed: 42,
51        }
52    }
53}
54
55// ─── Result / fitted struct ───────────────────────────────────────────────────
56
57/// Summary of a k-fold cross-validation linear probing run.
58#[derive(Debug, Clone)]
59pub struct LinearProbeResult {
60    /// Mean accuracy across all folds.
61    pub mean_accuracy: f64,
62    /// Standard deviation of per-fold accuracies.
63    pub std_accuracy: f64,
64    /// Per-fold accuracy values (length = n_folds).
65    pub per_fold_accuracy: Vec<f64>,
66    /// Macro-averaged F1 score across all folds (mean per-class harmonic mean).
67    pub macro_f1: f64,
68    /// Per-class F1 scores (length = n_classes).
69    pub per_class_f1: Vec<f64>,
70}
71
72/// A fitted one-vs-all logistic regression model.
73#[derive(Debug, Clone)]
74pub struct FittedLinearProbe {
75    /// Weight matrix stored row-major `[n_classes × (in_dim + 1)]`.
76    /// The last column of each row is the absorbed bias term.
77    pub weights: Vec<f64>,
78    /// Feature dimensionality (before bias augmentation).
79    pub in_dim: usize,
80    /// Number of classes.
81    pub n_classes: usize,
82    /// IRLS iterations taken per binary sub-problem.
83    pub n_iter: Vec<usize>,
84    /// Whether each binary sub-problem converged within `max_iter`.
85    pub converged: Vec<bool>,
86}
87
88// ─── Private numerics ─────────────────────────────────────────────────────────
89
90/// Numerically stable sigmoid: avoids overflow for large |x|.
91#[inline]
92fn sigmoid(x: f64) -> f64 {
93    if x >= 0.0 {
94        1.0 / (1.0 + (-x).exp())
95    } else {
96        let ex = x.exp();
97        ex / (1.0 + ex)
98    }
99}
100
101/// Cholesky decomposition + forward/back substitution to solve `A·x = b`.
102///
103/// `a` is a row-major n×n **symmetric positive-definite** matrix.
104/// Returns `Err(Internal)` if the matrix is not positive-definite.
105fn cholesky_solve(a: &[f64], b: &[f64], n: usize) -> SslResult<Vec<f64>> {
106    debug_assert_eq!(a.len(), n * n);
107    debug_assert_eq!(b.len(), n);
108
109    // ── Cholesky factorisation A = L·Lᵀ (lower triangular L in-place) ──────
110    let mut l = vec![0.0_f64; n * n];
111    for i in 0..n {
112        for j in 0..=i {
113            let mut s = a[i * n + j];
114            for k in 0..j {
115                s -= l[i * n + k] * l[j * n + k];
116            }
117            if i == j {
118                if s <= 0.0 {
119                    return Err(SslError::Internal(
120                        "cholesky_solve: matrix not positive-definite".into(),
121                    ));
122                }
123                l[i * n + j] = s.sqrt();
124            } else {
125                l[i * n + j] = s / l[j * n + j];
126            }
127        }
128    }
129
130    // ── Forward substitution: solve L·y = b ─────────────────────────────────
131    let mut y = vec![0.0_f64; n];
132    for i in 0..n {
133        let mut s = b[i];
134        for k in 0..i {
135            s -= l[i * n + k] * y[k];
136        }
137        y[i] = s / l[i * n + i];
138    }
139
140    // ── Back substitution: solve Lᵀ·x = y ───────────────────────────────────
141    let mut x = vec![0.0_f64; n];
142    for i in (0..n).rev() {
143        let mut s = y[i];
144        for k in (i + 1)..n {
145            s -= l[k * n + i] * x[k];
146        }
147        x[i] = s / l[i * n + i];
148    }
149
150    Ok(x)
151}
152
153/// Accuracy: fraction of elements where `predicted[i] == truth[i]`.
154fn accuracy(predicted: &[usize], truth: &[usize]) -> f64 {
155    if predicted.is_empty() {
156        return 0.0;
157    }
158    let correct = predicted
159        .iter()
160        .zip(truth.iter())
161        .filter(|&(&p, &t)| p == t)
162        .count();
163    correct as f64 / predicted.len() as f64
164}
165
166/// Per-class F1 = TP / (TP + 0.5·(FP + FN)).
167fn f1_per_class(predicted: &[usize], truth: &[usize], n_classes: usize) -> Vec<f64> {
168    let mut tp = vec![0usize; n_classes];
169    let mut fp = vec![0usize; n_classes];
170    let mut fn_ = vec![0usize; n_classes];
171
172    for (&p, &t) in predicted.iter().zip(truth.iter()) {
173        if p < n_classes && t < n_classes {
174            if p == t {
175                tp[p] += 1;
176            } else {
177                fp[p] += 1;
178                fn_[t] += 1;
179            }
180        }
181    }
182
183    (0..n_classes)
184        .map(|k| {
185            let denom = tp[k] as f64 + 0.5 * (fp[k] + fn_[k]) as f64;
186            if denom < 1e-12 {
187                0.0
188            } else {
189                tp[k] as f64 / denom
190            }
191        })
192        .collect()
193}
194
195/// Fisher-Yates in-place shuffle using `LcgRng`.
196fn fisher_yates_shuffle(indices: &mut [usize], rng: &mut LcgRng) {
197    rng.shuffle(indices);
198}
199
200// ─── Core IRLS binary logistic regression ─────────────────────────────────────
201
202/// Fit a single binary (OVA) logistic regression for class `k` against the
203/// rest using IRLS.
204///
205/// `x_aug` — augmented feature matrix [n × (d+1)] (row-major).
206/// `y_bin` — binary labels for this class (0 or 1), length n.
207///
208/// Returns `(weights, iterations_taken, converged)`.
209fn irls_binary(
210    x_aug: &[f64],
211    y_bin: &[f64],
212    n: usize,
213    d_aug: usize,
214    config: &LinearProbeConfig,
215) -> SslResult<(Vec<f64>, usize, bool)> {
216    const EPS: f64 = 1e-7;
217
218    let mut w = vec![0.0_f64; d_aug];
219    let mut iters_done = 0usize;
220    let mut converged = false;
221
222    for iter in 0..config.max_iter {
223        // ── Step 1-3: compute η_i, p_i, W_i ─────────────────────────────────
224        let mut p_vec = vec![0.0_f64; n];
225        for i in 0..n {
226            let row = &x_aug[i * d_aug..(i + 1) * d_aug];
227            let eta_i: f64 = row.iter().zip(w.iter()).map(|(&xi, &wi)| xi * wi).sum();
228            p_vec[i] = sigmoid(eta_i).clamp(EPS, 1.0 - EPS);
229        }
230
231        // ── Step 4: working response z_i ─────────────────────────────────────
232        // z_i = η_i + (y_i - p_i) / (p_i · (1 - p_i))
233        // but we recompute η_i from w to keep things numerically fresh.
234        let mut eta_vec = vec![0.0_f64; n];
235        for i in 0..n {
236            let row = &x_aug[i * d_aug..(i + 1) * d_aug];
237            eta_vec[i] = row.iter().zip(w.iter()).map(|(&xi, &wi)| xi * wi).sum();
238        }
239
240        // ── Step 5: WLS normal equations ──────────────────────────────────────
241        // Accumulate Xᵀ W X and Xᵀ W z rank-1.
242        let mut xtwx = vec![0.0_f64; d_aug * d_aug];
243        let mut xtwz = vec![0.0_f64; d_aug];
244
245        for i in 0..n {
246            let p_i = p_vec[i];
247            let w_i = p_i * (1.0 - p_i); // IRLS weight
248            let z_i = eta_vec[i] + (y_bin[i] - p_i) / w_i;
249            let row = &x_aug[i * d_aug..(i + 1) * d_aug];
250
251            // rank-1 update of Xᵀ W X
252            for r in 0..d_aug {
253                let val_r = w_i * row[r];
254                for c in 0..d_aug {
255                    xtwx[r * d_aug + c] += val_r * row[c];
256                }
257                xtwz[r] += val_r * z_i;
258            }
259        }
260
261        // Add λ·I (L2 regularisation).
262        for j in 0..d_aug {
263            xtwx[j * d_aug + j] += config.l2_reg;
264        }
265
266        // Solve (Xᵀ W X + λI) · w_new = Xᵀ W z.
267        let w_new = cholesky_solve(&xtwx, &xtwz, d_aug)?;
268
269        // ── Step 6: convergence check ─────────────────────────────────────────
270        let delta_norm: f64 = w_new
271            .iter()
272            .zip(w.iter())
273            .map(|(&a, &b)| (a - b) * (a - b))
274            .sum::<f64>()
275            .sqrt();
276        let w_norm: f64 = w.iter().map(|&v| v * v).sum::<f64>().sqrt();
277        let rel = delta_norm / w_norm.max(1.0);
278
279        w = w_new;
280        iters_done = iter + 1;
281
282        if rel < config.tol {
283            converged = true;
284            break;
285        }
286    }
287
288    // Validate no NaN leaked through.
289    for &v in &w {
290        if v.is_nan() {
291            return Err(SslError::NanEncountered {
292                location: "irls_binary weight",
293            });
294        }
295    }
296
297    Ok((w, iters_done, converged))
298}
299
300// ─── Public API ───────────────────────────────────────────────────────────────
301
302/// Fit a one-vs-all multiclass logistic regression on `(features, labels)`.
303///
304/// * `features` — row-major `[n_samples × in_dim]` slice.
305/// * `labels`   — class indices in `0..config.n_classes`, length `n_samples`.
306///
307/// # Errors
308/// Returns [`SslError::EmptyInput`] if `n_samples == 0`,
309/// [`SslError::InvalidParameter`] for degenerate configuration, or
310/// [`SslError::DimensionMismatch`] on shape mismatches.
311pub fn linear_probe_fit(
312    features: &[f64],
313    labels: &[usize],
314    n_samples: usize,
315    in_dim: usize,
316    config: &LinearProbeConfig,
317) -> SslResult<FittedLinearProbe> {
318    // ── Validation ────────────────────────────────────────────────────────────
319    if n_samples == 0 {
320        return Err(SslError::EmptyInput);
321    }
322    if in_dim == 0 {
323        return Err(SslError::InvalidParameter {
324            name: "in_dim".into(),
325            reason: "feature dimension must be > 0".into(),
326        });
327    }
328    if config.n_classes < 2 {
329        return Err(SslError::InvalidParameter {
330            name: "n_classes".into(),
331            reason: "must be >= 2".into(),
332        });
333    }
334    if config.l2_reg < 0.0 || !config.l2_reg.is_finite() {
335        return Err(SslError::InvalidParameter {
336            name: "l2_reg".into(),
337            reason: "must be non-negative and finite".into(),
338        });
339    }
340    if features.len() != n_samples * in_dim {
341        return Err(SslError::DimensionMismatch {
342            expected: n_samples * in_dim,
343            got: features.len(),
344        });
345    }
346    if labels.len() != n_samples {
347        return Err(SslError::DimensionMismatch {
348            expected: n_samples,
349            got: labels.len(),
350        });
351    }
352    for (i, &lbl) in labels.iter().enumerate() {
353        if lbl >= config.n_classes {
354            return Err(SslError::InvalidParameter {
355                name: "labels".into(),
356                reason: format!(
357                    "label {} at index {} is out of range [0, {})",
358                    lbl, i, config.n_classes
359                ),
360            });
361        }
362    }
363
364    // ── Build augmented feature matrix [n × (D+1)] ───────────────────────────
365    let d_aug = in_dim + 1;
366    let mut x_aug = vec![0.0_f64; n_samples * d_aug];
367    for i in 0..n_samples {
368        let src = &features[i * in_dim..(i + 1) * in_dim];
369        let dst = &mut x_aug[i * d_aug..(i + 1) * d_aug];
370        dst[..in_dim].copy_from_slice(src);
371        dst[in_dim] = 1.0; // bias
372    }
373
374    // Check for NaN/Inf in features.
375    for (j, &v) in x_aug.iter().enumerate() {
376        if !v.is_finite() {
377            let sample = j / d_aug;
378            let _ = sample; // suppress unused var warning if debug assert only
379            return Err(SslError::NanEncountered {
380                location: "features (augmented)",
381            });
382        }
383    }
384
385    // ── Fit one binary classifier per class ───────────────────────────────────
386    let mut all_weights = vec![0.0_f64; config.n_classes * d_aug];
387    let mut n_iter_per_class = vec![0usize; config.n_classes];
388    let mut converged_per_class = vec![false; config.n_classes];
389
390    for k in 0..config.n_classes {
391        let y_bin: Vec<f64> = labels
392            .iter()
393            .map(|&lbl| if lbl == k { 1.0 } else { 0.0 })
394            .collect();
395
396        let (w_k, iters, conv) = irls_binary(&x_aug, &y_bin, n_samples, d_aug, config)?;
397
398        all_weights[k * d_aug..(k + 1) * d_aug].copy_from_slice(&w_k);
399        n_iter_per_class[k] = iters;
400        converged_per_class[k] = conv;
401    }
402
403    Ok(FittedLinearProbe {
404        weights: all_weights,
405        in_dim,
406        n_classes: config.n_classes,
407        n_iter: n_iter_per_class,
408        converged: converged_per_class,
409    })
410}
411
412/// Predict class labels for `n_samples` feature vectors.
413///
414/// Uses argmax over OVA sigmoid scores.
415///
416/// # Errors
417/// Returns [`SslError::DimensionMismatch`] if `features.len() != n_samples * probe.in_dim`.
418pub fn linear_probe_predict(
419    probe: &FittedLinearProbe,
420    features: &[f64],
421    n_samples: usize,
422) -> SslResult<Vec<usize>> {
423    let d_aug = probe.in_dim + 1;
424
425    if features.len() != n_samples * probe.in_dim {
426        return Err(SslError::DimensionMismatch {
427            expected: n_samples * probe.in_dim,
428            got: features.len(),
429        });
430    }
431
432    let mut predictions = vec![0usize; n_samples];
433    for i in 0..n_samples {
434        let src = &features[i * probe.in_dim..(i + 1) * probe.in_dim];
435
436        // Build augmented row.
437        let mut x_aug = vec![0.0_f64; d_aug];
438        x_aug[..probe.in_dim].copy_from_slice(src);
439        x_aug[probe.in_dim] = 1.0;
440
441        // Compute sigmoid score for each class and pick argmax.
442        let mut best_class = 0usize;
443        let mut best_score = f64::NEG_INFINITY;
444        for k in 0..probe.n_classes {
445            let w_k = &probe.weights[k * d_aug..(k + 1) * d_aug];
446            let eta: f64 = w_k.iter().zip(x_aug.iter()).map(|(&w, &x)| w * x).sum();
447            let score = sigmoid(eta);
448            if score > best_score {
449                best_score = score;
450                best_class = k;
451            }
452        }
453        predictions[i] = best_class;
454    }
455
456    Ok(predictions)
457}
458
459/// k-Fold cross-validation linear probing evaluation.
460///
461/// Shuffles `n_samples` indices with `LcgRng::new(config.seed)`, splits into
462/// `config.n_folds` contiguous folds, trains on the remaining folds, evaluates
463/// on the held-out fold, and aggregates accuracy + macro-F1.
464///
465/// # Errors
466/// Propagates any errors from [`linear_probe_fit`] or [`linear_probe_predict`].
467pub fn linear_probe_eval(
468    features: &[f64],
469    labels: &[usize],
470    n_samples: usize,
471    in_dim: usize,
472    config: &LinearProbeConfig,
473) -> SslResult<LinearProbeResult> {
474    if n_samples == 0 {
475        return Err(SslError::EmptyInput);
476    }
477    if config.n_folds < 2 {
478        return Err(SslError::InvalidParameter {
479            name: "n_folds".into(),
480            reason: "must be >= 2".into(),
481        });
482    }
483    if n_samples < config.n_folds {
484        return Err(SslError::BatchTooSmall);
485    }
486
487    // ── Shuffle indices ───────────────────────────────────────────────────────
488    let mut indices: Vec<usize> = (0..n_samples).collect();
489    let mut rng = LcgRng::new(config.seed);
490    fisher_yates_shuffle(&mut indices, &mut rng);
491
492    // ── Build fold boundaries ─────────────────────────────────────────────────
493    // Each fold gets floor(n/k) elements; the last fold absorbs the remainder.
494    let fold_size = n_samples / config.n_folds;
495    let mut fold_starts = Vec::with_capacity(config.n_folds + 1);
496    for f in 0..config.n_folds {
497        fold_starts.push(f * fold_size);
498    }
499    fold_starts.push(n_samples); // sentinel end
500
501    // ── Per-fold evaluation ───────────────────────────────────────────────────
502    let mut per_fold_accuracy = Vec::with_capacity(config.n_folds);
503    // Accumulate per-class F1 across folds (we'll average at the end).
504    let mut per_class_f1_sum = vec![0.0_f64; config.n_classes];
505
506    for fold_idx in 0..config.n_folds {
507        let val_start = fold_starts[fold_idx];
508        let val_end = fold_starts[fold_idx + 1];
509
510        // Collect validation indices and training indices.
511        let val_indices: Vec<usize> = indices[val_start..val_end].to_vec();
512        let train_indices: Vec<usize> = indices[..val_start]
513            .iter()
514            .chain(&indices[val_end..])
515            .copied()
516            .collect();
517
518        let n_train = train_indices.len();
519        let n_val = val_indices.len();
520
521        if n_train == 0 || n_val == 0 {
522            return Err(SslError::BatchTooSmall);
523        }
524
525        // Build train/val feature arrays.
526        let mut train_feat = vec![0.0_f64; n_train * in_dim];
527        let mut train_lbl = vec![0usize; n_train];
528        for (out_i, &src_i) in train_indices.iter().enumerate() {
529            train_feat[out_i * in_dim..(out_i + 1) * in_dim]
530                .copy_from_slice(&features[src_i * in_dim..(src_i + 1) * in_dim]);
531            train_lbl[out_i] = labels[src_i];
532        }
533
534        let mut val_feat = vec![0.0_f64; n_val * in_dim];
535        let mut val_lbl = vec![0usize; n_val];
536        for (out_i, &src_i) in val_indices.iter().enumerate() {
537            val_feat[out_i * in_dim..(out_i + 1) * in_dim]
538                .copy_from_slice(&features[src_i * in_dim..(src_i + 1) * in_dim]);
539            val_lbl[out_i] = labels[src_i];
540        }
541
542        // Fit and predict.
543        let probe = linear_probe_fit(&train_feat, &train_lbl, n_train, in_dim, config)?;
544        let preds = linear_probe_predict(&probe, &val_feat, n_val)?;
545
546        // Metrics.
547        let fold_acc = accuracy(&preds, &val_lbl);
548        per_fold_accuracy.push(fold_acc);
549
550        let f1s = f1_per_class(&preds, &val_lbl, config.n_classes);
551        for (k, &f1_k) in f1s.iter().enumerate() {
552            per_class_f1_sum[k] += f1_k;
553        }
554    }
555
556    // ── Aggregate ─────────────────────────────────────────────────────────────
557    let mean_accuracy = per_fold_accuracy.iter().sum::<f64>() / config.n_folds as f64;
558
559    let variance = per_fold_accuracy
560        .iter()
561        .map(|&a| {
562            let d = a - mean_accuracy;
563            d * d
564        })
565        .sum::<f64>()
566        / config.n_folds as f64;
567    let std_accuracy = variance.sqrt();
568
569    let per_class_f1: Vec<f64> = per_class_f1_sum
570        .iter()
571        .map(|&s| s / config.n_folds as f64)
572        .collect();
573
574    let macro_f1 = per_class_f1.iter().sum::<f64>() / config.n_classes as f64;
575
576    Ok(LinearProbeResult {
577        mean_accuracy,
578        std_accuracy,
579        per_fold_accuracy,
580        macro_f1,
581        per_class_f1,
582    })
583}
584
585// ─── Tests ────────────────────────────────────────────────────────────────────
586
587#[cfg(test)]
588mod tests {
589    use super::*;
590
591    // ── helpers ──────────────────────────────────────────────────────────────
592
593    /// Build a linearly separable binary dataset:
594    /// first `n/2` samples at origin (label 0), last `n/2` samples at [c, 0, …] (label 1).
595    fn make_binary_separable(n: usize, dim: usize, offset: f64) -> (Vec<f64>, Vec<usize>) {
596        let half = n / 2;
597        let mut feats = vec![0.0_f64; n * dim];
598        let mut lbls = vec![0usize; n];
599        for i in half..n {
600            feats[i * dim] = offset;
601            lbls[i] = 1;
602        }
603        (feats, lbls)
604    }
605
606    /// Build a 3-class perfectly separated dataset (each class in a corner).
607    fn make_multiclass_separable(n_per_class: usize, dim: usize) -> (Vec<f64>, Vec<usize>) {
608        let n = n_per_class * 3;
609        let mut feats = vec![0.0_f64; n * dim];
610        let mut lbls = vec![0usize; n];
611        for k in 0..3usize {
612            for i in 0..n_per_class {
613                let row = k * n_per_class + i;
614                // Place class k far from the others along dimension k.
615                feats[row * dim + k.min(dim - 1)] = (k + 1) as f64 * 20.0;
616                lbls[row] = k;
617            }
618        }
619        (feats, lbls)
620    }
621
622    // ── test 1: config defaults ───────────────────────────────────────────────
623
624    #[test]
625    fn config_defaults() {
626        let cfg = LinearProbeConfig::default();
627        assert_eq!(cfg.n_folds, 5);
628        assert_eq!(cfg.max_iter, 200);
629        assert!((cfg.l2_reg - 1e-3).abs() < 1e-15);
630        assert_eq!(cfg.n_classes, 2);
631        assert!((cfg.tol - 1e-5).abs() < 1e-18);
632        assert_eq!(cfg.seed, 42);
633    }
634
635    // ── test 2: sigmoid numerical stability ──────────────────────────────────
636
637    #[test]
638    fn sigmoid_stable() {
639        assert!((sigmoid(0.0) - 0.5).abs() < 1e-15);
640        assert!((sigmoid(100.0) - 1.0).abs() < 1e-6);
641        assert!(sigmoid(-100.0) < 1e-6);
642        // Check it doesn't produce NaN for extreme values.
643        assert!(sigmoid(f64::MAX / 2.0).is_finite());
644        assert!(sigmoid(f64::MIN / 2.0).is_finite());
645    }
646
647    // ── test 3: empty input error ─────────────────────────────────────────────
648
649    #[test]
650    fn fit_empty_error() {
651        let cfg = LinearProbeConfig::default();
652        let result = linear_probe_fit(&[], &[], 0, 4, &cfg);
653        assert!(matches!(result, Err(SslError::EmptyInput)));
654    }
655
656    // ── test 4: single-class error ────────────────────────────────────────────
657
658    #[test]
659    fn fit_single_class_error() {
660        let cfg = LinearProbeConfig {
661            n_classes: 1,
662            ..Default::default()
663        };
664        let feats = vec![0.0_f64; 10 * 4];
665        let lbls = vec![0usize; 10];
666        let result = linear_probe_fit(&feats, &lbls, 10, 4, &cfg);
667        assert!(matches!(
668            result,
669            Err(SslError::InvalidParameter { name: _, reason: _ })
670        ));
671    }
672
673    // ── test 5: binary linearly separable ────────────────────────────────────
674
675    #[test]
676    fn fit_binary_linearly_separable() {
677        let cfg = LinearProbeConfig {
678            n_classes: 2,
679            max_iter: 200,
680            l2_reg: 1e-4,
681            ..Default::default()
682        };
683        let (feats, lbls) = make_binary_separable(20, 2, 10.0);
684        let probe =
685            linear_probe_fit(&feats, &lbls, 20, 2, &cfg).expect("linear_probe_fit should succeed");
686        let preds =
687            linear_probe_predict(&probe, &feats, 20).expect("linear_probe_predict should succeed");
688        let acc = accuracy(&preds, &lbls);
689        assert!(
690            acc >= 0.9,
691            "expected accuracy >= 0.9 on separable data, got {acc:.4}"
692        );
693    }
694
695    // ── test 6: predict shape ─────────────────────────────────────────────────
696
697    #[test]
698    fn predict_shape() {
699        let cfg = LinearProbeConfig::default();
700        let (feats, lbls) = make_binary_separable(20, 4, 5.0);
701        let probe =
702            linear_probe_fit(&feats, &lbls, 20, 4, &cfg).expect("linear_probe_fit should succeed");
703        let preds =
704            linear_probe_predict(&probe, &feats, 20).expect("linear_probe_predict should succeed");
705        assert_eq!(preds.len(), 20);
706    }
707
708    // ── test 7: multiclass perfectly separated → accuracy = 1.0 ──────────────
709
710    #[test]
711    fn fit_multiclass() {
712        let cfg = LinearProbeConfig {
713            n_classes: 3,
714            max_iter: 300,
715            l2_reg: 1e-4,
716            ..Default::default()
717        };
718        let (feats, lbls) = make_multiclass_separable(10, 4);
719        let probe =
720            linear_probe_fit(&feats, &lbls, 30, 4, &cfg).expect("linear_probe_fit should succeed");
721        let preds =
722            linear_probe_predict(&probe, &feats, 30).expect("linear_probe_predict should succeed");
723        let acc = accuracy(&preds, &lbls);
724        assert!(
725            (acc - 1.0).abs() < 1e-9,
726            "expected perfect accuracy, got {acc:.4}"
727        );
728    }
729
730    // ── test 8: weights layout ────────────────────────────────────────────────
731
732    #[test]
733    fn fit_returns_n_class_rows() {
734        let cfg = LinearProbeConfig {
735            n_classes: 3,
736            ..Default::default()
737        };
738        let in_dim = 5;
739        let (feats, lbls) = make_multiclass_separable(5, in_dim);
740        let probe = linear_probe_fit(&feats, &lbls, 15, in_dim, &cfg)
741            .expect("linear_probe_fit should succeed");
742        assert_eq!(probe.weights.len(), cfg.n_classes * (in_dim + 1));
743        assert_eq!(probe.in_dim, in_dim);
744        assert_eq!(probe.n_classes, cfg.n_classes);
745    }
746
747    // ── test 9: CV mean accuracy on separable data ────────────────────────────
748
749    #[test]
750    fn eval_cv_mean_accuracy_positive() {
751        let cfg = LinearProbeConfig {
752            n_classes: 2,
753            n_folds: 5,
754            max_iter: 200,
755            l2_reg: 1e-4,
756            ..Default::default()
757        };
758        // Build a larger separable dataset so each fold has enough training data.
759        let (feats, lbls) = make_binary_separable(50, 4, 10.0);
760        let result = linear_probe_eval(&feats, &lbls, 50, 4, &cfg)
761            .expect("linear_probe_eval should succeed");
762        assert!(
763            result.mean_accuracy > 0.8,
764            "expected mean_accuracy > 0.8, got {:.4}",
765            result.mean_accuracy
766        );
767    }
768
769    // ── test 10: std accuracy finite and non-negative ─────────────────────────
770
771    #[test]
772    fn eval_std_accuracy_finite() {
773        let cfg = LinearProbeConfig {
774            n_classes: 2,
775            n_folds: 5,
776            l2_reg: 1e-3,
777            ..Default::default()
778        };
779        let (feats, lbls) = make_binary_separable(50, 4, 10.0);
780        let result = linear_probe_eval(&feats, &lbls, 50, 4, &cfg)
781            .expect("linear_probe_eval should succeed");
782        assert!(result.std_accuracy.is_finite());
783        assert!(result.std_accuracy >= 0.0);
784    }
785
786    // ── test 11: macro F1 in [0, 1] ───────────────────────────────────────────
787
788    #[test]
789    fn eval_macro_f1_range() {
790        let cfg = LinearProbeConfig {
791            n_classes: 2,
792            n_folds: 5,
793            l2_reg: 1e-3,
794            ..Default::default()
795        };
796        let (feats, lbls) = make_binary_separable(50, 4, 10.0);
797        let result = linear_probe_eval(&feats, &lbls, 50, 4, &cfg)
798            .expect("linear_probe_eval should succeed");
799        assert!(
800            result.macro_f1 >= 0.0 && result.macro_f1 <= 1.0,
801            "macro_f1 = {:.4} out of [0, 1]",
802            result.macro_f1
803        );
804    }
805
806    // ── test 12: per_class_f1 length ──────────────────────────────────────────
807
808    #[test]
809    fn per_class_f1_length() {
810        let cfg = LinearProbeConfig {
811            n_classes: 3,
812            n_folds: 3,
813            l2_reg: 1e-3,
814            ..Default::default()
815        };
816        let (feats, lbls) = make_multiclass_separable(15, 4);
817        let result = linear_probe_eval(&feats, &lbls, 45, 4, &cfg)
818            .expect("linear_probe_eval should succeed");
819        assert_eq!(result.per_class_f1.len(), 3);
820    }
821
822    // ── test 13: cholesky_solve on identity ───────────────────────────────────
823
824    #[test]
825    fn cholesky_solve_identity() {
826        let n = 4;
827        let mut a = vec![0.0_f64; n * n];
828        for i in 0..n {
829            a[i * n + i] = 1.0;
830        }
831        let b = vec![1.0, -2.0, std::f64::consts::PI, 0.0];
832        let x = cholesky_solve(&a, &b, n).expect("cholesky_solve should succeed");
833        for (xi, bi) in x.iter().zip(b.iter()) {
834            assert!((xi - bi).abs() < 1e-12, "expected x={bi}, got {xi}");
835        }
836    }
837
838    // ── test 14 (bonus): cholesky_solve on a 3×3 SPD matrix ─────────────────
839
840    #[test]
841    fn cholesky_solve_spd_3x3() {
842        // A = [[4,2,1],[2,5,3],[1,3,6]] — positive definite.
843        let a = vec![4.0, 2.0, 1.0, 2.0, 5.0, 3.0, 1.0, 3.0, 6.0];
844        let b = vec![1.0, 2.0, 3.0];
845        let x = cholesky_solve(&a, &b, 3).expect("cholesky_solve should succeed");
846        // Verify A·x ≈ b.
847        let ax0 = 4.0 * x[0] + 2.0 * x[1] + 1.0 * x[2];
848        let ax1 = 2.0 * x[0] + 5.0 * x[1] + 3.0 * x[2];
849        let ax2 = 1.0 * x[0] + 3.0 * x[1] + 6.0 * x[2];
850        assert!((ax0 - 1.0).abs() < 1e-10);
851        assert!((ax1 - 2.0).abs() < 1e-10);
852        assert!((ax2 - 3.0).abs() < 1e-10);
853    }
854}