Skip to main content

cjc_runtime/
ml.rs

1//! ML Toolkit — loss functions, optimizers, activations, metrics.
2//!
3//! # Determinism Contract
4//! - All functions are deterministic (no randomness except seeded kfold).
5//! - Kahan summation for all reductions.
6//! - Stable sort for AUC-ROC with index tie-breaking.
7
8use cjc_repro::KahanAccumulatorF64;
9
10use crate::accumulator::BinnedAccumulatorF64;
11use crate::error::RuntimeError;
12use crate::tensor::Tensor;
13
14// ---------------------------------------------------------------------------
15// Loss functions
16// ---------------------------------------------------------------------------
17
18/// Mean Squared Error: sum((pred - target)^2) / n.
19pub fn mse_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
20    if pred.len() != target.len() {
21        return Err("mse_loss: arrays must have same length".into());
22    }
23    if pred.is_empty() {
24        return Err("mse_loss: empty data".into());
25    }
26    let mut acc = KahanAccumulatorF64::new();
27    for i in 0..pred.len() {
28        let d = pred[i] - target[i];
29        acc.add(d * d);
30    }
31    Ok(acc.finalize() / pred.len() as f64)
32}
33
34/// Cross-entropy loss: -sum(target * ln(pred + eps)) / n.
35pub fn cross_entropy_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
36    if pred.len() != target.len() {
37        return Err("cross_entropy_loss: arrays must have same length".into());
38    }
39    if pred.is_empty() {
40        return Err("cross_entropy_loss: empty data".into());
41    }
42    let eps = 1e-12;
43    let mut acc = KahanAccumulatorF64::new();
44    for i in 0..pred.len() {
45        acc.add(-target[i] * (pred[i] + eps).ln());
46    }
47    Ok(acc.finalize() / pred.len() as f64)
48}
49
50/// Binary cross-entropy: -sum(t*ln(p) + (1-t)*ln(1-p)) / n.
51pub fn binary_cross_entropy(pred: &[f64], target: &[f64]) -> Result<f64, String> {
52    if pred.len() != target.len() {
53        return Err("binary_cross_entropy: arrays must have same length".into());
54    }
55    if pred.is_empty() {
56        return Err("binary_cross_entropy: empty data".into());
57    }
58    let eps = 1e-12;
59    let mut acc = KahanAccumulatorF64::new();
60    for i in 0..pred.len() {
61        let p = pred[i].max(eps).min(1.0 - eps);
62        acc.add(-(target[i] * p.ln() + (1.0 - target[i]) * (1.0 - p).ln()));
63    }
64    Ok(acc.finalize() / pred.len() as f64)
65}
66
67/// Huber loss: quadratic for small errors, linear for large.
68pub fn huber_loss(pred: &[f64], target: &[f64], delta: f64) -> Result<f64, String> {
69    if pred.len() != target.len() {
70        return Err("huber_loss: arrays must have same length".into());
71    }
72    if pred.is_empty() {
73        return Err("huber_loss: empty data".into());
74    }
75    let mut acc = KahanAccumulatorF64::new();
76    for i in 0..pred.len() {
77        let d = (pred[i] - target[i]).abs();
78        if d <= delta {
79            acc.add(0.5 * d * d);
80        } else {
81            acc.add(delta * (d - 0.5 * delta));
82        }
83    }
84    Ok(acc.finalize() / pred.len() as f64)
85}
86
87/// Hinge loss: sum(max(0, 1 - target * pred)) / n.
88pub fn hinge_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
89    if pred.len() != target.len() {
90        return Err("hinge_loss: arrays must have same length".into());
91    }
92    if pred.is_empty() {
93        return Err("hinge_loss: empty data".into());
94    }
95    let mut acc = KahanAccumulatorF64::new();
96    for i in 0..pred.len() {
97        acc.add((1.0 - target[i] * pred[i]).max(0.0));
98    }
99    Ok(acc.finalize() / pred.len() as f64)
100}
101
102// ---------------------------------------------------------------------------
103// Optimizers
104// ---------------------------------------------------------------------------
105
106/// SGD optimizer state.
107pub struct SgdState {
108    pub lr: f64,
109    pub momentum: f64,
110    pub velocity: Vec<f64>,
111}
112
113impl SgdState {
114    pub fn new(n_params: usize, lr: f64, momentum: f64) -> Self {
115        Self { lr, momentum, velocity: vec![0.0; n_params] }
116    }
117}
118
119/// SGD step: sequential, deterministic.
120pub fn sgd_step(params: &mut [f64], grads: &[f64], state: &mut SgdState) {
121    for i in 0..params.len() {
122        state.velocity[i] = state.momentum * state.velocity[i] + grads[i];
123        params[i] -= state.lr * state.velocity[i];
124    }
125}
126
127/// Adam optimizer state.
128pub struct AdamState {
129    pub lr: f64,
130    pub beta1: f64,
131    pub beta2: f64,
132    pub eps: f64,
133    pub t: u64,
134    pub m: Vec<f64>,
135    pub v: Vec<f64>,
136}
137
138impl AdamState {
139    pub fn new(n_params: usize, lr: f64) -> Self {
140        Self {
141            lr,
142            beta1: 0.9,
143            beta2: 0.999,
144            eps: 1e-8,
145            t: 0,
146            m: vec![0.0; n_params],
147            v: vec![0.0; n_params],
148        }
149    }
150}
151
152/// Adam step: sequential, deterministic.
153pub fn adam_step(params: &mut [f64], grads: &[f64], state: &mut AdamState) {
154    state.t += 1;
155    let t = state.t as f64;
156    for i in 0..params.len() {
157        state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * grads[i];
158        state.v[i] = state.beta2 * state.v[i] + (1.0 - state.beta2) * grads[i] * grads[i];
159        let m_hat = state.m[i] / (1.0 - state.beta1.powf(t));
160        let v_hat = state.v[i] / (1.0 - state.beta2.powf(t));
161        params[i] -= state.lr * m_hat / (v_hat.sqrt() + state.eps);
162    }
163}
164
165// ---------------------------------------------------------------------------
166// Classification metrics (Sprint 6)
167// ---------------------------------------------------------------------------
168
169/// Binary confusion matrix.
170#[derive(Debug, Clone)]
171pub struct ConfusionMatrix {
172    pub tp: usize,
173    pub fp: usize,
174    pub tn: usize,
175    pub fn_count: usize,
176}
177
178/// Build confusion matrix from predicted and actual boolean labels.
179pub fn confusion_matrix(predicted: &[bool], actual: &[bool]) -> ConfusionMatrix {
180    let mut tp = 0;
181    let mut fp = 0;
182    let mut tn = 0;
183    let mut fn_count = 0;
184    for i in 0..predicted.len().min(actual.len()) {
185        match (predicted[i], actual[i]) {
186            (true, true) => tp += 1,
187            (true, false) => fp += 1,
188            (false, true) => fn_count += 1,
189            (false, false) => tn += 1,
190        }
191    }
192    ConfusionMatrix { tp, fp, tn, fn_count }
193}
194
195/// Precision: TP / (TP + FP).
196pub fn precision(cm: &ConfusionMatrix) -> f64 {
197    let denom = cm.tp + cm.fp;
198    if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
199}
200
201/// Recall / sensitivity: TP / (TP + FN).
202pub fn recall(cm: &ConfusionMatrix) -> f64 {
203    let denom = cm.tp + cm.fn_count;
204    if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
205}
206
207/// F1 score: 2 * (precision * recall) / (precision + recall).
208pub fn f1_score(cm: &ConfusionMatrix) -> f64 {
209    let p = precision(cm);
210    let r = recall(cm);
211    if p + r == 0.0 { 0.0 } else { 2.0 * p * r / (p + r) }
212}
213
214/// Accuracy: (TP + TN) / total.
215pub fn accuracy(cm: &ConfusionMatrix) -> f64 {
216    let total = cm.tp + cm.fp + cm.tn + cm.fn_count;
217    if total == 0 { 0.0 } else { (cm.tp + cm.tn) as f64 / total as f64 }
218}
219
220/// AUC-ROC via trapezoidal rule.
221/// DETERMINISM: sort by score with stable sort + index tie-breaking.
222pub fn auc_roc(scores: &[f64], labels: &[bool]) -> Result<f64, String> {
223    if scores.len() != labels.len() {
224        return Err("auc_roc: scores and labels must have same length".into());
225    }
226    let n = scores.len();
227    if n == 0 {
228        return Err("auc_roc: empty data".into());
229    }
230    // Sort by score descending, stable with index tie-breaking
231    let mut indexed: Vec<(usize, f64, bool)> = scores.iter().zip(labels.iter())
232        .enumerate()
233        .map(|(i, (&s, &l))| (i, s, l))
234        .collect();
235    indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
236
237    let pos_count = labels.iter().filter(|&&l| l).count();
238    let neg_count = n - pos_count;
239    if pos_count == 0 || neg_count == 0 {
240        return Err("auc_roc: need both positive and negative labels".into());
241    }
242
243    let mut auc = 0.0;
244    let mut tp = 0.0;
245    let mut fp = 0.0;
246    let mut prev_fpr = 0.0;
247    let mut prev_tpr = 0.0;
248
249    for &(_, _, label) in &indexed {
250        if label { tp += 1.0; } else { fp += 1.0; }
251        let tpr = tp / pos_count as f64;
252        let fpr = fp / neg_count as f64;
253        auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
254        prev_fpr = fpr;
255        prev_tpr = tpr;
256    }
257    Ok(auc)
258}
259
260/// K-fold cross-validation indices.
261/// DETERMINISM: uses seeded RNG (Fisher-Yates).
262pub fn kfold_indices(n: usize, k: usize, seed: u64) -> Vec<(Vec<usize>, Vec<usize>)> {
263    let mut rng = cjc_repro::Rng::seeded(seed);
264    // Fisher-Yates shuffle of [0..n]
265    let mut indices: Vec<usize> = (0..n).collect();
266    for i in (1..n).rev() {
267        let j = (rng.next_u64() as usize) % (i + 1);
268        indices.swap(i, j);
269    }
270    let fold_size = n / k;
271    let mut folds = Vec::with_capacity(k);
272    for fold in 0..k {
273        let start = fold * fold_size;
274        let end = if fold == k - 1 { n } else { start + fold_size };
275        let test: Vec<usize> = indices[start..end].to_vec();
276        let train: Vec<usize> = indices[..start].iter()
277            .chain(indices[end..].iter())
278            .copied()
279            .collect();
280        folds.push((train, test));
281    }
282    folds
283}
284
285/// Train/test split indices.
286pub fn train_test_split(n: usize, test_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
287    let mut rng = cjc_repro::Rng::seeded(seed);
288    let mut indices: Vec<usize> = (0..n).collect();
289    for i in (1..n).rev() {
290        let j = (rng.next_u64() as usize) % (i + 1);
291        indices.swap(i, j);
292    }
293    let test_size = ((n as f64) * test_fraction).round() as usize;
294    let test = indices[..test_size].to_vec();
295    let train = indices[test_size..].to_vec();
296    (train, test)
297}
298
299// ---------------------------------------------------------------------------
300// Phase B4: ML Training Extensions
301// ---------------------------------------------------------------------------
302
303/// Batch normalization (inference mode).
304/// y = gamma * (x - running_mean) / sqrt(running_var + eps) + beta.
305pub fn batch_norm(
306    x: &[f64],
307    running_mean: &[f64],
308    running_var: &[f64],
309    gamma: &[f64],
310    beta: &[f64],
311    eps: f64,
312) -> Result<Vec<f64>, String> {
313    let n = x.len();
314    if running_mean.len() != n || running_var.len() != n || gamma.len() != n || beta.len() != n {
315        return Err("batch_norm: all arrays must have same length".into());
316    }
317    let mut result = Vec::with_capacity(n);
318    for i in 0..n {
319        let normed = (x[i] - running_mean[i]) / (running_var[i] + eps).sqrt();
320        result.push(gamma[i] * normed + beta[i]);
321    }
322    Ok(result)
323}
324
325/// Dropout mask generation using seeded RNG for determinism.
326/// Returns mask of 0.0 and scale values (1/(1-p)) using inverted dropout.
327pub fn dropout_mask(n: usize, drop_prob: f64, seed: u64) -> Vec<f64> {
328    let mut rng = cjc_repro::Rng::seeded(seed);
329    let scale = if drop_prob < 1.0 { 1.0 / (1.0 - drop_prob) } else { 0.0 };
330    let mut mask = Vec::with_capacity(n);
331    for _ in 0..n {
332        let r = (rng.next_u64() as f64) / (u64::MAX as f64);
333        if r < drop_prob {
334            mask.push(0.0);
335        } else {
336            mask.push(scale);
337        }
338    }
339    mask
340}
341
342/// Apply dropout: element-wise multiply data by mask.
343pub fn apply_dropout(data: &[f64], mask: &[f64]) -> Result<Vec<f64>, String> {
344    if data.len() != mask.len() {
345        return Err("apply_dropout: data and mask must have same length".into());
346    }
347    Ok(data.iter().zip(mask.iter()).map(|(&d, &m)| d * m).collect())
348}
349
350/// Learning rate schedule: step decay.
351/// lr = initial_lr * decay_rate^(floor(epoch / step_size))
352pub fn lr_step_decay(initial_lr: f64, decay_rate: f64, epoch: usize, step_size: usize) -> f64 {
353    initial_lr * decay_rate.powi((epoch / step_size) as i32)
354}
355
356/// Learning rate schedule: cosine annealing.
357/// lr = min_lr + 0.5 * (max_lr - min_lr) * (1 + cos(pi * epoch / total_epochs))
358pub fn lr_cosine(max_lr: f64, min_lr: f64, epoch: usize, total_epochs: usize) -> f64 {
359    let ratio = epoch as f64 / total_epochs as f64;
360    min_lr + 0.5 * (max_lr - min_lr) * (1.0 + (std::f64::consts::PI * ratio).cos())
361}
362
363/// Learning rate schedule: linear warmup.
364/// lr = initial_lr * min(1.0, epoch / warmup_epochs).
365pub fn lr_linear_warmup(initial_lr: f64, epoch: usize, warmup_epochs: usize) -> f64 {
366    if warmup_epochs == 0 {
367        return initial_lr;
368    }
369    initial_lr * (epoch as f64 / warmup_epochs as f64).min(1.0)
370}
371
372/// L1 regularization penalty: lambda * sum(|params|).
373pub fn l1_penalty(params: &[f64], lambda: f64) -> f64 {
374    let mut acc = KahanAccumulatorF64::new();
375    for &p in params {
376        acc.add(p.abs());
377    }
378    lambda * acc.finalize()
379}
380
381/// L2 regularization penalty: 0.5 * lambda * sum(params^2).
382pub fn l2_penalty(params: &[f64], lambda: f64) -> f64 {
383    let mut acc = KahanAccumulatorF64::new();
384    for &p in params {
385        acc.add(p * p);
386    }
387    0.5 * lambda * acc.finalize()
388}
389
390/// L1 regularization gradient: lambda * sign(params).
391pub fn l1_grad(params: &[f64], lambda: f64) -> Vec<f64> {
392    params.iter().map(|&p| {
393        if p > 0.0 { lambda } else if p < 0.0 { -lambda } else { 0.0 }
394    }).collect()
395}
396
397/// L2 regularization gradient: lambda * params.
398pub fn l2_grad(params: &[f64], lambda: f64) -> Vec<f64> {
399    params.iter().map(|&p| lambda * p).collect()
400}
401
402/// Early stopping state tracker.
403pub struct EarlyStoppingState {
404    pub patience: usize,
405    pub min_delta: f64,
406    pub best_loss: f64,
407    pub wait: usize,
408    pub stopped: bool,
409}
410
411impl EarlyStoppingState {
412    pub fn new(patience: usize, min_delta: f64) -> Self {
413        Self {
414            patience,
415            min_delta,
416            best_loss: f64::INFINITY,
417            wait: 0,
418            stopped: false,
419        }
420    }
421
422    /// Check if training should stop.
423    pub fn check(&mut self, current_loss: f64) -> bool {
424        if current_loss < self.best_loss - self.min_delta {
425            self.best_loss = current_loss;
426            self.wait = 0;
427        } else {
428            self.wait += 1;
429        }
430        if self.wait >= self.patience {
431            self.stopped = true;
432        }
433        self.stopped
434    }
435}
436
437// ---------------------------------------------------------------------------
438// Phase 3C: PCA (Principal Component Analysis)
439// ---------------------------------------------------------------------------
440
441/// Principal Component Analysis via SVD of centered data.
442///
443/// `data` is a 2D Tensor of shape (n_samples, n_features).
444/// `n_components` is the number of principal components to keep.
445///
446/// Returns (transformed_data, components, explained_variance_ratio):
447/// - `transformed_data`: (n_samples, n_components) — data projected onto principal components
448/// - `components`: (n_components, n_features) — principal component directions (rows)
449/// - `explained_variance_ratio`: Vec<f64> of length n_components — fraction of variance per component
450///
451/// **Determinism contract:** All reductions use `BinnedAccumulatorF64`.
452pub fn pca(
453    data: &Tensor,
454    n_components: usize,
455) -> Result<(Tensor, Tensor, Vec<f64>), RuntimeError> {
456    if data.ndim() != 2 {
457        return Err(RuntimeError::InvalidOperation(
458            "PCA requires a 2D data matrix".to_string(),
459        ));
460    }
461    let n_samples = data.shape()[0];
462    let n_features = data.shape()[1];
463
464    if n_samples == 0 || n_features == 0 {
465        return Err(RuntimeError::InvalidOperation(
466            "PCA: empty data matrix".to_string(),
467        ));
468    }
469    if n_components == 0 || n_components > n_features.min(n_samples) {
470        return Err(RuntimeError::InvalidOperation(format!(
471            "PCA: n_components ({}) must be in [1, min(n_samples, n_features) = {}]",
472            n_components,
473            n_features.min(n_samples)
474        )));
475    }
476
477    let raw = data.to_vec();
478
479    // Step 1: Compute column means using BinnedAccumulatorF64
480    let mut means = vec![0.0f64; n_features];
481    for j in 0..n_features {
482        let mut acc = BinnedAccumulatorF64::new();
483        for i in 0..n_samples {
484            acc.add(raw[i * n_features + j]);
485        }
486        means[j] = acc.finalize() / n_samples as f64;
487    }
488
489    // Step 2: Center the data
490    let mut centered = vec![0.0f64; n_samples * n_features];
491    for i in 0..n_samples {
492        for j in 0..n_features {
493            centered[i * n_features + j] = raw[i * n_features + j] - means[j];
494        }
495    }
496    let centered_tensor = Tensor::from_vec(centered, &[n_samples, n_features])?;
497
498    // Step 3: SVD of centered data
499    let (u, s, vt) = centered_tensor.svd()?;
500    let k = n_components.min(s.len());
501
502    // Step 4: Components = first k rows of Vt
503    let vt_data = vt.to_vec();
504    let vt_cols = vt.shape()[1]; // n_features
505    let mut components = vec![0.0f64; k * n_features];
506    for i in 0..k {
507        for j in 0..n_features {
508            components[i * n_features + j] = vt_data[i * vt_cols + j];
509        }
510    }
511
512    // Step 5: Explained variance = s_i^2 / (n_samples - 1)
513    // Total variance = sum of all s_i^2 / (n_samples - 1)
514    let denom = if n_samples > 1 {
515        (n_samples - 1) as f64
516    } else {
517        1.0
518    };
519
520    let mut total_var_acc = BinnedAccumulatorF64::new();
521    for &si in &s {
522        total_var_acc.add(si * si / denom);
523    }
524    let total_var = total_var_acc.finalize();
525
526    let explained_variance_ratio: Vec<f64> = if total_var > 1e-15 {
527        s[..k]
528            .iter()
529            .map(|&si| (si * si / denom) / total_var)
530            .collect()
531    } else {
532        vec![0.0; k]
533    };
534
535    // Step 6: Transformed data = U_k @ diag(S_k)
536    let u_data = u.to_vec();
537    let u_cols = u.shape()[1];
538    let mut transformed = vec![0.0f64; n_samples * k];
539    for i in 0..n_samples {
540        for j in 0..k {
541            transformed[i * k + j] = u_data[i * u_cols + j] * s[j];
542        }
543    }
544
545    Ok((
546        Tensor::from_vec(transformed, &[n_samples, k])?,
547        Tensor::from_vec(components, &[k, n_features])?,
548        explained_variance_ratio,
549    ))
550}
551
552// ---------------------------------------------------------------------------
553// L-BFGS Optimizer (Sprint 2)
554// ---------------------------------------------------------------------------
555
556/// L-BFGS optimizer state.
557///
558/// Limited-memory Broyden-Fletcher-Goldfarb-Shanno quasi-Newton method.
559/// Maintains a history of `m` most recent (s, y) pairs to approximate
560/// the inverse Hessian without storing it explicitly.
561pub struct LbfgsState {
562    pub lr: f64,
563    /// History size (default 10)
564    pub m: usize,
565    /// Parameter differences: s_k = params_{k+1} - params_k
566    pub s_history: Vec<Vec<f64>>,
567    /// Gradient differences: y_k = grad_{k+1} - grad_k
568    pub y_history: Vec<Vec<f64>>,
569    pub prev_params: Option<Vec<f64>>,
570    pub prev_grad: Option<Vec<f64>>,
571}
572
573impl LbfgsState {
574    pub fn new(lr: f64, m: usize) -> Self {
575        Self {
576            lr,
577            m,
578            s_history: Vec::new(),
579            y_history: Vec::new(),
580            prev_params: None,
581            prev_grad: None,
582        }
583    }
584}
585
586/// Dot product of two slices using Kahan summation for determinism.
587fn kahan_dot(a: &[f64], b: &[f64]) -> f64 {
588    debug_assert_eq!(a.len(), b.len());
589    let mut acc = KahanAccumulatorF64::new();
590    for (&ai, &bi) in a.iter().zip(b.iter()) {
591        acc.add(ai * bi);
592    }
593    acc.finalize()
594}
595
596/// Strong Wolfe line search.
597///
598/// Finds a step length `alpha` satisfying both:
599/// - Armijo (sufficient decrease): f(x + alpha*d) <= f(x) + c1*alpha*g^T*d
600/// - Curvature condition: |g(x + alpha*d)^T*d| <= c2*|g(x)^T*d|
601///
602/// Uses backtracking bracketing followed by bisection zoom.
603///
604/// # Arguments
605/// * `params` - Current parameters
606/// * `direction` - Search direction (should be a descent direction)
607/// * `f` - Function that returns (value, gradient) at a parameter vector
608/// * `f0` - Current function value
609/// * `g0` - Current gradient
610/// * `alpha_init` - Initial step length (typically `lr`)
611///
612/// # Returns
613/// `(alpha, new_params, new_val, new_grad)` or the best found so far on failure.
614pub fn wolfe_line_search<F>(
615    params: &[f64],
616    direction: &[f64],
617    f: &mut F,
618    f0: f64,
619    g0: &[f64],
620    alpha_init: f64,
621) -> (f64, Vec<f64>, f64, Vec<f64>)
622where
623    F: FnMut(&[f64]) -> (f64, Vec<f64>),
624{
625    let c1 = 1e-4_f64;
626    let c2 = 0.9_f64;
627    let derphi0 = kahan_dot(g0, direction); // should be negative for descent
628
629    // Helper: evaluate f at params + alpha * direction
630    let step = |alpha: f64| -> Vec<f64> {
631        params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
632    };
633
634    let max_iter = 30;
635    let mut alpha_lo = 0.0_f64;
636    let mut alpha_hi = f64::INFINITY;
637    let mut phi_lo = f0;
638    let mut dphi_lo = derphi0;
639    let mut alpha = alpha_init;
640
641    // Keep track of best valid alpha found (fallback)
642    let mut best_alpha = alpha_init;
643    let mut best_params = step(alpha_init);
644    let (mut best_val, mut best_grad) = f(&best_params);
645
646    for _iter in 0..max_iter {
647        let x_alpha = step(alpha);
648        let (phi_alpha, grad_alpha) = f(&x_alpha);
649        let dphi_alpha = kahan_dot(&grad_alpha, direction);
650
651        // Track the best point for fallback
652        if phi_alpha < best_val {
653            best_alpha = alpha;
654            best_params = x_alpha.clone();
655            best_val = phi_alpha;
656            best_grad = grad_alpha.clone();
657        }
658
659        // Armijo condition violated or function increased beyond bracket — shrink hi
660        if phi_alpha > f0 + c1 * alpha * derphi0 || (phi_alpha >= phi_lo && alpha_lo > 0.0) {
661            alpha_hi = alpha;
662            // Zoom between alpha_lo and alpha_hi
663            let (za, zp, zv, zg) = wolfe_zoom(
664                params, direction, f, f0, derphi0, c1, c2,
665                alpha_lo, alpha_hi, phi_lo, dphi_lo,
666            );
667            return (za, zp, zv, zg);
668        }
669
670        // Strong Wolfe curvature condition satisfied — done
671        if dphi_alpha.abs() <= c2 * derphi0.abs() {
672            return (alpha, x_alpha, phi_alpha, grad_alpha);
673        }
674
675        // Derivative is positive — zoom between alpha and alpha_lo
676        if dphi_alpha >= 0.0 {
677            let (za, zp, zv, zg) = wolfe_zoom(
678                params, direction, f, f0, derphi0, c1, c2,
679                alpha, alpha_lo, phi_alpha, dphi_alpha,
680            );
681            return (za, zp, zv, zg);
682        }
683
684        // Expand: step is too short
685        alpha_lo = alpha;
686        phi_lo = phi_alpha;
687        dphi_lo = dphi_alpha;
688
689        // Expand by factor 2 (capped to prevent explosion)
690        alpha = if alpha_hi.is_finite() {
691            (alpha_lo + alpha_hi) * 0.5
692        } else {
693            (alpha * 2.0).min(1e8)
694        };
695    }
696
697    // Return best found
698    (best_alpha, best_params, best_val, best_grad)
699}
700
701/// Zoom phase of Wolfe line search — bisects between [alpha_lo, alpha_hi].
702#[allow(clippy::too_many_arguments)]
703fn wolfe_zoom<F>(
704    params: &[f64],
705    direction: &[f64],
706    f: &mut F,
707    f0: f64,
708    derphi0: f64,
709    c1: f64,
710    c2: f64,
711    mut alpha_lo: f64,
712    mut alpha_hi: f64,
713    mut phi_lo: f64,
714    _dphi_lo: f64,
715) -> (f64, Vec<f64>, f64, Vec<f64>)
716where
717    F: FnMut(&[f64]) -> (f64, Vec<f64>),
718{
719    let step = |alpha: f64| -> Vec<f64> {
720        params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
721    };
722
723    let max_zoom = 20;
724    let mut best_alpha = alpha_lo;
725    let mut best_x = step(alpha_lo);
726    let (mut best_val, mut best_grad) = f(&best_x);
727
728    for _i in 0..max_zoom {
729        let alpha_j = (alpha_lo + alpha_hi) * 0.5;
730        let x_j = step(alpha_j);
731        let (phi_j, grad_j) = f(&x_j);
732        let dphi_j = kahan_dot(&grad_j, direction);
733
734        if phi_j < best_val {
735            best_alpha = alpha_j;
736            best_x = x_j.clone();
737            best_val = phi_j;
738            best_grad = grad_j.clone();
739        }
740
741        if phi_j > f0 + c1 * alpha_j * derphi0 || phi_j >= phi_lo {
742            alpha_hi = alpha_j;
743        } else {
744            // Curvature condition satisfied
745            if dphi_j.abs() <= c2 * derphi0.abs() {
746                return (alpha_j, x_j, phi_j, grad_j);
747            }
748            if dphi_j * (alpha_hi - alpha_lo) >= 0.0 {
749                alpha_hi = alpha_lo;
750            }
751            alpha_lo = alpha_j;
752            phi_lo = phi_j;
753        }
754
755        // Convergence check on bracket width
756        if (alpha_hi - alpha_lo).abs() < 1e-14 {
757            break;
758        }
759    }
760
761    (best_alpha, best_x, best_val, best_grad)
762}
763
764/// L-BFGS step with strong Wolfe line search.
765///
766/// Given current parameters and gradients (and a closure to evaluate the
767/// function and gradient at any point), computes the L-BFGS search direction
768/// via the two-loop recursion, performs a Wolfe line search along that
769/// direction, then updates the (s, y) history buffers.
770///
771/// # Arguments
772/// * `params` - Current parameter vector (modified in-place on success)
773/// * `grads` - Gradient at current params
774/// * `state` - L-BFGS state (history + meta-parameters)
775/// * `f` - Closure: takes parameter slice, returns `(loss, gradient_vec)`
776///
777/// # Returns
778/// `(new_params, new_grads, step_taken)` — step_taken is false if the search
779/// direction was not a descent direction (resets to gradient descent).
780pub fn lbfgs_step<F>(
781    params: &[f64],
782    grads: &[f64],
783    state: &mut LbfgsState,
784    mut f: F,
785) -> (Vec<f64>, Vec<f64>, bool)
786where
787    F: FnMut(&[f64]) -> (f64, Vec<f64>),
788{
789    let n = params.len();
790    debug_assert_eq!(grads.len(), n);
791
792    // ---- Two-loop L-BFGS recursion ----
793    // Computes H_k * g using the compact L-BFGS inverse Hessian approximation.
794    let hist_len = state.s_history.len();
795    let mut q: Vec<f64> = grads.to_vec();
796    let mut alphas = vec![0.0_f64; hist_len];
797    let mut rhos = vec![0.0_f64; hist_len];
798
799    // Loop 1: backward through history
800    for i in (0..hist_len).rev() {
801        let sy = kahan_dot(&state.s_history[i], &state.y_history[i]);
802        rhos[i] = if sy.abs() < 1e-300 { 0.0 } else { 1.0 / sy };
803        let sq = kahan_dot(&state.s_history[i], &q);
804        alphas[i] = rhos[i] * sq;
805        // q -= alpha_i * y_i
806        for j in 0..n {
807            q[j] -= alphas[i] * state.y_history[i][j];
808        }
809    }
810
811    // Scale by initial Hessian approximation H_0 = (s^T y / y^T y) * I
812    let scale = if hist_len > 0 {
813        let last = hist_len - 1;
814        let sy = kahan_dot(&state.s_history[last], &state.y_history[last]);
815        let yy = kahan_dot(&state.y_history[last], &state.y_history[last]);
816        if yy.abs() < 1e-300 { 1.0 } else { sy / yy }
817    } else {
818        1.0
819    };
820    let mut r: Vec<f64> = q.iter().map(|&qi| scale * qi).collect();
821
822    // Loop 2: forward through history
823    for i in 0..hist_len {
824        let yr = kahan_dot(&state.y_history[i], &r);
825        let beta = rhos[i] * yr;
826        // r += (alpha_i - beta) * s_i
827        let diff = alphas[i] - beta;
828        for j in 0..n {
829            r[j] += diff * state.s_history[i][j];
830        }
831    }
832
833    // Search direction: d = -H_k * g = -r
834    let direction: Vec<f64> = r.iter().map(|&ri| -ri).collect();
835
836    // Check descent condition: d^T g < 0
837    let descent_check = kahan_dot(&direction, grads);
838    let (direction, is_descent) = if descent_check >= 0.0 || !descent_check.is_finite() {
839        // Fall back to scaled negative gradient
840        let norm_g = kahan_dot(grads, grads).sqrt().max(1e-300);
841        (grads.iter().map(|&g| -g / norm_g).collect::<Vec<f64>>(), false)
842    } else {
843        (direction, true)
844    };
845
846    // ---- Wolfe line search ----
847    let (f0, _) = f(params);
848    let (_, new_params, _, new_grads) = wolfe_line_search(
849        params,
850        &direction,
851        &mut f,
852        f0,
853        grads,
854        state.lr,
855    );
856
857    // ---- Update history ----
858    // s_k = new_params - params
859    let s_k: Vec<f64> = new_params.iter().zip(params.iter()).map(|(&np, &p)| np - p).collect();
860    // y_k = new_grads - grads
861    let y_k: Vec<f64> = new_grads.iter().zip(grads.iter()).map(|(&ng, &g)| ng - g).collect();
862
863    let sy = kahan_dot(&s_k, &y_k);
864    // Only add to history if curvature condition holds (sy > 0)
865    if sy > 1e-300 {
866        state.s_history.push(s_k);
867        state.y_history.push(y_k);
868        // Trim to m most recent
869        if state.s_history.len() > state.m {
870            state.s_history.remove(0);
871            state.y_history.remove(0);
872        }
873    }
874
875    state.prev_params = Some(new_params.clone());
876    state.prev_grad = Some(new_grads.clone());
877
878    (new_params, new_grads, is_descent)
879}
880
881// ---------------------------------------------------------------------------
882// LSTM cell forward pass
883// ---------------------------------------------------------------------------
884
885/// LSTM cell forward pass.
886///
887/// * `x`:      `[batch, input_size]`
888/// * `h_prev`: `[batch, hidden_size]`
889/// * `c_prev`: `[batch, hidden_size]`
890/// * `w_ih`:   `[4*hidden_size, input_size]`
891/// * `w_hh`:   `[4*hidden_size, hidden_size]`
892/// * `b_ih`:   `[4*hidden_size]`
893/// * `b_hh`:   `[4*hidden_size]`
894///
895/// Returns `(h_new, c_new)`.
896///
897/// Gate layout (PyTorch convention): i, f, g, o — each of size `hidden_size`.
898/// All reductions use Kahan summation via the existing `Tensor::linear` method.
899pub fn lstm_cell(
900    x: &Tensor,
901    h_prev: &Tensor,
902    c_prev: &Tensor,
903    w_ih: &Tensor,
904    w_hh: &Tensor,
905    b_ih: &Tensor,
906    b_hh: &Tensor,
907) -> Result<(Tensor, Tensor), String> {
908    let map_err = |e: crate::error::RuntimeError| format!("{e}");
909
910    // Validate shapes
911    if x.ndim() != 2 {
912        return Err("lstm_cell: x must be 2-D [batch, input_size]".into());
913    }
914    if h_prev.ndim() != 2 {
915        return Err("lstm_cell: h_prev must be 2-D [batch, hidden_size]".into());
916    }
917    if c_prev.ndim() != 2 {
918        return Err("lstm_cell: c_prev must be 2-D [batch, hidden_size]".into());
919    }
920    let hidden_size = h_prev.shape()[1];
921    if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
922        return Err(format!(
923            "lstm_cell: w_ih must be [4*hidden_size, input_size], got {:?}",
924            w_ih.shape()
925        ));
926    }
927    if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
928        return Err(format!(
929            "lstm_cell: w_hh must be [4*hidden_size, hidden_size], got {:?}",
930            w_hh.shape()
931        ));
932    }
933    if b_ih.len() != 4 * hidden_size {
934        return Err(format!(
935            "lstm_cell: b_ih must have length 4*hidden_size={}, got {}",
936            4 * hidden_size,
937            b_ih.len()
938        ));
939    }
940    if b_hh.len() != 4 * hidden_size {
941        return Err(format!(
942            "lstm_cell: b_hh must have length 4*hidden_size={}, got {}",
943            4 * hidden_size,
944            b_hh.len()
945        ));
946    }
947
948    // gates = x @ w_ih^T + b_ih + h_prev @ w_hh^T + b_hh
949    // Use linear which does: input @ weight^T + bias
950    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
951    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
952    let gates = gates_ih.add(&gates_hh).map_err(map_err)?;
953
954    // gates is [batch, 4*hidden_size]. Split into 4 chunks along dim 1.
955    let chunks = gates.chunk(4, 1).map_err(map_err)?;
956    let gates_i = &chunks[0];
957    let gates_f = &chunks[1];
958    let gates_g = &chunks[2];
959    let gates_o = &chunks[3];
960
961    // Apply activations
962    let i = gates_i.sigmoid();
963    let f = gates_f.sigmoid();
964    let g = gates_g.tanh_activation();
965    let o = gates_o.sigmoid();
966
967    // c_new = f * c_prev + i * g
968    let fc = f.mul_elem(c_prev).map_err(map_err)?;
969    let ig = i.mul_elem(&g).map_err(map_err)?;
970    let c_new = fc.add(&ig).map_err(map_err)?;
971
972    // h_new = o * tanh(c_new)
973    let c_tanh = c_new.tanh_activation();
974    let h_new = o.mul_elem(&c_tanh).map_err(map_err)?;
975
976    Ok((h_new, c_new))
977}
978
979// ---------------------------------------------------------------------------
980// GRU cell forward pass
981// ---------------------------------------------------------------------------
982
983/// GRU cell forward pass.
984///
985/// * `x`:      `[batch, input_size]`
986/// * `h_prev`: `[batch, hidden_size]`
987/// * `w_ih`:   `[3*hidden_size, input_size]`
988/// * `w_hh`:   `[3*hidden_size, hidden_size]`
989/// * `b_ih`:   `[3*hidden_size]`
990/// * `b_hh`:   `[3*hidden_size]`
991///
992/// Returns `h_new` tensor of shape `[batch, hidden_size]`.
993///
994/// Gate layout: r (reset), z (update), n (new) — each of size `hidden_size`.
995/// All reductions use Kahan summation via the existing `Tensor::linear` method.
996pub fn gru_cell(
997    x: &Tensor,
998    h_prev: &Tensor,
999    w_ih: &Tensor,
1000    w_hh: &Tensor,
1001    b_ih: &Tensor,
1002    b_hh: &Tensor,
1003) -> Result<Tensor, String> {
1004    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1005
1006    // Validate shapes
1007    if x.ndim() != 2 {
1008        return Err("gru_cell: x must be 2-D [batch, input_size]".into());
1009    }
1010    if h_prev.ndim() != 2 {
1011        return Err("gru_cell: h_prev must be 2-D [batch, hidden_size]".into());
1012    }
1013    let hidden_size = h_prev.shape()[1];
1014    if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1015        return Err(format!(
1016            "gru_cell: w_ih must be [3*hidden_size, input_size], got {:?}",
1017            w_ih.shape()
1018        ));
1019    }
1020    if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1021        return Err(format!(
1022            "gru_cell: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1023            w_hh.shape()
1024        ));
1025    }
1026    if b_ih.len() != 3 * hidden_size {
1027        return Err(format!(
1028            "gru_cell: b_ih must have length 3*hidden_size={}, got {}",
1029            3 * hidden_size,
1030            b_ih.len()
1031        ));
1032    }
1033    if b_hh.len() != 3 * hidden_size {
1034        return Err(format!(
1035            "gru_cell: b_hh must have length 3*hidden_size={}, got {}",
1036            3 * hidden_size,
1037            b_hh.len()
1038        ));
1039    }
1040
1041    // gates_ih = x @ w_ih^T + b_ih   → [batch, 3*hidden_size]
1042    // gates_hh = h_prev @ w_hh^T + b_hh → [batch, 3*hidden_size]
1043    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1044    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1045
1046    // Split into r, z, n portions
1047    let ih_chunks = gates_ih.chunk(3, 1).map_err(map_err)?;
1048    let hh_chunks = gates_hh.chunk(3, 1).map_err(map_err)?;
1049    let r_ih = &ih_chunks[0];
1050    let z_ih = &ih_chunks[1];
1051    let n_ih = &ih_chunks[2];
1052    let r_hh = &hh_chunks[0];
1053    let z_hh = &hh_chunks[1];
1054    let n_hh = &hh_chunks[2];
1055
1056    // r = sigmoid(r_ih + r_hh)
1057    let r = r_ih.add(r_hh).map_err(map_err)?.sigmoid();
1058    // z = sigmoid(z_ih + z_hh)
1059    let z = z_ih.add(z_hh).map_err(map_err)?.sigmoid();
1060    // n = tanh(n_ih + r * n_hh)
1061    let r_n_hh = r.mul_elem(n_hh).map_err(map_err)?;
1062    let n = n_ih.add(&r_n_hh).map_err(map_err)?.tanh_activation();
1063
1064    // h_new = (1 - z) * n + z * h_prev
1065    // Build (1 - z): negate z then add ones
1066    let ones = Tensor::ones(z.shape());
1067    let one_minus_z = ones.sub(&z).map_err(map_err)?;
1068    let term1 = one_minus_z.mul_elem(&n).map_err(map_err)?;
1069    let term2 = z.mul_elem(h_prev).map_err(map_err)?;
1070    let h_new = term1.add(&term2).map_err(map_err)?;
1071
1072    Ok(h_new)
1073}
1074
1075// ---------------------------------------------------------------------------
1076// Fused LSTM cell forward pass (allocation-optimized)
1077// ---------------------------------------------------------------------------
1078
1079/// Fused LSTM cell: minimizes intermediate tensor allocations.
1080///
1081/// Produces bit-identical results to [`lstm_cell`] but reduces tensor
1082/// allocations from 13 to 4 (2 matmuls via `linear`, 2 output `from_vec`).
1083/// After the two `linear` calls the gate combination, activations, and cell /
1084/// hidden-state updates are computed element-wise in a single scalar loop
1085/// with no additional tensor temporaries.
1086///
1087/// Shape requirements are identical to [`lstm_cell`].
1088pub fn lstm_cell_fused(
1089    x: &Tensor,
1090    h_prev: &Tensor,
1091    c_prev: &Tensor,
1092    w_ih: &Tensor,
1093    w_hh: &Tensor,
1094    b_ih: &Tensor,
1095    b_hh: &Tensor,
1096) -> Result<(Tensor, Tensor), String> {
1097    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1098
1099    // --- Validate shapes (same checks as lstm_cell) --------------------------
1100    if x.ndim() != 2 {
1101        return Err("lstm_cell_fused: x must be 2-D [batch, input_size]".into());
1102    }
1103    if h_prev.ndim() != 2 {
1104        return Err("lstm_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1105    }
1106    if c_prev.ndim() != 2 {
1107        return Err("lstm_cell_fused: c_prev must be 2-D [batch, hidden_size]".into());
1108    }
1109    let batch = x.shape()[0];
1110    let hidden_size = h_prev.shape()[1];
1111    if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1112        return Err(format!(
1113            "lstm_cell_fused: w_ih must be [4*hidden_size, input_size], got {:?}",
1114            w_ih.shape()
1115        ));
1116    }
1117    if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1118        return Err(format!(
1119            "lstm_cell_fused: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1120            w_hh.shape()
1121        ));
1122    }
1123    if b_ih.len() != 4 * hidden_size {
1124        return Err(format!(
1125            "lstm_cell_fused: b_ih must have length 4*hidden_size={}, got {}",
1126            4 * hidden_size,
1127            b_ih.len()
1128        ));
1129    }
1130    if b_hh.len() != 4 * hidden_size {
1131        return Err(format!(
1132            "lstm_cell_fused: b_hh must have length 4*hidden_size={}, got {}",
1133            4 * hidden_size,
1134            b_hh.len()
1135        ));
1136    }
1137
1138    // --- Step 1: two matmuls (reuse existing Kahan-summation linear) ---------
1139    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1140    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1141
1142    // --- Step 2: fused gate combination + activations + state update ---------
1143    let gih = gates_ih.to_vec();
1144    let ghh = gates_hh.to_vec();
1145    let cprev = c_prev.to_vec();
1146
1147    let mut h_new_data = vec![0.0f64; batch * hidden_size];
1148    let mut c_new_data = vec![0.0f64; batch * hidden_size];
1149
1150    for b_idx in 0..batch {
1151        let base = b_idx * 4 * hidden_size;
1152        for h in 0..hidden_size {
1153            // Combine ih + hh for each gate (i, f, g, o)
1154            let gi = gih[base + h] + ghh[base + h];
1155            let gf = gih[base + hidden_size + h] + ghh[base + hidden_size + h];
1156            let gg = gih[base + 2 * hidden_size + h] + ghh[base + 2 * hidden_size + h];
1157            let go = gih[base + 3 * hidden_size + h] + ghh[base + 3 * hidden_size + h];
1158
1159            // Activations (scalar, no tensor allocation)
1160            let i_val = 1.0 / (1.0 + (-gi).exp()); // sigmoid
1161            let f_val = 1.0 / (1.0 + (-gf).exp()); // sigmoid
1162            let g_val = gg.tanh();                   // tanh
1163            let o_val = 1.0 / (1.0 + (-go).exp()); // sigmoid
1164
1165            // Cell and hidden state update
1166            let c_idx = b_idx * hidden_size + h;
1167            let c_val = f_val * cprev[c_idx] + i_val * g_val;
1168            c_new_data[c_idx] = c_val;
1169            h_new_data[c_idx] = o_val * c_val.tanh();
1170        }
1171    }
1172
1173    let h_new = Tensor::from_vec(h_new_data, &[batch, hidden_size]).map_err(map_err)?;
1174    let c_new = Tensor::from_vec(c_new_data, &[batch, hidden_size]).map_err(map_err)?;
1175
1176    Ok((h_new, c_new))
1177}
1178
1179// ---------------------------------------------------------------------------
1180// Fused GRU cell forward pass (allocation-optimized)
1181// ---------------------------------------------------------------------------
1182
1183/// Fused GRU cell: minimizes intermediate tensor allocations.
1184///
1185/// Produces bit-identical results to [`gru_cell`] but reduces tensor
1186/// allocations from ~12 to 3 (2 matmuls via `linear`, 1 output `from_vec`).
1187/// After the two `linear` calls the gate combination, activations, and
1188/// hidden-state update are computed element-wise in a single scalar loop.
1189///
1190/// Shape requirements are identical to [`gru_cell`].
1191pub fn gru_cell_fused(
1192    x: &Tensor,
1193    h_prev: &Tensor,
1194    w_ih: &Tensor,
1195    w_hh: &Tensor,
1196    b_ih: &Tensor,
1197    b_hh: &Tensor,
1198) -> Result<Tensor, String> {
1199    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1200
1201    // --- Validate shapes (same checks as gru_cell) ---------------------------
1202    if x.ndim() != 2 {
1203        return Err("gru_cell_fused: x must be 2-D [batch, input_size]".into());
1204    }
1205    if h_prev.ndim() != 2 {
1206        return Err("gru_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1207    }
1208    let batch = x.shape()[0];
1209    let hidden_size = h_prev.shape()[1];
1210    if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1211        return Err(format!(
1212            "gru_cell_fused: w_ih must be [3*hidden_size, input_size], got {:?}",
1213            w_ih.shape()
1214        ));
1215    }
1216    if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1217        return Err(format!(
1218            "gru_cell_fused: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1219            w_hh.shape()
1220        ));
1221    }
1222    if b_ih.len() != 3 * hidden_size {
1223        return Err(format!(
1224            "gru_cell_fused: b_ih must have length 3*hidden_size={}, got {}",
1225            3 * hidden_size,
1226            b_ih.len()
1227        ));
1228    }
1229    if b_hh.len() != 3 * hidden_size {
1230        return Err(format!(
1231            "gru_cell_fused: b_hh must have length 3*hidden_size={}, got {}",
1232            3 * hidden_size,
1233            b_hh.len()
1234        ));
1235    }
1236
1237    // --- Step 1: two matmuls (reuse existing Kahan-summation linear) ---------
1238    let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1239    let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1240
1241    // --- Step 2: fused gate combination + activations + state update ---------
1242    let gih = gates_ih.to_vec();
1243    let ghh = gates_hh.to_vec();
1244    let hp = h_prev.to_vec();
1245
1246    let mut h_new_data = vec![0.0f64; batch * hidden_size];
1247
1248    for b_idx in 0..batch {
1249        let base = b_idx * 3 * hidden_size;
1250        for h in 0..hidden_size {
1251            // r = sigmoid(ih_r + hh_r)
1252            let r_val =
1253                1.0 / (1.0 + (-(gih[base + h] + ghh[base + h])).exp());
1254            // z = sigmoid(ih_z + hh_z)
1255            let z_val = 1.0
1256                / (1.0
1257                    + (-(gih[base + hidden_size + h]
1258                        + ghh[base + hidden_size + h]))
1259                        .exp());
1260            // n = tanh(ih_n + r * hh_n)
1261            let n_val = (gih[base + 2 * hidden_size + h]
1262                + r_val * ghh[base + 2 * hidden_size + h])
1263                .tanh();
1264
1265            // h_new = (1 - z) * n + z * h_prev
1266            let h_idx = b_idx * hidden_size + h;
1267            h_new_data[h_idx] = (1.0 - z_val) * n_val + z_val * hp[h_idx];
1268        }
1269    }
1270
1271    Tensor::from_vec(h_new_data, &[batch, hidden_size])
1272        .map_err(|e| format!("{e}"))
1273}
1274
1275// ---------------------------------------------------------------------------
1276// Multi-Head Attention
1277// ---------------------------------------------------------------------------
1278
1279/// Multi-head attention: Q, K, V projections + scaled dot-product attention + output projection.
1280///
1281/// * `q`, `k`, `v`: `[batch, seq, model_dim]`
1282/// * `w_q`, `w_k`, `w_v`, `w_o`: `[model_dim, model_dim]`
1283/// * `b_q`, `b_k`, `b_v`, `b_o`: `[model_dim]`
1284/// * `num_heads`: number of attention heads (`model_dim` must be divisible by `num_heads`)
1285///
1286/// Returns `[batch, seq, model_dim]`.
1287///
1288/// All reductions use Kahan summation via the existing `Tensor::linear` and
1289/// `Tensor::scaled_dot_product_attention` methods.
1290pub fn multi_head_attention(
1291    q: &Tensor,
1292    k: &Tensor,
1293    v: &Tensor,
1294    w_q: &Tensor,
1295    w_k: &Tensor,
1296    w_v: &Tensor,
1297    w_o: &Tensor,
1298    b_q: &Tensor,
1299    b_k: &Tensor,
1300    b_v: &Tensor,
1301    b_o: &Tensor,
1302    num_heads: usize,
1303) -> Result<Tensor, String> {
1304    let map_err = |e: crate::error::RuntimeError| format!("{e}");
1305
1306    if q.ndim() != 3 {
1307        return Err("multi_head_attention: q must be 3-D [batch, seq, model_dim]".into());
1308    }
1309
1310    // Linear projections: Q = q.linear(w_q, b_q), etc.
1311    let q_proj = q.linear(w_q, b_q).map_err(map_err)?;
1312    let k_proj = k.linear(w_k, b_k).map_err(map_err)?;
1313    let v_proj = v.linear(w_v, b_v).map_err(map_err)?;
1314
1315    // Split heads: [batch, seq, model_dim] -> [batch, num_heads, seq, head_dim]
1316    let q_heads = q_proj.split_heads(num_heads).map_err(map_err)?;
1317    let k_heads = k_proj.split_heads(num_heads).map_err(map_err)?;
1318    let v_heads = v_proj.split_heads(num_heads).map_err(map_err)?;
1319
1320    // Scaled dot-product attention: works on [..., seq, head_dim]
1321    let attn = Tensor::scaled_dot_product_attention(&q_heads, &k_heads, &v_heads)
1322        .map_err(map_err)?;
1323
1324    // Merge heads back: [batch, num_heads, seq, head_dim] -> [batch, seq, model_dim]
1325    let merged = attn.merge_heads().map_err(map_err)?;
1326
1327    // Output projection
1328    let output = merged.linear(w_o, b_o).map_err(map_err)?;
1329
1330    Ok(output)
1331}
1332
1333// ---------------------------------------------------------------------------
1334// Tests
1335// ---------------------------------------------------------------------------
1336
1337#[cfg(test)]
1338mod tests {
1339    use super::*;
1340
1341    #[test]
1342    fn test_mse_zero() {
1343        let pred = [1.0, 2.0, 3.0];
1344        let target = [1.0, 2.0, 3.0];
1345        assert_eq!(mse_loss(&pred, &target).unwrap(), 0.0);
1346    }
1347
1348    #[test]
1349    fn test_mse_basic() {
1350        let pred = [1.0, 2.0, 3.0];
1351        let target = [2.0, 3.0, 4.0];
1352        assert_eq!(mse_loss(&pred, &target).unwrap(), 1.0);
1353    }
1354
1355    #[test]
1356    fn test_huber_loss_quadratic() {
1357        let pred = [1.0];
1358        let target = [1.5];
1359        let h = huber_loss(&pred, &target, 1.0).unwrap();
1360        // |0.5| < 1.0, so quadratic: 0.5 * 0.25 = 0.125
1361        assert!((h - 0.125).abs() < 1e-12);
1362    }
1363
1364    #[test]
1365    fn test_sgd_step() {
1366        let mut params = [1.0, 2.0];
1367        let grads = [0.1, 0.2];
1368        let mut state = SgdState::new(2, 0.1, 0.0);
1369        sgd_step(&mut params, &grads, &mut state);
1370        assert!((params[0] - 0.99).abs() < 1e-12);
1371        assert!((params[1] - 1.98).abs() < 1e-12);
1372    }
1373
1374    #[test]
1375    fn test_adam_step() {
1376        let mut params = [1.0, 2.0];
1377        let grads = [0.1, 0.2];
1378        let mut state = AdamState::new(2, 0.001);
1379        adam_step(&mut params, &grads, &mut state);
1380        // After one step, params should be slightly different
1381        assert!(params[0] < 1.0);
1382        assert!(params[1] < 2.0);
1383    }
1384
1385    #[test]
1386    fn test_confusion_matrix() {
1387        let pred = [true, true, false, false, true];
1388        let actual = [true, false, true, false, true];
1389        let cm = confusion_matrix(&pred, &actual);
1390        assert_eq!(cm.tp, 2);
1391        assert_eq!(cm.fp, 1);
1392        assert_eq!(cm.fn_count, 1);
1393        assert_eq!(cm.tn, 1);
1394    }
1395
1396    #[test]
1397    fn test_precision_recall_f1() {
1398        let cm = ConfusionMatrix { tp: 5, fp: 2, tn: 8, fn_count: 1 };
1399        assert!((precision(&cm) - 5.0 / 7.0).abs() < 1e-12);
1400        assert!((recall(&cm) - 5.0 / 6.0).abs() < 1e-12);
1401    }
1402
1403    #[test]
1404    fn test_auc_perfect() {
1405        let scores = [0.9, 0.8, 0.2, 0.1];
1406        let labels = [true, true, false, false];
1407        let auc = auc_roc(&scores, &labels).unwrap();
1408        assert!((auc - 1.0).abs() < 1e-12);
1409    }
1410
1411    #[test]
1412    fn test_kfold_deterministic() {
1413        let f1 = kfold_indices(100, 5, 42);
1414        let f2 = kfold_indices(100, 5, 42);
1415        for i in 0..5 {
1416            assert_eq!(f1[i].0, f2[i].0);
1417            assert_eq!(f1[i].1, f2[i].1);
1418        }
1419    }
1420
1421    #[test]
1422    fn test_train_test_split_coverage() {
1423        let (train, test) = train_test_split(100, 0.2, 42);
1424        assert_eq!(train.len() + test.len(), 100);
1425        assert_eq!(test.len(), 20);
1426    }
1427
1428    // --- B4: ML Training Extensions tests ---
1429
1430    #[test]
1431    fn test_batch_norm_identity() {
1432        // mean=0, var=1, gamma=1, beta=0, eps=0 → input unchanged
1433        let x = vec![1.0, 2.0, 3.0];
1434        let mean = vec![0.0, 0.0, 0.0];
1435        let var = vec![1.0, 1.0, 1.0];
1436        let gamma = vec![1.0, 1.0, 1.0];
1437        let beta = vec![0.0, 0.0, 0.0];
1438        let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1439        assert!((result[0] - 1.0).abs() < 1e-12);
1440        assert!((result[1] - 2.0).abs() < 1e-12);
1441        assert!((result[2] - 3.0).abs() < 1e-12);
1442    }
1443
1444    #[test]
1445    fn test_batch_norm_shift_scale() {
1446        let x = vec![0.0];
1447        let mean = vec![1.0]; // shift: x - 1 = -1
1448        let var = vec![4.0];  // scale: -1/sqrt(4) = -0.5
1449        let gamma = vec![2.0]; // multiply: 2 * -0.5 = -1
1450        let beta = vec![3.0]; // add: -1 + 3 = 2
1451        let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1452        assert!((result[0] - 2.0).abs() < 1e-12);
1453    }
1454
1455    #[test]
1456    fn test_dropout_mask_seed_determinism() {
1457        let m1 = dropout_mask(100, 0.5, 42);
1458        let m2 = dropout_mask(100, 0.5, 42);
1459        assert_eq!(m1, m2);
1460    }
1461
1462    #[test]
1463    fn test_dropout_mask_different_seeds() {
1464        let m1 = dropout_mask(100, 0.5, 42);
1465        let m2 = dropout_mask(100, 0.5, 99);
1466        assert_ne!(m1, m2);
1467    }
1468
1469    #[test]
1470    fn test_lr_step_decay_schedule() {
1471        let lr0 = lr_step_decay(0.1, 0.5, 0, 10);
1472        assert!((lr0 - 0.1).abs() < 1e-12);
1473        let lr10 = lr_step_decay(0.1, 0.5, 10, 10);
1474        assert!((lr10 - 0.05).abs() < 1e-12);
1475        let lr20 = lr_step_decay(0.1, 0.5, 20, 10);
1476        assert!((lr20 - 0.025).abs() < 1e-12);
1477    }
1478
1479    #[test]
1480    fn test_lr_cosine_endpoints() {
1481        let lr0 = lr_cosine(0.1, 0.001, 0, 100);
1482        assert!((lr0 - 0.1).abs() < 1e-10);
1483        let lr_end = lr_cosine(0.1, 0.001, 100, 100);
1484        assert!((lr_end - 0.001).abs() < 1e-10);
1485    }
1486
1487    #[test]
1488    fn test_lr_linear_warmup() {
1489        let lr0 = lr_linear_warmup(0.1, 0, 10);
1490        assert!((lr0).abs() < 1e-12);
1491        let lr5 = lr_linear_warmup(0.1, 5, 10);
1492        assert!((lr5 - 0.05).abs() < 1e-12);
1493        let lr15 = lr_linear_warmup(0.1, 15, 10);
1494        assert!((lr15 - 0.1).abs() < 1e-12);
1495    }
1496
1497    #[test]
1498    fn test_l1_penalty_known() {
1499        let params = [1.0, -2.0, 3.0];
1500        let p = l1_penalty(&params, 0.1);
1501        assert!((p - 0.6).abs() < 1e-12);
1502    }
1503
1504    #[test]
1505    fn test_l2_penalty_known() {
1506        let params = [1.0, -2.0, 3.0];
1507        let p = l2_penalty(&params, 0.1);
1508        // 0.5 * 0.1 * (1 + 4 + 9) = 0.5 * 0.1 * 14 = 0.7
1509        assert!((p - 0.7).abs() < 1e-12);
1510    }
1511
1512    #[test]
1513    fn test_early_stopping_triggers() {
1514        let mut es = EarlyStoppingState::new(3, 0.01);
1515        assert!(!es.check(1.0)); // best_loss=1.0, wait=0
1516        assert!(!es.check(1.0)); // no improvement, wait=1
1517        assert!(!es.check(1.0)); // no improvement, wait=2
1518        assert!(es.check(1.0));  // no improvement, wait=3 >= patience
1519    }
1520
1521    #[test]
1522    fn test_early_stopping_resets() {
1523        let mut es = EarlyStoppingState::new(3, 0.01);
1524        es.check(1.0);
1525        es.check(1.0); // wait=1
1526        assert!(!es.check(0.5)); // improvement, wait=0
1527        assert!(!es.check(0.5)); // wait=1
1528    }
1529
1530    // --- Phase 3C: PCA tests ---
1531
1532    #[test]
1533    fn test_pca_basic_2d() {
1534        // 4 samples, 2 features — data lies mostly along first axis
1535        let data = Tensor::from_vec(
1536            vec![
1537                1.0, 0.1,
1538                2.0, 0.2,
1539                3.0, 0.3,
1540                4.0, 0.4,
1541            ],
1542            &[4, 2],
1543        )
1544        .unwrap();
1545        let (transformed, components, evr) = pca(&data, 2).unwrap();
1546        assert_eq!(transformed.shape(), &[4, 2]);
1547        assert_eq!(components.shape(), &[2, 2]);
1548        assert_eq!(evr.len(), 2);
1549        // Explained variance ratios should sum to ~1.0
1550        let total: f64 = evr.iter().sum();
1551        assert!(
1552            (total - 1.0).abs() < 1e-8,
1553            "explained variance ratios sum to {} (expected ~1.0)",
1554            total
1555        );
1556        // First component should explain most variance
1557        assert!(evr[0] > 0.9, "first component explains {} of variance", evr[0]);
1558    }
1559
1560    #[test]
1561    fn test_pca_single_component() {
1562        let data = Tensor::from_vec(
1563            vec![
1564                1.0, 2.0, 3.0,
1565                4.0, 5.0, 6.0,
1566                7.0, 8.0, 9.0,
1567            ],
1568            &[3, 3],
1569        )
1570        .unwrap();
1571        let (transformed, components, evr) = pca(&data, 1).unwrap();
1572        assert_eq!(transformed.shape(), &[3, 1]);
1573        assert_eq!(components.shape(), &[1, 3]);
1574        assert_eq!(evr.len(), 1);
1575        assert!(evr[0] > 0.0 && evr[0] <= 1.0);
1576    }
1577
1578    #[test]
1579    fn test_pca_explained_variance_ratio_bounded() {
1580        let data = Tensor::from_vec(
1581            vec![
1582                1.0, 0.0, 0.5,
1583                0.0, 1.0, 0.5,
1584                1.0, 1.0, 1.0,
1585                2.0, 0.0, 1.0,
1586                0.0, 2.0, 1.0,
1587            ],
1588            &[5, 3],
1589        )
1590        .unwrap();
1591        let (_, _, evr) = pca(&data, 3).unwrap();
1592        let total: f64 = evr.iter().sum();
1593        assert!(
1594            total <= 1.0 + 1e-10,
1595            "explained variance ratios sum to {} (should be <= 1.0)",
1596            total
1597        );
1598        for &r in &evr {
1599            assert!(r >= -1e-10, "negative explained variance ratio: {}", r);
1600        }
1601    }
1602
1603    #[test]
1604    fn test_pca_deterministic() {
1605        let data = Tensor::from_vec(
1606            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1607            &[2, 3],
1608        )
1609        .unwrap();
1610        let (t1, c1, e1) = pca(&data, 2).unwrap();
1611        let (t2, c2, e2) = pca(&data, 2).unwrap();
1612        assert_eq!(t1.to_vec(), t2.to_vec(), "PCA transformed not deterministic");
1613        assert_eq!(c1.to_vec(), c2.to_vec(), "PCA components not deterministic");
1614        assert_eq!(e1, e2, "PCA explained variance not deterministic");
1615    }
1616
1617    #[test]
1618    fn test_pca_invalid_n_components() {
1619        let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1620        assert!(pca(&data, 0).is_err(), "n_components=0 should fail");
1621        assert!(pca(&data, 3).is_err(), "n_components > min(n,p) should fail");
1622    }
1623
1624    // --- Sprint 2: L-BFGS tests ---
1625
1626    /// Rosenbrock function: f(x,y) = (1-x)^2 + 100*(y-x^2)^2
1627    /// Gradient: df/dx = -2*(1-x) - 400*x*(y-x^2)
1628    ///           df/dy = 200*(y-x^2)
1629    /// Global minimum at (1, 1) with f=0.
1630    fn rosenbrock(p: &[f64]) -> (f64, Vec<f64>) {
1631        let x = p[0];
1632        let y = p[1];
1633        let a = 1.0 - x;
1634        let b = y - x * x;
1635        let val = a * a + 100.0 * b * b;
1636        let gx = -2.0 * a - 400.0 * x * b;
1637        let gy = 200.0 * b;
1638        (val, vec![gx, gy])
1639    }
1640
1641    #[test]
1642    fn test_lbfgs_rosenbrock_converges() {
1643        // Start far from optimum — L-BFGS should converge near (1, 1)
1644        let mut params = vec![-1.0_f64, 2.0_f64];
1645        let mut state = LbfgsState::new(0.5, 10);
1646
1647        let mut converged = false;
1648        for _iter in 0..200 {
1649            let (_, grads) = rosenbrock(&params);
1650            let grad_norm: f64 = kahan_dot(&grads, &grads).sqrt();
1651            if grad_norm < 1e-5 {
1652                converged = true;
1653                break;
1654            }
1655            let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, rosenbrock);
1656            params = new_p;
1657        }
1658        assert!(converged, "L-BFGS did not converge on Rosenbrock; params = {:?}", params);
1659        assert!(
1660            (params[0] - 1.0).abs() < 1e-3,
1661            "x should converge near 1.0, got {}",
1662            params[0]
1663        );
1664        assert!(
1665            (params[1] - 1.0).abs() < 1e-3,
1666            "y should converge near 1.0, got {}",
1667            params[1]
1668        );
1669    }
1670
1671    #[test]
1672    fn test_lbfgs_determinism() {
1673        // Same initial params + same seed → identical results
1674        let init = vec![-1.0_f64, 2.0_f64];
1675
1676        let run = |init: &[f64]| -> Vec<f64> {
1677            let mut params = init.to_vec();
1678            let mut state = LbfgsState::new(0.5, 10);
1679            for _ in 0..20 {
1680                let (_, grads) = rosenbrock(&params);
1681                let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, rosenbrock);
1682                params = new_p;
1683            }
1684            params
1685        };
1686
1687        let r1 = run(&init);
1688        let r2 = run(&init);
1689        assert_eq!(r1, r2, "L-BFGS must be bit-identical across runs");
1690    }
1691
1692    #[test]
1693    fn test_lbfgs_simple_quadratic() {
1694        // f(x) = x^2, minimum at x=0
1695        let mut params = vec![3.0_f64];
1696        let mut state = LbfgsState::new(1.0, 5);
1697        let quadratic = |p: &[f64]| -> (f64, Vec<f64>) {
1698            (p[0] * p[0], vec![2.0 * p[0]])
1699        };
1700
1701        for _ in 0..30 {
1702            let (_, grads) = quadratic(&params);
1703            let (new_p, _, _) = lbfgs_step(&params, &grads, &mut state, quadratic);
1704            params = new_p;
1705        }
1706        assert!(
1707            params[0].abs() < 1e-6,
1708            "L-BFGS should minimize x^2 to ~0, got {}",
1709            params[0]
1710        );
1711    }
1712
1713    #[test]
1714    fn test_wolfe_line_search_armijo() {
1715        // For f(x) = x^2, starting at x=3, direction d=-1 (descent)
1716        // Wolfe search should find a valid step
1717        let params = vec![3.0_f64];
1718        let direction = vec![-1.0_f64];
1719        let grads = vec![6.0_f64]; // gradient at x=3
1720        let f0 = 9.0;
1721
1722        let mut eval_count = 0;
1723        let mut obj = |p: &[f64]| -> (f64, Vec<f64>) {
1724            eval_count += 1;
1725            (p[0] * p[0], vec![2.0 * p[0]])
1726        };
1727
1728        let (alpha, new_params, new_val, _) =
1729            wolfe_line_search(&params, &direction, &mut obj, f0, &grads, 1.0);
1730
1731        // Armijo: f(x + alpha*d) <= f(x) + c1*alpha*g^T*d
1732        let c1 = 1e-4;
1733        let derphi0 = kahan_dot(&grads, &direction); // = -6
1734        assert!(
1735            new_val <= f0 + c1 * alpha * derphi0,
1736            "Armijo condition violated: {} > {} + {} * {} * {}",
1737            new_val, f0, c1, alpha, derphi0
1738        );
1739        assert!(new_params[0] < 3.0, "Step should move toward minimum");
1740    }
1741}