Skip to main content

candle_nn/
rnn.rs

1//! Recurrent Neural Networks
2use candle::{DType, Device, IndexOp, Result, Tensor};
3
4/// Trait for Recurrent Neural Networks.
5#[allow(clippy::upper_case_acronyms)]
6pub trait RNN {
7    type State: Clone;
8
9    /// A zero state from which the recurrent network is usually initialized.
10    fn zero_state(&self, batch_dim: usize) -> Result<Self::State>;
11
12    /// Applies a single step of the recurrent network.
13    ///
14    /// The input should have dimensions [batch_size, features].
15    fn step(&self, input: &Tensor, state: &Self::State) -> Result<Self::State>;
16
17    /// Applies multiple steps of the recurrent network.
18    ///
19    /// The input should have dimensions [batch_size, seq_len, features].
20    /// The initial state is the result of applying zero_state.
21    fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>> {
22        let batch_dim = input.dim(0)?;
23        let state = self.zero_state(batch_dim)?;
24        self.seq_init(input, &state)
25    }
26
27    /// Applies multiple steps of the recurrent network.
28    ///
29    /// The input should have dimensions [batch_size, seq_len, features].
30    fn seq_init(&self, input: &Tensor, init_state: &Self::State) -> Result<Vec<Self::State>> {
31        let (_b_size, seq_len, _features) = input.dims3()?;
32        let mut output = Vec::with_capacity(seq_len);
33        for seq_index in 0..seq_len {
34            let input = input.i((.., seq_index, ..))?.contiguous()?;
35            let state = if seq_index == 0 {
36                self.step(&input, init_state)?
37            } else {
38                self.step(&input, &output[seq_index - 1])?
39            };
40            output.push(state);
41        }
42        Ok(output)
43    }
44
45    /// Converts a sequence of state to a tensor.
46    fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>;
47}
48
49/// The state for a LSTM network, this contains two tensors.
50#[allow(clippy::upper_case_acronyms)]
51#[derive(Debug, Clone)]
52pub struct LSTMState {
53    pub h: Tensor,
54    pub c: Tensor,
55}
56
57impl LSTMState {
58    pub fn new(h: Tensor, c: Tensor) -> Self {
59        LSTMState { h, c }
60    }
61
62    /// The hidden state vector, which is also the output of the LSTM.
63    pub fn h(&self) -> &Tensor {
64        &self.h
65    }
66
67    /// The cell state vector.
68    pub fn c(&self) -> &Tensor {
69        &self.c
70    }
71}
72
73#[derive(Debug, Clone, Copy)]
74pub enum Direction {
75    Forward,
76    Backward,
77}
78
79#[allow(clippy::upper_case_acronyms)]
80#[derive(Debug, Clone, Copy)]
81pub struct LSTMConfig {
82    pub w_ih_init: super::Init,
83    pub w_hh_init: super::Init,
84    pub b_ih_init: Option<super::Init>,
85    pub b_hh_init: Option<super::Init>,
86    pub layer_idx: usize,
87    pub direction: Direction,
88}
89
90impl Default for LSTMConfig {
91    fn default() -> Self {
92        Self {
93            w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
94            w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
95            b_ih_init: Some(super::Init::Const(0.)),
96            b_hh_init: Some(super::Init::Const(0.)),
97            layer_idx: 0,
98            direction: Direction::Forward,
99        }
100    }
101}
102
103impl LSTMConfig {
104    pub fn default_no_bias() -> Self {
105        Self {
106            w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
107            w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
108            b_ih_init: None,
109            b_hh_init: None,
110            layer_idx: 0,
111            direction: Direction::Forward,
112        }
113    }
114}
115
116/// A Long Short-Term Memory (LSTM) layer.
117///
118/// <https://en.wikipedia.org/wiki/Long_short-term_memory>
119#[allow(clippy::upper_case_acronyms)]
120#[derive(Clone, Debug)]
121pub struct LSTM {
122    w_ih: Tensor,
123    w_hh: Tensor,
124    b_ih: Option<Tensor>,
125    b_hh: Option<Tensor>,
126    hidden_dim: usize,
127    config: LSTMConfig,
128    device: Device,
129    dtype: DType,
130}
131
132impl LSTM {
133    /// Creates a LSTM layer.
134    pub fn new(
135        in_dim: usize,
136        hidden_dim: usize,
137        config: LSTMConfig,
138        vb: crate::VarBuilder,
139    ) -> Result<Self> {
140        let layer_idx = config.layer_idx;
141        let direction_str = match config.direction {
142            Direction::Forward => "",
143            Direction::Backward => "_reverse",
144        };
145        let w_ih = vb.get_with_hints(
146            (4 * hidden_dim, in_dim),
147            &format!("weight_ih_l{layer_idx}{direction_str}"), // Only a single layer is supported.
148            config.w_ih_init,
149        )?;
150        let w_hh = vb.get_with_hints(
151            (4 * hidden_dim, hidden_dim),
152            &format!("weight_hh_l{layer_idx}{direction_str}"), // Only a single layer is supported.
153            config.w_hh_init,
154        )?;
155        let b_ih = match config.b_ih_init {
156            Some(init) => Some(vb.get_with_hints(
157                4 * hidden_dim,
158                &format!("bias_ih_l{layer_idx}{direction_str}"),
159                init,
160            )?),
161            None => None,
162        };
163        let b_hh = match config.b_hh_init {
164            Some(init) => Some(vb.get_with_hints(
165                4 * hidden_dim,
166                &format!("bias_hh_l{layer_idx}{direction_str}"),
167                init,
168            )?),
169            None => None,
170        };
171        Ok(Self {
172            w_ih,
173            w_hh,
174            b_ih,
175            b_hh,
176            hidden_dim,
177            config,
178            device: vb.device().clone(),
179            dtype: vb.dtype(),
180        })
181    }
182
183    pub fn config(&self) -> &LSTMConfig {
184        &self.config
185    }
186}
187
188/// Creates a LSTM layer.
189pub fn lstm(
190    in_dim: usize,
191    hidden_dim: usize,
192    config: LSTMConfig,
193    vb: crate::VarBuilder,
194) -> Result<LSTM> {
195    LSTM::new(in_dim, hidden_dim, config, vb)
196}
197
198impl RNN for LSTM {
199    type State = LSTMState;
200
201    fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
202        let zeros =
203            Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?;
204        Ok(Self::State {
205            h: zeros.clone(),
206            c: zeros.clone(),
207        })
208    }
209
210    fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> {
211        let w_ih = input.matmul(&self.w_ih.t()?)?;
212        let w_hh = in_state.h.matmul(&self.w_hh.t()?)?;
213        let w_ih = match &self.b_ih {
214            None => w_ih,
215            Some(b_ih) => w_ih.broadcast_add(b_ih)?,
216        };
217        let w_hh = match &self.b_hh {
218            None => w_hh,
219            Some(b_hh) => w_hh.broadcast_add(b_hh)?,
220        };
221        let chunks = (&w_ih + &w_hh)?.chunk(4, 1)?;
222        let in_gate = crate::ops::sigmoid(&chunks[0])?;
223        let forget_gate = crate::ops::sigmoid(&chunks[1])?;
224        let cell_gate = chunks[2].tanh()?;
225        let out_gate = crate::ops::sigmoid(&chunks[3])?;
226
227        let next_c = ((forget_gate * &in_state.c)? + (in_gate * cell_gate)?)?;
228        let next_h = (out_gate * next_c.tanh()?)?;
229        Ok(LSTMState {
230            c: next_c,
231            h: next_h,
232        })
233    }
234
235    fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
236        let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
237        Tensor::stack(&states, 1)
238    }
239}
240
241/// The state for a GRU network, this contains a single tensor.
242#[allow(clippy::upper_case_acronyms)]
243#[derive(Debug, Clone)]
244pub struct GRUState {
245    pub h: Tensor,
246}
247
248impl GRUState {
249    /// The hidden state vector, which is also the output of the LSTM.
250    pub fn h(&self) -> &Tensor {
251        &self.h
252    }
253}
254
255#[allow(clippy::upper_case_acronyms)]
256#[derive(Debug, Clone, Copy)]
257pub struct GRUConfig {
258    pub w_ih_init: super::Init,
259    pub w_hh_init: super::Init,
260    pub b_ih_init: Option<super::Init>,
261    pub b_hh_init: Option<super::Init>,
262}
263
264impl Default for GRUConfig {
265    fn default() -> Self {
266        Self {
267            w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
268            w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
269            b_ih_init: Some(super::Init::Const(0.)),
270            b_hh_init: Some(super::Init::Const(0.)),
271        }
272    }
273}
274
275impl GRUConfig {
276    pub fn default_no_bias() -> Self {
277        Self {
278            w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
279            w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
280            b_ih_init: None,
281            b_hh_init: None,
282        }
283    }
284}
285
286/// A Gated Recurrent Unit (GRU) layer.
287///
288/// <https://en.wikipedia.org/wiki/Gated_recurrent_unit>
289#[allow(clippy::upper_case_acronyms)]
290#[derive(Clone, Debug)]
291pub struct GRU {
292    w_ih: Tensor,
293    w_hh: Tensor,
294    b_ih: Option<Tensor>,
295    b_hh: Option<Tensor>,
296    hidden_dim: usize,
297    config: GRUConfig,
298    device: Device,
299    dtype: DType,
300}
301
302impl GRU {
303    /// Creates a GRU layer.
304    pub fn new(
305        in_dim: usize,
306        hidden_dim: usize,
307        config: GRUConfig,
308        vb: crate::VarBuilder,
309    ) -> Result<Self> {
310        let w_ih = vb.get_with_hints(
311            (3 * hidden_dim, in_dim),
312            "weight_ih_l0", // Only a single layer is supported.
313            config.w_ih_init,
314        )?;
315        let w_hh = vb.get_with_hints(
316            (3 * hidden_dim, hidden_dim),
317            "weight_hh_l0", // Only a single layer is supported.
318            config.w_hh_init,
319        )?;
320        let b_ih = match config.b_ih_init {
321            Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
322            None => None,
323        };
324        let b_hh = match config.b_hh_init {
325            Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
326            None => None,
327        };
328        Ok(Self {
329            w_ih,
330            w_hh,
331            b_ih,
332            b_hh,
333            hidden_dim,
334            config,
335            device: vb.device().clone(),
336            dtype: vb.dtype(),
337        })
338    }
339
340    pub fn config(&self) -> &GRUConfig {
341        &self.config
342    }
343}
344
345pub fn gru(
346    in_dim: usize,
347    hidden_dim: usize,
348    config: GRUConfig,
349    vb: crate::VarBuilder,
350) -> Result<GRU> {
351    GRU::new(in_dim, hidden_dim, config, vb)
352}
353
354impl RNN for GRU {
355    type State = GRUState;
356
357    fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
358        let h =
359            Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?;
360        Ok(Self::State { h })
361    }
362
363    fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> {
364        let w_ih = input.matmul(&self.w_ih.t()?)?;
365        let w_hh = in_state.h.matmul(&self.w_hh.t()?)?;
366        let w_ih = match &self.b_ih {
367            None => w_ih,
368            Some(b_ih) => w_ih.broadcast_add(b_ih)?,
369        };
370        let w_hh = match &self.b_hh {
371            None => w_hh,
372            Some(b_hh) => w_hh.broadcast_add(b_hh)?,
373        };
374        let chunks_ih = w_ih.chunk(3, 1)?;
375        let chunks_hh = w_hh.chunk(3, 1)?;
376        let r_gate = crate::ops::sigmoid(&(&chunks_ih[0] + &chunks_hh[0])?)?;
377        let z_gate = crate::ops::sigmoid(&(&chunks_ih[1] + &chunks_hh[1])?)?;
378        let n_gate = (&chunks_ih[2] + (r_gate * &chunks_hh[2])?)?.tanh();
379
380        let next_h = ((&z_gate * &in_state.h)? - ((&z_gate - 1.)? * n_gate)?)?;
381        Ok(GRUState { h: next_h })
382    }
383
384    fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
385        let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
386        Tensor::cat(&states, 1)
387    }
388}