use crate::error::{ModelError, ModelResult};
use crate::{AutoregressiveModel, ModelType};
use kizzasi_core::{gelu, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{rng, RngExt};
#[allow(unused_imports)]
use tracing::{debug, instrument, trace};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct S5Config {
pub input_dim: usize,
pub hidden_dim: usize,
pub state_dim: usize,
pub num_layers: usize,
pub dt: f32,
pub block_size: usize,
}
impl S5Config {
pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize) -> Self {
Self {
input_dim,
hidden_dim,
state_dim: 64,
num_layers,
dt: 0.001,
block_size: 64,
}
}
pub fn validate(&self) -> ModelResult<()> {
if self.hidden_dim == 0 {
return Err(ModelError::invalid_config("hidden_dim must be > 0"));
}
if self.state_dim == 0 {
return Err(ModelError::invalid_config("state_dim must be > 0"));
}
if self.num_layers == 0 {
return Err(ModelError::invalid_config("num_layers must be > 0"));
}
if self.dt <= 0.0 {
return Err(ModelError::invalid_config("dt must be > 0"));
}
if self.block_size == 0 {
return Err(ModelError::invalid_config("block_size must be > 0"));
}
Ok(())
}
}
#[allow(dead_code)]
struct S5Block {
log_a: Array1<f32>,
b_matrix: Array2<f32>,
c_matrix: Array2<f32>,
d_vec: Array1<f32>,
dt: f32,
a_bar: Array1<f32>,
b_bar: Array2<f32>,
state: Array1<f32>,
}
impl S5Block {
fn new(hidden_dim: usize, state_dim: usize, dt: f32) -> Self {
let mut rng = rng();
let log_a = Array1::from_shape_fn(state_dim, |i| -((i + 1) as f32).ln());
let scale_b = (2.0 / (state_dim + hidden_dim) as f32).sqrt();
let b_matrix = Array2::from_shape_fn((state_dim, hidden_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale_b
});
let scale_c = (2.0 / (hidden_dim + state_dim) as f32).sqrt();
let c_matrix = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale_c
});
let d_vec = Array1::from_shape_fn(hidden_dim, |_| rng.random::<f32>() * 0.01);
let a_bar = log_a.mapv(|log_a_i| (dt * log_a_i.exp()).exp());
let b_bar = b_matrix.clone() * dt;
let state = Array1::zeros(state_dim);
Self {
log_a,
b_matrix,
c_matrix,
d_vec,
dt,
a_bar,
b_bar,
state,
}
}
#[instrument(skip(self, x))]
fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
self.state = &self.state * &self.a_bar + self.b_bar.dot(x);
let y = self.c_matrix.dot(&self.state) + &self.d_vec * x;
Ok(y)
}
fn reset(&mut self) {
self.state.fill(0.0);
}
}
struct S5Layer {
input_proj: Array2<f32>,
s5_block: S5Block,
layer_norm: LayerNorm,
output_proj: Array2<f32>,
}
impl S5Layer {
fn new(config: &S5Config) -> ModelResult<Self> {
let mut rng = rng();
let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale
});
let s5_block = S5Block::new(config.hidden_dim, config.state_dim, config.dt);
let layer_norm = LayerNorm::new(config.hidden_dim, NormType::RMSNorm);
let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale
});
Ok(Self {
input_proj,
s5_block,
layer_norm,
output_proj,
})
}
fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let hidden = x.dot(&self.input_proj);
let ssm_out = self.s5_block.forward(&hidden)?;
let activated = gelu(&ssm_out);
let normed = self.layer_norm.forward(&activated);
let output = normed.dot(&self.output_proj) + x;
Ok(output)
}
fn reset(&mut self) {
self.s5_block.reset();
}
}
pub struct S5 {
config: S5Config,
layers: Vec<S5Layer>,
}
impl S5 {
#[instrument(skip(config), fields(input_dim = config.input_dim, hidden_dim = config.hidden_dim, num_layers = config.num_layers))]
pub fn new(config: S5Config) -> ModelResult<Self> {
debug!("Creating new S5 model");
config.validate()?;
let mut layers = Vec::with_capacity(config.num_layers);
for layer_idx in 0..config.num_layers {
trace!("Initializing S5 layer {}", layer_idx);
layers.push(S5Layer::new(&config)?);
}
debug!("Initialized {} S5 layers", layers.len());
debug!("S5 model created successfully");
Ok(Self { config, layers })
}
pub fn config(&self) -> &S5Config {
&self.config
}
}
impl SignalPredictor for S5 {
#[instrument(skip(self, input))]
fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
let mut x = input.clone();
for layer in &mut self.layers {
x = layer.forward(&x)?;
}
Ok(x)
}
#[instrument(skip(self))]
fn reset(&mut self) {
debug!("Resetting S5 model state");
for layer in &mut self.layers {
layer.reset();
}
}
fn context_window(&self) -> usize {
usize::MAX
}
}
impl AutoregressiveModel for S5 {
fn hidden_dim(&self) -> usize {
self.config.hidden_dim
}
fn state_dim(&self) -> usize {
self.config.state_dim
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn model_type(&self) -> ModelType {
ModelType::S4 }
fn get_states(&self) -> Vec<HiddenState> {
self.layers
.iter()
.map(|layer| {
let state_1d = layer.s5_block.state.clone();
let state_2d = state_1d.insert_axis(scirs2_core::ndarray::Axis(0));
let mut hidden_state = HiddenState::new(
self.config.hidden_dim,
state_2d.len_of(scirs2_core::ndarray::Axis(1)),
);
hidden_state.update(state_2d);
hidden_state
})
.collect()
}
fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
if states.len() != self.config.num_layers {
return Err(ModelError::state_count_mismatch(
"S5",
self.config.num_layers,
states.len(),
));
}
for (layer, state) in self.layers.iter_mut().zip(states.iter()) {
let state_2d = state.state();
if state_2d.nrows() > 0 && state_2d.ncols() > 0 {
layer.s5_block.state = state_2d.row(0).to_owned();
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_s5_creation() {
let config = S5Config::new(32, 64, 2);
let model = S5::new(config);
assert!(model.is_ok());
}
#[test]
fn test_s5_forward() {
let config = S5Config::new(32, 64, 2);
let mut model = S5::new(config).expect("Failed to create S5 model");
let input = Array1::from_vec(vec![1.0; 32]);
let output = model.step(&input);
assert!(output.is_ok());
assert_eq!(output.expect("Failed to get output").len(), 32);
}
#[test]
fn test_s5_reset() {
let config = S5Config::new(32, 64, 2);
let mut model = S5::new(config).expect("Failed to create S5 model");
let input = Array1::from_vec(vec![1.0; 32]);
let _output1 = model.step(&input).expect("Failed to get output1");
model.reset();
let output2 = model.step(&input).expect("Failed to get output2");
assert_eq!(output2.len(), 32);
}
#[test]
fn test_invalid_config() {
let mut config = S5Config::new(32, 64, 2);
config.state_dim = 0;
assert!(config.validate().is_err());
}
}