use burn_core as burn;
use super::gate_controller::GateController;
use crate::activation::{Activation, ActivationConfig};
use burn::config::Config;
use burn::module::Initializer;
use burn::module::Module;
use burn::module::{Content, DisplaySettings, ModuleDisplay};
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
#[derive(Config, Debug)]
pub struct GruConfig {
pub d_input: usize,
pub d_hidden: usize,
pub bias: bool,
#[config(default = "true")]
pub reset_after: bool,
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
#[config(default = "ActivationConfig::Sigmoid")]
pub gate_activation: ActivationConfig,
#[config(default = "ActivationConfig::Tanh")]
pub hidden_activation: ActivationConfig,
pub clip: Option<f64>,
}
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct Gru<B: Backend> {
pub update_gate: GateController<B>,
pub reset_gate: GateController<B>,
pub new_gate: GateController<B>,
pub d_hidden: usize,
pub reset_after: bool,
pub gate_activation: Activation<B>,
pub hidden_activation: Activation<B>,
pub clip: Option<f64>,
}
impl<B: Backend> ModuleDisplay for Gru<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self.update_gate.input_transform.weight.shape().dims();
let bias = self.update_gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.add("reset_after", &self.reset_after)
.optional()
}
}
impl GruConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> Gru<B> {
let d_output = self.d_hidden;
let update_gate = GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
device,
);
let reset_gate = GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
device,
);
let new_gate = GateController::new(
self.d_input,
d_output,
self.bias,
self.initializer.clone(),
device,
);
Gru {
update_gate,
reset_gate,
new_gate,
d_hidden: self.d_hidden,
reset_after: self.reset_after,
gate_activation: self.gate_activation.init(device),
hidden_activation: self.hidden_activation.init(device),
clip: self.clip,
}
}
}
impl<B: Backend> Gru<B> {
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<Tensor<B, 2>>,
) -> Tensor<B, 3> {
let device = batched_input.device();
let [batch_size, seq_length, _] = batched_input.shape().dims();
self.forward_iter(
batched_input.iter_dim(1).zip(0..seq_length),
state,
batch_size,
seq_length,
&device,
)
.0
}
pub(crate) fn forward_iter<I: Iterator<Item = (Tensor<B, 3>, usize)>>(
&self,
input_timestep_iter: I,
state: Option<Tensor<B, 2>>,
batch_size: usize,
seq_length: usize,
device: &B::Device,
) -> (Tensor<B, 3>, Tensor<B, 2>) {
let mut batched_hidden_state =
Tensor::empty([batch_size, seq_length, self.d_hidden], device);
let mut hidden_t = match state {
Some(state) => state,
None => Tensor::zeros([batch_size, self.d_hidden], device),
};
for (input_t, t) in input_timestep_iter {
let input_t = input_t.squeeze_dim(1);
let biased_ug_input_sum =
self.gate_product(&input_t, &hidden_t, None, &self.update_gate);
let update_values = self.gate_activation.forward(biased_ug_input_sum);
let biased_rg_input_sum =
self.gate_product(&input_t, &hidden_t, None, &self.reset_gate);
let reset_values = self.gate_activation.forward(biased_rg_input_sum);
let biased_ng_input_sum = if self.reset_after {
self.gate_product(&input_t, &hidden_t, Some(&reset_values), &self.new_gate)
} else {
let reset_t = hidden_t.clone().mul(reset_values);
self.gate_product(&input_t, &reset_t, None, &self.new_gate)
};
let candidate_state = self.hidden_activation.forward(biased_ng_input_sum);
let one_minus_z = update_values.clone().neg().add_scalar(1.0);
hidden_t = candidate_state.mul(one_minus_z) + update_values.mul(hidden_t);
if let Some(clip) = self.clip {
hidden_t = hidden_t.clamp(-clip, clip);
}
let unsqueezed_hidden_state = hidden_t.clone().unsqueeze_dim(1);
batched_hidden_state = batched_hidden_state.slice_assign(
[0..batch_size, t..(t + 1), 0..self.d_hidden],
unsqueezed_hidden_state,
);
}
(batched_hidden_state, hidden_t)
}
fn gate_product(
&self,
input: &Tensor<B, 2>,
hidden: &Tensor<B, 2>,
reset: Option<&Tensor<B, 2>>,
gate: &GateController<B>,
) -> Tensor<B, 2> {
let input_product = input.clone().matmul(gate.input_transform.weight.val());
let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
let input_part = match &gate.input_transform.bias {
Some(bias) => input_product + bias.val().unsqueeze(),
None => input_product,
};
let hidden_part = match &gate.hidden_transform.bias {
Some(bias) => hidden_product + bias.val().unsqueeze(),
None => hidden_product,
};
match reset {
Some(r) => input_part + r.clone().mul(hidden_part),
None => input_part + hidden_part,
}
}
}
#[derive(Config, Debug)]
pub struct BiGruConfig {
pub d_input: usize,
pub d_hidden: usize,
pub bias: bool,
#[config(default = "true")]
pub reset_after: bool,
#[config(default = "Initializer::XavierNormal{gain:1.0}")]
pub initializer: Initializer,
#[config(default = true)]
pub batch_first: bool,
#[config(default = "ActivationConfig::Sigmoid")]
pub gate_activation: ActivationConfig,
#[config(default = "ActivationConfig::Tanh")]
pub hidden_activation: ActivationConfig,
pub clip: Option<f64>,
}
#[derive(Module, Debug)]
#[module(custom_display)]
pub struct BiGru<B: Backend> {
pub forward: Gru<B>,
pub reverse: Gru<B>,
pub d_hidden: usize,
pub batch_first: bool,
}
impl<B: Backend> ModuleDisplay for BiGru<B> {
fn custom_settings(&self) -> Option<DisplaySettings> {
DisplaySettings::new()
.with_new_line_after_attribute(false)
.optional()
}
fn custom_content(&self, content: Content) -> Option<Content> {
let [d_input, _] = self
.forward
.update_gate
.input_transform
.weight
.shape()
.dims();
let bias = self.forward.update_gate.input_transform.bias.is_some();
content
.add("d_input", &d_input)
.add("d_hidden", &self.d_hidden)
.add("bias", &bias)
.optional()
}
}
impl BiGruConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> BiGru<B> {
let base_config = GruConfig::new(self.d_input, self.d_hidden, self.bias)
.with_initializer(self.initializer.clone())
.with_reset_after(self.reset_after)
.with_gate_activation(self.gate_activation.clone())
.with_hidden_activation(self.hidden_activation.clone())
.with_clip(self.clip);
BiGru {
forward: base_config.clone().init(device),
reverse: base_config.init(device),
d_hidden: self.d_hidden,
batch_first: self.batch_first,
}
}
}
impl<B: Backend> BiGru<B> {
pub fn forward(
&self,
batched_input: Tensor<B, 3>,
state: Option<Tensor<B, 3>>,
) -> (Tensor<B, 3>, Tensor<B, 3>) {
let batched_input = if self.batch_first {
batched_input
} else {
batched_input.swap_dims(0, 1)
};
let device = batched_input.clone().device();
let [batch_size, seq_length, _] = batched_input.shape().dims();
let [init_state_forward, init_state_reverse] = match state {
Some(state) => {
let hidden_state_forward = state
.clone()
.slice([0..1, 0..batch_size, 0..self.d_hidden])
.squeeze_dim(0);
let hidden_state_reverse = state
.slice([1..2, 0..batch_size, 0..self.d_hidden])
.squeeze_dim(0);
[Some(hidden_state_forward), Some(hidden_state_reverse)]
}
None => [None, None],
};
let (batched_hidden_state_forward, final_state_forward) = self.forward.forward_iter(
batched_input.clone().iter_dim(1).zip(0..seq_length),
init_state_forward,
batch_size,
seq_length,
&device,
);
let (batched_hidden_state_reverse, final_state_reverse) = self.reverse.forward_iter(
batched_input.iter_dim(1).rev().zip((0..seq_length).rev()),
init_state_reverse,
batch_size,
seq_length,
&device,
);
let output = Tensor::cat(
[batched_hidden_state_forward, batched_hidden_state_reverse].to_vec(),
2,
);
let output = if self.batch_first {
output
} else {
output.swap_dims(0, 1)
};
let state = Tensor::stack([final_state_forward, final_state_reverse].to_vec(), 0);
(output, state)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{LinearRecord, TestBackend};
use burn::module::Param;
use burn::tensor::{Distribution, TensorData};
use burn::tensor::{Tolerance, ops::FloatElem};
type FT = FloatElem<TestBackend>;
fn init_gru<B: Backend>(reset_after: bool, device: &B::Device) -> Gru<B> {
fn create_gate_controller<B: Backend>(
weights: f32,
biases: f32,
d_input: usize,
d_output: usize,
bias: bool,
initializer: Initializer,
device: &B::Device,
) -> GateController<B> {
let record_1 = LinearRecord {
weight: Param::from_data(TensorData::from([[weights]]), device),
bias: Some(Param::from_data(TensorData::from([biases]), device)),
};
let record_2 = LinearRecord {
weight: Param::from_data(TensorData::from([[weights]]), device),
bias: Some(Param::from_data(TensorData::from([biases]), device)),
};
GateController::create_with_weights(
d_input,
d_output,
bias,
initializer,
record_1,
record_2,
)
}
let config = GruConfig::new(1, 1, false).with_reset_after(reset_after);
let mut gru = config.init::<B>(device);
gru.update_gate = create_gate_controller(
0.5,
0.0,
1,
1,
false,
Initializer::XavierNormal { gain: 1.0 },
device,
);
gru.reset_gate = create_gate_controller(
0.6,
0.0,
1,
1,
false,
Initializer::XavierNormal { gain: 1.0 },
device,
);
gru.new_gate = create_gate_controller(
0.7,
0.0,
1,
1,
false,
Initializer::XavierNormal { gain: 1.0 },
device,
);
gru
}
#[test]
fn tests_forward_single_input_single_feature() {
let device = Default::default();
TestBackend::seed(&device, 0);
let mut gru = init_gru::<TestBackend>(false, &device);
let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
let expected = TensorData::from([[0.034]]);
let state = gru.forward(input.clone(), None);
let output = state
.select(0, Tensor::arange(0..1, &device))
.squeeze_dim::<2>(0);
let tolerance = Tolerance::default();
output
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
gru.reset_after = true; let state = gru.forward(input, None);
let output = state
.select(0, Tensor::arange(0..1, &device))
.squeeze_dim::<2>(0);
output
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
}
#[test]
fn tests_forward_seq_len_3() {
let device = Default::default();
TestBackend::seed(&device, 0);
let mut gru = init_gru::<TestBackend>(true, &device);
let input =
Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1], [0.2], [0.3]]]), &device);
let expected = TensorData::from([[0.0341], [0.0894], [0.1575]]);
let result = gru.forward(input.clone(), None);
let output = result
.select(0, Tensor::arange(0..1, &device))
.squeeze_dim::<2>(0);
let tolerance = Tolerance::default();
output
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
gru.reset_after = false; let state = gru.forward(input, None);
let output = state
.select(0, Tensor::arange(0..1, &device))
.squeeze_dim::<2>(0);
output
.to_data()
.assert_approx_eq::<FT>(&expected, tolerance);
}
#[test]
fn test_batched_forward_pass() {
let device = Default::default();
let gru = GruConfig::new(64, 1024, true).init::<TestBackend>(&device);
let batched_input =
Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
let hidden_state = gru.forward(batched_input, None);
assert_eq!(&*hidden_state.shape(), [8, 10, 1024]);
}
#[test]
fn display() {
let config = GruConfig::new(2, 8, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{layer}"),
"Gru {d_input: 2, d_hidden: 8, bias: true, reset_after: true, params: 288}"
);
}
#[test]
fn test_bigru_batched_forward_pass() {
let device = Default::default();
let bigru = BiGruConfig::new(64, 1024, true).init::<TestBackend>(&device);
let batched_input =
Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
let (output, state) = bigru.forward(batched_input, None);
assert_eq!(&*output.shape(), [8, 10, 2048]);
assert_eq!(&*state.shape(), [2, 8, 1024]);
}
#[test]
fn test_bigru_with_initial_state() {
let device = Default::default();
let bigru = BiGruConfig::new(32, 64, true).init::<TestBackend>(&device);
let batched_input =
Tensor::<TestBackend, 3>::random([4, 5, 32], Distribution::Default, &device);
let initial_state =
Tensor::<TestBackend, 3>::random([2, 4, 64], Distribution::Default, &device);
let (output, state) = bigru.forward(batched_input, Some(initial_state));
assert_eq!(&*output.shape(), [4, 5, 128]);
assert_eq!(&*state.shape(), [2, 4, 64]);
}
#[test]
fn test_bigru_seq_first() {
let device = Default::default();
let bigru = BiGruConfig::new(32, 64, true)
.with_batch_first(false)
.init::<TestBackend>(&device);
let batched_input =
Tensor::<TestBackend, 3>::random([5, 4, 32], Distribution::Default, &device);
let (output, state) = bigru.forward(batched_input, None);
assert_eq!(&*output.shape(), [5, 4, 128]);
assert_eq!(&*state.shape(), [2, 4, 64]);
}
#[test]
fn test_bigru_against_pytorch() {
use burn::tensor::Device;
let device = Default::default();
TestBackend::seed(&device, 0);
let config = BiGruConfig::new(2, 3, true);
let mut bigru = config.init::<TestBackend>(&device);
fn create_gate_controller<const D1: usize, const D2: usize>(
input_weights: [[f32; D1]; D2],
input_biases: [f32; D1],
hidden_weights: [[f32; D1]; D1],
hidden_biases: [f32; D1],
device: &Device<TestBackend>,
) -> GateController<TestBackend> {
let d_input = input_weights[0].len();
let d_output = input_weights.len();
let input_record = LinearRecord {
weight: Param::from_data(TensorData::from(input_weights), device),
bias: Some(Param::from_data(TensorData::from(input_biases), device)),
};
let hidden_record = LinearRecord {
weight: Param::from_data(TensorData::from(hidden_weights), device),
bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
};
GateController::create_with_weights(
d_input,
d_output,
true,
Initializer::XavierUniform { gain: 1.0 },
input_record,
hidden_record,
)
}
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[
[0.949, -0.861],
[0.892, 0.927],
[-0.173, -0.301],
[-0.081, 0.992],
]]),
&device,
);
let h0 = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[[0.280, 0.360, -1.242]], [[-0.588, 0.729, -0.788]]]),
&device,
);
bigru.forward.update_gate = create_gate_controller(
[[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]],
[0.2932, -0.3519, -0.5715],
[
[-0.3471, 0.5214, 0.0961],
[0.0545, -0.4904, -0.1875],
[-0.5702, 0.4457, 0.3568],
],
[-0.0100, 0.4518, -0.4102],
&device,
);
bigru.forward.reset_gate = create_gate_controller(
[[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]],
[-0.2524, 0.3333, 0.1033],
[
[-0.2695, -0.0677, -0.4557],
[0.1472, -0.2345, -0.2662],
[-0.2660, 0.3830, -0.1630],
],
[0.1663, 0.2391, 0.1826],
&device,
);
bigru.forward.new_gate = create_gate_controller(
[[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]],
[-0.2231, -0.4428, 0.4737],
[
[0.0900, -0.1821, 0.2430],
[0.4665, 0.1551, 0.5155],
[0.0631, -0.1566, 0.3337],
],
[0.0364, -0.3941, 0.1780],
&device,
);
bigru.reverse.update_gate = create_gate_controller(
[[-0.3444, 0.1924, -0.4765], [0.5193, 0.5556, -0.5727]],
[0.1090, 0.1779, -0.5385],
[
[0.1221, 0.3925, 0.5287],
[-0.1472, -0.4187, -0.1948],
[0.3441, -0.3082, -0.2047],
],
[0.0016, -0.2148, -0.0400],
&device,
);
bigru.reverse.reset_gate = create_gate_controller(
[[-0.1988, -0.1203, -0.3422], [0.1769, 0.4788, -0.3443]],
[-0.5053, -0.3676, 0.5771],
[
[-0.3936, 0.3504, -0.4486],
[0.3063, -0.1370, -0.2914],
[-0.2334, 0.3303, 0.1760],
],
[-0.5080, -0.2488, -0.3456],
&device,
);
bigru.reverse.new_gate = create_gate_controller(
[[-0.4517, 0.2339, 0.4797], [-0.3884, 0.2067, -0.2982]],
[-0.3792, -0.1922, 0.0903],
[
[-0.5586, -0.0762, -0.3944],
[-0.3306, -0.4191, -0.4898],
[0.1442, 0.0135, -0.3179],
],
[-0.3912, -0.3963, -0.3368],
&device,
);
let expected_output_with_init = TensorData::from([[
[0.24537, 0.14018, 0.19449, -0.49777, -0.15647, 0.48392],
[0.27468, -0.14514, 0.56205, -0.60381, -0.04986, 0.15683],
[-0.04062, -0.33486, 0.52330, -0.42244, -0.12644, -0.12034],
[-0.11743, -0.53873, 0.54429, -0.64943, 0.30127, -0.41943],
]]);
let expected_hn_with_init = TensorData::from([
[[-0.11743, -0.53873, 0.54429]],
[[-0.49777, -0.15647, 0.48392]],
]);
let expected_output_without_init = TensorData::from([[
[0.07452, -0.08247, 0.46677, -0.46770, -0.18086, 0.47519],
[0.15843, -0.27144, 0.65781, -0.50286, -0.12806, 0.14884],
[-0.10704, -0.41573, 0.53954, -0.24794, -0.24003, -0.10294],
[-0.16505, -0.57952, 0.53565, -0.23598, -0.07137, -0.28937],
]]);
let expected_hn_without_init = TensorData::from([
[[-0.16505, -0.57952, 0.53565]],
[[-0.46770, -0.18086, 0.47519]],
]);
let (output_with_init, hn_with_init) = bigru.forward(input.clone(), Some(h0));
let (output_without_init, hn_without_init) = bigru.forward(input, None);
let tolerance = Tolerance::permissive();
output_with_init
.to_data()
.assert_approx_eq::<FT>(&expected_output_with_init, tolerance);
output_without_init
.to_data()
.assert_approx_eq::<FT>(&expected_output_without_init, tolerance);
hn_with_init
.to_data()
.assert_approx_eq::<FT>(&expected_hn_with_init, tolerance);
hn_without_init
.to_data()
.assert_approx_eq::<FT>(&expected_hn_without_init, tolerance);
}
#[test]
fn bigru_display() {
let config = BiGruConfig::new(2, 8, true);
let layer = config.init::<TestBackend>(&Default::default());
assert_eq!(
alloc::format!("{layer}"),
"BiGru {d_input: 2, d_hidden: 8, bias: true, params: 576}"
);
}
#[test]
fn test_gru_custom_activations() {
let device = Default::default();
let config = GruConfig::new(4, 8, true)
.with_gate_activation(ActivationConfig::Relu)
.with_hidden_activation(ActivationConfig::Relu);
let gru = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::random([2, 3, 4], Distribution::Default, &device);
let output = gru.forward(input, None);
assert_eq!(&*output.shape(), [2, 3, 8]);
}
#[test]
fn test_bigru_custom_activations() {
let device = Default::default();
let config = BiGruConfig::new(4, 8, true)
.with_gate_activation(ActivationConfig::Relu)
.with_hidden_activation(ActivationConfig::Relu);
let bigru = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::random([2, 3, 4], Distribution::Default, &device);
let (output, state) = bigru.forward(input, None);
assert_eq!(&*output.shape(), [2, 3, 16]); assert_eq!(&*state.shape(), [2, 2, 8]);
}
#[test]
fn test_gru_clipping() {
let device = Default::default();
let clip_value = 0.5;
let config = GruConfig::new(4, 8, true).with_clip(Some(clip_value));
let gru = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
let output = gru.forward(input, None);
let output_data: Vec<f32> = output.to_data().to_vec().unwrap();
for val in output_data {
assert!(
val >= -clip_value as f32 && val <= clip_value as f32,
"Value {} is outside clip range [-{}, {}]",
val,
clip_value,
clip_value
);
}
}
#[test]
fn test_bigru_clipping() {
let device = Default::default();
let clip_value = 0.3;
let config = BiGruConfig::new(4, 8, true).with_clip(Some(clip_value));
let bigru = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 3>::random([2, 5, 4], Distribution::Default, &device);
let (output, state) = bigru.forward(input, None);
let output_data: Vec<f32> = output.to_data().to_vec().unwrap();
for val in output_data {
assert!(
val >= -clip_value as f32 && val <= clip_value as f32,
"Output value {} is outside clip range [-{}, {}]",
val,
clip_value,
clip_value
);
}
let state_data: Vec<f32> = state.to_data().to_vec().unwrap();
for val in state_data {
assert!(
val >= -clip_value as f32 && val <= clip_value as f32,
"State value {} is outside clip range [-{}, {}]",
val,
clip_value,
clip_value
);
}
}
#[test]
fn test_gru_against_pytorch() {
use burn::tensor::Device;
let device = Default::default();
TestBackend::seed(&device, 0);
let config = GruConfig::new(2, 3, true);
let mut gru = config.init::<TestBackend>(&device);
fn create_gate_controller<const D1: usize, const D2: usize>(
input_weights: [[f32; D1]; D2],
input_biases: [f32; D1],
hidden_weights: [[f32; D1]; D1],
hidden_biases: [f32; D1],
device: &Device<TestBackend>,
) -> GateController<TestBackend> {
let d_input = input_weights[0].len();
let d_output = input_weights.len();
let input_record = LinearRecord {
weight: Param::from_data(TensorData::from(input_weights), device),
bias: Some(Param::from_data(TensorData::from(input_biases), device)),
};
let hidden_record = LinearRecord {
weight: Param::from_data(TensorData::from(hidden_weights), device),
bias: Some(Param::from_data(TensorData::from(hidden_biases), device)),
};
GateController::create_with_weights(
d_input,
d_output,
true,
Initializer::XavierUniform { gain: 1.0 },
input_record,
hidden_record,
)
}
let input = Tensor::<TestBackend, 3>::from_data(
TensorData::from([[
[-0.11147, 0.12036],
[-0.36963, -0.24042],
[-1.19692, 0.20927],
[-0.97236, -0.75505],
]]),
&device,
);
let h0 = Tensor::<TestBackend, 2>::from_data(
TensorData::from([[0.3239, -0.10852, 0.21033]]),
&device,
);
gru.update_gate = create_gate_controller(
[[-0.2811, 0.5090, 0.5018], [0.3391, -0.4236, 0.1081]],
[0.2932, -0.3519, -0.5715],
[
[-0.3471, 0.5214, 0.0961],
[0.0545, -0.4904, -0.1875],
[-0.5702, 0.4457, 0.3568],
],
[-0.0100, 0.4518, -0.4102],
&device,
);
gru.reset_gate = create_gate_controller(
[[0.4414, -0.1353, -0.1265], [0.4792, 0.5304, 0.1165]],
[-0.2524, 0.3333, 0.1033],
[
[-0.2695, -0.0677, -0.4557],
[0.1472, -0.2345, -0.2662],
[-0.2660, 0.3830, -0.1630],
],
[0.1663, 0.2391, 0.1826],
&device,
);
gru.new_gate = create_gate_controller(
[[0.4266, 0.2784, 0.4451], [0.0782, -0.0815, 0.0853]],
[-0.2231, -0.4428, 0.4737],
[
[0.0900, -0.1821, 0.2430],
[0.4665, 0.1551, 0.5155],
[0.0631, -0.1566, 0.3337],
],
[0.0364, -0.3941, 0.1780],
&device,
);
let expected_output_with_h0 = TensorData::from([[
[0.05665, -0.34932, 0.43267],
[-0.1737, -0.49246, 0.38099],
[-0.35401, -0.68099, 0.05061],
[-0.47854, -0.70427, -0.13648],
]]);
let expected_output_no_h0 = TensorData::from([[
[-0.0985, -0.31661, 0.36126],
[-0.24563, -0.47784, 0.34609],
[-0.39497, -0.67659, 0.03083],
[-0.50146, -0.70066, -0.14894],
]]);
let output_with_h0 = gru.forward(input.clone(), Some(h0));
let output_no_h0 = gru.forward(input, None);
let tolerance = Tolerance::permissive();
output_with_h0
.to_data()
.assert_approx_eq::<FT>(&expected_output_with_h0, tolerance);
output_no_h0
.to_data()
.assert_approx_eq::<FT>(&expected_output_no_h0, tolerance);
}
}