1use axonml_autograd::Variable;
6use axonml_nn::{Module, Linear, Dropout, Parameter};
7use axonml_tensor::Tensor;
8use axonml_tensor::creation::{zeros, ones};
9
10use crate::attention::{MultiHeadSelfAttention, CausalSelfAttention};
11
12#[derive(Debug)]
14pub struct LayerNorm {
15 pub weight: Parameter,
17 pub bias: Parameter,
19 pub eps: f32,
21 pub dim: usize,
23}
24
25impl LayerNorm {
26 pub fn new(dim: usize, eps: f32) -> Self {
28 Self {
29 weight: Parameter::new(ones::<f32>(&[dim]), true),
30 bias: Parameter::new(zeros::<f32>(&[dim]), true),
31 eps,
32 dim,
33 }
34 }
35}
36
37impl Module for LayerNorm {
38 fn forward(&self, input: &Variable) -> Variable {
39 let mean = input.mean_dim(-1, true);
41 let variance = input.var_dim(-1, true);
42
43 let x_normalized = input.sub(&mean).div(&variance.add_scalar(self.eps).sqrt());
44
45 let weight_var = Variable::from_tensor_with_grad(self.weight.data().clone(), self.weight.requires_grad());
47 let bias_var = Variable::from_tensor_with_grad(self.bias.data().clone(), self.bias.requires_grad());
48
49 x_normalized.mul(&weight_var).add(&bias_var)
50 }
51
52 fn parameters(&self) -> Vec<Parameter> {
53 vec![self.weight.clone(), self.bias.clone()]
54 }
55}
56
57#[derive(Debug)]
59pub struct FeedForward {
60 pub fc1: Linear,
62 pub fc2: Linear,
64 pub dropout: Dropout,
66 pub activation: String,
68}
69
70impl FeedForward {
71 pub fn new(hidden_size: usize, intermediate_size: usize, dropout: f32, activation: &str) -> Self {
73 Self {
74 fc1: Linear::new(hidden_size, intermediate_size),
75 fc2: Linear::new(intermediate_size, hidden_size),
76 dropout: Dropout::new(dropout),
77 activation: activation.to_string(),
78 }
79 }
80
81 fn activate(&self, x: &Variable) -> Variable {
83 match self.activation.as_str() {
84 "gelu" => x.gelu(),
85 "relu" => x.relu(),
86 "silu" | "swish" => x.silu(),
87 "tanh" => x.tanh(),
88 _ => x.gelu(), }
90 }
91}
92
93impl Module for FeedForward {
94 fn forward(&self, input: &Variable) -> Variable {
95 let x = self.fc1.forward(input);
96 let x = self.activate(&x);
97 let x = self.dropout.forward(&x);
98 self.fc2.forward(&x)
99 }
100
101 fn parameters(&self) -> Vec<Parameter> {
102 let mut params = Vec::new();
103 params.extend(self.fc1.parameters());
104 params.extend(self.fc2.parameters());
105 params
106 }
107
108 fn train(&mut self) {
109 self.dropout.train();
110 }
111
112 fn eval(&mut self) {
113 self.dropout.eval();
114 }
115}
116
117#[derive(Debug)]
119pub struct TransformerEncoderBlock {
120 pub attention: MultiHeadSelfAttention,
122 pub ln1: LayerNorm,
124 pub ffn: FeedForward,
126 pub ln2: LayerNorm,
128 pub dropout: Dropout,
130 pub pre_norm: bool,
132}
133
134impl TransformerEncoderBlock {
135 pub fn new(
137 hidden_size: usize,
138 num_heads: usize,
139 intermediate_size: usize,
140 dropout: f32,
141 layer_norm_eps: f32,
142 activation: &str,
143 pre_norm: bool,
144 ) -> Self {
145 Self {
146 attention: MultiHeadSelfAttention::new(hidden_size, num_heads, dropout),
147 ln1: LayerNorm::new(hidden_size, layer_norm_eps),
148 ffn: FeedForward::new(hidden_size, intermediate_size, dropout, activation),
149 ln2: LayerNorm::new(hidden_size, layer_norm_eps),
150 dropout: Dropout::new(dropout),
151 pre_norm,
152 }
153 }
154
155 pub fn forward_with_mask(
157 &self,
158 hidden_states: &Variable,
159 attention_mask: Option<&Tensor<f32>>,
160 ) -> Variable {
161 if self.pre_norm {
162 let residual = hidden_states.clone();
164 let x = self.ln1.forward(hidden_states);
165 let x = self.attention.forward_with_mask(&x, attention_mask);
166 let x = self.dropout.forward(&x);
167 let x = x.add(&residual);
168
169 let residual = x.clone();
170 let x = self.ln2.forward(&x);
171 let x = self.ffn.forward(&x);
172 let x = self.dropout.forward(&x);
173 x.add(&residual)
174 } else {
175 let residual = hidden_states.clone();
177 let x = self.attention.forward_with_mask(hidden_states, attention_mask);
178 let x = self.dropout.forward(&x);
179 let x = self.ln1.forward(&x.add(&residual));
180
181 let residual = x.clone();
182 let x = self.ffn.forward(&x);
183 let x = self.dropout.forward(&x);
184 self.ln2.forward(&x.add(&residual))
185 }
186 }
187}
188
189impl Module for TransformerEncoderBlock {
190 fn forward(&self, input: &Variable) -> Variable {
191 self.forward_with_mask(input, None)
192 }
193
194 fn parameters(&self) -> Vec<Parameter> {
195 let mut params = Vec::new();
196 params.extend(self.attention.parameters());
197 params.extend(self.ln1.parameters());
198 params.extend(self.ffn.parameters());
199 params.extend(self.ln2.parameters());
200 params
201 }
202
203 fn train(&mut self) {
204 self.attention.train();
205 self.ffn.train();
206 self.dropout.train();
207 }
208
209 fn eval(&mut self) {
210 self.attention.eval();
211 self.ffn.eval();
212 self.dropout.eval();
213 }
214}
215
216#[derive(Debug)]
218pub struct TransformerDecoderBlock {
219 pub attention: CausalSelfAttention,
221 pub ln1: LayerNorm,
223 pub ffn: FeedForward,
225 pub ln2: LayerNorm,
227}
228
229impl TransformerDecoderBlock {
230 pub fn new(
232 n_embd: usize,
233 n_head: usize,
234 max_seq_len: usize,
235 dropout: f32,
236 layer_norm_eps: f32,
237 activation: &str,
238 ) -> Self {
239 Self {
240 attention: CausalSelfAttention::new(n_embd, n_head, max_seq_len, dropout),
241 ln1: LayerNorm::new(n_embd, layer_norm_eps),
242 ffn: FeedForward::new(n_embd, 4 * n_embd, dropout, activation),
243 ln2: LayerNorm::new(n_embd, layer_norm_eps),
244 }
245 }
246}
247
248impl Module for TransformerDecoderBlock {
249 fn forward(&self, input: &Variable) -> Variable {
250 let x = input.clone();
252
253 let residual = x.clone();
255 let x = self.ln1.forward(&x);
256 let x = self.attention.forward(&x);
257 let x = x.add(&residual);
258
259 let residual = x.clone();
261 let x = self.ln2.forward(&x);
262 let x = self.ffn.forward(&x);
263 x.add(&residual)
264 }
265
266 fn parameters(&self) -> Vec<Parameter> {
267 let mut params = Vec::new();
268 params.extend(self.attention.parameters());
269 params.extend(self.ln1.parameters());
270 params.extend(self.ffn.parameters());
271 params.extend(self.ln2.parameters());
272 params
273 }
274
275 fn train(&mut self) {
276 self.attention.train();
277 self.ffn.train();
278 }
279
280 fn eval(&mut self) {
281 self.attention.eval();
282 self.ffn.eval();
283 }
284}
285
286#[derive(Debug)]
288pub struct TransformerEncoder {
289 pub layers: Vec<TransformerEncoderBlock>,
291}
292
293impl TransformerEncoder {
294 pub fn new(
296 num_layers: usize,
297 hidden_size: usize,
298 num_heads: usize,
299 intermediate_size: usize,
300 dropout: f32,
301 layer_norm_eps: f32,
302 activation: &str,
303 pre_norm: bool,
304 ) -> Self {
305 let layers = (0..num_layers)
306 .map(|_| {
307 TransformerEncoderBlock::new(
308 hidden_size,
309 num_heads,
310 intermediate_size,
311 dropout,
312 layer_norm_eps,
313 activation,
314 pre_norm,
315 )
316 })
317 .collect();
318
319 Self { layers }
320 }
321
322 pub fn forward_with_mask(
324 &self,
325 hidden_states: &Variable,
326 attention_mask: Option<&Tensor<f32>>,
327 ) -> Variable {
328 let mut output = hidden_states.clone();
329 for layer in &self.layers {
330 output = layer.forward_with_mask(&output, attention_mask);
331 }
332 output
333 }
334}
335
336impl Module for TransformerEncoder {
337 fn forward(&self, input: &Variable) -> Variable {
338 self.forward_with_mask(input, None)
339 }
340
341 fn parameters(&self) -> Vec<Parameter> {
342 self.layers.iter().flat_map(|l| l.parameters()).collect()
343 }
344
345 fn train(&mut self) {
346 for layer in &mut self.layers {
347 layer.train();
348 }
349 }
350
351 fn eval(&mut self) {
352 for layer in &mut self.layers {
353 layer.eval();
354 }
355 }
356}
357
358#[derive(Debug)]
360pub struct TransformerDecoder {
361 pub layers: Vec<TransformerDecoderBlock>,
363 pub ln_f: LayerNorm,
365}
366
367impl TransformerDecoder {
368 pub fn new(
370 num_layers: usize,
371 n_embd: usize,
372 n_head: usize,
373 max_seq_len: usize,
374 dropout: f32,
375 layer_norm_eps: f32,
376 activation: &str,
377 ) -> Self {
378 let layers = (0..num_layers)
379 .map(|_| {
380 TransformerDecoderBlock::new(
381 n_embd,
382 n_head,
383 max_seq_len,
384 dropout,
385 layer_norm_eps,
386 activation,
387 )
388 })
389 .collect();
390
391 Self {
392 layers,
393 ln_f: LayerNorm::new(n_embd, layer_norm_eps),
394 }
395 }
396}
397
398impl Module for TransformerDecoder {
399 fn forward(&self, input: &Variable) -> Variable {
400 let mut output = input.clone();
401 for layer in &self.layers {
402 output = layer.forward(&output);
403 }
404 self.ln_f.forward(&output)
405 }
406
407 fn parameters(&self) -> Vec<Parameter> {
408 let mut params: Vec<Parameter> = self.layers.iter().flat_map(|l| l.parameters()).collect();
409 params.extend(self.ln_f.parameters());
410 params
411 }
412
413 fn train(&mut self) {
414 for layer in &mut self.layers {
415 layer.train();
416 }
417 }
418
419 fn eval(&mut self) {
420 for layer in &mut self.layers {
421 layer.eval();
422 }
423 }
424}
425
426#[derive(Debug)]
428pub enum TransformerBlock {
429 Encoder(TransformerEncoderBlock),
431 Decoder(TransformerDecoderBlock),
433}
434
435impl Module for TransformerBlock {
436 fn forward(&self, input: &Variable) -> Variable {
437 match self {
438 TransformerBlock::Encoder(block) => block.forward(input),
439 TransformerBlock::Decoder(block) => block.forward(input),
440 }
441 }
442
443 fn parameters(&self) -> Vec<Parameter> {
444 match self {
445 TransformerBlock::Encoder(block) => block.parameters(),
446 TransformerBlock::Decoder(block) => block.parameters(),
447 }
448 }
449
450 fn train(&mut self) {
451 match self {
452 TransformerBlock::Encoder(block) => block.train(),
453 TransformerBlock::Decoder(block) => block.train(),
454 }
455 }
456
457 fn eval(&mut self) {
458 match self {
459 TransformerBlock::Encoder(block) => block.eval(),
460 TransformerBlock::Decoder(block) => block.eval(),
461 }
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn test_layer_norm() {
471 let ln = LayerNorm::new(64, 1e-5);
472 let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
473 let output = ln.forward(&input);
474
475 assert_eq!(output.data().shape(), &[2, 8, 64]);
476 }
477
478 #[test]
479 fn test_feed_forward() {
480 let ffn = FeedForward::new(64, 256, 0.0, "gelu");
481 let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
482 let output = ffn.forward(&input);
483
484 assert_eq!(output.data().shape(), &[2, 8, 64]);
485 }
486
487 #[test]
488 fn test_encoder_block() {
489 let block = TransformerEncoderBlock::new(64, 4, 256, 0.0, 1e-5, "gelu", false);
490 let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
491 let output = block.forward(&input);
492
493 assert_eq!(output.data().shape(), &[2, 8, 64]);
494 }
495
496 #[test]
497 fn test_decoder_block() {
498 let block = TransformerDecoderBlock::new(64, 4, 128, 0.0, 1e-5, "gelu");
499 let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
500 let output = block.forward(&input);
501
502 assert_eq!(output.data().shape(), &[2, 8, 64]);
503 }
504
505 #[test]
506 fn test_transformer_encoder() {
507 let encoder = TransformerEncoder::new(2, 64, 4, 256, 0.0, 1e-5, "gelu", false);
508 let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
509 let output = encoder.forward(&input);
510
511 assert_eq!(output.data().shape(), &[2, 8, 64]);
512 }
513
514 #[test]
515 fn test_transformer_decoder() {
516 let decoder = TransformerDecoder::new(2, 64, 4, 128, 0.0, 1e-5, "gelu");
517 let input = Variable::new(Tensor::randn(&[2, 8, 64]), false);
518 let output = decoder.forward(&input);
519
520 assert_eq!(output.data().shape(), &[2, 8, 64]);
521 }
522}