1use std::collections::HashMap;
9
10use axonml_autograd::Variable;
11use axonml_tensor::Tensor;
12
13use crate::init::{xavier_uniform, zeros};
14use crate::module::Module;
15use crate::parameter::Parameter;
16
17pub struct RNNCell {
25 pub weight_ih: Parameter,
27 pub weight_hh: Parameter,
29 pub bias_ih: Parameter,
31 pub bias_hh: Parameter,
33 input_size: usize,
35 hidden_size: usize,
37}
38
39impl RNNCell {
40 pub fn new(input_size: usize, hidden_size: usize) -> Self {
42 Self {
43 weight_ih: Parameter::named("weight_ih", xavier_uniform(input_size, hidden_size), true),
44 weight_hh: Parameter::named(
45 "weight_hh",
46 xavier_uniform(hidden_size, hidden_size),
47 true,
48 ),
49 bias_ih: Parameter::named("bias_ih", zeros(&[hidden_size]), true),
50 bias_hh: Parameter::named("bias_hh", zeros(&[hidden_size]), true),
51 input_size,
52 hidden_size,
53 }
54 }
55
56 pub fn input_size(&self) -> usize {
58 self.input_size
59 }
60
61 pub fn hidden_size(&self) -> usize {
63 self.hidden_size
64 }
65
66 pub fn forward_step(&self, input: &Variable, hidden: &Variable) -> Variable {
68 let input_features = input.data().shape().last().copied().unwrap_or(0);
69 assert_eq!(
70 input_features, self.input_size,
71 "RNNCell: expected input size {}, got {}",
72 self.input_size, input_features
73 );
74 let weight_ih = self.weight_ih.variable();
76 let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
77 let ih = input.matmul(&weight_ih_t);
78 let bias_ih = self.bias_ih.variable();
79 let ih = ih.add_var(&bias_ih);
80
81 let weight_hh = self.weight_hh.variable();
83 let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
84 let hh = hidden.matmul(&weight_hh_t);
85 let bias_hh = self.bias_hh.variable();
86 let hh = hh.add_var(&bias_hh);
87
88 ih.add_var(&hh).tanh()
90 }
91}
92
93impl Module for RNNCell {
94 fn forward(&self, input: &Variable) -> Variable {
95 let batch_size = input.shape()[0];
97 let hidden = Variable::new(
98 zeros(&[batch_size, self.hidden_size]),
99 input.requires_grad(),
100 );
101 self.forward_step(input, &hidden)
102 }
103
104 fn parameters(&self) -> Vec<Parameter> {
105 vec![
106 self.weight_ih.clone(),
107 self.weight_hh.clone(),
108 self.bias_ih.clone(),
109 self.bias_hh.clone(),
110 ]
111 }
112
113 fn named_parameters(&self) -> HashMap<String, Parameter> {
114 let mut params = HashMap::new();
115 params.insert("weight_ih".to_string(), self.weight_ih.clone());
116 params.insert("weight_hh".to_string(), self.weight_hh.clone());
117 params.insert("bias_ih".to_string(), self.bias_ih.clone());
118 params.insert("bias_hh".to_string(), self.bias_hh.clone());
119 params
120 }
121
122 fn name(&self) -> &'static str {
123 "RNNCell"
124 }
125}
126
127pub struct RNN {
135 cells: Vec<RNNCell>,
137 input_size: usize,
139 hidden_size: usize,
141 num_layers: usize,
143 batch_first: bool,
145}
146
147impl RNN {
148 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
150 Self::with_options(input_size, hidden_size, num_layers, true)
151 }
152
153 pub fn with_options(
155 input_size: usize,
156 hidden_size: usize,
157 num_layers: usize,
158 batch_first: bool,
159 ) -> Self {
160 let mut cells = Vec::with_capacity(num_layers);
161
162 cells.push(RNNCell::new(input_size, hidden_size));
164
165 for _ in 1..num_layers {
167 cells.push(RNNCell::new(hidden_size, hidden_size));
168 }
169
170 Self {
171 cells,
172 input_size,
173 hidden_size,
174 num_layers,
175 batch_first,
176 }
177 }
178}
179
180impl Module for RNN {
181 fn forward(&self, input: &Variable) -> Variable {
182 let shape = input.shape();
183 let (batch_size, seq_len, _) = if self.batch_first {
184 (shape[0], shape[1], shape[2])
185 } else {
186 (shape[1], shape[0], shape[2])
187 };
188
189 let mut hiddens: Vec<Variable> = (0..self.num_layers)
191 .map(|_| {
192 Variable::new(
193 zeros(&[batch_size, self.hidden_size]),
194 input.requires_grad(),
195 )
196 })
197 .collect();
198
199 let input_data = input.data();
201 let mut outputs = Vec::with_capacity(seq_len);
202
203 for t in 0..seq_len {
204 let t_input = if self.batch_first {
206 let mut slice_data = vec![0.0f32; batch_size * self.input_size];
208 let input_vec = input_data.to_vec();
209 for b in 0..batch_size {
210 for f in 0..self.input_size {
211 let src_idx = b * seq_len * self.input_size + t * self.input_size + f;
212 let dst_idx = b * self.input_size + f;
213 slice_data[dst_idx] = input_vec[src_idx];
214 }
215 }
216 Variable::new(
217 Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
218 input.requires_grad(),
219 )
220 } else {
221 let mut slice_data = vec![0.0f32; batch_size * self.input_size];
223 let input_vec = input_data.to_vec();
224 for b in 0..batch_size {
225 for f in 0..self.input_size {
226 let src_idx = t * batch_size * self.input_size + b * self.input_size + f;
227 let dst_idx = b * self.input_size + f;
228 slice_data[dst_idx] = input_vec[src_idx];
229 }
230 }
231 Variable::new(
232 Tensor::from_vec(slice_data, &[batch_size, self.input_size]).unwrap(),
233 input.requires_grad(),
234 )
235 };
236
237 let mut layer_input = t_input;
239 for (l, cell) in self.cells.iter().enumerate() {
240 hiddens[l] = cell.forward_step(&layer_input, &hiddens[l]);
241 layer_input = hiddens[l].clone();
242 }
243
244 outputs.push(hiddens[self.num_layers - 1].clone());
245 }
246
247 let output_size = batch_size * seq_len * self.hidden_size;
249 let mut output_data = vec![0.0f32; output_size];
250
251 for (t, out) in outputs.iter().enumerate() {
252 let out_vec = out.data().to_vec();
253 for b in 0..batch_size {
254 for h in 0..self.hidden_size {
255 let src_idx = b * self.hidden_size + h;
256 let dst_idx = if self.batch_first {
257 b * seq_len * self.hidden_size + t * self.hidden_size + h
258 } else {
259 t * batch_size * self.hidden_size + b * self.hidden_size + h
260 };
261 output_data[dst_idx] = out_vec[src_idx];
262 }
263 }
264 }
265
266 let output_shape = if self.batch_first {
267 vec![batch_size, seq_len, self.hidden_size]
268 } else {
269 vec![seq_len, batch_size, self.hidden_size]
270 };
271
272 Variable::new(
273 Tensor::from_vec(output_data, &output_shape).unwrap(),
274 input.requires_grad(),
275 )
276 }
277
278 fn parameters(&self) -> Vec<Parameter> {
279 self.cells.iter().flat_map(|c| c.parameters()).collect()
280 }
281
282 fn name(&self) -> &'static str {
283 "RNN"
284 }
285}
286
287pub struct LSTMCell {
293 pub weight_ih: Parameter,
295 pub weight_hh: Parameter,
297 pub bias_ih: Parameter,
299 pub bias_hh: Parameter,
301 input_size: usize,
303 hidden_size: usize,
305}
306
307impl LSTMCell {
308 pub fn new(input_size: usize, hidden_size: usize) -> Self {
310 Self {
312 weight_ih: Parameter::named(
313 "weight_ih",
314 xavier_uniform(input_size, 4 * hidden_size),
315 true,
316 ),
317 weight_hh: Parameter::named(
318 "weight_hh",
319 xavier_uniform(hidden_size, 4 * hidden_size),
320 true,
321 ),
322 bias_ih: Parameter::named("bias_ih", zeros(&[4 * hidden_size]), true),
323 bias_hh: Parameter::named("bias_hh", zeros(&[4 * hidden_size]), true),
324 input_size,
325 hidden_size,
326 }
327 }
328
329 pub fn input_size(&self) -> usize {
331 self.input_size
332 }
333
334 pub fn hidden_size(&self) -> usize {
336 self.hidden_size
337 }
338
339 pub fn forward_step(
341 &self,
342 input: &Variable,
343 hx: &(Variable, Variable),
344 ) -> (Variable, Variable) {
345 let input_features = input.data().shape().last().copied().unwrap_or(0);
346 assert_eq!(
347 input_features, self.input_size,
348 "LSTMCell: expected input size {}, got {}",
349 self.input_size, input_features
350 );
351
352 let (h, c) = hx;
353
354 let weight_ih = self.weight_ih.variable();
356 let weight_ih_t = Variable::new(weight_ih.data().t().unwrap(), weight_ih.requires_grad());
357 let ih = input.matmul(&weight_ih_t);
358 let bias_ih = self.bias_ih.variable();
359 let ih = ih.add_var(&bias_ih);
360
361 let weight_hh = self.weight_hh.variable();
362 let weight_hh_t = Variable::new(weight_hh.data().t().unwrap(), weight_hh.requires_grad());
363 let hh = h.matmul(&weight_hh_t);
364 let bias_hh = self.bias_hh.variable();
365 let hh = hh.add_var(&bias_hh);
366
367 let gates = ih.add_var(&hh);
368 let gates_vec = gates.data().to_vec();
369 let batch_size = input.shape()[0];
370
371 let mut i_data = vec![0.0f32; batch_size * self.hidden_size];
373 let mut f_data = vec![0.0f32; batch_size * self.hidden_size];
374 let mut g_data = vec![0.0f32; batch_size * self.hidden_size];
375 let mut o_data = vec![0.0f32; batch_size * self.hidden_size];
376
377 for b in 0..batch_size {
378 for j in 0..self.hidden_size {
379 let base = b * 4 * self.hidden_size;
380 i_data[b * self.hidden_size + j] = gates_vec[base + j];
381 f_data[b * self.hidden_size + j] = gates_vec[base + self.hidden_size + j];
382 g_data[b * self.hidden_size + j] = gates_vec[base + 2 * self.hidden_size + j];
383 o_data[b * self.hidden_size + j] = gates_vec[base + 3 * self.hidden_size + j];
384 }
385 }
386
387 let i = Variable::new(
388 Tensor::from_vec(i_data, &[batch_size, self.hidden_size]).unwrap(),
389 input.requires_grad(),
390 )
391 .sigmoid();
392 let f = Variable::new(
393 Tensor::from_vec(f_data, &[batch_size, self.hidden_size]).unwrap(),
394 input.requires_grad(),
395 )
396 .sigmoid();
397 let g = Variable::new(
398 Tensor::from_vec(g_data, &[batch_size, self.hidden_size]).unwrap(),
399 input.requires_grad(),
400 )
401 .tanh();
402 let o = Variable::new(
403 Tensor::from_vec(o_data, &[batch_size, self.hidden_size]).unwrap(),
404 input.requires_grad(),
405 )
406 .sigmoid();
407
408 let c_new = f.mul_var(c).add_var(&i.mul_var(&g));
410
411 let h_new = o.mul_var(&c_new.tanh());
413
414 (h_new, c_new)
415 }
416}
417
418impl Module for LSTMCell {
419 fn forward(&self, input: &Variable) -> Variable {
420 let batch_size = input.shape()[0];
421 let h = Variable::new(
422 zeros(&[batch_size, self.hidden_size]),
423 input.requires_grad(),
424 );
425 let c = Variable::new(
426 zeros(&[batch_size, self.hidden_size]),
427 input.requires_grad(),
428 );
429 let (h_new, _) = self.forward_step(input, &(h, c));
430 h_new
431 }
432
433 fn parameters(&self) -> Vec<Parameter> {
434 vec![
435 self.weight_ih.clone(),
436 self.weight_hh.clone(),
437 self.bias_ih.clone(),
438 self.bias_hh.clone(),
439 ]
440 }
441
442 fn name(&self) -> &'static str {
443 "LSTMCell"
444 }
445}
446
447pub struct LSTM {
453 cells: Vec<LSTMCell>,
455 input_size: usize,
457 hidden_size: usize,
459 num_layers: usize,
461 batch_first: bool,
463}
464
465impl LSTM {
466 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
468 Self::with_options(input_size, hidden_size, num_layers, true)
469 }
470
471 pub fn with_options(
473 input_size: usize,
474 hidden_size: usize,
475 num_layers: usize,
476 batch_first: bool,
477 ) -> Self {
478 let mut cells = Vec::with_capacity(num_layers);
479 cells.push(LSTMCell::new(input_size, hidden_size));
480 for _ in 1..num_layers {
481 cells.push(LSTMCell::new(hidden_size, hidden_size));
482 }
483
484 Self {
485 cells,
486 input_size,
487 hidden_size,
488 num_layers,
489 batch_first,
490 }
491 }
492
493 pub fn input_size(&self) -> usize {
495 self.input_size
496 }
497
498 pub fn hidden_size(&self) -> usize {
500 self.hidden_size
501 }
502
503 pub fn num_layers(&self) -> usize {
505 self.num_layers
506 }
507}
508
509impl Module for LSTM {
510 fn forward(&self, input: &Variable) -> Variable {
511 let shape = input.shape();
514 let (batch_size, seq_len, input_features) = if self.batch_first {
515 (shape[0], shape[1], shape[2])
516 } else {
517 (shape[1], shape[0], shape[2])
518 };
519
520 let mut states: Vec<(Variable, Variable)> = (0..self.num_layers)
521 .map(|_| {
522 (
523 Variable::new(
524 zeros(&[batch_size, self.hidden_size]),
525 input.requires_grad(),
526 ),
527 Variable::new(
528 zeros(&[batch_size, self.hidden_size]),
529 input.requires_grad(),
530 ),
531 )
532 })
533 .collect();
534
535 let input_data = input.data();
536 let input_vec = input_data.to_vec();
537 let mut outputs = Vec::with_capacity(seq_len);
538
539 for t in 0..seq_len {
540 let mut slice_data = vec![0.0f32; batch_size * input_features];
541 for b in 0..batch_size {
542 for f in 0..input_features {
543 let src_idx = if self.batch_first {
544 b * seq_len * input_features + t * input_features + f
545 } else {
546 t * batch_size * input_features + b * input_features + f
547 };
548 slice_data[b * input_features + f] = input_vec[src_idx];
549 }
550 }
551
552 let mut layer_input = Variable::new(
554 Tensor::from_vec(slice_data.clone(), &[batch_size, input_features]).unwrap(),
555 input.requires_grad(),
556 );
557
558 for (l, cell) in self.cells.iter().enumerate() {
559 if l > 0 {
561 layer_input = states[l - 1].0.clone();
562 }
563 states[l] = cell.forward_step(&layer_input, &states[l]);
564 }
565
566 outputs.push(states[self.num_layers - 1].0.clone());
567 }
568
569 let mut output_data = vec![0.0f32; batch_size * seq_len * self.hidden_size];
571 for (t, out) in outputs.iter().enumerate() {
572 let out_vec = out.data().to_vec();
573 for b in 0..batch_size {
574 for h in 0..self.hidden_size {
575 let dst_idx = if self.batch_first {
576 b * seq_len * self.hidden_size + t * self.hidden_size + h
577 } else {
578 t * batch_size * self.hidden_size + b * self.hidden_size + h
579 };
580 output_data[dst_idx] = out_vec[b * self.hidden_size + h];
581 }
582 }
583 }
584
585 let output_shape = if self.batch_first {
586 vec![batch_size, seq_len, self.hidden_size]
587 } else {
588 vec![seq_len, batch_size, self.hidden_size]
589 };
590
591 Variable::new(
592 Tensor::from_vec(output_data, &output_shape).unwrap(),
593 input.requires_grad(),
594 )
595 }
596
597 fn parameters(&self) -> Vec<Parameter> {
598 self.cells.iter().flat_map(|c| c.parameters()).collect()
599 }
600
601 fn name(&self) -> &'static str {
602 "LSTM"
603 }
604}
605
606pub struct GRUCell {
618 pub weight_ih: Parameter,
620 pub weight_hh: Parameter,
622 pub bias_ih: Parameter,
624 pub bias_hh: Parameter,
626 input_size: usize,
628 hidden_size: usize,
630}
631
632impl GRUCell {
633 pub fn new(input_size: usize, hidden_size: usize) -> Self {
635 Self {
636 weight_ih: Parameter::named(
637 "weight_ih",
638 xavier_uniform(input_size, 3 * hidden_size),
639 true,
640 ),
641 weight_hh: Parameter::named(
642 "weight_hh",
643 xavier_uniform(hidden_size, 3 * hidden_size),
644 true,
645 ),
646 bias_ih: Parameter::named("bias_ih", zeros(&[3 * hidden_size]), true),
647 bias_hh: Parameter::named("bias_hh", zeros(&[3 * hidden_size]), true),
648 input_size,
649 hidden_size,
650 }
651 }
652
653 pub fn input_size(&self) -> usize {
655 self.input_size
656 }
657
658 pub fn hidden_size(&self) -> usize {
660 self.hidden_size
661 }
662}
663
664impl Module for GRUCell {
665 fn forward(&self, input: &Variable) -> Variable {
666 let batch_size = input.shape()[0];
667
668 Variable::new(
670 zeros(&[batch_size, self.hidden_size]),
671 input.requires_grad(),
672 )
673 }
674
675 fn parameters(&self) -> Vec<Parameter> {
676 vec![
677 self.weight_ih.clone(),
678 self.weight_hh.clone(),
679 self.bias_ih.clone(),
680 self.bias_hh.clone(),
681 ]
682 }
683
684 fn name(&self) -> &'static str {
685 "GRUCell"
686 }
687}
688
689pub struct GRU {
691 cells: Vec<GRUCell>,
693 hidden_size: usize,
695 num_layers: usize,
697 batch_first: bool,
699}
700
701impl GRU {
702 pub fn new(input_size: usize, hidden_size: usize, num_layers: usize) -> Self {
704 let mut cells = Vec::with_capacity(num_layers);
705 cells.push(GRUCell::new(input_size, hidden_size));
706 for _ in 1..num_layers {
707 cells.push(GRUCell::new(hidden_size, hidden_size));
708 }
709 Self {
710 cells,
711 hidden_size,
712 num_layers,
713 batch_first: true,
714 }
715 }
716
717 pub fn hidden_size(&self) -> usize {
719 self.hidden_size
720 }
721
722 pub fn num_layers(&self) -> usize {
724 self.num_layers
725 }
726}
727
728impl Module for GRU {
729 fn forward(&self, input: &Variable) -> Variable {
730 let shape = input.shape();
732 let (batch_size, seq_len) = if self.batch_first {
733 (shape[0], shape[1])
734 } else {
735 (shape[1], shape[0])
736 };
737
738 let output_shape = if self.batch_first {
739 vec![batch_size, seq_len, self.hidden_size]
740 } else {
741 vec![seq_len, batch_size, self.hidden_size]
742 };
743
744 Variable::new(zeros(&output_shape), input.requires_grad())
745 }
746
747 fn parameters(&self) -> Vec<Parameter> {
748 self.cells.iter().flat_map(|c| c.parameters()).collect()
749 }
750
751 fn name(&self) -> &'static str {
752 "GRU"
753 }
754}
755
756#[cfg(test)]
761mod tests {
762 use super::*;
763
764 #[test]
765 fn test_rnn_cell() {
766 let cell = RNNCell::new(10, 20);
767 let input = Variable::new(Tensor::from_vec(vec![1.0; 20], &[2, 10]).unwrap(), false);
768 let hidden = Variable::new(Tensor::from_vec(vec![0.0; 40], &[2, 20]).unwrap(), false);
769 let output = cell.forward_step(&input, &hidden);
770 assert_eq!(output.shape(), vec![2, 20]);
771 }
772
773 #[test]
774 fn test_rnn() {
775 let rnn = RNN::new(10, 20, 2);
776 let input = Variable::new(
777 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
778 false,
779 );
780 let output = rnn.forward(&input);
781 assert_eq!(output.shape(), vec![2, 5, 20]);
782 }
783
784 #[test]
785 fn test_lstm() {
786 let lstm = LSTM::new(10, 20, 1);
787 let input = Variable::new(
788 Tensor::from_vec(vec![1.0; 100], &[2, 5, 10]).unwrap(),
789 false,
790 );
791 let output = lstm.forward(&input);
792 assert_eq!(output.shape(), vec![2, 5, 20]);
793 }
794}