1use crate::error::{ModelError, ModelResult};
43use crate::{AutoregressiveModel, ModelType};
44use kizzasi_core::{
45 gelu, CausalConv1d, CoreResult, HiddenState, LayerNorm, NormType, SignalPredictor,
46};
47use scirs2_core::ndarray::{Array1, Array2};
48use scirs2_core::random::{rng, Rng};
49#[allow(unused_imports)]
50use tracing::{debug, instrument, trace};
51
52#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
54pub struct S4Config {
55 pub input_dim: usize,
57 pub hidden_dim: usize,
59 pub state_dim: usize,
61 pub num_layers: usize,
63 pub dropout: f32,
65 pub dt_min: f32,
67 pub dt_max: f32,
68 pub use_diagonal: bool,
70 pub use_rms_norm: bool,
72}
73
74impl Default for S4Config {
75 fn default() -> Self {
76 Self {
77 input_dim: 1,
78 hidden_dim: 512,
79 state_dim: 64,
80 num_layers: 6,
81 dropout: 0.0,
82 dt_min: 0.001,
83 dt_max: 0.1,
84 use_diagonal: true, use_rms_norm: true,
86 }
87 }
88}
89
90impl S4Config {
91 pub fn new() -> Self {
93 Self::default()
94 }
95
96 pub fn input_dim(mut self, dim: usize) -> Self {
98 self.input_dim = dim;
99 self
100 }
101
102 pub fn hidden_dim(mut self, dim: usize) -> Self {
104 self.hidden_dim = dim;
105 self
106 }
107
108 pub fn state_dim(mut self, dim: usize) -> Self {
110 self.state_dim = dim;
111 self
112 }
113
114 pub fn num_layers(mut self, n: usize) -> Self {
116 self.num_layers = n;
117 self
118 }
119
120 pub fn diagonal(mut self, use_diagonal: bool) -> Self {
122 self.use_diagonal = use_diagonal;
123 self
124 }
125
126 pub fn validate(&self) -> ModelResult<()> {
128 if self.hidden_dim == 0 {
129 return Err(ModelError::invalid_config("hidden_dim must be > 0"));
130 }
131 if self.state_dim == 0 {
132 return Err(ModelError::invalid_config("state_dim must be > 0"));
133 }
134 if self.num_layers == 0 {
135 return Err(ModelError::invalid_config("num_layers must be > 0"));
136 }
137 if self.dt_min <= 0.0 || self.dt_max <= 0.0 {
138 return Err(ModelError::invalid_config("dt_min and dt_max must be > 0"));
139 }
140 if self.dt_min > self.dt_max {
141 return Err(ModelError::invalid_config("dt_min must be <= dt_max"));
142 }
143 Ok(())
144 }
145}
146
147struct S4DKernel {
149 hidden_dim: usize,
150 state_dim: usize,
151
152 log_a: Array1<f32>,
155
156 b_matrix: Array2<f32>,
158
159 c_matrix: Array2<f32>,
161
162 d_skip: Array1<f32>,
164
165 log_dt: Array1<f32>,
167
168 state: Array2<f32>, }
171
172impl S4DKernel {
173 fn new(config: &S4Config) -> ModelResult<Self> {
174 let mut rng = rng();
175
176 let log_a = Array1::from_shape_fn(config.state_dim, |n| ((2 * n + 1) as f32 / 2.0).ln());
180
181 let scale = (1.0 / config.state_dim as f32).sqrt();
183 let b_matrix = Array2::from_shape_fn((config.state_dim, config.hidden_dim), |_| {
184 (rng.random::<f32>() - 0.5) * 2.0 * scale
185 });
186
187 let c_matrix = Array2::from_shape_fn((config.hidden_dim, config.state_dim), |_| {
189 (rng.random::<f32>() - 0.5) * 2.0 * scale
190 });
191
192 let d_skip = Array1::ones(config.hidden_dim);
194
195 let log_dt = Array1::from_shape_fn(config.hidden_dim, |_| {
197 let dt = config.dt_min + rng.random::<f32>() * (config.dt_max - config.dt_min);
198 dt.ln()
199 });
200
201 let state = Array2::zeros((config.hidden_dim, config.state_dim));
203
204 Ok(Self {
205 hidden_dim: config.hidden_dim,
206 state_dim: config.state_dim,
207 log_a,
208 b_matrix,
209 c_matrix,
210 d_skip,
211 log_dt,
212 state,
213 })
214 }
215
216 fn discretize(&self, dt: f32) -> (Array1<f32>, Array2<f32>) {
222 let mut a_bar = Array1::zeros(self.state_dim);
223 let mut b_bar = Array2::zeros(self.b_matrix.raw_dim());
224
225 for i in 0..self.state_dim {
226 let a_i = -self.log_a[i].exp();
228
229 a_bar[i] = (dt * a_i).exp();
231
232 let scale = (1.0 - a_bar[i]) / (-a_i);
234 for j in 0..self.hidden_dim {
235 b_bar[[i, j]] = self.b_matrix[[i, j]] * scale;
236 }
237 }
238
239 (a_bar, b_bar)
240 }
241
242 fn forward_step(&mut self, u: &Array1<f32>) -> CoreResult<Array1<f32>> {
244 let batch_size = u.len().min(self.hidden_dim);
245
246 let mut output = Array1::zeros(batch_size);
248
249 for dim in 0..batch_size {
250 let dt = self.log_dt[dim].exp();
251 let (a_bar, b_bar) = self.discretize(dt);
252
253 for i in 0..self.state_dim {
256 let bu = if dim < b_bar.shape()[1] {
257 b_bar[[i, dim]] * u[dim]
258 } else {
259 0.0
260 };
261 self.state[[dim, i]] = a_bar[i] * self.state[[dim, i]] + bu;
262 }
263
264 let mut c_h = 0.0;
266 for i in 0..self.state_dim {
267 c_h += self.c_matrix[[dim, i]] * self.state[[dim, i]];
268 }
269 output[dim] = c_h + self.d_skip[dim] * u[dim];
270 }
271
272 Ok(output)
273 }
274
275 fn reset(&mut self) {
276 self.state.fill(0.0);
277 }
278}
279
280struct S4DLayer {
282 norm: LayerNorm,
283 s4_kernel: S4DKernel,
284 conv: CausalConv1d,
285 output_proj: Array2<f32>,
286}
287
288impl S4DLayer {
289 fn new(config: &S4Config) -> ModelResult<Self> {
290 let norm_type = if config.use_rms_norm {
291 NormType::RMSNorm
292 } else {
293 NormType::LayerNorm
294 };
295
296 let norm = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
297 let s4_kernel = S4DKernel::new(config)?;
298
299 let conv = CausalConv1d::new(config.hidden_dim, config.hidden_dim, 3);
301
302 let mut rng = rng();
304 let scale = (2.0 / config.hidden_dim as f32).sqrt();
305 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.hidden_dim), |_| {
306 (rng.random::<f32>() - 0.5) * 2.0 * scale
307 });
308
309 Ok(Self {
310 norm,
311 s4_kernel,
312 conv,
313 output_proj,
314 })
315 }
316
317 fn forward(&mut self, x: &Array1<f32>) -> CoreResult<Array1<f32>> {
318 let x_norm = self.norm.forward(x);
320
321 let x_vec = x_norm.to_vec();
323 let conv_out = self.conv.forward_step(&x_vec);
324 let x_conv = Array1::from_vec(conv_out);
325
326 let ssm_out = self.s4_kernel.forward_step(&x_conv)?;
328
329 let activated = gelu(&ssm_out);
331
332 let mut projected = Array1::zeros(x.len().min(self.output_proj.shape()[0]));
334 for i in 0..projected.len() {
335 let mut sum = 0.0;
336 for j in 0..activated.len().min(self.output_proj.shape()[1]) {
337 sum += self.output_proj[[i, j]] * activated[j];
338 }
339 projected[i] = sum;
340 }
341
342 let mut output = x.clone();
344 for i in 0..output.len().min(projected.len()) {
345 output[i] += projected[i];
346 }
347
348 Ok(output)
349 }
350
351 fn reset(&mut self) {
352 self.s4_kernel.reset();
353 }
354}
355
356pub struct S4D {
358 config: S4Config,
359 layers: Vec<S4DLayer>,
360 ln_out: LayerNorm,
361 input_proj: Array2<f32>,
362 output_proj: Array2<f32>,
363}
364
365impl S4D {
366 pub fn new(config: S4Config) -> ModelResult<Self> {
368 config.validate()?;
369
370 let mut layers = Vec::with_capacity(config.num_layers);
372 for _ in 0..config.num_layers {
373 layers.push(S4DLayer::new(&config)?);
374 }
375
376 let norm_type = if config.use_rms_norm {
378 NormType::RMSNorm
379 } else {
380 NormType::LayerNorm
381 };
382 let ln_out = LayerNorm::new(config.hidden_dim, norm_type).with_eps(1e-5);
383
384 let mut rng = rng();
386 let scale = (2.0 / (config.input_dim + config.hidden_dim) as f32).sqrt();
387 let input_proj = Array2::from_shape_fn((config.input_dim, config.hidden_dim), |_| {
388 (rng.random::<f32>() - 0.5) * 2.0 * scale
389 });
390
391 let scale = (2.0 / (config.hidden_dim + config.input_dim) as f32).sqrt();
392 let output_proj = Array2::from_shape_fn((config.hidden_dim, config.input_dim), |_| {
393 (rng.random::<f32>() - 0.5) * 2.0 * scale
394 });
395
396 Ok(Self {
397 config,
398 layers,
399 ln_out,
400 input_proj,
401 output_proj,
402 })
403 }
404
405 pub fn config(&self) -> &S4Config {
407 &self.config
408 }
409
410 pub fn load_weights(&mut self, loader: &crate::loader::ModelLoader) -> ModelResult<()> {
432 if loader.has_tensor("input_proj") {
434 self.input_proj = loader.load_array2("input_proj")?;
435 }
436 if loader.has_tensor("output_proj") {
437 self.output_proj = loader.load_array2("output_proj")?;
438 }
439
440 if loader.has_tensor("ln_out.weight") {
442 let weight = loader.load_array1("ln_out.weight")?;
443 self.ln_out.set_gamma(weight);
444 }
445 if loader.has_tensor("ln_out.bias") {
446 let bias = loader.load_array1("ln_out.bias")?;
447 self.ln_out.set_beta(bias);
448 }
449
450 for (i, layer) in self.layers.iter_mut().enumerate() {
452 let prefix = format!("layers.{}", i);
453
454 if loader.has_tensor(&format!("{}.norm.weight", prefix)) {
456 let weight = loader.load_array1(&format!("{}.norm.weight", prefix))?;
457 layer.norm.set_gamma(weight);
458 }
459 if loader.has_tensor(&format!("{}.norm.bias", prefix)) {
460 let bias = loader.load_array1(&format!("{}.norm.bias", prefix))?;
461 layer.norm.set_beta(bias);
462 }
463
464 if loader.has_tensor(&format!("{}.output_proj", prefix)) {
466 layer.output_proj = loader.load_array2(&format!("{}.output_proj", prefix))?;
467 }
468
469 let kernel_prefix = format!("{}.s4_kernel", prefix);
471 if loader.has_tensor(&format!("{}.log_a", kernel_prefix)) {
472 layer.s4_kernel.log_a = loader.load_array1(&format!("{}.log_a", kernel_prefix))?;
473 }
474 if loader.has_tensor(&format!("{}.b_matrix", kernel_prefix)) {
475 layer.s4_kernel.b_matrix =
476 loader.load_array2(&format!("{}.b_matrix", kernel_prefix))?;
477 }
478 if loader.has_tensor(&format!("{}.c_matrix", kernel_prefix)) {
479 layer.s4_kernel.c_matrix =
480 loader.load_array2(&format!("{}.c_matrix", kernel_prefix))?;
481 }
482 if loader.has_tensor(&format!("{}.d_skip", kernel_prefix)) {
483 layer.s4_kernel.d_skip =
484 loader.load_array1(&format!("{}.d_skip", kernel_prefix))?;
485 }
486 if loader.has_tensor(&format!("{}.log_dt", kernel_prefix)) {
487 layer.s4_kernel.log_dt =
488 loader.load_array1(&format!("{}.log_dt", kernel_prefix))?;
489 }
490
491 if loader.has_tensor(&format!("{}.conv.weight", prefix)) {
493 let conv_weights = loader.load_array3(&format!("{}.conv.weight", prefix))?;
494 layer.conv.set_weights(conv_weights);
495 }
496 if loader.has_tensor(&format!("{}.conv.bias", prefix)) {
497 let conv_bias = loader.load_array1(&format!("{}.conv.bias", prefix))?;
498 layer.conv.set_bias(conv_bias.to_vec());
499 }
500 }
501
502 Ok(())
503 }
504
505 #[allow(unused_variables)]
507 pub fn save_weights(&self, path: &str) -> ModelResult<()> {
508 Err(ModelError::simple_load_error(
510 "S4D save_weights not yet implemented".to_string(),
511 ))
512 }
513}
514
515impl SignalPredictor for S4D {
516 #[instrument(skip(self, input))]
517 fn step(&mut self, input: &Array1<f32>) -> CoreResult<Array1<f32>> {
518 let mut hidden = input.dot(&self.input_proj);
520
521 for layer in &mut self.layers {
523 hidden = layer.forward(&hidden)?;
524 }
525
526 hidden = self.ln_out.forward(&hidden);
528
529 let output = hidden.dot(&self.output_proj);
531 Ok(output)
532 }
533
534 fn reset(&mut self) {
535 for layer in &mut self.layers {
536 layer.reset();
537 }
538 }
539
540 fn context_window(&self) -> usize {
541 usize::MAX
543 }
544}
545
546impl AutoregressiveModel for S4D {
547 fn hidden_dim(&self) -> usize {
548 self.config.hidden_dim
549 }
550
551 fn state_dim(&self) -> usize {
552 self.config.state_dim
553 }
554
555 fn num_layers(&self) -> usize {
556 self.config.num_layers
557 }
558
559 fn model_type(&self) -> ModelType {
560 ModelType::S4D
561 }
562
563 fn get_states(&self) -> Vec<HiddenState> {
564 self.layers
565 .iter()
566 .map(|layer| {
567 let state = layer.s4_kernel.state.clone();
568 let mut hs = HiddenState::new(state.shape()[0], state.shape()[1]);
569 hs.update(state);
570 hs
571 })
572 .collect()
573 }
574
575 fn set_states(&mut self, states: Vec<HiddenState>) -> ModelResult<()> {
576 if states.len() != self.config.num_layers {
577 return Err(ModelError::state_count_mismatch(
578 "S4D",
579 self.config.num_layers,
580 states.len(),
581 ));
582 }
583
584 for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
585 layer.s4_kernel.state = states[layer_idx].state().clone();
586 }
587
588 Ok(())
589 }
590}
591
592#[cfg(test)]
593mod tests {
594 use super::*;
595
596 #[test]
597 fn test_s4d_config() {
598 let config = S4Config::new().hidden_dim(256).state_dim(64).num_layers(4);
599
600 assert_eq!(config.hidden_dim, 256);
601 assert_eq!(config.state_dim, 64);
602 assert!(config.validate().is_ok());
603 }
604
605 #[test]
606 fn test_s4d_creation() {
607 let config = S4Config::new().hidden_dim(128).state_dim(32);
608 let model = S4D::new(config);
609 assert!(model.is_ok());
610 }
611
612 #[test]
613 fn test_s4d_forward() {
614 let config = S4Config::new().hidden_dim(64).state_dim(16).num_layers(2);
615 let mut model = S4D::new(config).expect("Failed to create S4D model");
616
617 let input = Array1::from_vec(vec![0.5]);
618 let output = model.step(&input);
619 assert!(output.is_ok());
620 }
621
622 #[test]
623 fn test_invalid_dt() {
624 let config = S4Config {
625 dt_min: 0.1,
626 dt_max: 0.01, ..Default::default()
628 };
629 assert!(config.validate().is_err());
630 }
631}