Skip to main content

cjc_runtime/
ml.rs

1//! ML Toolkit — loss functions, optimizers, activations, metrics.
2//!
3//! # Determinism Contract
4//! - All functions are deterministic (no randomness except seeded kfold).
5//! - Kahan summation for all reductions.
6//! - Stable sort for AUC-ROC with index tie-breaking.
7
8use cjc_repro::KahanAccumulatorF64;
9
10use crate::accumulator::BinnedAccumulatorF64;
11use crate::error::RuntimeError;
12use crate::tensor::Tensor;
13
14// ---------------------------------------------------------------------------
15// Loss functions
16// ---------------------------------------------------------------------------
17
18/// Mean Squared Error: sum((pred - target)^2) / n.
19pub fn mse_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
20    if pred.len() != target.len() {
21        return Err("mse_loss: arrays must have same length".into());
22    }
23    if pred.is_empty() {
24        return Err("mse_loss: empty data".into());
25    }
26    let mut acc = KahanAccumulatorF64::new();
27    for i in 0..pred.len() {
28        let d = pred[i] - target[i];
29        acc.add(d * d);
30    }
31    Ok(acc.finalize() / pred.len() as f64)
32}
33
34/// Cross-entropy loss: -sum(target * ln(pred + eps)) / n.
35pub fn cross_entropy_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
36    if pred.len() != target.len() {
37        return Err("cross_entropy_loss: arrays must have same length".into());
38    }
39    if pred.is_empty() {
40        return Err("cross_entropy_loss: empty data".into());
41    }
42    let eps = 1e-12;
43    let mut acc = KahanAccumulatorF64::new();
44    for i in 0..pred.len() {
45        acc.add(-target[i] * (pred[i] + eps).ln());
46    }
47    Ok(acc.finalize() / pred.len() as f64)
48}
49
50/// Binary cross-entropy: -sum(t*ln(p) + (1-t)*ln(1-p)) / n.
51pub fn binary_cross_entropy(pred: &[f64], target: &[f64]) -> Result<f64, String> {
52    if pred.len() != target.len() {
53        return Err("binary_cross_entropy: arrays must have same length".into());
54    }
55    if pred.is_empty() {
56        return Err("binary_cross_entropy: empty data".into());
57    }
58    let eps = 1e-12;
59    let mut acc = KahanAccumulatorF64::new();
60    for i in 0..pred.len() {
61        let p = pred[i].max(eps).min(1.0 - eps);
62        acc.add(-(target[i] * p.ln() + (1.0 - target[i]) * (1.0 - p).ln()));
63    }
64    Ok(acc.finalize() / pred.len() as f64)
65}
66
67/// Huber loss: quadratic for small errors, linear for large.
68pub fn huber_loss(pred: &[f64], target: &[f64], delta: f64) -> Result<f64, String> {
69    if pred.len() != target.len() {
70        return Err("huber_loss: arrays must have same length".into());
71    }
72    if pred.is_empty() {
73        return Err("huber_loss: empty data".into());
74    }
75    let mut acc = KahanAccumulatorF64::new();
76    for i in 0..pred.len() {
77        let d = (pred[i] - target[i]).abs();
78        if d <= delta {
79            acc.add(0.5 * d * d);
80        } else {
81            acc.add(delta * (d - 0.5 * delta));
82        }
83    }
84    Ok(acc.finalize() / pred.len() as f64)
85}
86
87/// Hinge loss: sum(max(0, 1 - target * pred)) / n.
88pub fn hinge_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
89    if pred.len() != target.len() {
90        return Err("hinge_loss: arrays must have same length".into());
91    }
92    if pred.is_empty() {
93        return Err("hinge_loss: empty data".into());
94    }
95    let mut acc = KahanAccumulatorF64::new();
96    for i in 0..pred.len() {
97        acc.add((1.0 - target[i] * pred[i]).max(0.0));
98    }
99    Ok(acc.finalize() / pred.len() as f64)
100}
101
102// ---------------------------------------------------------------------------
103// Optimizers
104// ---------------------------------------------------------------------------
105
106use crate::idx::ParamIdx;
107
108/// SGD optimizer state.
109///
110/// Phase 2b: the `velocity: Vec<f64>` is the canonical per-parameter
111/// state buffer. The field stays `pub` for backward compatibility with
112/// existing readers (cjc-eval / cjc-mir-exec read `state.velocity.len()`),
113/// but new code should prefer the typed accessors `velocity_at`,
114/// `set_velocity_at`, `n_params` so that parameter-position vs node-
115/// index confusion is caught at compile time.
116///
117/// See `ADR-0022` (Phase 2 — typed-ID newtypes) for the policy.
118pub struct SgdState {
119    pub lr: f64,
120    pub momentum: f64,
121    pub velocity: Vec<f64>,
122}
123
124impl SgdState {
125    pub fn new(n_params: usize, lr: f64, momentum: f64) -> Self {
126        Self { lr, momentum, velocity: vec![0.0; n_params] }
127    }
128
129    /// Number of parameters tracked by this state.
130    #[inline]
131    pub fn n_params(&self) -> usize {
132        self.velocity.len()
133    }
134
135    /// Read the velocity at a typed parameter index. Out-of-range
136    /// `ParamIdx` panics (consistent with `Vec` indexing).
137    #[inline]
138    pub fn velocity_at(&self, p: ParamIdx) -> f64 {
139        self.velocity[p.index()]
140    }
141
142    /// Write the velocity at a typed parameter index.
143    #[inline]
144    pub fn set_velocity_at(&mut self, p: ParamIdx, value: f64) {
145        self.velocity[p.index()] = value;
146    }
147}
148
149/// SGD step: sequential, deterministic.
150///
151/// Loop iterates `ParamIdx(0)..ParamIdx(n_params)` so each step uses
152/// the typed accessor. This produces bit-identical output to the prior
153/// `usize`-indexed loop (verified by regression).
154pub fn sgd_step(params: &mut [f64], grads: &[f64], state: &mut SgdState) {
155    let n = params.len();
156    for i in 0..n {
157        let p = ParamIdx::from_usize(i);
158        let new_velocity = state.momentum * state.velocity_at(p) + grads[i];
159        state.set_velocity_at(p, new_velocity);
160        params[i] -= state.lr * state.velocity_at(p);
161    }
162}
163
164/// Adam optimizer state.
165///
166/// Phase 2b: `m`/`v` are first/second moment buffers. The fields stay
167/// `pub` for backward compatibility (cjc-eval / cjc-mir-exec read
168/// `state.m.len()`), but new code should prefer the typed accessors
169/// `m_at` / `v_at` / `set_m_at` / `set_v_at` / `n_params`.
170///
171/// See `ADR-0022` (Phase 2 — typed-ID newtypes) for the policy.
172pub struct AdamState {
173    pub lr: f64,
174    pub beta1: f64,
175    pub beta2: f64,
176    pub eps: f64,
177    pub t: u64,
178    pub m: Vec<f64>,
179    pub v: Vec<f64>,
180}
181
182impl AdamState {
183    pub fn new(n_params: usize, lr: f64) -> Self {
184        Self {
185            lr,
186            beta1: 0.9,
187            beta2: 0.999,
188            eps: 1e-8,
189            t: 0,
190            m: vec![0.0; n_params],
191            v: vec![0.0; n_params],
192        }
193    }
194
195    /// Number of parameters tracked by this state.
196    #[inline]
197    pub fn n_params(&self) -> usize {
198        self.m.len()
199    }
200
201    /// Read the first-moment estimate at a typed parameter index.
202    #[inline]
203    pub fn m_at(&self, p: ParamIdx) -> f64 {
204        self.m[p.index()]
205    }
206
207    /// Write the first-moment estimate at a typed parameter index.
208    #[inline]
209    pub fn set_m_at(&mut self, p: ParamIdx, value: f64) {
210        self.m[p.index()] = value;
211    }
212
213    /// Read the second-moment estimate at a typed parameter index.
214    #[inline]
215    pub fn v_at(&self, p: ParamIdx) -> f64 {
216        self.v[p.index()]
217    }
218
219    /// Write the second-moment estimate at a typed parameter index.
220    #[inline]
221    pub fn set_v_at(&mut self, p: ParamIdx, value: f64) {
222        self.v[p.index()] = value;
223    }
224}
225
226/// Adam step: sequential, deterministic.
227///
228/// Loop iterates `ParamIdx(0)..ParamIdx(n_params)` so each step uses
229/// the typed accessors. Same arithmetic as the prior `usize`-indexed
230/// loop; verified bit-identical by regression.
231pub fn adam_step(params: &mut [f64], grads: &[f64], state: &mut AdamState) {
232    state.t += 1;
233    let t = state.t as f64;
234    let n = params.len();
235    for i in 0..n {
236        let p = ParamIdx::from_usize(i);
237        let new_m = state.beta1 * state.m_at(p) + (1.0 - state.beta1) * grads[i];
238        let new_v = state.beta2 * state.v_at(p) + (1.0 - state.beta2) * grads[i] * grads[i];
239        state.set_m_at(p, new_m);
240        state.set_v_at(p, new_v);
241        let m_hat = new_m / (1.0 - state.beta1.powf(t));
242        let v_hat = new_v / (1.0 - state.beta2.powf(t));
243        params[i] -= state.lr * m_hat / (v_hat.sqrt() + state.eps);
244    }
245}
246
247// ---------------------------------------------------------------------------
248// Classification metrics (Sprint 6)
249// ---------------------------------------------------------------------------
250
251/// Binary confusion matrix.
252#[derive(Debug, Clone)]
253pub struct ConfusionMatrix {
254    pub tp: usize,
255    pub fp: usize,
256    pub tn: usize,
257    pub fn_count: usize,
258}
259
260/// Build confusion matrix from predicted and actual boolean labels.
261pub fn confusion_matrix(predicted: &[bool], actual: &[bool]) -> ConfusionMatrix {
262    let mut tp = 0;
263    let mut fp = 0;
264    let mut tn = 0;
265    let mut fn_count = 0;
266    for i in 0..predicted.len().min(actual.len()) {
267        match (predicted[i], actual[i]) {
268            (true, true) => tp += 1,
269            (true, false) => fp += 1,
270            (false, true) => fn_count += 1,
271            (false, false) => tn += 1,
272        }
273    }
274    ConfusionMatrix { tp, fp, tn, fn_count }
275}
276
277/// Precision: TP / (TP + FP).
278pub fn precision(cm: &ConfusionMatrix) -> f64 {
279    let denom = cm.tp + cm.fp;
280    if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
281}
282
283/// Recall / sensitivity: TP / (TP + FN).
284pub fn recall(cm: &ConfusionMatrix) -> f64 {
285    let denom = cm.tp + cm.fn_count;
286    if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
287}
288
289/// F1 score: 2 * (precision * recall) / (precision + recall).
290pub fn f1_score(cm: &ConfusionMatrix) -> f64 {
291    let p = precision(cm);
292    let r = recall(cm);
293    if p + r == 0.0 { 0.0 } else { 2.0 * p * r / (p + r) }
294}
295
296/// Accuracy: (TP + TN) / total.
297pub fn accuracy(cm: &ConfusionMatrix) -> f64 {
298    let total = cm.tp + cm.fp + cm.tn + cm.fn_count;
299    if total == 0 { 0.0 } else { (cm.tp + cm.tn) as f64 / total as f64 }
300}
301
302/// AUC-ROC via trapezoidal rule.
303/// DETERMINISM: sort by score with stable sort + index tie-breaking.
304pub fn auc_roc(scores: &[f64], labels: &[bool]) -> Result<f64, String> {
305    if scores.len() != labels.len() {
306        return Err("auc_roc: scores and labels must have same length".into());
307    }
308    let n = scores.len();
309    if n == 0 {
310        return Err("auc_roc: empty data".into());
311    }
312    // Sort by score descending, stable with index tie-breaking
313    let mut indexed: Vec<(usize, f64, bool)> = scores.iter().zip(labels.iter())
314        .enumerate()
315        .map(|(i, (&s, &l))| (i, s, l))
316        .collect();
317    indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
318
319    let pos_count = labels.iter().filter(|&&l| l).count();
320    let neg_count = n - pos_count;
321    if pos_count == 0 || neg_count == 0 {
322        return Err("auc_roc: need both positive and negative labels".into());
323    }
324
325    let mut auc = 0.0;
326    let mut tp = 0.0;
327    let mut fp = 0.0;
328    let mut prev_fpr = 0.0;
329    let mut prev_tpr = 0.0;
330
331    for &(_, _, label) in &indexed {
332        if label { tp += 1.0; } else { fp += 1.0; }
333        let tpr = tp / pos_count as f64;
334        let fpr = fp / neg_count as f64;
335        auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
336        prev_fpr = fpr;
337        prev_tpr = tpr;
338    }
339    Ok(auc)
340}
341
342/// K-fold cross-validation indices.
343/// DETERMINISM: uses seeded RNG (Fisher-Yates).
344pub fn kfold_indices(n: usize, k: usize, seed: u64) -> Vec<(Vec<usize>, Vec<usize>)> {
345    let mut rng = cjc_repro::Rng::seeded(seed);
346    // Fisher-Yates shuffle of [0..n]
347    let mut indices: Vec<usize> = (0..n).collect();
348    for i in (1..n).rev() {
349        let j = (rng.next_u64() as usize) % (i + 1);
350        indices.swap(i, j);
351    }
352    let fold_size = n / k;
353    let mut folds = Vec::with_capacity(k);
354    for fold in 0..k {
355        let start = fold * fold_size;
356        let end = if fold == k - 1 { n } else { start + fold_size };
357        let test: Vec<usize> = indices[start..end].to_vec();
358        let train: Vec<usize> = indices[..start].iter()
359            .chain(indices[end..].iter())
360            .copied()
361            .collect();
362        folds.push((train, test));
363    }
364    folds
365}
366
367/// Train/test split indices.
368pub fn train_test_split(n: usize, test_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
369    let mut rng = cjc_repro::Rng::seeded(seed);
370    let mut indices: Vec<usize> = (0..n).collect();
371    for i in (1..n).rev() {
372        let j = (rng.next_u64() as usize) % (i + 1);
373        indices.swap(i, j);
374    }
375    let test_size = ((n as f64) * test_fraction).round() as usize;
376    let test = indices[..test_size].to_vec();
377    let train = indices[test_size..].to_vec();
378    (train, test)
379}
380
381// ---------------------------------------------------------------------------
382// Bootstrap Resampling
383// ---------------------------------------------------------------------------
384
385/// Bootstrap confidence interval for a statistic (e.g., mean).
386/// Returns (point_estimate, ci_lower, ci_upper, standard_error).
387/// `stat_fn` is 0=mean, 1=median.
388pub fn bootstrap(data: &[f64], n_resamples: usize, stat_fn: usize, seed: u64) -> Result<(f64, f64, f64, f64), String> {
389    if data.is_empty() { return Err("bootstrap: empty data".into()); }
390    let n = data.len();
391
392    // Compute the statistic on original data
393    let point = compute_stat(data, stat_fn)?;
394
395    // Bootstrap resampling
396    let mut rng = cjc_repro::Rng::seeded(seed);
397    let mut stats = Vec::with_capacity(n_resamples);
398    let mut resample = Vec::with_capacity(n);
399
400    for _ in 0..n_resamples {
401        resample.clear();
402        for _ in 0..n {
403            let idx = (rng.next_u64() as usize) % n;
404            resample.push(data[idx]);
405        }
406        stats.push(compute_stat(&resample, stat_fn)?);
407    }
408
409    // Sort for percentile CI
410    stats.sort_by(|a, b| a.total_cmp(b));
411
412    let ci_lower = stats[(n_resamples as f64 * 0.025) as usize];
413    let ci_upper = stats[(n_resamples as f64 * 0.975).min((n_resamples - 1) as f64) as usize];
414
415    // Standard error
416    let mean_stats: f64 = {
417        let mut acc = cjc_repro::KahanAccumulatorF64::new();
418        for &s in &stats { acc.add(s); }
419        acc.finalize() / n_resamples as f64
420    };
421    let se = {
422        let mut acc = cjc_repro::KahanAccumulatorF64::new();
423        for &s in &stats { let d = s - mean_stats; acc.add(d * d); }
424        (acc.finalize() / (n_resamples as f64 - 1.0)).sqrt()
425    };
426
427    Ok((point, ci_lower, ci_upper, se))
428}
429
430fn compute_stat(data: &[f64], stat_fn: usize) -> Result<f64, String> {
431    match stat_fn {
432        0 => {
433            // Mean
434            let mut acc = cjc_repro::KahanAccumulatorF64::new();
435            for &x in data { acc.add(x); }
436            Ok(acc.finalize() / data.len() as f64)
437        }
438        1 => {
439            // Median
440            let mut sorted = data.to_vec();
441            sorted.sort_by(|a, b| a.total_cmp(b));
442            let n = sorted.len();
443            if n % 2 == 0 {
444                Ok((sorted[n/2 - 1] + sorted[n/2]) / 2.0)
445            } else {
446                Ok(sorted[n/2])
447            }
448        }
449        _ => Err(format!("bootstrap: unknown stat_fn {}", stat_fn)),
450    }
451}
452
453/// Permutation test: test whether two groups differ on a statistic.
454/// Returns (observed_diff, p_value).
455pub fn permutation_test(x: &[f64], y: &[f64], n_perms: usize, seed: u64) -> Result<(f64, f64), String> {
456    if x.is_empty() || y.is_empty() { return Err("permutation_test: empty group".into()); }
457
458    let nx = x.len();
459    let combined: Vec<f64> = x.iter().chain(y.iter()).copied().collect();
460    let n = combined.len();
461
462    // Observed difference of means
463    let mean_x = compute_stat(x, 0)?;
464    let mean_y = compute_stat(y, 0)?;
465    let observed = (mean_x - mean_y).abs();
466
467    // Permutation
468    let mut rng = cjc_repro::Rng::seeded(seed);
469    let mut count_extreme = 0usize;
470    let mut perm = combined.clone();
471
472    for _ in 0..n_perms {
473        // Fisher-Yates shuffle
474        for i in (1..n).rev() {
475            let j = (rng.next_u64() as usize) % (i + 1);
476            perm.swap(i, j);
477        }
478        let perm_mean_x = compute_stat(&perm[..nx], 0)?;
479        let perm_mean_y = compute_stat(&perm[nx..], 0)?;
480        if (perm_mean_x - perm_mean_y).abs() >= observed {
481            count_extreme += 1;
482        }
483    }
484
485    let p_value = count_extreme as f64 / n_perms as f64;
486    Ok((observed, p_value))
487}
488
489/// Stratified train/test split: maintains class proportions in both sets.
490/// `labels` is an array of integer class labels, `test_frac` is fraction for test set.
491/// Returns (train_indices, test_indices).
492pub fn stratified_split(labels: &[i64], test_frac: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
493    use std::collections::BTreeMap;
494
495    let n = labels.len();
496    // Group indices by label
497    let mut groups: BTreeMap<i64, Vec<usize>> = BTreeMap::new();
498    for (i, &label) in labels.iter().enumerate() {
499        groups.entry(label).or_default().push(i);
500    }
501
502    let mut train = Vec::with_capacity(n);
503    let mut test = Vec::with_capacity(n);
504    let mut rng = cjc_repro::Rng::seeded(seed);
505
506    for (_label, mut indices) in groups {
507        // Shuffle indices within each stratum
508        let m = indices.len();
509        for i in (1..m).rev() {
510            let j = (rng.next_u64() as usize) % (i + 1);
511            indices.swap(i, j);
512        }
513        let n_test = ((m as f64 * test_frac).round() as usize).max(if m > 1 { 1 } else { 0 });
514        let n_test = n_test.min(m);
515        test.extend_from_slice(&indices[..n_test]);
516        train.extend_from_slice(&indices[n_test..]);
517    }
518
519    // Sort both for deterministic output order
520    train.sort();
521    test.sort();
522    (train, test)
523}
524
525// ---------------------------------------------------------------------------
526// Phase B4: ML Training Extensions
527// ---------------------------------------------------------------------------
528
529/// Batch normalization (inference mode).
530/// y = gamma * (x - running_mean) / sqrt(running_var + eps) + beta.
531pub fn batch_norm(
532    x: &[f64],
533    running_mean: &[f64],
534    running_var: &[f64],
535    gamma: &[f64],
536    beta: &[f64],
537    eps: f64,
538) -> Result<Vec<f64>, String> {
539    let n = x.len();
540    if running_mean.len() != n || running_var.len() != n || gamma.len() != n || beta.len() != n {
541        return Err("batch_norm: all arrays must have same length".into());
542    }
543    let mut result = Vec::with_capacity(n);
544    for i in 0..n {
545        let normed = (x[i] - running_mean[i]) / (running_var[i] + eps).sqrt();
546        result.push(gamma[i] * normed + beta[i]);
547    }
548    Ok(result)
549}
550
551/// Dropout mask generation using seeded RNG for determinism.
552/// Returns mask of 0.0 and scale values (1/(1-p)) using inverted dropout.
553pub fn dropout_mask(n: usize, drop_prob: f64, seed: u64) -> Vec<f64> {
554    let mut rng = cjc_repro::Rng::seeded(seed);
555    let scale = if drop_prob < 1.0 { 1.0 / (1.0 - drop_prob) } else { 0.0 };
556    let mut mask = Vec::with_capacity(n);
557    for _ in 0..n {
558        let r = (rng.next_u64() as f64) / (u64::MAX as f64);
559        if r < drop_prob {
560            mask.push(0.0);
561        } else {
562            mask.push(scale);
563        }
564    }
565    mask
566}
567
568/// Apply dropout: element-wise multiply data by mask.
569pub fn apply_dropout(data: &[f64], mask: &[f64]) -> Result<Vec<f64>, String> {
570    if data.len() != mask.len() {
571        return Err("apply_dropout: data and mask must have same length".into());
572    }
573    Ok(data.iter().zip(mask.iter()).map(|(&d, &m)| d * m).collect())
574}
575
576/// Learning rate schedule: step decay.
577/// lr = initial_lr * decay_rate^(floor(epoch / step_size))
578pub fn lr_step_decay(initial_lr: f64, decay_rate: f64, epoch: usize, step_size: usize) -> f64 {
579    initial_lr * decay_rate.powi((epoch / step_size) as i32)
580}
581
582/// Learning rate schedule: cosine annealing.
583/// lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(pi * epoch / total_epochs))
584pub fn lr_cosine(max_lr: f64, min_lr: f64, epoch: usize, total_epochs: usize) -> f64 {
585    let ratio = epoch as f64 / total_epochs as f64;
586    min_lr + 0.5 * (max_lr - min_lr) * (1.0 + (std::f64::consts::PI * ratio).cos())
587}
588
589/// Learning rate schedule: linear warmup.
590/// lr = initial_lr * min(1.0, epoch / warmup_epochs).
591pub fn lr_linear_warmup(initial_lr: f64, epoch: usize, warmup_epochs: usize) -> f64 {
592    if warmup_epochs == 0 {
593        return initial_lr;
594    }
595    initial_lr * (epoch as f64 / warmup_epochs as f64).min(1.0)
596}
597
598/// L1 regularization penalty: lambda * sum(|params|).
599pub fn l1_penalty(params: &[f64], lambda: f64) -> f64 {
600    let mut acc = KahanAccumulatorF64::new();
601    for &p in params {
602        acc.add(p.abs());
603    }
604    lambda * acc.finalize()
605}
606
607/// L2 regularization penalty: 0.5 * lambda * sum(params^2).
608pub fn l2_penalty(params: &[f64], lambda: f64) -> f64 {
609    let mut acc = KahanAccumulatorF64::new();
610    for &p in params {
611        acc.add(p * p);
612    }
613    0.5 * lambda * acc.finalize()
614}
615
616/// L1 regularization gradient: lambda * sign(params).
617pub fn l1_grad(params: &[f64], lambda: f64) -> Vec<f64> {
618    params.iter().map(|&p| {
619        if p > 0.0 { lambda } else if p < 0.0 { -lambda } else { 0.0 }
620    }).collect()
621}
622
623/// L2 regularization gradient: lambda * params.
624pub fn l2_grad(params: &[f64], lambda: f64) -> Vec<f64> {
625    params.iter().map(|&p| lambda * p).collect()
626}
627
628/// Early stopping state tracker.
629pub struct EarlyStoppingState {
630    pub patience: usize,
631    pub min_delta: f64,
632    pub best_loss: f64,
633    pub wait: usize,
634    pub stopped: bool,
635}
636
637impl EarlyStoppingState {
638    pub fn new(patience: usize, min_delta: f64) -> Self {
639        Self {
640            patience,
641            min_delta,
642            best_loss: f64::INFINITY,
643            wait: 0,
644            stopped: false,
645        }
646    }
647
648    /// Check if training should stop.
649    pub fn check(&mut self, current_loss: f64) -> bool {
650        if current_loss < self.best_loss - self.min_delta {
651            self.best_loss = current_loss;
652            self.wait = 0;
653        } else {
654            self.wait += 1;
655        }
656        if self.wait >= self.patience {
657            self.stopped = true;
658        }
659        self.stopped
660    }
661}
662
663// ---------------------------------------------------------------------------
664// Phase 3C: PCA (Principal Component Analysis)
665// ---------------------------------------------------------------------------
666
667/// Principal Component Analysis via SVD of centered data.
668///
669/// `data` is a 2D Tensor of shape (n_samples, n_features).
670/// `n_components` is the number of principal components to keep.
671///
672/// Returns (transformed_data, components, explained_variance_ratio):
673/// - `transformed_data`: (n_samples, n_components) — data projected onto principal components
674/// - `components`: (n_components, n_features) — principal component directions (rows)
675/// - `explained_variance_ratio`: Vec<f64> of length n_components — fraction of variance per component
676///
677/// **Determinism contract:** All reductions use `BinnedAccumulatorF64`.
678pub fn pca(
679    data: &Tensor,
680    n_components: usize,
681) -> Result<(Tensor, Tensor, Vec<f64>), RuntimeError> {
682    if data.ndim() != 2 {
683        return Err(RuntimeError::InvalidOperation(
684            "PCA requires a 2D data matrix".to_string(),
685        ));
686    }
687    let n_samples = data.shape()[0];
688    let n_features = data.shape()[1];
689
690    if n_samples == 0 || n_features == 0 {
691        return Err(RuntimeError::InvalidOperation(
692            "PCA: empty data matrix".to_string(),
693        ));
694    }
695    if n_components == 0 || n_components > n_features.min(n_samples) {
696        return Err(RuntimeError::InvalidOperation(format!(
697            "PCA: n_components ({}) must be in [1, min(n_samples, n_features) = {}]",
698            n_components,
699            n_features.min(n_samples)
700        )));
701    }
702
703    let raw = data.to_vec();
704
705    // Step 1: Compute column means using BinnedAccumulatorF64
706    let mut means = vec![0.0f64; n_features];
707    for j in 0..n_features {
708        let mut acc = BinnedAccumulatorF64::new();
709        for i in 0..n_samples {
710            acc.add(raw[i * n_features + j]);
711        }
712        means[j] = acc.finalize() / n_samples as f64;
713    }
714
715    // Step 2: Center the data
716    let mut centered = vec![0.0f64; n_samples * n_features];
717    for i in 0..n_samples {
718        for j in 0..n_features {
719            centered[i * n_features + j] = raw[i * n_features + j] - means[j];
720        }
721    }
722    let centered_tensor = Tensor::from_vec(centered, &[n_samples, n_features])?;
723
724    // Step 3: SVD of centered data
725    let (u, s, vt) = centered_tensor.svd()?;
726    let k = n_components.min(s.len());
727
728    // Step 4: Components = first k rows of Vt
729    let vt_data = vt.to_vec();
730    let vt_cols = vt.shape()[1]; // n_features
731    let mut components = vec![0.0f64; k * n_features];
732    for i in 0..k {
733        for j in 0..n_features {
734            components[i * n_features + j] = vt_data[i * vt_cols + j];
735        }
736    }
737
738    // Step 5: Explained variance = s_i^2 / (n_samples - 1)
739    // Total variance = sum of all s_i^2 / (n_samples - 1)
740    let denom = if n_samples > 1 {
741        (n_samples - 1) as f64
742    } else {
743        1.0
744    };
745
746    let mut total_var_acc = BinnedAccumulatorF64::new();
747    for &si in &s {
748        total_var_acc.add(si * si / denom);
749    }
750    let total_var = total_var_acc.finalize();
751
752    let explained_variance_ratio: Vec<f64> = if total_var > 1e-15 {
753        s[..k]
754            .iter()
755            .map(|&si| (si * si / denom) / total_var)
756            .collect()
757    } else {
758        vec![0.0; k]
759    };
760
761    // Step 6: Transformed data = U_k @ diag(S_k)
762    let u_data = u.to_vec();
763    let u_cols = u.shape()[1];
764    let mut transformed = vec![0.0f64; n_samples * k];
765    for i in 0..n_samples {
766        for j in 0..k {
767            transformed[i * k + j] = u_data[i * u_cols + j] * s[j];
768        }
769    }
770
771    Ok((
772        Tensor::from_vec(transformed, &[n_samples, k])?,
773        Tensor::from_vec(components, &[k, n_features])?,
774        explained_variance_ratio,
775    ))
776}
777
778// ---------------------------------------------------------------------------
779// L-BFGS Optimizer (Sprint 2)
780// ---------------------------------------------------------------------------
781
782/// L-BFGS optimizer state.
783///
784/// Limited-memory Broyden-Fletcher-Goldfarb-Shanno quasi-Newton method.
785/// Maintains a history of `m` most recent (s, y) pairs to approximate
786/// the inverse Hessian without storing it explicitly.
787pub struct LbfgsState {
788    pub lr: f64,
789    /// History size (default 10)
790    pub m: usize,
791    /// Parameter differences: s_k = params_{k+1} - params_k
792    pub s_history: Vec<Vec<f64>>,
793    /// Gradient differences: y_k = grad_{k+1} - grad_k
794    pub y_history: Vec<Vec<f64>>,
795    pub prev_params: Option<Vec<f64>>,
796    pub prev_grad: Option<Vec<f64>>,
797}
798
799impl LbfgsState {
800    pub fn new(lr: f64, m: usize) -> Self {
801        Self {
802            lr,
803            m,
804            s_history: Vec::new(),
805            y_history: Vec::new(),
806            prev_params: None,
807            prev_grad: None,
808        }
809    }
810}
811
812/// Dot product of two slices using Kahan summation for determinism.
813fn kahan_dot(a: &[f64], b: &[f64]) -> f64 {
814    debug_assert_eq!(a.len(), b.len());
815    let mut acc = KahanAccumulatorF64::new();
816    for (&ai, &bi) in a.iter().zip(b.iter()) {
817        acc.add(ai * bi);
818    }
819    acc.finalize()
820}
821
822/// Strong Wolfe line search.
823///
824/// Finds a step length `alpha` satisfying both:
825/// - Armijo (sufficient decrease): f(x + alpha*d) <= f(x) + c1*alpha*g^T*d
826/// - Curvature condition: |g(x + alpha*d)^T*d| <= c2*|g(x)^T*d|
827///
828/// Uses backtracking bracketing followed by bisection zoom.
829///
830/// # Arguments
831/// * `params` - Current parameters
832/// * `direction` - Search direction (should be a descent direction)
833/// * `f` - Function that returns (value, gradient) at a parameter vector
834/// * `f0` - Current function value
835/// * `g0` - Current gradient
836/// * `alpha_init` - Initial step length (typically `lr`)
837///
838/// # Returns
839/// `(alpha, new_params, new_val, new_grad)` or the best found so far on failure.
840pub fn wolfe_line_search<F>(
841    params: &[f64],
842    direction: &[f64],
843    f: &mut F,
844    f0: f64,
845    g0: &[f64],
846    alpha_init: f64,
847) -> (f64, Vec<f64>, f64, Vec<f64>)
848where
849    F: FnMut(&[f64]) -> (f64, Vec<f64>),
850{
851    let c1 = 1e-4_f64;
852    let c2 = 0.9_f64;
853    let derphi0 = kahan_dot(g0, direction); // should be negative for descent
854
855    // Helper: evaluate f at params + alpha * direction
856    let step = |alpha: f64| -> Vec<f64> {
857        params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
858    };
859
860    let max_iter = 30;
861    let mut alpha_lo = 0.0_f64;
862    let mut alpha_hi = f64::INFINITY;
863    let mut phi_lo = f0;
864    let mut dphi_lo = derphi0;
865    let mut alpha = alpha_init;
866
867    // Keep track of best valid alpha found (fallback)
868    let mut best_alpha = alpha_init;
869    let mut best_params = step(alpha_init);
870    let (mut best_val, mut best_grad) = f(&best_params);
871
872    for _iter in 0..max_iter {
873        let x_alpha = step(alpha);
874        let (phi_alpha, grad_alpha) = f(&x_alpha);
875        let dphi_alpha = kahan_dot(&grad_alpha, direction);
876
877        // Track the best point for fallback
878        if phi_alpha < best_val {
879            best_alpha = alpha;
880            best_params = x_alpha.clone();
881            best_val = phi_alpha;
882            best_grad = grad_alpha.clone();
883        }
884
885        // Armijo condition violated or function increased beyond bracket — shrink hi
886        if phi_alpha > f0 + c1 * alpha * derphi0 || (phi_alpha >= phi_lo && alpha_lo > 0.0) {
887            alpha_hi = alpha;
888            // Zoom between alpha_lo and alpha_hi
889            let (za, zp, zv, zg) = wolfe_zoom(
890                params, direction, f, f0, derphi0, c1, c2,
891                alpha_lo, alpha_hi, phi_lo, dphi_lo,
892            );
893            return (za, zp, zv, zg);
894        }
895
896        // Strong Wolfe curvature condition satisfied — done
897        if dphi_alpha.abs() <= c2 * derphi0.abs() {
898            return (alpha, x_alpha, phi_alpha, grad_alpha);
899        }
900
901        // Derivative is positive — zoom between alpha and alpha_lo
902        if dphi_alpha >= 0.0 {
903            let (za, zp, zv, zg) = wolfe_zoom(
904                params, direction, f, f0, derphi0, c1, c2,
905                alpha, alpha_lo, phi_alpha, dphi_alpha,
906            );
907            return (za, zp, zv, zg);
908        }
909
910        // Expand: step is too short
911        alpha_lo = alpha;
912        phi_lo = phi_alpha;
913        dphi_lo = dphi_alpha;
914
915        // Expand by factor 2 (capped to prevent explosion)
916        alpha = if alpha_hi.is_finite() {
917            (alpha_lo + alpha_hi) * 0.5
918        } else {
919            (alpha * 2.0).min(1e8)
920        };
921    }
922
923    // Return best found
924    (best_alpha, best_params, best_val, best_grad)
925}
926
927/// Zoom phase of Wolfe line search — bisects between [alpha_lo, alpha_hi].
928#[allow(clippy::too_many_arguments)]
929fn wolfe_zoom<F>(
930    params: &[f64],
931    direction: &[f64],
932    f: &mut F,
933    f0: f64,
934    derphi0: f64,
935    c1: f64,
936    c2: f64,
937    mut alpha_lo: f64,
938    mut alpha_hi: f64,
939    mut phi_lo: f64,
940    _dphi_lo: f64,
941) -> (f64, Vec<f64>, f64, Vec<f64>)
942where
943    F: FnMut(&[f64]) -> (f64, Vec<f64>),
944{
945    let step = |alpha: f64| -> Vec<f64> {
946        params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
947    };
948
949    let max_zoom = 20;
950    let mut best_alpha = alpha_lo;
951    let mut best_x = step(alpha_lo);
952    let (mut best_val, mut best_grad) = f(&best_x);
953
954    for _i in 0..max_zoom {
955        let alpha_j = (alpha_lo + alpha_hi) * 0.5;
956        let x_j = step(alpha_j);
957        let (phi_j, grad_j) = f(&x_j);
958        let dphi_j = kahan_dot(&grad_j, direction);
959
960        if phi_j < best_val {
961            best_alpha = alpha_j;
962            best_x = x_j.clone();
963            best_val = phi_j;
964            best_grad = grad_j.clone();
965        }
966
967        if phi_j > f0 + c1 * alpha_j * derphi0 || phi_j >= phi_lo {
968            alpha_hi = alpha_j;
969        } else {
970            // Curvature condition satisfied
971            if dphi_j.abs() <= c2 * derphi0.abs() {
972                return (alpha_j, x_j, phi_j, grad_j);
973            }
974            if dphi_j * (alpha_hi - alpha_lo) >= 0.0 {
975                alpha_hi = alpha_lo;
976            }
977            alpha_lo = alpha_j;
978            phi_lo = phi_j;
979        }
980
981        // Convergence check on bracket width
982        if (alpha_hi - alpha_lo).abs() < 1e-14 {
983            break;
984        }
985    }
986
987    (best_alpha, best_x, best_val, best_grad)
988}
989
990/// L-BFGS step with strong Wolfe line search.
991///
992/// Given current parameters and gradients (and a closure to evaluate the
993/// function and gradient at any point), computes the L-BFGS search direction
994/// via the two-loop recursion, performs a Wolfe line search along that
995/// direction, then updates the (s, y) history buffers.
996///
997/// # Arguments
998/// * `params` - Current parameter vector (modified in-place on success)
999/// * `grads` - Gradient at current params
1000/// * `state` - L-BFGS state (history + meta-parameters)
1001/// * `f` - Closure: takes parameter slice, returns `(loss, gradient_vec)`
1002///
1003/// # Returns
1004/// `(new_params, new_grads, step_taken)` — step_taken is false if the search
1005/// direction was not a descent direction (resets to gradient descent).
1006pub fn lbfgs_step<F>(
1007    params: &[f64],
1008    grads: &[f64],
1009    state: &mut LbfgsState,
1010    mut f: F,
1011) -> (Vec<f64>, Vec<f64>, bool)
1012where
1013    F: FnMut(&[f64]) -> (f64, Vec<f64>),
1014{
1015    let n = params.len();
1016    debug_assert_eq!(grads.len(), n);
1017
1018    // ---- Two-loop L-BFGS recursion ----
1019    // Computes H_k * g using the compact L-BFGS inverse Hessian approximation.
1020    let hist_len = state.s_history.len();
1021    let mut q: Vec<f64> = grads.to_vec();
1022    let mut alphas = vec![0.0_f64; hist_len];
1023    let mut rhos = vec![0.0_f64; hist_len];
1024
1025    // Loop 1: backward through history
1026    for i in (0..hist_len).rev() {
1027        let sy = kahan_dot(&state.s_history[i], &state.y_history[i]);
1028        rhos[i] = if sy.abs() < 1e-300 { 0.0 } else { 1.0 / sy };
1029        let sq = kahan_dot(&state.s_history[i], &q);
1030        alphas[i] = rhos[i] * sq;
1031        // q -= alpha_i * y_i
1032        for j in 0..n {
1033            q[j] -= alphas[i] * state.y_history[i][j];
1034        }
1035    }
1036
1037    // Scale by initial Hessian approximation H_0 = (s^T y / y^T y) * I
1038    let scale = if hist_len > 0 {
1039        let last = hist_len - 1;
1040        let sy = kahan_dot(&state.s_history[last], &state.y_history[last]);
1041        let yy = kahan_dot(&state.y_history[last], &state.y_history[last]);
1042        if yy.abs() < 1e-300 { 1.0 } else { sy / yy }
1043    } else {
1044        1.0
1045    };
1046    let mut r: Vec<f64> = q.iter().map(|&qi| scale * qi).collect();
1047
1048    // Loop 2: forward through history
1049    for i in 0..hist_len {
1050        let yr = kahan_dot(&state.y_history[i], &r);
1051        let beta = rhos[i] * yr;
1052        // r += (alpha_i - beta) * s_i
1053        let diff = alphas[i] - beta;
1054        for j in 0..n {
1055            r[j] += diff * state.s_history[i][j];
1056        }
1057    }
1058
1059    // Search direction: d = -H_k * g = -r
1060    let direction: Vec<f64> = r.iter().map(|&ri| -ri).collect();
1061
1062    // Check descent condition: d^T g < 0
1063    let descent_check = kahan_dot(&direction, grads);
1064    let (direction, is_descent) = if descent_check >= 0.0 || !descent_check.is_finite() {
1065        // Fall back to scaled negative gradient
1066        let norm_g = kahan_dot(grads, grads).sqrt().max(1e-300);
1067        (grads.iter().map(|&g| -g / norm_g).collect::<Vec<f64>>(), false)
1068    } else {
1069        (direction, true)
1070    };
1071
1072    // ---- Wolfe line search ----
1073    let (f0, _) = f(params);
1074    let (_, new_params, _, new_grads) = wolfe_line_search(
1075        params,
1076        &direction,
1077        &mut f,
1078        f0,
1079        grads,
1080        state.lr,
1081    );
1082
1083    // ---- Update history ----
1084    // s_k = new_params - params
1085    let s_k: Vec<f64> = new_params.iter().zip(params.iter()).map(|(&np, &p)| np - p).collect();
1086    // y_k = new_grads - grads
1087    let y_k: Vec<f64> = new_grads.iter().zip(grads.iter()).map(|(&ng, &g)| ng - g).collect();
1088
1089    let sy = kahan_dot(&s_k, &y_k);
1090    // Only add to history if curvature condition holds (sy > 0)
1091    if sy > 1e-300 {
1092        state.s_history.push(s_k);
1093        state.y_history.push(y_k);
1094        // Trim to m most recent
1095        if state.s_history.len() > state.m {
1096            state.s_history.remove(0);
1097            state.y_history.remove(0);
1098        }
1099    }
1100
1101    state.prev_params = Some(new_params.clone());
1102    state.prev_grad = Some(new_grads.clone());
1103
1104    (new_params, new_grads, is_descent)
1105}
1106
1107// ---------------------------------------------------------------------------
1108// LSTM cell forward pass
1109// ---------------------------------------------------------------------------
1110
1111/// LSTM cell forward pass.
1112///
1113/// * `x`:      `[batch, input_size]`
1114/// * `h_prev`: `[batch, hidden_size]`
1115/// * `c_prev`: `[batch, hidden_size]`
1116/// * `w_ih`:   `[4*hidden_size, input_size]`
1117/// * `w_hh`:   `[4*hidden_size, hidden_size]`
1118/// * `b_ih`:   `[4*hidden_size]`
1119/// * `b_hh`:   `[4*hidden_size]`
1120///
1121/// Returns `(h_new, c_new)`.
1122///
1123/// Gate layout (PyTorch convention): i, f, g, o — each of size `hidden_size`.
1124/// All reductions use Kahan summation via the existing `Tensor::linear` method.
1125pub fn lstm_cell(
1126    x: &Tensor,
1127    h_prev: &Tensor,
1128    c_prev: &Tensor,
1129    w_ih: &Tensor,
1130    w_hh: &Tensor,
1131    b_ih: &Tensor,
1132    b_hh: &Tensor,
1133) -> Result<(Tensor, Tensor), String> {
1134    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1135
1136    // Validate shapes
1137    if x.ndim() != 2 {
1138        return Err("lstm_cell: x must be 2-D [batch, input_size]".into());
1139    }
1140    if h_prev.ndim() != 2 {
1141        return Err("lstm_cell: h_prev must be 2-D [batch, hidden_size]".into());
1142    }
1143    if c_prev.ndim() != 2 {
1144        return Err("lstm_cell: c_prev must be 2-D [batch, hidden_size]".into());
1145    }
1146    let hidden_size = h_prev.shape()[1];
1147    if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1148        return Err(format!(
1149            "lstm_cell: w_ih must be [4*hidden_size, input_size], got {:?}",
1150            w_ih.shape()
1151        ));
1152    }
1153    if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1154        return Err(format!(
1155            "lstm_cell: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1156            w_hh.shape()
1157        ));
1158    }
1159    if b_ih.len() != 4 * hidden_size {
1160        return Err(format!(
1161            "lstm_cell: b_ih must have length 4*hidden_size={}, got {}",
1162            4 * hidden_size,
1163            b_ih.len()
1164        ));
1165    }
1166    if b_hh.len() != 4 * hidden_size {
1167        return Err(format!(
1168            "lstm_cell: b_hh must have length 4*hidden_size={}, got {}",
1169            4 * hidden_size,
1170            b_hh.len()
1171        ));
1172    }
1173
1174    // gates = x @ w_ih^T + b_ih + h_prev @ w_hh^T + b_hh
1175    // Use linear which does: input @ weight^T + bias
1176    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1177    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1178    let gates = gates_ih.add(&gates_hh).map_err(map_err)?;
1179
1180    // gates is [batch, 4*hidden_size]. Split into 4 chunks along dim 1.
1181    let chunks = gates.chunk(4, 1).map_err(map_err)?;
1182    let gates_i = &chunks[0];
1183    let gates_f = &chunks[1];
1184    let gates_g = &chunks[2];
1185    let gates_o = &chunks[3];
1186
1187    // Apply activations
1188    let i = gates_i.sigmoid();
1189    let f = gates_f.sigmoid();
1190    let g = gates_g.tanh_activation();
1191    let o = gates_o.sigmoid();
1192
1193    // c_new = f * c_prev + i * g
1194    let fc = f.mul_elem(c_prev).map_err(map_err)?;
1195    let ig = i.mul_elem(&g).map_err(map_err)?;
1196    let c_new = fc.add(&ig).map_err(map_err)?;
1197
1198    // h_new = o * tanh(c_new)
1199    let c_tanh = c_new.tanh_activation();
1200    let h_new = o.mul_elem(&c_tanh).map_err(map_err)?;
1201
1202    Ok((h_new, c_new))
1203}
1204
1205// ---------------------------------------------------------------------------
1206// GRU cell forward pass
1207// ---------------------------------------------------------------------------
1208
1209/// GRU cell forward pass.
1210///
1211/// * `x`:      `[batch, input_size]`
1212/// * `h_prev`: `[batch, hidden_size]`
1213/// * `w_ih`:   `[3*hidden_size, input_size]`
1214/// * `w_hh`:   `[3*hidden_size, hidden_size]`
1215/// * `b_ih`:   `[3*hidden_size]`
1216/// * `b_hh`:   `[3*hidden_size]`
1217///
1218/// Returns `h_new` tensor of shape `[batch, hidden_size]`.
1219///
1220/// Gate layout: r (reset), z (update), n (new) — each of size `hidden_size`.
1221/// All reductions use Kahan summation via the existing `Tensor::linear` method.
1222pub fn gru_cell(
1223    x: &Tensor,
1224    h_prev: &Tensor,
1225    w_ih: &Tensor,
1226    w_hh: &Tensor,
1227    b_ih: &Tensor,
1228    b_hh: &Tensor,
1229) -> Result<Tensor, String> {
1230    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1231
1232    // Validate shapes
1233    if x.ndim() != 2 {
1234        return Err("gru_cell: x must be 2-D [batch, input_size]".into());
1235    }
1236    if h_prev.ndim() != 2 {
1237        return Err("gru_cell: h_prev must be 2-D [batch, hidden_size]".into());
1238    }
1239    let hidden_size = h_prev.shape()[1];
1240    if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1241        return Err(format!(
1242            "gru_cell: w_ih must be [3*hidden_size, input_size], got {:?}",
1243            w_ih.shape()
1244        ));
1245    }
1246    if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1247        return Err(format!(
1248            "gru_cell: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1249            w_hh.shape()
1250        ));
1251    }
1252    if b_ih.len() != 3 * hidden_size {
1253        return Err(format!(
1254            "gru_cell: b_ih must have length 3*hidden_size={}, got {}",
1255            3 * hidden_size,
1256            b_ih.len()
1257        ));
1258    }
1259    if b_hh.len() != 3 * hidden_size {
1260        return Err(format!(
1261            "gru_cell: b_hh must have length 3*hidden_size={}, got {}",
1262            3 * hidden_size,
1263            b_hh.len()
1264        ));
1265    }
1266
1267    // gates_ih = x @ w_ih^T + b_ih   → [batch, 3*hidden_size]
1268    // gates_hh = h_prev @ w_hh^T + b_hh → [batch, 3*hidden_size]
1269    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1270    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1271
1272    // Split into r, z, n portions
1273    let ih_chunks = gates_ih.chunk(3, 1).map_err(map_err)?;
1274    let hh_chunks = gates_hh.chunk(3, 1).map_err(map_err)?;
1275    let r_ih = &ih_chunks[0];
1276    let z_ih = &ih_chunks[1];
1277    let n_ih = &ih_chunks[2];
1278    let r_hh = &hh_chunks[0];
1279    let z_hh = &hh_chunks[1];
1280    let n_hh = &hh_chunks[2];
1281
1282    // r = sigmoid(r_ih + r_hh)
1283    let r = r_ih.add(r_hh).map_err(map_err)?.sigmoid();
1284    // z = sigmoid(z_ih + z_hh)
1285    let z = z_ih.add(z_hh).map_err(map_err)?.sigmoid();
1286    // n = tanh(n_ih + r * n_hh)
1287    let r_n_hh = r.mul_elem(n_hh).map_err(map_err)?;
1288    let n = n_ih.add(&r_n_hh).map_err(map_err)?.tanh_activation();
1289
1290    // h_new = (1 - z) * n + z * h_prev
1291    // Build (1 - z): negate z then add ones
1292    let ones = Tensor::ones(z.shape());
1293    let one_minus_z = ones.sub(&z).map_err(map_err)?;
1294    let term1 = one_minus_z.mul_elem(&n).map_err(map_err)?;
1295    let term2 = z.mul_elem(h_prev).map_err(map_err)?;
1296    let h_new = term1.add(&term2).map_err(map_err)?;
1297
1298    Ok(h_new)
1299}
1300
1301// ---------------------------------------------------------------------------
1302// Fused LSTM cell forward pass (allocation-optimized)
1303// ---------------------------------------------------------------------------
1304
1305/// Fused LSTM cell: minimizes intermediate tensor allocations.
1306///
1307/// Produces bit-identical results to [`lstm_cell`] but reduces tensor
1308/// allocations from 13 to 4 (2 matmuls via `linear`, 2 output `from_vec`).
1309/// After the two `linear` calls the gate combination, activations, and cell /
1310/// hidden-state updates are computed element-wise in a single scalar loop
1311/// with no additional tensor temporaries.
1312///
1313/// Shape requirements are identical to [`lstm_cell`].
1314pub fn lstm_cell_fused(
1315    x: &Tensor,
1316    h_prev: &Tensor,
1317    c_prev: &Tensor,
1318    w_ih: &Tensor,
1319    w_hh: &Tensor,
1320    b_ih: &Tensor,
1321    b_hh: &Tensor,
1322) -> Result<(Tensor, Tensor), String> {
1323    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1324
1325    // --- Validate shapes (same checks as lstm_cell) --------------------------
1326    if x.ndim() != 2 {
1327        return Err("lstm_cell_fused: x must be 2-D [batch, input_size]".into());
1328    }
1329    if h_prev.ndim() != 2 {
1330        return Err("lstm_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1331    }
1332    if c_prev.ndim() != 2 {
1333        return Err("lstm_cell_fused: c_prev must be 2-D [batch, hidden_size]".into());
1334    }
1335    let batch = x.shape()[0];
1336    let hidden_size = h_prev.shape()[1];
1337    if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1338        return Err(format!(
1339            "lstm_cell_fused: w_ih must be [4*hidden_size, input_size], got {:?}",
1340            w_ih.shape()
1341        ));
1342    }
1343    if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1344        return Err(format!(
1345            "lstm_cell_fused: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1346            w_hh.shape()
1347        ));
1348    }
1349    if b_ih.len() != 4 * hidden_size {
1350        return Err(format!(
1351            "lstm_cell_fused: b_ih must have length 4*hidden_size={}, got {}",
1352            4 * hidden_size,
1353            b_ih.len()
1354        ));
1355    }
1356    if b_hh.len() != 4 * hidden_size {
1357        return Err(format!(
1358            "lstm_cell_fused: b_hh must have length 4*hidden_size={}, got {}",
1359            4 * hidden_size,
1360            b_hh.len()
1361        ));
1362    }
1363
1364    // --- Step 1: two matmuls (reuse existing Kahan-summation linear) ---------
1365    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1366    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1367
1368    // --- Step 2: fused gate combination + activations + state update ---------
1369    let gih = gates_ih.to_vec();
1370    let ghh = gates_hh.to_vec();
1371    let cprev = c_prev.to_vec();
1372
1373    let mut h_new_data = vec![0.0f64; batch * hidden_size];
1374    let mut c_new_data = vec![0.0f64; batch * hidden_size];
1375
1376    for b_idx in 0..batch {
1377        let base = b_idx * 4 * hidden_size;
1378        for h in 0..hidden_size {
1379            // Combine ih + hh for each gate (i, f, g, o)
1380            let gi = gih[base + h] + ghh[base + h];
1381            let gf = gih[base + hidden_size + h] + ghh[base + hidden_size + h];
1382            let gg = gih[base + 2 * hidden_size + h] + ghh[base + 2 * hidden_size + h];
1383            let go = gih[base + 3 * hidden_size + h] + ghh[base + 3 * hidden_size + h];
1384
1385            // Activations (scalar, no tensor allocation)
1386            let i_val = 1.0 / (1.0 + (-gi).exp()); // sigmoid
1387            let f_val = 1.0 / (1.0 + (-gf).exp()); // sigmoid
1388            let g_val = gg.tanh();                   // tanh
1389            let o_val = 1.0 / (1.0 + (-go).exp()); // sigmoid
1390
1391            // Cell and hidden state update
1392            let c_idx = b_idx * hidden_size + h;
1393            let c_val = f_val * cprev[c_idx] + i_val * g_val;
1394            c_new_data[c_idx] = c_val;
1395            h_new_data[c_idx] = o_val * c_val.tanh();
1396        }
1397    }
1398
1399    let h_new = Tensor::from_vec(h_new_data, &[batch, hidden_size]).map_err(map_err)?;
1400    let c_new = Tensor::from_vec(c_new_data, &[batch, hidden_size]).map_err(map_err)?;
1401
1402    Ok((h_new, c_new))
1403}
1404
1405// ---------------------------------------------------------------------------
1406// Fused GRU cell forward pass (allocation-optimized)
1407// ---------------------------------------------------------------------------
1408
1409/// Fused GRU cell: minimizes intermediate tensor allocations.
1410///
1411/// Produces bit-identical results to [`gru_cell`] but reduces tensor
1412/// allocations from ~12 to 3 (2 matmuls via `linear`, 1 output `from_vec`).
1413/// After the two `linear` calls the gate combination, activations, and
1414/// hidden-state update are computed element-wise in a single scalar loop.
1415///
1416/// Shape requirements are identical to [`gru_cell`].
1417pub fn gru_cell_fused(
1418    x: &Tensor,
1419    h_prev: &Tensor,
1420    w_ih: &Tensor,
1421    w_hh: &Tensor,
1422    b_ih: &Tensor,
1423    b_hh: &Tensor,
1424) -> Result<Tensor, String> {
1425    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1426
1427    // --- Validate shapes (same checks as gru_cell) ---------------------------
1428    if x.ndim() != 2 {
1429        return Err("gru_cell_fused: x must be 2-D [batch, input_size]".into());
1430    }
1431    if h_prev.ndim() != 2 {
1432        return Err("gru_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1433    }
1434    let batch = x.shape()[0];
1435    let hidden_size = h_prev.shape()[1];
1436    if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1437        return Err(format!(
1438            "gru_cell_fused: w_ih must be [3*hidden_size, input_size], got {:?}",
1439            w_ih.shape()
1440        ));
1441    }
1442    if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1443        return Err(format!(
1444            "gru_cell_fused: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1445            w_hh.shape()
1446        ));
1447    }
1448    if b_ih.len() != 3 * hidden_size {
1449        return Err(format!(
1450            "gru_cell_fused: b_ih must have length 3*hidden_size={}, got {}",
1451            3 * hidden_size,
1452            b_ih.len()
1453        ));
1454    }
1455    if b_hh.len() != 3 * hidden_size {
1456        return Err(format!(
1457            "gru_cell_fused: b_hh must have length 3*hidden_size={}, got {}",
1458            3 * hidden_size,
1459            b_hh.len()
1460        ));
1461    }
1462
1463    // --- Step 1: two matmuls (reuse existing Kahan-summation linear) ---------
1464    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1465    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1466
1467    // --- Step 2: fused gate combination + activations + state update ---------
1468    let gih = gates_ih.to_vec();
1469    let ghh = gates_hh.to_vec();
1470    let hp = h_prev.to_vec();
1471
1472    let mut h_new_data = vec![0.0f64; batch * hidden_size];
1473
1474    for b_idx in 0..batch {
1475        let base = b_idx * 3 * hidden_size;
1476        for h in 0..hidden_size {
1477            // r = sigmoid(ih_r + hh_r)
1478            let r_val =
1479                1.0 / (1.0 + (-(gih[base + h] + ghh[base + h])).exp());
1480            // z = sigmoid(ih_z + hh_z)
1481            let z_val = 1.0
1482                / (1.0
1483                    + (-(gih[base + hidden_size + h]
1484                        + ghh[base + hidden_size + h]))
1485                        .exp());
1486            // n = tanh(ih_n + r * hh_n)
1487            let n_val = (gih[base + 2 * hidden_size + h]
1488                + r_val * ghh[base + 2 * hidden_size + h])
1489                .tanh();
1490
1491            // h_new = (1 - z) * n + z * h_prev
1492            let h_idx = b_idx * hidden_size + h;
1493            h_new_data[h_idx] = (1.0 - z_val) * n_val + z_val * hp[h_idx];
1494        }
1495    }
1496
1497    Tensor::from_vec(h_new_data, &[batch, hidden_size])
1498        .map_err(|e| format!("{e}"))
1499}
1500
1501// ---------------------------------------------------------------------------
1502// Multi-Head Attention
1503// ---------------------------------------------------------------------------
1504
1505/// Multi-head attention: Q, K, V projections + scaled dot-product attention + output projection.
1506///
1507/// * `q`, `k`, `v`: `[batch, seq, model_dim]`
1508/// * `w_q`, `w_k`, `w_v`, `w_o`: `[model_dim, model_dim]`
1509/// * `b_q`, `b_k`, `b_v`, `b_o`: `[model_dim]`
1510/// * `num_heads`: number of attention heads (`model_dim` must be divisible by `num_heads`)
1511///
1512/// Returns `[batch, seq, model_dim]`.
1513///
1514/// All reductions use Kahan summation via the existing `Tensor::linear` and
1515/// `Tensor::scaled_dot_product_attention` methods.
1516pub fn multi_head_attention(
1517    q: &Tensor,
1518    k: &Tensor,
1519    v: &Tensor,
1520    w_q: &Tensor,
1521    w_k: &Tensor,
1522    w_v: &Tensor,
1523    w_o: &Tensor,
1524    b_q: &Tensor,
1525    b_k: &Tensor,
1526    b_v: &Tensor,
1527    b_o: &Tensor,
1528    num_heads: usize,
1529) -> Result<Tensor, String> {
1530    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1531
1532    if q.ndim() != 3 {
1533        return Err("multi_head_attention: q must be 3-D [batch, seq, model_dim]".into());
1534    }
1535
1536    // Linear projections: Q = q.linear(w_q, b_q), etc.
1537    let q_proj = q.linear(w_q, b_q).map_err(map_err)?;
1538    let k_proj = k.linear(w_k, b_k).map_err(map_err)?;
1539    let v_proj = v.linear(w_v, b_v).map_err(map_err)?;
1540
1541    // Split heads: [batch, seq, model_dim] -> [batch, num_heads, seq, head_dim]
1542    let q_heads = q_proj.split_heads(num_heads).map_err(map_err)?;
1543    let k_heads = k_proj.split_heads(num_heads).map_err(map_err)?;
1544    let v_heads = v_proj.split_heads(num_heads).map_err(map_err)?;
1545
1546    // Scaled dot-product attention: works on [..., seq, head_dim]
1547    let attn = Tensor::scaled_dot_product_attention(&q_heads, &k_heads, &v_heads)
1548        .map_err(map_err)?;
1549
1550    // Merge heads back: [batch, num_heads, seq, head_dim] -> [batch, seq, model_dim]
1551    let merged = attn.merge_heads().map_err(map_err)?;
1552
1553    // Output projection
1554    let output = merged.linear(w_o, b_o).map_err(map_err)?;
1555
1556    Ok(output)
1557}
1558
1559// ---------------------------------------------------------------------------
1560// Embedding layer
1561// ---------------------------------------------------------------------------
1562
1563/// Embedding lookup: maps integer indices to dense vectors.
1564///
1565/// Performs a table lookup in the weight matrix, selecting rows
1566/// corresponding to the given indices.
1567///
1568/// # Arguments
1569///
1570/// * `weight` - Embedding matrix of shape `[vocab_size, embed_dim]`
1571/// * `indices` - 1-D tensor of integer indices
1572///
1573/// # Returns
1574///
1575/// Tensor of shape `[len(indices), embed_dim]`
1576///
1577/// # Errors
1578///
1579/// Returns an error if any index is out of bounds for the vocabulary size.
1580pub fn embedding(weight: &crate::tensor::Tensor, indices: &[i64]) -> Result<crate::tensor::Tensor, String> {
1581    let shape = weight.shape();
1582    if shape.len() != 2 {
1583        return Err(format!("embedding: weight must be 2-D [vocab_size, embed_dim], got {:?}", shape));
1584    }
1585    let vocab_size = shape[0];
1586    let embed_dim = shape[1];
1587    let weight_data = weight.to_vec();
1588
1589    let mut out = Vec::with_capacity(indices.len() * embed_dim);
1590    for &idx in indices {
1591        let i = idx as usize;
1592        if i >= vocab_size {
1593            return Err(format!("embedding: index {} out of bounds for vocab_size {}", idx, vocab_size));
1594        }
1595        let start = i * embed_dim;
1596        out.extend_from_slice(&weight_data[start..start + embed_dim]);
1597    }
1598    crate::tensor::Tensor::from_vec(out, &[indices.len(), embed_dim])
1599        .map_err(|e| e.to_string())
1600}
1601
1602// ---------------------------------------------------------------------------
1603// Deterministic mini-batch indices
1604// ---------------------------------------------------------------------------
1605
1606/// Creates deterministic batch index ranges for mini-batch training.
1607///
1608/// Generates `(start, end)` pairs that cover the entire dataset, shuffled
1609/// deterministically using the provided seed via [`SplitMix64`].
1610///
1611/// # Arguments
1612///
1613/// * `dataset_size` - Total number of samples
1614/// * `batch_size` - Number of samples per batch (last batch may be smaller)
1615/// * `seed` - RNG seed for deterministic shuffling
1616///
1617/// # Returns
1618///
1619/// A vector of `(start, end)` index pairs covering all samples.
1620pub fn batch_indices(dataset_size: usize, batch_size: usize, seed: u64) -> Vec<(usize, usize)> {
1621    use cjc_repro::Rng;
1622    let mut rng = Rng::seeded(seed);
1623    // Fisher-Yates shuffle with deterministic RNG
1624    let mut indices: Vec<usize> = (0..dataset_size).collect();
1625    for i in (1..dataset_size).rev() {
1626        let j = (rng.next_u64() as usize) % (i + 1);
1627        indices.swap(i, j);
1628    }
1629    let mut batches = Vec::new();
1630    let mut i = 0;
1631    while i < dataset_size {
1632        let end = (i + batch_size).min(dataset_size);
1633        batches.push((i, end));
1634        i = end;
1635    }
1636    batches
1637}
1638
1639// ---------------------------------------------------------------------------
1640// Tests
1641// ---------------------------------------------------------------------------
1642
1643#[cfg(test)]
1644mod tests {
1645    use super::*;
1646
1647    #[test]
1648    fn test_mse_zero() {
1649        let pred = [1.0, 2.0, 3.0];
1650        let target = [1.0, 2.0, 3.0];
1651        assert_eq!(mse_loss(&pred, &target).unwrap(), 0.0);
1652    }
1653
1654    #[test]
1655    fn test_mse_basic() {
1656        let pred = [1.0, 2.0, 3.0];
1657        let target = [2.0, 3.0, 4.0];
1658        assert_eq!(mse_loss(&pred, &target).unwrap(), 1.0);
1659    }
1660
1661    #[test]
1662    fn test_huber_loss_quadratic() {
1663        let pred = [1.0];
1664        let target = [1.5];
1665        let h = huber_loss(&pred, &target, 1.0).unwrap();
1666        // |0.5| < 1.0, so quadratic: 0.5 * 0.25 = 0.125
1667        assert!((h - 0.125).abs() < 1e-12);
1668    }
1669
1670    #[test]
1671    fn test_sgd_step() {
1672        let mut params = [1.0, 2.0];
1673        let grads = [0.1, 0.2];
1674        let mut state = SgdState::new(2, 0.1, 0.0);
1675        sgd_step(&mut params, &grads, &mut state);
1676        assert!((params[0] - 0.99).abs() < 1e-12);
1677        assert!((params[1] - 1.98).abs() < 1e-12);
1678    }
1679
1680    #[test]
1681    fn test_adam_step() {
1682        let mut params = [1.0, 2.0];
1683        let grads = [0.1, 0.2];
1684        let mut state = AdamState::new(2, 0.001);
1685        adam_step(&mut params, &grads, &mut state);
1686        // After one step, params should be slightly different
1687        assert!(params[0] < 1.0);
1688        assert!(params[1] < 2.0);
1689    }
1690
1691    // ── Phase 2b: ParamIdx-typed accessors ──────────────────────────
1692
1693    #[test]
1694    fn paramidx_accessors_agree_with_direct_field_reads_adam() {
1695        // After several Adam steps, the typed accessors must read the same
1696        // bytes as direct `state.m[i]` / `state.v[i]` access.
1697        let mut params = vec![0.5, -0.3, 1.7, 0.0];
1698        let grads = vec![0.1, -0.2, 0.05, 0.4];
1699        let mut state = AdamState::new(4, 0.01);
1700        for _ in 0..5 {
1701            adam_step(&mut params, &grads, &mut state);
1702        }
1703        for i in 0..4 {
1704            let p = crate::idx::ParamIdx::from_usize(i);
1705            assert_eq!(state.m_at(p).to_bits(), state.m[i].to_bits());
1706            assert_eq!(state.v_at(p).to_bits(), state.v[i].to_bits());
1707        }
1708        assert_eq!(state.n_params(), 4);
1709    }
1710
1711    #[test]
1712    fn paramidx_accessors_agree_with_direct_field_reads_sgd() {
1713        let mut params = vec![1.0, 2.0, 3.0];
1714        let grads = vec![0.01, -0.02, 0.03];
1715        let mut state = SgdState::new(3, 0.05, 0.9);
1716        for _ in 0..7 {
1717            sgd_step(&mut params, &grads, &mut state);
1718        }
1719        for i in 0..3 {
1720            let p = crate::idx::ParamIdx::from_usize(i);
1721            assert_eq!(state.velocity_at(p).to_bits(), state.velocity[i].to_bits());
1722        }
1723        assert_eq!(state.n_params(), 3);
1724    }
1725
1726    #[test]
1727    fn paramidx_typed_setters_round_trip() {
1728        let mut state = AdamState::new(8, 0.001);
1729        for i in 0..8 {
1730            let p = crate::idx::ParamIdx::from_usize(i);
1731            state.set_m_at(p, (i as f64) * 0.5);
1732            state.set_v_at(p, (i as f64) * 0.25);
1733        }
1734        for i in 0..8 {
1735            let p = crate::idx::ParamIdx::from_usize(i);
1736            assert_eq!(state.m_at(p), (i as f64) * 0.5);
1737            assert_eq!(state.v_at(p), (i as f64) * 0.25);
1738            // and direct field read agrees
1739            assert_eq!(state.m[i], (i as f64) * 0.5);
1740            assert_eq!(state.v[i], (i as f64) * 0.25);
1741        }
1742    }
1743
1744    #[test]
1745    fn paramidx_adam_step_byte_equal_across_independent_runs() {
1746        // Determinism: two independent runs with the same initial state
1747        // must produce byte-equal final state and params. This is the
1748        // strongest contract for the typed-accessor migration:
1749        // re-ordering the loop body or replacing direct field access
1750        // with typed accessors must never perturb the output bits.
1751        let init_params = vec![0.5, -0.5, 1.5, -1.5, 0.7];
1752        let grads = vec![0.1, -0.05, 0.02, 0.08, -0.03];
1753
1754        let run = || {
1755            let mut params = init_params.clone();
1756            let mut state = AdamState::new(5, 0.01);
1757            for _ in 0..20 {
1758                adam_step(&mut params, &grads, &mut state);
1759            }
1760            (params, state.m.clone(), state.v.clone(), state.t)
1761        };
1762
1763        let (p1, m1, v1, t1) = run();
1764        let (p2, m2, v2, t2) = run();
1765
1766        let bits = |xs: &[f64]| -> Vec<u64> { xs.iter().map(|x| x.to_bits()).collect() };
1767        assert_eq!(bits(&p1), bits(&p2));
1768        assert_eq!(bits(&m1), bits(&m2));
1769        assert_eq!(bits(&v1), bits(&v2));
1770        assert_eq!(t1, t2);
1771    }
1772
1773    #[test]
1774    fn test_confusion_matrix() {
1775        let pred = [true, true, false, false, true];
1776        let actual = [true, false, true, false, true];
1777        let cm = confusion_matrix(&pred, &actual);
1778        assert_eq!(cm.tp, 2);
1779        assert_eq!(cm.fp, 1);
1780        assert_eq!(cm.fn_count, 1);
1781        assert_eq!(cm.tn, 1);
1782    }
1783
1784    #[test]
1785    fn test_precision_recall_f1() {
1786        let cm = ConfusionMatrix { tp: 5, fp: 2, tn: 8, fn_count: 1 };
1787        assert!((precision(&cm) - 5.0 / 7.0).abs() < 1e-12);
1788        assert!((recall(&cm) - 5.0 / 6.0).abs() < 1e-12);
1789    }
1790
1791    #[test]
1792    fn test_auc_perfect() {
1793        let scores = [0.9, 0.8, 0.2, 0.1];
1794        let labels = [true, true, false, false];
1795        let auc = auc_roc(&scores, &labels).unwrap();
1796        assert!((auc - 1.0).abs() < 1e-12);
1797    }
1798
1799    #[test]
1800    fn test_kfold_deterministic() {
1801        let f1 = kfold_indices(100, 5, 42);
1802        let f2 = kfold_indices(100, 5, 42);
1803        for i in 0..5 {
1804            assert_eq!(f1[i].0, f2[i].0);
1805            assert_eq!(f1[i].1, f2[i].1);
1806        }
1807    }
1808
1809    #[test]
1810    fn test_train_test_split_coverage() {
1811        let (train, test) = train_test_split(100, 0.2, 42);
1812        assert_eq!(train.len() + test.len(), 100);
1813        assert_eq!(test.len(), 20);
1814    }
1815
1816    // --- B4: ML Training Extensions tests ---
1817
1818    #[test]
1819    fn test_batch_norm_identity() {
1820        // mean=0, var=1, gamma=1, beta=0, eps=0 → input unchanged
1821        let x = vec![1.0, 2.0, 3.0];
1822        let mean = vec![0.0, 0.0, 0.0];
1823        let var = vec![1.0, 1.0, 1.0];
1824        let gamma = vec![1.0, 1.0, 1.0];
1825        let beta = vec![0.0, 0.0, 0.0];
1826        let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1827        assert!((result[0] - 1.0).abs() < 1e-12);
1828        assert!((result[1] - 2.0).abs() < 1e-12);
1829        assert!((result[2] - 3.0).abs() < 1e-12);
1830    }
1831
1832    #[test]
1833    fn test_batch_norm_shift_scale() {
1834        let x = vec![0.0];
1835        let mean = vec![1.0]; // shift: x - 1 = -1
1836        let var = vec![4.0];  // scale: -1/sqrt(4) = -0.5
1837        let gamma = vec![2.0]; // multiply: 2 * -0.5 = -1
1838        let beta = vec![3.0]; // add: -1 + 3 = 2
1839        let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1840        assert!((result[0] - 2.0).abs() < 1e-12);
1841    }
1842
1843    #[test]
1844    fn test_dropout_mask_seed_determinism() {
1845        let m1 = dropout_mask(100, 0.5, 42);
1846        let m2 = dropout_mask(100, 0.5, 42);
1847        assert_eq!(m1, m2);
1848    }
1849
1850    #[test]
1851    fn test_dropout_mask_different_seeds() {
1852        let m1 = dropout_mask(100, 0.5, 42);
1853        let m2 = dropout_mask(100, 0.5, 99);
1854        assert_ne!(m1, m2);
1855    }
1856
1857    #[test]
1858    fn test_lr_step_decay_schedule() {
1859        let lr0 = lr_step_decay(0.1, 0.5, 0, 10);
1860        assert!((lr0 - 0.1).abs() < 1e-12);
1861        let lr10 = lr_step_decay(0.1, 0.5, 10, 10);
1862        assert!((lr10 - 0.05).abs() < 1e-12);
1863        let lr20 = lr_step_decay(0.1, 0.5, 20, 10);
1864        assert!((lr20 - 0.025).abs() < 1e-12);
1865    }
1866
1867    #[test]
1868    fn test_lr_cosine_endpoints() {
1869        let lr0 = lr_cosine(0.1, 0.001, 0, 100);
1870        assert!((lr0 - 0.1).abs() < 1e-10);
1871        let lr_end = lr_cosine(0.1, 0.001, 100, 100);
1872        assert!((lr_end - 0.001).abs() < 1e-10);
1873    }
1874
1875    #[test]
1876    fn test_lr_linear_warmup() {
1877        let lr0 = lr_linear_warmup(0.1, 0, 10);
1878        assert!((lr0).abs() < 1e-12);
1879        let lr5 = lr_linear_warmup(0.1, 5, 10);
1880        assert!((lr5 - 0.05).abs() < 1e-12);
1881        let lr15 = lr_linear_warmup(0.1, 15, 10);
1882        assert!((lr15 - 0.1).abs() < 1e-12);
1883    }
1884
1885    #[test]
1886    fn test_l1_penalty_known() {
1887        let params = [1.0, -2.0, 3.0];
1888        let p = l1_penalty(&params, 0.1);
1889        assert!((p - 0.6).abs() < 1e-12);
1890    }
1891
1892    #[test]
1893    fn test_l2_penalty_known() {
1894        let params = [1.0, -2.0, 3.0];
1895        let p = l2_penalty(&params, 0.1);
1896        // 0.5 * 0.1 * (1 + 4 + 9) = 0.5 * 0.1 * 14 = 0.7
1897        assert!((p - 0.7).abs() < 1e-12);
1898    }
1899
1900    #[test]
1901    fn test_early_stopping_triggers() {
1902        let mut es = EarlyStoppingState::new(3, 0.01);
1903        assert!(!es.check(1.0)); // best_loss=1.0, wait=0
1904        assert!(!es.check(1.0)); // no improvement, wait=1
1905        assert!(!es.check(1.0)); // no improvement, wait=2
1906        assert!(es.check(1.0));  // no improvement, wait=3 >= patience
1907    }
1908
1909    #[test]
1910    fn test_early_stopping_resets() {
1911        let mut es = EarlyStoppingState::new(3, 0.01);
1912        es.check(1.0);
1913        es.check(1.0); // wait=1
1914        assert!(!es.check(0.5)); // improvement, wait=0
1915        assert!(!es.check(0.5)); // wait=1
1916    }
1917
1918    // --- Phase 3C: PCA tests ---
1919
1920    #[test]
1921    fn test_pca_basic_2d() {
1922        // 4 samples, 2 features — data lies mostly along first axis
1923        let data = Tensor::from_vec(
1924            vec![
1925                1.0, 0.1,
1926                2.0, 0.2,
1927                3.0, 0.3,
1928                4.0, 0.4,
1929            ],
1930            &[4, 2],
1931        )
1932        .unwrap();
1933        let (transformed, components, evr) = pca(&data, 2).unwrap();
1934        assert_eq!(transformed.shape(), &[4, 2]);
1935        assert_eq!(components.shape(), &[2, 2]);
1936        assert_eq!(evr.len(), 2);
1937        // Explained variance ratios should sum to ~1.0
1938        let total: f64 = evr.iter().sum();
1939        assert!(
1940            (total - 1.0).abs() < 1e-8,
1941            "explained variance ratios sum to {} (expected ~1.0)",
1942            total
1943        );
1944        // First component should explain most variance
1945        assert!(evr[0] > 0.9, "first component explains {} of variance", evr[0]);
1946    }
1947
1948    #[test]
1949    fn test_pca_single_component() {
1950        let data = Tensor::from_vec(
1951            vec![
1952                1.0, 2.0, 3.0,
1953                4.0, 5.0, 6.0,
1954                7.0, 8.0, 9.0,
1955            ],
1956            &[3, 3],
1957        )
1958        .unwrap();
1959        let (transformed, components, evr) = pca(&data, 1).unwrap();
1960        assert_eq!(transformed.shape(), &[3, 1]);
1961        assert_eq!(components.shape(), &[1, 3]);
1962        assert_eq!(evr.len(), 1);
1963        assert!(evr[0] > 0.0 && evr[0] <= 1.0);
1964    }
1965
1966    #[test]
1967    fn test_pca_explained_variance_ratio_bounded() {
1968        let data = Tensor::from_vec(
1969            vec![
1970                1.0, 0.0, 0.5,
1971                0.0, 1.0, 0.5,
1972                1.0, 1.0, 1.0,
1973                2.0, 0.0, 1.0,
1974                0.0, 2.0, 1.0,
1975            ],
1976            &[5, 3],
1977        )
1978        .unwrap();
1979        let (_, _, evr) = pca(&data, 3).unwrap();
1980        let total: f64 = evr.iter().sum();
1981        assert!(
1982            total <= 1.0 + 1e-10,
1983            "explained variance ratios sum to {} (should be <= 1.0)",
1984            total
1985        );
1986        for &r in &evr {
1987            assert!(r >= -1e-10, "negative explained variance ratio: {}", r);
1988        }
1989    }
1990
1991    #[test]
1992    fn test_pca_deterministic() {
1993        let data = Tensor::from_vec(
1994            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1995            &[2, 3],
1996        )
1997        .unwrap();
1998        let (t1, c1, e1) = pca(&data, 2).unwrap();
1999        let (t2, c2, e2) = pca(&data, 2).unwrap();
2000        assert_eq!(t1.to_vec(), t2.to_vec(), "PCA transformed not deterministic");
2001        assert_eq!(c1.to_vec(), c2.to_vec(), "PCA components not deterministic");
2002        assert_eq!(e1, e2, "PCA explained variance not deterministic");
2003    }
2004
2005    #[test]
2006    fn test_pca_invalid_n_components() {
2007        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2008        assert!(pca(&data, 0).is_err(), "n_components=0 should fail");
2009        assert!(pca(&data, 3).is_err(), "n_components > min(n,p) should fail");
2010    }
2011
2012    // --- Sprint 2: L-BFGS tests ---
2013
2014    /// Rosenbrock function: f(x,y) = (1-x)^2 + 100*(y-x^2)^2
2015    /// Gradient: df/dx = -2*(1-x) - 400*x*(y-x^2)
2016    ///           df/dy = 200*(y-x^2)
2017    /// Global minimum at (1, 1) with f=0.
2018    fn rosenbrock(p: &[f64]) -> (f64, Vec<f64>) {
2019        let x = p[0];
2020        let y = p[1];
2021        let a = 1.0 - x;
2022        let b = y - x * x;
2023        let val = a * a + 100.0 * b * b;
2024        let gx = -2.0 * a - 400.0 * x * b;
2025        let gy = 200.0 * b;
2026        (val, vec![gx, gy])
2027    }
2028
2029    #[test]
2030    fn test_lbfgs_rosenbrock_converges() {
2031        // Start far from optimum — L-BFGS should converge near (1, 1)
2032        let mut params = vec![-1.0_f64, 2.0_f64];
2033        let mut state = LbfgsState::new(0.5, 10);
2034
2035        let mut converged = false;
2036        for _iter in 0..200 {
2037            let (_, grads) = rosenbrock(&params);
2038            let grad_norm: f64 = kahan_dot(&grads, &grads).sqrt();
2039            if grad_norm < 1e-5 {
2040                converged = true;
2041                break;
2042            }
2043            let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, rosenbrock);
2044            params = new_p;
2045        }
2046        assert!(converged, "L-BFGS did not converge on Rosenbrock; params = {:?}", params);
2047        assert!(
2048            (params[0] - 1.0).abs() < 1e-3,
2049            "x should converge near 1.0, got {}",
2050            params[0]
2051        );
2052        assert!(
2053            (params[1] - 1.0).abs() < 1e-3,
2054            "y should converge near 1.0, got {}",
2055            params[1]
2056        );
2057    }
2058
2059    #[test]
2060    fn test_lbfgs_determinism() {
2061        // Same initial params + same seed → identical results
2062        let init = vec![-1.0_f64, 2.0_f64];
2063
2064        let run = |init: &[f64]| -> Vec<f64> {
2065            let mut params = init.to_vec();
2066            let mut state = LbfgsState::new(0.5, 10);
2067            for _ in 0..20 {
2068                let (_, grads) = rosenbrock(&params);
2069                let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, rosenbrock);
2070                params = new_p;
2071            }
2072            params
2073        };
2074
2075        let r1 = run(&init);
2076        let r2 = run(&init);
2077        assert_eq!(r1, r2, "L-BFGS must be bit-identical across runs");
2078    }
2079
2080    #[test]
2081    fn test_lbfgs_simple_quadratic() {
2082        // f(x) = x^2, minimum at x=0
2083        let mut params = vec![3.0_f64];
2084        let mut state = LbfgsState::new(1.0, 5);
2085        let quadratic = |p: &[f64]| -> (f64, Vec<f64>) {
2086            (p[0] * p[0], vec![2.0 * p[0]])
2087        };
2088
2089        for _ in 0..30 {
2090            let (_, grads) = quadratic(&params);
2091            let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, quadratic);
2092            params = new_p;
2093        }
2094        assert!(
2095            params[0].abs() < 1e-6,
2096            "L-BFGS should minimize x^2 to ~0, got {}",
2097            params[0]
2098        );
2099    }
2100
2101    #[test]
2102    fn test_embedding_basic() {
2103        let weight = crate::tensor::Tensor::from_vec(
2104            vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
2105            &[3, 2],
2106        ).unwrap();
2107        let indices = vec![0, 2, 1];
2108        let result = super::embedding(&weight, &indices).unwrap();
2109        assert_eq!(result.shape(), &[3, 2]);
2110        let data = result.to_vec();
2111        assert!((data[0] - 0.1).abs() < 1e-12);
2112        assert!((data[1] - 0.2).abs() < 1e-12);
2113        assert!((data[2] - 0.5).abs() < 1e-12);
2114        assert!((data[3] - 0.6).abs() < 1e-12);
2115        assert!((data[4] - 0.3).abs() < 1e-12);
2116        assert!((data[5] - 0.4).abs() < 1e-12);
2117    }
2118
2119    #[test]
2120    fn test_embedding_out_of_bounds() {
2121        let weight = crate::tensor::Tensor::from_vec(vec![1.0, 2.0], &[1, 2]).unwrap();
2122        let result = super::embedding(&weight, &[1]);
2123        assert!(result.is_err());
2124    }
2125
2126    #[test]
2127    fn test_batch_indices_deterministic() {
2128        let b1 = super::batch_indices(10, 3, 42);
2129        let b2 = super::batch_indices(10, 3, 42);
2130        assert_eq!(b1, b2);
2131        // Should cover all indices
2132        let total: usize = b1.iter().map(|(s, e)| e - s).sum();
2133        assert_eq!(total, 10);
2134    }
2135
2136    #[test]
2137    fn test_wolfe_line_search_armijo() {
2138        // For f(x) = x^2, starting at x=3, direction d=-1 (descent)
2139        // Wolfe search should find a valid step
2140        let params = vec![3.0_f64];
2141        let direction = vec![-1.0_f64];
2142        let grads = vec![6.0_f64]; // gradient at x=3
2143        let f0 = 9.0;
2144
2145        let mut eval_count = 0;
2146        let mut obj = |p: &[f64]| -> (f64, Vec<f64>) {
2147            eval_count += 1;
2148            (p[0] * p[0], vec![2.0 * p[0]])
2149        };
2150
2151        let (alpha, new_params, new_val, _) =
2152            wolfe_line_search(&params, &direction, &mut obj, f0, &grads, 1.0);
2153
2154        // Armijo: f(x + alpha*d) <= f(x) + c1*alpha*g^T*d
2155        let c1 = 1e-4;
2156        let derphi0 = kahan_dot(&grads, &direction); // = -6
2157        assert!(
2158            new_val <= f0 + c1 * alpha * derphi0,
2159            "Armijo condition violated: {} > {} + {} * {} * {}",
2160            new_val, f0, c1, alpha, derphi0
2161        );
2162        assert!(new_params[0] < 3.0, "Step should move toward minimum");
2163    }
2164}