1use crate as burn;
2
3use crate::config::Config;
4use crate::module::Module;
5use crate::module::{Content, DisplaySettings, ModuleDisplay};
6use crate::nn::rnn::gate_controller;
7use crate::nn::Initializer;
8use crate::tensor::activation;
9use crate::tensor::backend::Backend;
10use crate::tensor::Tensor;
11
12use super::gate_controller::GateController;
13
14#[derive(Config)]
16pub struct GruConfig {
17 pub d_input: usize,
19 pub d_hidden: usize,
21 pub bias: bool,
23 #[config(default = "Initializer::XavierNormal{gain:1.0}")]
25 pub initializer: Initializer,
26}
27
28#[derive(Module, Debug)]
34#[module(custom_display)]
35pub struct Gru<B: Backend> {
36 pub update_gate: GateController<B>,
38 pub reset_gate: GateController<B>,
40 pub new_gate: GateController<B>,
42 pub d_hidden: usize,
44}
45
46impl<B: Backend> ModuleDisplay for Gru<B> {
47 fn custom_settings(&self) -> Option<DisplaySettings> {
48 DisplaySettings::new()
49 .with_new_line_after_attribute(false)
50 .optional()
51 }
52
53 fn custom_content(&self, content: Content) -> Option<Content> {
54 let [d_input, _] = self.update_gate.input_transform.weight.shape().dims();
55 let bias = self.update_gate.input_transform.bias.is_some();
56
57 content
58 .add("d_input", &d_input)
59 .add("d_hidden", &self.d_hidden)
60 .add("bias", &bias)
61 .optional()
62 }
63}
64
65impl GruConfig {
66 pub fn init<B: Backend>(&self, device: &B::Device) -> Gru<B> {
68 let d_output = self.d_hidden;
69
70 let update_gate = gate_controller::GateController::new(
71 self.d_input,
72 d_output,
73 self.bias,
74 self.initializer.clone(),
75 device,
76 );
77 let reset_gate = gate_controller::GateController::new(
78 self.d_input,
79 d_output,
80 self.bias,
81 self.initializer.clone(),
82 device,
83 );
84 let new_gate = gate_controller::GateController::new(
85 self.d_input,
86 d_output,
87 self.bias,
88 self.initializer.clone(),
89 device,
90 );
91
92 Gru {
93 update_gate,
94 reset_gate,
95 new_gate,
96 d_hidden: self.d_hidden,
97 }
98 }
99}
100
101impl<B: Backend> Gru<B> {
102 pub fn forward(
111 &self,
112 batched_input: Tensor<B, 3>,
113 state: Option<Tensor<B, 3>>,
114 ) -> Tensor<B, 3> {
115 let [batch_size, seq_length, _] = batched_input.shape().dims();
116
117 let mut hidden_state = match state {
118 Some(state) => state,
119 None => Tensor::zeros(
120 [batch_size, seq_length, self.d_hidden],
121 &batched_input.device(),
122 ),
123 };
124
125 for (t, (input_t, hidden_t)) in batched_input
126 .iter_dim(1)
127 .zip(hidden_state.clone().iter_dim(1))
128 .enumerate()
129 {
130 let input_t = input_t.squeeze(1);
131 let hidden_t = hidden_t.squeeze(1);
132 let biased_ug_input_sum = self.gate_product(&input_t, &hidden_t, &self.update_gate);
134 let update_values = activation::sigmoid(biased_ug_input_sum); let biased_rg_input_sum = self.gate_product(&input_t, &hidden_t, &self.reset_gate);
138 let reset_values = activation::sigmoid(biased_rg_input_sum); let reset_t = hidden_t.clone().mul(reset_values); let biased_ng_input_sum = self.gate_product(&input_t, &reset_t, &self.new_gate);
143 let candidate_state = biased_ng_input_sum.tanh(); let state_vector = candidate_state
148 .clone()
149 .mul(update_values.clone().sub_scalar(1).mul_scalar(-1)) + update_values.clone().mul(hidden_t);
151
152 let current_shape = state_vector.shape().dims;
153 let unsqueezed_shape = [current_shape[0], 1, current_shape[1]];
154 let reshaped_state_vector = state_vector.reshape(unsqueezed_shape);
155 hidden_state = hidden_state.slice_assign(
156 [0..batch_size, t..(t + 1), 0..self.d_hidden],
157 reshaped_state_vector,
158 );
159 }
160
161 hidden_state
162 }
163
164 fn gate_product(
174 &self,
175 input: &Tensor<B, 2>,
176 hidden: &Tensor<B, 2>,
177 gate: &GateController<B>,
178 ) -> Tensor<B, 2> {
179 let input_product = input.clone().matmul(gate.input_transform.weight.val());
180 let hidden_product = hidden.clone().matmul(gate.hidden_transform.weight.val());
181
182 let input_bias = gate
183 .input_transform
184 .bias
185 .as_ref()
186 .map(|bias_param| bias_param.val());
187 let hidden_bias = gate
188 .hidden_transform
189 .bias
190 .as_ref()
191 .map(|bias_param| bias_param.val());
192
193 match (input_bias, hidden_bias) {
194 (Some(input_bias), Some(hidden_bias)) => {
195 input_product + input_bias.unsqueeze() + hidden_product + hidden_bias.unsqueeze()
196 }
197 (Some(input_bias), None) => input_product + input_bias.unsqueeze() + hidden_product,
198 (None, Some(hidden_bias)) => input_product + hidden_product + hidden_bias.unsqueeze(),
199 (None, None) => input_product + hidden_product,
200 }
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::tensor::{Distribution, TensorData};
208 use crate::{module::Param, nn::LinearRecord, TestBackend};
209
210 #[test]
218 fn tests_forward_single_input_single_feature() {
219 TestBackend::seed(0);
220 let config = GruConfig::new(1, 1, false);
221 let device = Default::default();
222 let mut gru = config.init::<TestBackend>(&device);
223
224 fn create_gate_controller(
225 weights: f32,
226 biases: f32,
227 d_input: usize,
228 d_output: usize,
229 bias: bool,
230 initializer: Initializer,
231 device: &<TestBackend as Backend>::Device,
232 ) -> GateController<TestBackend> {
233 let record_1 = LinearRecord {
234 weight: Param::from_data(TensorData::from([[weights]]), device),
235 bias: Some(Param::from_data(TensorData::from([biases]), device)),
236 };
237 let record_2 = LinearRecord {
238 weight: Param::from_data(TensorData::from([[weights]]), device),
239 bias: Some(Param::from_data(TensorData::from([biases]), device)),
240 };
241 gate_controller::GateController::create_with_weights(
242 d_input,
243 d_output,
244 bias,
245 initializer,
246 record_1,
247 record_2,
248 )
249 }
250
251 gru.update_gate = create_gate_controller(
252 0.5,
253 0.0,
254 1,
255 1,
256 false,
257 Initializer::XavierNormal { gain: 1.0 },
258 &device,
259 );
260 gru.reset_gate = create_gate_controller(
261 0.6,
262 0.0,
263 1,
264 1,
265 false,
266 Initializer::XavierNormal { gain: 1.0 },
267 &device,
268 );
269 gru.new_gate = create_gate_controller(
270 0.7,
271 0.0,
272 1,
273 1,
274 false,
275 Initializer::XavierNormal { gain: 1.0 },
276 &device,
277 );
278
279 let input = Tensor::<TestBackend, 3>::from_data(TensorData::from([[[0.1]]]), &device);
280
281 let state = gru.forward(input, None);
282
283 let output = state
284 .select(0, Tensor::arange(0..1, &device))
285 .squeeze::<2>(0);
286
287 let expected = TensorData::from([[0.034]]);
288 output.to_data().assert_approx_eq(&expected, 3);
289 }
290
291 #[test]
292 fn test_batched_forward_pass() {
293 let device = Default::default();
294 let gru = GruConfig::new(64, 1024, true).init::<TestBackend>(&device);
295 let batched_input =
296 Tensor::<TestBackend, 3>::random([8, 10, 64], Distribution::Default, &device);
297
298 let hidden_state = gru.forward(batched_input, None);
299
300 assert_eq!(hidden_state.shape().dims, [8, 10, 1024]);
301 }
302
303 #[test]
304 fn display() {
305 let config = GruConfig::new(2, 8, true);
306
307 let layer = config.init::<TestBackend>(&Default::default());
308
309 assert_eq!(
310 alloc::format!("{}", layer),
311 "Gru {d_input: 2, d_hidden: 8, bias: true, params: 288}"
312 );
313 }
314}