1use crate::{Device, Tensor};
3use std::borrow::Borrow;
4
5#[allow(clippy::upper_case_acronyms)]
7pub trait RNN {
8 type State;
9
10 fn zero_state(&self, batch_dim: i64) -> Self::State;
12
13 fn step(&self, input: &Tensor, state: &Self::State) -> Self::State;
17
18 fn seq(&self, input: &Tensor) -> (Tensor, Self::State) {
23 let batch_dim = input.size()[0];
24 let state = self.zero_state(batch_dim);
25 self.seq_init(input, &state)
26 }
27
28 fn seq_init(&self, input: &Tensor, state: &Self::State) -> (Tensor, Self::State);
32}
33
34#[allow(clippy::upper_case_acronyms)]
36#[derive(Debug)]
37pub struct LSTMState(pub (Tensor, Tensor));
38
39impl LSTMState {
40 pub fn h(&self) -> Tensor {
42 (self.0).0.shallow_clone()
43 }
44
45 pub fn c(&self) -> Tensor {
47 (self.0).1.shallow_clone()
48 }
49}
50
51#[allow(clippy::upper_case_acronyms)]
54#[derive(Debug, Clone, Copy)]
55pub struct RNNConfig {
56 pub has_biases: bool,
57 pub num_layers: i64,
58 pub dropout: f64,
59 pub train: bool,
60 pub bidirectional: bool,
61 pub batch_first: bool,
62 pub w_ih_init: super::Init,
63 pub w_hh_init: super::Init,
64 pub b_ih_init: Option<super::Init>,
65 pub b_hh_init: Option<super::Init>,
66}
67
68impl Default for RNNConfig {
69 fn default() -> Self {
70 RNNConfig {
71 has_biases: true,
72 num_layers: 1,
73 dropout: 0.,
74 train: true,
75 bidirectional: false,
76 batch_first: true,
77 w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
78 w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
79 b_ih_init: Some(super::Init::Const(0.)),
80 b_hh_init: Some(super::Init::Const(0.)),
81 }
82 }
83}
84
85fn rnn_weights<'a, T: Borrow<super::Path<'a>>>(
86 vs: T,
87 in_dim: i64,
88 hidden_dim: i64,
89 gate_dim: i64,
90 num_directions: i64,
91 c: RNNConfig,
92) -> Vec<Tensor> {
93 let vs = vs.borrow();
94 let mut flat_weights = vec![];
95 for layer_idx in 0..c.num_layers {
96 for direction_idx in 0..num_directions {
97 let in_dim = if layer_idx == 0 { in_dim } else { hidden_dim * num_directions };
98 let suffix = if direction_idx == 1 { "_reverse" } else { "" };
99 let w_ih = vs.var(
100 &format!("weight_ih_l{layer_idx}{suffix}"),
101 &[gate_dim, in_dim],
102 c.w_ih_init,
103 );
104 let w_hh = vs.var(
105 &format!("weight_hh_l{layer_idx}{suffix}"),
106 &[gate_dim, hidden_dim],
107 c.w_hh_init,
108 );
109 flat_weights.push(w_ih);
110 flat_weights.push(w_hh);
111 if c.has_biases {
112 let b_ih = vs.var(
113 &format!("bias_ih_l{layer_idx}{suffix}"),
114 &[gate_dim],
115 c.b_ih_init.unwrap(),
116 );
117 let b_hh = vs.var(
118 &format!("bias_hh_l{layer_idx}{suffix}"),
119 &[gate_dim],
120 c.b_hh_init.unwrap(),
121 );
122 flat_weights.push(b_ih);
123 flat_weights.push(b_hh);
124 }
125 }
126 }
127 flat_weights
128}
129
130#[allow(clippy::upper_case_acronyms)]
134#[derive(Debug)]
135pub struct LSTM {
136 flat_weights: Vec<Tensor>,
137 hidden_dim: i64,
138 config: RNNConfig,
139 device: Device,
140}
141
142pub fn lstm<'a, T: Borrow<super::Path<'a>>>(
144 vs: T,
145 in_dim: i64,
146 hidden_dim: i64,
147 c: RNNConfig,
148) -> LSTM {
149 let vs = vs.borrow();
150 let num_directions = if c.bidirectional { 2 } else { 1 };
151 let gate_dim = 4 * hidden_dim;
152 let flat_weights = rnn_weights(vs, in_dim, hidden_dim, gate_dim, num_directions, c);
153
154 if vs.device().is_cuda() && crate::Cuda::cudnn_is_available() {
155 let _ = Tensor::internal_cudnn_rnn_flatten_weight(
156 &flat_weights,
157 4,
158 in_dim,
159 2, hidden_dim,
161 0, c.num_layers,
163 c.batch_first,
164 c.bidirectional,
165 );
166 }
167 LSTM { flat_weights, hidden_dim, config: c, device: vs.device() }
168}
169
170impl RNN for LSTM {
171 type State = LSTMState;
172
173 fn zero_state(&self, batch_dim: i64) -> LSTMState {
174 let num_directions = if self.config.bidirectional { 2 } else { 1 };
175 let layer_dim = self.config.num_layers * num_directions;
176 let shape = [layer_dim, batch_dim, self.hidden_dim];
177 let zeros = Tensor::zeros(shape, (self.flat_weights[0].kind(), self.device));
178 LSTMState((zeros.shallow_clone(), zeros.shallow_clone()))
179 }
180
181 fn step(&self, input: &Tensor, in_state: &LSTMState) -> LSTMState {
182 let input = input.unsqueeze(1);
183 let (_output, state) = self.seq_init(&input, in_state);
184 state
185 }
186
187 fn seq_init(&self, input: &Tensor, in_state: &LSTMState) -> (Tensor, LSTMState) {
188 let LSTMState((h, c)) = in_state;
189 let flat_weights = self.flat_weights.iter().collect::<Vec<_>>();
190 let (output, h, c) = input.lstm(
191 &[h, c],
192 &flat_weights,
193 self.config.has_biases,
194 self.config.num_layers,
195 self.config.dropout,
196 self.config.train,
197 self.config.bidirectional,
198 self.config.batch_first,
199 );
200 (output, LSTMState((h, c)))
201 }
202}
203
204#[allow(clippy::upper_case_acronyms)]
206#[derive(Debug)]
207pub struct GRUState(pub Tensor);
208
209impl GRUState {
210 pub fn value(&self) -> Tensor {
211 self.0.shallow_clone()
212 }
213}
214
215#[allow(clippy::upper_case_acronyms)]
219#[derive(Debug)]
220pub struct GRU {
221 flat_weights: Vec<Tensor>,
222 hidden_dim: i64,
223 config: RNNConfig,
224 device: Device,
225}
226
227pub fn gru<'a, T: Borrow<super::Path<'a>>>(
229 vs: T,
230 in_dim: i64,
231 hidden_dim: i64,
232 c: RNNConfig,
233) -> GRU {
234 let vs = vs.borrow();
235 let num_directions = if c.bidirectional { 2 } else { 1 };
236 let gate_dim = 3 * hidden_dim;
237 let flat_weights = rnn_weights(vs, in_dim, hidden_dim, gate_dim, num_directions, c);
238
239 if vs.device().is_cuda() && crate::Cuda::cudnn_is_available() {
240 let _ = Tensor::internal_cudnn_rnn_flatten_weight(
241 &flat_weights,
242 4,
243 in_dim,
244 3, hidden_dim,
246 0, c.num_layers,
248 c.batch_first,
249 c.bidirectional,
250 );
251 }
252 GRU { flat_weights, hidden_dim, config: c, device: vs.device() }
253}
254
255impl RNN for GRU {
256 type State = GRUState;
257
258 fn zero_state(&self, batch_dim: i64) -> GRUState {
259 let num_directions = if self.config.bidirectional { 2 } else { 1 };
260 let layer_dim = self.config.num_layers * num_directions;
261 let shape = [layer_dim, batch_dim, self.hidden_dim];
262 GRUState(Tensor::zeros(shape, (self.flat_weights[0].kind(), self.device)))
263 }
264
265 fn step(&self, input: &Tensor, in_state: &GRUState) -> GRUState {
266 let input = input.unsqueeze(1);
267 let (_output, state) = self.seq_init(&input, in_state);
268 state
269 }
270
271 fn seq_init(&self, input: &Tensor, in_state: &GRUState) -> (Tensor, GRUState) {
272 let GRUState(h) = in_state;
273 let (output, h) = input.gru(
274 h,
275 &self.flat_weights,
276 self.config.has_biases,
277 self.config.num_layers,
278 self.config.dropout,
279 self.config.train,
280 self.config.bidirectional,
281 self.config.batch_first,
282 );
283 (output, GRUState(h))
284 }
285}