tch_plus/nn/
rnn.rs

1//! Recurrent Neural Networks
2use crate::{Device, Tensor};
3use std::borrow::Borrow;
4
5/// Trait for Recurrent Neural Networks.
6#[allow(clippy::upper_case_acronyms)]
7pub trait RNN {
8    type State;
9
10    /// A zero state from which the recurrent network is usually initialized.
11    fn zero_state(&self, batch_dim: i64) -> Self::State;
12
13    /// Applies a single step of the recurrent network.
14    ///
15    /// The input should have dimensions [batch_size, features].
16    fn step(&self, input: &Tensor, state: &Self::State) -> Self::State;
17
18    /// Applies multiple steps of the recurrent network.
19    ///
20    /// The input should have dimensions [batch_size, seq_len, features].
21    /// The initial state is the result of applying zero_state.
22    fn seq(&self, input: &Tensor) -> (Tensor, Self::State) {
23        let batch_dim = input.size()[0];
24        let state = self.zero_state(batch_dim);
25        self.seq_init(input, &state)
26    }
27
28    /// Applies multiple steps of the recurrent network.
29    ///
30    /// The input should have dimensions [batch_size, seq_len, features].
31    fn seq_init(&self, input: &Tensor, state: &Self::State) -> (Tensor, Self::State);
32}
33
34/// The state for a LSTM network, this contains two tensors.
35#[allow(clippy::upper_case_acronyms)]
36#[derive(Debug)]
37pub struct LSTMState(pub (Tensor, Tensor));
38
39impl LSTMState {
40    /// The hidden state vector, which is also the output of the LSTM.
41    pub fn h(&self) -> Tensor {
42        (self.0).0.shallow_clone()
43    }
44
45    /// The cell state vector.
46    pub fn c(&self) -> Tensor {
47        (self.0).1.shallow_clone()
48    }
49}
50
51// The GRU and LSTM layers share the same config.
52/// Configuration for the GRU and LSTM layers.
53#[allow(clippy::upper_case_acronyms)]
54#[derive(Debug, Clone, Copy)]
55pub struct RNNConfig {
56    pub has_biases: bool,
57    pub num_layers: i64,
58    pub dropout: f64,
59    pub train: bool,
60    pub bidirectional: bool,
61    pub batch_first: bool,
62    pub w_ih_init: super::Init,
63    pub w_hh_init: super::Init,
64    pub b_ih_init: Option<super::Init>,
65    pub b_hh_init: Option<super::Init>,
66}
67
68impl Default for RNNConfig {
69    fn default() -> Self {
70        RNNConfig {
71            has_biases: true,
72            num_layers: 1,
73            dropout: 0.,
74            train: true,
75            bidirectional: false,
76            batch_first: true,
77            w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
78            w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
79            b_ih_init: Some(super::Init::Const(0.)),
80            b_hh_init: Some(super::Init::Const(0.)),
81        }
82    }
83}
84
85fn rnn_weights<'a, T: Borrow<super::Path<'a>>>(
86    vs: T,
87    in_dim: i64,
88    hidden_dim: i64,
89    gate_dim: i64,
90    num_directions: i64,
91    c: RNNConfig,
92) -> Vec<Tensor> {
93    let vs = vs.borrow();
94    let mut flat_weights = vec![];
95    for layer_idx in 0..c.num_layers {
96        for direction_idx in 0..num_directions {
97            let in_dim = if layer_idx == 0 { in_dim } else { hidden_dim * num_directions };
98            let suffix = if direction_idx == 1 { "_reverse" } else { "" };
99            let w_ih = vs.var(
100                &format!("weight_ih_l{layer_idx}{suffix}"),
101                &[gate_dim, in_dim],
102                c.w_ih_init,
103            );
104            let w_hh = vs.var(
105                &format!("weight_hh_l{layer_idx}{suffix}"),
106                &[gate_dim, hidden_dim],
107                c.w_hh_init,
108            );
109            flat_weights.push(w_ih);
110            flat_weights.push(w_hh);
111            if c.has_biases {
112                let b_ih = vs.var(
113                    &format!("bias_ih_l{layer_idx}{suffix}"),
114                    &[gate_dim],
115                    c.b_ih_init.unwrap(),
116                );
117                let b_hh = vs.var(
118                    &format!("bias_hh_l{layer_idx}{suffix}"),
119                    &[gate_dim],
120                    c.b_hh_init.unwrap(),
121                );
122                flat_weights.push(b_ih);
123                flat_weights.push(b_hh);
124            }
125        }
126    }
127    flat_weights
128}
129
130/// A Long Short-Term Memory (LSTM) layer.
131///
132/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
133#[allow(clippy::upper_case_acronyms)]
134#[derive(Debug)]
135pub struct LSTM {
136    flat_weights: Vec<Tensor>,
137    hidden_dim: i64,
138    config: RNNConfig,
139    device: Device,
140}
141
142/// Creates a LSTM layer.
143pub fn lstm<'a, T: Borrow<super::Path<'a>>>(
144    vs: T,
145    in_dim: i64,
146    hidden_dim: i64,
147    c: RNNConfig,
148) -> LSTM {
149    let vs = vs.borrow();
150    let num_directions = if c.bidirectional { 2 } else { 1 };
151    let gate_dim = 4 * hidden_dim;
152    let flat_weights = rnn_weights(vs, in_dim, hidden_dim, gate_dim, num_directions, c);
153
154    if vs.device().is_cuda() && crate::Cuda::cudnn_is_available() {
155        let _ = Tensor::internal_cudnn_rnn_flatten_weight(
156            &flat_weights,
157            4,
158            in_dim,
159            2, /* 2 for LSTM see rnn.cpp in pytorch */
160            hidden_dim,
161            0, /* disables projections */
162            c.num_layers,
163            c.batch_first,
164            c.bidirectional,
165        );
166    }
167    LSTM { flat_weights, hidden_dim, config: c, device: vs.device() }
168}
169
170impl RNN for LSTM {
171    type State = LSTMState;
172
173    fn zero_state(&self, batch_dim: i64) -> LSTMState {
174        let num_directions = if self.config.bidirectional { 2 } else { 1 };
175        let layer_dim = self.config.num_layers * num_directions;
176        let shape = [layer_dim, batch_dim, self.hidden_dim];
177        let zeros = Tensor::zeros(shape, (self.flat_weights[0].kind(), self.device));
178        LSTMState((zeros.shallow_clone(), zeros.shallow_clone()))
179    }
180
181    fn step(&self, input: &Tensor, in_state: &LSTMState) -> LSTMState {
182        let input = input.unsqueeze(1);
183        let (_output, state) = self.seq_init(&input, in_state);
184        state
185    }
186
187    fn seq_init(&self, input: &Tensor, in_state: &LSTMState) -> (Tensor, LSTMState) {
188        let LSTMState((h, c)) = in_state;
189        let flat_weights = self.flat_weights.iter().collect::<Vec<_>>();
190        let (output, h, c) = input.lstm(
191            &[h, c],
192            &flat_weights,
193            self.config.has_biases,
194            self.config.num_layers,
195            self.config.dropout,
196            self.config.train,
197            self.config.bidirectional,
198            self.config.batch_first,
199        );
200        (output, LSTMState((h, c)))
201    }
202}
203
204/// A GRU state, this contains a single tensor.
205#[allow(clippy::upper_case_acronyms)]
206#[derive(Debug)]
207pub struct GRUState(pub Tensor);
208
209impl GRUState {
210    pub fn value(&self) -> Tensor {
211        self.0.shallow_clone()
212    }
213}
214
215/// A Gated Recurrent Unit (GRU) layer.
216///
217/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
218#[allow(clippy::upper_case_acronyms)]
219#[derive(Debug)]
220pub struct GRU {
221    flat_weights: Vec<Tensor>,
222    hidden_dim: i64,
223    config: RNNConfig,
224    device: Device,
225}
226
227/// Creates a new GRU layer.
228pub fn gru<'a, T: Borrow<super::Path<'a>>>(
229    vs: T,
230    in_dim: i64,
231    hidden_dim: i64,
232    c: RNNConfig,
233) -> GRU {
234    let vs = vs.borrow();
235    let num_directions = if c.bidirectional { 2 } else { 1 };
236    let gate_dim = 3 * hidden_dim;
237    let flat_weights = rnn_weights(vs, in_dim, hidden_dim, gate_dim, num_directions, c);
238
239    if vs.device().is_cuda() && crate::Cuda::cudnn_is_available() {
240        let _ = Tensor::internal_cudnn_rnn_flatten_weight(
241            &flat_weights,
242            4,
243            in_dim,
244            3, /* 3 for GRU see rnn.cpp in pytorch */
245            hidden_dim,
246            0, /* disables projections */
247            c.num_layers,
248            c.batch_first,
249            c.bidirectional,
250        );
251    }
252    GRU { flat_weights, hidden_dim, config: c, device: vs.device() }
253}
254
255impl RNN for GRU {
256    type State = GRUState;
257
258    fn zero_state(&self, batch_dim: i64) -> GRUState {
259        let num_directions = if self.config.bidirectional { 2 } else { 1 };
260        let layer_dim = self.config.num_layers * num_directions;
261        let shape = [layer_dim, batch_dim, self.hidden_dim];
262        GRUState(Tensor::zeros(shape, (self.flat_weights[0].kind(), self.device)))
263    }
264
265    fn step(&self, input: &Tensor, in_state: &GRUState) -> GRUState {
266        let input = input.unsqueeze(1);
267        let (_output, state) = self.seq_init(&input, in_state);
268        state
269    }
270
271    fn seq_init(&self, input: &Tensor, in_state: &GRUState) -> (Tensor, GRUState) {
272        let GRUState(h) = in_state;
273        let (output, h) = input.gru(
274            h,
275            &self.flat_weights,
276            self.config.has_biases,
277            self.config.num_layers,
278            self.config.dropout,
279            self.config.train,
280            self.config.bidirectional,
281            self.config.batch_first,
282        );
283        (output, GRUState(h))
284    }
285}