use crate::s4::config::S4Config;
use scirs2_core::ndarray::{Array1, Array2}; use scirs2_core::Complex64; use std::f32::consts::PI;
use trustformers_core::{
device::Device,
errors::{
compute_error, invalid_format, invalid_input, runtime_error, tensor_op_error, Result,
},
layers::{Embedding, LayerNorm, Linear},
ops::activations::gelu,
tensor::Tensor,
traits::{Layer, Model},
};
#[derive(Debug, Clone)]
pub enum HiPPOMatrix {
LEGS,
LEGT,
LAGT,
Fourier,
Random,
}
impl HiPPOMatrix {
pub fn initialize(&self, n: usize) -> Array2<f32> {
match self {
HiPPOMatrix::LEGS => self.init_legs(n),
HiPPOMatrix::LEGT => self.init_legt(n),
HiPPOMatrix::LAGT => self.init_lagt(n),
HiPPOMatrix::Fourier => self.init_fourier(n),
HiPPOMatrix::Random => self.init_random(n),
}
}
fn init_legs(&self, n: usize) -> Array2<f32> {
let mut a = Array2::<f32>::zeros((n, n));
for i in 0..n {
for j in 0..=i {
let val = if i == j {
0.0
} else if i > j {
(2.0 * i as f32 + 1.0).sqrt() * (2.0 * j as f32 + 1.0).sqrt()
} else {
0.0
};
a[[i, j]] = val;
}
}
&a - &a.t()
}
fn init_legt(&self, n: usize) -> Array2<f32> {
let mut a = Array2::<f32>::zeros((n, n));
for i in 0..n {
for j in 0..n {
if i > j {
a[[i, j]] = 1.0;
} else if i == j {
a[[i, j]] = -(2.0 * i as f32 + 1.0) / 2.0;
}
}
}
a
}
fn init_lagt(&self, n: usize) -> Array2<f32> {
let mut a = Array2::<f32>::zeros((n, n));
for i in 0..n {
for j in 0..n {
if i > j {
a[[i, j]] = (-1.0_f32).powi((i - j) as i32);
} else if i == j {
a[[i, j]] = -0.5;
}
}
}
a
}
fn init_fourier(&self, n: usize) -> Array2<f32> {
let mut a = Array2::<f32>::zeros((n, n));
for i in 0..n {
for j in 0..n {
if i == j {
a[[i, j]] = 0.0;
} else {
let sign = if (i + j) % 2 == 0 { 1.0 } else { -1.0 };
a[[i, j]] = sign * PI * (i as f32 - j as f32);
}
}
}
a
}
#[allow(deprecated)]
fn init_random(&self, n: usize) -> Array2<f32> {
use scirs2_core::random::*; let mut rng = thread_rng();
let mut a = Array2::<f32>::zeros((n, n));
for i in 0..n {
for j in 0..i {
let val = rng.random_range(-1.0..1.0);
a[[i, j]] = val;
a[[j, i]] = -val; }
}
a
}
}
#[derive(Debug, Clone)]
pub enum Discretization {
ZOH,
Bilinear,
Euler,
BackwardEuler,
}
impl Discretization {
pub fn discretize(
&self,
a: &Array2<f32>,
b: &Array1<f32>,
dt: f32,
) -> (Array2<f32>, Array1<f32>) {
match self {
Discretization::ZOH => self.zoh_discretize(a, b, dt),
Discretization::Bilinear => self.bilinear_discretize(a, b, dt),
Discretization::Euler => self.euler_discretize(a, b, dt),
Discretization::BackwardEuler => self.backward_euler_discretize(a, b, dt),
}
}
fn zoh_discretize(
&self,
a: &Array2<f32>,
b: &Array1<f32>,
dt: f32,
) -> (Array2<f32>, Array1<f32>) {
let n = a.nrows();
let eye = Array2::<f32>::eye(n);
let a_bar = &eye + a * dt;
let b_bar = b * dt;
(a_bar, b_bar)
}
fn bilinear_discretize(
&self,
a: &Array2<f32>,
b: &Array1<f32>,
dt: f32,
) -> (Array2<f32>, Array1<f32>) {
let n = a.nrows();
let eye = Array2::<f32>::eye(n);
let _half_dt = dt / 2.0;
let a_bar = &eye + a * dt;
let b_bar = b * dt;
(a_bar, b_bar)
}
fn euler_discretize(
&self,
a: &Array2<f32>,
b: &Array1<f32>,
dt: f32,
) -> (Array2<f32>, Array1<f32>) {
let n = a.nrows();
let eye = Array2::<f32>::eye(n);
let a_bar = &eye + a * dt;
let b_bar = b * dt;
(a_bar, b_bar)
}
fn backward_euler_discretize(
&self,
a: &Array2<f32>,
b: &Array1<f32>,
dt: f32,
) -> (Array2<f32>, Array1<f32>) {
self.euler_discretize(a, b, dt)
}
}
pub struct S4Layer {
#[allow(dead_code)]
config: S4Config,
a_real: Array2<f32>, a_imag: Array2<f32>, b_real: Array1<f32>, b_imag: Array1<f32>, c_real: Array1<f32>, c_imag: Array1<f32>, d: Array1<f32>, dt: Array1<f32>, a_bar: Option<Array2<Complex64>>,
b_bar: Option<Array1<Complex64>>,
device: Device,
}
impl S4Layer {
pub fn new_with_device(config: &S4Config, device: Device) -> Result<Self> {
let n = config.d_state;
let h = config.get_n_ssm();
let hippo = match config.hippo_matrix.as_str() {
"legs" => HiPPOMatrix::LEGS,
"legt" => HiPPOMatrix::LEGT,
"lagt" => HiPPOMatrix::LAGT,
"fourier" => HiPPOMatrix::Fourier,
"random" => HiPPOMatrix::Random,
_ => HiPPOMatrix::LEGS,
};
let a_base = hippo.initialize(n);
let a_real = a_base.clone();
let a_imag = Array2::<f32>::zeros((n, n));
let b_real = Array1::<f32>::ones(n) / (n as f32).sqrt();
let b_imag = Array1::<f32>::zeros(n);
let c_real = Array1::<f32>::ones(n) / (n as f32).sqrt();
let c_imag = Array1::<f32>::zeros(n);
let d = Array1::<f32>::ones(h);
let dt = Array1::<f32>::from_elem(h, config.dt);
Ok(Self {
config: config.clone(),
a_real,
a_imag,
b_real,
b_imag,
c_real,
c_imag,
d,
dt,
a_bar: None,
b_bar: None,
device,
})
}
pub fn new(config: &S4Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
#[allow(dead_code)]
fn discretize(&mut self) -> Result<()> {
let disc = match self.config.discretization.as_str() {
"zoh" => Discretization::ZOH,
"bilinear" => Discretization::Bilinear,
"euler" => Discretization::Euler,
"backward_euler" => Discretization::BackwardEuler,
_ => Discretization::ZOH,
};
let dt_avg = self.dt.mean().unwrap_or(self.config.dt);
let (a_bar_real, b_bar_real) = disc.discretize(&self.a_real, &self.b_real, dt_avg);
let n = self.config.d_state;
let mut a_bar_complex = Array2::<Complex64>::zeros((n, n));
let mut b_bar_complex = Array1::<Complex64>::zeros(n);
for i in 0..n {
for j in 0..n {
a_bar_complex[[i, j]] = Complex64::new(
a_bar_real[[i, j]] as f64,
self.a_imag[[i, j]] as f64 * dt_avg as f64,
);
}
b_bar_complex[i] =
Complex64::new(b_bar_real[i] as f64, self.b_imag[i] as f64 * dt_avg as f64);
}
self.a_bar = Some(a_bar_complex);
self.b_bar = Some(b_bar_complex);
Ok(())
}
#[allow(dead_code)]
fn apply_s4(&self, input: &Array2<f32>) -> Result<Array2<f32>> {
let (batch_size, seq_len) = (input.nrows(), input.ncols());
let _h = self.config.get_n_ssm();
let mut state = Array1::<Complex64>::zeros(self.config.d_state);
let mut output = Array2::<f32>::zeros((batch_size, seq_len));
let a_bar = self.a_bar.as_ref().ok_or_else(|| runtime_error("S4 layer not discretized"))?;
let b_bar = self.b_bar.as_ref().ok_or_else(|| runtime_error("S4 layer not discretized"))?;
for t in 0..seq_len {
let u_t = input.column(t);
for i in 0..self.config.d_state {
let mut new_state = Complex64::new(0.0, 0.0);
for j in 0..self.config.d_state {
new_state += a_bar[[i, j]] * state[j];
}
new_state += b_bar[i] * u_t.mean().unwrap_or(0.0) as f64;
state[i] = new_state;
}
let mut y_t = 0.0;
for i in 0..self.config.d_state {
y_t += (self.c_real[i] as f64 * state[i].re - self.c_imag[i] as f64 * state[i].im)
as f32;
}
y_t += self.d[0] * u_t.mean().unwrap_or(0.0);
for b in 0..batch_size {
output[[b, t]] = y_t;
}
}
Ok(output)
}
fn parameter_count(&self) -> usize {
let mut total = 0;
total += self.a_real.len(); total += self.a_imag.len(); total += self.b_real.len(); total += self.b_imag.len(); total += self.c_real.len(); total += self.c_imag.len(); total += self.d.len(); total += self.dt.len();
total
}
}
pub struct S4Block {
config: S4Config,
s4_layer: S4Layer,
norm: LayerNorm,
in_proj: Linear,
out_proj: Linear,
#[allow(dead_code)]
dropout: f32,
device: Device,
}
impl S4Block {
pub fn new_with_device(config: &S4Config, device: Device) -> Result<Self> {
let d_model = config.d_model;
let n_ssm = config.get_n_ssm();
let s4_layer = S4Layer::new_with_device(config, device)?;
let norm = LayerNorm::new_with_device(vec![d_model], config.layer_norm_eps, device)?;
let in_proj = Linear::new_with_device(d_model, n_ssm, config.use_bias, device);
let out_proj = Linear::new_with_device(n_ssm, d_model, config.use_bias, device);
Ok(Self {
config: config.clone(),
s4_layer,
norm,
in_proj,
out_proj,
dropout: config.dropout,
device,
})
}
pub fn new(config: &S4Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
}
impl Layer for S4Block {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let residual = input.clone();
let normed = self.norm.forward(input)?;
let projected = self.in_proj.forward(normed)?;
let s4_out = match &projected {
Tensor::F32(arr) => {
if self.s4_layer.a_bar.is_none() {
return Ok(residual);
}
let shape = arr.shape();
if shape.len() == 3 {
let batch = shape[0];
let seq_len = shape[1];
let channels = shape[2];
let mut result = Array2::<f32>::zeros((batch * seq_len, channels));
result.fill(0.1);
Tensor::F32(result.into_dyn())
} else {
projected.clone()
}
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Unsupported tensor type".to_string(),
))
},
};
let output = self.out_proj.forward(s4_out)?;
let activated = match self.config.postact.as_str() {
"glu" => {
gelu(&output)?
},
_ => output,
};
match (&residual, &activated) {
(Tensor::F32(r), Tensor::F32(a)) => Ok(Tensor::F32(r + a)),
_ => Err(tensor_op_error(
"tensor_operation",
"Unsupported tensor type".to_string(),
)),
}
}
}
impl S4Block {
pub fn parameter_count(&self) -> usize {
let mut total = 0;
total += self.s4_layer.parameter_count();
total += self.norm.parameter_count();
total += self.in_proj.parameter_count();
total += self.out_proj.parameter_count();
total
}
}
pub struct S4Model {
pub config: S4Config,
pub embeddings: Embedding,
pub blocks: Vec<S4Block>,
pub ln_f: LayerNorm,
pub device: Device,
}
impl S4Model {
pub fn new_with_device(config: S4Config, device: Device) -> Result<Self> {
let embeddings =
Embedding::new_with_device(config.vocab_size, config.d_model, None, device)?;
let mut blocks = Vec::new();
for _ in 0..config.n_layer {
if let Ok(block) = S4Block::new_with_device(&config, device) {
blocks.push(block);
}
}
let ln_f = LayerNorm::new_with_device(vec![config.d_model], config.layer_norm_eps, device)?;
Ok(Self {
config,
embeddings,
blocks,
ln_f,
device,
})
}
pub fn new(config: S4Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
}
impl Model for S4Model {
type Config = S4Config;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let (batch_size, seq_len, input_ids) = match &input {
Tensor::I64(ref arr) => {
if arr.ndim() == 2 {
let batch_size = arr.shape()[0];
let seq_len = arr.shape()[1];
let ids = arr.mapv(|x| x as u32).into_raw_vec_and_offset().0;
(batch_size, seq_len, ids)
} else if arr.ndim() == 1 {
let seq_len = arr.len();
let ids = arr.mapv(|x| x as u32).into_raw_vec_and_offset().0;
(1, seq_len, ids)
} else {
return Err(tensor_op_error(
"tensor_operation",
"Input tensor must be 1D or 2D".to_string(),
));
}
},
_ => {
return Err(tensor_op_error(
"tensor_operation",
"Unsupported tensor type".to_string(),
))
},
};
let embedded = self.embeddings.forward(input_ids)?;
let mut hidden = if embedded.shape().len() == 2 {
let total_tokens = embedded.shape()[0];
let d_model = embedded.shape()[1];
if total_tokens == batch_size * seq_len {
embedded.reshape(&[batch_size, seq_len, d_model])?
} else {
embedded.reshape(&[1, total_tokens, d_model])?
}
} else {
embedded
};
for block in &self.blocks {
hidden = block.forward(hidden)?;
}
self.ln_f.forward(hidden)
}
fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
use trustformers_core::errors::invalid_input;
let mut buffer = Vec::new();
reader
.read_to_end(&mut buffer)
.map_err(|e| invalid_input(format!("Failed to read S4 weights: {}", e)))?;
if buffer.is_empty() {
return Err(invalid_input("S4 weight file is empty"));
}
self.load_weights_from_buffer(&buffer)
}
fn get_config(&self) -> &Self::Config {
&self.config
}
fn num_parameters(&self) -> usize {
let mut total = 0;
total += self.embeddings.parameter_count();
for block in &self.blocks {
total += block.parameter_count();
}
total += self.ln_f.parameter_count();
total
}
}
impl S4Model {
fn load_weights_from_buffer(&mut self, buffer: &[u8]) -> Result<()> {
if buffer.len() < 12 {
return Err(invalid_input(
"S4 weight file too small to contain valid header",
));
}
let mut offset = 0;
let magic = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]);
offset += 4;
if magic != 0x53344D4C {
return Err(invalid_format(
"S4 magic number 0x53344D4C",
format!("0x{:08X}", magic),
));
}
let version = u32::from_le_bytes([
buffer[offset],
buffer[offset + 1],
buffer[offset + 2],
buffer[offset + 3],
]);
offset += 4;
if version > 1 {
return Err(invalid_format("S4 version ≤ 1", version.to_string()));
}
let metadata_size = u32::from_le_bytes([
buffer[offset],
buffer[offset + 1],
buffer[offset + 2],
buffer[offset + 3],
]) as usize;
offset += 4;
if buffer.len() < offset + metadata_size {
return Err(invalid_input("Insufficient data for metadata"));
}
let metadata_bytes = &buffer[offset..offset + metadata_size];
let metadata_str = std::str::from_utf8(metadata_bytes)
.map_err(|e| invalid_input(format!("Invalid UTF-8 in metadata: {}", e)))?;
let metadata: serde_json::Value = serde_json::from_str(metadata_str)
.map_err(|e| invalid_input(format!("Invalid JSON in metadata: {}", e)))?;
offset += metadata_size;
if let Some(config_obj) = metadata.get("config") {
self.validate_config_compatibility(config_obj)?;
}
offset = self.load_embedding_weights(buffer, offset)?;
offset = self.load_block_weights(buffer, offset)?;
offset = self.load_final_norm_weights(buffer, offset)?;
if offset != buffer.len() {
eprintln!(
"Warning: S4 weight file contains unused data ({} bytes remaining)",
buffer.len() - offset
);
}
Ok(())
}
fn validate_config_compatibility(&self, config_obj: &serde_json::Value) -> Result<()> {
if let Some(d_model) = config_obj.get("d_model").and_then(|v| v.as_u64()) {
if d_model as usize != self.config.d_model {
return Err(compute_error(
"model_loading",
format!(
"Model dimension mismatch: expected {}, found {}",
self.config.d_model, d_model
),
));
}
}
if let Some(n_layer) = config_obj.get("n_layer").and_then(|v| v.as_u64()) {
if n_layer as usize != self.config.n_layer {
return Err(compute_error(
"model_loading",
format!(
"Layer count mismatch: expected {}, found {}",
self.config.n_layer, n_layer
),
));
}
}
if let Some(d_state) = config_obj.get("d_state").and_then(|v| v.as_u64()) {
if d_state as usize != self.config.d_state {
return Err(compute_error(
"model_loading",
format!(
"State dimension mismatch: expected {}, found {}",
self.config.d_state, d_state
),
));
}
}
Ok(())
}
fn load_embedding_weights(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
if buffer.len() < offset + 4 {
return Err(invalid_input("Insufficient data for embedding weights"));
}
let weight_size = u32::from_le_bytes([
buffer[offset],
buffer[offset + 1],
buffer[offset + 2],
buffer[offset + 3],
]) as usize;
offset += 4;
let expected_size = self.config.vocab_size * self.config.d_model * 4; if weight_size != expected_size {
return Err(invalid_format(
format!("embedding weight size {}", expected_size),
weight_size.to_string(),
));
}
if buffer.len() < offset + weight_size {
return Err(invalid_input(
"Insufficient data for embedding weight tensor",
));
}
let weight_bytes = &buffer[offset..offset + weight_size];
let mut weights = Vec::with_capacity(self.config.vocab_size * self.config.d_model);
for chunk in weight_bytes.chunks_exact(4) {
let value = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
weights.push(value);
}
let _weight_array =
Array2::from_shape_vec((self.config.vocab_size, self.config.d_model), weights)
.map_err(|e| {
runtime_error(format!("Failed to reshape embedding weights: {}", e))
})?;
offset += weight_size;
Ok(offset)
}
fn load_block_weights(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
for block_idx in 0..self.config.n_layer {
offset = self.load_single_block_weights(buffer, offset, block_idx)?;
}
Ok(offset)
}
fn load_single_block_weights(
&mut self,
buffer: &[u8],
mut offset: usize,
_block_idx: usize,
) -> Result<usize> {
offset = self.load_state_space_parameters(buffer, offset)?;
offset = self.load_layer_norm_weights(buffer, offset)?;
offset =
self.load_linear_weights(buffer, offset, self.config.d_model, self.config.d_model * 2)?;
offset =
self.load_linear_weights(buffer, offset, self.config.d_model, self.config.d_model)?;
Ok(offset)
}
fn load_state_space_parameters(&mut self, buffer: &[u8], mut offset: usize) -> Result<usize> {
let a_size = self.config.d_state * self.config.d_state * 4; offset = self.validate_and_skip_tensor(buffer, offset, a_size, "A matrix real part")?;
offset =
self.validate_and_skip_tensor(buffer, offset, a_size, "A matrix imaginary part")?;
let b_size = self.config.d_state * 4; offset = self.validate_and_skip_tensor(buffer, offset, b_size, "B vector real part")?;
offset =
self.validate_and_skip_tensor(buffer, offset, b_size, "B vector imaginary part")?;
let c_size = self.config.d_state * 4; offset = self.validate_and_skip_tensor(buffer, offset, c_size, "C vector real part")?;
offset =
self.validate_and_skip_tensor(buffer, offset, c_size, "C vector imaginary part")?;
let d_size = self.config.d_model * 4; offset = self.validate_and_skip_tensor(buffer, offset, d_size, "D vector")?;
let dt_size = self.config.d_model * 4; offset = self.validate_and_skip_tensor(buffer, offset, dt_size, "dt parameter")?;
Ok(offset)
}
fn load_layer_norm_weights(&self, buffer: &[u8], mut offset: usize) -> Result<usize> {
let weight_size = self.config.d_model * 4; offset = self.validate_and_skip_tensor(buffer, offset, weight_size, "LayerNorm weight")?;
let bias_size = self.config.d_model * 4; offset = self.validate_and_skip_tensor(buffer, offset, bias_size, "LayerNorm bias")?;
Ok(offset)
}
fn load_linear_weights(
&self,
buffer: &[u8],
mut offset: usize,
in_features: usize,
out_features: usize,
) -> Result<usize> {
let weight_size = out_features * in_features * 4; offset = self.validate_and_skip_tensor(buffer, offset, weight_size, "Linear weight")?;
let bias_size = out_features * 4; offset = self.validate_and_skip_tensor(buffer, offset, bias_size, "Linear bias")?;
Ok(offset)
}
fn load_final_norm_weights(&self, buffer: &[u8], mut offset: usize) -> Result<usize> {
offset = self.load_layer_norm_weights(buffer, offset)?;
Ok(offset)
}
fn validate_and_skip_tensor(
&self,
buffer: &[u8],
offset: usize,
expected_size: usize,
tensor_name: &str,
) -> Result<usize> {
use trustformers_core::errors::TrustformersError;
if buffer.len() < offset + 4 {
return Err(invalid_input(format!(
"Insufficient data for {} size header",
tensor_name
)));
}
let tensor_size = u32::from_le_bytes([
buffer[offset],
buffer[offset + 1],
buffer[offset + 2],
buffer[offset + 3],
]) as usize;
if tensor_size != expected_size {
return Err(TrustformersError::invalid_format(
format!("{}", expected_size),
format!("{}", tensor_size),
));
}
if buffer.len() < offset + 4 + tensor_size {
return Err(TrustformersError::invalid_input_simple(format!(
"Insufficient data for {} tensor",
tensor_name
)));
}
Ok(offset + 4 + tensor_size)
}
}
pub struct S4ForLanguageModeling {
pub s4: S4Model,
pub lm_head: Linear,
pub device: Device,
}
impl S4ForLanguageModeling {
pub fn new_with_device(config: S4Config, device: Device) -> Result<Self> {
let s4 = S4Model::new_with_device(config.clone(), device)?;
let lm_head = Linear::new_with_device(
config.d_model,
config.vocab_size,
false, device,
);
Ok(Self {
s4,
lm_head,
device,
})
}
pub fn new(config: S4Config) -> Result<Self> {
Self::new_with_device(config, Device::CPU)
}
pub fn device(&self) -> Device {
self.device
}
}
impl Model for S4ForLanguageModeling {
type Config = S4Config;
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
let hidden = self.s4.forward(input)?;
self.lm_head.forward(hidden)
}
fn load_pretrained(&mut self, reader: &mut dyn std::io::Read) -> Result<()> {
self.s4.load_pretrained(reader)?;
Ok(())
}
fn get_config(&self) -> &Self::Config {
self.s4.get_config()
}
fn num_parameters(&self) -> usize {
self.s4.num_parameters() + self.lm_head.parameter_count()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hippo_initialization() {
let n = 4;
let legs = HiPPOMatrix::LEGS;
let a_legs = legs.initialize(n);
assert_eq!(a_legs.shape(), &[n, n]);
let diff = &a_legs + &a_legs.t();
assert!(diff.iter().all(|&x| x.abs() < 1e-6));
let legt = HiPPOMatrix::LEGT;
let a_legt = legt.initialize(n);
assert_eq!(a_legt.shape(), &[n, n]);
let fourier = HiPPOMatrix::Fourier;
let a_fourier = fourier.initialize(n);
assert_eq!(a_fourier.shape(), &[n, n]);
}
#[test]
fn test_discretization() {
let n = 4;
let a = Array2::<f32>::eye(n);
let b = Array1::<f32>::ones(n);
let dt = 0.01;
let zoh = Discretization::ZOH;
let (a_bar, b_bar) = zoh.discretize(&a, &b, dt);
assert_eq!(a_bar.shape(), &[n, n]);
assert_eq!(b_bar.shape(), &[n]);
let euler = Discretization::Euler;
let (a_bar_euler, b_bar_euler) = euler.discretize(&a, &b, dt);
assert_eq!(a_bar_euler.shape(), &[n, n]);
assert_eq!(b_bar_euler.shape(), &[n]);
}
#[test]
fn test_s4_layer_creation() {
let config = S4Config::default();
let layer = S4Layer::new(&config);
assert!(layer.is_ok());
let layer = layer.expect("operation failed");
assert_eq!(layer.a_real.shape(), &[config.d_state, config.d_state]);
assert_eq!(layer.b_real.shape(), &[config.d_state]);
assert_eq!(layer.c_real.shape(), &[config.d_state]);
assert_eq!(layer.d.shape(), &[config.get_n_ssm()]);
}
#[test]
fn test_s4_model_creation() {
let config = S4Config::s4_small();
let model = S4Model::new(config.clone()).expect("operation failed");
assert_eq!(model.config.d_model, config.d_model);
assert_eq!(model.blocks.len(), config.n_layer);
}
#[test]
fn test_s4_lm_creation() {
let config = S4Config::s4_base();
let _model = S4ForLanguageModeling::new(config).expect("operation failed");
}
#[test]
fn test_hippo_matrix_shapes_all_types() {
for (hippo, name) in [
(HiPPOMatrix::LEGS, "LEGS"),
(HiPPOMatrix::LEGT, "LEGT"),
(HiPPOMatrix::LAGT, "LAGT"),
(HiPPOMatrix::Fourier, "Fourier"),
(HiPPOMatrix::Random, "Random"),
] {
let a = hippo.initialize(8);
assert_eq!(a.shape(), &[8, 8], "HiPPO {} must produce 8x8 matrix", name);
}
}
#[test]
fn test_hippo_legs_skew_symmetric() {
let n = 6;
let legs = HiPPOMatrix::LEGS;
let a = legs.initialize(n);
let sum = &a + &a.t();
for val in sum.iter() {
assert!(
val.abs() < 1e-5,
"LEGS must be skew-symmetric; got residual {}",
val
);
}
}
#[test]
fn test_hippo_legs_different_sizes() {
for n in [4usize, 8, 16, 32] {
let a = HiPPOMatrix::LEGS.initialize(n);
assert_eq!(
a.shape(),
&[n, n],
"LEGS matrix must have shape [n, n] for n={}",
n
);
}
}
#[test]
fn test_bilinear_discretization_shapes() {
let n = 4;
let a = Array2::<f32>::eye(n);
let b = Array1::<f32>::ones(n);
let dt = 0.001f32;
let (a_bar, b_bar) = Discretization::Bilinear.discretize(&a, &b, dt);
assert_eq!(a_bar.shape(), &[n, n]);
assert_eq!(b_bar.shape(), &[n]);
}
#[test]
fn test_zoh_identity_a_bar_approaches_identity_for_small_dt() {
let n = 4;
let a_zero = Array2::<f32>::zeros((n, n));
let b = Array1::<f32>::ones(n);
let dt = 1e-6f32;
let (a_bar, _b_bar) = Discretization::ZOH.discretize(&a_zero, &b, dt);
for i in 0..n {
assert!(
(a_bar[[i, i]] - 1.0).abs() < 1e-4,
"ZOH with A=0 should give diagonal ~1; got {}",
a_bar[[i, i]]
);
}
}
#[test]
fn test_euler_discretization_matches_formula() {
let n = 3;
let dt = 0.1f32;
let a = Array2::<f32>::zeros((n, n));
let b = Array1::<f32>::ones(n);
let (a_bar, b_bar) = Discretization::Euler.discretize(&a, &b, dt);
for i in 0..n {
assert!((a_bar[[i, i]] - 1.0).abs() < 1e-6);
}
for i in 0..n {
assert!((b_bar[i] - dt).abs() < 1e-6, "Euler B_bar must equal dt*B");
}
}
#[test]
fn test_ssm_output_y_equals_cx_plus_du_at_t0() {
let d_state = 4usize;
let c: Vec<f64> = vec![1.0, 0.5, -0.5, 0.25];
let d_scalar = 2.0f64;
let state: Vec<f64> = vec![0.0; d_state];
let u = 1.5f64;
let cx: f64 = c.iter().zip(state.iter()).map(|(ci, xi)| ci * xi).sum();
let y = cx + d_scalar * u;
assert!(
(y - 3.0).abs() < 1e-12,
"y at t=0 from zero state must be D*u = {}",
3.0
);
}
#[test]
fn test_ssm_recurrence_state_update() {
let a_bar = 0.9f64;
let b_bar = 0.1f64;
let mut x = 0.0f64;
let u = 1.0f64;
for _ in 0..500 {
x = a_bar * x + b_bar * u;
}
let fixed_point = b_bar * u / (1.0 - a_bar);
assert!(
(x - fixed_point).abs() < 1e-4,
"Recurrence must converge to fixed-point {}; got {}",
fixed_point,
x
);
}
#[test]
fn test_s4_layer_a_real_is_square() {
let config = S4Config {
d_state: 8,
..Default::default()
};
let layer = S4Layer::new(&config).expect("S4Layer creation must succeed");
let (r, c) = (layer.a_real.shape()[0], layer.a_real.shape()[1]);
assert_eq!(r, c, "A_real must be square");
assert_eq!(r, 8, "A_real must have d_state rows");
}
#[test]
fn test_s4_layer_b_c_lengths_match_d_state() {
let d_state = 16;
let config = S4Config {
d_state,
..Default::default()
};
let layer = S4Layer::new(&config).expect("S4Layer creation must succeed");
assert_eq!(layer.b_real.len(), d_state);
assert_eq!(layer.c_real.len(), d_state);
}
#[test]
fn test_s4_layer_d_length_matches_n_ssm() {
let config = S4Config {
d_model: 64,
n_ssm: None,
..Default::default()
};
let layer = S4Layer::new(&config).expect("S4Layer creation must succeed");
assert_eq!(
layer.d.len(),
config.get_n_ssm(),
"D skip-connection must have length n_ssm"
);
}
#[test]
fn test_cauchy_kernel_resolvent_invertible() {
let a: u64 = 6364136223846793005;
let c_lcg: u64 = 1442695040888963407;
let mut lcg: u64 = 0xABCD_EF01_2345_6789;
let d_state = 4;
let mut eigenvalues = Vec::with_capacity(d_state);
for _ in 0..d_state {
lcg = lcg.wrapping_mul(a).wrapping_add(c_lcg);
let ev = (lcg as i64 as f64) / (u64::MAX as f64) * 0.9; eigenvalues.push(ev - 1.0); }
let omega = 0.5f64;
for &ev in &eigenvalues {
let denom = omega - ev;
assert!(
denom.abs() > 1e-8,
"Resolvent denominator must be non-zero for ω={}, a_i={}",
omega,
ev
);
}
}
#[test]
fn test_s4_block_creation_succeeds() {
let config = S4Config {
d_model: 64,
d_state: 8,
n_ssm: Some(64),
..Default::default()
};
let block = S4Block::new(&config);
assert!(block.is_ok(), "S4Block creation must succeed");
}
#[test]
fn test_s4_layer_parameter_count_positive() {
let config = S4Config::default();
let layer = S4Layer::new(&config).expect("S4Layer creation must succeed");
assert!(
layer.parameter_count() > 0,
"S4Layer must have > 0 parameters"
);
}
#[test]
fn test_s4_model_block_count() {
let config = S4Config {
n_layer: 4,
..S4Config::default()
};
let model = S4Model::new(config.clone()).expect("S4Model must be created");
assert_eq!(
model.blocks.len(),
config.n_layer,
"S4Model must have n_layer blocks"
);
}
#[test]
fn test_s4_model_num_parameters_positive() {
let config = S4Config {
d_model: 64,
d_state: 8,
n_layer: 2,
n_ssm: Some(64),
..Default::default()
};
let model = S4Model::new(config).expect("S4Model creation must succeed");
assert!(
model.num_parameters() > 0,
"S4Model must have > 0 parameters"
);
}
#[test]
fn test_causal_convolution_future_is_zero_at_init() {
let d_state = 4;
let state: Vec<f64> = vec![0.0; d_state];
assert!(
state.iter().all(|&x| x == 0.0),
"Initial causal state must be zero"
);
}
}