use alloc::vec::Vec;
use crate::math;
pub use crate::rng::Xorshift64;
#[inline]
fn dot(a: &[f64], b: &[f64]) -> f64 {
crate::simd::simd_dot(a, b)
}
#[inline]
pub fn mat_vec(w: &[f64], x: &[f64], rows: usize, cols: usize, out: &mut [f64]) {
debug_assert_eq!(w.len(), rows * cols, "w must be rows*cols");
debug_assert_eq!(x.len(), cols, "x must have cols elements");
debug_assert_eq!(out.len(), rows, "out must have rows elements");
crate::simd::simd_mat_vec(w, x, rows, cols, out);
}
pub fn init_weights(rng: &mut Xorshift64, len: usize) -> Vec<f64> {
let mut w = Vec::with_capacity(len);
for _ in 0..len {
w.push(rng.next_normal() * 0.01);
}
w
}
#[inline]
pub fn fixed_decay(gamma: f64) -> f64 {
gamma
}
#[inline]
pub fn sigmoid_gate(w_gate: &[f64], x: &[f64]) -> f64 {
math::sigmoid(dot(w_gate, x))
}
#[inline]
pub fn exponential_gate(w_decay: &[f64], x: &[f64], initial_decay: f64) -> f64 {
let raw = initial_decay + math::softplus(dot(w_decay, x));
math::exp(-raw)
}
#[inline]
pub fn lstm_gates(w_f: &[f64], w_i: &[f64], x: &[f64]) -> (f64, f64) {
(math::sigmoid(dot(w_f, x)), math::sigmoid(dot(w_i, x)))
}
pub fn vector_decay(w_decay: &[f64], x: &[f64], d_key: usize) -> Vec<f64> {
let d_model = x.len();
debug_assert_eq!(
w_decay.len(),
d_key * d_model,
"w_decay must be d_key * d_model"
);
let scale = -math::ln(0.6); let mut w = Vec::with_capacity(d_key);
for i in 0..d_key {
let row = &w_decay[i * d_model..(i + 1) * d_model];
let raw = dot(row, x);
w.push(math::exp(-scale * math::sigmoid(raw)));
}
w
}
pub fn vector_sigmoid_gate(w_gate: &[f64], x: &[f64], d_key: usize) -> Vec<f64> {
let d_model = x.len();
debug_assert_eq!(
w_gate.len(),
d_key * d_model,
"w_gate must be d_key * d_model"
);
let mut g = Vec::with_capacity(d_key);
for i in 0..d_key {
let row = &w_gate[i * d_model..(i + 1) * d_model];
g.push(math::sigmoid(dot(row, x)));
}
g
}
pub fn vector_lower_bounded_gate(
w_gate: &[f64],
x: &[f64],
d_key: usize,
lower_bound: f64,
) -> Vec<f64> {
let d_model = x.len();
debug_assert_eq!(
w_gate.len(),
d_key * d_model,
"w_gate must be d_key * d_model"
);
let range = 1.0 - lower_bound;
let mut g = Vec::with_capacity(d_key);
for i in 0..d_key {
let row = &w_gate[i * d_model..(i + 1) * d_model];
let raw = dot(row, x);
g.push(lower_bound + range * math::sigmoid(raw));
}
g
}
#[inline]
pub fn extended_sigmoid_gate(w: &[f64], x: &[f64]) -> f64 {
2.0 * math::sigmoid(dot(w, x))
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn fixed_decay_returns_gamma() {
assert!(
(fixed_decay(0.9) - 0.9).abs() < 1e-12,
"fixed_decay(0.9) should return 0.9"
);
assert!(
(fixed_decay(0.0) - 0.0).abs() < 1e-12,
"fixed_decay(0.0) should return 0.0"
);
}
#[test]
fn sigmoid_gate_at_zero_bias() {
let w = vec![0.0; 4];
let x = vec![1.0, 2.0, 3.0, 4.0];
let g = sigmoid_gate(&w, &x);
assert!(
(g - 0.5).abs() < 1e-12,
"sigmoid(0) should be 0.5, got {}",
g
);
}
#[test]
fn sigmoid_gate_large_positive() {
let w = vec![10.0; 4];
let x = vec![1.0; 4];
let g = sigmoid_gate(&w, &x);
assert!(
g > 0.99,
"sigmoid of large positive should be > 0.99, got {}",
g
);
}
#[test]
fn sigmoid_gate_large_negative() {
let w = vec![-10.0; 4];
let x = vec![1.0; 4];
let g = sigmoid_gate(&w, &x);
assert!(
g < 0.01,
"sigmoid of large negative should be < 0.01, got {}",
g
);
}
#[test]
fn exponential_gate_in_unit_interval() {
let w = vec![0.1, -0.1, 0.05, 0.0];
let x = vec![1.0, 2.0, -1.0, 0.5];
let g = exponential_gate(&w, &x, 0.5);
assert!(
g > 0.0 && g < 1.0,
"exponential gate should be in (0, 1), got {}",
g
);
}
#[test]
fn exponential_gate_large_decay_small_output() {
let w = vec![0.0; 4];
let x = vec![0.0; 4];
let g = exponential_gate(&w, &x, 10.0);
assert!(
g < 0.001,
"large decay should produce very small gate, got {}",
g
);
}
#[test]
fn lstm_gates_at_zero() {
let w_f = vec![0.0; 4];
let w_i = vec![0.0; 4];
let x = vec![1.0; 4];
let (f, i) = lstm_gates(&w_f, &w_i, &x);
assert!(
(f - 0.5).abs() < 1e-12,
"forget gate at zero should be 0.5, got {}",
f
);
assert!(
(i - 0.5).abs() < 1e-12,
"input gate at zero should be 0.5, got {}",
i
);
}
#[test]
fn lstm_gates_independent() {
let w_f = vec![10.0; 2];
let w_i = vec![-10.0; 2];
let x = vec![1.0; 2];
let (f, i) = lstm_gates(&w_f, &w_i, &x);
assert!(f > 0.99, "forget gate should be near 1, got {}", f);
assert!(i < 0.01, "input gate should be near 0, got {}", i);
}
#[test]
fn xorshift_deterministic_same_seed() {
let mut rng1 = Xorshift64(42);
let mut rng2 = Xorshift64(42);
for _ in 0..50 {
assert_eq!(
rng1.next_u64(),
rng2.next_u64(),
"same seed must produce same sequence"
);
}
}
#[test]
fn init_weights_correct_length_and_small() {
let mut rng = Xorshift64(123);
let w = init_weights(&mut rng, 100);
assert_eq!(w.len(), 100, "should produce 100 weights");
let max_abs = w.iter().fold(0.0f64, |m, &x| {
let a = if x < 0.0 { -x } else { x };
if a > m {
a
} else {
m
}
});
assert!(
max_abs < 0.5,
"weights with scale 0.01 should be small, max_abs={}",
max_abs
);
}
#[test]
fn vector_decay_bounded() {
let w = vec![0.1, -0.2, 0.3, -0.1, 0.05, 0.15, -0.05, 0.2];
let x = vec![1.0, 2.0, -1.0, 0.5];
let decay = vector_decay(&w, &x, 2);
assert_eq!(decay.len(), 2, "should produce 2 decay values");
for (i, &d) in decay.iter().enumerate() {
assert!(
d > 0.6 && d < 1.0,
"decay[{}] should be in (0.6, 1.0) per Peng et al. 2025 Eq. 8, got {}",
i,
d
);
}
}
#[test]
fn rwkv7_w_t_lower_bound_is_paper_spec() {
let d_key = 4;
let d_model = 4;
let w = vec![100.0f64; d_key * d_model];
let x = vec![1.0f64; d_model];
let decay = vector_decay(&w, &x, d_key);
assert_eq!(decay.len(), d_key, "should produce d_key decay values");
for (i, &d) in decay.iter().enumerate() {
assert!(
(d - 0.6_f64).abs() < 1e-9,
"w_t lower bound must be 0.6 per Peng et al. 2025 Eq. 8, got decay[{}]={}",
i,
d
);
}
}
#[test]
fn vector_sigmoid_gate_bounded() {
let w = vec![10.0, 0.0, -10.0, 0.0]; let x = vec![1.0, 0.0];
let g = vector_sigmoid_gate(&w, &x, 2);
assert!(g[0] > 0.99, "large positive should give ~1, got {}", g[0]);
assert!(g[1] < 0.01, "large negative should give ~0, got {}", g[1]);
}
#[test]
fn extended_sigmoid_gate_range() {
let w_pos = vec![100.0];
let w_neg = vec![-100.0];
let x = vec![1.0];
let high = extended_sigmoid_gate(&w_pos, &x);
let low = extended_sigmoid_gate(&w_neg, &x);
assert!(high > 1.99, "large positive should give ~2.0, got {}", high);
assert!(low < 0.01, "large negative should give ~0.0, got {}", low);
let mid = extended_sigmoid_gate(&[0.0], &[0.0]);
assert!(
(mid - 1.0).abs() < 1e-6,
"zero input should give 1.0, got {}",
mid
);
}
#[test]
fn vector_lower_bounded_gate_range() {
let w = vec![10.0, 0.0, -10.0, 0.0]; let x = vec![1.0, 0.0];
let lower_bound = 0.9;
let g = vector_lower_bounded_gate(&w, &x, 2, lower_bound);
assert!(
g[0] > 0.999,
"large positive should give ~1.0, got {}",
g[0]
);
assert!(
(g[1] - lower_bound).abs() < 0.001,
"large negative should give ~lower_bound ({}), got {}",
lower_bound,
g[1]
);
}
#[test]
fn vector_lower_bounded_gate_zero_bound() {
let w = vec![0.0, 0.0]; let x = vec![0.0, 0.0];
let g = vector_lower_bounded_gate(&w, &x, 1, 0.0);
assert!(
(g[0] - 0.5).abs() < 1e-12,
"with lb=0 and zero input, gate should be sigmoid(0)=0.5, got {}",
g[0]
);
}
#[test]
fn vector_lower_bounded_gate_at_midpoint() {
let w = vec![0.0, 0.0]; let x = vec![0.0, 0.0];
let lb = 0.9;
let g = vector_lower_bounded_gate(&w, &x, 1, lb);
let expected = lb + (1.0 - lb) * 0.5; assert!(
(g[0] - expected).abs() < 1e-12,
"at zero input with lb=0.9, gate should be {}, got {}",
expected,
g[0]
);
}
#[test]
fn mat_vec_basic() {
let w = vec![1.0, 2.0, 3.0, 4.0];
let x = vec![1.0, 1.0];
let mut out = vec![0.0; 2];
mat_vec(&w, &x, 2, 2, &mut out);
assert!((out[0] - 3.0).abs() < 1e-12, "row 0: 1+2=3, got {}", out[0]);
assert!((out[1] - 7.0).abs() < 1e-12, "row 1: 3+4=7, got {}", out[1]);
}
}