Skip to main content

cjc_runtime/
ml.rs

1//! ML Toolkit — loss functions, optimizers, activations, metrics.
2//!
3//! # Determinism Contract
4//! - All functions are deterministic (no randomness except seeded kfold).
5//! - Kahan summation for all reductions.
6//! - Stable sort for AUC-ROC with index tie-breaking.
7
8use cjc_repro::KahanAccumulatorF64;
9
10use crate::accumulator::BinnedAccumulatorF64;
11use crate::error::RuntimeError;
12use crate::tensor::Tensor;
13
14// ---------------------------------------------------------------------------
15// Loss functions
16// ---------------------------------------------------------------------------
17
18/// Mean Squared Error: sum((pred - target)^2) / n.
19pub fn mse_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
20    if pred.len() != target.len() {
21        return Err("mse_loss: arrays must have same length".into());
22    }
23    if pred.is_empty() {
24        return Err("mse_loss: empty data".into());
25    }
26    let mut acc = KahanAccumulatorF64::new();
27    for i in 0..pred.len() {
28        let d = pred[i] - target[i];
29        acc.add(d * d);
30    }
31    Ok(acc.finalize() / pred.len() as f64)
32}
33
34/// Cross-entropy loss: -sum(target * ln(pred + eps)) / n.
35pub fn cross_entropy_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
36    if pred.len() != target.len() {
37        return Err("cross_entropy_loss: arrays must have same length".into());
38    }
39    if pred.is_empty() {
40        return Err("cross_entropy_loss: empty data".into());
41    }
42    let eps = 1e-12;
43    let mut acc = KahanAccumulatorF64::new();
44    for i in 0..pred.len() {
45        acc.add(-target[i] * (pred[i] + eps).ln());
46    }
47    Ok(acc.finalize() / pred.len() as f64)
48}
49
50/// Binary cross-entropy: -sum(t*ln(p) + (1-t)*ln(1-p)) / n.
51pub fn binary_cross_entropy(pred: &[f64], target: &[f64]) -> Result<f64, String> {
52    if pred.len() != target.len() {
53        return Err("binary_cross_entropy: arrays must have same length".into());
54    }
55    if pred.is_empty() {
56        return Err("binary_cross_entropy: empty data".into());
57    }
58    let eps = 1e-12;
59    let mut acc = KahanAccumulatorF64::new();
60    for i in 0..pred.len() {
61        let p = pred[i].max(eps).min(1.0 - eps);
62        acc.add(-(target[i] * p.ln() + (1.0 - target[i]) * (1.0 - p).ln()));
63    }
64    Ok(acc.finalize() / pred.len() as f64)
65}
66
67/// Huber loss: quadratic for small errors, linear for large.
68pub fn huber_loss(pred: &[f64], target: &[f64], delta: f64) -> Result<f64, String> {
69    if pred.len() != target.len() {
70        return Err("huber_loss: arrays must have same length".into());
71    }
72    if pred.is_empty() {
73        return Err("huber_loss: empty data".into());
74    }
75    let mut acc = KahanAccumulatorF64::new();
76    for i in 0..pred.len() {
77        let d = (pred[i] - target[i]).abs();
78        if d <= delta {
79            acc.add(0.5 * d * d);
80        } else {
81            acc.add(delta * (d - 0.5 * delta));
82        }
83    }
84    Ok(acc.finalize() / pred.len() as f64)
85}
86
87/// Hinge loss: sum(max(0, 1 - target * pred)) / n.
88pub fn hinge_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
89    if pred.len() != target.len() {
90        return Err("hinge_loss: arrays must have same length".into());
91    }
92    if pred.is_empty() {
93        return Err("hinge_loss: empty data".into());
94    }
95    let mut acc = KahanAccumulatorF64::new();
96    for i in 0..pred.len() {
97        acc.add((1.0 - target[i] * pred[i]).max(0.0));
98    }
99    Ok(acc.finalize() / pred.len() as f64)
100}
101
102// ---------------------------------------------------------------------------
103// Optimizers
104// ---------------------------------------------------------------------------
105
106/// SGD optimizer state.
107pub struct SgdState {
108    pub lr: f64,
109    pub momentum: f64,
110    pub velocity: Vec<f64>,
111}
112
113impl SgdState {
114    pub fn new(n_params: usize, lr: f64, momentum: f64) -> Self {
115        Self { lr, momentum, velocity: vec![0.0; n_params] }
116    }
117}
118
119/// SGD step: sequential, deterministic.
120pub fn sgd_step(params: &mut [f64], grads: &[f64], state: &mut SgdState) {
121    for i in 0..params.len() {
122        state.velocity[i] = state.momentum * state.velocity[i] + grads[i];
123        params[i] -= state.lr * state.velocity[i];
124    }
125}
126
127/// Adam optimizer state.
128pub struct AdamState {
129    pub lr: f64,
130    pub beta1: f64,
131    pub beta2: f64,
132    pub eps: f64,
133    pub t: u64,
134    pub m: Vec<f64>,
135    pub v: Vec<f64>,
136}
137
138impl AdamState {
139    pub fn new(n_params: usize, lr: f64) -> Self {
140        Self {
141            lr,
142            beta1: 0.9,
143            beta2: 0.999,
144            eps: 1e-8,
145            t: 0,
146            m: vec![0.0; n_params],
147            v: vec![0.0; n_params],
148        }
149    }
150}
151
152/// Adam step: sequential, deterministic.
153pub fn adam_step(params: &mut [f64], grads: &[f64], state: &mut AdamState) {
154    state.t += 1;
155    let t = state.t as f64;
156    for i in 0..params.len() {
157        state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * grads[i];
158        state.v[i] = state.beta2 * state.v[i] + (1.0 - state.beta2) * grads[i] * grads[i];
159        let m_hat = state.m[i] / (1.0 - state.beta1.powf(t));
160        let v_hat = state.v[i] / (1.0 - state.beta2.powf(t));
161        params[i] -= state.lr * m_hat / (v_hat.sqrt() + state.eps);
162    }
163}
164
165// ---------------------------------------------------------------------------
166// Classification metrics (Sprint 6)
167// ---------------------------------------------------------------------------
168
169/// Binary confusion matrix.
170#[derive(Debug, Clone)]
171pub struct ConfusionMatrix {
172    pub tp: usize,
173    pub fp: usize,
174    pub tn: usize,
175    pub fn_count: usize,
176}
177
178/// Build confusion matrix from predicted and actual boolean labels.
179pub fn confusion_matrix(predicted: &[bool], actual: &[bool]) -> ConfusionMatrix {
180    let mut tp = 0;
181    let mut fp = 0;
182    let mut tn = 0;
183    let mut fn_count = 0;
184    for i in 0..predicted.len().min(actual.len()) {
185        match (predicted[i], actual[i]) {
186            (true, true) => tp += 1,
187            (true, false) => fp += 1,
188            (false, true) => fn_count += 1,
189            (false, false) => tn += 1,
190        }
191    }
192    ConfusionMatrix { tp, fp, tn, fn_count }
193}
194
195/// Precision: TP / (TP + FP).
196pub fn precision(cm: &ConfusionMatrix) -> f64 {
197    let denom = cm.tp + cm.fp;
198    if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
199}
200
201/// Recall / sensitivity: TP / (TP + FN).
202pub fn recall(cm: &ConfusionMatrix) -> f64 {
203    let denom = cm.tp + cm.fn_count;
204    if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
205}
206
207/// F1 score: 2 * (precision * recall) / (precision + recall).
208pub fn f1_score(cm: &ConfusionMatrix) -> f64 {
209    let p = precision(cm);
210    let r = recall(cm);
211    if p + r == 0.0 { 0.0 } else { 2.0 * p * r / (p + r) }
212}
213
214/// Accuracy: (TP + TN) / total.
215pub fn accuracy(cm: &ConfusionMatrix) -> f64 {
216    let total = cm.tp + cm.fp + cm.tn + cm.fn_count;
217    if total == 0 { 0.0 } else { (cm.tp + cm.tn) as f64 / total as f64 }
218}
219
220/// AUC-ROC via trapezoidal rule.
221/// DETERMINISM: sort by score with stable sort + index tie-breaking.
222pub fn auc_roc(scores: &[f64], labels: &[bool]) -> Result<f64, String> {
223    if scores.len() != labels.len() {
224        return Err("auc_roc: scores and labels must have same length".into());
225    }
226    let n = scores.len();
227    if n == 0 {
228        return Err("auc_roc: empty data".into());
229    }
230    // Sort by score descending, stable with index tie-breaking
231    let mut indexed: Vec<(usize, f64, bool)> = scores.iter().zip(labels.iter())
232        .enumerate()
233        .map(|(i, (&s, &l))| (i, s, l))
234        .collect();
235    indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
236
237    let pos_count = labels.iter().filter(|&&l| l).count();
238    let neg_count = n - pos_count;
239    if pos_count == 0 || neg_count == 0 {
240        return Err("auc_roc: need both positive and negative labels".into());
241    }
242
243    let mut auc = 0.0;
244    let mut tp = 0.0;
245    let mut fp = 0.0;
246    let mut prev_fpr = 0.0;
247    let mut prev_tpr = 0.0;
248
249    for &(_, _, label) in &indexed {
250        if label { tp += 1.0; } else { fp += 1.0; }
251        let tpr = tp / pos_count as f64;
252        let fpr = fp / neg_count as f64;
253        auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
254        prev_fpr = fpr;
255        prev_tpr = tpr;
256    }
257    Ok(auc)
258}
259
260/// K-fold cross-validation indices.
261/// DETERMINISM: uses seeded RNG (Fisher-Yates).
262pub fn kfold_indices(n: usize, k: usize, seed: u64) -> Vec<(Vec<usize>, Vec<usize>)> {
263    let mut rng = cjc_repro::Rng::seeded(seed);
264    // Fisher-Yates shuffle of [0..n]
265    let mut indices: Vec<usize> = (0..n).collect();
266    for i in (1..n).rev() {
267        let j = (rng.next_u64() as usize) % (i + 1);
268        indices.swap(i, j);
269    }
270    let fold_size = n / k;
271    let mut folds = Vec::with_capacity(k);
272    for fold in 0..k {
273        let start = fold * fold_size;
274        let end = if fold == k - 1 { n } else { start + fold_size };
275        let test: Vec<usize> = indices[start..end].to_vec();
276        let train: Vec<usize> = indices[..start].iter()
277            .chain(indices[end..].iter())
278            .copied()
279            .collect();
280        folds.push((train, test));
281    }
282    folds
283}
284
285/// Train/test split indices.
286pub fn train_test_split(n: usize, test_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
287    let mut rng = cjc_repro::Rng::seeded(seed);
288    let mut indices: Vec<usize> = (0..n).collect();
289    for i in (1..n).rev() {
290        let j = (rng.next_u64() as usize) % (i + 1);
291        indices.swap(i, j);
292    }
293    let test_size = ((n as f64) * test_fraction).round() as usize;
294    let test = indices[..test_size].to_vec();
295    let train = indices[test_size..].to_vec();
296    (train, test)
297}
298
299// ---------------------------------------------------------------------------
300// Bootstrap Resampling
301// ---------------------------------------------------------------------------
302
303/// Bootstrap confidence interval for a statistic (e.g., mean).
304/// Returns (point_estimate, ci_lower, ci_upper, standard_error).
305/// `stat_fn` is 0=mean, 1=median.
306pub fn bootstrap(data: &[f64], n_resamples: usize, stat_fn: usize, seed: u64) -> Result<(f64, f64, f64, f64), String> {
307    if data.is_empty() { return Err("bootstrap: empty data".into()); }
308    let n = data.len();
309
310    // Compute the statistic on original data
311    let point = compute_stat(data, stat_fn)?;
312
313    // Bootstrap resampling
314    let mut rng = cjc_repro::Rng::seeded(seed);
315    let mut stats = Vec::with_capacity(n_resamples);
316    let mut resample = Vec::with_capacity(n);
317
318    for _ in 0..n_resamples {
319        resample.clear();
320        for _ in 0..n {
321            let idx = (rng.next_u64() as usize) % n;
322            resample.push(data[idx]);
323        }
324        stats.push(compute_stat(&resample, stat_fn)?);
325    }
326
327    // Sort for percentile CI
328    stats.sort_by(|a, b| a.total_cmp(b));
329
330    let ci_lower = stats[(n_resamples as f64 * 0.025) as usize];
331    let ci_upper = stats[(n_resamples as f64 * 0.975).min((n_resamples - 1) as f64) as usize];
332
333    // Standard error
334    let mean_stats: f64 = {
335        let mut acc = cjc_repro::KahanAccumulatorF64::new();
336        for &s in &stats { acc.add(s); }
337        acc.finalize() / n_resamples as f64
338    };
339    let se = {
340        let mut acc = cjc_repro::KahanAccumulatorF64::new();
341        for &s in &stats { let d = s - mean_stats; acc.add(d * d); }
342        (acc.finalize() / (n_resamples as f64 - 1.0)).sqrt()
343    };
344
345    Ok((point, ci_lower, ci_upper, se))
346}
347
348fn compute_stat(data: &[f64], stat_fn: usize) -> Result<f64, String> {
349    match stat_fn {
350        0 => {
351            // Mean
352            let mut acc = cjc_repro::KahanAccumulatorF64::new();
353            for &x in data { acc.add(x); }
354            Ok(acc.finalize() / data.len() as f64)
355        }
356        1 => {
357            // Median
358            let mut sorted = data.to_vec();
359            sorted.sort_by(|a, b| a.total_cmp(b));
360            let n = sorted.len();
361            if n % 2 == 0 {
362                Ok((sorted[n/2 - 1] + sorted[n/2]) / 2.0)
363            } else {
364                Ok(sorted[n/2])
365            }
366        }
367        _ => Err(format!("bootstrap: unknown stat_fn {}", stat_fn)),
368    }
369}
370
371/// Permutation test: test whether two groups differ on a statistic.
372/// Returns (observed_diff, p_value).
373pub fn permutation_test(x: &[f64], y: &[f64], n_perms: usize, seed: u64) -> Result<(f64, f64), String> {
374    if x.is_empty() || y.is_empty() { return Err("permutation_test: empty group".into()); }
375
376    let nx = x.len();
377    let combined: Vec<f64> = x.iter().chain(y.iter()).copied().collect();
378    let n = combined.len();
379
380    // Observed difference of means
381    let mean_x = compute_stat(x, 0)?;
382    let mean_y = compute_stat(y, 0)?;
383    let observed = (mean_x - mean_y).abs();
384
385    // Permutation
386    let mut rng = cjc_repro::Rng::seeded(seed);
387    let mut count_extreme = 0usize;
388    let mut perm = combined.clone();
389
390    for _ in 0..n_perms {
391        // Fisher-Yates shuffle
392        for i in (1..n).rev() {
393            let j = (rng.next_u64() as usize) % (i + 1);
394            perm.swap(i, j);
395        }
396        let perm_mean_x = compute_stat(&perm[..nx], 0)?;
397        let perm_mean_y = compute_stat(&perm[nx..], 0)?;
398        if (perm_mean_x - perm_mean_y).abs() >= observed {
399            count_extreme += 1;
400        }
401    }
402
403    let p_value = count_extreme as f64 / n_perms as f64;
404    Ok((observed, p_value))
405}
406
407/// Stratified train/test split: maintains class proportions in both sets.
408/// `labels` is an array of integer class labels, `test_frac` is fraction for test set.
409/// Returns (train_indices, test_indices).
410pub fn stratified_split(labels: &[i64], test_frac: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
411    use std::collections::BTreeMap;
412
413    let n = labels.len();
414    // Group indices by label
415    let mut groups: BTreeMap<i64, Vec<usize>> = BTreeMap::new();
416    for (i, &label) in labels.iter().enumerate() {
417        groups.entry(label).or_default().push(i);
418    }
419
420    let mut train = Vec::with_capacity(n);
421    let mut test = Vec::with_capacity(n);
422    let mut rng = cjc_repro::Rng::seeded(seed);
423
424    for (_label, mut indices) in groups {
425        // Shuffle indices within each stratum
426        let m = indices.len();
427        for i in (1..m).rev() {
428            let j = (rng.next_u64() as usize) % (i + 1);
429            indices.swap(i, j);
430        }
431        let n_test = ((m as f64 * test_frac).round() as usize).max(if m > 1 { 1 } else { 0 });
432        let n_test = n_test.min(m);
433        test.extend_from_slice(&indices[..n_test]);
434        train.extend_from_slice(&indices[n_test..]);
435    }
436
437    // Sort both for deterministic output order
438    train.sort();
439    test.sort();
440    (train, test)
441}
442
443// ---------------------------------------------------------------------------
444// Phase B4: ML Training Extensions
445// ---------------------------------------------------------------------------
446
447/// Batch normalization (inference mode).
448/// y = gamma * (x - running_mean) / sqrt(running_var + eps) + beta.
449pub fn batch_norm(
450    x: &[f64],
451    running_mean: &[f64],
452    running_var: &[f64],
453    gamma: &[f64],
454    beta: &[f64],
455    eps: f64,
456) -> Result<Vec<f64>, String> {
457    let n = x.len();
458    if running_mean.len() != n || running_var.len() != n || gamma.len() != n || beta.len() != n {
459        return Err("batch_norm: all arrays must have same length".into());
460    }
461    let mut result = Vec::with_capacity(n);
462    for i in 0..n {
463        let normed = (x[i] - running_mean[i]) / (running_var[i] + eps).sqrt();
464        result.push(gamma[i] * normed + beta[i]);
465    }
466    Ok(result)
467}
468
469/// Dropout mask generation using seeded RNG for determinism.
470/// Returns mask of 0.0 and scale values (1/(1-p)) using inverted dropout.
471pub fn dropout_mask(n: usize, drop_prob: f64, seed: u64) -> Vec<f64> {
472    let mut rng = cjc_repro::Rng::seeded(seed);
473    let scale = if drop_prob < 1.0 { 1.0 / (1.0 - drop_prob) } else { 0.0 };
474    let mut mask = Vec::with_capacity(n);
475    for _ in 0..n {
476        let r = (rng.next_u64() as f64) / (u64::MAX as f64);
477        if r < drop_prob {
478            mask.push(0.0);
479        } else {
480            mask.push(scale);
481        }
482    }
483    mask
484}
485
486/// Apply dropout: element-wise multiply data by mask.
487pub fn apply_dropout(data: &[f64], mask: &[f64]) -> Result<Vec<f64>, String> {
488    if data.len() != mask.len() {
489        return Err("apply_dropout: data and mask must have same length".into());
490    }
491    Ok(data.iter().zip(mask.iter()).map(|(&d, &m)| d * m).collect())
492}
493
494/// Learning rate schedule: step decay.
495/// lr = initial_lr * decay_rate^(floor(epoch / step_size))
496pub fn lr_step_decay(initial_lr: f64, decay_rate: f64, epoch: usize, step_size: usize) -> f64 {
497    initial_lr * decay_rate.powi((epoch / step_size) as i32)
498}
499
500/// Learning rate schedule: cosine annealing.
501/// lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(pi * epoch / total_epochs))
502pub fn lr_cosine(max_lr: f64, min_lr: f64, epoch: usize, total_epochs: usize) -> f64 {
503    let ratio = epoch as f64 / total_epochs as f64;
504    min_lr + 0.5 * (max_lr - min_lr) * (1.0 + (std::f64::consts::PI * ratio).cos())
505}
506
507/// Learning rate schedule: linear warmup.
508/// lr = initial_lr * min(1.0, epoch / warmup_epochs).
509pub fn lr_linear_warmup(initial_lr: f64, epoch: usize, warmup_epochs: usize) -> f64 {
510    if warmup_epochs == 0 {
511        return initial_lr;
512    }
513    initial_lr * (epoch as f64 / warmup_epochs as f64).min(1.0)
514}
515
516/// L1 regularization penalty: lambda * sum(|params|).
517pub fn l1_penalty(params: &[f64], lambda: f64) -> f64 {
518    let mut acc = KahanAccumulatorF64::new();
519    for &p in params {
520        acc.add(p.abs());
521    }
522    lambda * acc.finalize()
523}
524
525/// L2 regularization penalty: 0.5 * lambda * sum(params^2).
526pub fn l2_penalty(params: &[f64], lambda: f64) -> f64 {
527    let mut acc = KahanAccumulatorF64::new();
528    for &p in params {
529        acc.add(p * p);
530    }
531    0.5 * lambda * acc.finalize()
532}
533
534/// L1 regularization gradient: lambda * sign(params).
535pub fn l1_grad(params: &[f64], lambda: f64) -> Vec<f64> {
536    params.iter().map(|&p| {
537        if p > 0.0 { lambda } else if p < 0.0 { -lambda } else { 0.0 }
538    }).collect()
539}
540
541/// L2 regularization gradient: lambda * params.
542pub fn l2_grad(params: &[f64], lambda: f64) -> Vec<f64> {
543    params.iter().map(|&p| lambda * p).collect()
544}
545
546/// Early stopping state tracker.
547pub struct EarlyStoppingState {
548    pub patience: usize,
549    pub min_delta: f64,
550    pub best_loss: f64,
551    pub wait: usize,
552    pub stopped: bool,
553}
554
555impl EarlyStoppingState {
556    pub fn new(patience: usize, min_delta: f64) -> Self {
557        Self {
558            patience,
559            min_delta,
560            best_loss: f64::INFINITY,
561            wait: 0,
562            stopped: false,
563        }
564    }
565
566    /// Check if training should stop.
567    pub fn check(&mut self, current_loss: f64) -> bool {
568        if current_loss < self.best_loss - self.min_delta {
569            self.best_loss = current_loss;
570            self.wait = 0;
571        } else {
572            self.wait += 1;
573        }
574        if self.wait >= self.patience {
575            self.stopped = true;
576        }
577        self.stopped
578    }
579}
580
581// ---------------------------------------------------------------------------
582// Phase 3C: PCA (Principal Component Analysis)
583// ---------------------------------------------------------------------------
584
585/// Principal Component Analysis via SVD of centered data.
586///
587/// `data` is a 2D Tensor of shape (n_samples, n_features).
588/// `n_components` is the number of principal components to keep.
589///
590/// Returns (transformed_data, components, explained_variance_ratio):
591/// - `transformed_data`: (n_samples, n_components) — data projected onto principal components
592/// - `components`: (n_components, n_features) — principal component directions (rows)
593/// - `explained_variance_ratio`: Vec<f64> of length n_components — fraction of variance per component
594///
595/// **Determinism contract:** All reductions use `BinnedAccumulatorF64`.
596pub fn pca(
597    data: &Tensor,
598    n_components: usize,
599) -> Result<(Tensor, Tensor, Vec<f64>), RuntimeError> {
600    if data.ndim() != 2 {
601        return Err(RuntimeError::InvalidOperation(
602            "PCA requires a 2D data matrix".to_string(),
603        ));
604    }
605    let n_samples = data.shape()[0];
606    let n_features = data.shape()[1];
607
608    if n_samples == 0 || n_features == 0 {
609        return Err(RuntimeError::InvalidOperation(
610            "PCA: empty data matrix".to_string(),
611        ));
612    }
613    if n_components == 0 || n_components > n_features.min(n_samples) {
614        return Err(RuntimeError::InvalidOperation(format!(
615            "PCA: n_components ({}) must be in [1, min(n_samples, n_features) = {}]",
616            n_components,
617            n_features.min(n_samples)
618        )));
619    }
620
621    let raw = data.to_vec();
622
623    // Step 1: Compute column means using BinnedAccumulatorF64
624    let mut means = vec![0.0f64; n_features];
625    for j in 0..n_features {
626        let mut acc = BinnedAccumulatorF64::new();
627        for i in 0..n_samples {
628            acc.add(raw[i * n_features + j]);
629        }
630        means[j] = acc.finalize() / n_samples as f64;
631    }
632
633    // Step 2: Center the data
634    let mut centered = vec![0.0f64; n_samples * n_features];
635    for i in 0..n_samples {
636        for j in 0..n_features {
637            centered[i * n_features + j] = raw[i * n_features + j] - means[j];
638        }
639    }
640    let centered_tensor = Tensor::from_vec(centered, &[n_samples, n_features])?;
641
642    // Step 3: SVD of centered data
643    let (u, s, vt) = centered_tensor.svd()?;
644    let k = n_components.min(s.len());
645
646    // Step 4: Components = first k rows of Vt
647    let vt_data = vt.to_vec();
648    let vt_cols = vt.shape()[1]; // n_features
649    let mut components = vec![0.0f64; k * n_features];
650    for i in 0..k {
651        for j in 0..n_features {
652            components[i * n_features + j] = vt_data[i * vt_cols + j];
653        }
654    }
655
656    // Step 5: Explained variance = s_i^2 / (n_samples - 1)
657    // Total variance = sum of all s_i^2 / (n_samples - 1)
658    let denom = if n_samples > 1 {
659        (n_samples - 1) as f64
660    } else {
661        1.0
662    };
663
664    let mut total_var_acc = BinnedAccumulatorF64::new();
665    for &si in &s {
666        total_var_acc.add(si * si / denom);
667    }
668    let total_var = total_var_acc.finalize();
669
670    let explained_variance_ratio: Vec<f64> = if total_var > 1e-15 {
671        s[..k]
672            .iter()
673            .map(|&si| (si * si / denom) / total_var)
674            .collect()
675    } else {
676        vec![0.0; k]
677    };
678
679    // Step 6: Transformed data = U_k @ diag(S_k)
680    let u_data = u.to_vec();
681    let u_cols = u.shape()[1];
682    let mut transformed = vec![0.0f64; n_samples * k];
683    for i in 0..n_samples {
684        for j in 0..k {
685            transformed[i * k + j] = u_data[i * u_cols + j] * s[j];
686        }
687    }
688
689    Ok((
690        Tensor::from_vec(transformed, &[n_samples, k])?,
691        Tensor::from_vec(components, &[k, n_features])?,
692        explained_variance_ratio,
693    ))
694}
695
696// ---------------------------------------------------------------------------
697// L-BFGS Optimizer (Sprint 2)
698// ---------------------------------------------------------------------------
699
700/// L-BFGS optimizer state.
701///
702/// Limited-memory Broyden-Fletcher-Goldfarb-Shanno quasi-Newton method.
703/// Maintains a history of `m` most recent (s, y) pairs to approximate
704/// the inverse Hessian without storing it explicitly.
705pub struct LbfgsState {
706    pub lr: f64,
707    /// History size (default 10)
708    pub m: usize,
709    /// Parameter differences: s_k = params_{k+1} - params_k
710    pub s_history: Vec<Vec<f64>>,
711    /// Gradient differences: y_k = grad_{k+1} - grad_k
712    pub y_history: Vec<Vec<f64>>,
713    pub prev_params: Option<Vec<f64>>,
714    pub prev_grad: Option<Vec<f64>>,
715}
716
717impl LbfgsState {
718    pub fn new(lr: f64, m: usize) -> Self {
719        Self {
720            lr,
721            m,
722            s_history: Vec::new(),
723            y_history: Vec::new(),
724            prev_params: None,
725            prev_grad: None,
726        }
727    }
728}
729
730/// Dot product of two slices using Kahan summation for determinism.
731fn kahan_dot(a: &[f64], b: &[f64]) -> f64 {
732    debug_assert_eq!(a.len(), b.len());
733    let mut acc = KahanAccumulatorF64::new();
734    for (&ai, &bi) in a.iter().zip(b.iter()) {
735        acc.add(ai * bi);
736    }
737    acc.finalize()
738}
739
740/// Strong Wolfe line search.
741///
742/// Finds a step length `alpha` satisfying both:
743/// - Armijo (sufficient decrease): f(x + alpha*d) <= f(x) + c1*alpha*g^T*d
744/// - Curvature condition: |g(x + alpha*d)^T*d| <= c2*|g(x)^T*d|
745///
746/// Uses backtracking bracketing followed by bisection zoom.
747///
748/// # Arguments
749/// * `params` - Current parameters
750/// * `direction` - Search direction (should be a descent direction)
751/// * `f` - Function that returns (value, gradient) at a parameter vector
752/// * `f0` - Current function value
753/// * `g0` - Current gradient
754/// * `alpha_init` - Initial step length (typically `lr`)
755///
756/// # Returns
757/// `(alpha, new_params, new_val, new_grad)` or the best found so far on failure.
758pub fn wolfe_line_search<F>(
759    params: &[f64],
760    direction: &[f64],
761    f: &mut F,
762    f0: f64,
763    g0: &[f64],
764    alpha_init: f64,
765) -> (f64, Vec<f64>, f64, Vec<f64>)
766where
767    F: FnMut(&[f64]) -> (f64, Vec<f64>),
768{
769    let c1 = 1e-4_f64;
770    let c2 = 0.9_f64;
771    let derphi0 = kahan_dot(g0, direction); // should be negative for descent
772
773    // Helper: evaluate f at params + alpha * direction
774    let step = |alpha: f64| -> Vec<f64> {
775        params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
776    };
777
778    let max_iter = 30;
779    let mut alpha_lo = 0.0_f64;
780    let mut alpha_hi = f64::INFINITY;
781    let mut phi_lo = f0;
782    let mut dphi_lo = derphi0;
783    let mut alpha = alpha_init;
784
785    // Keep track of best valid alpha found (fallback)
786    let mut best_alpha = alpha_init;
787    let mut best_params = step(alpha_init);
788    let (mut best_val, mut best_grad) = f(&best_params);
789
790    for _iter in 0..max_iter {
791        let x_alpha = step(alpha);
792        let (phi_alpha, grad_alpha) = f(&x_alpha);
793        let dphi_alpha = kahan_dot(&grad_alpha, direction);
794
795        // Track the best point for fallback
796        if phi_alpha < best_val {
797            best_alpha = alpha;
798            best_params = x_alpha.clone();
799            best_val = phi_alpha;
800            best_grad = grad_alpha.clone();
801        }
802
803        // Armijo condition violated or function increased beyond bracket — shrink hi
804        if phi_alpha > f0 + c1 * alpha * derphi0 || (phi_alpha >= phi_lo && alpha_lo > 0.0) {
805            alpha_hi = alpha;
806            // Zoom between alpha_lo and alpha_hi
807            let (za, zp, zv, zg) = wolfe_zoom(
808                params, direction, f, f0, derphi0, c1, c2,
809                alpha_lo, alpha_hi, phi_lo, dphi_lo,
810            );
811            return (za, zp, zv, zg);
812        }
813
814        // Strong Wolfe curvature condition satisfied — done
815        if dphi_alpha.abs() <= c2 * derphi0.abs() {
816            return (alpha, x_alpha, phi_alpha, grad_alpha);
817        }
818
819        // Derivative is positive — zoom between alpha and alpha_lo
820        if dphi_alpha >= 0.0 {
821            let (za, zp, zv, zg) = wolfe_zoom(
822                params, direction, f, f0, derphi0, c1, c2,
823                alpha, alpha_lo, phi_alpha, dphi_alpha,
824            );
825            return (za, zp, zv, zg);
826        }
827
828        // Expand: step is too short
829        alpha_lo = alpha;
830        phi_lo = phi_alpha;
831        dphi_lo = dphi_alpha;
832
833        // Expand by factor 2 (capped to prevent explosion)
834        alpha = if alpha_hi.is_finite() {
835            (alpha_lo + alpha_hi) * 0.5
836        } else {
837            (alpha * 2.0).min(1e8)
838        };
839    }
840
841    // Return best found
842    (best_alpha, best_params, best_val, best_grad)
843}
844
845/// Zoom phase of Wolfe line search — bisects between [alpha_lo, alpha_hi].
846#[allow(clippy::too_many_arguments)]
847fn wolfe_zoom<F>(
848    params: &[f64],
849    direction: &[f64],
850    f: &mut F,
851    f0: f64,
852    derphi0: f64,
853    c1: f64,
854    c2: f64,
855    mut alpha_lo: f64,
856    mut alpha_hi: f64,
857    mut phi_lo: f64,
858    _dphi_lo: f64,
859) -> (f64, Vec<f64>, f64, Vec<f64>)
860where
861    F: FnMut(&[f64]) -> (f64, Vec<f64>),
862{
863    let step = |alpha: f64| -> Vec<f64> {
864        params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
865    };
866
867    let max_zoom = 20;
868    let mut best_alpha = alpha_lo;
869    let mut best_x = step(alpha_lo);
870    let (mut best_val, mut best_grad) = f(&best_x);
871
872    for _i in 0..max_zoom {
873        let alpha_j = (alpha_lo + alpha_hi) * 0.5;
874        let x_j = step(alpha_j);
875        let (phi_j, grad_j) = f(&x_j);
876        let dphi_j = kahan_dot(&grad_j, direction);
877
878        if phi_j < best_val {
879            best_alpha = alpha_j;
880            best_x = x_j.clone();
881            best_val = phi_j;
882            best_grad = grad_j.clone();
883        }
884
885        if phi_j > f0 + c1 * alpha_j * derphi0 || phi_j >= phi_lo {
886            alpha_hi = alpha_j;
887        } else {
888            // Curvature condition satisfied
889            if dphi_j.abs() <= c2 * derphi0.abs() {
890                return (alpha_j, x_j, phi_j, grad_j);
891            }
892            if dphi_j * (alpha_hi - alpha_lo) >= 0.0 {
893                alpha_hi = alpha_lo;
894            }
895            alpha_lo = alpha_j;
896            phi_lo = phi_j;
897        }
898
899        // Convergence check on bracket width
900        if (alpha_hi - alpha_lo).abs() < 1e-14 {
901            break;
902        }
903    }
904
905    (best_alpha, best_x, best_val, best_grad)
906}
907
908/// L-BFGS step with strong Wolfe line search.
909///
910/// Given current parameters and gradients (and a closure to evaluate the
911/// function and gradient at any point), computes the L-BFGS search direction
912/// via the two-loop recursion, performs a Wolfe line search along that
913/// direction, then updates the (s, y) history buffers.
914///
915/// # Arguments
916/// * `params` - Current parameter vector (modified in-place on success)
917/// * `grads` - Gradient at current params
918/// * `state` - L-BFGS state (history + meta-parameters)
919/// * `f` - Closure: takes parameter slice, returns `(loss, gradient_vec)`
920///
921/// # Returns
922/// `(new_params, new_grads, step_taken)` — step_taken is false if the search
923/// direction was not a descent direction (resets to gradient descent).
924pub fn lbfgs_step<F>(
925    params: &[f64],
926    grads: &[f64],
927    state: &mut LbfgsState,
928    mut f: F,
929) -> (Vec<f64>, Vec<f64>, bool)
930where
931    F: FnMut(&[f64]) -> (f64, Vec<f64>),
932{
933    let n = params.len();
934    debug_assert_eq!(grads.len(), n);
935
936    // ---- Two-loop L-BFGS recursion ----
937    // Computes H_k * g using the compact L-BFGS inverse Hessian approximation.
938    let hist_len = state.s_history.len();
939    let mut q: Vec<f64> = grads.to_vec();
940    let mut alphas = vec![0.0_f64; hist_len];
941    let mut rhos = vec![0.0_f64; hist_len];
942
943    // Loop 1: backward through history
944    for i in (0..hist_len).rev() {
945        let sy = kahan_dot(&state.s_history[i], &state.y_history[i]);
946        rhos[i] = if sy.abs() < 1e-300 { 0.0 } else { 1.0 / sy };
947        let sq = kahan_dot(&state.s_history[i], &q);
948        alphas[i] = rhos[i] * sq;
949        // q -= alpha_i * y_i
950        for j in 0..n {
951            q[j] -= alphas[i] * state.y_history[i][j];
952        }
953    }
954
955    // Scale by initial Hessian approximation H_0 = (s^T y / y^T y) * I
956    let scale = if hist_len > 0 {
957        let last = hist_len - 1;
958        let sy = kahan_dot(&state.s_history[last], &state.y_history[last]);
959        let yy = kahan_dot(&state.y_history[last], &state.y_history[last]);
960        if yy.abs() < 1e-300 { 1.0 } else { sy / yy }
961    } else {
962        1.0
963    };
964    let mut r: Vec<f64> = q.iter().map(|&qi| scale * qi).collect();
965
966    // Loop 2: forward through history
967    for i in 0..hist_len {
968        let yr = kahan_dot(&state.y_history[i], &r);
969        let beta = rhos[i] * yr;
970        // r += (alpha_i - beta) * s_i
971        let diff = alphas[i] - beta;
972        for j in 0..n {
973            r[j] += diff * state.s_history[i][j];
974        }
975    }
976
977    // Search direction: d = -H_k * g = -r
978    let direction: Vec<f64> = r.iter().map(|&ri| -ri).collect();
979
980    // Check descent condition: d^T g < 0
981    let descent_check = kahan_dot(&direction, grads);
982    let (direction, is_descent) = if descent_check >= 0.0 || !descent_check.is_finite() {
983        // Fall back to scaled negative gradient
984        let norm_g = kahan_dot(grads, grads).sqrt().max(1e-300);
985        (grads.iter().map(|&g| -g / norm_g).collect::<Vec<f64>>(), false)
986    } else {
987        (direction, true)
988    };
989
990    // ---- Wolfe line search ----
991    let (f0, _) = f(params);
992    let (_, new_params, _, new_grads) = wolfe_line_search(
993        params,
994        &direction,
995        &mut f,
996        f0,
997        grads,
998        state.lr,
999    );
1000
1001    // ---- Update history ----
1002    // s_k = new_params - params
1003    let s_k: Vec<f64> = new_params.iter().zip(params.iter()).map(|(&np, &p)| np - p).collect();
1004    // y_k = new_grads - grads
1005    let y_k: Vec<f64> = new_grads.iter().zip(grads.iter()).map(|(&ng, &g)| ng - g).collect();
1006
1007    let sy = kahan_dot(&s_k, &y_k);
1008    // Only add to history if curvature condition holds (sy > 0)
1009    if sy > 1e-300 {
1010        state.s_history.push(s_k);
1011        state.y_history.push(y_k);
1012        // Trim to m most recent
1013        if state.s_history.len() > state.m {
1014            state.s_history.remove(0);
1015            state.y_history.remove(0);
1016        }
1017    }
1018
1019    state.prev_params = Some(new_params.clone());
1020    state.prev_grad = Some(new_grads.clone());
1021
1022    (new_params, new_grads, is_descent)
1023}
1024
1025// ---------------------------------------------------------------------------
1026// LSTM cell forward pass
1027// ---------------------------------------------------------------------------
1028
1029/// LSTM cell forward pass.
1030///
1031/// * `x`:      `[batch, input_size]`
1032/// * `h_prev`: `[batch, hidden_size]`
1033/// * `c_prev`: `[batch, hidden_size]`
1034/// * `w_ih`:   `[4*hidden_size, input_size]`
1035/// * `w_hh`:   `[4*hidden_size, hidden_size]`
1036/// * `b_ih`:   `[4*hidden_size]`
1037/// * `b_hh`:   `[4*hidden_size]`
1038///
1039/// Returns `(h_new, c_new)`.
1040///
1041/// Gate layout (PyTorch convention): i, f, g, o — each of size `hidden_size`.
1042/// All reductions use Kahan summation via the existing `Tensor::linear` method.
1043pub fn lstm_cell(
1044    x: &Tensor,
1045    h_prev: &Tensor,
1046    c_prev: &Tensor,
1047    w_ih: &Tensor,
1048    w_hh: &Tensor,
1049    b_ih: &Tensor,
1050    b_hh: &Tensor,
1051) -> Result<(Tensor, Tensor), String> {
1052    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1053
1054    // Validate shapes
1055    if x.ndim() != 2 {
1056        return Err("lstm_cell: x must be 2-D [batch, input_size]".into());
1057    }
1058    if h_prev.ndim() != 2 {
1059        return Err("lstm_cell: h_prev must be 2-D [batch, hidden_size]".into());
1060    }
1061    if c_prev.ndim() != 2 {
1062        return Err("lstm_cell: c_prev must be 2-D [batch, hidden_size]".into());
1063    }
1064    let hidden_size = h_prev.shape()[1];
1065    if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1066        return Err(format!(
1067            "lstm_cell: w_ih must be [4*hidden_size, input_size], got {:?}",
1068            w_ih.shape()
1069        ));
1070    }
1071    if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1072        return Err(format!(
1073            "lstm_cell: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1074            w_hh.shape()
1075        ));
1076    }
1077    if b_ih.len() != 4 * hidden_size {
1078        return Err(format!(
1079            "lstm_cell: b_ih must have length 4*hidden_size={}, got {}",
1080            4 * hidden_size,
1081            b_ih.len()
1082        ));
1083    }
1084    if b_hh.len() != 4 * hidden_size {
1085        return Err(format!(
1086            "lstm_cell: b_hh must have length 4*hidden_size={}, got {}",
1087            4 * hidden_size,
1088            b_hh.len()
1089        ));
1090    }
1091
1092    // gates = x @ w_ih^T + b_ih + h_prev @ w_hh^T + b_hh
1093    // Use linear which does: input @ weight^T + bias
1094    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1095    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1096    let gates = gates_ih.add(&gates_hh).map_err(map_err)?;
1097
1098    // gates is [batch, 4*hidden_size]. Split into 4 chunks along dim 1.
1099    let chunks = gates.chunk(4, 1).map_err(map_err)?;
1100    let gates_i = &chunks[0];
1101    let gates_f = &chunks[1];
1102    let gates_g = &chunks[2];
1103    let gates_o = &chunks[3];
1104
1105    // Apply activations
1106    let i = gates_i.sigmoid();
1107    let f = gates_f.sigmoid();
1108    let g = gates_g.tanh_activation();
1109    let o = gates_o.sigmoid();
1110
1111    // c_new = f * c_prev + i * g
1112    let fc = f.mul_elem(c_prev).map_err(map_err)?;
1113    let ig = i.mul_elem(&g).map_err(map_err)?;
1114    let c_new = fc.add(&ig).map_err(map_err)?;
1115
1116    // h_new = o * tanh(c_new)
1117    let c_tanh = c_new.tanh_activation();
1118    let h_new = o.mul_elem(&c_tanh).map_err(map_err)?;
1119
1120    Ok((h_new, c_new))
1121}
1122
1123// ---------------------------------------------------------------------------
1124// GRU cell forward pass
1125// ---------------------------------------------------------------------------
1126
1127/// GRU cell forward pass.
1128///
1129/// * `x`:      `[batch, input_size]`
1130/// * `h_prev`: `[batch, hidden_size]`
1131/// * `w_ih`:   `[3*hidden_size, input_size]`
1132/// * `w_hh`:   `[3*hidden_size, hidden_size]`
1133/// * `b_ih`:   `[3*hidden_size]`
1134/// * `b_hh`:   `[3*hidden_size]`
1135///
1136/// Returns `h_new` tensor of shape `[batch, hidden_size]`.
1137///
1138/// Gate layout: r (reset), z (update), n (new) — each of size `hidden_size`.
1139/// All reductions use Kahan summation via the existing `Tensor::linear` method.
1140pub fn gru_cell(
1141    x: &Tensor,
1142    h_prev: &Tensor,
1143    w_ih: &Tensor,
1144    w_hh: &Tensor,
1145    b_ih: &Tensor,
1146    b_hh: &Tensor,
1147) -> Result<Tensor, String> {
1148    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1149
1150    // Validate shapes
1151    if x.ndim() != 2 {
1152        return Err("gru_cell: x must be 2-D [batch, input_size]".into());
1153    }
1154    if h_prev.ndim() != 2 {
1155        return Err("gru_cell: h_prev must be 2-D [batch, hidden_size]".into());
1156    }
1157    let hidden_size = h_prev.shape()[1];
1158    if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1159        return Err(format!(
1160            "gru_cell: w_ih must be [3*hidden_size, input_size], got {:?}",
1161            w_ih.shape()
1162        ));
1163    }
1164    if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1165        return Err(format!(
1166            "gru_cell: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1167            w_hh.shape()
1168        ));
1169    }
1170    if b_ih.len() != 3 * hidden_size {
1171        return Err(format!(
1172            "gru_cell: b_ih must have length 3*hidden_size={}, got {}",
1173            3 * hidden_size,
1174            b_ih.len()
1175        ));
1176    }
1177    if b_hh.len() != 3 * hidden_size {
1178        return Err(format!(
1179            "gru_cell: b_hh must have length 3*hidden_size={}, got {}",
1180            3 * hidden_size,
1181            b_hh.len()
1182        ));
1183    }
1184
1185    // gates_ih = x @ w_ih^T + b_ih   → [batch, 3*hidden_size]
1186    // gates_hh = h_prev @ w_hh^T + b_hh → [batch, 3*hidden_size]
1187    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1188    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1189
1190    // Split into r, z, n portions
1191    let ih_chunks = gates_ih.chunk(3, 1).map_err(map_err)?;
1192    let hh_chunks = gates_hh.chunk(3, 1).map_err(map_err)?;
1193    let r_ih = &ih_chunks[0];
1194    let z_ih = &ih_chunks[1];
1195    let n_ih = &ih_chunks[2];
1196    let r_hh = &hh_chunks[0];
1197    let z_hh = &hh_chunks[1];
1198    let n_hh = &hh_chunks[2];
1199
1200    // r = sigmoid(r_ih + r_hh)
1201    let r = r_ih.add(r_hh).map_err(map_err)?.sigmoid();
1202    // z = sigmoid(z_ih + z_hh)
1203    let z = z_ih.add(z_hh).map_err(map_err)?.sigmoid();
1204    // n = tanh(n_ih + r * n_hh)
1205    let r_n_hh = r.mul_elem(n_hh).map_err(map_err)?;
1206    let n = n_ih.add(&r_n_hh).map_err(map_err)?.tanh_activation();
1207
1208    // h_new = (1 - z) * n + z * h_prev
1209    // Build (1 - z): negate z then add ones
1210    let ones = Tensor::ones(z.shape());
1211    let one_minus_z = ones.sub(&z).map_err(map_err)?;
1212    let term1 = one_minus_z.mul_elem(&n).map_err(map_err)?;
1213    let term2 = z.mul_elem(h_prev).map_err(map_err)?;
1214    let h_new = term1.add(&term2).map_err(map_err)?;
1215
1216    Ok(h_new)
1217}
1218
1219// ---------------------------------------------------------------------------
1220// Fused LSTM cell forward pass (allocation-optimized)
1221// ---------------------------------------------------------------------------
1222
1223/// Fused LSTM cell: minimizes intermediate tensor allocations.
1224///
1225/// Produces bit-identical results to [`lstm_cell`] but reduces tensor
1226/// allocations from 13 to 4 (2 matmuls via `linear`, 2 output `from_vec`).
1227/// After the two `linear` calls the gate combination, activations, and cell /
1228/// hidden-state updates are computed element-wise in a single scalar loop
1229/// with no additional tensor temporaries.
1230///
1231/// Shape requirements are identical to [`lstm_cell`].
1232pub fn lstm_cell_fused(
1233    x: &Tensor,
1234    h_prev: &Tensor,
1235    c_prev: &Tensor,
1236    w_ih: &Tensor,
1237    w_hh: &Tensor,
1238    b_ih: &Tensor,
1239    b_hh: &Tensor,
1240) -> Result<(Tensor, Tensor), String> {
1241    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1242
1243    // --- Validate shapes (same checks as lstm_cell) --------------------------
1244    if x.ndim() != 2 {
1245        return Err("lstm_cell_fused: x must be 2-D [batch, input_size]".into());
1246    }
1247    if h_prev.ndim() != 2 {
1248        return Err("lstm_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1249    }
1250    if c_prev.ndim() != 2 {
1251        return Err("lstm_cell_fused: c_prev must be 2-D [batch, hidden_size]".into());
1252    }
1253    let batch = x.shape()[0];
1254    let hidden_size = h_prev.shape()[1];
1255    if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1256        return Err(format!(
1257            "lstm_cell_fused: w_ih must be [4*hidden_size, input_size], got {:?}",
1258            w_ih.shape()
1259        ));
1260    }
1261    if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1262        return Err(format!(
1263            "lstm_cell_fused: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1264            w_hh.shape()
1265        ));
1266    }
1267    if b_ih.len() != 4 * hidden_size {
1268        return Err(format!(
1269            "lstm_cell_fused: b_ih must have length 4*hidden_size={}, got {}",
1270            4 * hidden_size,
1271            b_ih.len()
1272        ));
1273    }
1274    if b_hh.len() != 4 * hidden_size {
1275        return Err(format!(
1276            "lstm_cell_fused: b_hh must have length 4*hidden_size={}, got {}",
1277            4 * hidden_size,
1278            b_hh.len()
1279        ));
1280    }
1281
1282    // --- Step 1: two matmuls (reuse existing Kahan-summation linear) ---------
1283    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1284    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1285
1286    // --- Step 2: fused gate combination + activations + state update ---------
1287    let gih = gates_ih.to_vec();
1288    let ghh = gates_hh.to_vec();
1289    let cprev = c_prev.to_vec();
1290
1291    let mut h_new_data = vec![0.0f64; batch * hidden_size];
1292    let mut c_new_data = vec![0.0f64; batch * hidden_size];
1293
1294    for b_idx in 0..batch {
1295        let base = b_idx * 4 * hidden_size;
1296        for h in 0..hidden_size {
1297            // Combine ih + hh for each gate (i, f, g, o)
1298            let gi = gih[base + h] + ghh[base + h];
1299            let gf = gih[base + hidden_size + h] + ghh[base + hidden_size + h];
1300            let gg = gih[base + 2 * hidden_size + h] + ghh[base + 2 * hidden_size + h];
1301            let go = gih[base + 3 * hidden_size + h] + ghh[base + 3 * hidden_size + h];
1302
1303            // Activations (scalar, no tensor allocation)
1304            let i_val = 1.0 / (1.0 + (-gi).exp()); // sigmoid
1305            let f_val = 1.0 / (1.0 + (-gf).exp()); // sigmoid
1306            let g_val = gg.tanh();                   // tanh
1307            let o_val = 1.0 / (1.0 + (-go).exp()); // sigmoid
1308
1309            // Cell and hidden state update
1310            let c_idx = b_idx * hidden_size + h;
1311            let c_val = f_val * cprev[c_idx] + i_val * g_val;
1312            c_new_data[c_idx] = c_val;
1313            h_new_data[c_idx] = o_val * c_val.tanh();
1314        }
1315    }
1316
1317    let h_new = Tensor::from_vec(h_new_data, &[batch, hidden_size]).map_err(map_err)?;
1318    let c_new = Tensor::from_vec(c_new_data, &[batch, hidden_size]).map_err(map_err)?;
1319
1320    Ok((h_new, c_new))
1321}
1322
1323// ---------------------------------------------------------------------------
1324// Fused GRU cell forward pass (allocation-optimized)
1325// ---------------------------------------------------------------------------
1326
1327/// Fused GRU cell: minimizes intermediate tensor allocations.
1328///
1329/// Produces bit-identical results to [`gru_cell`] but reduces tensor
1330/// allocations from ~12 to 3 (2 matmuls via `linear`, 1 output `from_vec`).
1331/// After the two `linear` calls the gate combination, activations, and
1332/// hidden-state update are computed element-wise in a single scalar loop.
1333///
1334/// Shape requirements are identical to [`gru_cell`].
1335pub fn gru_cell_fused(
1336    x: &Tensor,
1337    h_prev: &Tensor,
1338    w_ih: &Tensor,
1339    w_hh: &Tensor,
1340    b_ih: &Tensor,
1341    b_hh: &Tensor,
1342) -> Result<Tensor, String> {
1343    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1344
1345    // --- Validate shapes (same checks as gru_cell) ---------------------------
1346    if x.ndim() != 2 {
1347        return Err("gru_cell_fused: x must be 2-D [batch, input_size]".into());
1348    }
1349    if h_prev.ndim() != 2 {
1350        return Err("gru_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1351    }
1352    let batch = x.shape()[0];
1353    let hidden_size = h_prev.shape()[1];
1354    if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1355        return Err(format!(
1356            "gru_cell_fused: w_ih must be [3*hidden_size, input_size], got {:?}",
1357            w_ih.shape()
1358        ));
1359    }
1360    if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1361        return Err(format!(
1362            "gru_cell_fused: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1363            w_hh.shape()
1364        ));
1365    }
1366    if b_ih.len() != 3 * hidden_size {
1367        return Err(format!(
1368            "gru_cell_fused: b_ih must have length 3*hidden_size={}, got {}",
1369            3 * hidden_size,
1370            b_ih.len()
1371        ));
1372    }
1373    if b_hh.len() != 3 * hidden_size {
1374        return Err(format!(
1375            "gru_cell_fused: b_hh must have length 3*hidden_size={}, got {}",
1376            3 * hidden_size,
1377            b_hh.len()
1378        ));
1379    }
1380
1381    // --- Step 1: two matmuls (reuse existing Kahan-summation linear) ---------
1382    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1383    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1384
1385    // --- Step 2: fused gate combination + activations + state update ---------
1386    let gih = gates_ih.to_vec();
1387    let ghh = gates_hh.to_vec();
1388    let hp = h_prev.to_vec();
1389
1390    let mut h_new_data = vec![0.0f64; batch * hidden_size];
1391
1392    for b_idx in 0..batch {
1393        let base = b_idx * 3 * hidden_size;
1394        for h in 0..hidden_size {
1395            // r = sigmoid(ih_r + hh_r)
1396            let r_val =
1397                1.0 / (1.0 + (-(gih[base + h] + ghh[base + h])).exp());
1398            // z = sigmoid(ih_z + hh_z)
1399            let z_val = 1.0
1400                / (1.0
1401                    + (-(gih[base + hidden_size + h]
1402                        + ghh[base + hidden_size + h]))
1403                        .exp());
1404            // n = tanh(ih_n + r * hh_n)
1405            let n_val = (gih[base + 2 * hidden_size + h]
1406                + r_val * ghh[base + 2 * hidden_size + h])
1407                .tanh();
1408
1409            // h_new = (1 - z) * n + z * h_prev
1410            let h_idx = b_idx * hidden_size + h;
1411            h_new_data[h_idx] = (1.0 - z_val) * n_val + z_val * hp[h_idx];
1412        }
1413    }
1414
1415    Tensor::from_vec(h_new_data, &[batch, hidden_size])
1416        .map_err(|e| format!("{e}"))
1417}
1418
1419// ---------------------------------------------------------------------------
1420// Multi-Head Attention
1421// ---------------------------------------------------------------------------
1422
1423/// Multi-head attention: Q, K, V projections + scaled dot-product attention + output projection.
1424///
1425/// * `q`, `k`, `v`: `[batch, seq, model_dim]`
1426/// * `w_q`, `w_k`, `w_v`, `w_o`: `[model_dim, model_dim]`
1427/// * `b_q`, `b_k`, `b_v`, `b_o`: `[model_dim]`
1428/// * `num_heads`: number of attention heads (`model_dim` must be divisible by `num_heads`)
1429///
1430/// Returns `[batch, seq, model_dim]`.
1431///
1432/// All reductions use Kahan summation via the existing `Tensor::linear` and
1433/// `Tensor::scaled_dot_product_attention` methods.
1434pub fn multi_head_attention(
1435    q: &Tensor,
1436    k: &Tensor,
1437    v: &Tensor,
1438    w_q: &Tensor,
1439    w_k: &Tensor,
1440    w_v: &Tensor,
1441    w_o: &Tensor,
1442    b_q: &Tensor,
1443    b_k: &Tensor,
1444    b_v: &Tensor,
1445    b_o: &Tensor,
1446    num_heads: usize,
1447) -> Result<Tensor, String> {
1448    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1449
1450    if q.ndim() != 3 {
1451        return Err("multi_head_attention: q must be 3-D [batch, seq, model_dim]".into());
1452    }
1453
1454    // Linear projections: Q = q.linear(w_q, b_q), etc.
1455    let q_proj = q.linear(w_q, b_q).map_err(map_err)?;
1456    let k_proj = k.linear(w_k, b_k).map_err(map_err)?;
1457    let v_proj = v.linear(w_v, b_v).map_err(map_err)?;
1458
1459    // Split heads: [batch, seq, model_dim] -> [batch, num_heads, seq, head_dim]
1460    let q_heads = q_proj.split_heads(num_heads).map_err(map_err)?;
1461    let k_heads = k_proj.split_heads(num_heads).map_err(map_err)?;
1462    let v_heads = v_proj.split_heads(num_heads).map_err(map_err)?;
1463
1464    // Scaled dot-product attention: works on [..., seq, head_dim]
1465    let attn = Tensor::scaled_dot_product_attention(&q_heads, &k_heads, &v_heads)
1466        .map_err(map_err)?;
1467
1468    // Merge heads back: [batch, num_heads, seq, head_dim] -> [batch, seq, model_dim]
1469    let merged = attn.merge_heads().map_err(map_err)?;
1470
1471    // Output projection
1472    let output = merged.linear(w_o, b_o).map_err(map_err)?;
1473
1474    Ok(output)
1475}
1476
1477// ---------------------------------------------------------------------------
1478// Embedding layer
1479// ---------------------------------------------------------------------------
1480
1481/// Embedding lookup: maps integer indices to dense vectors.
1482///
1483/// Performs a table lookup in the weight matrix, selecting rows
1484/// corresponding to the given indices.
1485///
1486/// # Arguments
1487///
1488/// * `weight` - Embedding matrix of shape `[vocab_size, embed_dim]`
1489/// * `indices` - 1-D tensor of integer indices
1490///
1491/// # Returns
1492///
1493/// Tensor of shape `[len(indices), embed_dim]`
1494///
1495/// # Errors
1496///
1497/// Returns an error if any index is out of bounds for the vocabulary size.
1498pub fn embedding(weight: &crate::tensor::Tensor, indices: &[i64]) -> Result<crate::tensor::Tensor, String> {
1499    let shape = weight.shape();
1500    if shape.len() != 2 {
1501        return Err(format!("embedding: weight must be 2-D [vocab_size, embed_dim], got {:?}", shape));
1502    }
1503    let vocab_size = shape[0];
1504    let embed_dim = shape[1];
1505    let weight_data = weight.to_vec();
1506
1507    let mut out = Vec::with_capacity(indices.len() * embed_dim);
1508    for &idx in indices {
1509        let i = idx as usize;
1510        if i >= vocab_size {
1511            return Err(format!("embedding: index {} out of bounds for vocab_size {}", idx, vocab_size));
1512        }
1513        let start = i * embed_dim;
1514        out.extend_from_slice(&weight_data[start..start + embed_dim]);
1515    }
1516    crate::tensor::Tensor::from_vec(out, &[indices.len(), embed_dim])
1517        .map_err(|e| e.to_string())
1518}
1519
1520// ---------------------------------------------------------------------------
1521// Deterministic mini-batch indices
1522// ---------------------------------------------------------------------------
1523
1524/// Creates deterministic batch index ranges for mini-batch training.
1525///
1526/// Generates `(start, end)` pairs that cover the entire dataset, shuffled
1527/// deterministically using the provided seed via [`SplitMix64`].
1528///
1529/// # Arguments
1530///
1531/// * `dataset_size` - Total number of samples
1532/// * `batch_size` - Number of samples per batch (last batch may be smaller)
1533/// * `seed` - RNG seed for deterministic shuffling
1534///
1535/// # Returns
1536///
1537/// A vector of `(start, end)` index pairs covering all samples.
1538pub fn batch_indices(dataset_size: usize, batch_size: usize, seed: u64) -> Vec<(usize, usize)> {
1539    use cjc_repro::Rng;
1540    let mut rng = Rng::seeded(seed);
1541    // Fisher-Yates shuffle with deterministic RNG
1542    let mut indices: Vec<usize> = (0..dataset_size).collect();
1543    for i in (1..dataset_size).rev() {
1544        let j = (rng.next_u64() as usize) % (i + 1);
1545        indices.swap(i, j);
1546    }
1547    let mut batches = Vec::new();
1548    let mut i = 0;
1549    while i < dataset_size {
1550        let end = (i + batch_size).min(dataset_size);
1551        batches.push((i, end));
1552        i = end;
1553    }
1554    batches
1555}
1556
1557// ---------------------------------------------------------------------------
1558// Tests
1559// ---------------------------------------------------------------------------
1560
1561#[cfg(test)]
1562mod tests {
1563    use super::*;
1564
1565    #[test]
1566    fn test_mse_zero() {
1567        let pred = [1.0, 2.0, 3.0];
1568        let target = [1.0, 2.0, 3.0];
1569        assert_eq!(mse_loss(&pred, &target).unwrap(), 0.0);
1570    }
1571
1572    #[test]
1573    fn test_mse_basic() {
1574        let pred = [1.0, 2.0, 3.0];
1575        let target = [2.0, 3.0, 4.0];
1576        assert_eq!(mse_loss(&pred, &target).unwrap(), 1.0);
1577    }
1578
1579    #[test]
1580    fn test_huber_loss_quadratic() {
1581        let pred = [1.0];
1582        let target = [1.5];
1583        let h = huber_loss(&pred, &target, 1.0).unwrap();
1584        // |0.5| < 1.0, so quadratic: 0.5 * 0.25 = 0.125
1585        assert!((h - 0.125).abs() < 1e-12);
1586    }
1587
1588    #[test]
1589    fn test_sgd_step() {
1590        let mut params = [1.0, 2.0];
1591        let grads = [0.1, 0.2];
1592        let mut state = SgdState::new(2, 0.1, 0.0);
1593        sgd_step(&mut params, &grads, &mut state);
1594        assert!((params[0] - 0.99).abs() < 1e-12);
1595        assert!((params[1] - 1.98).abs() < 1e-12);
1596    }
1597
1598    #[test]
1599    fn test_adam_step() {
1600        let mut params = [1.0, 2.0];
1601        let grads = [0.1, 0.2];
1602        let mut state = AdamState::new(2, 0.001);
1603        adam_step(&mut params, &grads, &mut state);
1604        // After one step, params should be slightly different
1605        assert!(params[0] < 1.0);
1606        assert!(params[1] < 2.0);
1607    }
1608
1609    #[test]
1610    fn test_confusion_matrix() {
1611        let pred = [true, true, false, false, true];
1612        let actual = [true, false, true, false, true];
1613        let cm = confusion_matrix(&pred, &actual);
1614        assert_eq!(cm.tp, 2);
1615        assert_eq!(cm.fp, 1);
1616        assert_eq!(cm.fn_count, 1);
1617        assert_eq!(cm.tn, 1);
1618    }
1619
1620    #[test]
1621    fn test_precision_recall_f1() {
1622        let cm = ConfusionMatrix { tp: 5, fp: 2, tn: 8, fn_count: 1 };
1623        assert!((precision(&cm) - 5.0 / 7.0).abs() < 1e-12);
1624        assert!((recall(&cm) - 5.0 / 6.0).abs() < 1e-12);
1625    }
1626
1627    #[test]
1628    fn test_auc_perfect() {
1629        let scores = [0.9, 0.8, 0.2, 0.1];
1630        let labels = [true, true, false, false];
1631        let auc = auc_roc(&scores, &labels).unwrap();
1632        assert!((auc - 1.0).abs() < 1e-12);
1633    }
1634
1635    #[test]
1636    fn test_kfold_deterministic() {
1637        let f1 = kfold_indices(100, 5, 42);
1638        let f2 = kfold_indices(100, 5, 42);
1639        for i in 0..5 {
1640            assert_eq!(f1[i].0, f2[i].0);
1641            assert_eq!(f1[i].1, f2[i].1);
1642        }
1643    }
1644
1645    #[test]
1646    fn test_train_test_split_coverage() {
1647        let (train, test) = train_test_split(100, 0.2, 42);
1648        assert_eq!(train.len() + test.len(), 100);
1649        assert_eq!(test.len(), 20);
1650    }
1651
1652    // --- B4: ML Training Extensions tests ---
1653
1654    #[test]
1655    fn test_batch_norm_identity() {
1656        // mean=0, var=1, gamma=1, beta=0, eps=0 → input unchanged
1657        let x = vec![1.0, 2.0, 3.0];
1658        let mean = vec![0.0, 0.0, 0.0];
1659        let var = vec![1.0, 1.0, 1.0];
1660        let gamma = vec![1.0, 1.0, 1.0];
1661        let beta = vec![0.0, 0.0, 0.0];
1662        let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1663        assert!((result[0] - 1.0).abs() < 1e-12);
1664        assert!((result[1] - 2.0).abs() < 1e-12);
1665        assert!((result[2] - 3.0).abs() < 1e-12);
1666    }
1667
1668    #[test]
1669    fn test_batch_norm_shift_scale() {
1670        let x = vec![0.0];
1671        let mean = vec![1.0]; // shift: x - 1 = -1
1672        let var = vec![4.0];  // scale: -1/sqrt(4) = -0.5
1673        let gamma = vec![2.0]; // multiply: 2 * -0.5 = -1
1674        let beta = vec![3.0]; // add: -1 + 3 = 2
1675        let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1676        assert!((result[0] - 2.0).abs() < 1e-12);
1677    }
1678
1679    #[test]
1680    fn test_dropout_mask_seed_determinism() {
1681        let m1 = dropout_mask(100, 0.5, 42);
1682        let m2 = dropout_mask(100, 0.5, 42);
1683        assert_eq!(m1, m2);
1684    }
1685
1686    #[test]
1687    fn test_dropout_mask_different_seeds() {
1688        let m1 = dropout_mask(100, 0.5, 42);
1689        let m2 = dropout_mask(100, 0.5, 99);
1690        assert_ne!(m1, m2);
1691    }
1692
1693    #[test]
1694    fn test_lr_step_decay_schedule() {
1695        let lr0 = lr_step_decay(0.1, 0.5, 0, 10);
1696        assert!((lr0 - 0.1).abs() < 1e-12);
1697        let lr10 = lr_step_decay(0.1, 0.5, 10, 10);
1698        assert!((lr10 - 0.05).abs() < 1e-12);
1699        let lr20 = lr_step_decay(0.1, 0.5, 20, 10);
1700        assert!((lr20 - 0.025).abs() < 1e-12);
1701    }
1702
1703    #[test]
1704    fn test_lr_cosine_endpoints() {
1705        let lr0 = lr_cosine(0.1, 0.001, 0, 100);
1706        assert!((lr0 - 0.1).abs() < 1e-10);
1707        let lr_end = lr_cosine(0.1, 0.001, 100, 100);
1708        assert!((lr_end - 0.001).abs() < 1e-10);
1709    }
1710
1711    #[test]
1712    fn test_lr_linear_warmup() {
1713        let lr0 = lr_linear_warmup(0.1, 0, 10);
1714        assert!((lr0).abs() < 1e-12);
1715        let lr5 = lr_linear_warmup(0.1, 5, 10);
1716        assert!((lr5 - 0.05).abs() < 1e-12);
1717        let lr15 = lr_linear_warmup(0.1, 15, 10);
1718        assert!((lr15 - 0.1).abs() < 1e-12);
1719    }
1720
1721    #[test]
1722    fn test_l1_penalty_known() {
1723        let params = [1.0, -2.0, 3.0];
1724        let p = l1_penalty(&params, 0.1);
1725        assert!((p - 0.6).abs() < 1e-12);
1726    }
1727
1728    #[test]
1729    fn test_l2_penalty_known() {
1730        let params = [1.0, -2.0, 3.0];
1731        let p = l2_penalty(&params, 0.1);
1732        // 0.5 * 0.1 * (1 + 4 + 9) = 0.5 * 0.1 * 14 = 0.7
1733        assert!((p - 0.7).abs() < 1e-12);
1734    }
1735
1736    #[test]
1737    fn test_early_stopping_triggers() {
1738        let mut es = EarlyStoppingState::new(3, 0.01);
1739        assert!(!es.check(1.0)); // best_loss=1.0, wait=0
1740        assert!(!es.check(1.0)); // no improvement, wait=1
1741        assert!(!es.check(1.0)); // no improvement, wait=2
1742        assert!(es.check(1.0));  // no improvement, wait=3 >= patience
1743    }
1744
1745    #[test]
1746    fn test_early_stopping_resets() {
1747        let mut es = EarlyStoppingState::new(3, 0.01);
1748        es.check(1.0);
1749        es.check(1.0); // wait=1
1750        assert!(!es.check(0.5)); // improvement, wait=0
1751        assert!(!es.check(0.5)); // wait=1
1752    }
1753
1754    // --- Phase 3C: PCA tests ---
1755
1756    #[test]
1757    fn test_pca_basic_2d() {
1758        // 4 samples, 2 features — data lies mostly along first axis
1759        let data = Tensor::from_vec(
1760            vec![
1761                1.0, 0.1,
1762                2.0, 0.2,
1763                3.0, 0.3,
1764                4.0, 0.4,
1765            ],
1766            &[4, 2],
1767        )
1768        .unwrap();
1769        let (transformed, components, evr) = pca(&data, 2).unwrap();
1770        assert_eq!(transformed.shape(), &[4, 2]);
1771        assert_eq!(components.shape(), &[2, 2]);
1772        assert_eq!(evr.len(), 2);
1773        // Explained variance ratios should sum to ~1.0
1774        let total: f64 = evr.iter().sum();
1775        assert!(
1776            (total - 1.0).abs() < 1e-8,
1777            "explained variance ratios sum to {} (expected ~1.0)",
1778            total
1779        );
1780        // First component should explain most variance
1781        assert!(evr[0] > 0.9, "first component explains {} of variance", evr[0]);
1782    }
1783
1784    #[test]
1785    fn test_pca_single_component() {
1786        let data = Tensor::from_vec(
1787            vec![
1788                1.0, 2.0, 3.0,
1789                4.0, 5.0, 6.0,
1790                7.0, 8.0, 9.0,
1791            ],
1792            &[3, 3],
1793        )
1794        .unwrap();
1795        let (transformed, components, evr) = pca(&data, 1).unwrap();
1796        assert_eq!(transformed.shape(), &[3, 1]);
1797        assert_eq!(components.shape(), &[1, 3]);
1798        assert_eq!(evr.len(), 1);
1799        assert!(evr[0] > 0.0 && evr[0] <= 1.0);
1800    }
1801
1802    #[test]
1803    fn test_pca_explained_variance_ratio_bounded() {
1804        let data = Tensor::from_vec(
1805            vec![
1806                1.0, 0.0, 0.5,
1807                0.0, 1.0, 0.5,
1808                1.0, 1.0, 1.0,
1809                2.0, 0.0, 1.0,
1810                0.0, 2.0, 1.0,
1811            ],
1812            &[5, 3],
1813        )
1814        .unwrap();
1815        let (_, _, evr) = pca(&data, 3).unwrap();
1816        let total: f64 = evr.iter().sum();
1817        assert!(
1818            total <= 1.0 + 1e-10,
1819            "explained variance ratios sum to {} (should be <= 1.0)",
1820            total
1821        );
1822        for &r in &evr {
1823            assert!(r >= -1e-10, "negative explained variance ratio: {}", r);
1824        }
1825    }
1826
1827    #[test]
1828    fn test_pca_deterministic() {
1829        let data = Tensor::from_vec(
1830            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1831            &[2, 3],
1832        )
1833        .unwrap();
1834        let (t1, c1, e1) = pca(&data, 2).unwrap();
1835        let (t2, c2, e2) = pca(&data, 2).unwrap();
1836        assert_eq!(t1.to_vec(), t2.to_vec(), "PCA transformed not deterministic");
1837        assert_eq!(c1.to_vec(), c2.to_vec(), "PCA components not deterministic");
1838        assert_eq!(e1, e2, "PCA explained variance not deterministic");
1839    }
1840
1841    #[test]
1842    fn test_pca_invalid_n_components() {
1843        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1844        assert!(pca(&data, 0).is_err(), "n_components=0 should fail");
1845        assert!(pca(&data, 3).is_err(), "n_components > min(n,p) should fail");
1846    }
1847
1848    // --- Sprint 2: L-BFGS tests ---
1849
1850    /// Rosenbrock function: f(x,y) = (1-x)^2 + 100*(y-x^2)^2
1851    /// Gradient: df/dx = -2*(1-x) - 400*x*(y-x^2)
1852    ///           df/dy = 200*(y-x^2)
1853    /// Global minimum at (1, 1) with f=0.
1854    fn rosenbrock(p: &[f64]) -> (f64, Vec<f64>) {
1855        let x = p[0];
1856        let y = p[1];
1857        let a = 1.0 - x;
1858        let b = y - x * x;
1859        let val = a * a + 100.0 * b * b;
1860        let gx = -2.0 * a - 400.0 * x * b;
1861        let gy = 200.0 * b;
1862        (val, vec![gx, gy])
1863    }
1864
1865    #[test]
1866    fn test_lbfgs_rosenbrock_converges() {
1867        // Start far from optimum — L-BFGS should converge near (1, 1)
1868        let mut params = vec![-1.0_f64, 2.0_f64];
1869        let mut state = LbfgsState::new(0.5, 10);
1870
1871        let mut converged = false;
1872        for _iter in 0..200 {
1873            let (_, grads) = rosenbrock(&params);
1874            let grad_norm: f64 = kahan_dot(&grads, &grads).sqrt();
1875            if grad_norm < 1e-5 {
1876                converged = true;
1877                break;
1878            }
1879            let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, rosenbrock);
1880            params = new_p;
1881        }
1882        assert!(converged, "L-BFGS did not converge on Rosenbrock; params = {:?}", params);
1883        assert!(
1884            (params[0] - 1.0).abs() < 1e-3,
1885            "x should converge near 1.0, got {}",
1886            params[0]
1887        );
1888        assert!(
1889            (params[1] - 1.0).abs() < 1e-3,
1890            "y should converge near 1.0, got {}",
1891            params[1]
1892        );
1893    }
1894
1895    #[test]
1896    fn test_lbfgs_determinism() {
1897        // Same initial params + same seed → identical results
1898        let init = vec![-1.0_f64, 2.0_f64];
1899
1900        let run = |init: &[f64]| -> Vec<f64> {
1901            let mut params = init.to_vec();
1902            let mut state = LbfgsState::new(0.5, 10);
1903            for _ in 0..20 {
1904                let (_, grads) = rosenbrock(&params);
1905                let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, rosenbrock);
1906                params = new_p;
1907            }
1908            params
1909        };
1910
1911        let r1 = run(&init);
1912        let r2 = run(&init);
1913        assert_eq!(r1, r2, "L-BFGS must be bit-identical across runs");
1914    }
1915
1916    #[test]
1917    fn test_lbfgs_simple_quadratic() {
1918        // f(x) = x^2, minimum at x=0
1919        let mut params = vec![3.0_f64];
1920        let mut state = LbfgsState::new(1.0, 5);
1921        let quadratic = |p: &[f64]| -> (f64, Vec<f64>) {
1922            (p[0] * p[0], vec![2.0 * p[0]])
1923        };
1924
1925        for _ in 0..30 {
1926            let (_, grads) = quadratic(&params);
1927            let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, quadratic);
1928            params = new_p;
1929        }
1930        assert!(
1931            params[0].abs() < 1e-6,
1932            "L-BFGS should minimize x^2 to ~0, got {}",
1933            params[0]
1934        );
1935    }
1936
1937    #[test]
1938    fn test_embedding_basic() {
1939        let weight = crate::tensor::Tensor::from_vec(
1940            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
1941            &[3, 2],
1942        ).unwrap();
1943        let indices = vec![0, 2, 1];
1944        let result = super::embedding(&weight, &indices).unwrap();
1945        assert_eq!(result.shape(), &[3, 2]);
1946        let data = result.to_vec();
1947        assert!((data[0] - 0.1).abs() < 1e-12);
1948        assert!((data[1] - 0.2).abs() < 1e-12);
1949        assert!((data[2] - 0.5).abs() < 1e-12);
1950        assert!((data[3] - 0.6).abs() < 1e-12);
1951        assert!((data[4] - 0.3).abs() < 1e-12);
1952        assert!((data[5] - 0.4).abs() < 1e-12);
1953    }
1954
1955    #[test]
1956    fn test_embedding_out_of_bounds() {
1957        let weight = crate::tensor::Tensor::from_vec(vec![1.0, 2.0], &[1, 2]).unwrap();
1958        let result = super::embedding(&weight, &[1]);
1959        assert!(result.is_err());
1960    }
1961
1962    #[test]
1963    fn test_batch_indices_deterministic() {
1964        let b1 = super::batch_indices(10, 3, 42);
1965        let b2 = super::batch_indices(10, 3, 42);
1966        assert_eq!(b1, b2);
1967        // Should cover all indices
1968        let total: usize = b1.iter().map(|(s, e)| e - s).sum();
1969        assert_eq!(total, 10);
1970    }
1971
1972    #[test]
1973    fn test_wolfe_line_search_armijo() {
1974        // For f(x) = x^2, starting at x=3, direction d=-1 (descent)
1975        // Wolfe search should find a valid step
1976        let params = vec![3.0_f64];
1977        let direction = vec![-1.0_f64];
1978        let grads = vec![6.0_f64]; // gradient at x=3
1979        let f0 = 9.0;
1980
1981        let mut eval_count = 0;
1982        let mut obj = |p: &[f64]| -> (f64, Vec<f64>) {
1983            eval_count += 1;
1984            (p[0] * p[0], vec![2.0 * p[0]])
1985        };
1986
1987        let (alpha, new_params, new_val, _) =
1988            wolfe_line_search(&params, &direction, &mut obj, f0, &grads, 1.0);
1989
1990        // Armijo: f(x + alpha*d) <= f(x) + c1*alpha*g^T*d
1991        let c1 = 1e-4;
1992        let derphi0 = kahan_dot(&grads, &direction); // = -6
1993        assert!(
1994            new_val <= f0 + c1 * alpha * derphi0,
1995            "Armijo condition violated: {} > {} + {} * {} * {}",
1996            new_val, f0, c1, alpha, derphi0
1997        );
1998        assert!(new_params[0] < 3.0, "Step should move toward minimum");
1999    }
2000}