use scirs2_core::ndarray::{Array1, Array2};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HiddenState {
state: Array2<f32>,
step_count: usize,
#[serde(default)]
conv_history: Option<Vec<Vec<f32>>>,
}
impl HiddenState {
pub fn new(hidden_dim: usize, state_dim: usize) -> Self {
Self {
state: Array2::zeros((hidden_dim, state_dim)),
step_count: 0,
conv_history: None,
}
}
pub fn reset(&mut self) {
self.state.fill(0.0);
self.step_count = 0;
if let Some(ref mut hist) = self.conv_history {
for h in hist {
h.fill(0.0);
}
}
}
pub fn set_conv_history(&mut self, history: Vec<Vec<f32>>) {
self.conv_history = Some(history);
}
pub fn conv_history(&self) -> Option<&Vec<Vec<f32>>> {
self.conv_history.as_ref()
}
pub fn take_conv_history(&mut self) -> Option<Vec<Vec<f32>>> {
self.conv_history.take()
}
pub fn update(&mut self, new_state: Array2<f32>) {
self.state = new_state;
self.step_count += 1;
}
pub fn state(&self) -> &Array2<f32> {
&self.state
}
pub fn state_mut(&mut self) -> &mut Array2<f32> {
&mut self.state
}
pub fn step_count(&self) -> usize {
self.step_count
}
pub fn get_row(&self, idx: usize) -> Array1<f32> {
self.state.row(idx).to_owned()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hidden_state() {
let mut state = HiddenState::new(256, 16);
assert_eq!(state.step_count(), 0);
assert_eq!(state.state().shape(), &[256, 16]);
state.reset();
assert_eq!(state.step_count(), 0);
}
}