use super::linear::Linear;
use super::module::Module;
use crate::autograd::Tensor;
pub struct GRU {
input_size: usize,
hidden_size: usize,
w_ir: Linear,
w_hr: Linear,
w_iz: Linear,
w_hz: Linear,
w_in: Linear,
w_hn: Linear,
training: bool,
}
impl GRU {
#[must_use]
pub fn new(input_size: usize, hidden_size: usize) -> Self {
Self {
input_size,
hidden_size,
w_ir: Linear::new(input_size, hidden_size),
w_hr: Linear::new(hidden_size, hidden_size),
w_iz: Linear::new(input_size, hidden_size),
w_hz: Linear::new(hidden_size, hidden_size),
w_in: Linear::new(input_size, hidden_size),
w_hn: Linear::new(hidden_size, hidden_size),
training: true,
}
}
#[must_use]
pub fn forward_step(&self, x: &Tensor, h: &Tensor) -> Tensor {
let r = sigmoid_tensor(&add_tensors(&self.w_ir.forward(x), &self.w_hr.forward(h)));
let z = sigmoid_tensor(&add_tensors(&self.w_iz.forward(x), &self.w_hz.forward(h)));
let n = tanh_tensor(&add_tensors(
&self.w_in.forward(x),
&mul_tensors(&r, &self.w_hn.forward(h)),
));
let one_minus_z = sub_from_one(&z);
add_tensors(&mul_tensors(&one_minus_z, &n), &mul_tensors(&z, h))
}
#[must_use]
pub fn forward_sequence(&self, x: &Tensor, h0: Option<&Tensor>) -> (Tensor, Tensor) {
let batch = x.shape()[0];
let seq_len = x.shape()[1];
let mut h = match h0 {
Some(h) => h.clone(),
None => Tensor::zeros(&[batch, self.hidden_size]),
};
let mut outputs = Vec::with_capacity(seq_len * batch * self.hidden_size);
for t in 0..seq_len {
let xt = slice_timestep(x, t);
h = self.forward_step(&xt, &h);
outputs.extend_from_slice(h.data());
}
let output = Tensor::new(&outputs, &[batch, seq_len, self.hidden_size]);
(output, h)
}
#[must_use]
pub fn input_size(&self) -> usize {
self.input_size
}
#[must_use]
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
}
impl Module for GRU {
fn forward(&self, input: &Tensor) -> Tensor {
let (output, _) = self.forward_sequence(input, None);
output
}
fn parameters(&self) -> Vec<&Tensor> {
let mut p = self.w_ir.parameters();
p.extend(self.w_hr.parameters());
p.extend(self.w_iz.parameters());
p.extend(self.w_hz.parameters());
p.extend(self.w_in.parameters());
p.extend(self.w_hn.parameters());
p
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut p = self.w_ir.parameters_mut();
p.extend(self.w_hr.parameters_mut());
p.extend(self.w_iz.parameters_mut());
p.extend(self.w_hz.parameters_mut());
p.extend(self.w_in.parameters_mut());
p.extend(self.w_hn.parameters_mut());
p
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for GRU {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GRU")
.field("input_size", &self.input_size)
.field("hidden_size", &self.hidden_size)
.finish_non_exhaustive()
}
}
fn sigmoid_tensor(x: &Tensor) -> Tensor {
crate::nn::functional::sigmoid(x)
}
fn tanh_tensor(x: &Tensor) -> Tensor {
crate::nn::functional::tanh(x)
}
fn add_tensors(a: &Tensor, b: &Tensor) -> Tensor {
a.add(b)
}
fn mul_tensors(a: &Tensor, b: &Tensor) -> Tensor {
let data: Vec<f32> = a
.data()
.iter()
.zip(b.data())
.map(|(&x, &y)| x * y)
.collect();
Tensor::new(&data, a.shape())
}
fn sub_from_one(x: &Tensor) -> Tensor {
let data: Vec<f32> = x.data().iter().map(|&v| 1.0 - v).collect();
Tensor::new(&data, x.shape())
}
fn slice_timestep(x: &Tensor, t: usize) -> Tensor {
let batch = x.shape()[0];
let input_size = x.shape()[2];
let offset = t * input_size;
let mut data = Vec::with_capacity(batch * input_size);
for b in 0..batch {
let start = b * x.shape()[1] * input_size + offset;
data.extend_from_slice(&x.data()[start..start + input_size]);
}
Tensor::new(&data, &[batch, input_size])
}
pub struct Bidirectional {
forward_rnn: GRU,
backward_rnn: GRU,
input_size: usize,
hidden_size: usize,
training: bool,
}
impl Bidirectional {
#[must_use]
pub fn new(input_size: usize, hidden_size: usize) -> Self {
Self {
forward_rnn: GRU::new(input_size, hidden_size),
backward_rnn: GRU::new(input_size, hidden_size),
input_size,
hidden_size,
training: true,
}
}
#[must_use]
pub fn forward_sequence(&self, x: &Tensor) -> (Tensor, Tensor, Tensor) {
let batch = x.shape()[0];
let seq_len = x.shape()[1];
let (fwd_out, fwd_h) = self.forward_rnn.forward_sequence(x, None);
let x_rev = reverse_sequence(x);
let (bwd_out_rev, bwd_h) = self.backward_rnn.forward_sequence(&x_rev, None);
let bwd_out = reverse_sequence(&bwd_out_rev);
let output = concat_last_dim(&fwd_out, &bwd_out, batch, seq_len, self.hidden_size);
(output, fwd_h, bwd_h)
}
#[must_use]
pub fn output_size(&self) -> usize {
self.hidden_size * 2
}
#[must_use]
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
}
impl Module for Bidirectional {
fn forward(&self, input: &Tensor) -> Tensor {
let (output, _, _) = self.forward_sequence(input);
output
}
fn parameters(&self) -> Vec<&Tensor> {
let mut p = self.forward_rnn.parameters();
p.extend(self.backward_rnn.parameters());
p
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut p = self.forward_rnn.parameters_mut();
p.extend(self.backward_rnn.parameters_mut());
p
}
fn train(&mut self) {
self.training = true;
self.forward_rnn.train();
self.backward_rnn.train();
}
fn eval(&mut self) {
self.training = false;
self.forward_rnn.eval();
self.backward_rnn.eval();
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for Bidirectional {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Bidirectional")
.field("input_size", &self.input_size)
.field("hidden_size", &self.hidden_size)
.finish_non_exhaustive()
}
}
fn reverse_sequence(x: &Tensor) -> Tensor {
let (batch, seq_len, features) = (x.shape()[0], x.shape()[1], x.shape()[2]);
let mut data = vec![0.0; batch * seq_len * features];
for b in 0..batch {
for t in 0..seq_len {
let src = b * seq_len * features + t * features;
let dst = b * seq_len * features + (seq_len - 1 - t) * features;
data[dst..dst + features].copy_from_slice(&x.data()[src..src + features]);
}
}
Tensor::new(&data, &[batch, seq_len, features])
}
fn concat_last_dim(a: &Tensor, b: &Tensor, batch: usize, seq_len: usize, hidden: usize) -> Tensor {
let out_size = hidden * 2;
let mut data = vec![0.0; batch * seq_len * out_size];
for ba in 0..batch {
for t in 0..seq_len {
let dst = ba * seq_len * out_size + t * out_size;
let src_a = ba * seq_len * hidden + t * hidden;
let src_b = ba * seq_len * hidden + t * hidden;
data[dst..dst + hidden].copy_from_slice(&a.data()[src_a..src_a + hidden]);
data[dst + hidden..dst + out_size].copy_from_slice(&b.data()[src_b..src_b + hidden]);
}
}
Tensor::new(&data, &[batch, seq_len, out_size])
}
pub struct LSTM {
input_size: usize,
hidden_size: usize,
w_if: Linear,
w_hf: Linear,
w_ii: Linear,
w_hi: Linear,
w_ig: Linear,
w_hg: Linear,
w_io: Linear,
w_ho: Linear,
training: bool,
}
impl LSTM {
#[must_use]
pub fn new(input_size: usize, hidden_size: usize) -> Self {
Self {
input_size,
hidden_size,
w_if: Linear::new(input_size, hidden_size),
w_hf: Linear::new(hidden_size, hidden_size),
w_ii: Linear::new(input_size, hidden_size),
w_hi: Linear::new(hidden_size, hidden_size),
w_ig: Linear::new(input_size, hidden_size),
w_hg: Linear::new(hidden_size, hidden_size),
w_io: Linear::new(input_size, hidden_size),
w_ho: Linear::new(hidden_size, hidden_size),
training: true,
}
}
#[must_use]
pub fn forward_step(&self, x: &Tensor, h: &Tensor, c: &Tensor) -> (Tensor, Tensor) {
let f = sigmoid_tensor(&add_tensors(&self.w_if.forward(x), &self.w_hf.forward(h)));
let i = sigmoid_tensor(&add_tensors(&self.w_ii.forward(x), &self.w_hi.forward(h)));
let g = tanh_tensor(&add_tensors(&self.w_ig.forward(x), &self.w_hg.forward(h)));
let o = sigmoid_tensor(&add_tensors(&self.w_io.forward(x), &self.w_ho.forward(h)));
let c_new = add_tensors(&mul_tensors(&f, c), &mul_tensors(&i, &g));
let h_new = mul_tensors(&o, &tanh_tensor(&c_new));
(h_new, c_new)
}
#[must_use]
pub fn forward_sequence(
&self,
x: &Tensor,
h0: Option<&Tensor>,
c0: Option<&Tensor>,
) -> (Tensor, Tensor, Tensor) {
let batch = x.shape()[0];
let seq_len = x.shape()[1];
let mut h = match h0 {
Some(h) => h.clone(),
None => Tensor::zeros(&[batch, self.hidden_size]),
};
let mut c = match c0 {
Some(c) => c.clone(),
None => Tensor::zeros(&[batch, self.hidden_size]),
};
let mut outputs = Vec::with_capacity(seq_len * batch * self.hidden_size);
for t in 0..seq_len {
let xt = slice_timestep(x, t);
let (h_new, c_new) = self.forward_step(&xt, &h, &c);
h = h_new;
c = c_new;
outputs.extend_from_slice(h.data());
}
let output = Tensor::new(&outputs, &[batch, seq_len, self.hidden_size]);
(output, h, c)
}
#[must_use]
pub fn input_size(&self) -> usize {
self.input_size
}
#[must_use]
pub fn hidden_size(&self) -> usize {
self.hidden_size
}
}
impl Module for LSTM {
fn forward(&self, input: &Tensor) -> Tensor {
let (output, _, _) = self.forward_sequence(input, None, None);
output
}
fn parameters(&self) -> Vec<&Tensor> {
let mut p = self.w_if.parameters();
p.extend(self.w_hf.parameters());
p.extend(self.w_ii.parameters());
p.extend(self.w_hi.parameters());
p.extend(self.w_ig.parameters());
p.extend(self.w_hg.parameters());
p.extend(self.w_io.parameters());
p.extend(self.w_ho.parameters());
p
}
fn parameters_mut(&mut self) -> Vec<&mut Tensor> {
let mut p = self.w_if.parameters_mut();
p.extend(self.w_hf.parameters_mut());
p.extend(self.w_ii.parameters_mut());
p.extend(self.w_hi.parameters_mut());
p.extend(self.w_ig.parameters_mut());
p.extend(self.w_hg.parameters_mut());
p.extend(self.w_io.parameters_mut());
p.extend(self.w_ho.parameters_mut());
p
}
fn train(&mut self) {
self.training = true;
}
fn eval(&mut self) {
self.training = false;
}
fn training(&self) -> bool {
self.training
}
}
impl std::fmt::Debug for LSTM {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("LSTM")
.field("input_size", &self.input_size)
.field("hidden_size", &self.hidden_size)
.finish_non_exhaustive()
}
}
#[cfg(test)]
#[path = "rnn_tests.rs"]
mod tests;