kizzasi_core/
ssm.rs

1//! State Space Model implementations
2//!
3//! Implements Mamba-style selective SSM for O(1) inference steps.
4//! Uses SIMD-optimized operations for high performance.
5
6#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9use crate::config::KizzasiConfig;
10use crate::embedding::ContinuousEmbedding;
11use crate::error::CoreResult;
12use crate::simd;
13use crate::state::HiddenState;
14use crate::SignalPredictor;
15use scirs2_core::ndarray::{Array1, Array2};
16use scirs2_core::random::thread_rng;
17use serde::{Deserialize, Serialize};
18
19/// Trait for state space model implementations
20pub trait StateSpaceModel {
21    /// Perform a single recurrence step
22    fn recurrence_step(
23        &self,
24        input: &Array1<f32>,
25        state: &mut HiddenState,
26    ) -> CoreResult<Array1<f32>>;
27
28    /// Get model configuration
29    fn config(&self) -> &KizzasiConfig;
30}
31
32/// Selective State Space Model (Mamba-style)
33///
34/// Implements the selective scan mechanism from Mamba for
35/// content-aware state transitions.
36#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct SelectiveSSM {
38    config: KizzasiConfig,
39    embedding: ContinuousEmbedding,
40    state: HiddenState,
41    // SSM parameters (A, B, C, D matrices per layer)
42    a_matrices: Vec<Array2<f32>>,
43    b_matrices: Vec<Array2<f32>>,
44    c_matrices: Vec<Array2<f32>>,
45    d_vectors: Vec<Array1<f32>>,
46    // Output projection
47    output_proj: Array2<f32>,
48}
49
50impl SelectiveSSM {
51    /// Create a new SelectiveSSM from configuration
52    pub fn new(config: KizzasiConfig) -> CoreResult<Self> {
53        let hidden_dim = config.get_hidden_dim();
54        let state_dim = config.get_state_dim();
55        let num_layers = config.get_num_layers();
56        let input_dim = config.get_input_dim();
57        let output_dim = config.get_output_dim();
58
59        // Initialize embedding layer
60        let embedding = ContinuousEmbedding::new(input_dim, hidden_dim);
61
62        // Initialize hidden state
63        let state = HiddenState::new(hidden_dim, state_dim);
64
65        // Initialize SSM matrices for each layer
66        let mut rng = thread_rng();
67        let scale = 0.01;
68        let mut a_matrices = Vec::with_capacity(num_layers);
69        let mut b_matrices = Vec::with_capacity(num_layers);
70        let mut c_matrices = Vec::with_capacity(num_layers);
71        let mut d_vectors = Vec::with_capacity(num_layers);
72
73        for _ in 0..num_layers {
74            // A matrix: state transition (initialized for stability)
75            let a = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
76                -0.5 + rng.random::<f32>() * scale
77            });
78            a_matrices.push(a);
79
80            // B matrix: input projection to state
81            let b = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
82                (rng.random::<f32>() - 0.5) * scale
83            });
84            b_matrices.push(b);
85
86            // C matrix: state to output projection
87            let c = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
88                (rng.random::<f32>() - 0.5) * scale
89            });
90            c_matrices.push(c);
91
92            // D vector: skip connection
93            let d = Array1::ones(hidden_dim);
94            d_vectors.push(d);
95        }
96
97        // Output projection
98        let output_proj = Array2::from_shape_fn((hidden_dim, output_dim), |_| {
99            (rng.random::<f32>() - 0.5) * scale
100        });
101
102        Ok(Self {
103            config,
104            embedding,
105            state,
106            a_matrices,
107            b_matrices,
108            c_matrices,
109            d_vectors,
110            output_proj,
111        })
112    }
113
114    /// Get a reference to the hidden state
115    pub fn get_state(&self) -> &HiddenState {
116        &self.state
117    }
118
119    /// Get a mutable reference to the hidden state
120    pub fn get_state_mut(&mut self) -> &mut HiddenState {
121        &mut self.state
122    }
123
124    /// Set the hidden state
125    pub fn set_state(&mut self, state: HiddenState) {
126        self.state = state;
127    }
128
129    /// Get the step count from the hidden state
130    pub fn step_count(&self) -> usize {
131        self.state.step_count()
132    }
133
134    /// Get a reference to the embedding layer
135    pub fn embedding(&self) -> &ContinuousEmbedding {
136        &self.embedding
137    }
138
139    /// Get a reference to the A matrices
140    pub fn a_matrices(&self) -> &Vec<Array2<f32>> {
141        &self.a_matrices
142    }
143
144    /// Get a reference to the B matrices
145    pub fn b_matrices(&self) -> &Vec<Array2<f32>> {
146        &self.b_matrices
147    }
148
149    /// Get a reference to the C matrices
150    pub fn c_matrices(&self) -> &Vec<Array2<f32>> {
151        &self.c_matrices
152    }
153
154    /// Get a reference to the D vectors
155    pub fn d_vectors(&self) -> &Vec<Array1<f32>> {
156        &self.d_vectors
157    }
158
159    /// Get a reference to the output projection matrix
160    pub fn output_proj(&self) -> &Array2<f32> {
161        &self.output_proj
162    }
163
164    /// Discretize continuous SSM parameters (standard precision)
165    #[allow(dead_code)]
166    fn discretize(
167        &self,
168        delta: f32,
169        a: &Array2<f32>,
170        b: &Array2<f32>,
171    ) -> (Array2<f32>, Array2<f32>) {
172        // Zero-order hold discretization
173        // A_bar = exp(delta * A)
174        // B_bar = (A^-1) * (A_bar - I) * B ≈ delta * B for small delta
175        let a_bar = a.mapv(|x| (delta * x).exp());
176        let b_bar = b.mapv(|x| delta * x);
177        (a_bar, b_bar)
178    }
179
180    /// Selective scan step for a single layer (SIMD-optimized)
181    fn selective_scan_step(
182        &self,
183        layer_idx: usize,
184        x: &Array1<f32>,
185        h: &mut Array2<f32>,
186    ) -> Array1<f32> {
187        let a = &self.a_matrices[layer_idx];
188        let b = &self.b_matrices[layer_idx];
189        let c = &self.c_matrices[layer_idx];
190        let d = &self.d_vectors[layer_idx];
191
192        // Compute input-dependent delta (simplified)
193        let delta = 0.1; // In full implementation, this is learned
194
195        // Discretize using SIMD-optimized exp
196        let (a_bar, b_bar) = self.discretize_simd(delta, a, b);
197
198        // State update: h = A_bar * h + B_bar * x (SIMD-optimized per row)
199        for i in 0..h.nrows() {
200            let x_val = x[i];
201            let row_len = h.ncols();
202            let mut h_row = h.row_mut(i);
203            let a_row = a_bar.row(i);
204            let b_row = b_bar.row(i);
205
206            // Use SIMD FMA for each row element
207            for j in 0..row_len {
208                h_row[j] = a_row[j].mul_add(h_row[j], b_row[j] * x_val);
209            }
210        }
211
212        // Output: y = C * h + D * x (SIMD-optimized dot products)
213        let mut y = Array1::zeros(x.len());
214        for i in 0..y.len() {
215            let h_row = h.row(i);
216            let c_row = c.row(i);
217            y[i] = simd::dot_view(h_row, c_row) + d[i] * x[i];
218        }
219
220        y
221    }
222
223    /// SIMD-optimized discretization using fast_exp
224    fn discretize_simd(
225        &self,
226        delta: f32,
227        a: &Array2<f32>,
228        b: &Array2<f32>,
229    ) -> (Array2<f32>, Array2<f32>) {
230        // Zero-order hold discretization with fast exp approximation
231        let a_bar = a.mapv(|x| simd::fast_exp(delta * x));
232        let b_bar = b.mapv(|x| delta * x);
233        (a_bar, b_bar)
234    }
235}
236
237impl StateSpaceModel for SelectiveSSM {
238    fn recurrence_step(
239        &self,
240        input: &Array1<f32>,
241        state: &mut HiddenState,
242    ) -> CoreResult<Array1<f32>> {
243        // Embed input
244        let mut x = self.embedding.embed(input)?;
245
246        // Apply layer normalization
247        x = ContinuousEmbedding::layer_norm(&x, 1e-5);
248
249        // Process through each layer
250        let mut h = state.state().clone();
251        for layer_idx in 0..self.config.get_num_layers() {
252            x = self.selective_scan_step(layer_idx, &x, &mut h);
253            x = ContinuousEmbedding::layer_norm(&x, 1e-5);
254        }
255
256        // Update state
257        state.update(h);
258
259        // Project to output dimension
260        let output = x.dot(&self.output_proj);
261        Ok(output)
262    }
263
264    fn config(&self) -> &KizzasiConfig {
265        &self.config
266    }
267}
268
269impl SignalPredictor for SelectiveSSM {
270    fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
271        let mut state = self.state.clone();
272        let output = self.recurrence_step(input, &mut state)?;
273        self.state = state;
274        Ok(output)
275    }
276
277    fn reset(&mut self) {
278        self.state.reset();
279    }
280
281    fn context_window(&self) -> usize {
282        self.config.get_context_window()
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289
290    #[test]
291    fn test_selective_ssm() {
292        let config = KizzasiConfig::new()
293            .input_dim(3)
294            .output_dim(3)
295            .hidden_dim(64)
296            .state_dim(8)
297            .num_layers(2);
298
299        let mut ssm = SelectiveSSM::new(config).expect("SSM creation should succeed");
300        let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
301
302        let output = ssm.step(&input).expect("SSM step should succeed");
303        assert_eq!(output.len(), 3);
304    }
305}