use crate::{CoreError, CoreResult, HiddenState};
use scirs2_core::ndarray::{s, Array1, Array2, Axis};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct S5Config {
pub d_input: usize,
pub d_model: usize,
pub d_state: usize,
pub n_blocks: usize,
pub use_complex: bool,
pub dt: f32,
pub use_bias: bool,
pub dropout: f32,
pub layer_norm_eps: f32,
}
impl Default for S5Config {
fn default() -> Self {
Self {
d_input: 1,
d_model: 256,
d_state: 64,
n_blocks: 8,
use_complex: false,
dt: 0.001,
use_bias: false,
dropout: 0.0,
layer_norm_eps: 1e-5,
}
}
}
impl S5Config {
pub fn new(d_input: usize, d_model: usize, d_state: usize) -> Self {
Self {
d_input,
d_model,
d_state,
..Default::default()
}
}
pub fn with_blocks(mut self, n_blocks: usize) -> Self {
self.n_blocks = n_blocks;
self
}
pub fn with_dt(mut self, dt: f32) -> Self {
self.dt = dt;
self
}
pub fn with_complex(mut self, use_complex: bool) -> Self {
self.use_complex = use_complex;
self
}
pub fn with_dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
pub fn block_size(&self) -> usize {
self.d_model / self.n_blocks
}
pub fn validate(&self) -> CoreResult<()> {
if self.d_input == 0 {
return Err(CoreError::InvalidConfig("d_input must be > 0".into()));
}
if self.d_model == 0 {
return Err(CoreError::InvalidConfig("d_model must be > 0".into()));
}
if self.d_state == 0 {
return Err(CoreError::InvalidConfig("d_state must be > 0".into()));
}
if self.n_blocks == 0 {
return Err(CoreError::InvalidConfig("n_blocks must be > 0".into()));
}
if !self.d_model.is_multiple_of(self.n_blocks) {
return Err(CoreError::InvalidConfig(
"d_model must be divisible by n_blocks".into(),
));
}
if self.dt <= 0.0 {
return Err(CoreError::InvalidConfig("dt must be > 0".into()));
}
Ok(())
}
}
pub struct S5Layer {
config: S5Config,
in_proj_w: Array2<f32>,
in_proj_b: Option<Array1<f32>>,
out_proj_w: Array2<f32>,
out_proj_b: Option<Array1<f32>>,
#[allow(dead_code)]
a_matrices: Vec<Array2<f32>>,
#[allow(dead_code)]
b_matrices: Vec<Array2<f32>>,
c_matrices: Vec<Array2<f32>>,
d_vectors: Vec<Array1<f32>>,
a_bar: Vec<Array2<f32>>,
b_bar: Vec<Array2<f32>>,
norm_w: Array1<f32>,
norm_b: Array1<f32>,
hidden_states: Vec<HiddenState>,
}
impl S5Layer {
pub fn new(config: S5Config) -> CoreResult<Self> {
config.validate()?;
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
let init_scale = (2.0 / config.d_model as f32).sqrt();
let block_size = config.block_size();
let in_proj_w = Array2::from_shape_fn((config.d_model, config.d_input), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
let in_proj_b = if config.use_bias {
Some(Array1::zeros(config.d_model))
} else {
None
};
let out_proj_w = Array2::from_shape_fn((config.d_model, config.d_model), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
let out_proj_b = if config.use_bias {
Some(Array1::zeros(config.d_model))
} else {
None
};
let mut a_matrices = Vec::with_capacity(config.n_blocks);
let mut b_matrices = Vec::with_capacity(config.n_blocks);
let mut c_matrices = Vec::with_capacity(config.n_blocks);
let mut d_vectors = Vec::with_capacity(config.n_blocks);
let mut hidden_states = Vec::with_capacity(config.n_blocks);
for _ in 0..config.n_blocks {
let a = Self::init_a_matrix(config.d_state)?;
a_matrices.push(a);
let b = Array2::from_shape_fn((config.d_state, block_size), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
b_matrices.push(b);
let c = Array2::from_shape_fn((block_size, config.d_state), |_| {
(rng.random::<f32>() - 0.5) * 2.0 * init_scale
});
c_matrices.push(c);
let d = Array1::from_elem(block_size, 0.01);
d_vectors.push(d);
hidden_states.push(HiddenState::new(1, config.d_state));
}
let (a_bar, b_bar) = Self::discretize_all(&a_matrices, &b_matrices, config.dt);
let norm_w = Array1::ones(config.d_model);
let norm_b = Array1::zeros(config.d_model);
Ok(Self {
config,
in_proj_w,
in_proj_b,
out_proj_w,
out_proj_b,
a_matrices,
b_matrices,
c_matrices,
d_vectors,
a_bar,
b_bar,
norm_w,
norm_b,
hidden_states,
})
}
fn init_a_matrix(d_state: usize) -> CoreResult<Array2<f32>> {
use scirs2_core::random::thread_rng;
let mut rng = thread_rng();
let mut a = Array2::zeros((d_state, d_state));
for i in 0..d_state {
let log_lambda = -rng.random::<f32>() * 4.0 - 1.0; a[[i, i]] = log_lambda.exp();
if i > 0 {
a[[i, i - 1]] = (rng.random::<f32>() - 0.5) * 0.1;
}
if i < d_state - 1 {
a[[i, i + 1]] = (rng.random::<f32>() - 0.5) * 0.1;
}
}
Ok(a)
}
fn discretize_all(
a_matrices: &[Array2<f32>],
b_matrices: &[Array2<f32>],
dt: f32,
) -> (Vec<Array2<f32>>, Vec<Array2<f32>>) {
let mut a_bar_vec = Vec::with_capacity(a_matrices.len());
let mut b_bar_vec = Vec::with_capacity(b_matrices.len());
for (a, b) in a_matrices.iter().zip(b_matrices.iter()) {
let (a_bar, b_bar) = Self::discretize_block(a, b, dt);
a_bar_vec.push(a_bar);
b_bar_vec.push(b_bar);
}
(a_bar_vec, b_bar_vec)
}
fn discretize_block(a: &Array2<f32>, b: &Array2<f32>, dt: f32) -> (Array2<f32>, Array2<f32>) {
let d_state = a.nrows();
let mut a_bar = Array2::zeros((d_state, d_state));
let mut b_bar = b.clone();
for i in 0..d_state {
for j in 0..d_state {
if i == j {
a_bar[[i, j]] = (a[[i, j]] * dt).exp();
} else {
a_bar[[i, j]] = a[[i, j]] * dt;
}
}
}
b_bar *= dt;
(a_bar, b_bar)
}
fn layer_norm(&self, x: &Array1<f32>) -> Array1<f32> {
let mean = x.mean().unwrap_or(0.0);
let variance = x.iter().map(|&xi| (xi - mean).powi(2)).sum::<f32>() / (x.len() as f32);
let std = (variance + self.config.layer_norm_eps).sqrt();
let normalized = x.mapv(|xi| (xi - mean) / std);
&normalized * &self.norm_w + &self.norm_b
}
fn block_step(&mut self, block_idx: usize, u: &Array1<f32>) -> CoreResult<Array1<f32>> {
let a_bar = &self.a_bar[block_idx];
let b_bar = &self.b_bar[block_idx];
let c = &self.c_matrices[block_idx];
let d = &self.d_vectors[block_idx];
let h_state = self.hidden_states[block_idx].state();
let h = h_state.row(0);
let h_new_1d = a_bar.dot(&h) + b_bar.dot(u);
let mut h_new = Array2::zeros((1, h_new_1d.len()));
h_new.row_mut(0).assign(&h_new_1d);
let y = c.dot(&h_new_1d) + d * u;
self.hidden_states[block_idx].update(h_new);
Ok(y)
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
if x.len() != self.config.d_input {
return Err(CoreError::DimensionMismatch {
expected: self.config.d_input,
got: x.len(),
});
}
let mut u = self.in_proj_w.dot(x);
if let Some(ref bias) = self.in_proj_b {
u = &u + bias;
}
let u_norm = self.layer_norm(&u);
let block_size = self.config.block_size();
let mut y = Array1::zeros(self.config.d_model);
for block_idx in 0..self.config.n_blocks {
let start = block_idx * block_size;
let end = start + block_size;
let u_block = u_norm.slice(s![start..end]).to_owned();
let y_block = self.block_step(block_idx, &u_block)?;
y.slice_mut(s![start..end]).assign(&y_block);
}
let mut out = self.out_proj_w.dot(&y);
if let Some(ref bias) = self.out_proj_b {
out = &out + bias;
}
if x.len() == out.len() {
out = &out + x;
}
Ok(out)
}
pub fn forward_batch(&mut self, x: &Array2<f32>) -> CoreResult<Array2<f32>> {
let batch_size = x.nrows();
let mut outputs = Array2::zeros((batch_size, self.config.d_model));
for (i, input) in x.axis_iter(Axis(0)).enumerate() {
let output = self.forward(&input.to_owned())?;
outputs.row_mut(i).assign(&output);
}
Ok(outputs)
}
pub fn reset(&mut self) {
for state in &mut self.hidden_states {
state.reset();
}
}
pub fn config(&self) -> &S5Config {
&self.config
}
}
pub struct S5Model {
layers: Vec<S5Layer>,
config: S5Config,
}
impl S5Model {
pub fn new(config: S5Config, n_layers: usize) -> CoreResult<Self> {
let mut layers = Vec::with_capacity(n_layers);
layers.push(S5Layer::new(config.clone())?);
for _ in 1..n_layers {
let mut layer_config = config.clone();
layer_config.d_input = config.d_model;
layers.push(S5Layer::new(layer_config)?);
}
Ok(Self { layers, config })
}
pub fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
let mut hidden = x.clone();
for layer in &mut self.layers {
hidden = layer.forward(&hidden)?;
}
Ok(hidden)
}
pub fn forward_batch(&mut self, x: &Array2<f32>) -> CoreResult<Array2<f32>> {
let mut hidden = x.clone();
for layer in &mut self.layers {
hidden = layer.forward_batch(&hidden)?;
}
Ok(hidden)
}
pub fn reset(&mut self) {
for layer in &mut self.layers {
layer.reset();
}
}
pub fn n_layers(&self) -> usize {
self.layers.len()
}
pub fn config(&self) -> &S5Config {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_s5_config() {
let config = S5Config::new(10, 256, 64);
assert_eq!(config.d_input, 10);
assert_eq!(config.d_model, 256);
assert_eq!(config.d_state, 64);
assert_eq!(config.block_size(), 32); assert!(config.validate().is_ok());
}
#[test]
fn test_s5_config_validation() {
let mut config = S5Config::new(10, 256, 64);
config.d_model = 0;
assert!(config.validate().is_err());
let mut config = S5Config::new(10, 255, 64);
config.n_blocks = 8;
assert!(config.validate().is_err());
let mut config = S5Config::new(10, 256, 64);
config.dt = -0.1;
assert!(config.validate().is_err());
}
#[test]
fn test_s5_layer_creation() {
let config = S5Config::new(10, 128, 32);
let result = S5Layer::new(config);
assert!(result.is_ok());
}
#[test]
fn test_s5_forward() {
let config = S5Config::new(10, 64, 16);
let mut layer = S5Layer::new(config).unwrap();
let input = Array1::from_vec(vec![0.1; 10]);
let output = layer.forward(&input);
assert!(output.is_ok());
let output = output.unwrap();
assert_eq!(output.len(), 64);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_s5_reset() {
let config = S5Config::new(10, 64, 16);
let mut layer = S5Layer::new(config).unwrap();
let input = Array1::from_vec(vec![0.1; 10]);
layer.forward(&input).unwrap();
layer.forward(&input).unwrap();
layer.reset();
for state in &layer.hidden_states {
let h = state.state();
assert!(h.iter().all(|&x| x == 0.0));
}
}
#[test]
fn test_s5_model() {
let config = S5Config::new(10, 64, 16);
let mut model = S5Model::new(config, 3).unwrap();
assert_eq!(model.n_layers(), 3);
let input = Array1::from_vec(vec![0.1; 10]);
let output = model.forward(&input);
assert!(output.is_ok());
let output = output.unwrap();
assert_eq!(output.len(), 64);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_s5_batch() {
let config = S5Config::new(10, 64, 16);
let mut layer = S5Layer::new(config).unwrap();
let batch = Array2::from_shape_fn((4, 10), |(i, j)| 0.1 * (i as f32 + j as f32));
let output = layer.forward_batch(&batch);
assert!(output.is_ok());
let output = output.unwrap();
assert_eq!(output.shape(), &[4, 64]);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_s5_no_nan() {
let config = S5Config::new(10, 64, 16);
let mut layer = S5Layer::new(config).unwrap();
for _ in 0..10 {
let input = Array1::from_elem(10, 0.5);
let output = layer.forward(&input).unwrap();
assert!(output.iter().all(|&x| !x.is_nan()));
}
}
#[test]
fn test_layer_norm() {
let config = S5Config::new(10, 64, 16);
let layer = S5Layer::new(config).unwrap();
let input = Array1::from_vec((0..64).map(|i| i as f32).collect());
let normalized = layer.layer_norm(&input);
let mean = normalized.mean().unwrap();
let std = (normalized.iter().map(|&x| x.powi(2)).sum::<f32>() / 64.0).sqrt();
assert!((mean.abs()) < 1e-5, "Mean should be close to 0");
assert!((std - 1.0).abs() < 1e-4, "Std should be close to 1");
}
#[test]
fn test_discretization() {
let a = Array2::from_shape_fn((4, 4), |(i, j)| if i == j { -0.5 } else { 0.0 });
let b = Array2::from_shape_fn((4, 2), |_| 0.1);
let dt = 0.01;
let (a_bar, b_bar) = S5Layer::discretize_block(&a, &b, dt);
let expected_a = (-0.5 * dt).exp();
for i in 0..4 {
assert!((a_bar[[i, i]] - expected_a).abs() < 1e-5);
}
for i in 0..4 {
for j in 0..2 {
assert!((b_bar[[i, j]] - 0.1 * dt).abs() < 1e-6);
}
}
}
}