1use cjc_repro::KahanAccumulatorF64;
9
10use crate::accumulator::BinnedAccumulatorF64;
11use crate::error::RuntimeError;
12use crate::tensor::Tensor;
13
14pub fn mse_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
20 if pred.len() != target.len() {
21 return Err("mse_loss: arrays must have same length".into());
22 }
23 if pred.is_empty() {
24 return Err("mse_loss: empty data".into());
25 }
26 let mut acc = KahanAccumulatorF64::new();
27 for i in 0..pred.len() {
28 let d = pred[i] - target[i];
29 acc.add(d * d);
30 }
31 Ok(acc.finalize() / pred.len() as f64)
32}
33
34pub fn cross_entropy_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
36 if pred.len() != target.len() {
37 return Err("cross_entropy_loss: arrays must have same length".into());
38 }
39 if pred.is_empty() {
40 return Err("cross_entropy_loss: empty data".into());
41 }
42 let eps = 1e-12;
43 let mut acc = KahanAccumulatorF64::new();
44 for i in 0..pred.len() {
45 acc.add(-target[i] * (pred[i] + eps).ln());
46 }
47 Ok(acc.finalize() / pred.len() as f64)
48}
49
50pub fn binary_cross_entropy(pred: &[f64], target: &[f64]) -> Result<f64, String> {
52 if pred.len() != target.len() {
53 return Err("binary_cross_entropy: arrays must have same length".into());
54 }
55 if pred.is_empty() {
56 return Err("binary_cross_entropy: empty data".into());
57 }
58 let eps = 1e-12;
59 let mut acc = KahanAccumulatorF64::new();
60 for i in 0..pred.len() {
61 let p = pred[i].max(eps).min(1.0 - eps);
62 acc.add(-(target[i] * p.ln() + (1.0 - target[i]) * (1.0 - p).ln()));
63 }
64 Ok(acc.finalize() / pred.len() as f64)
65}
66
67pub fn huber_loss(pred: &[f64], target: &[f64], delta: f64) -> Result<f64, String> {
69 if pred.len() != target.len() {
70 return Err("huber_loss: arrays must have same length".into());
71 }
72 if pred.is_empty() {
73 return Err("huber_loss: empty data".into());
74 }
75 let mut acc = KahanAccumulatorF64::new();
76 for i in 0..pred.len() {
77 let d = (pred[i] - target[i]).abs();
78 if d <= delta {
79 acc.add(0.5 * d * d);
80 } else {
81 acc.add(delta * (d - 0.5 * delta));
82 }
83 }
84 Ok(acc.finalize() / pred.len() as f64)
85}
86
87pub fn hinge_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
89 if pred.len() != target.len() {
90 return Err("hinge_loss: arrays must have same length".into());
91 }
92 if pred.is_empty() {
93 return Err("hinge_loss: empty data".into());
94 }
95 let mut acc = KahanAccumulatorF64::new();
96 for i in 0..pred.len() {
97 acc.add((1.0 - target[i] * pred[i]).max(0.0));
98 }
99 Ok(acc.finalize() / pred.len() as f64)
100}
101
102pub struct SgdState {
108 pub lr: f64,
109 pub momentum: f64,
110 pub velocity: Vec<f64>,
111}
112
113impl SgdState {
114 pub fn new(n_params: usize, lr: f64, momentum: f64) -> Self {
115 Self { lr, momentum, velocity: vec![0.0; n_params] }
116 }
117}
118
119pub fn sgd_step(params: &mut [f64], grads: &[f64], state: &mut SgdState) {
121 for i in 0..params.len() {
122 state.velocity[i] = state.momentum * state.velocity[i] + grads[i];
123 params[i] -= state.lr * state.velocity[i];
124 }
125}
126
127pub struct AdamState {
129 pub lr: f64,
130 pub beta1: f64,
131 pub beta2: f64,
132 pub eps: f64,
133 pub t: u64,
134 pub m: Vec<f64>,
135 pub v: Vec<f64>,
136}
137
138impl AdamState {
139 pub fn new(n_params: usize, lr: f64) -> Self {
140 Self {
141 lr,
142 beta1: 0.9,
143 beta2: 0.999,
144 eps: 1e-8,
145 t: 0,
146 m: vec![0.0; n_params],
147 v: vec![0.0; n_params],
148 }
149 }
150}
151
152pub fn adam_step(params: &mut [f64], grads: &[f64], state: &mut AdamState) {
154 state.t += 1;
155 let t = state.t as f64;
156 for i in 0..params.len() {
157 state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * grads[i];
158 state.v[i] = state.beta2 * state.v[i] + (1.0 - state.beta2) * grads[i] * grads[i];
159 let m_hat = state.m[i] / (1.0 - state.beta1.powf(t));
160 let v_hat = state.v[i] / (1.0 - state.beta2.powf(t));
161 params[i] -= state.lr * m_hat / (v_hat.sqrt() + state.eps);
162 }
163}
164
165#[derive(Debug, Clone)]
171pub struct ConfusionMatrix {
172 pub tp: usize,
173 pub fp: usize,
174 pub tn: usize,
175 pub fn_count: usize,
176}
177
178pub fn confusion_matrix(predicted: &[bool], actual: &[bool]) -> ConfusionMatrix {
180 let mut tp = 0;
181 let mut fp = 0;
182 let mut tn = 0;
183 let mut fn_count = 0;
184 for i in 0..predicted.len().min(actual.len()) {
185 match (predicted[i], actual[i]) {
186 (true, true) => tp += 1,
187 (true, false) => fp += 1,
188 (false, true) => fn_count += 1,
189 (false, false) => tn += 1,
190 }
191 }
192 ConfusionMatrix { tp, fp, tn, fn_count }
193}
194
195pub fn precision(cm: &ConfusionMatrix) -> f64 {
197 let denom = cm.tp + cm.fp;
198 if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
199}
200
201pub fn recall(cm: &ConfusionMatrix) -> f64 {
203 let denom = cm.tp + cm.fn_count;
204 if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
205}
206
207pub fn f1_score(cm: &ConfusionMatrix) -> f64 {
209 let p = precision(cm);
210 let r = recall(cm);
211 if p + r == 0.0 { 0.0 } else { 2.0 * p * r / (p + r) }
212}
213
214pub fn accuracy(cm: &ConfusionMatrix) -> f64 {
216 let total = cm.tp + cm.fp + cm.tn + cm.fn_count;
217 if total == 0 { 0.0 } else { (cm.tp + cm.tn) as f64 / total as f64 }
218}
219
220pub fn auc_roc(scores: &[f64], labels: &[bool]) -> Result<f64, String> {
223 if scores.len() != labels.len() {
224 return Err("auc_roc: scores and labels must have same length".into());
225 }
226 let n = scores.len();
227 if n == 0 {
228 return Err("auc_roc: empty data".into());
229 }
230 let mut indexed: Vec<(usize, f64, bool)> = scores.iter().zip(labels.iter())
232 .enumerate()
233 .map(|(i, (&s, &l))| (i, s, l))
234 .collect();
235 indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
236
237 let pos_count = labels.iter().filter(|&&l| l).count();
238 let neg_count = n - pos_count;
239 if pos_count == 0 || neg_count == 0 {
240 return Err("auc_roc: need both positive and negative labels".into());
241 }
242
243 let mut auc = 0.0;
244 let mut tp = 0.0;
245 let mut fp = 0.0;
246 let mut prev_fpr = 0.0;
247 let mut prev_tpr = 0.0;
248
249 for &(_, _, label) in &indexed {
250 if label { tp += 1.0; } else { fp += 1.0; }
251 let tpr = tp / pos_count as f64;
252 let fpr = fp / neg_count as f64;
253 auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
254 prev_fpr = fpr;
255 prev_tpr = tpr;
256 }
257 Ok(auc)
258}
259
260pub fn kfold_indices(n: usize, k: usize, seed: u64) -> Vec<(Vec<usize>, Vec<usize>)> {
263 let mut rng = cjc_repro::Rng::seeded(seed);
264 let mut indices: Vec<usize> = (0..n).collect();
266 for i in (1..n).rev() {
267 let j = (rng.next_u64() as usize) % (i + 1);
268 indices.swap(i, j);
269 }
270 let fold_size = n / k;
271 let mut folds = Vec::with_capacity(k);
272 for fold in 0..k {
273 let start = fold * fold_size;
274 let end = if fold == k - 1 { n } else { start + fold_size };
275 let test: Vec<usize> = indices[start..end].to_vec();
276 let train: Vec<usize> = indices[..start].iter()
277 .chain(indices[end..].iter())
278 .copied()
279 .collect();
280 folds.push((train, test));
281 }
282 folds
283}
284
285pub fn train_test_split(n: usize, test_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
287 let mut rng = cjc_repro::Rng::seeded(seed);
288 let mut indices: Vec<usize> = (0..n).collect();
289 for i in (1..n).rev() {
290 let j = (rng.next_u64() as usize) % (i + 1);
291 indices.swap(i, j);
292 }
293 let test_size = ((n as f64) * test_fraction).round() as usize;
294 let test = indices[..test_size].to_vec();
295 let train = indices[test_size..].to_vec();
296 (train, test)
297}
298
299pub fn batch_norm(
306 x: &[f64],
307 running_mean: &[f64],
308 running_var: &[f64],
309 gamma: &[f64],
310 beta: &[f64],
311 eps: f64,
312) -> Result<Vec<f64>, String> {
313 let n = x.len();
314 if running_mean.len() != n || running_var.len() != n || gamma.len() != n || beta.len() != n {
315 return Err("batch_norm: all arrays must have same length".into());
316 }
317 let mut result = Vec::with_capacity(n);
318 for i in 0..n {
319 let normed = (x[i] - running_mean[i]) / (running_var[i] + eps).sqrt();
320 result.push(gamma[i] * normed + beta[i]);
321 }
322 Ok(result)
323}
324
325pub fn dropout_mask(n: usize, drop_prob: f64, seed: u64) -> Vec<f64> {
328 let mut rng = cjc_repro::Rng::seeded(seed);
329 let scale = if drop_prob < 1.0 { 1.0 / (1.0 - drop_prob) } else { 0.0 };
330 let mut mask = Vec::with_capacity(n);
331 for _ in 0..n {
332 let r = (rng.next_u64() as f64) / (u64::MAX as f64);
333 if r < drop_prob {
334 mask.push(0.0);
335 } else {
336 mask.push(scale);
337 }
338 }
339 mask
340}
341
342pub fn apply_dropout(data: &[f64], mask: &[f64]) -> Result<Vec<f64>, String> {
344 if data.len() != mask.len() {
345 return Err("apply_dropout: data and mask must have same length".into());
346 }
347 Ok(data.iter().zip(mask.iter()).map(|(&d, &m)| d * m).collect())
348}
349
350pub fn lr_step_decay(initial_lr: f64, decay_rate: f64, epoch: usize, step_size: usize) -> f64 {
353 initial_lr * decay_rate.powi((epoch / step_size) as i32)
354}
355
356pub fn lr_cosine(max_lr: f64, min_lr: f64, epoch: usize, total_epochs: usize) -> f64 {
359 let ratio = epoch as f64 / total_epochs as f64;
360 min_lr + 0.5 * (max_lr - min_lr) * (1.0 + (std::f64::consts::PI * ratio).cos())
361}
362
363pub fn lr_linear_warmup(initial_lr: f64, epoch: usize, warmup_epochs: usize) -> f64 {
366 if warmup_epochs == 0 {
367 return initial_lr;
368 }
369 initial_lr * (epoch as f64 / warmup_epochs as f64).min(1.0)
370}
371
372pub fn l1_penalty(params: &[f64], lambda: f64) -> f64 {
374 let mut acc = KahanAccumulatorF64::new();
375 for &p in params {
376 acc.add(p.abs());
377 }
378 lambda * acc.finalize()
379}
380
381pub fn l2_penalty(params: &[f64], lambda: f64) -> f64 {
383 let mut acc = KahanAccumulatorF64::new();
384 for &p in params {
385 acc.add(p * p);
386 }
387 0.5 * lambda * acc.finalize()
388}
389
390pub fn l1_grad(params: &[f64], lambda: f64) -> Vec<f64> {
392 params.iter().map(|&p| {
393 if p > 0.0 { lambda } else if p < 0.0 { -lambda } else { 0.0 }
394 }).collect()
395}
396
397pub fn l2_grad(params: &[f64], lambda: f64) -> Vec<f64> {
399 params.iter().map(|&p| lambda * p).collect()
400}
401
402pub struct EarlyStoppingState {
404 pub patience: usize,
405 pub min_delta: f64,
406 pub best_loss: f64,
407 pub wait: usize,
408 pub stopped: bool,
409}
410
411impl EarlyStoppingState {
412 pub fn new(patience: usize, min_delta: f64) -> Self {
413 Self {
414 patience,
415 min_delta,
416 best_loss: f64::INFINITY,
417 wait: 0,
418 stopped: false,
419 }
420 }
421
422 pub fn check(&mut self, current_loss: f64) -> bool {
424 if current_loss < self.best_loss - self.min_delta {
425 self.best_loss = current_loss;
426 self.wait = 0;
427 } else {
428 self.wait += 1;
429 }
430 if self.wait >= self.patience {
431 self.stopped = true;
432 }
433 self.stopped
434 }
435}
436
437pub fn pca(
453 data: &Tensor,
454 n_components: usize,
455) -> Result<(Tensor, Tensor, Vec<f64>), RuntimeError> {
456 if data.ndim() != 2 {
457 return Err(RuntimeError::InvalidOperation(
458 "PCA requires a 2D data matrix".to_string(),
459 ));
460 }
461 let n_samples = data.shape()[0];
462 let n_features = data.shape()[1];
463
464 if n_samples == 0 || n_features == 0 {
465 return Err(RuntimeError::InvalidOperation(
466 "PCA: empty data matrix".to_string(),
467 ));
468 }
469 if n_components == 0 || n_components > n_features.min(n_samples) {
470 return Err(RuntimeError::InvalidOperation(format!(
471 "PCA: n_components ({}) must be in [1, min(n_samples, n_features) = {}]",
472 n_components,
473 n_features.min(n_samples)
474 )));
475 }
476
477 let raw = data.to_vec();
478
479 let mut means = vec![0.0f64; n_features];
481 for j in 0..n_features {
482 let mut acc = BinnedAccumulatorF64::new();
483 for i in 0..n_samples {
484 acc.add(raw[i * n_features + j]);
485 }
486 means[j] = acc.finalize() / n_samples as f64;
487 }
488
489 let mut centered = vec![0.0f64; n_samples * n_features];
491 for i in 0..n_samples {
492 for j in 0..n_features {
493 centered[i * n_features + j] = raw[i * n_features + j] - means[j];
494 }
495 }
496 let centered_tensor = Tensor::from_vec(centered, &[n_samples, n_features])?;
497
498 let (u, s, vt) = centered_tensor.svd()?;
500 let k = n_components.min(s.len());
501
502 let vt_data = vt.to_vec();
504 let vt_cols = vt.shape()[1]; let mut components = vec![0.0f64; k * n_features];
506 for i in 0..k {
507 for j in 0..n_features {
508 components[i * n_features + j] = vt_data[i * vt_cols + j];
509 }
510 }
511
512 let denom = if n_samples > 1 {
515 (n_samples - 1) as f64
516 } else {
517 1.0
518 };
519
520 let mut total_var_acc = BinnedAccumulatorF64::new();
521 for &si in &s {
522 total_var_acc.add(si * si / denom);
523 }
524 let total_var = total_var_acc.finalize();
525
526 let explained_variance_ratio: Vec<f64> = if total_var > 1e-15 {
527 s[..k]
528 .iter()
529 .map(|&si| (si * si / denom) / total_var)
530 .collect()
531 } else {
532 vec![0.0; k]
533 };
534
535 let u_data = u.to_vec();
537 let u_cols = u.shape()[1];
538 let mut transformed = vec![0.0f64; n_samples * k];
539 for i in 0..n_samples {
540 for j in 0..k {
541 transformed[i * k + j] = u_data[i * u_cols + j] * s[j];
542 }
543 }
544
545 Ok((
546 Tensor::from_vec(transformed, &[n_samples, k])?,
547 Tensor::from_vec(components, &[k, n_features])?,
548 explained_variance_ratio,
549 ))
550}
551
552pub struct LbfgsState {
562 pub lr: f64,
563 pub m: usize,
565 pub s_history: Vec<Vec<f64>>,
567 pub y_history: Vec<Vec<f64>>,
569 pub prev_params: Option<Vec<f64>>,
570 pub prev_grad: Option<Vec<f64>>,
571}
572
573impl LbfgsState {
574 pub fn new(lr: f64, m: usize) -> Self {
575 Self {
576 lr,
577 m,
578 s_history: Vec::new(),
579 y_history: Vec::new(),
580 prev_params: None,
581 prev_grad: None,
582 }
583 }
584}
585
586fn kahan_dot(a: &[f64], b: &[f64]) -> f64 {
588 debug_assert_eq!(a.len(), b.len());
589 let mut acc = KahanAccumulatorF64::new();
590 for (&ai, &bi) in a.iter().zip(b.iter()) {
591 acc.add(ai * bi);
592 }
593 acc.finalize()
594}
595
596pub fn wolfe_line_search<F>(
615 params: &[f64],
616 direction: &[f64],
617 f: &mut F,
618 f0: f64,
619 g0: &[f64],
620 alpha_init: f64,
621) -> (f64, Vec<f64>, f64, Vec<f64>)
622where
623 F: FnMut(&[f64]) -> (f64, Vec<f64>),
624{
625 let c1 = 1e-4_f64;
626 let c2 = 0.9_f64;
627 let derphi0 = kahan_dot(g0, direction); let step = |alpha: f64| -> Vec<f64> {
631 params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
632 };
633
634 let max_iter = 30;
635 let mut alpha_lo = 0.0_f64;
636 let mut alpha_hi = f64::INFINITY;
637 let mut phi_lo = f0;
638 let mut dphi_lo = derphi0;
639 let mut alpha = alpha_init;
640
641 let mut best_alpha = alpha_init;
643 let mut best_params = step(alpha_init);
644 let (mut best_val, mut best_grad) = f(&best_params);
645
646 for _iter in 0..max_iter {
647 let x_alpha = step(alpha);
648 let (phi_alpha, grad_alpha) = f(&x_alpha);
649 let dphi_alpha = kahan_dot(&grad_alpha, direction);
650
651 if phi_alpha < best_val {
653 best_alpha = alpha;
654 best_params = x_alpha.clone();
655 best_val = phi_alpha;
656 best_grad = grad_alpha.clone();
657 }
658
659 if phi_alpha > f0 + c1 * alpha * derphi0 || (phi_alpha >= phi_lo && alpha_lo > 0.0) {
661 alpha_hi = alpha;
662 let (za, zp, zv, zg) = wolfe_zoom(
664 params, direction, f, f0, derphi0, c1, c2,
665 alpha_lo, alpha_hi, phi_lo, dphi_lo,
666 );
667 return (za, zp, zv, zg);
668 }
669
670 if dphi_alpha.abs() <= c2 * derphi0.abs() {
672 return (alpha, x_alpha, phi_alpha, grad_alpha);
673 }
674
675 if dphi_alpha >= 0.0 {
677 let (za, zp, zv, zg) = wolfe_zoom(
678 params, direction, f, f0, derphi0, c1, c2,
679 alpha, alpha_lo, phi_alpha, dphi_alpha,
680 );
681 return (za, zp, zv, zg);
682 }
683
684 alpha_lo = alpha;
686 phi_lo = phi_alpha;
687 dphi_lo = dphi_alpha;
688
689 alpha = if alpha_hi.is_finite() {
691 (alpha_lo + alpha_hi) * 0.5
692 } else {
693 (alpha * 2.0).min(1e8)
694 };
695 }
696
697 (best_alpha, best_params, best_val, best_grad)
699}
700
701#[allow(clippy::too_many_arguments)]
703fn wolfe_zoom<F>(
704 params: &[f64],
705 direction: &[f64],
706 f: &mut F,
707 f0: f64,
708 derphi0: f64,
709 c1: f64,
710 c2: f64,
711 mut alpha_lo: f64,
712 mut alpha_hi: f64,
713 mut phi_lo: f64,
714 _dphi_lo: f64,
715) -> (f64, Vec<f64>, f64, Vec<f64>)
716where
717 F: FnMut(&[f64]) -> (f64, Vec<f64>),
718{
719 let step = |alpha: f64| -> Vec<f64> {
720 params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
721 };
722
723 let max_zoom = 20;
724 let mut best_alpha = alpha_lo;
725 let mut best_x = step(alpha_lo);
726 let (mut best_val, mut best_grad) = f(&best_x);
727
728 for _i in 0..max_zoom {
729 let alpha_j = (alpha_lo + alpha_hi) * 0.5;
730 let x_j = step(alpha_j);
731 let (phi_j, grad_j) = f(&x_j);
732 let dphi_j = kahan_dot(&grad_j, direction);
733
734 if phi_j < best_val {
735 best_alpha = alpha_j;
736 best_x = x_j.clone();
737 best_val = phi_j;
738 best_grad = grad_j.clone();
739 }
740
741 if phi_j > f0 + c1 * alpha_j * derphi0 || phi_j >= phi_lo {
742 alpha_hi = alpha_j;
743 } else {
744 if dphi_j.abs() <= c2 * derphi0.abs() {
746 return (alpha_j, x_j, phi_j, grad_j);
747 }
748 if dphi_j * (alpha_hi - alpha_lo) >= 0.0 {
749 alpha_hi = alpha_lo;
750 }
751 alpha_lo = alpha_j;
752 phi_lo = phi_j;
753 }
754
755 if (alpha_hi - alpha_lo).abs() < 1e-14 {
757 break;
758 }
759 }
760
761 (best_alpha, best_x, best_val, best_grad)
762}
763
764pub fn lbfgs_step<F>(
781 params: &[f64],
782 grads: &[f64],
783 state: &mut LbfgsState,
784 mut f: F,
785) -> (Vec<f64>, Vec<f64>, bool)
786where
787 F: FnMut(&[f64]) -> (f64, Vec<f64>),
788{
789 let n = params.len();
790 debug_assert_eq!(grads.len(), n);
791
792 let hist_len = state.s_history.len();
795 let mut q: Vec<f64> = grads.to_vec();
796 let mut alphas = vec![0.0_f64; hist_len];
797 let mut rhos = vec![0.0_f64; hist_len];
798
799 for i in (0..hist_len).rev() {
801 let sy = kahan_dot(&state.s_history[i], &state.y_history[i]);
802 rhos[i] = if sy.abs() < 1e-300 { 0.0 } else { 1.0 / sy };
803 let sq = kahan_dot(&state.s_history[i], &q);
804 alphas[i] = rhos[i] * sq;
805 for j in 0..n {
807 q[j] -= alphas[i] * state.y_history[i][j];
808 }
809 }
810
811 let scale = if hist_len > 0 {
813 let last = hist_len - 1;
814 let sy = kahan_dot(&state.s_history[last], &state.y_history[last]);
815 let yy = kahan_dot(&state.y_history[last], &state.y_history[last]);
816 if yy.abs() < 1e-300 { 1.0 } else { sy / yy }
817 } else {
818 1.0
819 };
820 let mut r: Vec<f64> = q.iter().map(|&qi| scale * qi).collect();
821
822 for i in 0..hist_len {
824 let yr = kahan_dot(&state.y_history[i], &r);
825 let beta = rhos[i] * yr;
826 let diff = alphas[i] - beta;
828 for j in 0..n {
829 r[j] += diff * state.s_history[i][j];
830 }
831 }
832
833 let direction: Vec<f64> = r.iter().map(|&ri| -ri).collect();
835
836 let descent_check = kahan_dot(&direction, grads);
838 let (direction, is_descent) = if descent_check >= 0.0 || !descent_check.is_finite() {
839 let norm_g = kahan_dot(grads, grads).sqrt().max(1e-300);
841 (grads.iter().map(|&g| -g / norm_g).collect::<Vec<f64>>(), false)
842 } else {
843 (direction, true)
844 };
845
846 let (f0, _) = f(params);
848 let (_, new_params, _, new_grads) = wolfe_line_search(
849 params,
850 &direction,
851 &mut f,
852 f0,
853 grads,
854 state.lr,
855 );
856
857 let s_k: Vec<f64> = new_params.iter().zip(params.iter()).map(|(&np, &p)| np - p).collect();
860 let y_k: Vec<f64> = new_grads.iter().zip(grads.iter()).map(|(&ng, &g)| ng - g).collect();
862
863 let sy = kahan_dot(&s_k, &y_k);
864 if sy > 1e-300 {
866 state.s_history.push(s_k);
867 state.y_history.push(y_k);
868 if state.s_history.len() > state.m {
870 state.s_history.remove(0);
871 state.y_history.remove(0);
872 }
873 }
874
875 state.prev_params = Some(new_params.clone());
876 state.prev_grad = Some(new_grads.clone());
877
878 (new_params, new_grads, is_descent)
879}
880
881pub fn lstm_cell(
900 x: &Tensor,
901 h_prev: &Tensor,
902 c_prev: &Tensor,
903 w_ih: &Tensor,
904 w_hh: &Tensor,
905 b_ih: &Tensor,
906 b_hh: &Tensor,
907) -> Result<(Tensor, Tensor), String> {
908 let map_err = |e: crate::error::RuntimeError| format!("{e}");
909
910 if x.ndim() != 2 {
912 return Err("lstm_cell: x must be 2-D [batch, input_size]".into());
913 }
914 if h_prev.ndim() != 2 {
915 return Err("lstm_cell: h_prev must be 2-D [batch, hidden_size]".into());
916 }
917 if c_prev.ndim() != 2 {
918 return Err("lstm_cell: c_prev must be 2-D [batch, hidden_size]".into());
919 }
920 let hidden_size = h_prev.shape()[1];
921 if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
922 return Err(format!(
923 "lstm_cell: w_ih must be [4*hidden_size, input_size], got {:?}",
924 w_ih.shape()
925 ));
926 }
927 if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
928 return Err(format!(
929 "lstm_cell: w_hh must be [4*hidden_size, hidden_size], got {:?}",
930 w_hh.shape()
931 ));
932 }
933 if b_ih.len() != 4 * hidden_size {
934 return Err(format!(
935 "lstm_cell: b_ih must have length 4*hidden_size={}, got {}",
936 4 * hidden_size,
937 b_ih.len()
938 ));
939 }
940 if b_hh.len() != 4 * hidden_size {
941 return Err(format!(
942 "lstm_cell: b_hh must have length 4*hidden_size={}, got {}",
943 4 * hidden_size,
944 b_hh.len()
945 ));
946 }
947
948 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
951 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
952 let gates = gates_ih.add(&gates_hh).map_err(map_err)?;
953
954 let chunks = gates.chunk(4, 1).map_err(map_err)?;
956 let gates_i = &chunks[0];
957 let gates_f = &chunks[1];
958 let gates_g = &chunks[2];
959 let gates_o = &chunks[3];
960
961 let i = gates_i.sigmoid();
963 let f = gates_f.sigmoid();
964 let g = gates_g.tanh_activation();
965 let o = gates_o.sigmoid();
966
967 let fc = f.mul_elem(c_prev).map_err(map_err)?;
969 let ig = i.mul_elem(&g).map_err(map_err)?;
970 let c_new = fc.add(&ig).map_err(map_err)?;
971
972 let c_tanh = c_new.tanh_activation();
974 let h_new = o.mul_elem(&c_tanh).map_err(map_err)?;
975
976 Ok((h_new, c_new))
977}
978
979pub fn gru_cell(
997 x: &Tensor,
998 h_prev: &Tensor,
999 w_ih: &Tensor,
1000 w_hh: &Tensor,
1001 b_ih: &Tensor,
1002 b_hh: &Tensor,
1003) -> Result<Tensor, String> {
1004 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1005
1006 if x.ndim() != 2 {
1008 return Err("gru_cell: x must be 2-D [batch, input_size]".into());
1009 }
1010 if h_prev.ndim() != 2 {
1011 return Err("gru_cell: h_prev must be 2-D [batch, hidden_size]".into());
1012 }
1013 let hidden_size = h_prev.shape()[1];
1014 if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1015 return Err(format!(
1016 "gru_cell: w_ih must be [3*hidden_size, input_size], got {:?}",
1017 w_ih.shape()
1018 ));
1019 }
1020 if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1021 return Err(format!(
1022 "gru_cell: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1023 w_hh.shape()
1024 ));
1025 }
1026 if b_ih.len() != 3 * hidden_size {
1027 return Err(format!(
1028 "gru_cell: b_ih must have length 3*hidden_size={}, got {}",
1029 3 * hidden_size,
1030 b_ih.len()
1031 ));
1032 }
1033 if b_hh.len() != 3 * hidden_size {
1034 return Err(format!(
1035 "gru_cell: b_hh must have length 3*hidden_size={}, got {}",
1036 3 * hidden_size,
1037 b_hh.len()
1038 ));
1039 }
1040
1041 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1044 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1045
1046 let ih_chunks = gates_ih.chunk(3, 1).map_err(map_err)?;
1048 let hh_chunks = gates_hh.chunk(3, 1).map_err(map_err)?;
1049 let r_ih = &ih_chunks[0];
1050 let z_ih = &ih_chunks[1];
1051 let n_ih = &ih_chunks[2];
1052 let r_hh = &hh_chunks[0];
1053 let z_hh = &hh_chunks[1];
1054 let n_hh = &hh_chunks[2];
1055
1056 let r = r_ih.add(r_hh).map_err(map_err)?.sigmoid();
1058 let z = z_ih.add(z_hh).map_err(map_err)?.sigmoid();
1060 let r_n_hh = r.mul_elem(n_hh).map_err(map_err)?;
1062 let n = n_ih.add(&r_n_hh).map_err(map_err)?.tanh_activation();
1063
1064 let ones = Tensor::ones(z.shape());
1067 let one_minus_z = ones.sub(&z).map_err(map_err)?;
1068 let term1 = one_minus_z.mul_elem(&n).map_err(map_err)?;
1069 let term2 = z.mul_elem(h_prev).map_err(map_err)?;
1070 let h_new = term1.add(&term2).map_err(map_err)?;
1071
1072 Ok(h_new)
1073}
1074
1075pub fn lstm_cell_fused(
1089 x: &Tensor,
1090 h_prev: &Tensor,
1091 c_prev: &Tensor,
1092 w_ih: &Tensor,
1093 w_hh: &Tensor,
1094 b_ih: &Tensor,
1095 b_hh: &Tensor,
1096) -> Result<(Tensor, Tensor), String> {
1097 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1098
1099 if x.ndim() != 2 {
1101 return Err("lstm_cell_fused: x must be 2-D [batch, input_size]".into());
1102 }
1103 if h_prev.ndim() != 2 {
1104 return Err("lstm_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1105 }
1106 if c_prev.ndim() != 2 {
1107 return Err("lstm_cell_fused: c_prev must be 2-D [batch, hidden_size]".into());
1108 }
1109 let batch = x.shape()[0];
1110 let hidden_size = h_prev.shape()[1];
1111 if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1112 return Err(format!(
1113 "lstm_cell_fused: w_ih must be [4*hidden_size, input_size], got {:?}",
1114 w_ih.shape()
1115 ));
1116 }
1117 if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1118 return Err(format!(
1119 "lstm_cell_fused: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1120 w_hh.shape()
1121 ));
1122 }
1123 if b_ih.len() != 4 * hidden_size {
1124 return Err(format!(
1125 "lstm_cell_fused: b_ih must have length 4*hidden_size={}, got {}",
1126 4 * hidden_size,
1127 b_ih.len()
1128 ));
1129 }
1130 if b_hh.len() != 4 * hidden_size {
1131 return Err(format!(
1132 "lstm_cell_fused: b_hh must have length 4*hidden_size={}, got {}",
1133 4 * hidden_size,
1134 b_hh.len()
1135 ));
1136 }
1137
1138 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1140 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1141
1142 let gih = gates_ih.to_vec();
1144 let ghh = gates_hh.to_vec();
1145 let cprev = c_prev.to_vec();
1146
1147 let mut h_new_data = vec![0.0f64; batch * hidden_size];
1148 let mut c_new_data = vec![0.0f64; batch * hidden_size];
1149
1150 for b_idx in 0..batch {
1151 let base = b_idx * 4 * hidden_size;
1152 for h in 0..hidden_size {
1153 let gi = gih[base + h] + ghh[base + h];
1155 let gf = gih[base + hidden_size + h] + ghh[base + hidden_size + h];
1156 let gg = gih[base + 2 * hidden_size + h] + ghh[base + 2 * hidden_size + h];
1157 let go = gih[base + 3 * hidden_size + h] + ghh[base + 3 * hidden_size + h];
1158
1159 let i_val = 1.0 / (1.0 + (-gi).exp()); let f_val = 1.0 / (1.0 + (-gf).exp()); let g_val = gg.tanh(); let o_val = 1.0 / (1.0 + (-go).exp()); let c_idx = b_idx * hidden_size + h;
1167 let c_val = f_val * cprev[c_idx] + i_val * g_val;
1168 c_new_data[c_idx] = c_val;
1169 h_new_data[c_idx] = o_val * c_val.tanh();
1170 }
1171 }
1172
1173 let h_new = Tensor::from_vec(h_new_data, &[batch, hidden_size]).map_err(map_err)?;
1174 let c_new = Tensor::from_vec(c_new_data, &[batch, hidden_size]).map_err(map_err)?;
1175
1176 Ok((h_new, c_new))
1177}
1178
1179pub fn gru_cell_fused(
1192 x: &Tensor,
1193 h_prev: &Tensor,
1194 w_ih: &Tensor,
1195 w_hh: &Tensor,
1196 b_ih: &Tensor,
1197 b_hh: &Tensor,
1198) -> Result<Tensor, String> {
1199 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1200
1201 if x.ndim() != 2 {
1203 return Err("gru_cell_fused: x must be 2-D [batch, input_size]".into());
1204 }
1205 if h_prev.ndim() != 2 {
1206 return Err("gru_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1207 }
1208 let batch = x.shape()[0];
1209 let hidden_size = h_prev.shape()[1];
1210 if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1211 return Err(format!(
1212 "gru_cell_fused: w_ih must be [3*hidden_size, input_size], got {:?}",
1213 w_ih.shape()
1214 ));
1215 }
1216 if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1217 return Err(format!(
1218 "gru_cell_fused: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1219 w_hh.shape()
1220 ));
1221 }
1222 if b_ih.len() != 3 * hidden_size {
1223 return Err(format!(
1224 "gru_cell_fused: b_ih must have length 3*hidden_size={}, got {}",
1225 3 * hidden_size,
1226 b_ih.len()
1227 ));
1228 }
1229 if b_hh.len() != 3 * hidden_size {
1230 return Err(format!(
1231 "gru_cell_fused: b_hh must have length 3*hidden_size={}, got {}",
1232 3 * hidden_size,
1233 b_hh.len()
1234 ));
1235 }
1236
1237 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1239 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1240
1241 let gih = gates_ih.to_vec();
1243 let ghh = gates_hh.to_vec();
1244 let hp = h_prev.to_vec();
1245
1246 let mut h_new_data = vec![0.0f64; batch * hidden_size];
1247
1248 for b_idx in 0..batch {
1249 let base = b_idx * 3 * hidden_size;
1250 for h in 0..hidden_size {
1251 let r_val =
1253 1.0 / (1.0 + (-(gih[base + h] + ghh[base + h])).exp());
1254 let z_val = 1.0
1256 / (1.0
1257 + (-(gih[base + hidden_size + h]
1258 + ghh[base + hidden_size + h]))
1259 .exp());
1260 let n_val = (gih[base + 2 * hidden_size + h]
1262 + r_val * ghh[base + 2 * hidden_size + h])
1263 .tanh();
1264
1265 let h_idx = b_idx * hidden_size + h;
1267 h_new_data[h_idx] = (1.0 - z_val) * n_val + z_val * hp[h_idx];
1268 }
1269 }
1270
1271 Tensor::from_vec(h_new_data, &[batch, hidden_size])
1272 .map_err(|e| format!("{e}"))
1273}
1274
1275pub fn multi_head_attention(
1291 q: &Tensor,
1292 k: &Tensor,
1293 v: &Tensor,
1294 w_q: &Tensor,
1295 w_k: &Tensor,
1296 w_v: &Tensor,
1297 w_o: &Tensor,
1298 b_q: &Tensor,
1299 b_k: &Tensor,
1300 b_v: &Tensor,
1301 b_o: &Tensor,
1302 num_heads: usize,
1303) -> Result<Tensor, String> {
1304 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1305
1306 if q.ndim() != 3 {
1307 return Err("multi_head_attention: q must be 3-D [batch, seq, model_dim]".into());
1308 }
1309
1310 let q_proj = q.linear(w_q, b_q).map_err(map_err)?;
1312 let k_proj = k.linear(w_k, b_k).map_err(map_err)?;
1313 let v_proj = v.linear(w_v, b_v).map_err(map_err)?;
1314
1315 let q_heads = q_proj.split_heads(num_heads).map_err(map_err)?;
1317 let k_heads = k_proj.split_heads(num_heads).map_err(map_err)?;
1318 let v_heads = v_proj.split_heads(num_heads).map_err(map_err)?;
1319
1320 let attn = Tensor::scaled_dot_product_attention(&q_heads, &k_heads, &v_heads)
1322 .map_err(map_err)?;
1323
1324 let merged = attn.merge_heads().map_err(map_err)?;
1326
1327 let output = merged.linear(w_o, b_o).map_err(map_err)?;
1329
1330 Ok(output)
1331}
1332
1333#[cfg(test)]
1338mod tests {
1339 use super::*;
1340
1341 #[test]
1342 fn test_mse_zero() {
1343 let pred = [1.0, 2.0, 3.0];
1344 let target = [1.0, 2.0, 3.0];
1345 assert_eq!(mse_loss(&pred, &target).unwrap(), 0.0);
1346 }
1347
1348 #[test]
1349 fn test_mse_basic() {
1350 let pred = [1.0, 2.0, 3.0];
1351 let target = [2.0, 3.0, 4.0];
1352 assert_eq!(mse_loss(&pred, &target).unwrap(), 1.0);
1353 }
1354
1355 #[test]
1356 fn test_huber_loss_quadratic() {
1357 let pred = [1.0];
1358 let target = [1.5];
1359 let h = huber_loss(&pred, &target, 1.0).unwrap();
1360 assert!((h - 0.125).abs() < 1e-12);
1362 }
1363
1364 #[test]
1365 fn test_sgd_step() {
1366 let mut params = [1.0, 2.0];
1367 let grads = [0.1, 0.2];
1368 let mut state = SgdState::new(2, 0.1, 0.0);
1369 sgd_step(&mut params, &grads, &mut state);
1370 assert!((params[0] - 0.99).abs() < 1e-12);
1371 assert!((params[1] - 1.98).abs() < 1e-12);
1372 }
1373
1374 #[test]
1375 fn test_adam_step() {
1376 let mut params = [1.0, 2.0];
1377 let grads = [0.1, 0.2];
1378 let mut state = AdamState::new(2, 0.001);
1379 adam_step(&mut params, &grads, &mut state);
1380 assert!(params[0] < 1.0);
1382 assert!(params[1] < 2.0);
1383 }
1384
1385 #[test]
1386 fn test_confusion_matrix() {
1387 let pred = [true, true, false, false, true];
1388 let actual = [true, false, true, false, true];
1389 let cm = confusion_matrix(&pred, &actual);
1390 assert_eq!(cm.tp, 2);
1391 assert_eq!(cm.fp, 1);
1392 assert_eq!(cm.fn_count, 1);
1393 assert_eq!(cm.tn, 1);
1394 }
1395
1396 #[test]
1397 fn test_precision_recall_f1() {
1398 let cm = ConfusionMatrix { tp: 5, fp: 2, tn: 8, fn_count: 1 };
1399 assert!((precision(&cm) - 5.0 / 7.0).abs() < 1e-12);
1400 assert!((recall(&cm) - 5.0 / 6.0).abs() < 1e-12);
1401 }
1402
1403 #[test]
1404 fn test_auc_perfect() {
1405 let scores = [0.9, 0.8, 0.2, 0.1];
1406 let labels = [true, true, false, false];
1407 let auc = auc_roc(&scores, &labels).unwrap();
1408 assert!((auc - 1.0).abs() < 1e-12);
1409 }
1410
1411 #[test]
1412 fn test_kfold_deterministic() {
1413 let f1 = kfold_indices(100, 5, 42);
1414 let f2 = kfold_indices(100, 5, 42);
1415 for i in 0..5 {
1416 assert_eq!(f1[i].0, f2[i].0);
1417 assert_eq!(f1[i].1, f2[i].1);
1418 }
1419 }
1420
1421 #[test]
1422 fn test_train_test_split_coverage() {
1423 let (train, test) = train_test_split(100, 0.2, 42);
1424 assert_eq!(train.len() + test.len(), 100);
1425 assert_eq!(test.len(), 20);
1426 }
1427
1428 #[test]
1431 fn test_batch_norm_identity() {
1432 let x = vec![1.0, 2.0, 3.0];
1434 let mean = vec![0.0, 0.0, 0.0];
1435 let var = vec![1.0, 1.0, 1.0];
1436 let gamma = vec![1.0, 1.0, 1.0];
1437 let beta = vec![0.0, 0.0, 0.0];
1438 let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1439 assert!((result[0] - 1.0).abs() < 1e-12);
1440 assert!((result[1] - 2.0).abs() < 1e-12);
1441 assert!((result[2] - 3.0).abs() < 1e-12);
1442 }
1443
1444 #[test]
1445 fn test_batch_norm_shift_scale() {
1446 let x = vec![0.0];
1447 let mean = vec![1.0]; let var = vec![4.0]; let gamma = vec![2.0]; let beta = vec![3.0]; let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1452 assert!((result[0] - 2.0).abs() < 1e-12);
1453 }
1454
1455 #[test]
1456 fn test_dropout_mask_seed_determinism() {
1457 let m1 = dropout_mask(100, 0.5, 42);
1458 let m2 = dropout_mask(100, 0.5, 42);
1459 assert_eq!(m1, m2);
1460 }
1461
1462 #[test]
1463 fn test_dropout_mask_different_seeds() {
1464 let m1 = dropout_mask(100, 0.5, 42);
1465 let m2 = dropout_mask(100, 0.5, 99);
1466 assert_ne!(m1, m2);
1467 }
1468
1469 #[test]
1470 fn test_lr_step_decay_schedule() {
1471 let lr0 = lr_step_decay(0.1, 0.5, 0, 10);
1472 assert!((lr0 - 0.1).abs() < 1e-12);
1473 let lr10 = lr_step_decay(0.1, 0.5, 10, 10);
1474 assert!((lr10 - 0.05).abs() < 1e-12);
1475 let lr20 = lr_step_decay(0.1, 0.5, 20, 10);
1476 assert!((lr20 - 0.025).abs() < 1e-12);
1477 }
1478
1479 #[test]
1480 fn test_lr_cosine_endpoints() {
1481 let lr0 = lr_cosine(0.1, 0.001, 0, 100);
1482 assert!((lr0 - 0.1).abs() < 1e-10);
1483 let lr_end = lr_cosine(0.1, 0.001, 100, 100);
1484 assert!((lr_end - 0.001).abs() < 1e-10);
1485 }
1486
1487 #[test]
1488 fn test_lr_linear_warmup() {
1489 let lr0 = lr_linear_warmup(0.1, 0, 10);
1490 assert!((lr0).abs() < 1e-12);
1491 let lr5 = lr_linear_warmup(0.1, 5, 10);
1492 assert!((lr5 - 0.05).abs() < 1e-12);
1493 let lr15 = lr_linear_warmup(0.1, 15, 10);
1494 assert!((lr15 - 0.1).abs() < 1e-12);
1495 }
1496
1497 #[test]
1498 fn test_l1_penalty_known() {
1499 let params = [1.0, -2.0, 3.0];
1500 let p = l1_penalty(¶ms, 0.1);
1501 assert!((p - 0.6).abs() < 1e-12);
1502 }
1503
1504 #[test]
1505 fn test_l2_penalty_known() {
1506 let params = [1.0, -2.0, 3.0];
1507 let p = l2_penalty(¶ms, 0.1);
1508 assert!((p - 0.7).abs() < 1e-12);
1510 }
1511
1512 #[test]
1513 fn test_early_stopping_triggers() {
1514 let mut es = EarlyStoppingState::new(3, 0.01);
1515 assert!(!es.check(1.0)); assert!(!es.check(1.0)); assert!(!es.check(1.0)); assert!(es.check(1.0)); }
1520
1521 #[test]
1522 fn test_early_stopping_resets() {
1523 let mut es = EarlyStoppingState::new(3, 0.01);
1524 es.check(1.0);
1525 es.check(1.0); assert!(!es.check(0.5)); assert!(!es.check(0.5)); }
1529
1530 #[test]
1533 fn test_pca_basic_2d() {
1534 let data = Tensor::from_vec(
1536 vec![
1537 1.0, 0.1,
1538 2.0, 0.2,
1539 3.0, 0.3,
1540 4.0, 0.4,
1541 ],
1542 &[4, 2],
1543 )
1544 .unwrap();
1545 let (transformed, components, evr) = pca(&data, 2).unwrap();
1546 assert_eq!(transformed.shape(), &[4, 2]);
1547 assert_eq!(components.shape(), &[2, 2]);
1548 assert_eq!(evr.len(), 2);
1549 let total: f64 = evr.iter().sum();
1551 assert!(
1552 (total - 1.0).abs() < 1e-8,
1553 "explained variance ratios sum to {} (expected ~1.0)",
1554 total
1555 );
1556 assert!(evr[0] > 0.9, "first component explains {} of variance", evr[0]);
1558 }
1559
1560 #[test]
1561 fn test_pca_single_component() {
1562 let data = Tensor::from_vec(
1563 vec![
1564 1.0, 2.0, 3.0,
1565 4.0, 5.0, 6.0,
1566 7.0, 8.0, 9.0,
1567 ],
1568 &[3, 3],
1569 )
1570 .unwrap();
1571 let (transformed, components, evr) = pca(&data, 1).unwrap();
1572 assert_eq!(transformed.shape(), &[3, 1]);
1573 assert_eq!(components.shape(), &[1, 3]);
1574 assert_eq!(evr.len(), 1);
1575 assert!(evr[0] > 0.0 && evr[0] <= 1.0);
1576 }
1577
1578 #[test]
1579 fn test_pca_explained_variance_ratio_bounded() {
1580 let data = Tensor::from_vec(
1581 vec![
1582 1.0, 0.0, 0.5,
1583 0.0, 1.0, 0.5,
1584 1.0, 1.0, 1.0,
1585 2.0, 0.0, 1.0,
1586 0.0, 2.0, 1.0,
1587 ],
1588 &[5, 3],
1589 )
1590 .unwrap();
1591 let (_, _, evr) = pca(&data, 3).unwrap();
1592 let total: f64 = evr.iter().sum();
1593 assert!(
1594 total <= 1.0 + 1e-10,
1595 "explained variance ratios sum to {} (should be <= 1.0)",
1596 total
1597 );
1598 for &r in &evr {
1599 assert!(r >= -1e-10, "negative explained variance ratio: {}", r);
1600 }
1601 }
1602
1603 #[test]
1604 fn test_pca_deterministic() {
1605 let data = Tensor::from_vec(
1606 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1607 &[2, 3],
1608 )
1609 .unwrap();
1610 let (t1, c1, e1) = pca(&data, 2).unwrap();
1611 let (t2, c2, e2) = pca(&data, 2).unwrap();
1612 assert_eq!(t1.to_vec(), t2.to_vec(), "PCA transformed not deterministic");
1613 assert_eq!(c1.to_vec(), c2.to_vec(), "PCA components not deterministic");
1614 assert_eq!(e1, e2, "PCA explained variance not deterministic");
1615 }
1616
1617 #[test]
1618 fn test_pca_invalid_n_components() {
1619 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1620 assert!(pca(&data, 0).is_err(), "n_components=0 should fail");
1621 assert!(pca(&data, 3).is_err(), "n_components > min(n,p) should fail");
1622 }
1623
1624 fn rosenbrock(p: &[f64]) -> (f64, Vec<f64>) {
1631 let x = p[0];
1632 let y = p[1];
1633 let a = 1.0 - x;
1634 let b = y - x * x;
1635 let val = a * a + 100.0 * b * b;
1636 let gx = -2.0 * a - 400.0 * x * b;
1637 let gy = 200.0 * b;
1638 (val, vec![gx, gy])
1639 }
1640
1641 #[test]
1642 fn test_lbfgs_rosenbrock_converges() {
1643 let mut params = vec![-1.0_f64, 2.0_f64];
1645 let mut state = LbfgsState::new(0.5, 10);
1646
1647 let mut converged = false;
1648 for _iter in 0..200 {
1649 let (_, grads) = rosenbrock(¶ms);
1650 let grad_norm: f64 = kahan_dot(&grads, &grads).sqrt();
1651 if grad_norm < 1e-5 {
1652 converged = true;
1653 break;
1654 }
1655 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, rosenbrock);
1656 params = new_p;
1657 }
1658 assert!(converged, "L-BFGS did not converge on Rosenbrock; params = {:?}", params);
1659 assert!(
1660 (params[0] - 1.0).abs() < 1e-3,
1661 "x should converge near 1.0, got {}",
1662 params[0]
1663 );
1664 assert!(
1665 (params[1] - 1.0).abs() < 1e-3,
1666 "y should converge near 1.0, got {}",
1667 params[1]
1668 );
1669 }
1670
1671 #[test]
1672 fn test_lbfgs_determinism() {
1673 let init = vec![-1.0_f64, 2.0_f64];
1675
1676 let run = |init: &[f64]| -> Vec<f64> {
1677 let mut params = init.to_vec();
1678 let mut state = LbfgsState::new(0.5, 10);
1679 for _ in 0..20 {
1680 let (_, grads) = rosenbrock(¶ms);
1681 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, rosenbrock);
1682 params = new_p;
1683 }
1684 params
1685 };
1686
1687 let r1 = run(&init);
1688 let r2 = run(&init);
1689 assert_eq!(r1, r2, "L-BFGS must be bit-identical across runs");
1690 }
1691
1692 #[test]
1693 fn test_lbfgs_simple_quadratic() {
1694 let mut params = vec![3.0_f64];
1696 let mut state = LbfgsState::new(1.0, 5);
1697 let quadratic = |p: &[f64]| -> (f64, Vec<f64>) {
1698 (p[0] * p[0], vec![2.0 * p[0]])
1699 };
1700
1701 for _ in 0..30 {
1702 let (_, grads) = quadratic(¶ms);
1703 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, quadratic);
1704 params = new_p;
1705 }
1706 assert!(
1707 params[0].abs() < 1e-6,
1708 "L-BFGS should minimize x^2 to ~0, got {}",
1709 params[0]
1710 );
1711 }
1712
1713 #[test]
1714 fn test_wolfe_line_search_armijo() {
1715 let params = vec![3.0_f64];
1718 let direction = vec![-1.0_f64];
1719 let grads = vec![6.0_f64]; let f0 = 9.0;
1721
1722 let mut eval_count = 0;
1723 let mut obj = |p: &[f64]| -> (f64, Vec<f64>) {
1724 eval_count += 1;
1725 (p[0] * p[0], vec![2.0 * p[0]])
1726 };
1727
1728 let (alpha, new_params, new_val, _) =
1729 wolfe_line_search(¶ms, &direction, &mut obj, f0, &grads, 1.0);
1730
1731 let c1 = 1e-4;
1733 let derphi0 = kahan_dot(&grads, &direction); assert!(
1735 new_val <= f0 + c1 * alpha * derphi0,
1736 "Armijo condition violated: {} > {} + {} * {} * {}",
1737 new_val, f0, c1, alpha, derphi0
1738 );
1739 assert!(new_params[0] < 3.0, "Step should move toward minimum");
1740 }
1741}