1use candle::{DType, Device, IndexOp, Result, Tensor};
3
4#[allow(clippy::upper_case_acronyms)]
6pub trait RNN {
7 type State: Clone;
8
9 fn zero_state(&self, batch_dim: usize) -> Result<Self::State>;
11
12 fn step(&self, input: &Tensor, state: &Self::State) -> Result<Self::State>;
16
17 fn seq(&self, input: &Tensor) -> Result<Vec<Self::State>> {
22 let batch_dim = input.dim(0)?;
23 let state = self.zero_state(batch_dim)?;
24 self.seq_init(input, &state)
25 }
26
27 fn seq_init(&self, input: &Tensor, init_state: &Self::State) -> Result<Vec<Self::State>> {
31 let (_b_size, seq_len, _features) = input.dims3()?;
32 let mut output = Vec::with_capacity(seq_len);
33 for seq_index in 0..seq_len {
34 let input = input.i((.., seq_index, ..))?.contiguous()?;
35 let state = if seq_index == 0 {
36 self.step(&input, init_state)?
37 } else {
38 self.step(&input, &output[seq_index - 1])?
39 };
40 output.push(state);
41 }
42 Ok(output)
43 }
44
45 fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor>;
47}
48
49#[allow(clippy::upper_case_acronyms)]
51#[derive(Debug, Clone)]
52pub struct LSTMState {
53 pub h: Tensor,
54 pub c: Tensor,
55}
56
57impl LSTMState {
58 pub fn new(h: Tensor, c: Tensor) -> Self {
59 LSTMState { h, c }
60 }
61
62 pub fn h(&self) -> &Tensor {
64 &self.h
65 }
66
67 pub fn c(&self) -> &Tensor {
69 &self.c
70 }
71}
72
73#[derive(Debug, Clone, Copy)]
74pub enum Direction {
75 Forward,
76 Backward,
77}
78
79#[allow(clippy::upper_case_acronyms)]
80#[derive(Debug, Clone, Copy)]
81pub struct LSTMConfig {
82 pub w_ih_init: super::Init,
83 pub w_hh_init: super::Init,
84 pub b_ih_init: Option<super::Init>,
85 pub b_hh_init: Option<super::Init>,
86 pub layer_idx: usize,
87 pub direction: Direction,
88}
89
90impl Default for LSTMConfig {
91 fn default() -> Self {
92 Self {
93 w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
94 w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
95 b_ih_init: Some(super::Init::Const(0.)),
96 b_hh_init: Some(super::Init::Const(0.)),
97 layer_idx: 0,
98 direction: Direction::Forward,
99 }
100 }
101}
102
103impl LSTMConfig {
104 pub fn default_no_bias() -> Self {
105 Self {
106 w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
107 w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
108 b_ih_init: None,
109 b_hh_init: None,
110 layer_idx: 0,
111 direction: Direction::Forward,
112 }
113 }
114}
115
116#[allow(clippy::upper_case_acronyms)]
120#[derive(Clone, Debug)]
121pub struct LSTM {
122 w_ih: Tensor,
123 w_hh: Tensor,
124 b_ih: Option<Tensor>,
125 b_hh: Option<Tensor>,
126 hidden_dim: usize,
127 config: LSTMConfig,
128 device: Device,
129 dtype: DType,
130}
131
132impl LSTM {
133 pub fn new(
135 in_dim: usize,
136 hidden_dim: usize,
137 config: LSTMConfig,
138 vb: crate::VarBuilder,
139 ) -> Result<Self> {
140 let layer_idx = config.layer_idx;
141 let direction_str = match config.direction {
142 Direction::Forward => "",
143 Direction::Backward => "_reverse",
144 };
145 let w_ih = vb.get_with_hints(
146 (4 * hidden_dim, in_dim),
147 &format!("weight_ih_l{layer_idx}{direction_str}"), config.w_ih_init,
149 )?;
150 let w_hh = vb.get_with_hints(
151 (4 * hidden_dim, hidden_dim),
152 &format!("weight_hh_l{layer_idx}{direction_str}"), config.w_hh_init,
154 )?;
155 let b_ih = match config.b_ih_init {
156 Some(init) => Some(vb.get_with_hints(
157 4 * hidden_dim,
158 &format!("bias_ih_l{layer_idx}{direction_str}"),
159 init,
160 )?),
161 None => None,
162 };
163 let b_hh = match config.b_hh_init {
164 Some(init) => Some(vb.get_with_hints(
165 4 * hidden_dim,
166 &format!("bias_hh_l{layer_idx}{direction_str}"),
167 init,
168 )?),
169 None => None,
170 };
171 Ok(Self {
172 w_ih,
173 w_hh,
174 b_ih,
175 b_hh,
176 hidden_dim,
177 config,
178 device: vb.device().clone(),
179 dtype: vb.dtype(),
180 })
181 }
182
183 pub fn config(&self) -> &LSTMConfig {
184 &self.config
185 }
186}
187
188pub fn lstm(
190 in_dim: usize,
191 hidden_dim: usize,
192 config: LSTMConfig,
193 vb: crate::VarBuilder,
194) -> Result<LSTM> {
195 LSTM::new(in_dim, hidden_dim, config, vb)
196}
197
198impl RNN for LSTM {
199 type State = LSTMState;
200
201 fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
202 let zeros =
203 Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?;
204 Ok(Self::State {
205 h: zeros.clone(),
206 c: zeros.clone(),
207 })
208 }
209
210 fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> {
211 let w_ih = input.matmul(&self.w_ih.t()?)?;
212 let w_hh = in_state.h.matmul(&self.w_hh.t()?)?;
213 let w_ih = match &self.b_ih {
214 None => w_ih,
215 Some(b_ih) => w_ih.broadcast_add(b_ih)?,
216 };
217 let w_hh = match &self.b_hh {
218 None => w_hh,
219 Some(b_hh) => w_hh.broadcast_add(b_hh)?,
220 };
221 let chunks = (&w_ih + &w_hh)?.chunk(4, 1)?;
222 let in_gate = crate::ops::sigmoid(&chunks[0])?;
223 let forget_gate = crate::ops::sigmoid(&chunks[1])?;
224 let cell_gate = chunks[2].tanh()?;
225 let out_gate = crate::ops::sigmoid(&chunks[3])?;
226
227 let next_c = ((forget_gate * &in_state.c)? + (in_gate * cell_gate)?)?;
228 let next_h = (out_gate * next_c.tanh()?)?;
229 Ok(LSTMState {
230 c: next_c,
231 h: next_h,
232 })
233 }
234
235 fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
236 let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
237 Tensor::stack(&states, 1)
238 }
239}
240
241#[allow(clippy::upper_case_acronyms)]
243#[derive(Debug, Clone)]
244pub struct GRUState {
245 pub h: Tensor,
246}
247
248impl GRUState {
249 pub fn h(&self) -> &Tensor {
251 &self.h
252 }
253}
254
255#[allow(clippy::upper_case_acronyms)]
256#[derive(Debug, Clone, Copy)]
257pub struct GRUConfig {
258 pub w_ih_init: super::Init,
259 pub w_hh_init: super::Init,
260 pub b_ih_init: Option<super::Init>,
261 pub b_hh_init: Option<super::Init>,
262}
263
264impl Default for GRUConfig {
265 fn default() -> Self {
266 Self {
267 w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
268 w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
269 b_ih_init: Some(super::Init::Const(0.)),
270 b_hh_init: Some(super::Init::Const(0.)),
271 }
272 }
273}
274
275impl GRUConfig {
276 pub fn default_no_bias() -> Self {
277 Self {
278 w_ih_init: super::init::DEFAULT_KAIMING_UNIFORM,
279 w_hh_init: super::init::DEFAULT_KAIMING_UNIFORM,
280 b_ih_init: None,
281 b_hh_init: None,
282 }
283 }
284}
285
286#[allow(clippy::upper_case_acronyms)]
290#[derive(Clone, Debug)]
291pub struct GRU {
292 w_ih: Tensor,
293 w_hh: Tensor,
294 b_ih: Option<Tensor>,
295 b_hh: Option<Tensor>,
296 hidden_dim: usize,
297 config: GRUConfig,
298 device: Device,
299 dtype: DType,
300}
301
302impl GRU {
303 pub fn new(
305 in_dim: usize,
306 hidden_dim: usize,
307 config: GRUConfig,
308 vb: crate::VarBuilder,
309 ) -> Result<Self> {
310 let w_ih = vb.get_with_hints(
311 (3 * hidden_dim, in_dim),
312 "weight_ih_l0", config.w_ih_init,
314 )?;
315 let w_hh = vb.get_with_hints(
316 (3 * hidden_dim, hidden_dim),
317 "weight_hh_l0", config.w_hh_init,
319 )?;
320 let b_ih = match config.b_ih_init {
321 Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_ih_l0", init)?),
322 None => None,
323 };
324 let b_hh = match config.b_hh_init {
325 Some(init) => Some(vb.get_with_hints(3 * hidden_dim, "bias_hh_l0", init)?),
326 None => None,
327 };
328 Ok(Self {
329 w_ih,
330 w_hh,
331 b_ih,
332 b_hh,
333 hidden_dim,
334 config,
335 device: vb.device().clone(),
336 dtype: vb.dtype(),
337 })
338 }
339
340 pub fn config(&self) -> &GRUConfig {
341 &self.config
342 }
343}
344
345pub fn gru(
346 in_dim: usize,
347 hidden_dim: usize,
348 config: GRUConfig,
349 vb: crate::VarBuilder,
350) -> Result<GRU> {
351 GRU::new(in_dim, hidden_dim, config, vb)
352}
353
354impl RNN for GRU {
355 type State = GRUState;
356
357 fn zero_state(&self, batch_dim: usize) -> Result<Self::State> {
358 let h =
359 Tensor::zeros((batch_dim, self.hidden_dim), self.dtype, &self.device)?.contiguous()?;
360 Ok(Self::State { h })
361 }
362
363 fn step(&self, input: &Tensor, in_state: &Self::State) -> Result<Self::State> {
364 let w_ih = input.matmul(&self.w_ih.t()?)?;
365 let w_hh = in_state.h.matmul(&self.w_hh.t()?)?;
366 let w_ih = match &self.b_ih {
367 None => w_ih,
368 Some(b_ih) => w_ih.broadcast_add(b_ih)?,
369 };
370 let w_hh = match &self.b_hh {
371 None => w_hh,
372 Some(b_hh) => w_hh.broadcast_add(b_hh)?,
373 };
374 let chunks_ih = w_ih.chunk(3, 1)?;
375 let chunks_hh = w_hh.chunk(3, 1)?;
376 let r_gate = crate::ops::sigmoid(&(&chunks_ih[0] + &chunks_hh[0])?)?;
377 let z_gate = crate::ops::sigmoid(&(&chunks_ih[1] + &chunks_hh[1])?)?;
378 let n_gate = (&chunks_ih[2] + (r_gate * &chunks_hh[2])?)?.tanh();
379
380 let next_h = ((&z_gate * &in_state.h)? - ((&z_gate - 1.)? * n_gate)?)?;
381 Ok(GRUState { h: next_h })
382 }
383
384 fn states_to_tensor(&self, states: &[Self::State]) -> Result<Tensor> {
385 let states = states.iter().map(|s| s.h.clone()).collect::<Vec<_>>();
386 Tensor::cat(&states, 1)
387 }
388}