1use cjc_repro::KahanAccumulatorF64;
9
10use crate::accumulator::BinnedAccumulatorF64;
11use crate::error::RuntimeError;
12use crate::tensor::Tensor;
13
14pub fn mse_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
20 if pred.len() != target.len() {
21 return Err("mse_loss: arrays must have same length".into());
22 }
23 if pred.is_empty() {
24 return Err("mse_loss: empty data".into());
25 }
26 let mut acc = KahanAccumulatorF64::new();
27 for i in 0..pred.len() {
28 let d = pred[i] - target[i];
29 acc.add(d * d);
30 }
31 Ok(acc.finalize() / pred.len() as f64)
32}
33
34pub fn cross_entropy_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
36 if pred.len() != target.len() {
37 return Err("cross_entropy_loss: arrays must have same length".into());
38 }
39 if pred.is_empty() {
40 return Err("cross_entropy_loss: empty data".into());
41 }
42 let eps = 1e-12;
43 let mut acc = KahanAccumulatorF64::new();
44 for i in 0..pred.len() {
45 acc.add(-target[i] * (pred[i] + eps).ln());
46 }
47 Ok(acc.finalize() / pred.len() as f64)
48}
49
50pub fn binary_cross_entropy(pred: &[f64], target: &[f64]) -> Result<f64, String> {
52 if pred.len() != target.len() {
53 return Err("binary_cross_entropy: arrays must have same length".into());
54 }
55 if pred.is_empty() {
56 return Err("binary_cross_entropy: empty data".into());
57 }
58 let eps = 1e-12;
59 let mut acc = KahanAccumulatorF64::new();
60 for i in 0..pred.len() {
61 let p = pred[i].max(eps).min(1.0 - eps);
62 acc.add(-(target[i] * p.ln() + (1.0 - target[i]) * (1.0 - p).ln()));
63 }
64 Ok(acc.finalize() / pred.len() as f64)
65}
66
67pub fn huber_loss(pred: &[f64], target: &[f64], delta: f64) -> Result<f64, String> {
69 if pred.len() != target.len() {
70 return Err("huber_loss: arrays must have same length".into());
71 }
72 if pred.is_empty() {
73 return Err("huber_loss: empty data".into());
74 }
75 let mut acc = KahanAccumulatorF64::new();
76 for i in 0..pred.len() {
77 let d = (pred[i] - target[i]).abs();
78 if d <= delta {
79 acc.add(0.5 * d * d);
80 } else {
81 acc.add(delta * (d - 0.5 * delta));
82 }
83 }
84 Ok(acc.finalize() / pred.len() as f64)
85}
86
87pub fn hinge_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
89 if pred.len() != target.len() {
90 return Err("hinge_loss: arrays must have same length".into());
91 }
92 if pred.is_empty() {
93 return Err("hinge_loss: empty data".into());
94 }
95 let mut acc = KahanAccumulatorF64::new();
96 for i in 0..pred.len() {
97 acc.add((1.0 - target[i] * pred[i]).max(0.0));
98 }
99 Ok(acc.finalize() / pred.len() as f64)
100}
101
102pub struct SgdState {
108 pub lr: f64,
109 pub momentum: f64,
110 pub velocity: Vec<f64>,
111}
112
113impl SgdState {
114 pub fn new(n_params: usize, lr: f64, momentum: f64) -> Self {
115 Self { lr, momentum, velocity: vec![0.0; n_params] }
116 }
117}
118
119pub fn sgd_step(params: &mut [f64], grads: &[f64], state: &mut SgdState) {
121 for i in 0..params.len() {
122 state.velocity[i] = state.momentum * state.velocity[i] + grads[i];
123 params[i] -= state.lr * state.velocity[i];
124 }
125}
126
127pub struct AdamState {
129 pub lr: f64,
130 pub beta1: f64,
131 pub beta2: f64,
132 pub eps: f64,
133 pub t: u64,
134 pub m: Vec<f64>,
135 pub v: Vec<f64>,
136}
137
138impl AdamState {
139 pub fn new(n_params: usize, lr: f64) -> Self {
140 Self {
141 lr,
142 beta1: 0.9,
143 beta2: 0.999,
144 eps: 1e-8,
145 t: 0,
146 m: vec![0.0; n_params],
147 v: vec![0.0; n_params],
148 }
149 }
150}
151
152pub fn adam_step(params: &mut [f64], grads: &[f64], state: &mut AdamState) {
154 state.t += 1;
155 let t = state.t as f64;
156 for i in 0..params.len() {
157 state.m[i] = state.beta1 * state.m[i] + (1.0 - state.beta1) * grads[i];
158 state.v[i] = state.beta2 * state.v[i] + (1.0 - state.beta2) * grads[i] * grads[i];
159 let m_hat = state.m[i] / (1.0 - state.beta1.powf(t));
160 let v_hat = state.v[i] / (1.0 - state.beta2.powf(t));
161 params[i] -= state.lr * m_hat / (v_hat.sqrt() + state.eps);
162 }
163}
164
165#[derive(Debug, Clone)]
171pub struct ConfusionMatrix {
172 pub tp: usize,
173 pub fp: usize,
174 pub tn: usize,
175 pub fn_count: usize,
176}
177
178pub fn confusion_matrix(predicted: &[bool], actual: &[bool]) -> ConfusionMatrix {
180 let mut tp = 0;
181 let mut fp = 0;
182 let mut tn = 0;
183 let mut fn_count = 0;
184 for i in 0..predicted.len().min(actual.len()) {
185 match (predicted[i], actual[i]) {
186 (true, true) => tp += 1,
187 (true, false) => fp += 1,
188 (false, true) => fn_count += 1,
189 (false, false) => tn += 1,
190 }
191 }
192 ConfusionMatrix { tp, fp, tn, fn_count }
193}
194
195pub fn precision(cm: &ConfusionMatrix) -> f64 {
197 let denom = cm.tp + cm.fp;
198 if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
199}
200
201pub fn recall(cm: &ConfusionMatrix) -> f64 {
203 let denom = cm.tp + cm.fn_count;
204 if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
205}
206
207pub fn f1_score(cm: &ConfusionMatrix) -> f64 {
209 let p = precision(cm);
210 let r = recall(cm);
211 if p + r == 0.0 { 0.0 } else { 2.0 * p * r / (p + r) }
212}
213
214pub fn accuracy(cm: &ConfusionMatrix) -> f64 {
216 let total = cm.tp + cm.fp + cm.tn + cm.fn_count;
217 if total == 0 { 0.0 } else { (cm.tp + cm.tn) as f64 / total as f64 }
218}
219
220pub fn auc_roc(scores: &[f64], labels: &[bool]) -> Result<f64, String> {
223 if scores.len() != labels.len() {
224 return Err("auc_roc: scores and labels must have same length".into());
225 }
226 let n = scores.len();
227 if n == 0 {
228 return Err("auc_roc: empty data".into());
229 }
230 let mut indexed: Vec<(usize, f64, bool)> = scores.iter().zip(labels.iter())
232 .enumerate()
233 .map(|(i, (&s, &l))| (i, s, l))
234 .collect();
235 indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
236
237 let pos_count = labels.iter().filter(|&&l| l).count();
238 let neg_count = n - pos_count;
239 if pos_count == 0 || neg_count == 0 {
240 return Err("auc_roc: need both positive and negative labels".into());
241 }
242
243 let mut auc = 0.0;
244 let mut tp = 0.0;
245 let mut fp = 0.0;
246 let mut prev_fpr = 0.0;
247 let mut prev_tpr = 0.0;
248
249 for &(_, _, label) in &indexed {
250 if label { tp += 1.0; } else { fp += 1.0; }
251 let tpr = tp / pos_count as f64;
252 let fpr = fp / neg_count as f64;
253 auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
254 prev_fpr = fpr;
255 prev_tpr = tpr;
256 }
257 Ok(auc)
258}
259
260pub fn kfold_indices(n: usize, k: usize, seed: u64) -> Vec<(Vec<usize>, Vec<usize>)> {
263 let mut rng = cjc_repro::Rng::seeded(seed);
264 let mut indices: Vec<usize> = (0..n).collect();
266 for i in (1..n).rev() {
267 let j = (rng.next_u64() as usize) % (i + 1);
268 indices.swap(i, j);
269 }
270 let fold_size = n / k;
271 let mut folds = Vec::with_capacity(k);
272 for fold in 0..k {
273 let start = fold * fold_size;
274 let end = if fold == k - 1 { n } else { start + fold_size };
275 let test: Vec<usize> = indices[start..end].to_vec();
276 let train: Vec<usize> = indices[..start].iter()
277 .chain(indices[end..].iter())
278 .copied()
279 .collect();
280 folds.push((train, test));
281 }
282 folds
283}
284
285pub fn train_test_split(n: usize, test_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
287 let mut rng = cjc_repro::Rng::seeded(seed);
288 let mut indices: Vec<usize> = (0..n).collect();
289 for i in (1..n).rev() {
290 let j = (rng.next_u64() as usize) % (i + 1);
291 indices.swap(i, j);
292 }
293 let test_size = ((n as f64) * test_fraction).round() as usize;
294 let test = indices[..test_size].to_vec();
295 let train = indices[test_size..].to_vec();
296 (train, test)
297}
298
299pub fn bootstrap(data: &[f64], n_resamples: usize, stat_fn: usize, seed: u64) -> Result<(f64, f64, f64, f64), String> {
307 if data.is_empty() { return Err("bootstrap: empty data".into()); }
308 let n = data.len();
309
310 let point = compute_stat(data, stat_fn)?;
312
313 let mut rng = cjc_repro::Rng::seeded(seed);
315 let mut stats = Vec::with_capacity(n_resamples);
316 let mut resample = Vec::with_capacity(n);
317
318 for _ in 0..n_resamples {
319 resample.clear();
320 for _ in 0..n {
321 let idx = (rng.next_u64() as usize) % n;
322 resample.push(data[idx]);
323 }
324 stats.push(compute_stat(&resample, stat_fn)?);
325 }
326
327 stats.sort_by(|a, b| a.total_cmp(b));
329
330 let ci_lower = stats[(n_resamples as f64 * 0.025) as usize];
331 let ci_upper = stats[(n_resamples as f64 * 0.975).min((n_resamples - 1) as f64) as usize];
332
333 let mean_stats: f64 = {
335 let mut acc = cjc_repro::KahanAccumulatorF64::new();
336 for &s in &stats { acc.add(s); }
337 acc.finalize() / n_resamples as f64
338 };
339 let se = {
340 let mut acc = cjc_repro::KahanAccumulatorF64::new();
341 for &s in &stats { let d = s - mean_stats; acc.add(d * d); }
342 (acc.finalize() / (n_resamples as f64 - 1.0)).sqrt()
343 };
344
345 Ok((point, ci_lower, ci_upper, se))
346}
347
348fn compute_stat(data: &[f64], stat_fn: usize) -> Result<f64, String> {
349 match stat_fn {
350 0 => {
351 let mut acc = cjc_repro::KahanAccumulatorF64::new();
353 for &x in data { acc.add(x); }
354 Ok(acc.finalize() / data.len() as f64)
355 }
356 1 => {
357 let mut sorted = data.to_vec();
359 sorted.sort_by(|a, b| a.total_cmp(b));
360 let n = sorted.len();
361 if n % 2 == 0 {
362 Ok((sorted[n/2 - 1] + sorted[n/2]) / 2.0)
363 } else {
364 Ok(sorted[n/2])
365 }
366 }
367 _ => Err(format!("bootstrap: unknown stat_fn {}", stat_fn)),
368 }
369}
370
371pub fn permutation_test(x: &[f64], y: &[f64], n_perms: usize, seed: u64) -> Result<(f64, f64), String> {
374 if x.is_empty() || y.is_empty() { return Err("permutation_test: empty group".into()); }
375
376 let nx = x.len();
377 let combined: Vec<f64> = x.iter().chain(y.iter()).copied().collect();
378 let n = combined.len();
379
380 let mean_x = compute_stat(x, 0)?;
382 let mean_y = compute_stat(y, 0)?;
383 let observed = (mean_x - mean_y).abs();
384
385 let mut rng = cjc_repro::Rng::seeded(seed);
387 let mut count_extreme = 0usize;
388 let mut perm = combined.clone();
389
390 for _ in 0..n_perms {
391 for i in (1..n).rev() {
393 let j = (rng.next_u64() as usize) % (i + 1);
394 perm.swap(i, j);
395 }
396 let perm_mean_x = compute_stat(&perm[..nx], 0)?;
397 let perm_mean_y = compute_stat(&perm[nx..], 0)?;
398 if (perm_mean_x - perm_mean_y).abs() >= observed {
399 count_extreme += 1;
400 }
401 }
402
403 let p_value = count_extreme as f64 / n_perms as f64;
404 Ok((observed, p_value))
405}
406
407pub fn stratified_split(labels: &[i64], test_frac: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
411 use std::collections::BTreeMap;
412
413 let n = labels.len();
414 let mut groups: BTreeMap<i64, Vec<usize>> = BTreeMap::new();
416 for (i, &label) in labels.iter().enumerate() {
417 groups.entry(label).or_default().push(i);
418 }
419
420 let mut train = Vec::with_capacity(n);
421 let mut test = Vec::with_capacity(n);
422 let mut rng = cjc_repro::Rng::seeded(seed);
423
424 for (_label, mut indices) in groups {
425 let m = indices.len();
427 for i in (1..m).rev() {
428 let j = (rng.next_u64() as usize) % (i + 1);
429 indices.swap(i, j);
430 }
431 let n_test = ((m as f64 * test_frac).round() as usize).max(if m > 1 { 1 } else { 0 });
432 let n_test = n_test.min(m);
433 test.extend_from_slice(&indices[..n_test]);
434 train.extend_from_slice(&indices[n_test..]);
435 }
436
437 train.sort();
439 test.sort();
440 (train, test)
441}
442
443pub fn batch_norm(
450 x: &[f64],
451 running_mean: &[f64],
452 running_var: &[f64],
453 gamma: &[f64],
454 beta: &[f64],
455 eps: f64,
456) -> Result<Vec<f64>, String> {
457 let n = x.len();
458 if running_mean.len() != n || running_var.len() != n || gamma.len() != n || beta.len() != n {
459 return Err("batch_norm: all arrays must have same length".into());
460 }
461 let mut result = Vec::with_capacity(n);
462 for i in 0..n {
463 let normed = (x[i] - running_mean[i]) / (running_var[i] + eps).sqrt();
464 result.push(gamma[i] * normed + beta[i]);
465 }
466 Ok(result)
467}
468
469pub fn dropout_mask(n: usize, drop_prob: f64, seed: u64) -> Vec<f64> {
472 let mut rng = cjc_repro::Rng::seeded(seed);
473 let scale = if drop_prob < 1.0 { 1.0 / (1.0 - drop_prob) } else { 0.0 };
474 let mut mask = Vec::with_capacity(n);
475 for _ in 0..n {
476 let r = (rng.next_u64() as f64) / (u64::MAX as f64);
477 if r < drop_prob {
478 mask.push(0.0);
479 } else {
480 mask.push(scale);
481 }
482 }
483 mask
484}
485
486pub fn apply_dropout(data: &[f64], mask: &[f64]) -> Result<Vec<f64>, String> {
488 if data.len() != mask.len() {
489 return Err("apply_dropout: data and mask must have same length".into());
490 }
491 Ok(data.iter().zip(mask.iter()).map(|(&d, &m)| d * m).collect())
492}
493
494pub fn lr_step_decay(initial_lr: f64, decay_rate: f64, epoch: usize, step_size: usize) -> f64 {
497 initial_lr * decay_rate.powi((epoch / step_size) as i32)
498}
499
500pub fn lr_cosine(max_lr: f64, min_lr: f64, epoch: usize, total_epochs: usize) -> f64 {
503 let ratio = epoch as f64 / total_epochs as f64;
504 min_lr + 0.5 * (max_lr - min_lr) * (1.0 + (std::f64::consts::PI * ratio).cos())
505}
506
507pub fn lr_linear_warmup(initial_lr: f64, epoch: usize, warmup_epochs: usize) -> f64 {
510 if warmup_epochs == 0 {
511 return initial_lr;
512 }
513 initial_lr * (epoch as f64 / warmup_epochs as f64).min(1.0)
514}
515
516pub fn l1_penalty(params: &[f64], lambda: f64) -> f64 {
518 let mut acc = KahanAccumulatorF64::new();
519 for &p in params {
520 acc.add(p.abs());
521 }
522 lambda * acc.finalize()
523}
524
525pub fn l2_penalty(params: &[f64], lambda: f64) -> f64 {
527 let mut acc = KahanAccumulatorF64::new();
528 for &p in params {
529 acc.add(p * p);
530 }
531 0.5 * lambda * acc.finalize()
532}
533
534pub fn l1_grad(params: &[f64], lambda: f64) -> Vec<f64> {
536 params.iter().map(|&p| {
537 if p > 0.0 { lambda } else if p < 0.0 { -lambda } else { 0.0 }
538 }).collect()
539}
540
541pub fn l2_grad(params: &[f64], lambda: f64) -> Vec<f64> {
543 params.iter().map(|&p| lambda * p).collect()
544}
545
546pub struct EarlyStoppingState {
548 pub patience: usize,
549 pub min_delta: f64,
550 pub best_loss: f64,
551 pub wait: usize,
552 pub stopped: bool,
553}
554
555impl EarlyStoppingState {
556 pub fn new(patience: usize, min_delta: f64) -> Self {
557 Self {
558 patience,
559 min_delta,
560 best_loss: f64::INFINITY,
561 wait: 0,
562 stopped: false,
563 }
564 }
565
566 pub fn check(&mut self, current_loss: f64) -> bool {
568 if current_loss < self.best_loss - self.min_delta {
569 self.best_loss = current_loss;
570 self.wait = 0;
571 } else {
572 self.wait += 1;
573 }
574 if self.wait >= self.patience {
575 self.stopped = true;
576 }
577 self.stopped
578 }
579}
580
581pub fn pca(
597 data: &Tensor,
598 n_components: usize,
599) -> Result<(Tensor, Tensor, Vec<f64>), RuntimeError> {
600 if data.ndim() != 2 {
601 return Err(RuntimeError::InvalidOperation(
602 "PCA requires a 2D data matrix".to_string(),
603 ));
604 }
605 let n_samples = data.shape()[0];
606 let n_features = data.shape()[1];
607
608 if n_samples == 0 || n_features == 0 {
609 return Err(RuntimeError::InvalidOperation(
610 "PCA: empty data matrix".to_string(),
611 ));
612 }
613 if n_components == 0 || n_components > n_features.min(n_samples) {
614 return Err(RuntimeError::InvalidOperation(format!(
615 "PCA: n_components ({}) must be in [1, min(n_samples, n_features) = {}]",
616 n_components,
617 n_features.min(n_samples)
618 )));
619 }
620
621 let raw = data.to_vec();
622
623 let mut means = vec![0.0f64; n_features];
625 for j in 0..n_features {
626 let mut acc = BinnedAccumulatorF64::new();
627 for i in 0..n_samples {
628 acc.add(raw[i * n_features + j]);
629 }
630 means[j] = acc.finalize() / n_samples as f64;
631 }
632
633 let mut centered = vec![0.0f64; n_samples * n_features];
635 for i in 0..n_samples {
636 for j in 0..n_features {
637 centered[i * n_features + j] = raw[i * n_features + j] - means[j];
638 }
639 }
640 let centered_tensor = Tensor::from_vec(centered, &[n_samples, n_features])?;
641
642 let (u, s, vt) = centered_tensor.svd()?;
644 let k = n_components.min(s.len());
645
646 let vt_data = vt.to_vec();
648 let vt_cols = vt.shape()[1]; let mut components = vec![0.0f64; k * n_features];
650 for i in 0..k {
651 for j in 0..n_features {
652 components[i * n_features + j] = vt_data[i * vt_cols + j];
653 }
654 }
655
656 let denom = if n_samples > 1 {
659 (n_samples - 1) as f64
660 } else {
661 1.0
662 };
663
664 let mut total_var_acc = BinnedAccumulatorF64::new();
665 for &si in &s {
666 total_var_acc.add(si * si / denom);
667 }
668 let total_var = total_var_acc.finalize();
669
670 let explained_variance_ratio: Vec<f64> = if total_var > 1e-15 {
671 s[..k]
672 .iter()
673 .map(|&si| (si * si / denom) / total_var)
674 .collect()
675 } else {
676 vec![0.0; k]
677 };
678
679 let u_data = u.to_vec();
681 let u_cols = u.shape()[1];
682 let mut transformed = vec![0.0f64; n_samples * k];
683 for i in 0..n_samples {
684 for j in 0..k {
685 transformed[i * k + j] = u_data[i * u_cols + j] * s[j];
686 }
687 }
688
689 Ok((
690 Tensor::from_vec(transformed, &[n_samples, k])?,
691 Tensor::from_vec(components, &[k, n_features])?,
692 explained_variance_ratio,
693 ))
694}
695
696pub struct LbfgsState {
706 pub lr: f64,
707 pub m: usize,
709 pub s_history: Vec<Vec<f64>>,
711 pub y_history: Vec<Vec<f64>>,
713 pub prev_params: Option<Vec<f64>>,
714 pub prev_grad: Option<Vec<f64>>,
715}
716
717impl LbfgsState {
718 pub fn new(lr: f64, m: usize) -> Self {
719 Self {
720 lr,
721 m,
722 s_history: Vec::new(),
723 y_history: Vec::new(),
724 prev_params: None,
725 prev_grad: None,
726 }
727 }
728}
729
730fn kahan_dot(a: &[f64], b: &[f64]) -> f64 {
732 debug_assert_eq!(a.len(), b.len());
733 let mut acc = KahanAccumulatorF64::new();
734 for (&ai, &bi) in a.iter().zip(b.iter()) {
735 acc.add(ai * bi);
736 }
737 acc.finalize()
738}
739
740pub fn wolfe_line_search<F>(
759 params: &[f64],
760 direction: &[f64],
761 f: &mut F,
762 f0: f64,
763 g0: &[f64],
764 alpha_init: f64,
765) -> (f64, Vec<f64>, f64, Vec<f64>)
766where
767 F: FnMut(&[f64]) -> (f64, Vec<f64>),
768{
769 let c1 = 1e-4_f64;
770 let c2 = 0.9_f64;
771 let derphi0 = kahan_dot(g0, direction); let step = |alpha: f64| -> Vec<f64> {
775 params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
776 };
777
778 let max_iter = 30;
779 let mut alpha_lo = 0.0_f64;
780 let mut alpha_hi = f64::INFINITY;
781 let mut phi_lo = f0;
782 let mut dphi_lo = derphi0;
783 let mut alpha = alpha_init;
784
785 let mut best_alpha = alpha_init;
787 let mut best_params = step(alpha_init);
788 let (mut best_val, mut best_grad) = f(&best_params);
789
790 for _iter in 0..max_iter {
791 let x_alpha = step(alpha);
792 let (phi_alpha, grad_alpha) = f(&x_alpha);
793 let dphi_alpha = kahan_dot(&grad_alpha, direction);
794
795 if phi_alpha < best_val {
797 best_alpha = alpha;
798 best_params = x_alpha.clone();
799 best_val = phi_alpha;
800 best_grad = grad_alpha.clone();
801 }
802
803 if phi_alpha > f0 + c1 * alpha * derphi0 || (phi_alpha >= phi_lo && alpha_lo > 0.0) {
805 alpha_hi = alpha;
806 let (za, zp, zv, zg) = wolfe_zoom(
808 params, direction, f, f0, derphi0, c1, c2,
809 alpha_lo, alpha_hi, phi_lo, dphi_lo,
810 );
811 return (za, zp, zv, zg);
812 }
813
814 if dphi_alpha.abs() <= c2 * derphi0.abs() {
816 return (alpha, x_alpha, phi_alpha, grad_alpha);
817 }
818
819 if dphi_alpha >= 0.0 {
821 let (za, zp, zv, zg) = wolfe_zoom(
822 params, direction, f, f0, derphi0, c1, c2,
823 alpha, alpha_lo, phi_alpha, dphi_alpha,
824 );
825 return (za, zp, zv, zg);
826 }
827
828 alpha_lo = alpha;
830 phi_lo = phi_alpha;
831 dphi_lo = dphi_alpha;
832
833 alpha = if alpha_hi.is_finite() {
835 (alpha_lo + alpha_hi) * 0.5
836 } else {
837 (alpha * 2.0).min(1e8)
838 };
839 }
840
841 (best_alpha, best_params, best_val, best_grad)
843}
844
845#[allow(clippy::too_many_arguments)]
847fn wolfe_zoom<F>(
848 params: &[f64],
849 direction: &[f64],
850 f: &mut F,
851 f0: f64,
852 derphi0: f64,
853 c1: f64,
854 c2: f64,
855 mut alpha_lo: f64,
856 mut alpha_hi: f64,
857 mut phi_lo: f64,
858 _dphi_lo: f64,
859) -> (f64, Vec<f64>, f64, Vec<f64>)
860where
861 F: FnMut(&[f64]) -> (f64, Vec<f64>),
862{
863 let step = |alpha: f64| -> Vec<f64> {
864 params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
865 };
866
867 let max_zoom = 20;
868 let mut best_alpha = alpha_lo;
869 let mut best_x = step(alpha_lo);
870 let (mut best_val, mut best_grad) = f(&best_x);
871
872 for _i in 0..max_zoom {
873 let alpha_j = (alpha_lo + alpha_hi) * 0.5;
874 let x_j = step(alpha_j);
875 let (phi_j, grad_j) = f(&x_j);
876 let dphi_j = kahan_dot(&grad_j, direction);
877
878 if phi_j < best_val {
879 best_alpha = alpha_j;
880 best_x = x_j.clone();
881 best_val = phi_j;
882 best_grad = grad_j.clone();
883 }
884
885 if phi_j > f0 + c1 * alpha_j * derphi0 || phi_j >= phi_lo {
886 alpha_hi = alpha_j;
887 } else {
888 if dphi_j.abs() <= c2 * derphi0.abs() {
890 return (alpha_j, x_j, phi_j, grad_j);
891 }
892 if dphi_j * (alpha_hi - alpha_lo) >= 0.0 {
893 alpha_hi = alpha_lo;
894 }
895 alpha_lo = alpha_j;
896 phi_lo = phi_j;
897 }
898
899 if (alpha_hi - alpha_lo).abs() < 1e-14 {
901 break;
902 }
903 }
904
905 (best_alpha, best_x, best_val, best_grad)
906}
907
908pub fn lbfgs_step<F>(
925 params: &[f64],
926 grads: &[f64],
927 state: &mut LbfgsState,
928 mut f: F,
929) -> (Vec<f64>, Vec<f64>, bool)
930where
931 F: FnMut(&[f64]) -> (f64, Vec<f64>),
932{
933 let n = params.len();
934 debug_assert_eq!(grads.len(), n);
935
936 let hist_len = state.s_history.len();
939 let mut q: Vec<f64> = grads.to_vec();
940 let mut alphas = vec![0.0_f64; hist_len];
941 let mut rhos = vec![0.0_f64; hist_len];
942
943 for i in (0..hist_len).rev() {
945 let sy = kahan_dot(&state.s_history[i], &state.y_history[i]);
946 rhos[i] = if sy.abs() < 1e-300 { 0.0 } else { 1.0 / sy };
947 let sq = kahan_dot(&state.s_history[i], &q);
948 alphas[i] = rhos[i] * sq;
949 for j in 0..n {
951 q[j] -= alphas[i] * state.y_history[i][j];
952 }
953 }
954
955 let scale = if hist_len > 0 {
957 let last = hist_len - 1;
958 let sy = kahan_dot(&state.s_history[last], &state.y_history[last]);
959 let yy = kahan_dot(&state.y_history[last], &state.y_history[last]);
960 if yy.abs() < 1e-300 { 1.0 } else { sy / yy }
961 } else {
962 1.0
963 };
964 let mut r: Vec<f64> = q.iter().map(|&qi| scale * qi).collect();
965
966 for i in 0..hist_len {
968 let yr = kahan_dot(&state.y_history[i], &r);
969 let beta = rhos[i] * yr;
970 let diff = alphas[i] - beta;
972 for j in 0..n {
973 r[j] += diff * state.s_history[i][j];
974 }
975 }
976
977 let direction: Vec<f64> = r.iter().map(|&ri| -ri).collect();
979
980 let descent_check = kahan_dot(&direction, grads);
982 let (direction, is_descent) = if descent_check >= 0.0 || !descent_check.is_finite() {
983 let norm_g = kahan_dot(grads, grads).sqrt().max(1e-300);
985 (grads.iter().map(|&g| -g / norm_g).collect::<Vec<f64>>(), false)
986 } else {
987 (direction, true)
988 };
989
990 let (f0, _) = f(params);
992 let (_, new_params, _, new_grads) = wolfe_line_search(
993 params,
994 &direction,
995 &mut f,
996 f0,
997 grads,
998 state.lr,
999 );
1000
1001 let s_k: Vec<f64> = new_params.iter().zip(params.iter()).map(|(&np, &p)| np - p).collect();
1004 let y_k: Vec<f64> = new_grads.iter().zip(grads.iter()).map(|(&ng, &g)| ng - g).collect();
1006
1007 let sy = kahan_dot(&s_k, &y_k);
1008 if sy > 1e-300 {
1010 state.s_history.push(s_k);
1011 state.y_history.push(y_k);
1012 if state.s_history.len() > state.m {
1014 state.s_history.remove(0);
1015 state.y_history.remove(0);
1016 }
1017 }
1018
1019 state.prev_params = Some(new_params.clone());
1020 state.prev_grad = Some(new_grads.clone());
1021
1022 (new_params, new_grads, is_descent)
1023}
1024
1025pub fn lstm_cell(
1044 x: &Tensor,
1045 h_prev: &Tensor,
1046 c_prev: &Tensor,
1047 w_ih: &Tensor,
1048 w_hh: &Tensor,
1049 b_ih: &Tensor,
1050 b_hh: &Tensor,
1051) -> Result<(Tensor, Tensor), String> {
1052 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1053
1054 if x.ndim() != 2 {
1056 return Err("lstm_cell: x must be 2-D [batch, input_size]".into());
1057 }
1058 if h_prev.ndim() != 2 {
1059 return Err("lstm_cell: h_prev must be 2-D [batch, hidden_size]".into());
1060 }
1061 if c_prev.ndim() != 2 {
1062 return Err("lstm_cell: c_prev must be 2-D [batch, hidden_size]".into());
1063 }
1064 let hidden_size = h_prev.shape()[1];
1065 if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1066 return Err(format!(
1067 "lstm_cell: w_ih must be [4*hidden_size, input_size], got {:?}",
1068 w_ih.shape()
1069 ));
1070 }
1071 if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1072 return Err(format!(
1073 "lstm_cell: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1074 w_hh.shape()
1075 ));
1076 }
1077 if b_ih.len() != 4 * hidden_size {
1078 return Err(format!(
1079 "lstm_cell: b_ih must have length 4*hidden_size={}, got {}",
1080 4 * hidden_size,
1081 b_ih.len()
1082 ));
1083 }
1084 if b_hh.len() != 4 * hidden_size {
1085 return Err(format!(
1086 "lstm_cell: b_hh must have length 4*hidden_size={}, got {}",
1087 4 * hidden_size,
1088 b_hh.len()
1089 ));
1090 }
1091
1092 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1095 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1096 let gates = gates_ih.add(&gates_hh).map_err(map_err)?;
1097
1098 let chunks = gates.chunk(4, 1).map_err(map_err)?;
1100 let gates_i = &chunks[0];
1101 let gates_f = &chunks[1];
1102 let gates_g = &chunks[2];
1103 let gates_o = &chunks[3];
1104
1105 let i = gates_i.sigmoid();
1107 let f = gates_f.sigmoid();
1108 let g = gates_g.tanh_activation();
1109 let o = gates_o.sigmoid();
1110
1111 let fc = f.mul_elem(c_prev).map_err(map_err)?;
1113 let ig = i.mul_elem(&g).map_err(map_err)?;
1114 let c_new = fc.add(&ig).map_err(map_err)?;
1115
1116 let c_tanh = c_new.tanh_activation();
1118 let h_new = o.mul_elem(&c_tanh).map_err(map_err)?;
1119
1120 Ok((h_new, c_new))
1121}
1122
1123pub fn gru_cell(
1141 x: &Tensor,
1142 h_prev: &Tensor,
1143 w_ih: &Tensor,
1144 w_hh: &Tensor,
1145 b_ih: &Tensor,
1146 b_hh: &Tensor,
1147) -> Result<Tensor, String> {
1148 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1149
1150 if x.ndim() != 2 {
1152 return Err("gru_cell: x must be 2-D [batch, input_size]".into());
1153 }
1154 if h_prev.ndim() != 2 {
1155 return Err("gru_cell: h_prev must be 2-D [batch, hidden_size]".into());
1156 }
1157 let hidden_size = h_prev.shape()[1];
1158 if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1159 return Err(format!(
1160 "gru_cell: w_ih must be [3*hidden_size, input_size], got {:?}",
1161 w_ih.shape()
1162 ));
1163 }
1164 if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1165 return Err(format!(
1166 "gru_cell: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1167 w_hh.shape()
1168 ));
1169 }
1170 if b_ih.len() != 3 * hidden_size {
1171 return Err(format!(
1172 "gru_cell: b_ih must have length 3*hidden_size={}, got {}",
1173 3 * hidden_size,
1174 b_ih.len()
1175 ));
1176 }
1177 if b_hh.len() != 3 * hidden_size {
1178 return Err(format!(
1179 "gru_cell: b_hh must have length 3*hidden_size={}, got {}",
1180 3 * hidden_size,
1181 b_hh.len()
1182 ));
1183 }
1184
1185 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1188 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1189
1190 let ih_chunks = gates_ih.chunk(3, 1).map_err(map_err)?;
1192 let hh_chunks = gates_hh.chunk(3, 1).map_err(map_err)?;
1193 let r_ih = &ih_chunks[0];
1194 let z_ih = &ih_chunks[1];
1195 let n_ih = &ih_chunks[2];
1196 let r_hh = &hh_chunks[0];
1197 let z_hh = &hh_chunks[1];
1198 let n_hh = &hh_chunks[2];
1199
1200 let r = r_ih.add(r_hh).map_err(map_err)?.sigmoid();
1202 let z = z_ih.add(z_hh).map_err(map_err)?.sigmoid();
1204 let r_n_hh = r.mul_elem(n_hh).map_err(map_err)?;
1206 let n = n_ih.add(&r_n_hh).map_err(map_err)?.tanh_activation();
1207
1208 let ones = Tensor::ones(z.shape());
1211 let one_minus_z = ones.sub(&z).map_err(map_err)?;
1212 let term1 = one_minus_z.mul_elem(&n).map_err(map_err)?;
1213 let term2 = z.mul_elem(h_prev).map_err(map_err)?;
1214 let h_new = term1.add(&term2).map_err(map_err)?;
1215
1216 Ok(h_new)
1217}
1218
1219pub fn lstm_cell_fused(
1233 x: &Tensor,
1234 h_prev: &Tensor,
1235 c_prev: &Tensor,
1236 w_ih: &Tensor,
1237 w_hh: &Tensor,
1238 b_ih: &Tensor,
1239 b_hh: &Tensor,
1240) -> Result<(Tensor, Tensor), String> {
1241 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1242
1243 if x.ndim() != 2 {
1245 return Err("lstm_cell_fused: x must be 2-D [batch, input_size]".into());
1246 }
1247 if h_prev.ndim() != 2 {
1248 return Err("lstm_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1249 }
1250 if c_prev.ndim() != 2 {
1251 return Err("lstm_cell_fused: c_prev must be 2-D [batch, hidden_size]".into());
1252 }
1253 let batch = x.shape()[0];
1254 let hidden_size = h_prev.shape()[1];
1255 if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
1256 return Err(format!(
1257 "lstm_cell_fused: w_ih must be [4*hidden_size, input_size], got {:?}",
1258 w_ih.shape()
1259 ));
1260 }
1261 if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
1262 return Err(format!(
1263 "lstm_cell_fused: w_hh must be [4*hidden_size, hidden_size], got {:?}",
1264 w_hh.shape()
1265 ));
1266 }
1267 if b_ih.len() != 4 * hidden_size {
1268 return Err(format!(
1269 "lstm_cell_fused: b_ih must have length 4*hidden_size={}, got {}",
1270 4 * hidden_size,
1271 b_ih.len()
1272 ));
1273 }
1274 if b_hh.len() != 4 * hidden_size {
1275 return Err(format!(
1276 "lstm_cell_fused: b_hh must have length 4*hidden_size={}, got {}",
1277 4 * hidden_size,
1278 b_hh.len()
1279 ));
1280 }
1281
1282 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1284 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1285
1286 let gih = gates_ih.to_vec();
1288 let ghh = gates_hh.to_vec();
1289 let cprev = c_prev.to_vec();
1290
1291 let mut h_new_data = vec![0.0f64; batch * hidden_size];
1292 let mut c_new_data = vec![0.0f64; batch * hidden_size];
1293
1294 for b_idx in 0..batch {
1295 let base = b_idx * 4 * hidden_size;
1296 for h in 0..hidden_size {
1297 let gi = gih[base + h] + ghh[base + h];
1299 let gf = gih[base + hidden_size + h] + ghh[base + hidden_size + h];
1300 let gg = gih[base + 2 * hidden_size + h] + ghh[base + 2 * hidden_size + h];
1301 let go = gih[base + 3 * hidden_size + h] + ghh[base + 3 * hidden_size + h];
1302
1303 let i_val = 1.0 / (1.0 + (-gi).exp()); let f_val = 1.0 / (1.0 + (-gf).exp()); let g_val = gg.tanh(); let o_val = 1.0 / (1.0 + (-go).exp()); let c_idx = b_idx * hidden_size + h;
1311 let c_val = f_val * cprev[c_idx] + i_val * g_val;
1312 c_new_data[c_idx] = c_val;
1313 h_new_data[c_idx] = o_val * c_val.tanh();
1314 }
1315 }
1316
1317 let h_new = Tensor::from_vec(h_new_data, &[batch, hidden_size]).map_err(map_err)?;
1318 let c_new = Tensor::from_vec(c_new_data, &[batch, hidden_size]).map_err(map_err)?;
1319
1320 Ok((h_new, c_new))
1321}
1322
1323pub fn gru_cell_fused(
1336 x: &Tensor,
1337 h_prev: &Tensor,
1338 w_ih: &Tensor,
1339 w_hh: &Tensor,
1340 b_ih: &Tensor,
1341 b_hh: &Tensor,
1342) -> Result<Tensor, String> {
1343 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1344
1345 if x.ndim() != 2 {
1347 return Err("gru_cell_fused: x must be 2-D [batch, input_size]".into());
1348 }
1349 if h_prev.ndim() != 2 {
1350 return Err("gru_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
1351 }
1352 let batch = x.shape()[0];
1353 let hidden_size = h_prev.shape()[1];
1354 if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
1355 return Err(format!(
1356 "gru_cell_fused: w_ih must be [3*hidden_size, input_size], got {:?}",
1357 w_ih.shape()
1358 ));
1359 }
1360 if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
1361 return Err(format!(
1362 "gru_cell_fused: w_hh must be [3*hidden_size, hidden_size], got {:?}",
1363 w_hh.shape()
1364 ));
1365 }
1366 if b_ih.len() != 3 * hidden_size {
1367 return Err(format!(
1368 "gru_cell_fused: b_ih must have length 3*hidden_size={}, got {}",
1369 3 * hidden_size,
1370 b_ih.len()
1371 ));
1372 }
1373 if b_hh.len() != 3 * hidden_size {
1374 return Err(format!(
1375 "gru_cell_fused: b_hh must have length 3*hidden_size={}, got {}",
1376 3 * hidden_size,
1377 b_hh.len()
1378 ));
1379 }
1380
1381 let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
1383 let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
1384
1385 let gih = gates_ih.to_vec();
1387 let ghh = gates_hh.to_vec();
1388 let hp = h_prev.to_vec();
1389
1390 let mut h_new_data = vec![0.0f64; batch * hidden_size];
1391
1392 for b_idx in 0..batch {
1393 let base = b_idx * 3 * hidden_size;
1394 for h in 0..hidden_size {
1395 let r_val =
1397 1.0 / (1.0 + (-(gih[base + h] + ghh[base + h])).exp());
1398 let z_val = 1.0
1400 / (1.0
1401 + (-(gih[base + hidden_size + h]
1402 + ghh[base + hidden_size + h]))
1403 .exp());
1404 let n_val = (gih[base + 2 * hidden_size + h]
1406 + r_val * ghh[base + 2 * hidden_size + h])
1407 .tanh();
1408
1409 let h_idx = b_idx * hidden_size + h;
1411 h_new_data[h_idx] = (1.0 - z_val) * n_val + z_val * hp[h_idx];
1412 }
1413 }
1414
1415 Tensor::from_vec(h_new_data, &[batch, hidden_size])
1416 .map_err(|e| format!("{e}"))
1417}
1418
1419pub fn multi_head_attention(
1435 q: &Tensor,
1436 k: &Tensor,
1437 v: &Tensor,
1438 w_q: &Tensor,
1439 w_k: &Tensor,
1440 w_v: &Tensor,
1441 w_o: &Tensor,
1442 b_q: &Tensor,
1443 b_k: &Tensor,
1444 b_v: &Tensor,
1445 b_o: &Tensor,
1446 num_heads: usize,
1447) -> Result<Tensor, String> {
1448 let map_err = |e: crate::error::RuntimeError| format!("{e}");
1449
1450 if q.ndim() != 3 {
1451 return Err("multi_head_attention: q must be 3-D [batch, seq, model_dim]".into());
1452 }
1453
1454 let q_proj = q.linear(w_q, b_q).map_err(map_err)?;
1456 let k_proj = k.linear(w_k, b_k).map_err(map_err)?;
1457 let v_proj = v.linear(w_v, b_v).map_err(map_err)?;
1458
1459 let q_heads = q_proj.split_heads(num_heads).map_err(map_err)?;
1461 let k_heads = k_proj.split_heads(num_heads).map_err(map_err)?;
1462 let v_heads = v_proj.split_heads(num_heads).map_err(map_err)?;
1463
1464 let attn = Tensor::scaled_dot_product_attention(&q_heads, &k_heads, &v_heads)
1466 .map_err(map_err)?;
1467
1468 let merged = attn.merge_heads().map_err(map_err)?;
1470
1471 let output = merged.linear(w_o, b_o).map_err(map_err)?;
1473
1474 Ok(output)
1475}
1476
1477pub fn embedding(weight: &crate::tensor::Tensor, indices: &[i64]) -> Result<crate::tensor::Tensor, String> {
1499 let shape = weight.shape();
1500 if shape.len() != 2 {
1501 return Err(format!("embedding: weight must be 2-D [vocab_size, embed_dim], got {:?}", shape));
1502 }
1503 let vocab_size = shape[0];
1504 let embed_dim = shape[1];
1505 let weight_data = weight.to_vec();
1506
1507 let mut out = Vec::with_capacity(indices.len() * embed_dim);
1508 for &idx in indices {
1509 let i = idx as usize;
1510 if i >= vocab_size {
1511 return Err(format!("embedding: index {} out of bounds for vocab_size {}", idx, vocab_size));
1512 }
1513 let start = i * embed_dim;
1514 out.extend_from_slice(&weight_data[start..start + embed_dim]);
1515 }
1516 crate::tensor::Tensor::from_vec(out, &[indices.len(), embed_dim])
1517 .map_err(|e| e.to_string())
1518}
1519
1520pub fn batch_indices(dataset_size: usize, batch_size: usize, seed: u64) -> Vec<(usize, usize)> {
1539 use cjc_repro::Rng;
1540 let mut rng = Rng::seeded(seed);
1541 let mut indices: Vec<usize> = (0..dataset_size).collect();
1543 for i in (1..dataset_size).rev() {
1544 let j = (rng.next_u64() as usize) % (i + 1);
1545 indices.swap(i, j);
1546 }
1547 let mut batches = Vec::new();
1548 let mut i = 0;
1549 while i < dataset_size {
1550 let end = (i + batch_size).min(dataset_size);
1551 batches.push((i, end));
1552 i = end;
1553 }
1554 batches
1555}
1556
1557#[cfg(test)]
1562mod tests {
1563 use super::*;
1564
1565 #[test]
1566 fn test_mse_zero() {
1567 let pred = [1.0, 2.0, 3.0];
1568 let target = [1.0, 2.0, 3.0];
1569 assert_eq!(mse_loss(&pred, &target).unwrap(), 0.0);
1570 }
1571
1572 #[test]
1573 fn test_mse_basic() {
1574 let pred = [1.0, 2.0, 3.0];
1575 let target = [2.0, 3.0, 4.0];
1576 assert_eq!(mse_loss(&pred, &target).unwrap(), 1.0);
1577 }
1578
1579 #[test]
1580 fn test_huber_loss_quadratic() {
1581 let pred = [1.0];
1582 let target = [1.5];
1583 let h = huber_loss(&pred, &target, 1.0).unwrap();
1584 assert!((h - 0.125).abs() < 1e-12);
1586 }
1587
1588 #[test]
1589 fn test_sgd_step() {
1590 let mut params = [1.0, 2.0];
1591 let grads = [0.1, 0.2];
1592 let mut state = SgdState::new(2, 0.1, 0.0);
1593 sgd_step(&mut params, &grads, &mut state);
1594 assert!((params[0] - 0.99).abs() < 1e-12);
1595 assert!((params[1] - 1.98).abs() < 1e-12);
1596 }
1597
1598 #[test]
1599 fn test_adam_step() {
1600 let mut params = [1.0, 2.0];
1601 let grads = [0.1, 0.2];
1602 let mut state = AdamState::new(2, 0.001);
1603 adam_step(&mut params, &grads, &mut state);
1604 assert!(params[0] < 1.0);
1606 assert!(params[1] < 2.0);
1607 }
1608
1609 #[test]
1610 fn test_confusion_matrix() {
1611 let pred = [true, true, false, false, true];
1612 let actual = [true, false, true, false, true];
1613 let cm = confusion_matrix(&pred, &actual);
1614 assert_eq!(cm.tp, 2);
1615 assert_eq!(cm.fp, 1);
1616 assert_eq!(cm.fn_count, 1);
1617 assert_eq!(cm.tn, 1);
1618 }
1619
1620 #[test]
1621 fn test_precision_recall_f1() {
1622 let cm = ConfusionMatrix { tp: 5, fp: 2, tn: 8, fn_count: 1 };
1623 assert!((precision(&cm) - 5.0 / 7.0).abs() < 1e-12);
1624 assert!((recall(&cm) - 5.0 / 6.0).abs() < 1e-12);
1625 }
1626
1627 #[test]
1628 fn test_auc_perfect() {
1629 let scores = [0.9, 0.8, 0.2, 0.1];
1630 let labels = [true, true, false, false];
1631 let auc = auc_roc(&scores, &labels).unwrap();
1632 assert!((auc - 1.0).abs() < 1e-12);
1633 }
1634
1635 #[test]
1636 fn test_kfold_deterministic() {
1637 let f1 = kfold_indices(100, 5, 42);
1638 let f2 = kfold_indices(100, 5, 42);
1639 for i in 0..5 {
1640 assert_eq!(f1[i].0, f2[i].0);
1641 assert_eq!(f1[i].1, f2[i].1);
1642 }
1643 }
1644
1645 #[test]
1646 fn test_train_test_split_coverage() {
1647 let (train, test) = train_test_split(100, 0.2, 42);
1648 assert_eq!(train.len() + test.len(), 100);
1649 assert_eq!(test.len(), 20);
1650 }
1651
1652 #[test]
1655 fn test_batch_norm_identity() {
1656 let x = vec![1.0, 2.0, 3.0];
1658 let mean = vec![0.0, 0.0, 0.0];
1659 let var = vec![1.0, 1.0, 1.0];
1660 let gamma = vec![1.0, 1.0, 1.0];
1661 let beta = vec![0.0, 0.0, 0.0];
1662 let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1663 assert!((result[0] - 1.0).abs() < 1e-12);
1664 assert!((result[1] - 2.0).abs() < 1e-12);
1665 assert!((result[2] - 3.0).abs() < 1e-12);
1666 }
1667
1668 #[test]
1669 fn test_batch_norm_shift_scale() {
1670 let x = vec![0.0];
1671 let mean = vec![1.0]; let var = vec![4.0]; let gamma = vec![2.0]; let beta = vec![3.0]; let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
1676 assert!((result[0] - 2.0).abs() < 1e-12);
1677 }
1678
1679 #[test]
1680 fn test_dropout_mask_seed_determinism() {
1681 let m1 = dropout_mask(100, 0.5, 42);
1682 let m2 = dropout_mask(100, 0.5, 42);
1683 assert_eq!(m1, m2);
1684 }
1685
1686 #[test]
1687 fn test_dropout_mask_different_seeds() {
1688 let m1 = dropout_mask(100, 0.5, 42);
1689 let m2 = dropout_mask(100, 0.5, 99);
1690 assert_ne!(m1, m2);
1691 }
1692
1693 #[test]
1694 fn test_lr_step_decay_schedule() {
1695 let lr0 = lr_step_decay(0.1, 0.5, 0, 10);
1696 assert!((lr0 - 0.1).abs() < 1e-12);
1697 let lr10 = lr_step_decay(0.1, 0.5, 10, 10);
1698 assert!((lr10 - 0.05).abs() < 1e-12);
1699 let lr20 = lr_step_decay(0.1, 0.5, 20, 10);
1700 assert!((lr20 - 0.025).abs() < 1e-12);
1701 }
1702
1703 #[test]
1704 fn test_lr_cosine_endpoints() {
1705 let lr0 = lr_cosine(0.1, 0.001, 0, 100);
1706 assert!((lr0 - 0.1).abs() < 1e-10);
1707 let lr_end = lr_cosine(0.1, 0.001, 100, 100);
1708 assert!((lr_end - 0.001).abs() < 1e-10);
1709 }
1710
1711 #[test]
1712 fn test_lr_linear_warmup() {
1713 let lr0 = lr_linear_warmup(0.1, 0, 10);
1714 assert!((lr0).abs() < 1e-12);
1715 let lr5 = lr_linear_warmup(0.1, 5, 10);
1716 assert!((lr5 - 0.05).abs() < 1e-12);
1717 let lr15 = lr_linear_warmup(0.1, 15, 10);
1718 assert!((lr15 - 0.1).abs() < 1e-12);
1719 }
1720
1721 #[test]
1722 fn test_l1_penalty_known() {
1723 let params = [1.0, -2.0, 3.0];
1724 let p = l1_penalty(¶ms, 0.1);
1725 assert!((p - 0.6).abs() < 1e-12);
1726 }
1727
1728 #[test]
1729 fn test_l2_penalty_known() {
1730 let params = [1.0, -2.0, 3.0];
1731 let p = l2_penalty(¶ms, 0.1);
1732 assert!((p - 0.7).abs() < 1e-12);
1734 }
1735
1736 #[test]
1737 fn test_early_stopping_triggers() {
1738 let mut es = EarlyStoppingState::new(3, 0.01);
1739 assert!(!es.check(1.0)); assert!(!es.check(1.0)); assert!(!es.check(1.0)); assert!(es.check(1.0)); }
1744
1745 #[test]
1746 fn test_early_stopping_resets() {
1747 let mut es = EarlyStoppingState::new(3, 0.01);
1748 es.check(1.0);
1749 es.check(1.0); assert!(!es.check(0.5)); assert!(!es.check(0.5)); }
1753
1754 #[test]
1757 fn test_pca_basic_2d() {
1758 let data = Tensor::from_vec(
1760 vec![
1761 1.0, 0.1,
1762 2.0, 0.2,
1763 3.0, 0.3,
1764 4.0, 0.4,
1765 ],
1766 &[4, 2],
1767 )
1768 .unwrap();
1769 let (transformed, components, evr) = pca(&data, 2).unwrap();
1770 assert_eq!(transformed.shape(), &[4, 2]);
1771 assert_eq!(components.shape(), &[2, 2]);
1772 assert_eq!(evr.len(), 2);
1773 let total: f64 = evr.iter().sum();
1775 assert!(
1776 (total - 1.0).abs() < 1e-8,
1777 "explained variance ratios sum to {} (expected ~1.0)",
1778 total
1779 );
1780 assert!(evr[0] > 0.9, "first component explains {} of variance", evr[0]);
1782 }
1783
1784 #[test]
1785 fn test_pca_single_component() {
1786 let data = Tensor::from_vec(
1787 vec![
1788 1.0, 2.0, 3.0,
1789 4.0, 5.0, 6.0,
1790 7.0, 8.0, 9.0,
1791 ],
1792 &[3, 3],
1793 )
1794 .unwrap();
1795 let (transformed, components, evr) = pca(&data, 1).unwrap();
1796 assert_eq!(transformed.shape(), &[3, 1]);
1797 assert_eq!(components.shape(), &[1, 3]);
1798 assert_eq!(evr.len(), 1);
1799 assert!(evr[0] > 0.0 && evr[0] <= 1.0);
1800 }
1801
1802 #[test]
1803 fn test_pca_explained_variance_ratio_bounded() {
1804 let data = Tensor::from_vec(
1805 vec![
1806 1.0, 0.0, 0.5,
1807 0.0, 1.0, 0.5,
1808 1.0, 1.0, 1.0,
1809 2.0, 0.0, 1.0,
1810 0.0, 2.0, 1.0,
1811 ],
1812 &[5, 3],
1813 )
1814 .unwrap();
1815 let (_, _, evr) = pca(&data, 3).unwrap();
1816 let total: f64 = evr.iter().sum();
1817 assert!(
1818 total <= 1.0 + 1e-10,
1819 "explained variance ratios sum to {} (should be <= 1.0)",
1820 total
1821 );
1822 for &r in &evr {
1823 assert!(r >= -1e-10, "negative explained variance ratio: {}", r);
1824 }
1825 }
1826
1827 #[test]
1828 fn test_pca_deterministic() {
1829 let data = Tensor::from_vec(
1830 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1831 &[2, 3],
1832 )
1833 .unwrap();
1834 let (t1, c1, e1) = pca(&data, 2).unwrap();
1835 let (t2, c2, e2) = pca(&data, 2).unwrap();
1836 assert_eq!(t1.to_vec(), t2.to_vec(), "PCA transformed not deterministic");
1837 assert_eq!(c1.to_vec(), c2.to_vec(), "PCA components not deterministic");
1838 assert_eq!(e1, e2, "PCA explained variance not deterministic");
1839 }
1840
1841 #[test]
1842 fn test_pca_invalid_n_components() {
1843 let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
1844 assert!(pca(&data, 0).is_err(), "n_components=0 should fail");
1845 assert!(pca(&data, 3).is_err(), "n_components > min(n,p) should fail");
1846 }
1847
1848 fn rosenbrock(p: &[f64]) -> (f64, Vec<f64>) {
1855 let x = p[0];
1856 let y = p[1];
1857 let a = 1.0 - x;
1858 let b = y - x * x;
1859 let val = a * a + 100.0 * b * b;
1860 let gx = -2.0 * a - 400.0 * x * b;
1861 let gy = 200.0 * b;
1862 (val, vec![gx, gy])
1863 }
1864
1865 #[test]
1866 fn test_lbfgs_rosenbrock_converges() {
1867 let mut params = vec![-1.0_f64, 2.0_f64];
1869 let mut state = LbfgsState::new(0.5, 10);
1870
1871 let mut converged = false;
1872 for _iter in 0..200 {
1873 let (_, grads) = rosenbrock(¶ms);
1874 let grad_norm: f64 = kahan_dot(&grads, &grads).sqrt();
1875 if grad_norm < 1e-5 {
1876 converged = true;
1877 break;
1878 }
1879 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, rosenbrock);
1880 params = new_p;
1881 }
1882 assert!(converged, "L-BFGS did not converge on Rosenbrock; params = {:?}", params);
1883 assert!(
1884 (params[0] - 1.0).abs() < 1e-3,
1885 "x should converge near 1.0, got {}",
1886 params[0]
1887 );
1888 assert!(
1889 (params[1] - 1.0).abs() < 1e-3,
1890 "y should converge near 1.0, got {}",
1891 params[1]
1892 );
1893 }
1894
1895 #[test]
1896 fn test_lbfgs_determinism() {
1897 let init = vec![-1.0_f64, 2.0_f64];
1899
1900 let run = |init: &[f64]| -> Vec<f64> {
1901 let mut params = init.to_vec();
1902 let mut state = LbfgsState::new(0.5, 10);
1903 for _ in 0..20 {
1904 let (_, grads) = rosenbrock(¶ms);
1905 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, rosenbrock);
1906 params = new_p;
1907 }
1908 params
1909 };
1910
1911 let r1 = run(&init);
1912 let r2 = run(&init);
1913 assert_eq!(r1, r2, "L-BFGS must be bit-identical across runs");
1914 }
1915
1916 #[test]
1917 fn test_lbfgs_simple_quadratic() {
1918 let mut params = vec![3.0_f64];
1920 let mut state = LbfgsState::new(1.0, 5);
1921 let quadratic = |p: &[f64]| -> (f64, Vec<f64>) {
1922 (p[0] * p[0], vec![2.0 * p[0]])
1923 };
1924
1925 for _ in 0..30 {
1926 let (_, grads) = quadratic(¶ms);
1927 let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, quadratic);
1928 params = new_p;
1929 }
1930 assert!(
1931 params[0].abs() < 1e-6,
1932 "L-BFGS should minimize x^2 to ~0, got {}",
1933 params[0]
1934 );
1935 }
1936
1937 #[test]
1938 fn test_embedding_basic() {
1939 let weight = crate::tensor::Tensor::from_vec(
1940 vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
1941 &[3, 2],
1942 ).unwrap();
1943 let indices = vec![0, 2, 1];
1944 let result = super::embedding(&weight, &indices).unwrap();
1945 assert_eq!(result.shape(), &[3, 2]);
1946 let data = result.to_vec();
1947 assert!((data[0] - 0.1).abs() < 1e-12);
1948 assert!((data[1] - 0.2).abs() < 1e-12);
1949 assert!((data[2] - 0.5).abs() < 1e-12);
1950 assert!((data[3] - 0.6).abs() < 1e-12);
1951 assert!((data[4] - 0.3).abs() < 1e-12);
1952 assert!((data[5] - 0.4).abs() < 1e-12);
1953 }
1954
1955 #[test]
1956 fn test_embedding_out_of_bounds() {
1957 let weight = crate::tensor::Tensor::from_vec(vec![1.0, 2.0], &[1, 2]).unwrap();
1958 let result = super::embedding(&weight, &[1]);
1959 assert!(result.is_err());
1960 }
1961
1962 #[test]
1963 fn test_batch_indices_deterministic() {
1964 let b1 = super::batch_indices(10, 3, 42);
1965 let b2 = super::batch_indices(10, 3, 42);
1966 assert_eq!(b1, b2);
1967 let total: usize = b1.iter().map(|(s, e)| e - s).sum();
1969 assert_eq!(total, 10);
1970 }
1971
1972 #[test]
1973 fn test_wolfe_line_search_armijo() {
1974 let params = vec![3.0_f64];
1977 let direction = vec![-1.0_f64];
1978 let grads = vec![6.0_f64]; let f0 = 9.0;
1980
1981 let mut eval_count = 0;
1982 let mut obj = |p: &[f64]| -> (f64, Vec<f64>) {
1983 eval_count += 1;
1984 (p[0] * p[0], vec![2.0 * p[0]])
1985 };
1986
1987 let (alpha, new_params, new_val, _) =
1988 wolfe_line_search(¶ms, &direction, &mut obj, f0, &grads, 1.0);
1989
1990 let c1 = 1e-4;
1992 let derphi0 = kahan_dot(&grads, &direction); assert!(
1994 new_val <= f0 + c1 * alpha * derphi0,
1995 "Armijo condition violated: {} > {} + {} * {} * {}",
1996 new_val, f0, c1, alpha, derphi0
1997 );
1998 assert!(new_params[0] < 3.0, "Step should move toward minimum");
1999 }
2000}