use crate as burn;
use crate::config::Config;
use crate::module::Module;
use crate::nn::rnn::gate_controller;
use crate::nn::Initializer;
use crate::nn::LinearConfig;
use crate::tensor::backend::Backend;
use crate::tensor::Tensor;
use burn_tensor::activation;
use super::gate_controller::GateController;
#[derive(Config)]
pub struct GruConfig {
pub d_input: usize,
pub d_hidden: usize,
pub bias: bool,
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
}
#[derive(Module, Debug)]
pub struct Gru<B: Backend> {
update_gate: GateController<B>,
reset_gate: GateController<B>,
new_gate: GateController<B>,
d_hidden: usize,
}
impl GruConfig {
pub fn init<B: Backend>(&self) -> Gru<B> {
let d_output = self.d_hidden;
let update_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
let reset_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
let new_gate = gate_controller::GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
);
Gru {
update_gate,
reset_gate,
new_gate,
d_hidden: self.d_hidden,
}
}
pub fn init_with<B: Backend>(self, record: GruRecord<B>) -> Gru<B> {
let linear_config = LinearConfig {
d_input: self.d_input,
d_output: self.d_hidden,
bias: self.bias,
initializer: self.initializer.clone(),
};
Gru {
update_gate: gate_controller::GateController::new_with(
&linear_config,
record.update_gate,
),
reset_gate: gate_controller::GateController::new_with(
&linear_config,
record.reset_gate,
),
new_gate: gate_controller::GateController::new_with(&linear_config, record.new_gate),
d_hidden: self.d_hidden,
}
}
}
impl<B: Backend> Gru<B> {
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<Tensor<B, 3>>,
) -> Tensor<B, 3> {
let [batch_size, seq_length, _] = batched_input.shape().dims;
let mut hidden_state = match state {
Some(state) => state,
None => Tensor::zeros([batch_size, seq_length, self.d_hidden]),
};
for (t, (input_t, hidden_t)) in batched_input
.iter_dim(1)
.zip(hidden_state.clone().iter_dim(1))
.enumerate()
{
let input_t = input_t.squeeze(1);
let hidden_t = hidden_t.squeeze(1);
let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate);
let update_values = activation::sigmoid(biased_ug_input_sum); let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate);
let reset_values = activation::sigmoid(biased_rg_input_sum); let reset_t = hidden_t.clone().mul(reset_values); let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate);
let candidate_state = biased_ng_input_sum.tanh(); let state_vector = candidate_state
.clone()
.mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) + update_values.clone().mul(hidden_t);
let current_shape = state_vector.shape().dims;
let unsqueezed_shape = [current_shape[0], 1, current_shape[1]];
let reshaped_state_vector = state_vector.reshape(unsqueezed_shape);
hidden_state = hidden_state.slice_assign(
[0..batch_size, t..(t + 1), 0..self.d_hidden],
reshaped_state_vector,
);
}
hidden_state
}
fn gate_product(
&self,
input: &Tensor<B, 2>,
hidden: &Tensor<B, 2>,
gate: &GateController<B>,
) -> Tensor<B, 2> {
let input_product = input.clone().matmul(gate.input_transform.weight.val());
let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
let input_bias = gate
.input_transform
.bias
.as_ref()
.map(|bias_param| bias_param.val());
let hidden_bias = gate
.hidden_transform
.bias
.as_ref()
.map(|bias_param| bias_param.val());
match (input_bias, hidden_bias) {
(Some(input_bias), Some(hidden_bias)) => {
input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
}
(Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product,
(None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(),
(None, None) => input_product + hidden_product,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{module::Param, nn::LinearRecord, TestBackend};
use burn_tensor::{Data, Distribution};
#[test]
fn tests_forward_single_input_single_feature() {
TestBackend::seed(0);
let config = GruConfig::new(1, 1, false);
let mut gru = config.init::<TestBackend>();
fn create_gate_controller(
weights: f32,
biases: f32,
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
) -> GateController<TestBackend> {
let record = LinearRecord {
weight: Param::from(Tensor::from_data(Data::from([[weights]]))),
bias: Some(Param::from(Tensor::from_data(Data::from([biases])))),
};
gate_controller::GateController::create_with_weights(
d_input,
d_output,
bias,
initializer,
record.clone(),
record,
)
}
gru.update_gate = create_gate_controller(
0.5,
0.0,
1,
1,
false,
Initializer::XavierNormal { gain: 1.0 },
);
gru.reset_gate = create_gate_controller(
0.6,
0.0,
1,
1,
false,
Initializer::XavierNormal { gain: 1.0 },
);
gru.new_gate = create_gate_controller(
0.7,
0.0,
1,
1,
false,
Initializer::XavierNormal { gain: 1.0 },
);
let input = Tensor::<TestBackend, 3>::from_data(Data::from([[[0.1]]]));
let state = gru.forward(input, None);
let output = state.select(0, Tensor::arange(0..1)).squeeze(0);
output.to_data().assert_approx_eq(&Data::from([[0.034]]), 3);
}
#[test]
fn test_batched_forward_pass() {
let gru = GruConfig::new(64, 1024, true).init::<TestBackend>();
let batched_input = Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default);
let hidden_state = gru.forward(batched_input, None);
assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
}
}