use crate::math;
pub use crate::rng::Xorshift64;
#[inline]
pub fn mat_vec(w: &[f64], x: &[f64], rows: usize, cols: usize, out: &mut [f64]) {
debug_assert_eq!(w.len(), rows * cols, "w must be rows*cols");
debug_assert_eq!(x.len(), cols, "x must have cols elements");
debug_assert_eq!(out.len(), rows, "out must have rows elements");
crate::simd::simd_mat_vec(w, x, rows, cols, out);
}
#[inline]
pub fn dot(a: &[f64], b: &[f64]) -> f64 {
debug_assert_eq!(a.len(), b.len(), "dot product requires equal lengths");
crate::simd::simd_dot(a, b)
}
#[inline]
pub fn softplus(x: f64) -> f64 {
math::softplus(x)
}
#[inline]
pub fn sigmoid(x: f64) -> f64 {
math::sigmoid(x)
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
use alloc::vec::Vec;
#[test]
fn mat_vec_identity() {
let w = vec![1.0, 0.0, 0.0, 1.0];
let x = vec![3.0, 4.0];
let mut out = vec![0.0; 2];
mat_vec(&w, &x, 2, 2, &mut out);
assert!(
math::abs(out[0] - 3.0) < 1e-12,
"expected 3.0, got {}",
out[0]
);
assert!(
math::abs(out[1] - 4.0) < 1e-12,
"expected 4.0, got {}",
out[1]
);
}
#[test]
fn mat_vec_rectangular() {
let w = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let x = vec![1.0, 1.0];
let mut out = vec![0.0; 3];
mat_vec(&w, &x, 3, 2, &mut out);
assert!(math::abs(out[0] - 3.0) < 1e-12, "row 0: 1+2=3");
assert!(math::abs(out[1] - 7.0) < 1e-12, "row 1: 3+4=7");
assert!(math::abs(out[2] - 11.0) < 1e-12, "row 2: 5+6=11");
}
#[test]
fn dot_product_basic() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
let result = dot(&a, &b);
assert!(
math::abs(result - 32.0) < 1e-12,
"expected 32.0, got {}",
result
);
}
#[test]
fn dot_product_orthogonal() {
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let result = dot(&a, &b);
assert!(
math::abs(result) < 1e-12,
"orthogonal vectors should have dot=0"
);
}
#[test]
fn softplus_large_positive() {
let result = softplus(50.0);
assert!(
math::abs(result - 50.0) < 1e-10,
"softplus(50) should be ~50, got {}",
result
);
}
#[test]
fn softplus_large_negative() {
let result = softplus(-50.0);
assert!(
(0.0..1e-20).contains(&result),
"softplus(-50) should be ~0, got {}",
result
);
}
#[test]
fn softplus_zero() {
let result = softplus(0.0);
let expected = math::ln(2.0);
assert!(
math::abs(result - expected) < 1e-12,
"softplus(0) should be ln(2)={}, got {}",
expected,
result
);
}
#[test]
fn softplus_always_positive() {
let values = [-10.0, -1.0, 0.0, 1.0, 10.0];
for &x in &values {
let result = softplus(x);
assert!(
result > 0.0,
"softplus({}) should be > 0, got {}",
x,
result
);
}
}
#[test]
fn sigmoid_range() {
let moderate = [-10.0, -1.0, 0.0, 1.0, 10.0];
for &x in &moderate {
let result = sigmoid(x);
assert!(
result > 0.0 && result < 1.0,
"sigmoid({}) should be in (0,1), got {}",
x,
result
);
}
let extreme = [-100.0, 100.0];
for &x in &extreme {
let result = sigmoid(x);
assert!(
(0.0..=1.0).contains(&result),
"sigmoid({}) should be in [0,1], got {}",
x,
result
);
}
}
#[test]
fn sigmoid_zero() {
let result = sigmoid(0.0);
assert!(
math::abs(result - 0.5) < 1e-12,
"sigmoid(0) should be 0.5, got {}",
result
);
}
#[test]
fn sigmoid_symmetry() {
let x = 3.0;
let s_pos = sigmoid(x);
let s_neg = sigmoid(-x);
assert!(
math::abs(s_pos + s_neg - 1.0) < 1e-12,
"sigmoid(x) + sigmoid(-x) should be 1.0"
);
}
#[test]
fn xorshift_deterministic() {
let mut rng1 = Xorshift64(42);
let mut rng2 = Xorshift64(42);
for _ in 0..100 {
assert_eq!(
rng1.next_u64(),
rng2.next_u64(),
"same seed should produce same sequence"
);
}
}
#[test]
fn xorshift_f64_in_unit_interval() {
let mut rng = Xorshift64(12345);
for i in 0..1000 {
let val = rng.next_f64();
assert!(
(0.0..1.0).contains(&val),
"next_f64() sample {} = {} not in [0,1)",
i,
val
);
}
}
#[test]
fn xorshift_normal_distribution() {
let mut rng = Xorshift64(9999);
let n = 10000;
let mut sum = 0.0;
let mut sum_sq = 0.0;
for _ in 0..n {
let x = rng.next_normal();
sum += x;
sum_sq += x * x;
}
let mean = sum / n as f64;
let variance = sum_sq / n as f64 - mean * mean;
assert!(
math::abs(mean) < 0.05,
"normal mean should be ~0, got {}",
mean
);
assert!(
math::abs(variance - 1.0) < 0.1,
"normal variance should be ~1, got {}",
variance
);
}
#[test]
fn mat_vec_single_element() {
let w = vec![7.0];
let x = vec![3.0];
let mut out = vec![0.0];
mat_vec(&w, &x, 1, 1, &mut out);
assert!(math::abs(out[0] - 21.0) < 1e-12, "7*3=21");
}
#[test]
fn softplus_moderate_values() {
let x = 5.0;
let expected = math::ln(1.0 + math::exp(5.0));
let result = softplus(x);
assert!(
math::abs(result - expected) < 1e-10,
"softplus(5) expected {}, got {}",
expected,
result
);
}
#[test]
fn xorshift_different_seeds_differ() {
let mut rng1 = Xorshift64(1);
let mut rng2 = Xorshift64(2);
let seq1: Vec<u64> = (0..10).map(|_| rng1.next_u64()).collect();
let seq2: Vec<u64> = (0..10).map(|_| rng2.next_u64()).collect();
assert_ne!(
seq1, seq2,
"different seeds should produce different sequences"
);
}
}