use zyx::{DType, Tensor, ZyxError};
use zyx_derive::Module;
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct RNNCell {
pub weight_ih: Tensor,
pub weight_hh: Tensor,
pub bias_ih: Option<Tensor>,
pub bias_hh: Option<Tensor>,
pub hidden_size: u64,
nonlinearity: &'static str,
}
impl RNNCell {
pub fn new(
input_size: u64,
hidden_size: u64,
bias: bool,
nonlinearity: &'static str,
dtype: Option<DType>,
) -> Result<RNNCell, ZyxError> {
let dtype = dtype.unwrap_or(DType::F32);
let gain = match nonlinearity {
"relu" => 1.0 / (3_u32.pow(2) as f32).sqrt(),
_ => 1.0,
};
let scale = gain / (hidden_size as f32).sqrt();
let weight_ih = Tensor::uniform([hidden_size, input_size], -scale..scale)?.cast(dtype);
let weight_hh = Tensor::uniform([hidden_size, hidden_size], -scale..scale)?.cast(dtype);
let bias_ih = if bias {
Some(Tensor::zeros([hidden_size], dtype))
} else {
None
};
let bias_hh = if bias {
Some(Tensor::zeros([hidden_size], dtype))
} else {
None
};
Ok(RNNCell {
weight_ih,
weight_hh,
bias_ih,
bias_hh,
hidden_size,
nonlinearity,
})
}
pub fn forward(&self, x: &Tensor, hx: &Tensor) -> Result<Tensor, ZyxError> {
let h_new = x.matmul(&self.weight_ih.t())? + hx.matmul(&self.weight_hh.t())?;
let h_new = if let Some(b) = &self.bias_ih {
h_new + b
} else {
h_new
};
let h_new = if let Some(b) = &self.bias_hh {
h_new + b
} else {
h_new
};
let h_new = match self.nonlinearity {
"relu" => h_new.relu(),
_ => h_new.tanh(),
};
Ok(h_new)
}
}