basic_decoder/
basic_decoder.rs1use train_station::{
4 optimizers::{Adam, Optimizer},
5 Tensor,
6};
7
8#[path = "basic_linear_layer.rs"]
9mod basic_linear_layer;
10use basic_linear_layer::LinearLayer;
11
12#[allow(clippy::duplicate_mod)]
13#[path = "multi_head_attention.rs"]
14mod multi_head_attention;
15use multi_head_attention::MultiHeadAttention;
16
17pub struct DecoderBlock {
18 pub _embed_dim: usize,
19 pub _num_heads: usize,
20 self_attn: MultiHeadAttention,
21 cross_attn: MultiHeadAttention,
22 ffn_in: LinearLayer,
23 ffn_out: LinearLayer,
24}
25
26impl DecoderBlock {
27 pub fn new(embed_dim: usize, num_heads: usize, seed: Option<u64>) -> Self {
28 let s0 = seed;
29 let s1 = s0.map(|s| s + 1);
30 let s2 = s0.map(|s| s + 2);
31 let s3 = s0.map(|s| s + 3);
32 Self {
33 _embed_dim: embed_dim,
34 _num_heads: num_heads,
35 self_attn: MultiHeadAttention::new(embed_dim, num_heads, s0),
36 cross_attn: MultiHeadAttention::new(embed_dim, num_heads, s1),
37 ffn_in: LinearLayer::new(embed_dim, embed_dim * 2, s2),
38 ffn_out: LinearLayer::new(embed_dim * 2, embed_dim, s3),
39 }
40 }
41
42 pub fn parameters(&mut self) -> Vec<&mut Tensor> {
43 let mut params = Vec::new();
44 params.extend(self.self_attn.parameters());
45 params.extend(self.cross_attn.parameters());
46 params.extend(self.ffn_in.parameters());
47 params.extend(self.ffn_out.parameters());
48 params
49 }
50
51 pub fn forward(
57 &self,
58 tgt: &Tensor,
59 memory: &Tensor,
60 causal_mask: Option<&Tensor>,
61 cross_mask: Option<&Tensor>,
62 ) -> Tensor {
63 let self_attn = self.self_attn.forward(tgt, tgt, tgt, causal_mask);
64 let res1 = self_attn.add_tensor(tgt);
65
66 let cross = self.cross_attn.forward(&res1, memory, memory, cross_mask);
67 let res2 = cross.add_tensor(&res1);
68
69 let (b, t, e) = Self::triple(tgt);
70 let x2d = res2.contiguous().view(vec![(b * t) as i32, e as i32]);
71 let hidden = self.ffn_in.forward(&x2d).relu();
72 let out2d = self.ffn_out.forward(&hidden);
73 let out = out2d.view(vec![b as i32, t as i32, e as i32]);
74 out.add_tensor(&res2)
75 }
76
77 fn triple(t: &Tensor) -> (usize, usize, usize) {
78 let d = t.shape().dims();
79 (d[0], d[1], d[2])
80 }
81}
82
83#[allow(unused)]
84fn main() -> Result<(), Box<dyn std::error::Error>> {
85 println!("=== Basic Decoder Example ===");
86
87 let batch = 2usize;
88 let src = 7usize;
89 let tgt = 5usize;
90 let embed = 32usize;
91 let heads = 4usize;
92
93 let memory = Tensor::randn(vec![batch, src, embed], Some(21));
94 let tgt_in = Tensor::randn(vec![batch, tgt, embed], Some(22));
95
96 let mut dec = DecoderBlock::new(embed, heads, Some(456));
97 let out = dec.forward(&tgt_in, &memory, None, None);
98 println!("Output shape: {:?}", out.shape().dims());
99
100 let mut opt = Adam::with_learning_rate(0.01);
101 let mut params = dec.parameters();
102 for p in ¶ms {
103 opt.add_parameter(p);
104 }
105 let mut loss = out.mean();
106 loss.backward(None);
107 opt.step(&mut params);
108 opt.zero_grad(&mut params);
109 println!("Loss: {:.6}", loss.value());
110 println!("=== Done ===");
111 Ok(())
112}