1use cjc_repro::KahanAccumulatorF64;
9
10use crate::accumulator::BinnedAccumulatorF64;
11use crate::error::RuntimeError;
12use crate::tensor::Tensor;
13
14pub fn mse_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
20 if pred.len() != target.len() {
21 return Err("mse_loss: arrays must have same length".into());
22 }
23 if pred.is_empty() {
24 return Err("mse_loss: empty data".into());
25 }
26 let mut acc = KahanAccumulatorF64::new();
27 for i in 0..pred.len() {
28 let d = pred[i] - target[i];
29 acc.add(d * d);
30 }
31 Ok(acc.finalize() / pred.len() as f64)
32}
33
34pub fn cross_entropy_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
36 if pred.len() != target.len() {
37 return Err("cross_entropy_loss: arrays must have same length".into());
38 }
39 if pred.is_empty() {
40 return Err("cross_entropy_loss: empty data".into());
41 }
42 let eps = 1e-12;
43 let mut acc = KahanAccumulatorF64::new();
44 for i in 0..pred.len() {
45 acc.add(-target[i] * (pred[i] + eps).ln());
46 }
47 Ok(acc.finalize() / pred.len() as f64)
48}
49
50pub fn binary_cross_entropy(pred: &[f64], target: &[f64]) -> Result<f64, String> {
52 if pred.len() != target.len() {
53 return Err("binary_cross_entropy: arrays must have same length".into());
54 }
55 if pred.is_empty() {
56 return Err("binary_cross_entropy: empty data".into());
57 }
58 let eps = 1e-12;
59 let mut acc = KahanAccumulatorF64::new();
60 for i in 0..pred.len() {
61 let p = pred[i].max(eps).min(1.0 - eps);
62 acc.add(-(target[i] * p.ln() + (1.0 - target[i]) * (1.0 - p).ln()));
63 }
64 Ok(acc.finalize() / pred.len() as f64)
65}
66
67pub fn huber_loss(pred: &[f64], target: &[f64], delta: f64) -> Result<f64, String> {
69 if pred.len() != target.len() {
70 return Err("huber_loss: arrays must have same length".into());
71 }
72 if pred.is_empty() {
73 return Err("huber_loss: empty data".into());
74 }
75 let mut acc = KahanAccumulatorF64::new();
76 for i in 0..pred.len() {
77 let d = (pred[i] - target[i]).abs();
78 if d <= delta {
79 acc.add(0.5 * d * d);
80 } else {
81 acc.add(delta * (d - 0.5 * delta));
82 }
83 }
84 Ok(acc.finalize() / pred.len() as f64)
85}
86
87pub fn hinge_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
89 if pred.len() != target.len() {
90 return Err("hinge_loss: arrays must have same length".into());
91 }
92 if pred.is_empty() {
93 return Err("hinge_loss: empty data".into());
94 }
95 let mut acc = KahanAccumulatorF64::new();
96 for i in 0..pred.len() {
97 acc.add((1.0 - target[i] * pred[i]).max(0.0));
98 }
99 Ok(acc.finalize() / pred.len() as f64)
100}
101
102use crate::idx::ParamIdx;
107
108pub struct SgdState {
119 pub lr: f64,
120 pub momentum: f64,
121 pub velocity: Vec<f64>,
122}
123
124impl SgdState {
125 pub fn new(n_params: usize, lr: f64, momentum: f64) -> Self {
126 Self { lr, momentum, velocity: vec![0.0; n_params] }
127 }
128
129 #[inline]
131 pub fn n_params(&self) -> usize {
132 self.velocity.len()
133 }
134
135 #[inline]
138 pub fn velocity_at(&self, p: ParamIdx) -> f64 {
139 self.velocity[p.index()]
140 }
141
142 #[inline]
144 pub fn set_velocity_at(&mut self, p: ParamIdx, value: f64) {
145 self.velocity[p.index()] = value;
146 }
147}
148
149pub fn sgd_step(params: &mut [f64], grads: &[f64], state: &mut SgdState) {
155 let n = params.len();
156 for i in 0..n {
157 let p = ParamIdx::from_usize(i);
158 let new_velocity = state.momentum * state.velocity_at(p) + grads[i];
159 state.set_velocity_at(p, new_velocity);
160 params[i] -= state.lr * state.velocity_at(p);
161 }
162}
163
164pub struct AdamState {
173 pub lr: f64,
174 pub beta1: f64,
175 pub beta2: f64,
176 pub eps: f64,
177 pub t: u64,
178 pub m: Vec<f64>,
179 pub v: Vec<f64>,
180}
181
182impl AdamState {
183 pub fn new(n_params: usize, lr: f64) -> Self {
184 Self {
185 lr,
186 beta1: 0.9,
187 beta2: 0.999,
188 eps: 1e-8,
189 t: 0,
190 m: vec![0.0; n_params],
191 v: vec![0.0; n_params],
192 }
193 }
194
195 #[inline]
197 pub fn n_params(&self) -> usize {
198 self.m.len()
199 }
200
201 #[inline]
203 pub fn m_at(&self, p: ParamIdx) -> f64 {
204 self.m[p.index()]
205 }
206
207 #[inline]
209 pub fn set_m_at(&mut self, p: ParamIdx, value: f64) {
210 self.m[p.index()] = value;
211 }
212
213 #[inline]
215 pub fn v_at(&self, p: ParamIdx) -> f64 {
216 self.v[p.index()]
217 }
218
219 #[inline]
221 pub fn set_v_at(&mut self, p: ParamIdx, value: f64) {
222 self.v[p.index()] = value;
223 }
224}
225
226pub fn adam_step(params: &mut [f64], grads: &[f64], state: &mut AdamState) {
232 state.t += 1;
233 let t = state.t as f64;
234 let n = params.len();
235 for i in 0..n {
236 let p = ParamIdx::from_usize(i);
237 let new_m = state.beta1 * state.m_at(p) + (1.0 - state.beta1) * grads[i];
238 let new_v = state.beta2 * state.v_at(p) + (1.0 - state.beta2) * grads[i] * grads[i];
239 state.set_m_at(p, new_m);
240 state.set_v_at(p, new_v);
241 let m_hat = new_m / (1.0 - state.beta1.powf(t));
242 let v_hat = new_v / (1.0 - state.beta2.powf(t));
243 params[i] -= state.lr * m_hat / (v_hat.sqrt() + state.eps);
244 }
245}
246
247#[derive(Debug, Clone)]
253pub struct ConfusionMatrix {
254 pub tp: usize,
255 pub fp: usize,
256 pub tn: usize,
257 pub fn_count: usize,
258}
259
260pub fn confusion_matrix(predicted: &[bool], actual: &[bool]) -> ConfusionMatrix {
262 let mut tp = 0;
263 let mut fp = 0;
264 let mut tn = 0;
265 let mut fn_count = 0;
266 for i in 0..predicted.len().min(actual.len()) {
267 match (predicted[i], actual[i]) {
268 (true, true) => tp += 1,
269 (true, false) => fp += 1,
270 (false, true) => fn_count += 1,
271 (false, false) => tn += 1,
272 }
273 }
274 ConfusionMatrix { tp, fp, tn, fn_count }
275}
276
277pub fn precision(cm: &ConfusionMatrix) -> f64 {
279 let denom = cm.tp + cm.fp;
280 if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
281}
282
283pub fn recall(cm: &ConfusionMatrix) -> f64 {
285 let denom = cm.tp + cm.fn_count;
286 if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
287}
288
289pub fn f1_score(cm: &ConfusionMatrix) -> f64 {
291 let p = precision(cm);
292 let r = recall(cm);
293 if p + r == 0.0 { 0.0 } else { 2.0 * p * r / (p + r) }
294}
295
296pub fn accuracy(cm: &ConfusionMatrix) -> f64 {
298 let total = cm.tp + cm.fp + cm.tn + cm.fn_count;
299 if total == 0 { 0.0 } else { (cm.tp + cm.tn) as f64 / total as f64 }
300}
301
302pub fn auc_roc(scores: &[f64], labels: &[bool]) -> Result<f64, String> {
305 if scores.len() != labels.len() {
306 return Err("auc_roc: scores and labels must have same length".into());
307 }
308 let n = scores.len();
309 if n == 0 {
310 return Err("auc_roc: empty data".into());
311 }
312 let mut indexed: Vec<(usize, f64, bool)> = scores.iter().zip(labels.iter())
314 .enumerate()
315 .map(|(i, (&s, &l))| (i, s, l))
316 .collect();
317 indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
318
319 let pos_count = labels.iter().filter(|&&l| l).count();
320 let neg_count = n - pos_count;
321 if pos_count == 0 || neg_count == 0 {
322 return Err("auc_roc: need both positive and negative labels".into());
323 }
324
325 let mut auc = 0.0;
326 let mut tp = 0.0;
327 let mut fp = 0.0;
328 let mut prev_fpr = 0.0;
329 let mut prev_tpr = 0.0;
330
331 for &(_, _, label) in &indexed {
332 if label { tp += 1.0; } else { fp += 1.0; }
333 let tpr = tp / pos_count as f64;
334 let fpr = fp / neg_count as f64;
335 auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
336 prev_fpr = fpr;
337 prev_tpr = tpr;
338 }
339 Ok(auc)
340}
341
342pub fn kfold_indices(n: usize, k: usize, seed: u64) -> Vec<(Vec<usize>, Vec<usize>)> {
345 let mut rng = cjc_repro::Rng::seeded(seed);
346 let mut indices: Vec<usize> = (0..n).collect();
348 for i in (1..n).rev() {
349 let j = (rng.next_u64() as usize) % (i + 1);
350 indices.swap(i, j);
351 }
352 let fold_size = n / k;
353 let mut folds = Vec::with_capacity(k);
354 for fold in 0..k {
355 let start = fold * fold_size;
356 let end = if fold == k - 1 { n } else { start + fold_size };
357 let test: Vec<usize> = indices[start..end].to_vec();
358 let train: Vec<usize> = indices[..start].iter()
359 .chain(indices[end..].iter())
360 .copied()
361 .collect();
362 folds.push((train, test));
363 }
364 folds
365}
366
367pub fn train_test_split(n: usize, test_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
369 let mut rng = cjc_repro::Rng::seeded(seed);
370 let mut indices: Vec<usize> = (0..n).collect();
371 for i in (1..n).rev() {
372 let j = (rng.next_u64() as usize) % (i + 1);
373 indices.swap(i, j);
374 }
375 let test_size = ((n as f64) * test_fraction).round() as usize;
376 let test = indices[..test_size].to_vec();
377 let train = indices[test_size..].to_vec();
378 (train, test)
379}
380
381pub fn bootstrap(data: &[f64], n_resamples: usize, stat_fn: usize, seed: u64) -> Result<(f64, f64, f64, f64), String> {
389 if data.is_empty() { return Err("bootstrap: empty data".into()); }
390 let n = data.len();
391
392 let point = compute_stat(data, stat_fn)?;
394
395 let mut rng = cjc_repro::Rng::seeded(seed);
397 let mut stats = Vec::with_capacity(n_resamples);
398 let mut resample = Vec::with_capacity(n);
399
400 for _ in 0..n_resamples {
401 resample.clear();
402 for _ in 0..n {
403 let idx = (rng.next_u64() as usize) % n;
404 resample.push(data[idx]);
405 }
406 stats.push(compute_stat(&resample, stat_fn)?);
407 }
408
409 stats.sort_by(|a, b| a.total_cmp(b));
411
412 let ci_lower = stats[(n_resamples as f64 * 0.025) as usize];
413 let ci_upper = stats[(n_resamples as f64 * 0.975).min((n_resamples - 1) as f64) as usize];
414
415 let mean_stats: f64 = {
417 let mut acc = cjc_repro::KahanAccumulatorF64::new();
418 for &s in &stats { acc.add(s); }
419 acc.finalize() / n_resamples as f64
420 };
421 let se = {
422 let mut acc = cjc_repro::KahanAccumulatorF64::new();
423 for &s in &stats { let d = s - mean_stats; acc.add(d * d); }
424 (acc.finalize() / (n_resamples as f64 - 1.0)).sqrt()
425 };
426
427 Ok((point, ci_lower, ci_upper, se))
428}
429
430fn compute_stat(data: &[f64], stat_fn: usize) -> Result<f64, String> {
431 match stat_fn {
432 0 => {
433 let mut acc = cjc_repro::KahanAccumulatorF64::new();
435 for &x in data { acc.add(x); }
436 Ok(acc.finalize() / data.len() as f64)
437 }
438 1 => {
439 let mut sorted = data.to_vec();
441 sorted.sort_by(|a, b| a.total_cmp(b));
442 let n = sorted.len();
443 if n % 2 == 0 {
444 Ok((sorted[n/2 - 1] + sorted[n/2]) / 2.0)
445 } else {
446 Ok(sorted[n/2])
447 }
448 }
449 _ => Err(format!("bootstrap: unknown stat_fn {}", stat_fn)),
450 }
451}
452
453pub fn permutation_test(x: &[f64], y: &[f64], n_perms: usize, seed: u64) -> Result<(f64, f64), String> {
456 if x.is_empty() || y.is_empty() { return Err("permutation_test: empty group".into()); }
457
458 let nx = x.len();
459 let combined: Vec<f64> = x.iter().chain(y.iter()).copied().collect();
460 let n = combined.len();
461
462 let mean_x = compute_stat(x, 0)?;
464 let mean_y = compute_stat(y, 0)?;
465 let observed = (mean_x - mean_y).abs();
466
467 let mut rng = cjc_repro::Rng::seeded(seed);
469 let mut count_extreme = 0usize;
470 let mut perm = combined.clone();
471
472 for _ in 0..n_perms {
473 for i in (1..n).rev() {
475 let j = (rng.next_u64() as usize) % (i + 1);
476 perm.swap(i, j);
477 }
478 let perm_mean_x = compute_stat(&perm[..nx], 0)?;
479 let perm_mean_y = compute_stat(&perm[nx..], 0)?;
480 if (perm_mean_x - perm_mean_y).abs() >= observed {
481 count_extreme += 1;
482 }
483 }
484
485 let p_value = count_extreme as f64 / n_perms as f64;
486 Ok((observed, p_value))
487}
488
489pub fn stratified_split(labels: &[i64], test_frac: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
493 use std::collections::BTreeMap;
494
495 let n = labels.len();
496 let mut groups: BTreeMap<i64, Vec<usize>> = BTreeMap::new();
498 for (i, &label) in labels.iter().enumerate() {
499 groups.entry(label).or_default().push(i);
500 }
501
502 let mut train = Vec::with_capacity(n);
503 let mut test = Vec::with_capacity(n);
504 let mut rng = cjc_repro::Rng::seeded(seed);
505
506 for (_label, mut indices) in groups {
507 let m = indices.len();
509 for i in (1..m).rev() {
510 let j = (rng.next_u64() as usize) % (i + 1);
511 indices.swap(i, j);
512 }
513 let n_test = ((m as f64 * test_frac).round() as usize).max(if m > 1 { 1 } else { 0 });
514 let n_test = n_test.min(m);
515 test.extend_from_slice(&indices[..n_test]);
516 train.extend_from_slice(&indices[n_test..]);
517 }
518
519 train.sort();
521 test.sort();
522 (train, test)
523}
524
525pub fn batch_norm(
532 x: &[f64],
533 running_mean: &[f64],
534 running_var: &[f64],
535 gamma: &[f64],
536 beta: &[f64],
537 eps: f64,
538) -> Result<Vec<f64>, String> {
539 let n = x.len();
540 if running_mean.len() != n || running_var.len() != n || gamma.len() != n || beta.len() != n {
541 return Err("batch_norm: all arrays must have same length".into());
542 }
543 let mut result = Vec::with_capacity(n);
544 for i in 0..n {
545 let normed = (x[i] - running_mean[i]) / (running_var[i] + eps).sqrt();
546 result.push(gamma[i] * normed + beta[i]);
547 }
548 Ok(result)
549}
550
551pub fn dropout_mask(n: usize, drop_prob: f64, seed: u64) -> Vec<f64> {
554 let mut rng = cjc_repro::Rng::seeded(seed);
555 let scale = if drop_prob < 1.0 { 1.0 / (1.0 - drop_prob) } else { 0.0 };
556 let mut mask = Vec::with_capacity(n);
557 for _ in 0..n {
558 let r = (rng.next_u64() as f64) / (u64::MAX as f64);
559 if r < drop_prob {
560 mask.push(0.0);
561 } else {
562 mask.push(scale);
563 }
564 }
565 mask
566}
567
568pub fn apply_dropout(data: &[f64], mask: &[f64]) -> Result<Vec<f64>, String> {
570 if data.len() != mask.len() {
571 return Err("apply_dropout: data and mask must have same length".into());
572 }
573 Ok(data.iter().zip(mask.iter()).map(|(&d, &m)| d * m).collect())
574}
575
576pub fn lr_step_decay(initial_lr: f64, decay_rate: f64, epoch: usize, step_size: usize) -> f64 {
579 initial_lr * decay_rate.powi((epoch / step_size) as i32)
580}
581
582pub fn lr_cosine(max_lr: f64, min_lr: f64, epoch: usize, total_epochs: usize) -> f64 {
585 let ratio = epoch as f64 / total_epochs as f64;
586 min_lr + 0.5 * (max_lr - min_lr) * (1.0 + (std::f64::consts::PI * ratio).cos())
587}
588
589pub fn lr_linear_warmup(initial_lr: f64, epoch: usize, warmup_epochs: usize) -> f64 {
592 if warmup_epochs == 0 {
593 return initial_lr;
594 }
595 initial_lr * (epoch as f64 / warmup_epochs as f64).min(1.0)
596}
597
598pub fn l1_penalty(params: &[f64], lambda: f64) -> f64 {
600 let mut acc = KahanAccumulatorF64::new();
601 for &p in params {
602 acc.add(p.abs());
603 }
604 lambda * acc.finalize()
605}
606
607pub fn l2_penalty(params: &[f64], lambda: f64) -> f64 {
609 let mut acc = KahanAccumulatorF64::new();
610 for &p in params {
611 acc.add(p * p);
612 }
613 0.5 * lambda * acc.finalize()
614}
615
616pub fn l1_grad(params: &[f64], lambda: f64) -> Vec<f64> {
618 params.iter().map(|&p| {
619 if p > 0.0 { lambda } else if p < 0.0 { -lambda } else { 0.0 }
620 }).collect()
621}
622
623pub fn l2_grad(params: &[f64], lambda: f64) -> Vec<f64> {
625 params.iter().map(|&p| lambda * p).collect()
626}
627
628pub struct EarlyStoppingState {
630 pub patience: usize,
631 pub min_delta: f64,
632 pub best_loss: f64,
633 pub wait: usize,
634 pub stopped: bool,
635}
636
637impl EarlyStoppingState {
638 pub fn new(patience: usize, min_delta: f64) -> Self {
639 Self {
640 patience,
641 min_delta,
642 best_loss: f64::INFINITY,
643 wait: 0,
644 stopped: false,
645 }
646 }
647
648 pub fn check(&mut self, current_loss: f64) -> bool {
650 if current_loss < self.best_loss - self.min_delta {
651 self.best_loss = current_loss;
652 self.wait = 0;
653 } else {
654 self.wait += 1;
655 }
656 if self.wait >= self.patience {
657 self.stopped = true;
658 }
659 self.stopped
660 }
661}
662
663pub fn pca(
679 data: &Tensor,
680 n_components: usize,
681) -> Result<(Tensor, Tensor, Vec<f64>), RuntimeError> {
682 if data.ndim() != 2 {
683 return Err(RuntimeError::InvalidOperation(
684 "PCA requires a 2D data matrix".to_string(),
685 ));
686 }
687 let n_samples = data.shape()[0];
688 let n_features = data.shape()[1];
689
690 if n_samples == 0 || n_features == 0 {
691 return Err(RuntimeError::InvalidOperation(
692 "PCA: empty data matrix".to_string(),
693 ));
694 }
695 if n_components == 0 || n_components > n_features.min(n_samples) {
696 return Err(RuntimeError::InvalidOperation(format!(
697 "PCA: n_components ({}) must be in [1, min(n_samples, n_features) = {}]",
698 n_components,
699 n_features.min(n_samples)
700 )));
701 }
702
703 let raw = data.to_vec();
704
705 let mut means = vec![0.0f64; n_features];
707 for j in 0..n_features {
708 let mut acc = BinnedAccumulatorF64::new();
709 for i in 0..n_samples {
710 acc.add(raw[i * n_features + j]);
711 }
712 means[j] = acc.finalize() / n_samples as f64;
713 }
714
715 let mut centered = vec![0.0f64; n_samples * n_features];
717 for i in 0..n_samples {
718 for j in 0..n_features {
719 centered[i * n_features + j] = raw[i * n_features + j] - means[j];
720 }
721 }
722 let centered_tensor = Tensor::from_vec(centered, &[n_samples, n_features])?;
723
724 let (u, s, vt) = centered_tensor.svd()?;
726 let k = n_components.min(s.len());
727
728 let vt_data = vt.to_vec();
730 let vt_cols = vt.shape()[1]; let mut components = vec![0.0f64; k * n_features];
732 for i in 0..k {
733 for j in 0..n_features {
734 components[i * n_features + j] = vt_data[i * vt_cols + j];
735 }
736 }
737
738 let denom = if n_samples > 1 {
741 (n_samples - 1) as f64
742 } else {
743 1.0
744 };
745
746 let mut total_var_acc = BinnedAccumulatorF64::new();
747 for &si in &s {
748 total_var_acc.add(si * si / denom);
749 }
750 let total_var = total_var_acc.finalize();
751
752 let explained_variance_ratio: Vec<f64> = if total_var > 1e-15 {
753 s[..k]
754 .iter()
755 .map(|&si| (si * si / denom) / total_var)
756 .collect()
757 } else {
758 vec![0.0; k]
759 };
760
761 let u_data = u.to_vec();
763 let u_cols = u.shape()[1];
764 let mut transformed = vec![0.0f64; n_samples * k];
765 for i in 0..n_samples {
766 for j in 0..k {
767 transformed[i * k + j] = u_data[i * u_cols + j] * s[j];
768 }
769 }
770
771 Ok((
772 Tensor::from_vec(transformed, &[n_samples, k])?,
773 Tensor::from_vec(components, &[k, n_features])?,
774 explained_variance_ratio,
775 ))
776}
777
778pub struct LbfgsState {
788 pub lr: f64,
789 pub m: usize,
791 pub s_history: Vec<Vec<f64>>,
793 pub y_history: Vec<Vec<f64>>,
795 pub prev_params: Option<Vec<f64>>,
796 pub prev_grad: Option<Vec<f64>>,
797}
798
799impl LbfgsState {
800 pub fn new(lr: f64, m: usize) -> Self {
801 Self {
802 lr,
803 m,
804 s_history: Vec::new(),
805 y_history: Vec::new(),
806 prev_params: None,
807 prev_grad: None,
808 }
809 }
810}
811
812fn kahan_dot(a: &[f64], b: &[f64]) -> f64 {
814 debug_assert_eq!(a.len(), b.len());
815 let mut acc = KahanAccumulatorF64::new();
816 for (&ai, &bi) in a.iter().zip(b.iter()) {
817 acc.add(ai * bi);
818 }
819 acc.finalize()
820}
821
822pub fn wolfe_line_search<F>(
841 params: &[f64],
842 direction: &[f64],
843 f: &mut F,
844 f0: f64,
845 g0: &[f64],
846 alpha_init: f64,
847) -> (f64, Vec<f64>, f64, Vec<f64>)
848where
849 F: FnMut(&[f64]) -> (f64, Vec<f64>),
850{
851 let c1 = 1e-4_f64;
852 let c2 = 0.9_f64;
853 let derphi0 = kahan_dot(g0, direction); let step = |alpha: f64| -> Vec<f64> {
857 params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
858 };
859
860 let max_iter = 30;
861 let mut alpha_lo = 0.0_f64;
862 let mut alpha_hi = f64::INFINITY;
863 let mut phi_lo = f0;
864 let mut dphi_lo = derphi0;
865 let mut alpha = alpha_init;
866
867 let mut best_alpha = alpha_init;
869 let mut best_params = step(alpha_init);
870 let (mut best_val, mut best_grad) = f(&best_params);
871
872 for _iter in 0..max_iter {
873 let x_alpha = step(alpha);
874 let (phi_alpha, grad_alpha) = f(&x_alpha);
875 let dphi_alpha = kahan_dot(&grad_alpha, direction);
876
877 if phi_alpha < best_val {
879 best_alpha = alpha;
880 best_params = x_alpha.clone();
881 best_val = phi_alpha;
882 best_grad = grad_alpha.clone();
883 }
884
885 if phi_alpha > f0 + c1 * alpha * derphi0 || (phi_alpha >= phi_lo && alpha_lo > 0.0) {
887 alpha_hi = alpha;
888 let (za, zp, zv, zg) = wolfe_zoom(
890 params, direction, f, f0, derphi0, c1, c2,
891 alpha_lo, alpha_hi, phi_lo, dphi_lo,
892 );
893 return (za, zp, zv, zg);
894 }
895
896 if dphi_alpha.abs() <= c2 * derphi0.abs() {
898 return (alpha, x_alpha, phi_alpha, grad_alpha);
899 }
900
901 if dphi_alpha >= 0.0 {
903 let (za, zp, zv, zg) = wolfe_zoom(
904 params, direction, f, f0, derphi0, c1, c2,
905 alpha, alpha_lo, phi_alpha, dphi_alpha,
906 );
907 return (za, zp, zv, zg);
908 }
909
910 alpha_lo = alpha;
912 phi_lo = phi_alpha;
913 dphi_lo = dphi_alpha;
914
915 alpha = if alpha_hi.is_finite() {
917 (alpha_lo + alpha_hi) * 0.5
918 } else {
919 (alpha * 2.0).min(1e8)
920 };
921 }
922
923 (best_alpha, best_params, best_val, best_grad)
925}
926
927#[allow(clippy::too_many_arguments)]
929fn wolfe_zoom<F>(
930 params: &[f64],
931 direction: &[f64],
932 f: &mut F,
933 f0: f64,
934 derphi0: f64,
935 c1: f64,
936 c2: f64,
937 mut alpha_lo: f64,
938 mut alpha_hi: f64,
939 mut phi_lo: f64,
940 _dphi_lo: f64,
941) -> (f64, Vec<f64>, f64, Vec<f64>)
942where
943 F: FnMut(&[f64]) -> (f64, Vec<f64>),
944{
945 let step = |alpha: f64| -> Vec<f64> {
946 params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
947 };
948
949 let max_zoom = 20;
950 let mut best_alpha = alpha_lo;
951 let mut best_x = step(alpha_lo);
952 let (mut best_val, mut best_grad) = f(&best_x);
953
954 for _i in 0..max_zoom {
955 let alpha_j = (alpha_lo + alpha_hi) * 0.5;
956 let x_j = step(alpha_j);
957 let (phi_j, grad_j) = f(&x_j);
958 let dphi_j = kahan_dot(&grad_j, direction);
959
960 if phi_j < best_val {
961 best_alpha = alpha_j;
962 best_x = x_j.clone();
963 best_val = phi_j;
964 best_grad = grad_j.clone();
965 }
966
967 if phi_j > f0 + c1 * alpha_j * derphi0 || phi_j >= phi_lo {
968 alpha_hi = alpha_j;
969 } else {
970 if dphi_j.abs() <= c2 * derphi0.abs() {
972 return (alpha_j, x_j, phi_j, grad_j);
973 }
974 if dphi_j * (alpha_hi - alpha_lo) >= 0.0 {
975 alpha_hi = alpha_lo;
976 }
977 alpha_lo = alpha_j;
978 phi_lo = phi_j;
979 }
980
981 if (alpha_hi - alpha_lo).abs() < 1e-14 {
983 break;
984 }
985 }
986
987 (best_alpha, best_x, best_val, best_grad)
988}
989
990pub fn lbfgs_step<F>(
1007 params: &[f64],
1008 grads: &[f64],
1009 state: &mut LbfgsState,
1010 mut f: F,
1011) -> (Vec<f64>, Vec<f64>, bool)
1012where
1013 F: FnMut(&[f64]) -> (f64, Vec<f64>),
1014{
1015 let n = params.len();
1016 debug_assert_eq!(grads.len(), n);
1017
1018 let hist_len = state.s_history.len();
1021 let mut q: Vec<f64> = grads.to_vec();
1022 let mut alphas = vec![0.0_f64; hist_len];
1023 let mut rhos = vec![0.0_f64; hist_len];
1024
1025 for i in (0..hist_len).rev() {
1027 let sy = kahan_dot(&state.s_history[i], &state.y_history[i]);
1028 rhos[i] = if sy.abs() < 1e-300 { 0.0 } else { 1.0 / sy };
1029 let sq = kahan_dot(&state.s_history[i], &q);
1030 alphas[i] = rhos[i] * sq;
1031 for j in 0..n {
1033 q[j] -= alphas[i] * state.y_history[i][j];
1034 }
1035 }
1036
1037 let scale = if hist_len > 0 {
1039 let last = hist_len - 1;
1040 let sy = kahan_dot(&state.s_history[last], &state.y_history[last]);
1041 let yy = kahan_dot(&state.y_history[last], &state.y_history[last]);
1042 if yy.abs() < 1e-300 { 1.0 } else { sy / yy }
1043 } else {
1044 1.0
1045 };
1046 let mut r: Vec<f64> = q.iter().map(|&qi| scale * qi).collect();
1047
1048 for i in 0..hist_len {
1050 let yr = kahan_dot(&state.y_history[i], &r);
1051 let beta = rhos[i] * yr;
1052 let diff = alphas[i] - beta;
1054 for j in 0..n {
1055 r[j] += diff * state.s_history[i][j];
1056 }
1057 }
1058
1059 let direction: Vec<f64> = r.iter().map(|&ri| -ri).collect();
1061
1062 let descent_check = kahan_dot(&direction, grads);
1064 let (direction, is_descent) = if descent_check >= 0.0 || !descent_check.is_finite() {
1065 let norm_g = kahan_dot(grads, grads).sqrt().max(1e-300);
1067 (grads.iter().map(|&g| -g / norm_g).collect::<Vec<f64>>(), false)
1068 } else {
1069 (direction, true)
1070 };
1071
1072 let (f0, _) = f(params);
1074 let (_, new_params, _, new_grads) = wolfe_line_search(
1075 params,
1076 &direction,
1077 &mut f,
1078 f0,
1079 grads,
1080 state.lr,
1081 );
1082
1083 let s_k: Vec<f64> = new_params.iter().zip(params.iter()).map(|(&np, &p)| np - p).collect();
1086 let y_k: Vec<f64> = new_grads.iter().zip(grads.iter()).map(|(&ng, &g)| ng - g).collect();
1088
1089 let sy = kahan_dot(&s_k, &y_k);
1090 if sy > 1e-300 {
1092 state.s_history.push(s_k);
1093 state.y_history.push(y_k);
1094 if state.s_history.len() > state.m {
1096 state.s_history.remove(0);
1097 state.y_history.remove(0);
1098 }
1099 }
1100
1101 state.prev_params = Some(new_params.clone());
1102 state.prev_grad = Some(new_grads.clone());
1103
1104 (new_params, new_grads, is_descent)
1105}
1106
1107pub fn lstm_cell(
1126 x: &Tensor,
1127 h_prev: &Tensor,
1128 c_prev: &Tensor,
1129 w_ih: &Tensor,
1130 w_hh: &Tensor,
1131 b_ih: &Tensor,
1132 b_hh: &Tensor,
1133) -> Result<(Tensor, Tensor), String> {
1134 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1135
1136 if x.ndim() != 2 {
1138 return Err("lstm_cell: x must be 2-D [batch, input_size]".into());
1139 }
1140 if h_prev.ndim() != 2 {
1141 return Err("lstm_cell: h_prev must be 2-D [batch, hidden_size]".into());
1142 }
1143 if c_prev.ndim() != 2 {
1144 return Err("lstm_cell: c_prev must be 2-D [batch, hidden_size]".into());
1145 }
1146 let hidden_size = h_prev.shape()[1];
1147 if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1148 return Err(format!(
1149 "lstm_cell: w_ih must be [4*hidden_size, input_size], got {:?}",
1150 w_ih.shape()
1151 ));
1152 }
1153 if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1154 return Err(format!(
1155 "lstm_cell: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1156 w_hh.shape()
1157 ));
1158 }
1159 if b_ih.len() != 4 * hidden_size {
1160 return Err(format!(
1161 "lstm_cell: b_ih must have length 4*hidden_size={}, got {}",
1162 4 * hidden_size,
1163 b_ih.len()
1164 ));
1165 }
1166 if b_hh.len() != 4 * hidden_size {
1167 return Err(format!(
1168 "lstm_cell: b_hh must have length 4*hidden_size={}, got {}",
1169 4 * hidden_size,
1170 b_hh.len()
1171 ));
1172 }
1173
1174 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1177 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1178 let gates = gates_ih.add(&gates_hh).map_err(map_err)?;
1179
1180 let chunks = gates.chunk(4, 1).map_err(map_err)?;
1182 let gates_i = &chunks[0];
1183 let gates_f = &chunks[1];
1184 let gates_g = &chunks[2];
1185 let gates_o = &chunks[3];
1186
1187 let i = gates_i.sigmoid();
1189 let f = gates_f.sigmoid();
1190 let g = gates_g.tanh_activation();
1191 let o = gates_o.sigmoid();
1192
1193 let fc = f.mul_elem(c_prev).map_err(map_err)?;
1195 let ig = i.mul_elem(&g).map_err(map_err)?;
1196 let c_new = fc.add(&ig).map_err(map_err)?;
1197
1198 let c_tanh = c_new.tanh_activation();
1200 let h_new = o.mul_elem(&c_tanh).map_err(map_err)?;
1201
1202 Ok((h_new, c_new))
1203}
1204
1205pub fn gru_cell(
1223 x: &Tensor,
1224 h_prev: &Tensor,
1225 w_ih: &Tensor,
1226 w_hh: &Tensor,
1227 b_ih: &Tensor,
1228 b_hh: &Tensor,
1229) -> Result<Tensor, String> {
1230 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1231
1232 if x.ndim() != 2 {
1234 return Err("gru_cell: x must be 2-D [batch, input_size]".into());
1235 }
1236 if h_prev.ndim() != 2 {
1237 return Err("gru_cell: h_prev must be 2-D [batch, hidden_size]".into());
1238 }
1239 let hidden_size = h_prev.shape()[1];
1240 if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1241 return Err(format!(
1242 "gru_cell: w_ih must be [3*hidden_size, input_size], got {:?}",
1243 w_ih.shape()
1244 ));
1245 }
1246 if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1247 return Err(format!(
1248 "gru_cell: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1249 w_hh.shape()
1250 ));
1251 }
1252 if b_ih.len() != 3 * hidden_size {
1253 return Err(format!(
1254 "gru_cell: b_ih must have length 3*hidden_size={}, got {}",
1255 3 * hidden_size,
1256 b_ih.len()
1257 ));
1258 }
1259 if b_hh.len() != 3 * hidden_size {
1260 return Err(format!(
1261 "gru_cell: b_hh must have length 3*hidden_size={}, got {}",
1262 3 * hidden_size,
1263 b_hh.len()
1264 ));
1265 }
1266
1267 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1270 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1271
1272 let ih_chunks = gates_ih.chunk(3, 1).map_err(map_err)?;
1274 let hh_chunks = gates_hh.chunk(3, 1).map_err(map_err)?;
1275 let r_ih = &ih_chunks[0];
1276 let z_ih = &ih_chunks[1];
1277 let n_ih = &ih_chunks[2];
1278 let r_hh = &hh_chunks[0];
1279 let z_hh = &hh_chunks[1];
1280 let n_hh = &hh_chunks[2];
1281
1282 let r = r_ih.add(r_hh).map_err(map_err)?.sigmoid();
1284 let z = z_ih.add(z_hh).map_err(map_err)?.sigmoid();
1286 let r_n_hh = r.mul_elem(n_hh).map_err(map_err)?;
1288 let n = n_ih.add(&r_n_hh).map_err(map_err)?.tanh_activation();
1289
1290 let ones = Tensor::ones(z.shape());
1293 let one_minus_z = ones.sub(&z).map_err(map_err)?;
1294 let term1 = one_minus_z.mul_elem(&n).map_err(map_err)?;
1295 let term2 = z.mul_elem(h_prev).map_err(map_err)?;
1296 let h_new = term1.add(&term2).map_err(map_err)?;
1297
1298 Ok(h_new)
1299}
1300
1301pub fn lstm_cell_fused(
1315 x: &Tensor,
1316 h_prev: &Tensor,
1317 c_prev: &Tensor,
1318 w_ih: &Tensor,
1319 w_hh: &Tensor,
1320 b_ih: &Tensor,
1321 b_hh: &Tensor,
1322) -> Result<(Tensor, Tensor), String> {
1323 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1324
1325 if x.ndim() != 2 {
1327 return Err("lstm_cell_fused: x must be 2-D [batch, input_size]".into());
1328 }
1329 if h_prev.ndim() != 2 {
1330 return Err("lstm_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1331 }
1332 if c_prev.ndim() != 2 {
1333 return Err("lstm_cell_fused: c_prev must be 2-D [batch, hidden_size]".into());
1334 }
1335 let batch = x.shape()[0];
1336 let hidden_size = h_prev.shape()[1];
1337 if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1338 return Err(format!(
1339 "lstm_cell_fused: w_ih must be [4*hidden_size, input_size], got {:?}",
1340 w_ih.shape()
1341 ));
1342 }
1343 if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1344 return Err(format!(
1345 "lstm_cell_fused: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1346 w_hh.shape()
1347 ));
1348 }
1349 if b_ih.len() != 4 * hidden_size {
1350 return Err(format!(
1351 "lstm_cell_fused: b_ih must have length 4*hidden_size={}, got {}",
1352 4 * hidden_size,
1353 b_ih.len()
1354 ));
1355 }
1356 if b_hh.len() != 4 * hidden_size {
1357 return Err(format!(
1358 "lstm_cell_fused: b_hh must have length 4*hidden_size={}, got {}",
1359 4 * hidden_size,
1360 b_hh.len()
1361 ));
1362 }
1363
1364 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1366 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1367
1368 let gih = gates_ih.to_vec();
1370 let ghh = gates_hh.to_vec();
1371 let cprev = c_prev.to_vec();
1372
1373 let mut h_new_data = vec![0.0f64; batch * hidden_size];
1374 let mut c_new_data = vec![0.0f64; batch * hidden_size];
1375
1376 for b_idx in 0..batch {
1377 let base = b_idx * 4 * hidden_size;
1378 for h in 0..hidden_size {
1379 let gi = gih[base + h] + ghh[base + h];
1381 let gf = gih[base + hidden_size + h] + ghh[base + hidden_size + h];
1382 let gg = gih[base + 2 * hidden_size + h] + ghh[base + 2 * hidden_size + h];
1383 let go = gih[base + 3 * hidden_size + h] + ghh[base + 3 * hidden_size + h];
1384
1385 let i_val = 1.0 / (1.0 + (-gi).exp()); let f_val = 1.0 / (1.0 + (-gf).exp()); let g_val = gg.tanh(); let o_val = 1.0 / (1.0 + (-go).exp()); let c_idx = b_idx * hidden_size + h;
1393 let c_val = f_val * cprev[c_idx] + i_val * g_val;
1394 c_new_data[c_idx] = c_val;
1395 h_new_data[c_idx] = o_val * c_val.tanh();
1396 }
1397 }
1398
1399 let h_new = Tensor::from_vec(h_new_data, &[batch, hidden_size]).map_err(map_err)?;
1400 let c_new = Tensor::from_vec(c_new_data, &[batch, hidden_size]).map_err(map_err)?;
1401
1402 Ok((h_new, c_new))
1403}
1404
1405pub fn gru_cell_fused(
1418 x: &Tensor,
1419 h_prev: &Tensor,
1420 w_ih: &Tensor,
1421 w_hh: &Tensor,
1422 b_ih: &Tensor,
1423 b_hh: &Tensor,
1424) -> Result<Tensor, String> {
1425 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1426
1427 if x.ndim() != 2 {
1429 return Err("gru_cell_fused: x must be 2-D [batch, input_size]".into());
1430 }
1431 if h_prev.ndim() != 2 {
1432 return Err("gru_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1433 }
1434 let batch = x.shape()[0];
1435 let hidden_size = h_prev.shape()[1];
1436 if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1437 return Err(format!(
1438 "gru_cell_fused: w_ih must be [3*hidden_size, input_size], got {:?}",
1439 w_ih.shape()
1440 ));
1441 }
1442 if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1443 return Err(format!(
1444 "gru_cell_fused: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1445 w_hh.shape()
1446 ));
1447 }
1448 if b_ih.len() != 3 * hidden_size {
1449 return Err(format!(
1450 "gru_cell_fused: b_ih must have length 3*hidden_size={}, got {}",
1451 3 * hidden_size,
1452 b_ih.len()
1453 ));
1454 }
1455 if b_hh.len() != 3 * hidden_size {
1456 return Err(format!(
1457 "gru_cell_fused: b_hh must have length 3*hidden_size={}, got {}",
1458 3 * hidden_size,
1459 b_hh.len()
1460 ));
1461 }
1462
1463 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1465 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1466
1467 let gih = gates_ih.to_vec();
1469 let ghh = gates_hh.to_vec();
1470 let hp = h_prev.to_vec();
1471
1472 let mut h_new_data = vec![0.0f64; batch * hidden_size];
1473
1474 for b_idx in 0..batch {
1475 let base = b_idx * 3 * hidden_size;
1476 for h in 0..hidden_size {
1477 let r_val =
1479 1.0 / (1.0 + (-(gih[base + h] + ghh[base + h])).exp());
1480 let z_val = 1.0
1482 / (1.0
1483 + (-(gih[base + hidden_size + h]
1484 + ghh[base + hidden_size + h]))
1485 .exp());
1486 let n_val = (gih[base + 2 * hidden_size + h]
1488 + r_val * ghh[base + 2 * hidden_size + h])
1489 .tanh();
1490
1491 let h_idx = b_idx * hidden_size + h;
1493 h_new_data[h_idx] = (1.0 - z_val) * n_val + z_val * hp[h_idx];
1494 }
1495 }
1496
1497 Tensor::from_vec(h_new_data, &[batch, hidden_size])
1498 .map_err(|e| format!("{e}"))
1499}
1500
1501pub fn multi_head_attention(
1517 q: &Tensor,
1518 k: &Tensor,
1519 v: &Tensor,
1520 w_q: &Tensor,
1521 w_k: &Tensor,
1522 w_v: &Tensor,
1523 w_o: &Tensor,
1524 b_q: &Tensor,
1525 b_k: &Tensor,
1526 b_v: &Tensor,
1527 b_o: &Tensor,
1528 num_heads: usize,
1529) -> Result<Tensor, String> {
1530 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1531
1532 if q.ndim() != 3 {
1533 return Err("multi_head_attention: q must be 3-D [batch, seq, model_dim]".into());
1534 }
1535
1536 let q_proj = q.linear(w_q, b_q).map_err(map_err)?;
1538 let k_proj = k.linear(w_k, b_k).map_err(map_err)?;
1539 let v_proj = v.linear(w_v, b_v).map_err(map_err)?;
1540
1541 let q_heads = q_proj.split_heads(num_heads).map_err(map_err)?;
1543 let k_heads = k_proj.split_heads(num_heads).map_err(map_err)?;
1544 let v_heads = v_proj.split_heads(num_heads).map_err(map_err)?;
1545
1546 let attn = Tensor::scaled_dot_product_attention(&q_heads, &k_heads, &v_heads)
1548 .map_err(map_err)?;
1549
1550 let merged = attn.merge_heads().map_err(map_err)?;
1552
1553 let output = merged.linear(w_o, b_o).map_err(map_err)?;
1555
1556 Ok(output)
1557}
1558
1559pub fn embedding(weight: &crate::tensor::Tensor, indices: &[i64]) -> Result<crate::tensor::Tensor, String> {
1581 let shape = weight.shape();
1582 if shape.len() != 2 {
1583 return Err(format!("embedding: weight must be 2-D [vocab_size, embed_dim], got {:?}", shape));
1584 }
1585 let vocab_size = shape[0];
1586 let embed_dim = shape[1];
1587 let weight_data = weight.to_vec();
1588
1589 let mut out = Vec::with_capacity(indices.len() * embed_dim);
1590 for &idx in indices {
1591 let i = idx as usize;
1592 if i >= vocab_size {
1593 return Err(format!("embedding: index {} out of bounds for vocab_size {}", idx, vocab_size));
1594 }
1595 let start = i * embed_dim;
1596 out.extend_from_slice(&weight_data[start..start + embed_dim]);
1597 }
1598 crate::tensor::Tensor::from_vec(out, &[indices.len(), embed_dim])
1599 .map_err(|e| e.to_string())
1600}
1601
1602pub fn batch_indices(dataset_size: usize, batch_size: usize, seed: u64) -> Vec<(usize, usize)> {
1621 use cjc_repro::Rng;
1622 let mut rng = Rng::seeded(seed);
1623 let mut indices: Vec<usize> = (0..dataset_size).collect();
1625 for i in (1..dataset_size).rev() {
1626 let j = (rng.next_u64() as usize) % (i + 1);
1627 indices.swap(i, j);
1628 }
1629 let mut batches = Vec::new();
1630 let mut i = 0;
1631 while i < dataset_size {
1632 let end = (i + batch_size).min(dataset_size);
1633 batches.push((i, end));
1634 i = end;
1635 }
1636 batches
1637}
1638
1639#[cfg(test)]
1644mod tests {
1645 use super::*;
1646
1647 #[test]
1648 fn test_mse_zero() {
1649 let pred = [1.0, 2.0, 3.0];
1650 let target = [1.0, 2.0, 3.0];
1651 assert_eq!(mse_loss(&pred, &target).unwrap(), 0.0);
1652 }
1653
1654 #[test]
1655 fn test_mse_basic() {
1656 let pred = [1.0, 2.0, 3.0];
1657 let target = [2.0, 3.0, 4.0];
1658 assert_eq!(mse_loss(&pred, &target).unwrap(), 1.0);
1659 }
1660
1661 #[test]
1662 fn test_huber_loss_quadratic() {
1663 let pred = [1.0];
1664 let target = [1.5];
1665 let h = huber_loss(&pred, &target, 1.0).unwrap();
1666 assert!((h - 0.125).abs() < 1e-12);
1668 }
1669
1670 #[test]
1671 fn test_sgd_step() {
1672 let mut params = [1.0, 2.0];
1673 let grads = [0.1, 0.2];
1674 let mut state = SgdState::new(2, 0.1, 0.0);
1675 sgd_step(&mut params, &grads, &mut state);
1676 assert!((params[0] - 0.99).abs() < 1e-12);
1677 assert!((params[1] - 1.98).abs() < 1e-12);
1678 }
1679
1680 #[test]
1681 fn test_adam_step() {
1682 let mut params = [1.0, 2.0];
1683 let grads = [0.1, 0.2];
1684 let mut state = AdamState::new(2, 0.001);
1685 adam_step(&mut params, &grads, &mut state);
1686 assert!(params[0] < 1.0);
1688 assert!(params[1] < 2.0);
1689 }
1690
1691 #[test]
1694 fn paramidx_accessors_agree_with_direct_field_reads_adam() {
1695 let mut params = vec![0.5, -0.3, 1.7, 0.0];
1698 let grads = vec![0.1, -0.2, 0.05, 0.4];
1699 let mut state = AdamState::new(4, 0.01);
1700 for _ in 0..5 {
1701 adam_step(&mut params, &grads, &mut state);
1702 }
1703 for i in 0..4 {
1704 let p = crate::idx::ParamIdx::from_usize(i);
1705 assert_eq!(state.m_at(p).to_bits(), state.m[i].to_bits());
1706 assert_eq!(state.v_at(p).to_bits(), state.v[i].to_bits());
1707 }
1708 assert_eq!(state.n_params(), 4);
1709 }
1710
1711 #[test]
1712 fn paramidx_accessors_agree_with_direct_field_reads_sgd() {
1713 let mut params = vec![1.0, 2.0, 3.0];
1714 let grads = vec![0.01, -0.02, 0.03];
1715 let mut state = SgdState::new(3, 0.05, 0.9);
1716 for _ in 0..7 {
1717 sgd_step(&mut params, &grads, &mut state);
1718 }
1719 for i in 0..3 {
1720 let p = crate::idx::ParamIdx::from_usize(i);
1721 assert_eq!(state.velocity_at(p).to_bits(), state.velocity[i].to_bits());
1722 }
1723 assert_eq!(state.n_params(), 3);
1724 }
1725
1726 #[test]
1727 fn paramidx_typed_setters_round_trip() {
1728 let mut state = AdamState::new(8, 0.001);
1729 for i in 0..8 {
1730 let p = crate::idx::ParamIdx::from_usize(i);
1731 state.set_m_at(p, (i as f64) * 0.5);
1732 state.set_v_at(p, (i as f64) * 0.25);
1733 }
1734 for i in 0..8 {
1735 let p = crate::idx::ParamIdx::from_usize(i);
1736 assert_eq!(state.m_at(p), (i as f64) * 0.5);
1737 assert_eq!(state.v_at(p), (i as f64) * 0.25);
1738 assert_eq!(state.m[i], (i as f64) * 0.5);
1740 assert_eq!(state.v[i], (i as f64) * 0.25);
1741 }
1742 }
1743
1744 #[test]
1745 fn paramidx_adam_step_byte_equal_across_independent_runs() {
1746 let init_params = vec![0.5, -0.5, 1.5, -1.5, 0.7];
1752 let grads = vec![0.1, -0.05, 0.02, 0.08, -0.03];
1753
1754 let run = || {
1755 let mut params = init_params.clone();
1756 let mut state = AdamState::new(5, 0.01);
1757 for _ in 0..20 {
1758 adam_step(&mut params, &grads, &mut state);
1759 }
1760 (params, state.m.clone(), state.v.clone(), state.t)
1761 };
1762
1763 let (p1, m1, v1, t1) = run();
1764 let (p2, m2, v2, t2) = run();
1765
1766 let bits = |xs: &[f64]| -> Vec<u64> { xs.iter().map(|x| x.to_bits()).collect() };
1767 assert_eq!(bits(&p1), bits(&p2));
1768 assert_eq!(bits(&m1), bits(&m2));
1769 assert_eq!(bits(&v1), bits(&v2));
1770 assert_eq!(t1, t2);
1771 }
1772
1773 #[test]
1774 fn test_confusion_matrix() {
1775 let pred = [true, true, false, false, true];
1776 let actual = [true, false, true, false, true];
1777 let cm = confusion_matrix(&pred, &actual);
1778 assert_eq!(cm.tp, 2);
1779 assert_eq!(cm.fp, 1);
1780 assert_eq!(cm.fn_count, 1);
1781 assert_eq!(cm.tn, 1);
1782 }
1783
1784 #[test]
1785 fn test_precision_recall_f1() {
1786 let cm = ConfusionMatrix { tp: 5, fp: 2, tn: 8, fn_count: 1 };
1787 assert!((precision(&cm) - 5.0 / 7.0).abs() < 1e-12);
1788 assert!((recall(&cm) - 5.0 / 6.0).abs() < 1e-12);
1789 }
1790
1791 #[test]
1792 fn test_auc_perfect() {
1793 let scores = [0.9, 0.8, 0.2, 0.1];
1794 let labels = [true, true, false, false];
1795 let auc = auc_roc(&scores, &labels).unwrap();
1796 assert!((auc - 1.0).abs() < 1e-12);
1797 }
1798
1799 #[test]
1800 fn test_kfold_deterministic() {
1801 let f1 = kfold_indices(100, 5, 42);
1802 let f2 = kfold_indices(100, 5, 42);
1803 for i in 0..5 {
1804 assert_eq!(f1[i].0, f2[i].0);
1805 assert_eq!(f1[i].1, f2[i].1);
1806 }
1807 }
1808
1809 #[test]
1810 fn test_train_test_split_coverage() {
1811 let (train, test) = train_test_split(100, 0.2, 42);
1812 assert_eq!(train.len() + test.len(), 100);
1813 assert_eq!(test.len(), 20);
1814 }
1815
1816 #[test]
1819 fn test_batch_norm_identity() {
1820 let x = vec![1.0, 2.0, 3.0];
1822 let mean = vec![0.0, 0.0, 0.0];
1823 let var = vec![1.0, 1.0, 1.0];
1824 let gamma = vec![1.0, 1.0, 1.0];
1825 let beta = vec![0.0, 0.0, 0.0];
1826 let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1827 assert!((result[0] - 1.0).abs() < 1e-12);
1828 assert!((result[1] - 2.0).abs() < 1e-12);
1829 assert!((result[2] - 3.0).abs() < 1e-12);
1830 }
1831
1832 #[test]
1833 fn test_batch_norm_shift_scale() {
1834 let x = vec![0.0];
1835 let mean = vec![1.0]; let var = vec![4.0]; let gamma = vec![2.0]; let beta = vec![3.0]; let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1840 assert!((result[0] - 2.0).abs() < 1e-12);
1841 }
1842
1843 #[test]
1844 fn test_dropout_mask_seed_determinism() {
1845 let m1 = dropout_mask(100, 0.5, 42);
1846 let m2 = dropout_mask(100, 0.5, 42);
1847 assert_eq!(m1, m2);
1848 }
1849
1850 #[test]
1851 fn test_dropout_mask_different_seeds() {
1852 let m1 = dropout_mask(100, 0.5, 42);
1853 let m2 = dropout_mask(100, 0.5, 99);
1854 assert_ne!(m1, m2);
1855 }
1856
1857 #[test]
1858 fn test_lr_step_decay_schedule() {
1859 let lr0 = lr_step_decay(0.1, 0.5, 0, 10);
1860 assert!((lr0 - 0.1).abs() < 1e-12);
1861 let lr10 = lr_step_decay(0.1, 0.5, 10, 10);
1862 assert!((lr10 - 0.05).abs() < 1e-12);
1863 let lr20 = lr_step_decay(0.1, 0.5, 20, 10);
1864 assert!((lr20 - 0.025).abs() < 1e-12);
1865 }
1866
1867 #[test]
1868 fn test_lr_cosine_endpoints() {
1869 let lr0 = lr_cosine(0.1, 0.001, 0, 100);
1870 assert!((lr0 - 0.1).abs() < 1e-10);
1871 let lr_end = lr_cosine(0.1, 0.001, 100, 100);
1872 assert!((lr_end - 0.001).abs() < 1e-10);
1873 }
1874
1875 #[test]
1876 fn test_lr_linear_warmup() {
1877 let lr0 = lr_linear_warmup(0.1, 0, 10);
1878 assert!((lr0).abs() < 1e-12);
1879 let lr5 = lr_linear_warmup(0.1, 5, 10);
1880 assert!((lr5 - 0.05).abs() < 1e-12);
1881 let lr15 = lr_linear_warmup(0.1, 15, 10);
1882 assert!((lr15 - 0.1).abs() < 1e-12);
1883 }
1884
1885 #[test]
1886 fn test_l1_penalty_known() {
1887 let params = [1.0, -2.0, 3.0];
1888 let p = l1_penalty(¶ms, 0.1);
1889 assert!((p - 0.6).abs() < 1e-12);
1890 }
1891
1892 #[test]
1893 fn test_l2_penalty_known() {
1894 let params = [1.0, -2.0, 3.0];
1895 let p = l2_penalty(¶ms, 0.1);
1896 assert!((p - 0.7).abs() < 1e-12);
1898 }
1899
1900 #[test]
1901 fn test_early_stopping_triggers() {
1902 let mut es = EarlyStoppingState::new(3, 0.01);
1903 assert!(!es.check(1.0)); assert!(!es.check(1.0)); assert!(!es.check(1.0)); assert!(es.check(1.0)); }
1908
1909 #[test]
1910 fn test_early_stopping_resets() {
1911 let mut es = EarlyStoppingState::new(3, 0.01);
1912 es.check(1.0);
1913 es.check(1.0); assert!(!es.check(0.5)); assert!(!es.check(0.5)); }
1917
1918 #[test]
1921 fn test_pca_basic_2d() {
1922 let data = Tensor::from_vec(
1924 vec![
1925 1.0, 0.1,
1926 2.0, 0.2,
1927 3.0, 0.3,
1928 4.0, 0.4,
1929 ],
1930 &[4, 2],
1931 )
1932 .unwrap();
1933 let (transformed, components, evr) = pca(&data, 2).unwrap();
1934 assert_eq!(transformed.shape(), &[4, 2]);
1935 assert_eq!(components.shape(), &[2, 2]);
1936 assert_eq!(evr.len(), 2);
1937 let total: f64 = evr.iter().sum();
1939 assert!(
1940 (total - 1.0).abs() < 1e-8,
1941 "explained variance ratios sum to {} (expected ~1.0)",
1942 total
1943 );
1944 assert!(evr[0] > 0.9, "first component explains {} of variance", evr[0]);
1946 }
1947
1948 #[test]
1949 fn test_pca_single_component() {
1950 let data = Tensor::from_vec(
1951 vec![
1952 1.0, 2.0, 3.0,
1953 4.0, 5.0, 6.0,
1954 7.0, 8.0, 9.0,
1955 ],
1956 &[3, 3],
1957 )
1958 .unwrap();
1959 let (transformed, components, evr) = pca(&data, 1).unwrap();
1960 assert_eq!(transformed.shape(), &[3, 1]);
1961 assert_eq!(components.shape(), &[1, 3]);
1962 assert_eq!(evr.len(), 1);
1963 assert!(evr[0] > 0.0 && evr[0] <= 1.0);
1964 }
1965
1966 #[test]
1967 fn test_pca_explained_variance_ratio_bounded() {
1968 let data = Tensor::from_vec(
1969 vec![
1970 1.0, 0.0, 0.5,
1971 0.0, 1.0, 0.5,
1972 1.0, 1.0, 1.0,
1973 2.0, 0.0, 1.0,
1974 0.0, 2.0, 1.0,
1975 ],
1976 &[5, 3],
1977 )
1978 .unwrap();
1979 let (_, _, evr) = pca(&data, 3).unwrap();
1980 let total: f64 = evr.iter().sum();
1981 assert!(
1982 total <= 1.0 + 1e-10,
1983 "explained variance ratios sum to {} (should be <= 1.0)",
1984 total
1985 );
1986 for &r in &evr {
1987 assert!(r >= -1e-10, "negative explained variance ratio: {}", r);
1988 }
1989 }
1990
1991 #[test]
1992 fn test_pca_deterministic() {
1993 let data = Tensor::from_vec(
1994 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1995 &[2, 3],
1996 )
1997 .unwrap();
1998 let (t1, c1, e1) = pca(&data, 2).unwrap();
1999 let (t2, c2, e2) = pca(&data, 2).unwrap();
2000 assert_eq!(t1.to_vec(), t2.to_vec(), "PCA transformed not deterministic");
2001 assert_eq!(c1.to_vec(), c2.to_vec(), "PCA components not deterministic");
2002 assert_eq!(e1, e2, "PCA explained variance not deterministic");
2003 }
2004
2005 #[test]
2006 fn test_pca_invalid_n_components() {
2007 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
2008 assert!(pca(&data, 0).is_err(), "n_components=0 should fail");
2009 assert!(pca(&data, 3).is_err(), "n_components > min(n,p) should fail");
2010 }
2011
2012 fn rosenbrock(p: &[f64]) -> (f64, Vec<f64>) {
2019 let x = p[0];
2020 let y = p[1];
2021 let a = 1.0 - x;
2022 let b = y - x * x;
2023 let val = a * a + 100.0 * b * b;
2024 let gx = -2.0 * a - 400.0 * x * b;
2025 let gy = 200.0 * b;
2026 (val, vec![gx, gy])
2027 }
2028
2029 #[test]
2030 fn test_lbfgs_rosenbrock_converges() {
2031 let mut params = vec![-1.0_f64, 2.0_f64];
2033 let mut state = LbfgsState::new(0.5, 10);
2034
2035 let mut converged = false;
2036 for _iter in 0..200 {
2037 let (_, grads) = rosenbrock(¶ms);
2038 let grad_norm: f64 = kahan_dot(&grads, &grads).sqrt();
2039 if grad_norm < 1e-5 {
2040 converged = true;
2041 break;
2042 }
2043 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, rosenbrock);
2044 params = new_p;
2045 }
2046 assert!(converged, "L-BFGS did not converge on Rosenbrock; params = {:?}", params);
2047 assert!(
2048 (params[0] - 1.0).abs() < 1e-3,
2049 "x should converge near 1.0, got {}",
2050 params[0]
2051 );
2052 assert!(
2053 (params[1] - 1.0).abs() < 1e-3,
2054 "y should converge near 1.0, got {}",
2055 params[1]
2056 );
2057 }
2058
2059 #[test]
2060 fn test_lbfgs_determinism() {
2061 let init = vec![-1.0_f64, 2.0_f64];
2063
2064 let run = |init: &[f64]| -> Vec<f64> {
2065 let mut params = init.to_vec();
2066 let mut state = LbfgsState::new(0.5, 10);
2067 for _ in 0..20 {
2068 let (_, grads) = rosenbrock(¶ms);
2069 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, rosenbrock);
2070 params = new_p;
2071 }
2072 params
2073 };
2074
2075 let r1 = run(&init);
2076 let r2 = run(&init);
2077 assert_eq!(r1, r2, "L-BFGS must be bit-identical across runs");
2078 }
2079
2080 #[test]
2081 fn test_lbfgs_simple_quadratic() {
2082 let mut params = vec![3.0_f64];
2084 let mut state = LbfgsState::new(1.0, 5);
2085 let quadratic = |p: &[f64]| -> (f64, Vec<f64>) {
2086 (p[0] * p[0], vec![2.0 * p[0]])
2087 };
2088
2089 for _ in 0..30 {
2090 let (_, grads) = quadratic(¶ms);
2091 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, quadratic);
2092 params = new_p;
2093 }
2094 assert!(
2095 params[0].abs() < 1e-6,
2096 "L-BFGS should minimize x^2 to ~0, got {}",
2097 params[0]
2098 );
2099 }
2100
2101 #[test]
2102 fn test_embedding_basic() {
2103 let weight = crate::tensor::Tensor::from_vec(
2104 vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
2105 &[3, 2],
2106 ).unwrap();
2107 let indices = vec![0, 2, 1];
2108 let result = super::embedding(&weight, &indices).unwrap();
2109 assert_eq!(result.shape(), &[3, 2]);
2110 let data = result.to_vec();
2111 assert!((data[0] - 0.1).abs() < 1e-12);
2112 assert!((data[1] - 0.2).abs() < 1e-12);
2113 assert!((data[2] - 0.5).abs() < 1e-12);
2114 assert!((data[3] - 0.6).abs() < 1e-12);
2115 assert!((data[4] - 0.3).abs() < 1e-12);
2116 assert!((data[5] - 0.4).abs() < 1e-12);
2117 }
2118
2119 #[test]
2120 fn test_embedding_out_of_bounds() {
2121 let weight = crate::tensor::Tensor::from_vec(vec![1.0, 2.0], &[1, 2]).unwrap();
2122 let result = super::embedding(&weight, &[1]);
2123 assert!(result.is_err());
2124 }
2125
2126 #[test]
2127 fn test_batch_indices_deterministic() {
2128 let b1 = super::batch_indices(10, 3, 42);
2129 let b2 = super::batch_indices(10, 3, 42);
2130 assert_eq!(b1, b2);
2131 let total: usize = b1.iter().map(|(s, e)| e - s).sum();
2133 assert_eq!(total, 10);
2134 }
2135
2136 #[test]
2137 fn test_wolfe_line_search_armijo() {
2138 let params = vec![3.0_f64];
2141 let direction = vec![-1.0_f64];
2142 let grads = vec![6.0_f64]; let f0 = 9.0;
2144
2145 let mut eval_count = 0;
2146 let mut obj = |p: &[f64]| -> (f64, Vec<f64>) {
2147 eval_count += 1;
2148 (p[0] * p[0], vec![2.0 * p[0]])
2149 };
2150
2151 let (alpha, new_params, new_val, _) =
2152 wolfe_line_search(¶ms, &direction, &mut obj, f0, &grads, 1.0);
2153
2154 let c1 = 1e-4;
2156 let derphi0 = kahan_dot(&grads, &direction); assert!(
2158 new_val <= f0 + c1 * alpha * derphi0,
2159 "Armijo condition violated: {} > {} + {} * {} * {}",
2160 new_val, f0, c1, alpha, derphi0
2161 );
2162 assert!(new_params[0] < 3.0, "Step should move toward minimum");
2163 }
2164}