kizzasi_model/
s4.rs

1//! S4 and S4D: Structured State Space Models
2//!
3//! S4 (Structured State Space Sequence model) is a deep learning architecture
4//! that leverages properties of continuous-time linear time-invariant (LTI) systems
5//! for efficient and effective sequence modeling.
6//!
7//! # S4D Variant
8//!
9//! S4D simplifies S4 by using diagonal state matrices, which:
10//! - Reduces computation from O(N²) to O(N)
11//! - Simplifies implementation while maintaining quality
12//! - Makes the model more interpretable
13//!
14//! # SSM Formulation
15//!
16//! Continuous-time:
17//! ```text
18//! h'(t) = A h(t) + B u(t)
19//! y(t) = C h(t) + D u(t)
20//! ```
21//!
22//! Discrete-time (after discretization):
23//! ```text
24//! h[k] = A̅ h[k-1] + B̅ u[k]
25//! y[k] = C̅ h[k] + D̅ u[k]
26//! ```
27//!
28//! Where:
29//! - A̅ = exp(Δ·A) for ZOH (Zero-Order Hold)
30//! - B̅ = (A̅ - I) A^(-1) B
31//!
32//! # S4D Architecture
33//!
34//! For S4D, A is diagonal: A = diag(-exp(α₁), -exp(α₂), ..., -exp(αₙ))
35//! This makes discretization and computation much simpler.
36//!
37//! # References
38//!
39//! - S4 paper: https://arxiv.org/abs/2111.00396
40//! - S4D paper: https://arxiv.org/abs/2206.11893
41
42use crate::error::{ModelError, ModelResult};
43use crate::{AutoregressiveModel, ModelType};
44use kizzasi_core::{
45    gelu, CausalConv1d, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor,
46};
47use scirs2_core::ndarray::{Array1, Array2};
48use scirs2_core::random::{rng, Rng};
49#[allow(unused_imports)]
50use tracing::{debug, instrument, trace};
51
52/// Configuration for S4D
53#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct S4Config {
55    /// Input dimension
56    pub input_dim: usize,
57    /// Hidden dimension (d_model)
58    pub hidden_dim: usize,
59    /// State dimension (N)
60    pub state_dim: usize,
61    /// Number of layers
62    pub num_layers: usize,
63    /// Dropout rate
64    pub dropout: f32,
65    /// Discretization step size (Δ)
66    pub dt_min: f32,
67    pub dt_max: f32,
68    /// Use diagonal state matrix (S4D)
69    pub use_diagonal: bool,
70    /// Use RMSNorm instead of LayerNorm
71    pub use_rms_norm: bool,
72}
73
74impl Default for S4Config {
75    fn default() -> Self {
76        Self {
77            input_dim: 1,
78            hidden_dim: 512,
79            state_dim: 64,
80            num_layers: 6,
81            dropout: 0.0,
82            dt_min: 0.001,
83            dt_max: 0.1,
84            use_diagonal: true, // S4D by default
85            use_rms_norm: true,
86        }
87    }
88}
89
90impl S4Config {
91    /// Create a new S4 configuration
92    pub fn new() -> Self {
93        Self::default()
94    }
95
96    /// Set input dimension
97    pub fn input_dim(mut self, dim: usize) -> Self {
98        self.input_dim = dim;
99        self
100    }
101
102    /// Set hidden dimension
103    pub fn hidden_dim(mut self, dim: usize) -> Self {
104        self.hidden_dim = dim;
105        self
106    }
107
108    /// Set state dimension
109    pub fn state_dim(mut self, dim: usize) -> Self {
110        self.state_dim = dim;
111        self
112    }
113
114    /// Set number of layers
115    pub fn num_layers(mut self, n: usize) -> Self {
116        self.num_layers = n;
117        self
118    }
119
120    /// Use diagonal state matrix (S4D)
121    pub fn diagonal(mut self, use_diagonal: bool) -> Self {
122        self.use_diagonal = use_diagonal;
123        self
124    }
125
126    /// Validate the configuration
127    pub fn validate(&self) -> ModelResult<()> {
128        if self.hidden_dim == 0 {
129            return Err(ModelError::invalid_config("hidden_dim must be > 0"));
130        }
131        if self.state_dim == 0 {
132            return Err(ModelError::invalid_config("state_dim must be > 0"));
133        }
134        if self.num_layers == 0 {
135            return Err(ModelError::invalid_config("num_layers must be > 0"));
136        }
137        if self.dt_min <= 0.0 || self.dt_max <= 0.0 {
138            return Err(ModelError::invalid_config("dt_min and dt_max must be > 0"));
139        }
140        if self.dt_min > self.dt_max {
141            return Err(ModelError::invalid_config("dt_min must be <= dt_max"));
142        }
143        Ok(())
144    }
145}
146
147/// S4D kernel: Diagonal state space model
148struct S4DKernel {
149    hidden_dim: usize,
150    state_dim: usize,
151
152    /// Log of diagonal A matrix elements
153    /// A = diag(-exp(log_a[0]), -exp(log_a[1]), ..., -exp(log_a[N-1]))
154    log_a: Array1<f32>,
155
156    /// B matrix (input-to-state) [state_dim, hidden_dim]
157    b_matrix: Array2<f32>,
158
159    /// C matrix (state-to-output) [hidden_dim, state_dim]
160    c_matrix: Array2<f32>,
161
162    /// D matrix (skip connection) [hidden_dim]
163    d_skip: Array1<f32>,
164
165    /// Discretization step size (learnable)
166    log_dt: Array1<f32>,
167
168    /// Hidden state
169    state: Array2<f32>, // [hidden_dim, state_dim]
170}
171
172impl S4DKernel {
173    fn new(config: &S4Config) -> ModelResult<Self> {
174        let mut rng = rng();
175
176        // Initialize diagonal A with HiPPO initialization
177        // A = diag(-1/2, -3/2, -5/2, ..., -(2N-1)/2)
178        // Store log of absolute value since we negate later: A[n] = -exp(log_a[n])
179        let log_a = Array1::from_shape_fn(config.state_dim, |n| ((2 * n + 1) as f32 / 2.0).ln());
180
181        // Initialize B with random values
182        let scale = (1.0 / config.state_dim as f32).sqrt();
183        let b_matrix = Array2::from_shape_fn((config.state_dim, config.hidden_dim), |_| {
184            (rng.random::<f32>() - 0.5) * 2.0 * scale
185        });
186
187        // Initialize C with random values
188        let c_matrix = Array2::from_shape_fn((config.hidden_dim, config.state_dim), |_| {
189            (rng.random::<f32>() - 0.5) * 2.0 * scale
190        });
191
192        // Initialize D (skip connection)
193        let d_skip = Array1::ones(config.hidden_dim);
194
195        // Initialize discretization step (log scale)
196        let log_dt = Array1::from_shape_fn(config.hidden_dim, |_| {
197            let dt = config.dt_min + rng.random::<f32>() * (config.dt_max - config.dt_min);
198            dt.ln()
199        });
200
201        // Initialize state
202        let state = Array2::zeros((config.hidden_dim, config.state_dim));
203
204        Ok(Self {
205            hidden_dim: config.hidden_dim,
206            state_dim: config.state_dim,
207            log_a,
208            b_matrix,
209            c_matrix,
210            d_skip,
211            log_dt,
212            state,
213        })
214    }
215
216    /// Discretize continuous SSM to discrete SSM using ZOH (Zero-Order Hold)
217    ///
218    /// For diagonal A:
219    /// A̅[i] = exp(Δ·A[i])
220    /// B̅[i] = B[i] * (1 - A̅[i]) / (-A[i])
221    fn discretize(&self, dt: f32) -> (Array1<f32>, Array2<f32>) {
222        let mut a_bar = Array1::zeros(self.state_dim);
223        let mut b_bar = Array2::zeros(self.b_matrix.raw_dim());
224
225        for i in 0..self.state_dim {
226            // A[i] = -exp(log_a[i])
227            let a_i = -self.log_a[i].exp();
228
229            // A̅[i] = exp(Δ·A[i])
230            a_bar[i] = (dt * a_i).exp();
231
232            // B̅[i, :] = B[i, :] * (1 - A̅[i]) / (-A[i])
233            let scale = (1.0 - a_bar[i]) / (-a_i);
234            for j in 0..self.hidden_dim {
235                b_bar[[i, j]] = self.b_matrix[[i, j]] * scale;
236            }
237        }
238
239        (a_bar, b_bar)
240    }
241
242    /// Forward step: compute next state and output
243    fn forward_step(&mut self, u: &Array1<f32>) -> CoreResult<Array1<f32>> {
244        let batch_size = u.len().min(self.hidden_dim);
245
246        // Get discretization parameters (per dimension)
247        let mut output = Array1::zeros(batch_size);
248
249        for dim in 0..batch_size {
250            let dt = self.log_dt[dim].exp();
251            let (a_bar, b_bar) = self.discretize(dt);
252
253            // Update state: h[k] = A̅ ⊙ h[k-1] + B̅ u[k]
254            // Since A̅ is diagonal, this is element-wise multiplication
255            for i in 0..self.state_dim {
256                let bu = if dim < b_bar.shape()[1] {
257                    b_bar[[i, dim]] * u[dim]
258                } else {
259                    0.0
260                };
261                self.state[[dim, i]] = a_bar[i] * self.state[[dim, i]] + bu;
262            }
263
264            // Compute output: y[k] = C h[k] + D u[k]
265            let mut c_h = 0.0;
266            for i in 0..self.state_dim {
267                c_h += self.c_matrix[[dim, i]] * self.state[[dim, i]];
268            }
269            output[dim] = c_h + self.d_skip[dim] * u[dim];
270        }
271
272        Ok(output)
273    }
274
275    fn reset(&mut self) {
276        self.state.fill(0.0);
277    }
278}
279
280/// S4D Layer
281struct S4DLayer {
282    norm: LayerNorm,
283    s4_kernel: S4DKernel,
284    conv: CausalConv1d,
285    output_proj: Array2<f32>,
286}
287
288impl S4DLayer {
289    fn new(config: &S4Config) -> ModelResult<Self> {
290        let norm_type = if config.use_rms_norm {
291            NormType::RMSNorm
292        } else {
293            NormType::LayerNorm
294        };
295
296        let norm = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
297        let s4_kernel = S4DKernel::new(config)?;
298
299        // Short convolution for local context
300        let conv = CausalConv1d::new(config.hidden_dim, config.hidden_dim, 3);
301
302        // Output projection
303        let mut rng = rng();
304        let scale = (2.0 / config.hidden_dim as f32).sqrt();
305        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
306            (rng.random::<f32>() - 0.5) * 2.0 * scale
307        });
308
309        Ok(Self {
310            norm,
311            s4_kernel,
312            conv,
313            output_proj,
314        })
315    }
316
317    fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
318        // 1. Normalize
319        let x_norm = self.norm.forward(x);
320
321        // 2. Short convolution
322        let x_vec = x_norm.to_vec();
323        let conv_out = self.conv.forward_step(&x_vec);
324        let x_conv = Array1::from_vec(conv_out);
325
326        // 3. S4D SSM
327        let ssm_out = self.s4_kernel.forward_step(&x_conv)?;
328
329        // 4. Activation
330        let activated = gelu(&ssm_out);
331
332        // 5. Output projection
333        let mut projected = Array1::zeros(x.len().min(self.output_proj.shape()[0]));
334        for i in 0..projected.len() {
335            let mut sum = 0.0;
336            for j in 0..activated.len().min(self.output_proj.shape()[1]) {
337                sum += self.output_proj[[i, j]] * activated[j];
338            }
339            projected[i] = sum;
340        }
341
342        // 6. Residual connection
343        let mut output = x.clone();
344        for i in 0..output.len().min(projected.len()) {
345            output[i] += projected[i];
346        }
347
348        Ok(output)
349    }
350
351    fn reset(&mut self) {
352        self.s4_kernel.reset();
353    }
354}
355
356/// S4D model
357pub struct S4D {
358    config: S4Config,
359    layers: Vec<S4DLayer>,
360    ln_out: LayerNorm,
361    input_proj: Array2<f32>,
362    output_proj: Array2<f32>,
363}
364
365impl S4D {
366    /// Create a new S4D model
367    pub fn new(config: S4Config) -> ModelResult<Self> {
368        config.validate()?;
369
370        // Initialize layers
371        let mut layers = Vec::with_capacity(config.num_layers);
372        for _ in 0..config.num_layers {
373            layers.push(S4DLayer::new(&config)?);
374        }
375
376        // Output layer normalization
377        let norm_type = if config.use_rms_norm {
378            NormType::RMSNorm
379        } else {
380            NormType::LayerNorm
381        };
382        let ln_out = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
383
384        // Initialize input/output projections
385        let mut rng = rng();
386        let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
387        let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
388            (rng.random::<f32>() - 0.5) * 2.0 * scale
389        });
390
391        let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
392        let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
393            (rng.random::<f32>() - 0.5) * 2.0 * scale
394        });
395
396        Ok(Self {
397            config,
398            layers,
399            ln_out,
400            input_proj,
401            output_proj,
402        })
403    }
404
405    /// Get the configuration
406    pub fn config(&self) -> &S4Config {
407        &self.config
408    }
409
410    /// Load weights from a SafeTensors model file
411    ///
412    /// # Weight Naming Convention
413    ///
414    /// The following tensor names are expected:
415    /// - `input_proj`: Input projection matrix (input_dim, hidden_dim)
416    /// - `output_proj`: Output projection matrix (hidden_dim, input_dim)
417    /// - `ln_out.weight`: Output layer norm weight (gamma)
418    /// - `ln_out.bias`: Output layer norm bias (beta, optional)
419    ///
420    /// For each layer i:
421    /// - `layers.{i}.norm.weight`: Layer normalization weight
422    /// - `layers.{i}.norm.bias`: Layer normalization bias (optional)
423    /// - `layers.{i}.output_proj`: Output projection matrix
424    ///
425    /// S4D kernel parameters:
426    /// - `layers.{i}.s4_kernel.log_a`: Log of diagonal A matrix
427    /// - `layers.{i}.s4_kernel.b_matrix`: B matrix (input-to-state)
428    /// - `layers.{i}.s4_kernel.c_matrix`: C matrix (state-to-output)
429    /// - `layers.{i}.s4_kernel.d_skip`: D skip connection
430    /// - `layers.{i}.s4_kernel.log_dt`: Log of discretization step size
431    pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
432        // Load input/output projections
433        if loader.has_tensor("input_proj") {
434            self.input_proj = loader.load_array2("input_proj")?;
435        }
436        if loader.has_tensor("output_proj") {
437            self.output_proj = loader.load_array2("output_proj")?;
438        }
439
440        // Load output layer norm
441        if loader.has_tensor("ln_out.weight") {
442            let weight = loader.load_array1("ln_out.weight")?;
443            self.ln_out.set_gamma(weight);
444        }
445        if loader.has_tensor("ln_out.bias") {
446            let bias = loader.load_array1("ln_out.bias")?;
447            self.ln_out.set_beta(bias);
448        }
449
450        // Load each layer's weights
451        for (i, layer) in self.layers.iter_mut().enumerate() {
452            let prefix = format!("layers.{}", i);
453
454            // Load layer norm
455            if loader.has_tensor(&format!("{}.norm.weight", prefix)) {
456                let weight = loader.load_array1(&format!("{}.norm.weight", prefix))?;
457                layer.norm.set_gamma(weight);
458            }
459            if loader.has_tensor(&format!("{}.norm.bias", prefix)) {
460                let bias = loader.load_array1(&format!("{}.norm.bias", prefix))?;
461                layer.norm.set_beta(bias);
462            }
463
464            // Load output projection
465            if loader.has_tensor(&format!("{}.output_proj", prefix)) {
466                layer.output_proj = loader.load_array2(&format!("{}.output_proj", prefix))?;
467            }
468
469            // Load S4D kernel parameters
470            let kernel_prefix = format!("{}.s4_kernel", prefix);
471            if loader.has_tensor(&format!("{}.log_a", kernel_prefix)) {
472                layer.s4_kernel.log_a = loader.load_array1(&format!("{}.log_a", kernel_prefix))?;
473            }
474            if loader.has_tensor(&format!("{}.b_matrix", kernel_prefix)) {
475                layer.s4_kernel.b_matrix =
476                    loader.load_array2(&format!("{}.b_matrix", kernel_prefix))?;
477            }
478            if loader.has_tensor(&format!("{}.c_matrix", kernel_prefix)) {
479                layer.s4_kernel.c_matrix =
480                    loader.load_array2(&format!("{}.c_matrix", kernel_prefix))?;
481            }
482            if loader.has_tensor(&format!("{}.d_skip", kernel_prefix)) {
483                layer.s4_kernel.d_skip =
484                    loader.load_array1(&format!("{}.d_skip", kernel_prefix))?;
485            }
486            if loader.has_tensor(&format!("{}.log_dt", kernel_prefix)) {
487                layer.s4_kernel.log_dt =
488                    loader.load_array1(&format!("{}.log_dt", kernel_prefix))?;
489            }
490
491            // Load convolution weights [out_channels, in_channels, kernel_size]
492            if loader.has_tensor(&format!("{}.conv.weight", prefix)) {
493                let conv_weights = loader.load_array3(&format!("{}.conv.weight", prefix))?;
494                layer.conv.set_weights(conv_weights);
495            }
496            if loader.has_tensor(&format!("{}.conv.bias", prefix)) {
497                let conv_bias = loader.load_array1(&format!("{}.conv.bias", prefix))?;
498                layer.conv.set_bias(conv_bias.to_vec());
499            }
500        }
501
502        Ok(())
503    }
504
505    /// Save weights to a SafeTensors model file (stub for future implementation)
506    #[allow(unused_variables)]
507    pub fn save_weights(&self, path: &str) -> ModelResult<()> {
508        // TODO: Implement SafeTensors saving
509        Err(ModelError::simple_load_error(
510            "S4D save_weights not yet implemented".to_string(),
511        ))
512    }
513}
514
515impl SignalPredictor for S4D {
516    #[instrument(skip(self, input))]
517    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
518        // Project input to hidden dimension
519        let mut hidden = input.dot(&self.input_proj);
520
521        // Pass through each layer
522        for layer in &mut self.layers {
523            hidden = layer.forward(&hidden)?;
524        }
525
526        // Final layer normalization
527        hidden = self.ln_out.forward(&hidden);
528
529        // Project back to input dimension
530        let output = hidden.dot(&self.output_proj);
531        Ok(output)
532    }
533
534    fn reset(&mut self) {
535        for layer in &mut self.layers {
536            layer.reset();
537        }
538    }
539
540    fn context_window(&self) -> usize {
541        // S4D has theoretically infinite context via recurrence
542        usize::MAX
543    }
544}
545
546impl AutoregressiveModel for S4D {
547    fn hidden_dim(&self) -> usize {
548        self.config.hidden_dim
549    }
550
551    fn state_dim(&self) -> usize {
552        self.config.state_dim
553    }
554
555    fn num_layers(&self) -> usize {
556        self.config.num_layers
557    }
558
559    fn model_type(&self) -> ModelType {
560        ModelType::S4D
561    }
562
563    fn get_states(&self) -> Vec<HiddenState> {
564        self.layers
565            .iter()
566            .map(|layer| {
567                let state = layer.s4_kernel.state.clone();
568                let mut hs = HiddenState::new(state.shape()[0], state.shape()[1]);
569                hs.update(state);
570                hs
571            })
572            .collect()
573    }
574
575    fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
576        if states.len() != self.config.num_layers {
577            return Err(ModelError::state_count_mismatch(
578                "S4D",
579                self.config.num_layers,
580                states.len(),
581            ));
582        }
583
584        for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
585            layer.s4_kernel.state = states[layer_idx].state().clone();
586        }
587
588        Ok(())
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_s4d_config() {
598        let config = S4Config::new().hidden_dim(256).state_dim(64).num_layers(4);
599
600        assert_eq!(config.hidden_dim, 256);
601        assert_eq!(config.state_dim, 64);
602        assert!(config.validate().is_ok());
603    }
604
605    #[test]
606    fn test_s4d_creation() {
607        let config = S4Config::new().hidden_dim(128).state_dim(32);
608        let model = S4D::new(config);
609        assert!(model.is_ok());
610    }
611
612    #[test]
613    fn test_s4d_forward() {
614        let config = S4Config::new().hidden_dim(64).state_dim(16).num_layers(2);
615        let mut model = S4D::new(config).expect("Failed to create S4D model");
616
617        let input = Array1::from_vec(vec![0.5]);
618        let output = model.step(&input);
619        assert!(output.is_ok());
620    }
621
622    #[test]
623    fn test_invalid_dt() {
624        let config = S4Config {
625            dt_min: 0.1,
626            dt_max: 0.01, // max < min
627            ..Default::default()
628        };
629        assert!(config.validate().is_err());
630    }
631}