kizzasi_model/
s5.rs

1//! S5: Simplified State Space Model
2//!
3//! S5 simplifies S4 by using a more efficient parameterization of the state space
4//! while maintaining competitive performance. Key simplifications include:
5//!
6//! - **Simplified initialization**: Easier parameter initialization
7//! - **Faster computation**: Reduced computational overhead
8//! - **Better numerical stability**: Improved gradient flow
9//! - **Diagonal state matrix**: Like S4D, but with optimized discretization
10//!
11//! # Architecture
12//!
13//! ```text
14//! Input → [Linear] → [SSM Block] → [Activation] → [LayerNorm] → Output
15//!                         ↓
16//!                     [State]
17//! ```
18//!
19//! # SSM Formulation
20//!
21//! Continuous-time:
22//! ```text
23//! h'(t) = Ah(t) + Bx(t)
24//! y(t) = Ch(t)
25//! ```
26//!
27//! Where A is diagonal and initialized more simply than S4.
28//!
29//! # References
30//!
31//! - S5 paper: https://arxiv.org/abs/2208.04933
32//! - Efficiently Modeling Long Sequences with Structured State Spaces
33
34use crate::error::{ModelError, ModelResult};
35use crate::{AutoregressiveModel, ModelType};
36use kizzasi_core::{gelu, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
37use scirs2_core::ndarray::{Array1, Array2};
38use scirs2_core::random::{rng, Rng};
39
40#[allow(unused_imports)]
41use tracing::{debug, instrument, trace};
42
43/// Configuration for S5 model
44#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
45pub struct S5Config {
46    /// Input dimension
47    pub input_dim: usize,
48    /// Hidden dimension
49    pub hidden_dim: usize,
50    /// State dimension (typically 64-256)
51    pub state_dim: usize,
52    /// Number of layers
53    pub num_layers: usize,
54    /// Discretization step size (Δt)
55    pub dt: f32,
56    /// Block size for chunked computation
57    pub block_size: usize,
58}
59
60impl S5Config {
61    /// Create default S5 configuration
62    pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize) -> Self {
63        Self {
64            input_dim,
65            hidden_dim,
66            state_dim: 64,
67            num_layers,
68            dt: 0.001,
69            block_size: 64,
70        }
71    }
72
73    /// Validate configuration
74    pub fn validate(&self) -> ModelResult<()> {
75        if self.hidden_dim == 0 {
76            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
77        }
78        if self.state_dim == 0 {
79            return Err(ModelError::invalid_config("state_dim must be > 0"));
80        }
81        if self.num_layers == 0 {
82            return Err(ModelError::invalid_config("num_layers must be > 0"));
83        }
84        if self.dt <= 0.0 {
85            return Err(ModelError::invalid_config("dt must be > 0"));
86        }
87        if self.block_size == 0 {
88            return Err(ModelError::invalid_config("block_size must be > 0"));
89        }
90        Ok(())
91    }
92}
93
94/// S5 SSM block with diagonal state matrix
95#[allow(dead_code)]
96struct S5Block {
97    /// Diagonal of A matrix (log-space for stability)
98    log_a: Array1<f32>,
99    /// B matrix [state_dim, hidden_dim]
100    b_matrix: Array2<f32>,
101    /// C matrix [hidden_dim, state_dim]
102    c_matrix: Array2<f32>,
103    /// D skip connection [hidden_dim]
104    d_vec: Array1<f32>,
105    /// Discretization step
106    dt: f32,
107    /// Discretized A diagonal
108    a_bar: Array1<f32>,
109    /// Discretized B matrix
110    b_bar: Array2<f32>,
111    /// Current state [state_dim]
112    state: Array1<f32>,
113}
114
115impl S5Block {
116    /// Create new S5 block with simplified initialization
117    fn new(hidden_dim: usize, state_dim: usize, dt: f32) -> Self {
118        let mut rng = rng();
119
120        // Initialize log_a with uniform spacing (simplified vs S4's HiPPO)
121        let log_a = Array1::from_shape_fn(state_dim, |i| -((i + 1) as f32).ln());
122
123        // Initialize B and C with random values
124        let scale_b = (2.0 / (state_dim + hidden_dim) as f32).sqrt();
125        let b_matrix = Array2::from_shape_fn((state_dim, hidden_dim), |_| {
126            (rng.random::<f32>() - 0.5) * 2.0 * scale_b
127        });
128
129        let scale_c = (2.0 / (hidden_dim + state_dim) as f32).sqrt();
130        let c_matrix = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
131            (rng.random::<f32>() - 0.5) * 2.0 * scale_c
132        });
133
134        // Initialize D (skip connection) to small values
135        let d_vec = Array1::from_shape_fn(hidden_dim, |_| rng.random::<f32>() * 0.01);
136
137        // Discretize using zero-order hold (ZOH)
138        let a_bar = log_a.mapv(|log_a_i| (dt * log_a_i.exp()).exp());
139        let b_bar = b_matrix.clone() * dt;
140
141        let state = Array1::zeros(state_dim);
142
143        Self {
144            log_a,
145            b_matrix,
146            c_matrix,
147            d_vec,
148            dt,
149            a_bar,
150            b_bar,
151            state,
152        }
153    }
154
155    /// Forward pass through S5 block
156    #[instrument(skip(self, x))]
157    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
158        // Update state: h[t] = A̅·h[t-1] + B̅·x[t]
159        self.state = &self.state * &self.a_bar + self.b_bar.dot(x);
160
161        // Compute output: y[t] = C·h[t] + D·x[t]
162        let y = self.c_matrix.dot(&self.state) + &self.d_vec * x;
163
164        Ok(y)
165    }
166
167    /// Reset the state
168    fn reset(&mut self) {
169        self.state.fill(0.0);
170    }
171}
172
173/// S5 layer with SSM block, activation, and normalization
174struct S5Layer {
175    /// Input projection
176    input_proj: Array2<f32>,
177    /// S5 SSM block
178    s5_block: S5Block,
179    /// Layer normalization
180    layer_norm: LayerNorm,
181    /// Output projection
182    output_proj: Array2<f32>,
183}
184
185impl S5Layer {
186    /// Create a new S5 layer
187    fn new(config: &S5Config) -> ModelResult<Self> {
188        let mut rng = rng();
189
190        // Input projection
191        let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
192        let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
193            (rng.random::<f32>() - 0.5) * 2.0 * scale
194        });
195
196        // S5 block
197        let s5_block = S5Block::new(config.hidden_dim, config.state_dim, config.dt);
198
199        // Layer normalization
200        let layer_norm = LayerNorm::new(config.hidden_dim, NormType::RMSNorm);
201
202        // Output projection
203        let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
204        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
205            (rng.random::<f32>() - 0.5) * 2.0 * scale
206        });
207
208        Ok(Self {
209            input_proj,
210            s5_block,
211            layer_norm,
212            output_proj,
213        })
214    }
215
216    /// Forward pass
217    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
218        // Project input
219        let hidden = x.dot(&self.input_proj);
220
221        // S5 SSM block
222        let ssm_out = self.s5_block.forward(&hidden)?;
223
224        // Activation
225        let activated = gelu(&ssm_out);
226
227        // Layer norm
228        let normed = self.layer_norm.forward(&activated);
229
230        // Output projection with residual
231        let output = normed.dot(&self.output_proj) + x;
232
233        Ok(output)
234    }
235
236    /// Reset layer state
237    fn reset(&mut self) {
238        self.s5_block.reset();
239    }
240}
241
242/// S5 model with multiple layers
243pub struct S5 {
244    config: S5Config,
245    layers: Vec<S5Layer>,
246}
247
248impl S5 {
249    /// Create a new S5 model
250    #[instrument(skip(config), fields(input_dim = config.input_dim, hidden_dim = config.hidden_dim, num_layers = config.num_layers))]
251    pub fn new(config: S5Config) -> ModelResult<Self> {
252        debug!("Creating new S5 model");
253        config.validate()?;
254
255        let mut layers = Vec::with_capacity(config.num_layers);
256        for layer_idx in 0..config.num_layers {
257            trace!("Initializing S5 layer {}", layer_idx);
258            layers.push(S5Layer::new(&config)?);
259        }
260        debug!("Initialized {} S5 layers", layers.len());
261
262        debug!("S5 model created successfully");
263        Ok(Self { config, layers })
264    }
265
266    /// Get configuration
267    pub fn config(&self) -> &S5Config {
268        &self.config
269    }
270}
271
272impl SignalPredictor for S5 {
273    #[instrument(skip(self, input))]
274    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
275        let mut x = input.clone();
276
277        for layer in &mut self.layers {
278            x = layer.forward(&x)?;
279        }
280
281        Ok(x)
282    }
283
284    #[instrument(skip(self))]
285    fn reset(&mut self) {
286        debug!("Resetting S5 model state");
287        for layer in &mut self.layers {
288            layer.reset();
289        }
290    }
291
292    fn context_window(&self) -> usize {
293        // SSMs have theoretically infinite context via recurrence
294        usize::MAX
295    }
296}
297
298impl AutoregressiveModel for S5 {
299    fn hidden_dim(&self) -> usize {
300        self.config.hidden_dim
301    }
302
303    fn state_dim(&self) -> usize {
304        self.config.state_dim
305    }
306
307    fn num_layers(&self) -> usize {
308        self.config.num_layers
309    }
310
311    fn model_type(&self) -> ModelType {
312        ModelType::S4 // S5 is a variant of S4
313    }
314
315    fn get_states(&self) -> Vec<HiddenState> {
316        self.layers
317            .iter()
318            .map(|layer| {
319                // S5 uses 1D state, so expand to 2D for HiddenState
320                let state_1d = layer.s5_block.state.clone();
321                let state_2d = state_1d.insert_axis(scirs2_core::ndarray::Axis(0));
322                let mut hidden_state = HiddenState::new(
323                    self.config.hidden_dim,
324                    state_2d.len_of(scirs2_core::ndarray::Axis(1)),
325                );
326                hidden_state.update(state_2d);
327                hidden_state
328            })
329            .collect()
330    }
331
332    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
333        if states.len() != self.config.num_layers {
334            return Err(ModelError::state_count_mismatch(
335                "S5",
336                self.config.num_layers,
337                states.len(),
338            ));
339        }
340
341        for (layer, state) in self.layers.iter_mut().zip(states.iter()) {
342            // Convert from 2D back to 1D
343            let state_2d = state.state();
344            if state_2d.nrows() > 0 && state_2d.ncols() > 0 {
345                layer.s5_block.state = state_2d.row(0).to_owned();
346            }
347        }
348
349        Ok(())
350    }
351}
352
353#[cfg(test)]
354mod tests {
355    use super::*;
356
357    #[test]
358    fn test_s5_creation() {
359        let config = S5Config::new(32, 64, 2);
360        let model = S5::new(config);
361        assert!(model.is_ok());
362    }
363
364    #[test]
365    fn test_s5_forward() {
366        let config = S5Config::new(32, 64, 2);
367        let mut model = S5::new(config).expect("Failed to create S5 model");
368
369        let input = Array1::from_vec(vec![1.0; 32]);
370        let output = model.step(&input);
371        assert!(output.is_ok());
372        assert_eq!(output.expect("Failed to get output").len(), 32);
373    }
374
375    #[test]
376    fn test_s5_reset() {
377        let config = S5Config::new(32, 64, 2);
378        let mut model = S5::new(config).expect("Failed to create S5 model");
379
380        let input = Array1::from_vec(vec![1.0; 32]);
381        let _output1 = model.step(&input).expect("Failed to get output1");
382
383        model.reset();
384
385        let output2 = model.step(&input).expect("Failed to get output2");
386        // After reset, same input should give similar output to first step
387        assert_eq!(output2.len(), 32);
388    }
389
390    #[test]
391    fn test_invalid_config() {
392        let mut config = S5Config::new(32, 64, 2);
393        config.state_dim = 0;
394        assert!(config.validate().is_err());
395    }
396}