1#[cfg(not(feature = "std"))]
7use alloc::vec::Vec;
8
9use crate::config::KizzasiConfig;
10use crate::embedding::ContinuousEmbedding;
11use crate::error::CoreResult;
12use crate::simd;
13use crate::state::HiddenState;
14use crate::SignalPredictor;
15use scirs2_core::ndarray::{Array1, Array2};
16use scirs2_core::random::thread_rng;
17use serde::{Deserialize, Serialize};
18
19pub trait StateSpaceModel {
21 fn recurrence_step(
23 &self,
24 input: &Array1<f32>,
25 state: &mut HiddenState,
26 ) -> CoreResult<Array1<f32>>;
27
28 fn config(&self) -> &KizzasiConfig;
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
37pub struct SelectiveSSM {
38 config: KizzasiConfig,
39 embedding: ContinuousEmbedding,
40 state: HiddenState,
41 a_matrices: Vec<Array2<f32>>,
43 b_matrices: Vec<Array2<f32>>,
44 c_matrices: Vec<Array2<f32>>,
45 d_vectors: Vec<Array1<f32>>,
46 output_proj: Array2<f32>,
48}
49
50impl SelectiveSSM {
51 pub fn new(config: KizzasiConfig) -> CoreResult<Self> {
53 let hidden_dim = config.get_hidden_dim();
54 let state_dim = config.get_state_dim();
55 let num_layers = config.get_num_layers();
56 let input_dim = config.get_input_dim();
57 let output_dim = config.get_output_dim();
58
59 let embedding = ContinuousEmbedding::new(input_dim, hidden_dim);
61
62 let state = HiddenState::new(hidden_dim, state_dim);
64
65 let mut rng = thread_rng();
67 let scale = 0.01;
68 let mut a_matrices = Vec::with_capacity(num_layers);
69 let mut b_matrices = Vec::with_capacity(num_layers);
70 let mut c_matrices = Vec::with_capacity(num_layers);
71 let mut d_vectors = Vec::with_capacity(num_layers);
72
73 for _ in 0..num_layers {
74 let a = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
76 -0.5 + rng.random::<f32>() * scale
77 });
78 a_matrices.push(a);
79
80 let b = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
82 (rng.random::<f32>() - 0.5) * scale
83 });
84 b_matrices.push(b);
85
86 let c = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
88 (rng.random::<f32>() - 0.5) * scale
89 });
90 c_matrices.push(c);
91
92 let d = Array1::ones(hidden_dim);
94 d_vectors.push(d);
95 }
96
97 let output_proj = Array2::from_shape_fn((hidden_dim, output_dim), |_| {
99 (rng.random::<f32>() - 0.5) * scale
100 });
101
102 Ok(Self {
103 config,
104 embedding,
105 state,
106 a_matrices,
107 b_matrices,
108 c_matrices,
109 d_vectors,
110 output_proj,
111 })
112 }
113
114 pub fn get_state(&self) -> &HiddenState {
116 &self.state
117 }
118
119 pub fn get_state_mut(&mut self) -> &mut HiddenState {
121 &mut self.state
122 }
123
124 pub fn set_state(&mut self, state: HiddenState) {
126 self.state = state;
127 }
128
129 pub fn step_count(&self) -> usize {
131 self.state.step_count()
132 }
133
134 pub fn embedding(&self) -> &ContinuousEmbedding {
136 &self.embedding
137 }
138
139 pub fn a_matrices(&self) -> &Vec<Array2<f32>> {
141 &self.a_matrices
142 }
143
144 pub fn b_matrices(&self) -> &Vec<Array2<f32>> {
146 &self.b_matrices
147 }
148
149 pub fn c_matrices(&self) -> &Vec<Array2<f32>> {
151 &self.c_matrices
152 }
153
154 pub fn d_vectors(&self) -> &Vec<Array1<f32>> {
156 &self.d_vectors
157 }
158
159 pub fn output_proj(&self) -> &Array2<f32> {
161 &self.output_proj
162 }
163
164 #[allow(dead_code)]
166 fn discretize(
167 &self,
168 delta: f32,
169 a: &Array2<f32>,
170 b: &Array2<f32>,
171 ) -> (Array2<f32>, Array2<f32>) {
172 let a_bar = a.mapv(|x| (delta * x).exp());
176 let b_bar = b.mapv(|x| delta * x);
177 (a_bar, b_bar)
178 }
179
180 fn selective_scan_step(
182 &self,
183 layer_idx: usize,
184 x: &Array1<f32>,
185 h: &mut Array2<f32>,
186 ) -> Array1<f32> {
187 let a = &self.a_matrices[layer_idx];
188 let b = &self.b_matrices[layer_idx];
189 let c = &self.c_matrices[layer_idx];
190 let d = &self.d_vectors[layer_idx];
191
192 let delta = 0.1; let (a_bar, b_bar) = self.discretize_simd(delta, a, b);
197
198 for i in 0..h.nrows() {
200 let x_val = x[i];
201 let row_len = h.ncols();
202 let mut h_row = h.row_mut(i);
203 let a_row = a_bar.row(i);
204 let b_row = b_bar.row(i);
205
206 for j in 0..row_len {
208 h_row[j] = a_row[j].mul_add(h_row[j], b_row[j] * x_val);
209 }
210 }
211
212 let mut y = Array1::zeros(x.len());
214 for i in 0..y.len() {
215 let h_row = h.row(i);
216 let c_row = c.row(i);
217 y[i] = simd::dot_view(h_row, c_row) + d[i] * x[i];
218 }
219
220 y
221 }
222
223 fn discretize_simd(
225 &self,
226 delta: f32,
227 a: &Array2<f32>,
228 b: &Array2<f32>,
229 ) -> (Array2<f32>, Array2<f32>) {
230 let a_bar = a.mapv(|x| simd::fast_exp(delta * x));
232 let b_bar = b.mapv(|x| delta * x);
233 (a_bar, b_bar)
234 }
235}
236
237impl StateSpaceModel for SelectiveSSM {
238 fn recurrence_step(
239 &self,
240 input: &Array1<f32>,
241 state: &mut HiddenState,
242 ) -> CoreResult<Array1<f32>> {
243 let mut x = self.embedding.embed(input)?;
245
246 x = ContinuousEmbedding::layer_norm(&x, 1e-5);
248
249 let mut h = state.state().clone();
251 for layer_idx in 0..self.config.get_num_layers() {
252 x = self.selective_scan_step(layer_idx, &x, &mut h);
253 x = ContinuousEmbedding::layer_norm(&x, 1e-5);
254 }
255
256 state.update(h);
258
259 let output = x.dot(&self.output_proj);
261 Ok(output)
262 }
263
264 fn config(&self) -> &KizzasiConfig {
265 &self.config
266 }
267}
268
269impl SignalPredictor for SelectiveSSM {
270 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
271 let mut state = self.state.clone();
272 let output = self.recurrence_step(input, &mut state)?;
273 self.state = state;
274 Ok(output)
275 }
276
277 fn reset(&mut self) {
278 self.state.reset();
279 }
280
281 fn context_window(&self) -> usize {
282 self.config.get_context_window()
283 }
284}
285
286#[cfg(test)]
287mod tests {
288 use super::*;
289
290 #[test]
291 fn test_selective_ssm() {
292 let config = KizzasiConfig::new()
293 .input_dim(3)
294 .output_dim(3)
295 .hidden_dim(64)
296 .state_dim(8)
297 .num_layers(2);
298
299 let mut ssm = SelectiveSSM::new(config).expect("SSM creation should succeed");
300 let input = Array1::from_vec(vec![0.1, 0.2, 0.3]);
301
302 let output = ssm.step(&input).expect("SSM step should succeed");
303 assert_eq!(output.len(), 3);
304 }
305}