kizzasi_core/
state.rs

1//! Hidden state management for SSM
2
3use scirs2_core::ndarray::{Array1, Array2};
4use serde::{Deserialize, Serialize};
5
6/// Represents the hidden state of the SSM
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct HiddenState {
9    /// The recurrent state tensor (batch, layers, hidden_dim, state_dim)
10    state: Array2<f32>,
11    /// Number of steps processed
12    step_count: usize,
13    /// Optional convolution history buffer (for causal conv layers)
14    /// Note: The `default` attribute ensures backward compatibility with older checkpoints
15    #[serde(default)]
16    conv_history: Option<Vec<Vec<f32>>>,
17}
18
19impl HiddenState {
20    /// Create a new hidden state with given dimensions
21    pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
22        Self {
23            state: Array2::zeros((hidden_dim, state_dim)),
24            step_count: 0,
25            conv_history: None,
26        }
27    }
28
29    /// Reset the state to zeros
30    pub fn reset(&mut self) {
31        self.state.fill(0.0);
32        self.step_count = 0;
33        if let Some(ref mut hist) = self.conv_history {
34            for h in hist {
35                h.fill(0.0);
36            }
37        }
38    }
39
40    /// Set the convolution history
41    pub fn set_conv_history(&mut self, history: Vec<Vec<f32>>) {
42        self.conv_history = Some(history);
43    }
44
45    /// Get the convolution history
46    pub fn conv_history(&self) -> Option<&Vec<Vec<f32>>> {
47        self.conv_history.as_ref()
48    }
49
50    /// Take the convolution history (moves ownership)
51    pub fn take_conv_history(&mut self) -> Option<Vec<Vec<f32>>> {
52        self.conv_history.take()
53    }
54
55    /// Update the state with new values
56    pub fn update(&mut self, new_state: Array2<f32>) {
57        self.state = new_state;
58        self.step_count += 1;
59    }
60
61    /// Get the current state
62    pub fn state(&self) -> &Array2<f32> {
63        &self.state
64    }
65
66    /// Get mutable reference to state
67    pub fn state_mut(&mut self) -> &mut Array2<f32> {
68        &mut self.state
69    }
70
71    /// Get the step count
72    pub fn step_count(&self) -> usize {
73        self.step_count
74    }
75
76    /// Get a row of the state as 1D array
77    pub fn get_row(&self, idx: usize) -> Array1<f32> {
78        self.state.row(idx).to_owned()
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn test_hidden_state() {
88        let mut state = HiddenState::new(256, 16);
89        assert_eq!(state.step_count(), 0);
90        assert_eq!(state.state().shape(), &[256, 16]);
91
92        state.reset();
93        assert_eq!(state.step_count(), 0);
94    }
95}