use crate::error::{ModelError, ModelResult};
use crate::{AutoregressiveModel, ModelType};
use kizzasi_core::{
gelu, CausalConv1d, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor,
};
use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::random::{rng, RngExt};
#[allow(unused_imports)]
use tracing::{debug, instrument, trace};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct S4Config {
pub input_dim: usize,
pub hidden_dim: usize,
pub state_dim: usize,
pub num_layers: usize,
pub dropout: f32,
pub dt_min: f32,
pub dt_max: f32,
pub use_diagonal: bool,
pub use_rms_norm: bool,
}
impl Default for S4Config {
fn default() -> Self {
Self {
input_dim: 1,
hidden_dim: 512,
state_dim: 64,
num_layers: 6,
dropout: 0.0,
dt_min: 0.001,
dt_max: 0.1,
use_diagonal: true, use_rms_norm: true,
}
}
}
impl S4Config {
pub fn new() -> Self {
Self::default()
}
pub fn input_dim(mut self, dim: usize) -> Self {
self.input_dim = dim;
self
}
pub fn hidden_dim(mut self, dim: usize) -> Self {
self.hidden_dim = dim;
self
}
pub fn state_dim(mut self, dim: usize) -> Self {
self.state_dim = dim;
self
}
pub fn num_layers(mut self, n: usize) -> Self {
self.num_layers = n;
self
}
pub fn diagonal(mut self, use_diagonal: bool) -> Self {
self.use_diagonal = use_diagonal;
self
}
pub fn validate(&self) -> ModelResult<()> {
if self.hidden_dim == 0 {
return Err(ModelError::invalid_config("hidden_dim must be > 0"));
}
if self.state_dim == 0 {
return Err(ModelError::invalid_config("state_dim must be > 0"));
}
if self.num_layers == 0 {
return Err(ModelError::invalid_config("num_layers must be > 0"));
}
if self.dt_min <= 0.0 || self.dt_max <= 0.0 {
return Err(ModelError::invalid_config("dt_min and dt_max must be > 0"));
}
if self.dt_min > self.dt_max {
return Err(ModelError::invalid_config("dt_min must be <= dt_max"));
}
Ok(())
}
}
struct S4DKernel {
hidden_dim: usize,
state_dim: usize,
log_a: Array1<f32>,
b_matrix: Array2<f32>,
c_matrix: Array2<f32>,
d_skip: Array1<f32>,
log_dt: Array1<f32>,
state: Array2<f32>, }
impl S4DKernel {
fn new(config: &S4Config) -> ModelResult<Self> {
let mut rng = rng();
let log_a = Array1::from_shape_fn(config.state_dim, |n| ((2 * n + 1) as f32 / 2.0).ln());
let scale = (1.0 / config.state_dim as f32).sqrt();
let b_matrix = Array2::from_shape_fn((config.state_dim, config.hidden_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale
});
let c_matrix = Array2::from_shape_fn((config.hidden_dim, config.state_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale
});
let d_skip = Array1::ones(config.hidden_dim);
let log_dt = Array1::from_shape_fn(config.hidden_dim, |_| {
let dt = config.dt_min + rng.random::<f32>() * (config.dt_max - config.dt_min);
dt.ln()
});
let state = Array2::zeros((config.hidden_dim, config.state_dim));
Ok(Self {
hidden_dim: config.hidden_dim,
state_dim: config.state_dim,
log_a,
b_matrix,
c_matrix,
d_skip,
log_dt,
state,
})
}
fn discretize(&self, dt: f32) -> (Array1<f32>, Array2<f32>) {
let mut a_bar = Array1::zeros(self.state_dim);
let mut b_bar = Array2::zeros(self.b_matrix.raw_dim());
for i in 0..self.state_dim {
let a_i = -self.log_a[i].exp();
a_bar[i] = (dt * a_i).exp();
let scale = (1.0 - a_bar[i]) / (-a_i);
for j in 0..self.hidden_dim {
b_bar[[i, j]] = self.b_matrix[[i, j]] * scale;
}
}
(a_bar, b_bar)
}
fn forward_step(&mut self, u: &Array1<f32>) -> CoreResult<Array1<f32>> {
let batch_size = u.len().min(self.hidden_dim);
let mut output = Array1::zeros(batch_size);
for dim in 0..batch_size {
let dt = self.log_dt[dim].exp();
let (a_bar, b_bar) = self.discretize(dt);
for i in 0..self.state_dim {
let bu = if dim < b_bar.shape()[1] {
b_bar[[i, dim]] * u[dim]
} else {
0.0
};
self.state[[dim, i]] = a_bar[i] * self.state[[dim, i]] + bu;
}
let mut c_h = 0.0;
for i in 0..self.state_dim {
c_h += self.c_matrix[[dim, i]] * self.state[[dim, i]];
}
output[dim] = c_h + self.d_skip[dim] * u[dim];
}
Ok(output)
}
fn reset(&mut self) {
self.state.fill(0.0);
}
}
struct S4DLayer {
norm: LayerNorm,
s4_kernel: S4DKernel,
conv: CausalConv1d,
output_proj: Array2<f32>,
}
impl S4DLayer {
fn new(config: &S4Config) -> ModelResult<Self> {
let norm_type = if config.use_rms_norm {
NormType::RMSNorm
} else {
NormType::LayerNorm
};
let norm = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
let s4_kernel = S4DKernel::new(config)?;
let conv = CausalConv1d::new(config.hidden_dim, config.hidden_dim, 3);
let mut rng = rng();
let scale = (2.0 / config.hidden_dim as f32).sqrt();
let output_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale
});
Ok(Self {
norm,
s4_kernel,
conv,
output_proj,
})
}
fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let x_norm = self.norm.forward(x);
let x_vec = x_norm.to_vec();
let conv_out = self.conv.forward_step(&x_vec);
let x_conv = Array1::from_vec(conv_out);
let ssm_out = self.s4_kernel.forward_step(&x_conv)?;
let activated = gelu(&ssm_out);
let mut projected = Array1::zeros(x.len().min(self.output_proj.shape()[0]));
for i in 0..projected.len() {
let mut sum = 0.0;
for j in 0..activated.len().min(self.output_proj.shape()[1]) {
sum += self.output_proj[[i, j]] * activated[j];
}
projected[i] = sum;
}
let mut output = x.clone();
for i in 0..output.len().min(projected.len()) {
output[i] += projected[i];
}
Ok(output)
}
fn reset(&mut self) {
self.s4_kernel.reset();
}
}
pub struct S4D {
config: S4Config,
layers: Vec<S4DLayer>,
ln_out: LayerNorm,
input_proj: Array2<f32>,
output_proj: Array2<f32>,
}
impl S4D {
pub fn new(config: S4Config) -> ModelResult<Self> {
config.validate()?;
let mut layers = Vec::with_capacity(config.num_layers);
for _ in 0..config.num_layers {
layers.push(S4DLayer::new(&config)?);
}
let norm_type = if config.use_rms_norm {
NormType::RMSNorm
} else {
NormType::LayerNorm
};
let ln_out = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
let mut rng = rng();
let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale
});
let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * scale
});
Ok(Self {
config,
layers,
ln_out,
input_proj,
output_proj,
})
}
pub fn config(&self) -> &S4Config {
&self.config
}
pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
if loader.has_tensor("input_proj") {
self.input_proj = loader.load_array2("input_proj")?;
}
if loader.has_tensor("output_proj") {
self.output_proj = loader.load_array2("output_proj")?;
}
if loader.has_tensor("ln_out.weight") {
let weight = loader.load_array1("ln_out.weight")?;
self.ln_out.set_gamma(weight);
}
if loader.has_tensor("ln_out.bias") {
let bias = loader.load_array1("ln_out.bias")?;
self.ln_out.set_beta(bias);
}
for (i, layer) in self.layers.iter_mut().enumerate() {
let prefix = format!("layers.{}", i);
if loader.has_tensor(&format!("{}.norm.weight", prefix)) {
let weight = loader.load_array1(&format!("{}.norm.weight", prefix))?;
layer.norm.set_gamma(weight);
}
if loader.has_tensor(&format!("{}.norm.bias", prefix)) {
let bias = loader.load_array1(&format!("{}.norm.bias", prefix))?;
layer.norm.set_beta(bias);
}
if loader.has_tensor(&format!("{}.output_proj", prefix)) {
layer.output_proj = loader.load_array2(&format!("{}.output_proj", prefix))?;
}
let kernel_prefix = format!("{}.s4_kernel", prefix);
if loader.has_tensor(&format!("{}.log_a", kernel_prefix)) {
layer.s4_kernel.log_a = loader.load_array1(&format!("{}.log_a", kernel_prefix))?;
}
if loader.has_tensor(&format!("{}.b_matrix", kernel_prefix)) {
layer.s4_kernel.b_matrix =
loader.load_array2(&format!("{}.b_matrix", kernel_prefix))?;
}
if loader.has_tensor(&format!("{}.c_matrix", kernel_prefix)) {
layer.s4_kernel.c_matrix =
loader.load_array2(&format!("{}.c_matrix", kernel_prefix))?;
}
if loader.has_tensor(&format!("{}.d_skip", kernel_prefix)) {
layer.s4_kernel.d_skip =
loader.load_array1(&format!("{}.d_skip", kernel_prefix))?;
}
if loader.has_tensor(&format!("{}.log_dt", kernel_prefix)) {
layer.s4_kernel.log_dt =
loader.load_array1(&format!("{}.log_dt", kernel_prefix))?;
}
if loader.has_tensor(&format!("{}.conv.weight", prefix)) {
let conv_weights = loader.load_array3(&format!("{}.conv.weight", prefix))?;
layer.conv.set_weights(conv_weights);
}
if loader.has_tensor(&format!("{}.conv.bias", prefix)) {
let conv_bias = loader.load_array1(&format!("{}.conv.bias", prefix))?;
layer.conv.set_bias(conv_bias.to_vec());
}
}
Ok(())
}
pub fn save_weights_json<P: AsRef<std::path::Path>>(&self, path: P) -> ModelResult<()> {
let mut weights: std::collections::HashMap<String, Vec<f32>> =
std::collections::HashMap::new();
weights.insert(
"input_proj".to_string(),
self.input_proj.iter().copied().collect(),
);
weights.insert(
"output_proj".to_string(),
self.output_proj.iter().copied().collect(),
);
for (i, layer) in self.layers.iter().enumerate() {
let prefix = format!("layers.{}", i);
let kp = format!("{}.s4_kernel", prefix);
weights.insert(
format!("{}.output_proj", prefix),
layer.output_proj.iter().copied().collect(),
);
weights.insert(
format!("{}.log_a", kp),
layer.s4_kernel.log_a.iter().copied().collect(),
);
weights.insert(
format!("{}.b_matrix", kp),
layer.s4_kernel.b_matrix.iter().copied().collect(),
);
weights.insert(
format!("{}.c_matrix", kp),
layer.s4_kernel.c_matrix.iter().copied().collect(),
);
weights.insert(
format!("{}.d_skip", kp),
layer.s4_kernel.d_skip.iter().copied().collect(),
);
weights.insert(
format!("{}.log_dt", kp),
layer.s4_kernel.log_dt.iter().copied().collect(),
);
}
let file = std::fs::File::create(path.as_ref()).map_err(|e| {
ModelError::load_error("s4d save_weights", format!("failed to create file: {e}"))
})?;
serde_json::to_writer(file, &weights).map_err(|e| {
ModelError::load_error(
"s4d save_weights",
format!("JSON serialization failed: {e}"),
)
})?;
Ok(())
}
pub fn load_weights_json<P: AsRef<std::path::Path>>(&mut self, path: P) -> ModelResult<()> {
let file = std::fs::File::open(path.as_ref()).map_err(|e| {
ModelError::load_error("s4d load_weights", format!("failed to open file: {e}"))
})?;
let weights: std::collections::HashMap<String, Vec<f32>> = serde_json::from_reader(file)
.map_err(|e| {
ModelError::load_error(
"s4d load_weights",
format!("JSON deserialization failed: {e}"),
)
})?;
let load_array2 = |map: &std::collections::HashMap<String, Vec<f32>>,
key: &str,
rows: usize,
cols: usize|
-> ModelResult<Option<Array2<f32>>> {
if let Some(data) = map.get(key) {
if data.len() != rows * cols {
return Err(ModelError::load_error(
"s4d load_weights",
format!(
"shape mismatch for '{}': expected {}×{}={} but got {}",
key,
rows,
cols,
rows * cols,
data.len()
),
));
}
let arr = Array2::from_shape_vec((rows, cols), data.clone()).map_err(|e| {
ModelError::load_error(
"s4d load_weights",
format!("failed to reshape '{}': {e}", key),
)
})?;
Ok(Some(arr))
} else {
Ok(None)
}
};
let load_array1 = |map: &std::collections::HashMap<String, Vec<f32>>,
key: &str,
expected_len: usize|
-> ModelResult<Option<Array1<f32>>> {
if let Some(data) = map.get(key) {
if data.len() != expected_len {
return Err(ModelError::load_error(
"s4d load_weights",
format!(
"shape mismatch for '{}': expected {} but got {}",
key,
expected_len,
data.len()
),
));
}
Ok(Some(Array1::from_vec(data.clone())))
} else {
Ok(None)
}
};
let hidden = self.config.hidden_dim;
let state = self.config.state_dim;
if let Some(arr) = load_array2(&weights, "input_proj", self.config.input_dim, hidden)? {
self.input_proj = arr;
}
if let Some(arr) = load_array2(&weights, "output_proj", hidden, self.config.input_dim)? {
self.output_proj = arr;
}
for (i, layer) in self.layers.iter_mut().enumerate() {
let prefix = format!("layers.{}", i);
let kp = format!("{}.s4_kernel", prefix);
if let Some(arr) =
load_array2(&weights, &format!("{}.output_proj", prefix), hidden, hidden)?
{
layer.output_proj = arr;
}
if let Some(arr) = load_array1(&weights, &format!("{}.log_a", kp), state)? {
layer.s4_kernel.log_a = arr;
}
if let Some(arr) = load_array2(&weights, &format!("{}.b_matrix", kp), state, hidden)? {
layer.s4_kernel.b_matrix = arr;
}
if let Some(arr) = load_array2(&weights, &format!("{}.c_matrix", kp), hidden, state)? {
layer.s4_kernel.c_matrix = arr;
}
if let Some(arr) = load_array1(&weights, &format!("{}.d_skip", kp), hidden)? {
layer.s4_kernel.d_skip = arr;
}
if let Some(arr) = load_array1(&weights, &format!("{}.log_dt", kp), hidden)? {
layer.s4_kernel.log_dt = arr;
}
}
Ok(())
}
#[allow(unused_variables)]
pub fn save_weights(&self, path: &str) -> ModelResult<()> {
self.save_weights_json(path)
}
}
impl SignalPredictor for S4D {
#[instrument(skip(self, input))]
fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
let mut hidden = input.dot(&self.input_proj);
for layer in &mut self.layers {
hidden = layer.forward(&hidden)?;
}
hidden = self.ln_out.forward(&hidden);
let output = hidden.dot(&self.output_proj);
Ok(output)
}
fn reset(&mut self) {
for layer in &mut self.layers {
layer.reset();
}
}
fn context_window(&self) -> usize {
usize::MAX
}
}
impl AutoregressiveModel for S4D {
fn hidden_dim(&self) -> usize {
self.config.hidden_dim
}
fn state_dim(&self) -> usize {
self.config.state_dim
}
fn num_layers(&self) -> usize {
self.config.num_layers
}
fn model_type(&self) -> ModelType {
ModelType::S4D
}
fn get_states(&self) -> Vec<HiddenState> {
self.layers
.iter()
.map(|layer| {
let state = layer.s4_kernel.state.clone();
let mut hs = HiddenState::new(state.shape()[0], state.shape()[1]);
hs.update(state);
hs
})
.collect()
}
fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
if states.len() != self.config.num_layers {
return Err(ModelError::state_count_mismatch(
"S4D",
self.config.num_layers,
states.len(),
));
}
for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
layer.s4_kernel.state = states[layer_idx].state().clone();
}
Ok(())
}
fn load_weights_json(&mut self, path: &std::path::Path) -> ModelResult<()> {
S4D::load_weights_json(self, path)
}
fn save_weights_json(&self, path: &std::path::Path) -> ModelResult<()> {
S4D::save_weights_json(self, path)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_s4d_config() {
let config = S4Config::new().hidden_dim(256).state_dim(64).num_layers(4);
assert_eq!(config.hidden_dim, 256);
assert_eq!(config.state_dim, 64);
assert!(config.validate().is_ok());
}
#[test]
fn test_s4d_creation() {
let config = S4Config::new().hidden_dim(128).state_dim(32);
let model = S4D::new(config);
assert!(model.is_ok());
}
#[test]
fn test_s4d_forward() {
let config = S4Config::new().hidden_dim(64).state_dim(16).num_layers(2);
let mut model = S4D::new(config).expect("Failed to create S4D model");
let input = Array1::from_vec(vec![0.5]);
let output = model.step(&input);
assert!(output.is_ok());
}
#[test]
fn test_invalid_dt() {
let config = S4Config {
dt_min: 0.1,
dt_max: 0.01, ..Default::default()
};
assert!(config.validate().is_err());
}
#[test]
fn test_s4_save_load_roundtrip() {
use std::sync::atomic::{AtomicU64, Ordering};
static S4_ROUNDTRIP_COUNTER: AtomicU64 = AtomicU64::new(0);
let uid = S4_ROUNDTRIP_COUNTER.fetch_add(1, Ordering::Relaxed);
let config = S4Config::new().hidden_dim(32).state_dim(8).num_layers(2);
let model = S4D::new(config).expect("Failed to create S4D model");
let mut tmp = std::env::temp_dir();
tmp.push(format!("kizzasi_s4_roundtrip_test_{}.json", uid));
model
.save_weights_json(&tmp)
.expect("save_weights_json failed");
let config2 = S4Config::new().hidden_dim(32).state_dim(8).num_layers(2);
let mut model2 = S4D::new(config2).expect("Failed to create second S4D model");
model2
.load_weights_json(&tmp)
.expect("load_weights_json failed");
let file = std::fs::File::open(&tmp).expect("temp file should exist");
let reloaded: std::collections::HashMap<String, Vec<f32>> =
serde_json::from_reader(file).expect("should deserialize");
assert_eq!(reloaded.len(), 14, "unexpected number of weight keys");
let _ = std::fs::remove_file(&tmp);
}
}