1use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::linear::Linear;
6use crate::norm::LayerNorm;
7use crate::dropout::Dropout;
8use crate::attention::MultiHeadAttention;
9
10pub struct FeedForward {
12 linear1: Linear,
13 linear2: Linear,
14 dropout: Dropout,
15 activation: Activation,
16 training: bool,
17}
18
19#[derive(Clone, Copy)]
20pub enum Activation {
21 ReLU,
22 GELU,
23 SiLU,
24}
25
26impl FeedForward {
27 pub fn new(d_model: usize, d_ff: usize, dropout: f32) -> Self {
28 Self::with_activation(d_model, d_ff, dropout, Activation::GELU)
29 }
30
31 pub fn with_activation(d_model: usize, d_ff: usize, dropout: f32, activation: Activation) -> Self {
32 FeedForward {
33 linear1: Linear::new(d_model, d_ff),
34 linear2: Linear::new(d_ff, d_model),
35 dropout: Dropout::new(dropout),
36 activation,
37 training: true,
38 }
39 }
40}
41
42impl Module for FeedForward {
43 fn forward(&self, input: &Tensor) -> Tensor {
44 let x = self.linear1.forward(input);
45 let x = match self.activation {
46 Activation::ReLU => x.relu(),
47 Activation::GELU => x.gelu(),
48 Activation::SiLU => x.silu(),
49 };
50 let x = if self.training {
51 self.dropout.forward(&x)
52 } else {
53 x
54 };
55 self.linear2.forward(&x)
56 }
57
58 fn parameters(&self) -> Vec<Tensor> {
59 let mut params = self.linear1.parameters();
60 params.extend(self.linear2.parameters());
61 params
62 }
63
64 fn train(&mut self) { self.training = true; }
65 fn eval(&mut self) { self.training = false; }
66 fn is_training(&self) -> bool { self.training }
67}
68
69pub struct TransformerEncoderLayer {
71 self_attn: MultiHeadAttention,
72 ffn: FeedForward,
73 norm1: LayerNorm,
74 norm2: LayerNorm,
75 dropout: Dropout,
76 pre_norm: bool,
77 training: bool,
78}
79
80impl TransformerEncoderLayer {
81 pub fn new(d_model: usize, nhead: usize, d_ff: usize, dropout: f32) -> Self {
82 Self::with_config(d_model, nhead, d_ff, dropout, false)
83 }
84
85 pub fn with_config(d_model: usize, nhead: usize, d_ff: usize, dropout: f32, pre_norm: bool) -> Self {
86 TransformerEncoderLayer {
87 self_attn: MultiHeadAttention::new(d_model, nhead, dropout),
88 ffn: FeedForward::new(d_model, d_ff, dropout),
89 norm1: LayerNorm::new(&[d_model]),
90 norm2: LayerNorm::new(&[d_model]),
91 dropout: Dropout::new(dropout),
92 pre_norm,
93 training: true,
94 }
95 }
96
97 pub fn forward_with_mask(&self, src: &Tensor, _mask: Option<&Tensor>) -> Tensor {
98 if self.pre_norm {
99 let x = self.norm1.forward(src);
101 let attn_out = self.self_attn.forward(&x);
102 let x = src.add(&self.dropout.forward(&attn_out)).unwrap();
103
104 let x2 = self.norm2.forward(&x);
105 let ffn_out = self.ffn.forward(&x2);
106 x.add(&self.dropout.forward(&ffn_out)).unwrap()
107 } else {
108 let attn_out = self.self_attn.forward(src);
110 let x = self.norm1.forward(&src.add(&self.dropout.forward(&attn_out)).unwrap());
111
112 let ffn_out = self.ffn.forward(&x);
113 self.norm2.forward(&x.add(&self.dropout.forward(&ffn_out)).unwrap())
114 }
115 }
116}
117
118impl Module for TransformerEncoderLayer {
119 fn forward(&self, input: &Tensor) -> Tensor {
120 self.forward_with_mask(input, None)
121 }
122
123 fn parameters(&self) -> Vec<Tensor> {
124 let mut params = self.self_attn.parameters();
125 params.extend(self.ffn.parameters());
126 params.extend(self.norm1.parameters());
127 params.extend(self.norm2.parameters());
128 params
129 }
130
131 fn train(&mut self) {
132 self.training = true;
133 self.self_attn.train();
134 self.ffn.train();
135 }
136
137 fn eval(&mut self) {
138 self.training = false;
139 self.self_attn.eval();
140 self.ffn.eval();
141 }
142
143 fn is_training(&self) -> bool { self.training }
144}
145
146pub struct TransformerDecoderLayer {
148 self_attn: MultiHeadAttention,
149 cross_attn: MultiHeadAttention,
150 ffn: FeedForward,
151 norm1: LayerNorm,
152 norm2: LayerNorm,
153 norm3: LayerNorm,
154 dropout: Dropout,
155 #[allow(dead_code)]
156 pre_norm: bool,
157 training: bool,
158}
159
160impl TransformerDecoderLayer {
161 pub fn new(d_model: usize, nhead: usize, d_ff: usize, dropout: f32) -> Self {
162 TransformerDecoderLayer {
163 self_attn: MultiHeadAttention::new(d_model, nhead, dropout),
164 cross_attn: MultiHeadAttention::new(d_model, nhead, dropout),
165 ffn: FeedForward::new(d_model, d_ff, dropout),
166 norm1: LayerNorm::new(&[d_model]),
167 norm2: LayerNorm::new(&[d_model]),
168 norm3: LayerNorm::new(&[d_model]),
169 dropout: Dropout::new(dropout),
170 pre_norm: false,
171 training: true,
172 }
173 }
174
175 pub fn forward_with_memory(
176 &self,
177 tgt: &Tensor,
178 memory: &Tensor,
179 _tgt_mask: Option<&Tensor>,
180 memory_mask: Option<&Tensor>,
181 ) -> Tensor {
182 let x = self.norm1.forward(&tgt.add(&self.dropout.forward(&self.self_attn.forward(tgt))).unwrap());
184
185 let (cross_out, _, _, _) = self.cross_attn.forward_with_cache(&x, memory, memory, memory_mask, None, None);
187 let x = self.norm2.forward(&x.add(&self.dropout.forward(&cross_out)).unwrap());
188
189 let ffn_out = self.ffn.forward(&x);
191 self.norm3.forward(&x.add(&self.dropout.forward(&ffn_out)).unwrap())
192 }
193}
194
195impl Module for TransformerDecoderLayer {
196 fn forward(&self, input: &Tensor) -> Tensor {
197 self.self_attn.forward(input)
199 }
200
201 fn parameters(&self) -> Vec<Tensor> {
202 let mut params = self.self_attn.parameters();
203 params.extend(self.cross_attn.parameters());
204 params.extend(self.ffn.parameters());
205 params.extend(self.norm1.parameters());
206 params.extend(self.norm2.parameters());
207 params.extend(self.norm3.parameters());
208 params
209 }
210
211 fn train(&mut self) {
212 self.training = true;
213 self.self_attn.train();
214 self.cross_attn.train();
215 self.ffn.train();
216 }
217
218 fn eval(&mut self) {
219 self.training = false;
220 self.self_attn.eval();
221 self.cross_attn.eval();
222 self.ffn.eval();
223 }
224
225 fn is_training(&self) -> bool { self.training }
226}
227
228pub struct TransformerEncoder {
230 layers: Vec<TransformerEncoderLayer>,
231 norm: Option<LayerNorm>,
232}
233
234impl TransformerEncoder {
235 pub fn new(d_model: usize, nhead: usize, d_ff: usize, num_layers: usize, dropout: f32) -> Self {
236 let layers = (0..num_layers)
237 .map(|_| TransformerEncoderLayer::new(d_model, nhead, d_ff, dropout))
238 .collect();
239
240 TransformerEncoder {
241 layers,
242 norm: Some(LayerNorm::new(&[d_model])),
243 }
244 }
245
246 pub fn forward_with_mask(&self, src: &Tensor, mask: Option<&Tensor>) -> Tensor {
247 let mut output = src.clone();
248
249 for layer in &self.layers {
250 output = layer.forward_with_mask(&output, mask);
251 }
252
253 if let Some(ref norm) = self.norm {
254 output = norm.forward(&output);
255 }
256
257 output
258 }
259}
260
261impl Module for TransformerEncoder {
262 fn forward(&self, input: &Tensor) -> Tensor {
263 self.forward_with_mask(input, None)
264 }
265
266 fn parameters(&self) -> Vec<Tensor> {
267 let mut params: Vec<Tensor> = self.layers.iter()
268 .flat_map(|l| l.parameters())
269 .collect();
270 if let Some(ref norm) = self.norm {
271 params.extend(norm.parameters());
272 }
273 params
274 }
275
276 fn train(&mut self) {
277 for layer in &mut self.layers {
278 layer.train();
279 }
280 }
281
282 fn eval(&mut self) {
283 for layer in &mut self.layers {
284 layer.eval();
285 }
286 }
287
288 fn is_training(&self) -> bool {
289 self.layers.first().is_some_and(|l| l.is_training())
290 }
291}
292
293pub struct PositionalEncoding {
295 encoding: Tensor,
296 dropout: Dropout,
297 #[allow(dead_code)]
298 max_len: usize,
299 d_model: usize,
300}
301
302impl PositionalEncoding {
303 pub fn new(d_model: usize, max_len: usize, dropout: f32) -> Self {
304 let encoding = Self::create_encoding(d_model, max_len);
305
306 PositionalEncoding {
307 encoding,
308 dropout: Dropout::new(dropout),
309 max_len,
310 d_model,
311 }
312 }
313
314 fn create_encoding(d_model: usize, max_len: usize) -> Tensor {
315 let mut pe = vec![0.0f32; max_len * d_model];
316
317 for pos in 0..max_len {
318 for i in 0..d_model / 2 {
319 let angle = pos as f32 / (10000.0f32).powf(2.0 * i as f32 / d_model as f32);
320 pe[pos * d_model + 2 * i] = angle.sin();
321 pe[pos * d_model + 2 * i + 1] = angle.cos();
322 }
323 }
324
325 Tensor::from_slice(&pe, &[max_len, d_model]).unwrap()
326 }
327}
328
329impl Module for PositionalEncoding {
330 fn forward(&self, input: &Tensor) -> Tensor {
331 let seq_len = input.dims()[1];
332
333 let pe_data = self.encoding.data_f32();
335 let pe_slice: Vec<f32> = pe_data[..seq_len * self.d_model].to_vec();
336 let pe = Tensor::from_slice(&pe_slice, &[seq_len, self.d_model]).unwrap();
337
338 let result = input.add(&pe).unwrap();
340 self.dropout.forward(&result)
341 }
342
343 fn parameters(&self) -> Vec<Tensor> {
344 vec![] }
346
347 fn train(&mut self) {}
348 fn eval(&mut self) {}
349 fn is_training(&self) -> bool { false }
350}
351
352pub struct RotaryEmbedding {
354 #[allow(dead_code)]
355 dim: usize,
356 #[allow(dead_code)]
357 max_seq_len: usize,
358 cos_cache: Tensor,
359 sin_cache: Tensor,
360}
361
362impl RotaryEmbedding {
363 pub fn new(dim: usize, max_seq_len: usize, base: f32) -> Self {
364 let (cos_cache, sin_cache) = Self::compute_freqs(dim, max_seq_len, base);
365
366 RotaryEmbedding {
367 dim,
368 max_seq_len,
369 cos_cache,
370 sin_cache,
371 }
372 }
373
374 fn compute_freqs(dim: usize, max_seq_len: usize, base: f32) -> (Tensor, Tensor) {
375 let half_dim = dim / 2;
376
377 let inv_freq: Vec<f32> = (0..half_dim)
379 .map(|i| 1.0 / base.powf(2.0 * i as f32 / dim as f32))
380 .collect();
381
382 let mut cos_data = vec![0.0f32; max_seq_len * half_dim];
384 let mut sin_data = vec![0.0f32; max_seq_len * half_dim];
385
386 for pos in 0..max_seq_len {
387 for (i, &freq) in inv_freq.iter().enumerate() {
388 let angle = pos as f32 * freq;
389 cos_data[pos * half_dim + i] = angle.cos();
390 sin_data[pos * half_dim + i] = angle.sin();
391 }
392 }
393
394 (
395 Tensor::from_slice(&cos_data, &[max_seq_len, half_dim]).unwrap(),
396 Tensor::from_slice(&sin_data, &[max_seq_len, half_dim]).unwrap(),
397 )
398 }
399
400 pub fn apply(&self, q: &Tensor, k: &Tensor, start_pos: usize) -> (Tensor, Tensor) {
402 let seq_len = q.dims()[q.ndim() - 2];
403 let head_dim = q.dims()[q.ndim() - 1];
404 let half_dim = head_dim / 2;
405
406 let cos_data = self.cos_cache.data_f32();
407 let sin_data = self.sin_cache.data_f32();
408
409 let apply_rope = |x: &Tensor| -> Tensor {
410 let data = x.data_f32();
411 let batch_heads: usize = x.dims()[..x.ndim()-2].iter().product();
412
413 let mut result = vec![0.0f32; data.len()];
414
415 for bh in 0..batch_heads {
416 for s in 0..seq_len {
417 let pos = start_pos + s;
418 for i in 0..half_dim {
419 let cos_val = cos_data[pos * half_dim + i];
420 let sin_val = sin_data[pos * half_dim + i];
421
422 let idx1 = bh * seq_len * head_dim + s * head_dim + i;
423 let idx2 = bh * seq_len * head_dim + s * head_dim + i + half_dim;
424
425 let x1 = data[idx1];
426 let x2 = data[idx2];
427
428 result[idx1] = x1 * cos_val - x2 * sin_val;
429 result[idx2] = x1 * sin_val + x2 * cos_val;
430 }
431 }
432 }
433
434 Tensor::from_slice(&result, x.dims()).unwrap()
435 };
436
437 (apply_rope(q), apply_rope(k))
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444
445 #[test]
446 fn test_feed_forward() {
447 let ffn = FeedForward::new(64, 256, 0.1);
448 let input = Tensor::randn(&[2, 10, 64]);
449 let output = ffn.forward(&input);
450
451 assert_eq!(output.dims(), &[2, 10, 64]);
452 }
453
454 #[test]
455 fn test_transformer_encoder_layer() {
456 let layer = TransformerEncoderLayer::new(64, 8, 256, 0.1);
457 let input = Tensor::randn(&[2, 10, 64]);
458 let output = layer.forward(&input);
459
460 assert_eq!(output.dims(), &[2, 10, 64]);
461 }
462
463 #[test]
464 fn test_transformer_encoder() {
465 let encoder = TransformerEncoder::new(64, 8, 256, 6, 0.1);
466 let input = Tensor::randn(&[2, 10, 64]);
467 let output = encoder.forward(&input);
468
469 assert_eq!(output.dims(), &[2, 10, 64]);
470 }
471
472 #[test]
473 fn test_positional_encoding() {
474 let pe = PositionalEncoding::new(64, 512, 0.1);
475 let input = Tensor::randn(&[2, 10, 64]);
476 let output = pe.forward(&input);
477
478 assert_eq!(output.dims(), &[2, 10, 64]);
479 }
480}