use alloc::vec::Vec;
use crate::math;
pub fn mamba_init(n_state: usize) -> Vec<f64> {
let mut log_a = Vec::with_capacity(n_state);
for n in 0..n_state {
log_a.push(math::ln((n + 1) as f64));
}
log_a
}
pub fn s4d_lin_real(n_state: usize) -> Vec<f64> {
let val = math::ln(0.5);
let mut log_a = Vec::with_capacity(n_state);
for _ in 0..n_state {
log_a.push(val);
}
log_a
}
pub fn s4d_inv_real(n_state: usize) -> Vec<f64> {
let n = n_state as f64;
let mut log_a = Vec::with_capacity(n_state);
for i in 0..n_state {
let a_mag = 0.5 + (i as f64) / n;
log_a.push(math::ln(a_mag));
}
log_a
}
pub fn s4d_inv_complex(n_state: usize) -> Vec<f64> {
use core::f64::consts::PI;
let n = n_state as f64;
let mut log_a = Vec::with_capacity(2 * n_state);
for i in 0..n_state {
let a_mag = 0.5 + n / ((i + 1) as f64);
log_a.push(math::ln(a_mag));
let im = PI * ((i + 1) as f64) / n;
log_a.push(im);
}
log_a
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn mamba_init_length_and_values() {
let log_a = mamba_init(4);
assert_eq!(log_a.len(), 4);
assert!(math::abs(log_a[0]) < 1e-12, "log_A[0] should be ln(1)=0");
assert!(
math::abs(log_a[1] - math::ln(2.0)) < 1e-12,
"log_A[1] should be ln(2)"
);
assert!(
math::abs(log_a[3] - math::ln(4.0)) < 1e-12,
"log_A[3] should be ln(4)"
);
}
#[test]
fn mamba_init_produces_negative_a() {
let log_a = mamba_init(8);
for (n, &la) in log_a.iter().enumerate() {
let a = -math::exp(la);
assert!(
a < 0.0,
"A[{}] = {} should be negative (log_A={})",
n,
a,
la
);
let expected = -((n + 1) as f64);
assert!(
math::abs(a - expected) < 1e-10,
"A[{}] expected {}, got {}",
n,
expected,
a
);
}
}
#[test]
fn s4d_lin_all_equal() {
let log_a = s4d_lin_real(5);
assert_eq!(log_a.len(), 5);
let expected = math::ln(0.5);
for (i, &la) in log_a.iter().enumerate() {
assert!(
math::abs(la - expected) < 1e-12,
"log_A[{}] should be ln(0.5), got {}",
i,
la
);
}
}
#[test]
fn s4d_lin_produces_negative_a() {
let log_a = s4d_lin_real(3);
for &la in &log_a {
let a = -math::exp(la);
assert!(a < 0.0, "A should be negative, got {}", a);
assert!(math::abs(a - (-0.5)) < 1e-12, "A should be -0.5, got {}", a);
}
}
#[test]
fn s4d_inv_increasing_magnitude() {
let log_a = s4d_inv_real(8);
assert_eq!(log_a.len(), 8);
for i in 1..log_a.len() {
assert!(
log_a[i] > log_a[i - 1],
"log_A[{}]={} should be > log_A[{}]={}",
i,
log_a[i],
i - 1,
log_a[i - 1]
);
}
}
#[test]
fn s4d_inv_all_negative_a() {
let log_a = s4d_inv_real(16);
for (i, &la) in log_a.iter().enumerate() {
let a = -math::exp(la);
assert!(a < 0.0, "A[{}] should be negative, got {}", i, a);
}
}
#[test]
fn mamba_init_single_state() {
let log_a = mamba_init(1);
assert_eq!(log_a.len(), 1);
assert!(
math::abs(log_a[0]) < 1e-12,
"single state log_A should be ln(1)=0"
);
}
#[test]
fn s4d_inv_first_element() {
let log_a = s4d_inv_real(4);
let expected = math::ln(0.5);
assert!(
math::abs(log_a[0] - expected) < 1e-12,
"s4d_inv log_A[0] should be ln(0.5), got {}",
log_a[0]
);
}
#[test]
fn s4d_inv_complex_length_and_sign() {
let log_a = s4d_inv_complex(8);
assert_eq!(
log_a.len(),
16,
"s4d_inv_complex should return 2*n_state elements"
);
for i in 0..8 {
let a_re = -math::exp(log_a[2 * i]);
assert!(a_re < 0.0, "A_re[{}] should be negative, got {}", i, a_re);
}
for i in 0..8 {
assert!(
log_a[2 * i + 1] > 0.0,
"A_im[{}] should be positive, got {}",
i,
log_a[2 * i + 1]
);
}
}
}