ghostflow_nn/
transformer.rs

1//! Transformer architecture components
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::linear::Linear;
6use crate::norm::LayerNorm;
7use crate::dropout::Dropout;
8use crate::attention::MultiHeadAttention;
9
10/// Feed-Forward Network (FFN) used in Transformers
11pub struct FeedForward {
12    linear1: Linear,
13    linear2: Linear,
14    dropout: Dropout,
15    activation: Activation,
16    training: bool,
17}
18
19#[derive(Clone, Copy)]
20pub enum Activation {
21    ReLU,
22    GELU,
23    SiLU,
24}
25
26impl FeedForward {
27    pub fn new(d_model: usize, d_ff: usize, dropout: f32) -> Self {
28        Self::with_activation(d_model, d_ff, dropout, Activation::GELU)
29    }
30
31    pub fn with_activation(d_model: usize, d_ff: usize, dropout: f32, activation: Activation) -> Self {
32        FeedForward {
33            linear1: Linear::new(d_model, d_ff),
34            linear2: Linear::new(d_ff, d_model),
35            dropout: Dropout::new(dropout),
36            activation,
37            training: true,
38        }
39    }
40}
41
42impl Module for FeedForward {
43    fn forward(&self, input: &Tensor) -> Tensor {
44        let x = self.linear1.forward(input);
45        let x = match self.activation {
46            Activation::ReLU => x.relu(),
47            Activation::GELU => x.gelu(),
48            Activation::SiLU => x.silu(),
49        };
50        let x = if self.training {
51            self.dropout.forward(&x)
52        } else {
53            x
54        };
55        self.linear2.forward(&x)
56    }
57
58    fn parameters(&self) -> Vec<Tensor> {
59        let mut params = self.linear1.parameters();
60        params.extend(self.linear2.parameters());
61        params
62    }
63
64    fn train(&mut self) { self.training = true; }
65    fn eval(&mut self) { self.training = false; }
66    fn is_training(&self) -> bool { self.training }
67}
68
69/// Transformer Encoder Layer
70pub struct TransformerEncoderLayer {
71    self_attn: MultiHeadAttention,
72    ffn: FeedForward,
73    norm1: LayerNorm,
74    norm2: LayerNorm,
75    dropout: Dropout,
76    pre_norm: bool,
77    training: bool,
78}
79
80impl TransformerEncoderLayer {
81    pub fn new(d_model: usize, nhead: usize, d_ff: usize, dropout: f32) -> Self {
82        Self::with_config(d_model, nhead, d_ff, dropout, false)
83    }
84
85    pub fn with_config(d_model: usize, nhead: usize, d_ff: usize, dropout: f32, pre_norm: bool) -> Self {
86        TransformerEncoderLayer {
87            self_attn: MultiHeadAttention::new(d_model, nhead, dropout),
88            ffn: FeedForward::new(d_model, d_ff, dropout),
89            norm1: LayerNorm::new(&[d_model]),
90            norm2: LayerNorm::new(&[d_model]),
91            dropout: Dropout::new(dropout),
92            pre_norm,
93            training: true,
94        }
95    }
96
97    pub fn forward_with_mask(&self, src: &Tensor, _mask: Option<&Tensor>) -> Tensor {
98        if self.pre_norm {
99            // Pre-LN: x + Attn(LN(x))
100            let x = self.norm1.forward(src);
101            let attn_out = self.self_attn.forward(&x);
102            let x = src.add(&self.dropout.forward(&attn_out)).unwrap();
103            
104            let x2 = self.norm2.forward(&x);
105            let ffn_out = self.ffn.forward(&x2);
106            x.add(&self.dropout.forward(&ffn_out)).unwrap()
107        } else {
108            // Post-LN: LN(x + Attn(x))
109            let attn_out = self.self_attn.forward(src);
110            let x = self.norm1.forward(&src.add(&self.dropout.forward(&attn_out)).unwrap());
111            
112            let ffn_out = self.ffn.forward(&x);
113            self.norm2.forward(&x.add(&self.dropout.forward(&ffn_out)).unwrap())
114        }
115    }
116}
117
118impl Module for TransformerEncoderLayer {
119    fn forward(&self, input: &Tensor) -> Tensor {
120        self.forward_with_mask(input, None)
121    }
122
123    fn parameters(&self) -> Vec<Tensor> {
124        let mut params = self.self_attn.parameters();
125        params.extend(self.ffn.parameters());
126        params.extend(self.norm1.parameters());
127        params.extend(self.norm2.parameters());
128        params
129    }
130
131    fn train(&mut self) {
132        self.training = true;
133        self.self_attn.train();
134        self.ffn.train();
135    }
136
137    fn eval(&mut self) {
138        self.training = false;
139        self.self_attn.eval();
140        self.ffn.eval();
141    }
142
143    fn is_training(&self) -> bool { self.training }
144}
145
146/// Transformer Decoder Layer
147pub struct TransformerDecoderLayer {
148    self_attn: MultiHeadAttention,
149    cross_attn: MultiHeadAttention,
150    ffn: FeedForward,
151    norm1: LayerNorm,
152    norm2: LayerNorm,
153    norm3: LayerNorm,
154    dropout: Dropout,
155    #[allow(dead_code)]
156    pre_norm: bool,
157    training: bool,
158}
159
160impl TransformerDecoderLayer {
161    pub fn new(d_model: usize, nhead: usize, d_ff: usize, dropout: f32) -> Self {
162        TransformerDecoderLayer {
163            self_attn: MultiHeadAttention::new(d_model, nhead, dropout),
164            cross_attn: MultiHeadAttention::new(d_model, nhead, dropout),
165            ffn: FeedForward::new(d_model, d_ff, dropout),
166            norm1: LayerNorm::new(&[d_model]),
167            norm2: LayerNorm::new(&[d_model]),
168            norm3: LayerNorm::new(&[d_model]),
169            dropout: Dropout::new(dropout),
170            pre_norm: false,
171            training: true,
172        }
173    }
174
175    pub fn forward_with_memory(
176        &self,
177        tgt: &Tensor,
178        memory: &Tensor,
179        _tgt_mask: Option<&Tensor>,
180        memory_mask: Option<&Tensor>,
181    ) -> Tensor {
182        // Self-attention on target
183        let x = self.norm1.forward(&tgt.add(&self.dropout.forward(&self.self_attn.forward(tgt))).unwrap());
184        
185        // Cross-attention with encoder output
186        let (cross_out, _, _, _) = self.cross_attn.forward_with_cache(&x, memory, memory, memory_mask, None, None);
187        let x = self.norm2.forward(&x.add(&self.dropout.forward(&cross_out)).unwrap());
188        
189        // Feed-forward
190        let ffn_out = self.ffn.forward(&x);
191        self.norm3.forward(&x.add(&self.dropout.forward(&ffn_out)).unwrap())
192    }
193}
194
195impl Module for TransformerDecoderLayer {
196    fn forward(&self, input: &Tensor) -> Tensor {
197        // For standalone use, treat as self-attention only
198        self.self_attn.forward(input)
199    }
200
201    fn parameters(&self) -> Vec<Tensor> {
202        let mut params = self.self_attn.parameters();
203        params.extend(self.cross_attn.parameters());
204        params.extend(self.ffn.parameters());
205        params.extend(self.norm1.parameters());
206        params.extend(self.norm2.parameters());
207        params.extend(self.norm3.parameters());
208        params
209    }
210
211    fn train(&mut self) {
212        self.training = true;
213        self.self_attn.train();
214        self.cross_attn.train();
215        self.ffn.train();
216    }
217
218    fn eval(&mut self) {
219        self.training = false;
220        self.self_attn.eval();
221        self.cross_attn.eval();
222        self.ffn.eval();
223    }
224
225    fn is_training(&self) -> bool { self.training }
226}
227
228/// Transformer Encoder (stack of encoder layers)
229pub struct TransformerEncoder {
230    layers: Vec<TransformerEncoderLayer>,
231    norm: Option<LayerNorm>,
232}
233
234impl TransformerEncoder {
235    pub fn new(d_model: usize, nhead: usize, d_ff: usize, num_layers: usize, dropout: f32) -> Self {
236        let layers = (0..num_layers)
237            .map(|_| TransformerEncoderLayer::new(d_model, nhead, d_ff, dropout))
238            .collect();
239        
240        TransformerEncoder {
241            layers,
242            norm: Some(LayerNorm::new(&[d_model])),
243        }
244    }
245
246    pub fn forward_with_mask(&self, src: &Tensor, mask: Option<&Tensor>) -> Tensor {
247        let mut output = src.clone();
248        
249        for layer in &self.layers {
250            output = layer.forward_with_mask(&output, mask);
251        }
252        
253        if let Some(ref norm) = self.norm {
254            output = norm.forward(&output);
255        }
256        
257        output
258    }
259}
260
261impl Module for TransformerEncoder {
262    fn forward(&self, input: &Tensor) -> Tensor {
263        self.forward_with_mask(input, None)
264    }
265
266    fn parameters(&self) -> Vec<Tensor> {
267        let mut params: Vec<Tensor> = self.layers.iter()
268            .flat_map(|l| l.parameters())
269            .collect();
270        if let Some(ref norm) = self.norm {
271            params.extend(norm.parameters());
272        }
273        params
274    }
275
276    fn train(&mut self) {
277        for layer in &mut self.layers {
278            layer.train();
279        }
280    }
281
282    fn eval(&mut self) {
283        for layer in &mut self.layers {
284            layer.eval();
285        }
286    }
287
288    fn is_training(&self) -> bool {
289        self.layers.first().is_some_and(|l| l.is_training())
290    }
291}
292
293/// Positional Encoding (sinusoidal)
294pub struct PositionalEncoding {
295    encoding: Tensor,
296    dropout: Dropout,
297    #[allow(dead_code)]
298    max_len: usize,
299    d_model: usize,
300}
301
302impl PositionalEncoding {
303    pub fn new(d_model: usize, max_len: usize, dropout: f32) -> Self {
304        let encoding = Self::create_encoding(d_model, max_len);
305        
306        PositionalEncoding {
307            encoding,
308            dropout: Dropout::new(dropout),
309            max_len,
310            d_model,
311        }
312    }
313
314    fn create_encoding(d_model: usize, max_len: usize) -> Tensor {
315        let mut pe = vec![0.0f32; max_len * d_model];
316        
317        for pos in 0..max_len {
318            for i in 0..d_model / 2 {
319                let angle = pos as f32 / (10000.0f32).powf(2.0 * i as f32 / d_model as f32);
320                pe[pos * d_model + 2 * i] = angle.sin();
321                pe[pos * d_model + 2 * i + 1] = angle.cos();
322            }
323        }
324        
325        Tensor::from_slice(&pe, &[max_len, d_model]).unwrap()
326    }
327}
328
329impl Module for PositionalEncoding {
330    fn forward(&self, input: &Tensor) -> Tensor {
331        let seq_len = input.dims()[1];
332        
333        // Get positional encoding for this sequence length
334        let pe_data = self.encoding.data_f32();
335        let pe_slice: Vec<f32> = pe_data[..seq_len * self.d_model].to_vec();
336        let pe = Tensor::from_slice(&pe_slice, &[seq_len, self.d_model]).unwrap();
337        
338        // Add positional encoding to input
339        let result = input.add(&pe).unwrap();
340        self.dropout.forward(&result)
341    }
342
343    fn parameters(&self) -> Vec<Tensor> {
344        vec![] // Positional encoding has no learnable parameters
345    }
346
347    fn train(&mut self) {}
348    fn eval(&mut self) {}
349    fn is_training(&self) -> bool { false }
350}
351
352/// Rotary Position Embedding (RoPE) - used in modern LLMs
353pub struct RotaryEmbedding {
354    #[allow(dead_code)]
355    dim: usize,
356    #[allow(dead_code)]
357    max_seq_len: usize,
358    cos_cache: Tensor,
359    sin_cache: Tensor,
360}
361
362impl RotaryEmbedding {
363    pub fn new(dim: usize, max_seq_len: usize, base: f32) -> Self {
364        let (cos_cache, sin_cache) = Self::compute_freqs(dim, max_seq_len, base);
365        
366        RotaryEmbedding {
367            dim,
368            max_seq_len,
369            cos_cache,
370            sin_cache,
371        }
372    }
373
374    fn compute_freqs(dim: usize, max_seq_len: usize, base: f32) -> (Tensor, Tensor) {
375        let half_dim = dim / 2;
376        
377        // Compute inverse frequencies
378        let inv_freq: Vec<f32> = (0..half_dim)
379            .map(|i| 1.0 / base.powf(2.0 * i as f32 / dim as f32))
380            .collect();
381        
382        // Compute position * inv_freq
383        let mut cos_data = vec![0.0f32; max_seq_len * half_dim];
384        let mut sin_data = vec![0.0f32; max_seq_len * half_dim];
385        
386        for pos in 0..max_seq_len {
387            for (i, &freq) in inv_freq.iter().enumerate() {
388                let angle = pos as f32 * freq;
389                cos_data[pos * half_dim + i] = angle.cos();
390                sin_data[pos * half_dim + i] = angle.sin();
391            }
392        }
393        
394        (
395            Tensor::from_slice(&cos_data, &[max_seq_len, half_dim]).unwrap(),
396            Tensor::from_slice(&sin_data, &[max_seq_len, half_dim]).unwrap(),
397        )
398    }
399
400    /// Apply rotary embedding to query and key tensors
401    pub fn apply(&self, q: &Tensor, k: &Tensor, start_pos: usize) -> (Tensor, Tensor) {
402        let seq_len = q.dims()[q.ndim() - 2];
403        let head_dim = q.dims()[q.ndim() - 1];
404        let half_dim = head_dim / 2;
405        
406        let cos_data = self.cos_cache.data_f32();
407        let sin_data = self.sin_cache.data_f32();
408        
409        let apply_rope = |x: &Tensor| -> Tensor {
410            let data = x.data_f32();
411            let batch_heads: usize = x.dims()[..x.ndim()-2].iter().product();
412            
413            let mut result = vec![0.0f32; data.len()];
414            
415            for bh in 0..batch_heads {
416                for s in 0..seq_len {
417                    let pos = start_pos + s;
418                    for i in 0..half_dim {
419                        let cos_val = cos_data[pos * half_dim + i];
420                        let sin_val = sin_data[pos * half_dim + i];
421                        
422                        let idx1 = bh * seq_len * head_dim + s * head_dim + i;
423                        let idx2 = bh * seq_len * head_dim + s * head_dim + i + half_dim;
424                        
425                        let x1 = data[idx1];
426                        let x2 = data[idx2];
427                        
428                        result[idx1] = x1 * cos_val - x2 * sin_val;
429                        result[idx2] = x1 * sin_val + x2 * cos_val;
430                    }
431                }
432            }
433            
434            Tensor::from_slice(&result, x.dims()).unwrap()
435        };
436        
437        (apply_rope(q), apply_rope(k))
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444
445    #[test]
446    fn test_feed_forward() {
447        let ffn = FeedForward::new(64, 256, 0.1);
448        let input = Tensor::randn(&[2, 10, 64]);
449        let output = ffn.forward(&input);
450        
451        assert_eq!(output.dims(), &[2, 10, 64]);
452    }
453
454    #[test]
455    fn test_transformer_encoder_layer() {
456        let layer = TransformerEncoderLayer::new(64, 8, 256, 0.1);
457        let input = Tensor::randn(&[2, 10, 64]);
458        let output = layer.forward(&input);
459        
460        assert_eq!(output.dims(), &[2, 10, 64]);
461    }
462
463    #[test]
464    fn test_transformer_encoder() {
465        let encoder = TransformerEncoder::new(64, 8, 256, 6, 0.1);
466        let input = Tensor::randn(&[2, 10, 64]);
467        let output = encoder.forward(&input);
468        
469        assert_eq!(output.dims(), &[2, 10, 64]);
470    }
471
472    #[test]
473    fn test_positional_encoding() {
474        let pe = PositionalEncoding::new(64, 512, 0.1);
475        let input = Tensor::randn(&[2, 10, 64]);
476        let output = pe.forward(&input);
477        
478        assert_eq!(output.dims(), &[2, 10, 64]);
479    }
480}