use zyx::{DType, Tensor, ZyxError};
use zyx_derive::Module;
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct LSTMCell {
w_ih: Tensor, w_hh: Tensor, b_ih: Option<Tensor>, b_hh: Option<Tensor>, hidden_size: u64,
}
impl LSTMCell {
pub fn new(
input_size: u64,
hidden_size: u64,
bias: bool,
dtype: Option<DType>,
) -> Result<Self, ZyxError> {
let dtype = dtype.unwrap_or(DType::F32);
let k = (1.0 / (hidden_size as f32).sqrt()) as f32;
let w_ih = Tensor::uniform([4 * hidden_size, input_size], -k..k)?.cast(dtype);
let w_hh = Tensor::uniform([4 * hidden_size, hidden_size], -k..k)?.cast(dtype);
let (b_ih, b_hh) = if bias {
(
Some(Tensor::zeros([4 * hidden_size], dtype)),
Some(Tensor::zeros([4 * hidden_size], dtype)),
)
} else {
(None, None)
};
Ok(Self {
w_ih,
w_hh,
b_ih,
b_hh,
hidden_size,
})
}
pub fn forward(
&self,
x: &Tensor,
h: &Tensor,
c: &Tensor,
) -> Result<(Tensor, Tensor), ZyxError> {
let hs = self.hidden_size;
let mut gates = x.matmul(&self.w_ih.t())? + h.matmul(&self.w_hh.t())?;
if let Some(b) = &self.b_ih {
gates = &gates + b;
}
if let Some(b) = &self.b_hh {
gates = &gates + b;
}
let i = gates.narrow(1, 0, hs)?.sigmoid();
let f = gates.narrow(1, hs, hs)?.sigmoid();
let g = gates.narrow(1, 2 * hs, hs)?.tanh();
let o = gates.narrow(1, 3 * hs, hs)?.sigmoid();
let c_next = &f * c + &i * &g;
let h_next = &o * c_next.tanh();
Ok((h_next, c_next))
}
}