use cjc_repro::KahanAccumulatorF64;
use crate::accumulator::BinnedAccumulatorF64;
use crate::error::RuntimeError;
use crate::tensor::Tensor;
pub fn mse_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
if pred.len() != target.len() {
return Err("mse_loss: arrays must have same length".into());
}
if pred.is_empty() {
return Err("mse_loss: empty data".into());
}
let mut acc = KahanAccumulatorF64::new();
for i in 0..pred.len() {
let d = pred[i] - target[i];
acc.add(d * d);
}
Ok(acc.finalize() / pred.len() as f64)
}
pub fn cross_entropy_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
if pred.len() != target.len() {
return Err("cross_entropy_loss: arrays must have same length".into());
}
if pred.is_empty() {
return Err("cross_entropy_loss: empty data".into());
}
let eps = 1e-12;
let mut acc = KahanAccumulatorF64::new();
for i in 0..pred.len() {
acc.add(-target[i] * (pred[i] + eps).ln());
}
Ok(acc.finalize() / pred.len() as f64)
}
pub fn binary_cross_entropy(pred: &[f64], target: &[f64]) -> Result<f64, String> {
if pred.len() != target.len() {
return Err("binary_cross_entropy: arrays must have same length".into());
}
if pred.is_empty() {
return Err("binary_cross_entropy: empty data".into());
}
let eps = 1e-12;
let mut acc = KahanAccumulatorF64::new();
for i in 0..pred.len() {
let p = pred[i].max(eps).min(1.0 - eps);
acc.add(-(target[i] * p.ln() + (1.0 - target[i]) * (1.0 - p).ln()));
}
Ok(acc.finalize() / pred.len() as f64)
}
pub fn huber_loss(pred: &[f64], target: &[f64], delta: f64) -> Result<f64, String> {
if pred.len() != target.len() {
return Err("huber_loss: arrays must have same length".into());
}
if pred.is_empty() {
return Err("huber_loss: empty data".into());
}
let mut acc = KahanAccumulatorF64::new();
for i in 0..pred.len() {
let d = (pred[i] - target[i]).abs();
if d <= delta {
acc.add(0.5 * d * d);
} else {
acc.add(delta * (d - 0.5 * delta));
}
}
Ok(acc.finalize() / pred.len() as f64)
}
pub fn hinge_loss(pred: &[f64], target: &[f64]) -> Result<f64, String> {
if pred.len() != target.len() {
return Err("hinge_loss: arrays must have same length".into());
}
if pred.is_empty() {
return Err("hinge_loss: empty data".into());
}
let mut acc = KahanAccumulatorF64::new();
for i in 0..pred.len() {
acc.add((1.0 - target[i] * pred[i]).max(0.0));
}
Ok(acc.finalize() / pred.len() as f64)
}
use crate::idx::ParamIdx;
pub struct SgdState {
pub lr: f64,
pub momentum: f64,
pub velocity: Vec<f64>,
}
impl SgdState {
pub fn new(n_params: usize, lr: f64, momentum: f64) -> Self {
Self { lr, momentum, velocity: vec![0.0; n_params] }
}
#[inline]
pub fn n_params(&self) -> usize {
self.velocity.len()
}
#[inline]
pub fn velocity_at(&self, p: ParamIdx) -> f64 {
self.velocity[p.index()]
}
#[inline]
pub fn set_velocity_at(&mut self, p: ParamIdx, value: f64) {
self.velocity[p.index()] = value;
}
}
pub fn sgd_step(params: &mut [f64], grads: &[f64], state: &mut SgdState) {
let n = params.len();
for i in 0..n {
let p = ParamIdx::from_usize(i);
let new_velocity = state.momentum * state.velocity_at(p) + grads[i];
state.set_velocity_at(p, new_velocity);
params[i] -= state.lr * state.velocity_at(p);
}
}
pub struct AdamState {
pub lr: f64,
pub beta1: f64,
pub beta2: f64,
pub eps: f64,
pub t: u64,
pub m: Vec<f64>,
pub v: Vec<f64>,
}
impl AdamState {
pub fn new(n_params: usize, lr: f64) -> Self {
Self {
lr,
beta1: 0.9,
beta2: 0.999,
eps: 1e-8,
t: 0,
m: vec![0.0; n_params],
v: vec![0.0; n_params],
}
}
#[inline]
pub fn n_params(&self) -> usize {
self.m.len()
}
#[inline]
pub fn m_at(&self, p: ParamIdx) -> f64 {
self.m[p.index()]
}
#[inline]
pub fn set_m_at(&mut self, p: ParamIdx, value: f64) {
self.m[p.index()] = value;
}
#[inline]
pub fn v_at(&self, p: ParamIdx) -> f64 {
self.v[p.index()]
}
#[inline]
pub fn set_v_at(&mut self, p: ParamIdx, value: f64) {
self.v[p.index()] = value;
}
}
pub fn adam_step(params: &mut [f64], grads: &[f64], state: &mut AdamState) {
state.t += 1;
let t = state.t as f64;
let n = params.len();
for i in 0..n {
let p = ParamIdx::from_usize(i);
let new_m = state.beta1 * state.m_at(p) + (1.0 - state.beta1) * grads[i];
let new_v = state.beta2 * state.v_at(p) + (1.0 - state.beta2) * grads[i] * grads[i];
state.set_m_at(p, new_m);
state.set_v_at(p, new_v);
let m_hat = new_m / (1.0 - state.beta1.powf(t));
let v_hat = new_v / (1.0 - state.beta2.powf(t));
params[i] -= state.lr * m_hat / (v_hat.sqrt() + state.eps);
}
}
#[derive(Debug, Clone)]
pub struct ConfusionMatrix {
pub tp: usize,
pub fp: usize,
pub tn: usize,
pub fn_count: usize,
}
pub fn confusion_matrix(predicted: &[bool], actual: &[bool]) -> ConfusionMatrix {
let mut tp = 0;
let mut fp = 0;
let mut tn = 0;
let mut fn_count = 0;
for i in 0..predicted.len().min(actual.len()) {
match (predicted[i], actual[i]) {
(true, true) => tp += 1,
(true, false) => fp += 1,
(false, true) => fn_count += 1,
(false, false) => tn += 1,
}
}
ConfusionMatrix { tp, fp, tn, fn_count }
}
pub fn precision(cm: &ConfusionMatrix) -> f64 {
let denom = cm.tp + cm.fp;
if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
}
pub fn recall(cm: &ConfusionMatrix) -> f64 {
let denom = cm.tp + cm.fn_count;
if denom == 0 { 0.0 } else { cm.tp as f64 / denom as f64 }
}
pub fn f1_score(cm: &ConfusionMatrix) -> f64 {
let p = precision(cm);
let r = recall(cm);
if p + r == 0.0 { 0.0 } else { 2.0 * p * r / (p + r) }
}
pub fn accuracy(cm: &ConfusionMatrix) -> f64 {
let total = cm.tp + cm.fp + cm.tn + cm.fn_count;
if total == 0 { 0.0 } else { (cm.tp + cm.tn) as f64 / total as f64 }
}
pub fn auc_roc(scores: &[f64], labels: &[bool]) -> Result<f64, String> {
if scores.len() != labels.len() {
return Err("auc_roc: scores and labels must have same length".into());
}
let n = scores.len();
if n == 0 {
return Err("auc_roc: empty data".into());
}
let mut indexed: Vec<(usize, f64, bool)> = scores.iter().zip(labels.iter())
.enumerate()
.map(|(i, (&s, &l))| (i, s, l))
.collect();
indexed.sort_by(|a, b| b.1.total_cmp(&a.1).then(a.0.cmp(&b.0)));
let pos_count = labels.iter().filter(|&&l| l).count();
let neg_count = n - pos_count;
if pos_count == 0 || neg_count == 0 {
return Err("auc_roc: need both positive and negative labels".into());
}
let mut auc = 0.0;
let mut tp = 0.0;
let mut fp = 0.0;
let mut prev_fpr = 0.0;
let mut prev_tpr = 0.0;
for &(_, _, label) in &indexed {
if label { tp += 1.0; } else { fp += 1.0; }
let tpr = tp / pos_count as f64;
let fpr = fp / neg_count as f64;
auc += (fpr - prev_fpr) * (tpr + prev_tpr) / 2.0;
prev_fpr = fpr;
prev_tpr = tpr;
}
Ok(auc)
}
pub fn kfold_indices(n: usize, k: usize, seed: u64) -> Vec<(Vec<usize>, Vec<usize>)> {
let mut rng = cjc_repro::Rng::seeded(seed);
let mut indices: Vec<usize> = (0..n).collect();
for i in (1..n).rev() {
let j = (rng.next_u64() as usize) % (i + 1);
indices.swap(i, j);
}
let fold_size = n / k;
let mut folds = Vec::with_capacity(k);
for fold in 0..k {
let start = fold * fold_size;
let end = if fold == k - 1 { n } else { start + fold_size };
let test: Vec<usize> = indices[start..end].to_vec();
let train: Vec<usize> = indices[..start].iter()
.chain(indices[end..].iter())
.copied()
.collect();
folds.push((train, test));
}
folds
}
pub fn train_test_split(n: usize, test_fraction: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
let mut rng = cjc_repro::Rng::seeded(seed);
let mut indices: Vec<usize> = (0..n).collect();
for i in (1..n).rev() {
let j = (rng.next_u64() as usize) % (i + 1);
indices.swap(i, j);
}
let test_size = ((n as f64) * test_fraction).round() as usize;
let test = indices[..test_size].to_vec();
let train = indices[test_size..].to_vec();
(train, test)
}
pub fn bootstrap(data: &[f64], n_resamples: usize, stat_fn: usize, seed: u64) -> Result<(f64, f64, f64, f64), String> {
if data.is_empty() { return Err("bootstrap: empty data".into()); }
let n = data.len();
let point = compute_stat(data, stat_fn)?;
let mut rng = cjc_repro::Rng::seeded(seed);
let mut stats = Vec::with_capacity(n_resamples);
let mut resample = Vec::with_capacity(n);
for _ in 0..n_resamples {
resample.clear();
for _ in 0..n {
let idx = (rng.next_u64() as usize) % n;
resample.push(data[idx]);
}
stats.push(compute_stat(&resample, stat_fn)?);
}
stats.sort_by(|a, b| a.total_cmp(b));
let ci_lower = stats[(n_resamples as f64 * 0.025) as usize];
let ci_upper = stats[(n_resamples as f64 * 0.975).min((n_resamples - 1) as f64) as usize];
let mean_stats: f64 = {
let mut acc = cjc_repro::KahanAccumulatorF64::new();
for &s in &stats { acc.add(s); }
acc.finalize() / n_resamples as f64
};
let se = {
let mut acc = cjc_repro::KahanAccumulatorF64::new();
for &s in &stats { let d = s - mean_stats; acc.add(d * d); }
(acc.finalize() / (n_resamples as f64 - 1.0)).sqrt()
};
Ok((point, ci_lower, ci_upper, se))
}
fn compute_stat(data: &[f64], stat_fn: usize) -> Result<f64, String> {
match stat_fn {
0 => {
let mut acc = cjc_repro::KahanAccumulatorF64::new();
for &x in data { acc.add(x); }
Ok(acc.finalize() / data.len() as f64)
}
1 => {
let mut sorted = data.to_vec();
sorted.sort_by(|a, b| a.total_cmp(b));
let n = sorted.len();
if n % 2 == 0 {
Ok((sorted[n/2 - 1] + sorted[n/2]) / 2.0)
} else {
Ok(sorted[n/2])
}
}
_ => Err(format!("bootstrap: unknown stat_fn {}", stat_fn)),
}
}
pub fn permutation_test(x: &[f64], y: &[f64], n_perms: usize, seed: u64) -> Result<(f64, f64), String> {
if x.is_empty() || y.is_empty() { return Err("permutation_test: empty group".into()); }
let nx = x.len();
let combined: Vec<f64> = x.iter().chain(y.iter()).copied().collect();
let n = combined.len();
let mean_x = compute_stat(x, 0)?;
let mean_y = compute_stat(y, 0)?;
let observed = (mean_x - mean_y).abs();
let mut rng = cjc_repro::Rng::seeded(seed);
let mut count_extreme = 0usize;
let mut perm = combined.clone();
for _ in 0..n_perms {
for i in (1..n).rev() {
let j = (rng.next_u64() as usize) % (i + 1);
perm.swap(i, j);
}
let perm_mean_x = compute_stat(&perm[..nx], 0)?;
let perm_mean_y = compute_stat(&perm[nx..], 0)?;
if (perm_mean_x - perm_mean_y).abs() >= observed {
count_extreme += 1;
}
}
let p_value = count_extreme as f64 / n_perms as f64;
Ok((observed, p_value))
}
pub fn stratified_split(labels: &[i64], test_frac: f64, seed: u64) -> (Vec<usize>, Vec<usize>) {
use std::collections::BTreeMap;
let n = labels.len();
let mut groups: BTreeMap<i64, Vec<usize>> = BTreeMap::new();
for (i, &label) in labels.iter().enumerate() {
groups.entry(label).or_default().push(i);
}
let mut train = Vec::with_capacity(n);
let mut test = Vec::with_capacity(n);
let mut rng = cjc_repro::Rng::seeded(seed);
for (_label, mut indices) in groups {
let m = indices.len();
for i in (1..m).rev() {
let j = (rng.next_u64() as usize) % (i + 1);
indices.swap(i, j);
}
let n_test = ((m as f64 * test_frac).round() as usize).max(if m > 1 { 1 } else { 0 });
let n_test = n_test.min(m);
test.extend_from_slice(&indices[..n_test]);
train.extend_from_slice(&indices[n_test..]);
}
train.sort();
test.sort();
(train, test)
}
pub fn batch_norm(
x: &[f64],
running_mean: &[f64],
running_var: &[f64],
gamma: &[f64],
beta: &[f64],
eps: f64,
) -> Result<Vec<f64>, String> {
let n = x.len();
if running_mean.len() != n || running_var.len() != n || gamma.len() != n || beta.len() != n {
return Err("batch_norm: all arrays must have same length".into());
}
let mut result = Vec::with_capacity(n);
for i in 0..n {
let normed = (x[i] - running_mean[i]) / (running_var[i] + eps).sqrt();
result.push(gamma[i] * normed + beta[i]);
}
Ok(result)
}
pub fn dropout_mask(n: usize, drop_prob: f64, seed: u64) -> Vec<f64> {
let mut rng = cjc_repro::Rng::seeded(seed);
let scale = if drop_prob < 1.0 { 1.0 / (1.0 - drop_prob) } else { 0.0 };
let mut mask = Vec::with_capacity(n);
for _ in 0..n {
let r = (rng.next_u64() as f64) / (u64::MAX as f64);
if r < drop_prob {
mask.push(0.0);
} else {
mask.push(scale);
}
}
mask
}
pub fn apply_dropout(data: &[f64], mask: &[f64]) -> Result<Vec<f64>, String> {
if data.len() != mask.len() {
return Err("apply_dropout: data and mask must have same length".into());
}
Ok(data.iter().zip(mask.iter()).map(|(&d, &m)| d * m).collect())
}
pub fn lr_step_decay(initial_lr: f64, decay_rate: f64, epoch: usize, step_size: usize) -> f64 {
initial_lr * decay_rate.powi((epoch / step_size) as i32)
}
pub fn lr_cosine(max_lr: f64, min_lr: f64, epoch: usize, total_epochs: usize) -> f64 {
let ratio = epoch as f64 / total_epochs as f64;
min_lr + 0.5 * (max_lr - min_lr) * (1.0 + (std::f64::consts::PI * ratio).cos())
}
pub fn lr_linear_warmup(initial_lr: f64, epoch: usize, warmup_epochs: usize) -> f64 {
if warmup_epochs == 0 {
return initial_lr;
}
initial_lr * (epoch as f64 / warmup_epochs as f64).min(1.0)
}
pub fn l1_penalty(params: &[f64], lambda: f64) -> f64 {
let mut acc = KahanAccumulatorF64::new();
for &p in params {
acc.add(p.abs());
}
lambda * acc.finalize()
}
pub fn l2_penalty(params: &[f64], lambda: f64) -> f64 {
let mut acc = KahanAccumulatorF64::new();
for &p in params {
acc.add(p * p);
}
0.5 * lambda * acc.finalize()
}
pub fn l1_grad(params: &[f64], lambda: f64) -> Vec<f64> {
params.iter().map(|&p| {
if p > 0.0 { lambda } else if p < 0.0 { -lambda } else { 0.0 }
}).collect()
}
pub fn l2_grad(params: &[f64], lambda: f64) -> Vec<f64> {
params.iter().map(|&p| lambda * p).collect()
}
pub struct EarlyStoppingState {
pub patience: usize,
pub min_delta: f64,
pub best_loss: f64,
pub wait: usize,
pub stopped: bool,
}
impl EarlyStoppingState {
pub fn new(patience: usize, min_delta: f64) -> Self {
Self {
patience,
min_delta,
best_loss: f64::INFINITY,
wait: 0,
stopped: false,
}
}
pub fn check(&mut self, current_loss: f64) -> bool {
if current_loss < self.best_loss - self.min_delta {
self.best_loss = current_loss;
self.wait = 0;
} else {
self.wait += 1;
}
if self.wait >= self.patience {
self.stopped = true;
}
self.stopped
}
}
pub fn pca(
data: &Tensor,
n_components: usize,
) -> Result<(Tensor, Tensor, Vec<f64>), RuntimeError> {
if data.ndim() != 2 {
return Err(RuntimeError::InvalidOperation(
"PCA requires a 2D data matrix".to_string(),
));
}
let n_samples = data.shape()[0];
let n_features = data.shape()[1];
if n_samples == 0 || n_features == 0 {
return Err(RuntimeError::InvalidOperation(
"PCA: empty data matrix".to_string(),
));
}
if n_components == 0 || n_components > n_features.min(n_samples) {
return Err(RuntimeError::InvalidOperation(format!(
"PCA: n_components ({}) must be in [1, min(n_samples, n_features) = {}]",
n_components,
n_features.min(n_samples)
)));
}
let raw = data.to_vec();
let mut means = vec![0.0f64; n_features];
for j in 0..n_features {
let mut acc = BinnedAccumulatorF64::new();
for i in 0..n_samples {
acc.add(raw[i * n_features + j]);
}
means[j] = acc.finalize() / n_samples as f64;
}
let mut centered = vec![0.0f64; n_samples * n_features];
for i in 0..n_samples {
for j in 0..n_features {
centered[i * n_features + j] = raw[i * n_features + j] - means[j];
}
}
let centered_tensor = Tensor::from_vec(centered, &[n_samples, n_features])?;
let (u, s, vt) = centered_tensor.svd()?;
let k = n_components.min(s.len());
let vt_data = vt.to_vec();
let vt_cols = vt.shape()[1]; let mut components = vec![0.0f64; k * n_features];
for i in 0..k {
for j in 0..n_features {
components[i * n_features + j] = vt_data[i * vt_cols + j];
}
}
let denom = if n_samples > 1 {
(n_samples - 1) as f64
} else {
1.0
};
let mut total_var_acc = BinnedAccumulatorF64::new();
for &si in &s {
total_var_acc.add(si * si / denom);
}
let total_var = total_var_acc.finalize();
let explained_variance_ratio: Vec<f64> = if total_var > 1e-15 {
s[..k]
.iter()
.map(|&si| (si * si / denom) / total_var)
.collect()
} else {
vec![0.0; k]
};
let u_data = u.to_vec();
let u_cols = u.shape()[1];
let mut transformed = vec![0.0f64; n_samples * k];
for i in 0..n_samples {
for j in 0..k {
transformed[i * k + j] = u_data[i * u_cols + j] * s[j];
}
}
Ok((
Tensor::from_vec(transformed, &[n_samples, k])?,
Tensor::from_vec(components, &[k, n_features])?,
explained_variance_ratio,
))
}
pub struct LbfgsState {
pub lr: f64,
pub m: usize,
pub s_history: Vec<Vec<f64>>,
pub y_history: Vec<Vec<f64>>,
pub prev_params: Option<Vec<f64>>,
pub prev_grad: Option<Vec<f64>>,
}
impl LbfgsState {
pub fn new(lr: f64, m: usize) -> Self {
Self {
lr,
m,
s_history: Vec::new(),
y_history: Vec::new(),
prev_params: None,
prev_grad: None,
}
}
}
fn kahan_dot(a: &[f64], b: &[f64]) -> f64 {
debug_assert_eq!(a.len(), b.len());
let mut acc = KahanAccumulatorF64::new();
for (&ai, &bi) in a.iter().zip(b.iter()) {
acc.add(ai * bi);
}
acc.finalize()
}
pub fn wolfe_line_search<F>(
params: &[f64],
direction: &[f64],
f: &mut F,
f0: f64,
g0: &[f64],
alpha_init: f64,
) -> (f64, Vec<f64>, f64, Vec<f64>)
where
F: FnMut(&[f64]) -> (f64, Vec<f64>),
{
let c1 = 1e-4_f64;
let c2 = 0.9_f64;
let derphi0 = kahan_dot(g0, direction);
let step = |alpha: f64| -> Vec<f64> {
params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
};
let max_iter = 30;
let mut alpha_lo = 0.0_f64;
let mut alpha_hi = f64::INFINITY;
let mut phi_lo = f0;
let mut dphi_lo = derphi0;
let mut alpha = alpha_init;
let mut best_alpha = alpha_init;
let mut best_params = step(alpha_init);
let (mut best_val, mut best_grad) = f(&best_params);
for _iter in 0..max_iter {
let x_alpha = step(alpha);
let (phi_alpha, grad_alpha) = f(&x_alpha);
let dphi_alpha = kahan_dot(&grad_alpha, direction);
if phi_alpha < best_val {
best_alpha = alpha;
best_params = x_alpha.clone();
best_val = phi_alpha;
best_grad = grad_alpha.clone();
}
if phi_alpha > f0 + c1 * alpha * derphi0 || (phi_alpha >= phi_lo && alpha_lo > 0.0) {
alpha_hi = alpha;
let (za, zp, zv, zg) = wolfe_zoom(
params, direction, f, f0, derphi0, c1, c2,
alpha_lo, alpha_hi, phi_lo, dphi_lo,
);
return (za, zp, zv, zg);
}
if dphi_alpha.abs() <= c2 * derphi0.abs() {
return (alpha, x_alpha, phi_alpha, grad_alpha);
}
if dphi_alpha >= 0.0 {
let (za, zp, zv, zg) = wolfe_zoom(
params, direction, f, f0, derphi0, c1, c2,
alpha, alpha_lo, phi_alpha, dphi_alpha,
);
return (za, zp, zv, zg);
}
alpha_lo = alpha;
phi_lo = phi_alpha;
dphi_lo = dphi_alpha;
alpha = if alpha_hi.is_finite() {
(alpha_lo + alpha_hi) * 0.5
} else {
(alpha * 2.0).min(1e8)
};
}
(best_alpha, best_params, best_val, best_grad)
}
#[allow(clippy::too_many_arguments)]
fn wolfe_zoom<F>(
params: &[f64],
direction: &[f64],
f: &mut F,
f0: f64,
derphi0: f64,
c1: f64,
c2: f64,
mut alpha_lo: f64,
mut alpha_hi: f64,
mut phi_lo: f64,
_dphi_lo: f64,
) -> (f64, Vec<f64>, f64, Vec<f64>)
where
F: FnMut(&[f64]) -> (f64, Vec<f64>),
{
let step = |alpha: f64| -> Vec<f64> {
params.iter().zip(direction.iter()).map(|(&p, &d)| p + alpha * d).collect()
};
let max_zoom = 20;
let mut best_alpha = alpha_lo;
let mut best_x = step(alpha_lo);
let (mut best_val, mut best_grad) = f(&best_x);
for _i in 0..max_zoom {
let alpha_j = (alpha_lo + alpha_hi) * 0.5;
let x_j = step(alpha_j);
let (phi_j, grad_j) = f(&x_j);
let dphi_j = kahan_dot(&grad_j, direction);
if phi_j < best_val {
best_alpha = alpha_j;
best_x = x_j.clone();
best_val = phi_j;
best_grad = grad_j.clone();
}
if phi_j > f0 + c1 * alpha_j * derphi0 || phi_j >= phi_lo {
alpha_hi = alpha_j;
} else {
if dphi_j.abs() <= c2 * derphi0.abs() {
return (alpha_j, x_j, phi_j, grad_j);
}
if dphi_j * (alpha_hi - alpha_lo) >= 0.0 {
alpha_hi = alpha_lo;
}
alpha_lo = alpha_j;
phi_lo = phi_j;
}
if (alpha_hi - alpha_lo).abs() < 1e-14 {
break;
}
}
(best_alpha, best_x, best_val, best_grad)
}
pub fn lbfgs_step<F>(
params: &[f64],
grads: &[f64],
state: &mut LbfgsState,
mut f: F,
) -> (Vec<f64>, Vec<f64>, bool)
where
F: FnMut(&[f64]) -> (f64, Vec<f64>),
{
let n = params.len();
debug_assert_eq!(grads.len(), n);
let hist_len = state.s_history.len();
let mut q: Vec<f64> = grads.to_vec();
let mut alphas = vec![0.0_f64; hist_len];
let mut rhos = vec![0.0_f64; hist_len];
for i in (0..hist_len).rev() {
let sy = kahan_dot(&state.s_history[i], &state.y_history[i]);
rhos[i] = if sy.abs() < 1e-300 { 0.0 } else { 1.0 / sy };
let sq = kahan_dot(&state.s_history[i], &q);
alphas[i] = rhos[i] * sq;
for j in 0..n {
q[j] -= alphas[i] * state.y_history[i][j];
}
}
let scale = if hist_len > 0 {
let last = hist_len - 1;
let sy = kahan_dot(&state.s_history[last], &state.y_history[last]);
let yy = kahan_dot(&state.y_history[last], &state.y_history[last]);
if yy.abs() < 1e-300 { 1.0 } else { sy / yy }
} else {
1.0
};
let mut r: Vec<f64> = q.iter().map(|&qi| scale * qi).collect();
for i in 0..hist_len {
let yr = kahan_dot(&state.y_history[i], &r);
let beta = rhos[i] * yr;
let diff = alphas[i] - beta;
for j in 0..n {
r[j] += diff * state.s_history[i][j];
}
}
let direction: Vec<f64> = r.iter().map(|&ri| -ri).collect();
let descent_check = kahan_dot(&direction, grads);
let (direction, is_descent) = if descent_check >= 0.0 || !descent_check.is_finite() {
let norm_g = kahan_dot(grads, grads).sqrt().max(1e-300);
(grads.iter().map(|&g| -g / norm_g).collect::<Vec<f64>>(), false)
} else {
(direction, true)
};
let (f0, _) = f(params);
let (_, new_params, _, new_grads) = wolfe_line_search(
params,
&direction,
&mut f,
f0,
grads,
state.lr,
);
let s_k: Vec<f64> = new_params.iter().zip(params.iter()).map(|(&np, &p)| np - p).collect();
let y_k: Vec<f64> = new_grads.iter().zip(grads.iter()).map(|(&ng, &g)| ng - g).collect();
let sy = kahan_dot(&s_k, &y_k);
if sy > 1e-300 {
state.s_history.push(s_k);
state.y_history.push(y_k);
if state.s_history.len() > state.m {
state.s_history.remove(0);
state.y_history.remove(0);
}
}
state.prev_params = Some(new_params.clone());
state.prev_grad = Some(new_grads.clone());
(new_params, new_grads, is_descent)
}
pub fn lstm_cell(
x: &Tensor,
h_prev: &Tensor,
c_prev: &Tensor,
w_ih: &Tensor,
w_hh: &Tensor,
b_ih: &Tensor,
b_hh: &Tensor,
) -> Result<(Tensor, Tensor), String> {
let map_err = |e: crate::error::RuntimeError| format!("{e}");
if x.ndim() != 2 {
return Err("lstm_cell: x must be 2-D [batch, input_size]".into());
}
if h_prev.ndim() != 2 {
return Err("lstm_cell: h_prev must be 2-D [batch, hidden_size]".into());
}
if c_prev.ndim() != 2 {
return Err("lstm_cell: c_prev must be 2-D [batch, hidden_size]".into());
}
let hidden_size = h_prev.shape()[1];
if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
return Err(format!(
"lstm_cell: w_ih must be [4*hidden_size, input_size], got {:?}",
w_ih.shape()
));
}
if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
return Err(format!(
"lstm_cell: w_hh must be [4*hidden_size, hidden_size], got {:?}",
w_hh.shape()
));
}
if b_ih.len() != 4 * hidden_size {
return Err(format!(
"lstm_cell: b_ih must have length 4*hidden_size={}, got {}",
4 * hidden_size,
b_ih.len()
));
}
if b_hh.len() != 4 * hidden_size {
return Err(format!(
"lstm_cell: b_hh must have length 4*hidden_size={}, got {}",
4 * hidden_size,
b_hh.len()
));
}
let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
let gates = gates_ih.add(&gates_hh).map_err(map_err)?;
let chunks = gates.chunk(4, 1).map_err(map_err)?;
let gates_i = &chunks[0];
let gates_f = &chunks[1];
let gates_g = &chunks[2];
let gates_o = &chunks[3];
let i = gates_i.sigmoid();
let f = gates_f.sigmoid();
let g = gates_g.tanh_activation();
let o = gates_o.sigmoid();
let fc = f.mul_elem(c_prev).map_err(map_err)?;
let ig = i.mul_elem(&g).map_err(map_err)?;
let c_new = fc.add(&ig).map_err(map_err)?;
let c_tanh = c_new.tanh_activation();
let h_new = o.mul_elem(&c_tanh).map_err(map_err)?;
Ok((h_new, c_new))
}
pub fn gru_cell(
x: &Tensor,
h_prev: &Tensor,
w_ih: &Tensor,
w_hh: &Tensor,
b_ih: &Tensor,
b_hh: &Tensor,
) -> Result<Tensor, String> {
let map_err = |e: crate::error::RuntimeError| format!("{e}");
if x.ndim() != 2 {
return Err("gru_cell: x must be 2-D [batch, input_size]".into());
}
if h_prev.ndim() != 2 {
return Err("gru_cell: h_prev must be 2-D [batch, hidden_size]".into());
}
let hidden_size = h_prev.shape()[1];
if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
return Err(format!(
"gru_cell: w_ih must be [3*hidden_size, input_size], got {:?}",
w_ih.shape()
));
}
if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
return Err(format!(
"gru_cell: w_hh must be [3*hidden_size, hidden_size], got {:?}",
w_hh.shape()
));
}
if b_ih.len() != 3 * hidden_size {
return Err(format!(
"gru_cell: b_ih must have length 3*hidden_size={}, got {}",
3 * hidden_size,
b_ih.len()
));
}
if b_hh.len() != 3 * hidden_size {
return Err(format!(
"gru_cell: b_hh must have length 3*hidden_size={}, got {}",
3 * hidden_size,
b_hh.len()
));
}
let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
let ih_chunks = gates_ih.chunk(3, 1).map_err(map_err)?;
let hh_chunks = gates_hh.chunk(3, 1).map_err(map_err)?;
let r_ih = &ih_chunks[0];
let z_ih = &ih_chunks[1];
let n_ih = &ih_chunks[2];
let r_hh = &hh_chunks[0];
let z_hh = &hh_chunks[1];
let n_hh = &hh_chunks[2];
let r = r_ih.add(r_hh).map_err(map_err)?.sigmoid();
let z = z_ih.add(z_hh).map_err(map_err)?.sigmoid();
let r_n_hh = r.mul_elem(n_hh).map_err(map_err)?;
let n = n_ih.add(&r_n_hh).map_err(map_err)?.tanh_activation();
let ones = Tensor::ones(z.shape());
let one_minus_z = ones.sub(&z).map_err(map_err)?;
let term1 = one_minus_z.mul_elem(&n).map_err(map_err)?;
let term2 = z.mul_elem(h_prev).map_err(map_err)?;
let h_new = term1.add(&term2).map_err(map_err)?;
Ok(h_new)
}
pub fn lstm_cell_fused(
x: &Tensor,
h_prev: &Tensor,
c_prev: &Tensor,
w_ih: &Tensor,
w_hh: &Tensor,
b_ih: &Tensor,
b_hh: &Tensor,
) -> Result<(Tensor, Tensor), String> {
let map_err = |e: crate::error::RuntimeError| format!("{e}");
if x.ndim() != 2 {
return Err("lstm_cell_fused: x must be 2-D [batch, input_size]".into());
}
if h_prev.ndim() != 2 {
return Err("lstm_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
}
if c_prev.ndim() != 2 {
return Err("lstm_cell_fused: c_prev must be 2-D [batch, hidden_size]".into());
}
let batch = x.shape()[0];
let hidden_size = h_prev.shape()[1];
if w_ih.ndim() != 2 || w_ih.shape()[0] != 4 * hidden_size {
return Err(format!(
"lstm_cell_fused: w_ih must be [4*hidden_size, input_size], got {:?}",
w_ih.shape()
));
}
if w_hh.ndim() != 2 || w_hh.shape()[0] != 4 * hidden_size {
return Err(format!(
"lstm_cell_fused: w_hh must be [4*hidden_size, hidden_size], got {:?}",
w_hh.shape()
));
}
if b_ih.len() != 4 * hidden_size {
return Err(format!(
"lstm_cell_fused: b_ih must have length 4*hidden_size={}, got {}",
4 * hidden_size,
b_ih.len()
));
}
if b_hh.len() != 4 * hidden_size {
return Err(format!(
"lstm_cell_fused: b_hh must have length 4*hidden_size={}, got {}",
4 * hidden_size,
b_hh.len()
));
}
let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
let gih = gates_ih.to_vec();
let ghh = gates_hh.to_vec();
let cprev = c_prev.to_vec();
let mut h_new_data = vec![0.0f64; batch * hidden_size];
let mut c_new_data = vec![0.0f64; batch * hidden_size];
for b_idx in 0..batch {
let base = b_idx * 4 * hidden_size;
for h in 0..hidden_size {
let gi = gih[base + h] + ghh[base + h];
let gf = gih[base + hidden_size + h] + ghh[base + hidden_size + h];
let gg = gih[base + 2 * hidden_size + h] + ghh[base + 2 * hidden_size + h];
let go = gih[base + 3 * hidden_size + h] + ghh[base + 3 * hidden_size + h];
let i_val = 1.0 / (1.0 + (-gi).exp()); let f_val = 1.0 / (1.0 + (-gf).exp()); let g_val = gg.tanh(); let o_val = 1.0 / (1.0 + (-go).exp());
let c_idx = b_idx * hidden_size + h;
let c_val = f_val * cprev[c_idx] + i_val * g_val;
c_new_data[c_idx] = c_val;
h_new_data[c_idx] = o_val * c_val.tanh();
}
}
let h_new = Tensor::from_vec(h_new_data, &[batch, hidden_size]).map_err(map_err)?;
let c_new = Tensor::from_vec(c_new_data, &[batch, hidden_size]).map_err(map_err)?;
Ok((h_new, c_new))
}
pub fn gru_cell_fused(
x: &Tensor,
h_prev: &Tensor,
w_ih: &Tensor,
w_hh: &Tensor,
b_ih: &Tensor,
b_hh: &Tensor,
) -> Result<Tensor, String> {
let map_err = |e: crate::error::RuntimeError| format!("{e}");
if x.ndim() != 2 {
return Err("gru_cell_fused: x must be 2-D [batch, input_size]".into());
}
if h_prev.ndim() != 2 {
return Err("gru_cell_fused: h_prev must be 2-D [batch, hidden_size]".into());
}
let batch = x.shape()[0];
let hidden_size = h_prev.shape()[1];
if w_ih.ndim() != 2 || w_ih.shape()[0] != 3 * hidden_size {
return Err(format!(
"gru_cell_fused: w_ih must be [3*hidden_size, input_size], got {:?}",
w_ih.shape()
));
}
if w_hh.ndim() != 2 || w_hh.shape()[0] != 3 * hidden_size {
return Err(format!(
"gru_cell_fused: w_hh must be [3*hidden_size, hidden_size], got {:?}",
w_hh.shape()
));
}
if b_ih.len() != 3 * hidden_size {
return Err(format!(
"gru_cell_fused: b_ih must have length 3*hidden_size={}, got {}",
3 * hidden_size,
b_ih.len()
));
}
if b_hh.len() != 3 * hidden_size {
return Err(format!(
"gru_cell_fused: b_hh must have length 3*hidden_size={}, got {}",
3 * hidden_size,
b_hh.len()
));
}
let gates_ih = x.linear(w_ih, b_ih).map_err(map_err)?;
let gates_hh = h_prev.linear(w_hh, b_hh).map_err(map_err)?;
let gih = gates_ih.to_vec();
let ghh = gates_hh.to_vec();
let hp = h_prev.to_vec();
let mut h_new_data = vec![0.0f64; batch * hidden_size];
for b_idx in 0..batch {
let base = b_idx * 3 * hidden_size;
for h in 0..hidden_size {
let r_val =
1.0 / (1.0 + (-(gih[base + h] + ghh[base + h])).exp());
let z_val = 1.0
/ (1.0
+ (-(gih[base + hidden_size + h]
+ ghh[base + hidden_size + h]))
.exp());
let n_val = (gih[base + 2 * hidden_size + h]
+ r_val * ghh[base + 2 * hidden_size + h])
.tanh();
let h_idx = b_idx * hidden_size + h;
h_new_data[h_idx] = (1.0 - z_val) * n_val + z_val * hp[h_idx];
}
}
Tensor::from_vec(h_new_data, &[batch, hidden_size])
.map_err(|e| format!("{e}"))
}
pub fn multi_head_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
w_q: &Tensor,
w_k: &Tensor,
w_v: &Tensor,
w_o: &Tensor,
b_q: &Tensor,
b_k: &Tensor,
b_v: &Tensor,
b_o: &Tensor,
num_heads: usize,
) -> Result<Tensor, String> {
let map_err = |e: crate::error::RuntimeError| format!("{e}");
if q.ndim() != 3 {
return Err("multi_head_attention: q must be 3-D [batch, seq, model_dim]".into());
}
let q_proj = q.linear(w_q, b_q).map_err(map_err)?;
let k_proj = k.linear(w_k, b_k).map_err(map_err)?;
let v_proj = v.linear(w_v, b_v).map_err(map_err)?;
let q_heads = q_proj.split_heads(num_heads).map_err(map_err)?;
let k_heads = k_proj.split_heads(num_heads).map_err(map_err)?;
let v_heads = v_proj.split_heads(num_heads).map_err(map_err)?;
let attn = Tensor::scaled_dot_product_attention(&q_heads, &k_heads, &v_heads)
.map_err(map_err)?;
let merged = attn.merge_heads().map_err(map_err)?;
let output = merged.linear(w_o, b_o).map_err(map_err)?;
Ok(output)
}
pub fn embedding(weight: &crate::tensor::Tensor, indices: &[i64]) -> Result<crate::tensor::Tensor, String> {
let shape = weight.shape();
if shape.len() != 2 {
return Err(format!("embedding: weight must be 2-D [vocab_size, embed_dim], got {:?}", shape));
}
let vocab_size = shape[0];
let embed_dim = shape[1];
let weight_data = weight.to_vec();
let mut out = Vec::with_capacity(indices.len() * embed_dim);
for &idx in indices {
let i = idx as usize;
if i >= vocab_size {
return Err(format!("embedding: index {} out of bounds for vocab_size {}", idx, vocab_size));
}
let start = i * embed_dim;
out.extend_from_slice(&weight_data[start..start + embed_dim]);
}
crate::tensor::Tensor::from_vec(out, &[indices.len(), embed_dim])
.map_err(|e| e.to_string())
}
pub fn batch_indices(dataset_size: usize, batch_size: usize, seed: u64) -> Vec<(usize, usize)> {
use cjc_repro::Rng;
let mut rng = Rng::seeded(seed);
let mut indices: Vec<usize> = (0..dataset_size).collect();
for i in (1..dataset_size).rev() {
let j = (rng.next_u64() as usize) % (i + 1);
indices.swap(i, j);
}
let mut batches = Vec::new();
let mut i = 0;
while i < dataset_size {
let end = (i + batch_size).min(dataset_size);
batches.push((i, end));
i = end;
}
batches
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_mse_zero() {
let pred = [1.0, 2.0, 3.0];
let target = [1.0, 2.0, 3.0];
assert_eq!(mse_loss(&pred, &target).unwrap(), 0.0);
}
#[test]
fn test_mse_basic() {
let pred = [1.0, 2.0, 3.0];
let target = [2.0, 3.0, 4.0];
assert_eq!(mse_loss(&pred, &target).unwrap(), 1.0);
}
#[test]
fn test_huber_loss_quadratic() {
let pred = [1.0];
let target = [1.5];
let h = huber_loss(&pred, &target, 1.0).unwrap();
assert!((h - 0.125).abs() < 1e-12);
}
#[test]
fn test_sgd_step() {
let mut params = [1.0, 2.0];
let grads = [0.1, 0.2];
let mut state = SgdState::new(2, 0.1, 0.0);
sgd_step(&mut params, &grads, &mut state);
assert!((params[0] - 0.99).abs() < 1e-12);
assert!((params[1] - 1.98).abs() < 1e-12);
}
#[test]
fn test_adam_step() {
let mut params = [1.0, 2.0];
let grads = [0.1, 0.2];
let mut state = AdamState::new(2, 0.001);
adam_step(&mut params, &grads, &mut state);
assert!(params[0] < 1.0);
assert!(params[1] < 2.0);
}
#[test]
fn paramidx_accessors_agree_with_direct_field_reads_adam() {
let mut params = vec![0.5, -0.3, 1.7, 0.0];
let grads = vec![0.1, -0.2, 0.05, 0.4];
let mut state = AdamState::new(4, 0.01);
for _ in 0..5 {
adam_step(&mut params, &grads, &mut state);
}
for i in 0..4 {
let p = crate::idx::ParamIdx::from_usize(i);
assert_eq!(state.m_at(p).to_bits(), state.m[i].to_bits());
assert_eq!(state.v_at(p).to_bits(), state.v[i].to_bits());
}
assert_eq!(state.n_params(), 4);
}
#[test]
fn paramidx_accessors_agree_with_direct_field_reads_sgd() {
let mut params = vec![1.0, 2.0, 3.0];
let grads = vec![0.01, -0.02, 0.03];
let mut state = SgdState::new(3, 0.05, 0.9);
for _ in 0..7 {
sgd_step(&mut params, &grads, &mut state);
}
for i in 0..3 {
let p = crate::idx::ParamIdx::from_usize(i);
assert_eq!(state.velocity_at(p).to_bits(), state.velocity[i].to_bits());
}
assert_eq!(state.n_params(), 3);
}
#[test]
fn paramidx_typed_setters_round_trip() {
let mut state = AdamState::new(8, 0.001);
for i in 0..8 {
let p = crate::idx::ParamIdx::from_usize(i);
state.set_m_at(p, (i as f64) * 0.5);
state.set_v_at(p, (i as f64) * 0.25);
}
for i in 0..8 {
let p = crate::idx::ParamIdx::from_usize(i);
assert_eq!(state.m_at(p), (i as f64) * 0.5);
assert_eq!(state.v_at(p), (i as f64) * 0.25);
assert_eq!(state.m[i], (i as f64) * 0.5);
assert_eq!(state.v[i], (i as f64) * 0.25);
}
}
#[test]
fn paramidx_adam_step_byte_equal_across_independent_runs() {
let init_params = vec![0.5, -0.5, 1.5, -1.5, 0.7];
let grads = vec![0.1, -0.05, 0.02, 0.08, -0.03];
let run = || {
let mut params = init_params.clone();
let mut state = AdamState::new(5, 0.01);
for _ in 0..20 {
adam_step(&mut params, &grads, &mut state);
}
(params, state.m.clone(), state.v.clone(), state.t)
};
let (p1, m1, v1, t1) = run();
let (p2, m2, v2, t2) = run();
let bits = |xs: &[f64]| -> Vec<u64> { xs.iter().map(|x| x.to_bits()).collect() };
assert_eq!(bits(&p1), bits(&p2));
assert_eq!(bits(&m1), bits(&m2));
assert_eq!(bits(&v1), bits(&v2));
assert_eq!(t1, t2);
}
#[test]
fn test_confusion_matrix() {
let pred = [true, true, false, false, true];
let actual = [true, false, true, false, true];
let cm = confusion_matrix(&pred, &actual);
assert_eq!(cm.tp, 2);
assert_eq!(cm.fp, 1);
assert_eq!(cm.fn_count, 1);
assert_eq!(cm.tn, 1);
}
#[test]
fn test_precision_recall_f1() {
let cm = ConfusionMatrix { tp: 5, fp: 2, tn: 8, fn_count: 1 };
assert!((precision(&cm) - 5.0 / 7.0).abs() < 1e-12);
assert!((recall(&cm) - 5.0 / 6.0).abs() < 1e-12);
}
#[test]
fn test_auc_perfect() {
let scores = [0.9, 0.8, 0.2, 0.1];
let labels = [true, true, false, false];
let auc = auc_roc(&scores, &labels).unwrap();
assert!((auc - 1.0).abs() < 1e-12);
}
#[test]
fn test_kfold_deterministic() {
let f1 = kfold_indices(100, 5, 42);
let f2 = kfold_indices(100, 5, 42);
for i in 0..5 {
assert_eq!(f1[i].0, f2[i].0);
assert_eq!(f1[i].1, f2[i].1);
}
}
#[test]
fn test_train_test_split_coverage() {
let (train, test) = train_test_split(100, 0.2, 42);
assert_eq!(train.len() + test.len(), 100);
assert_eq!(test.len(), 20);
}
#[test]
fn test_batch_norm_identity() {
let x = vec![1.0, 2.0, 3.0];
let mean = vec![0.0, 0.0, 0.0];
let var = vec![1.0, 1.0, 1.0];
let gamma = vec![1.0, 1.0, 1.0];
let beta = vec![0.0, 0.0, 0.0];
let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
assert!((result[0] - 1.0).abs() < 1e-12);
assert!((result[1] - 2.0).abs() < 1e-12);
assert!((result[2] - 3.0).abs() < 1e-12);
}
#[test]
fn test_batch_norm_shift_scale() {
let x = vec![0.0];
let mean = vec![1.0]; let var = vec![4.0]; let gamma = vec![2.0]; let beta = vec![3.0]; let result = batch_norm(&x, &mean, &var, &gamma, &beta, 0.0).unwrap();
assert!((result[0] - 2.0).abs() < 1e-12);
}
#[test]
fn test_dropout_mask_seed_determinism() {
let m1 = dropout_mask(100, 0.5, 42);
let m2 = dropout_mask(100, 0.5, 42);
assert_eq!(m1, m2);
}
#[test]
fn test_dropout_mask_different_seeds() {
let m1 = dropout_mask(100, 0.5, 42);
let m2 = dropout_mask(100, 0.5, 99);
assert_ne!(m1, m2);
}
#[test]
fn test_lr_step_decay_schedule() {
let lr0 = lr_step_decay(0.1, 0.5, 0, 10);
assert!((lr0 - 0.1).abs() < 1e-12);
let lr10 = lr_step_decay(0.1, 0.5, 10, 10);
assert!((lr10 - 0.05).abs() < 1e-12);
let lr20 = lr_step_decay(0.1, 0.5, 20, 10);
assert!((lr20 - 0.025).abs() < 1e-12);
}
#[test]
fn test_lr_cosine_endpoints() {
let lr0 = lr_cosine(0.1, 0.001, 0, 100);
assert!((lr0 - 0.1).abs() < 1e-10);
let lr_end = lr_cosine(0.1, 0.001, 100, 100);
assert!((lr_end - 0.001).abs() < 1e-10);
}
#[test]
fn test_lr_linear_warmup() {
let lr0 = lr_linear_warmup(0.1, 0, 10);
assert!((lr0).abs() < 1e-12);
let lr5 = lr_linear_warmup(0.1, 5, 10);
assert!((lr5 - 0.05).abs() < 1e-12);
let lr15 = lr_linear_warmup(0.1, 15, 10);
assert!((lr15 - 0.1).abs() < 1e-12);
}
#[test]
fn test_l1_penalty_known() {
let params = [1.0, -2.0, 3.0];
let p = l1_penalty(¶ms, 0.1);
assert!((p - 0.6).abs() < 1e-12);
}
#[test]
fn test_l2_penalty_known() {
let params = [1.0, -2.0, 3.0];
let p = l2_penalty(¶ms, 0.1);
assert!((p - 0.7).abs() < 1e-12);
}
#[test]
fn test_early_stopping_triggers() {
let mut es = EarlyStoppingState::new(3, 0.01);
assert!(!es.check(1.0)); assert!(!es.check(1.0)); assert!(!es.check(1.0)); assert!(es.check(1.0)); }
#[test]
fn test_early_stopping_resets() {
let mut es = EarlyStoppingState::new(3, 0.01);
es.check(1.0);
es.check(1.0); assert!(!es.check(0.5)); assert!(!es.check(0.5)); }
#[test]
fn test_pca_basic_2d() {
let data = Tensor::from_vec(
vec![
1.0, 0.1,
2.0, 0.2,
3.0, 0.3,
4.0, 0.4,
],
&[4, 2],
)
.unwrap();
let (transformed, components, evr) = pca(&data, 2).unwrap();
assert_eq!(transformed.shape(), &[4, 2]);
assert_eq!(components.shape(), &[2, 2]);
assert_eq!(evr.len(), 2);
let total: f64 = evr.iter().sum();
assert!(
(total - 1.0).abs() < 1e-8,
"explained variance ratios sum to {} (expected ~1.0)",
total
);
assert!(evr[0] > 0.9, "first component explains {} of variance", evr[0]);
}
#[test]
fn test_pca_single_component() {
let data = Tensor::from_vec(
vec![
1.0, 2.0, 3.0,
4.0, 5.0, 6.0,
7.0, 8.0, 9.0,
],
&[3, 3],
)
.unwrap();
let (transformed, components, evr) = pca(&data, 1).unwrap();
assert_eq!(transformed.shape(), &[3, 1]);
assert_eq!(components.shape(), &[1, 3]);
assert_eq!(evr.len(), 1);
assert!(evr[0] > 0.0 && evr[0] <= 1.0);
}
#[test]
fn test_pca_explained_variance_ratio_bounded() {
let data = Tensor::from_vec(
vec![
1.0, 0.0, 0.5,
0.0, 1.0, 0.5,
1.0, 1.0, 1.0,
2.0, 0.0, 1.0,
0.0, 2.0, 1.0,
],
&[5, 3],
)
.unwrap();
let (_, _, evr) = pca(&data, 3).unwrap();
let total: f64 = evr.iter().sum();
assert!(
total <= 1.0 + 1e-10,
"explained variance ratios sum to {} (should be <= 1.0)",
total
);
for &r in &evr {
assert!(r >= -1e-10, "negative explained variance ratio: {}", r);
}
}
#[test]
fn test_pca_deterministic() {
let data = Tensor::from_vec(
vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
&[2, 3],
)
.unwrap();
let (t1, c1, e1) = pca(&data, 2).unwrap();
let (t2, c2, e2) = pca(&data, 2).unwrap();
assert_eq!(t1.to_vec(), t2.to_vec(), "PCA transformed not deterministic");
assert_eq!(c1.to_vec(), c2.to_vec(), "PCA components not deterministic");
assert_eq!(e1, e2, "PCA explained variance not deterministic");
}
#[test]
fn test_pca_invalid_n_components() {
let data = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap();
assert!(pca(&data, 0).is_err(), "n_components=0 should fail");
assert!(pca(&data, 3).is_err(), "n_components > min(n,p) should fail");
}
fn rosenbrock(p: &[f64]) -> (f64, Vec<f64>) {
let x = p[0];
let y = p[1];
let a = 1.0 - x;
let b = y - x * x;
let val = a * a + 100.0 * b * b;
let gx = -2.0 * a - 400.0 * x * b;
let gy = 200.0 * b;
(val, vec![gx, gy])
}
#[test]
fn test_lbfgs_rosenbrock_converges() {
let mut params = vec![-1.0_f64, 2.0_f64];
let mut state = LbfgsState::new(0.5, 10);
let mut converged = false;
for _iter in 0..200 {
let (_, grads) = rosenbrock(¶ms);
let grad_norm: f64 = kahan_dot(&grads, &grads).sqrt();
if grad_norm < 1e-5 {
converged = true;
break;
}
let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, rosenbrock);
params = new_p;
}
assert!(converged, "L-BFGS did not converge on Rosenbrock; params = {:?}", params);
assert!(
(params[0] - 1.0).abs() < 1e-3,
"x should converge near 1.0, got {}",
params[0]
);
assert!(
(params[1] - 1.0).abs() < 1e-3,
"y should converge near 1.0, got {}",
params[1]
);
}
#[test]
fn test_lbfgs_determinism() {
let init = vec![-1.0_f64, 2.0_f64];
let run = |init: &[f64]| -> Vec<f64> {
let mut params = init.to_vec();
let mut state = LbfgsState::new(0.5, 10);
for _ in 0..20 {
let (_, grads) = rosenbrock(¶ms);
let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, rosenbrock);
params = new_p;
}
params
};
let r1 = run(&init);
let r2 = run(&init);
assert_eq!(r1, r2, "L-BFGS must be bit-identical across runs");
}
#[test]
fn test_lbfgs_simple_quadratic() {
let mut params = vec![3.0_f64];
let mut state = LbfgsState::new(1.0, 5);
let quadratic = |p: &[f64]| -> (f64, Vec<f64>) {
(p[0] * p[0], vec![2.0 * p[0]])
};
for _ in 0..30 {
let (_, grads) = quadratic(¶ms);
let (new_p, _, _) = lbfgs_step(¶ms, &grads, &mut state, quadratic);
params = new_p;
}
assert!(
params[0].abs() < 1e-6,
"L-BFGS should minimize x^2 to ~0, got {}",
params[0]
);
}
#[test]
fn test_embedding_basic() {
let weight = crate::tensor::Tensor::from_vec(
vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6],
&[3, 2],
).unwrap();
let indices = vec![0, 2, 1];
let result = super::embedding(&weight, &indices).unwrap();
assert_eq!(result.shape(), &[3, 2]);
let data = result.to_vec();
assert!((data[0] - 0.1).abs() < 1e-12);
assert!((data[1] - 0.2).abs() < 1e-12);
assert!((data[2] - 0.5).abs() < 1e-12);
assert!((data[3] - 0.6).abs() < 1e-12);
assert!((data[4] - 0.3).abs() < 1e-12);
assert!((data[5] - 0.4).abs() < 1e-12);
}
#[test]
fn test_embedding_out_of_bounds() {
let weight = crate::tensor::Tensor::from_vec(vec![1.0, 2.0], &[1, 2]).unwrap();
let result = super::embedding(&weight, &[1]);
assert!(result.is_err());
}
#[test]
fn test_batch_indices_deterministic() {
let b1 = super::batch_indices(10, 3, 42);
let b2 = super::batch_indices(10, 3, 42);
assert_eq!(b1, b2);
let total: usize = b1.iter().map(|(s, e)| e - s).sum();
assert_eq!(total, 10);
}
#[test]
fn test_wolfe_line_search_armijo() {
let params = vec![3.0_f64];
let direction = vec![-1.0_f64];
let grads = vec![6.0_f64]; let f0 = 9.0;
let mut eval_count = 0;
let mut obj = |p: &[f64]| -> (f64, Vec<f64>) {
eval_count += 1;
(p[0] * p[0], vec![2.0 * p[0]])
};
let (alpha, new_params, new_val, _) =
wolfe_line_search(¶ms, &direction, &mut obj, f0, &grads, 1.0);
let c1 = 1e-4;
let derphi0 = kahan_dot(&grads, &direction); assert!(
new_val <= f0 + c1 * alpha * derphi0,
"Armijo condition violated: {} > {} + {} * {} * {}",
new_val, f0, c1, alpha, derphi0
);
assert!(new_params[0] < 3.0, "Step should move toward minimum");
}
}