1use crate::error::{ModelError, ModelResult};
35use crate::{AutoregressiveModel, ModelType};
36use kizzasi_core::{gelu, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor};
37use scirs2_core::ndarray::{Array1, Array2};
38use scirs2_core::random::{rng, Rng};
39
40#[allow(unused_imports)]
41use tracing::{debug, instrument, trace};
42
43#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
45pub struct S5Config {
46 pub input_dim: usize,
48 pub hidden_dim: usize,
50 pub state_dim: usize,
52 pub num_layers: usize,
54 pub dt: f32,
56 pub block_size: usize,
58}
59
60impl S5Config {
61 pub fn new(input_dim: usize, hidden_dim: usize, num_layers: usize) -> Self {
63 Self {
64 input_dim,
65 hidden_dim,
66 state_dim: 64,
67 num_layers,
68 dt: 0.001,
69 block_size: 64,
70 }
71 }
72
73 pub fn validate(&self) -> ModelResult<()> {
75 if self.hidden_dim == 0 {
76 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
77 }
78 if self.state_dim == 0 {
79 return Err(ModelError::invalid_config("state_dim must be > 0"));
80 }
81 if self.num_layers == 0 {
82 return Err(ModelError::invalid_config("num_layers must be > 0"));
83 }
84 if self.dt <= 0.0 {
85 return Err(ModelError::invalid_config("dt must be > 0"));
86 }
87 if self.block_size == 0 {
88 return Err(ModelError::invalid_config("block_size must be > 0"));
89 }
90 Ok(())
91 }
92}
93
94#[allow(dead_code)]
96struct S5Block {
97 log_a: Array1<f32>,
99 b_matrix: Array2<f32>,
101 c_matrix: Array2<f32>,
103 d_vec: Array1<f32>,
105 dt: f32,
107 a_bar: Array1<f32>,
109 b_bar: Array2<f32>,
111 state: Array1<f32>,
113}
114
115impl S5Block {
116 fn new(hidden_dim: usize, state_dim: usize, dt: f32) -> Self {
118 let mut rng = rng();
119
120 let log_a = Array1::from_shape_fn(state_dim, |i| -((i + 1) as f32).ln());
122
123 let scale_b = (2.0 / (state_dim + hidden_dim) as f32).sqrt();
125 let b_matrix = Array2::from_shape_fn((state_dim, hidden_dim), |_| {
126 (rng.random::<f32>() - 0.5) * 2.0 * scale_b
127 });
128
129 let scale_c = (2.0 / (hidden_dim + state_dim) as f32).sqrt();
130 let c_matrix = Array2::from_shape_fn((hidden_dim, state_dim), |_| {
131 (rng.random::<f32>() - 0.5) * 2.0 * scale_c
132 });
133
134 let d_vec = Array1::from_shape_fn(hidden_dim, |_| rng.random::<f32>() * 0.01);
136
137 let a_bar = log_a.mapv(|log_a_i| (dt * log_a_i.exp()).exp());
139 let b_bar = b_matrix.clone() * dt;
140
141 let state = Array1::zeros(state_dim);
142
143 Self {
144 log_a,
145 b_matrix,
146 c_matrix,
147 d_vec,
148 dt,
149 a_bar,
150 b_bar,
151 state,
152 }
153 }
154
155 #[instrument(skip(self, x))]
157 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
158 self.state = &self.state * &self.a_bar + self.b_bar.dot(x);
160
161 let y = self.c_matrix.dot(&self.state) + &self.d_vec * x;
163
164 Ok(y)
165 }
166
167 fn reset(&mut self) {
169 self.state.fill(0.0);
170 }
171}
172
173struct S5Layer {
175 input_proj: Array2<f32>,
177 s5_block: S5Block,
179 layer_norm: LayerNorm,
181 output_proj: Array2<f32>,
183}
184
185impl S5Layer {
186 fn new(config: &S5Config) -> ModelResult<Self> {
188 let mut rng = rng();
189
190 let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
192 let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
193 (rng.random::<f32>() - 0.5) * 2.0 * scale
194 });
195
196 let s5_block = S5Block::new(config.hidden_dim, config.state_dim, config.dt);
198
199 let layer_norm = LayerNorm::new(config.hidden_dim, NormType::RMSNorm);
201
202 let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
204 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
205 (rng.random::<f32>() - 0.5) * 2.0 * scale
206 });
207
208 Ok(Self {
209 input_proj,
210 s5_block,
211 layer_norm,
212 output_proj,
213 })
214 }
215
216 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
218 let hidden = x.dot(&self.input_proj);
220
221 let ssm_out = self.s5_block.forward(&hidden)?;
223
224 let activated = gelu(&ssm_out);
226
227 let normed = self.layer_norm.forward(&activated);
229
230 let output = normed.dot(&self.output_proj) + x;
232
233 Ok(output)
234 }
235
236 fn reset(&mut self) {
238 self.s5_block.reset();
239 }
240}
241
242pub struct S5 {
244 config: S5Config,
245 layers: Vec<S5Layer>,
246}
247
248impl S5 {
249 #[instrument(skip(config), fields(input_dim = config.input_dim, hidden_dim = config.hidden_dim, num_layers = config.num_layers))]
251 pub fn new(config: S5Config) -> ModelResult<Self> {
252 debug!("Creating new S5 model");
253 config.validate()?;
254
255 let mut layers = Vec::with_capacity(config.num_layers);
256 for layer_idx in 0..config.num_layers {
257 trace!("Initializing S5 layer {}", layer_idx);
258 layers.push(S5Layer::new(&config)?);
259 }
260 debug!("Initialized {} S5 layers", layers.len());
261
262 debug!("S5 model created successfully");
263 Ok(Self { config, layers })
264 }
265
266 pub fn config(&self) -> &S5Config {
268 &self.config
269 }
270}
271
272impl SignalPredictor for S5 {
273 #[instrument(skip(self, input))]
274 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
275 let mut x = input.clone();
276
277 for layer in &mut self.layers {
278 x = layer.forward(&x)?;
279 }
280
281 Ok(x)
282 }
283
284 #[instrument(skip(self))]
285 fn reset(&mut self) {
286 debug!("Resetting S5 model state");
287 for layer in &mut self.layers {
288 layer.reset();
289 }
290 }
291
292 fn context_window(&self) -> usize {
293 usize::MAX
295 }
296}
297
298impl AutoregressiveModel for S5 {
299 fn hidden_dim(&self) -> usize {
300 self.config.hidden_dim
301 }
302
303 fn state_dim(&self) -> usize {
304 self.config.state_dim
305 }
306
307 fn num_layers(&self) -> usize {
308 self.config.num_layers
309 }
310
311 fn model_type(&self) -> ModelType {
312 ModelType::S4 }
314
315 fn get_states(&self) -> Vec<HiddenState> {
316 self.layers
317 .iter()
318 .map(|layer| {
319 let state_1d = layer.s5_block.state.clone();
321 let state_2d = state_1d.insert_axis(scirs2_core::ndarray::Axis(0));
322 let mut hidden_state = HiddenState::new(
323 self.config.hidden_dim,
324 state_2d.len_of(scirs2_core::ndarray::Axis(1)),
325 );
326 hidden_state.update(state_2d);
327 hidden_state
328 })
329 .collect()
330 }
331
332 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
333 if states.len() != self.config.num_layers {
334 return Err(ModelError::state_count_mismatch(
335 "S5",
336 self.config.num_layers,
337 states.len(),
338 ));
339 }
340
341 for (layer, state) in self.layers.iter_mut().zip(states.iter()) {
342 let state_2d = state.state();
344 if state_2d.nrows() > 0 && state_2d.ncols() > 0 {
345 layer.s5_block.state = state_2d.row(0).to_owned();
346 }
347 }
348
349 Ok(())
350 }
351}
352
353#[cfg(test)]
354mod tests {
355 use super::*;
356
357 #[test]
358 fn test_s5_creation() {
359 let config = S5Config::new(32, 64, 2);
360 let model = S5::new(config);
361 assert!(model.is_ok());
362 }
363
364 #[test]
365 fn test_s5_forward() {
366 let config = S5Config::new(32, 64, 2);
367 let mut model = S5::new(config).expect("Failed to create S5 model");
368
369 let input = Array1::from_vec(vec![1.0; 32]);
370 let output = model.step(&input);
371 assert!(output.is_ok());
372 assert_eq!(output.expect("Failed to get output").len(), 32);
373 }
374
375 #[test]
376 fn test_s5_reset() {
377 let config = S5Config::new(32, 64, 2);
378 let mut model = S5::new(config).expect("Failed to create S5 model");
379
380 let input = Array1::from_vec(vec![1.0; 32]);
381 let _output1 = model.step(&input).expect("Failed to get output1");
382
383 model.reset();
384
385 let output2 = model.step(&input).expect("Failed to get output2");
386 assert_eq!(output2.len(), 32);
388 }
389
390 #[test]
391 fn test_invalid_config() {
392 let mut config = S5Config::new(32, 64, 2);
393 config.state_dim = 0;
394 assert!(config.validate().is_err());
395 }
396}