use zyx::{DType, Tensor, ZyxError};
use zyx_derive::Module;
#[derive(Debug, Module)]
#[cfg_attr(feature = "py", pyo3::pyclass)]
pub struct GRUCell {
pub weight_ih: Tensor, pub weight_hh: Tensor, pub bias_ih: Option<Tensor>, pub bias_hh: Option<Tensor>, pub hidden_size: u64,
}
impl GRUCell {
pub fn new(
input_size: u64,
hidden_size: u64,
bias: bool,
dtype: DType,
) -> Result<Self, ZyxError> {
let limit = 1.0 / (hidden_size as f32).sqrt();
Ok(GRUCell {
weight_ih: Tensor::uniform([3 * hidden_size, input_size], -limit..limit)?.cast(dtype),
weight_hh: Tensor::uniform([3 * hidden_size, hidden_size], -limit..limit)?.cast(dtype),
bias_ih: if bias {
Some(Tensor::uniform([3 * hidden_size], -limit..limit)?.cast(dtype))
} else {
None
},
bias_hh: if bias {
Some(Tensor::uniform([3 * hidden_size], -limit..limit)?.cast(dtype))
} else {
None
},
hidden_size,
})
}
pub fn forward(&self, input: Tensor, hx: Tensor) -> Result<Tensor, ZyxError> {
let hs = self.hidden_size;
let mut gates = input.matmul(&self.weight_ih.t())?;
if let Some(b_ih) = &self.bias_ih {
gates = gates + b_ih.reshape([1, 3 * hs])?;
}
let mut gates_h = hx.matmul(&self.weight_hh.t())?;
if let Some(b_hh) = &self.bias_hh {
gates_h = gates_h + b_hh.reshape([1, 3 * hs])?;
}
let z = (gates.slice((.., 0..hs))? + gates_h.slice((.., 0..hs))?).sigmoid();
let r = (gates.slice((.., hs..2 * hs))? + gates_h.slice((.., hs..2 * hs))?).sigmoid();
let n_input = gates.slice((.., 2 * hs..3 * hs))?;
let n_hidden = gates_h.slice((.., 2 * hs..3 * hs))?;
let n = (n_input + r * n_hidden).tanh();
let one = Tensor::ones_like(&z);
let h_next = (one - &z) * n + &z * hx;
Ok(h_next)
}
}