1use scirs2_core::ndarray::{Array1, Array2};
4use serde::{Deserialize, Serialize};
5
6#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct HiddenState {
9 state: Array2<f32>,
11 step_count: usize,
13 #[serde(default)]
16 conv_history: Option<Vec<Vec<f32>>>,
17}
18
19impl HiddenState {
20 pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
22 Self {
23 state: Array2::zeros((hidden_dim, state_dim)),
24 step_count: 0,
25 conv_history: None,
26 }
27 }
28
29 pub fn reset(&mut self) {
31 self.state.fill(0.0);
32 self.step_count = 0;
33 if let Some(ref mut hist) = self.conv_history {
34 for h in hist {
35 h.fill(0.0);
36 }
37 }
38 }
39
40 pub fn set_conv_history(&mut self, history: Vec<Vec<f32>>) {
42 self.conv_history = Some(history);
43 }
44
45 pub fn conv_history(&self) -> Option<&Vec<Vec<f32>>> {
47 self.conv_history.as_ref()
48 }
49
50 pub fn take_conv_history(&mut self) -> Option<Vec<Vec<f32>>> {
52 self.conv_history.take()
53 }
54
55 pub fn update(&mut self, new_state: Array2<f32>) {
57 self.state = new_state;
58 self.step_count += 1;
59 }
60
61 pub fn state(&self) -> &Array2<f32> {
63 &self.state
64 }
65
66 pub fn state_mut(&mut self) -> &mut Array2<f32> {
68 &mut self.state
69 }
70
71 pub fn step_count(&self) -> usize {
73 self.step_count
74 }
75
76 pub fn get_row(&self, idx: usize) -> Array1<f32> {
78 self.state.row(idx).to_owned()
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn test_hidden_state() {
88 let mut state = HiddenState::new(256, 16);
89 assert_eq!(state.step_count(), 0);
90 assert_eq!(state.state().shape(), &[256, 16]);
91
92 state.reset();
93 assert_eq!(state.step_count(), 0);
94 }
95}