ghostflow_nn/
attention.rs

1//! Attention mechanisms
2
3use ghostflow_core::Tensor;
4use crate::module::Module;
5use crate::linear::Linear;
6use crate::dropout::Dropout;
7
8/// Scaled Dot-Product Attention
9/// 
10/// Attention(Q, K, V) = softmax(QK^T / sqrt(d_k)) * V
11pub fn scaled_dot_product_attention(
12    query: &Tensor,
13    key: &Tensor,
14    value: &Tensor,
15    mask: Option<&Tensor>,
16    dropout_p: f32,
17    training: bool,
18) -> (Tensor, Tensor) {
19    let d_k = query.dims()[query.ndim() - 1] as f32;
20    let scale = 1.0 / d_k.sqrt();
21    
22    // QK^T
23    let key_t = key.transpose(key.ndim() - 2, key.ndim() - 1).unwrap();
24    let scores = query.matmul(&key_t).unwrap();
25    let scaled_scores = scores.mul_scalar(scale);
26    
27    // Apply mask if provided
28    let masked_scores = if let Some(m) = mask {
29        apply_attention_mask(&scaled_scores, m)
30    } else {
31        scaled_scores
32    };
33    
34    // Softmax
35    let attn_weights = masked_scores.softmax(-1);
36    
37    // Apply dropout during training
38    let attn_weights = if training && dropout_p > 0.0 {
39        let dropout = Dropout::new(dropout_p);
40        dropout.forward(&attn_weights)
41    } else {
42        attn_weights
43    };
44    
45    // Attention output
46    let output = attn_weights.matmul(value).unwrap();
47    
48    (output, attn_weights)
49}
50
51fn apply_attention_mask(scores: &Tensor, mask: &Tensor) -> Tensor {
52    // mask: 1 = keep, 0 = mask out
53    let mask_data = mask.data_f32();
54    let scores_data = scores.data_f32();
55    
56    let result: Vec<f32> = scores_data.iter()
57        .zip(mask_data.iter().cycle())
58        .map(|(&s, &m)| {
59            if m > 0.5 { s } else { f32::NEG_INFINITY }
60        })
61        .collect();
62    
63    Tensor::from_slice(&result, scores.dims()).unwrap()
64}
65
66/// Multi-Head Attention
67pub struct MultiHeadAttention {
68    embed_dim: usize,
69    num_heads: usize,
70    head_dim: usize,
71    
72    q_proj: Linear,
73    k_proj: Linear,
74    v_proj: Linear,
75    out_proj: Linear,
76    
77    dropout_p: f32,
78    training: bool,
79}
80
81impl MultiHeadAttention {
82    pub fn new(embed_dim: usize, num_heads: usize, dropout: f32) -> Self {
83        assert!(embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads");
84        let head_dim = embed_dim / num_heads;
85        
86        MultiHeadAttention {
87            embed_dim,
88            num_heads,
89            head_dim,
90            q_proj: Linear::new(embed_dim, embed_dim),
91            k_proj: Linear::new(embed_dim, embed_dim),
92            v_proj: Linear::new(embed_dim, embed_dim),
93            out_proj: Linear::new(embed_dim, embed_dim),
94            dropout_p: dropout,
95            training: true,
96        }
97    }
98
99    /// Forward pass with optional key-value caching
100    pub fn forward_with_cache(
101        &self,
102        query: &Tensor,
103        key: &Tensor,
104        value: &Tensor,
105        mask: Option<&Tensor>,
106        past_key: Option<&Tensor>,
107        past_value: Option<&Tensor>,
108    ) -> (Tensor, Tensor, Tensor, Tensor) {
109        let batch_size = query.dims()[0];
110        let seq_len = query.dims()[1];
111        
112        // Project Q, K, V
113        let q = self.q_proj.forward(query);
114        let mut k = self.k_proj.forward(key);
115        let mut v = self.v_proj.forward(value);
116        
117        // Concatenate with past key-value if provided (for incremental decoding)
118        if let (Some(pk), Some(pv)) = (past_key, past_value) {
119            k = concat_tensors(pk, &k, 1);
120            v = concat_tensors(pv, &v, 1);
121        }
122        
123        let kv_seq_len = k.dims()[1];
124        
125        // Reshape to [batch, heads, seq, head_dim]
126        let q = self.reshape_for_attention(&q, batch_size, seq_len);
127        let k = self.reshape_for_attention(&k, batch_size, kv_seq_len);
128        let v = self.reshape_for_attention(&v, batch_size, kv_seq_len);
129        
130        // Scaled dot-product attention
131        let (attn_output, attn_weights) = scaled_dot_product_attention(
132            &q, &k, &v, mask, self.dropout_p, self.training
133        );
134        
135        // Reshape back to [batch, seq, embed_dim]
136        let attn_output = self.reshape_from_attention(&attn_output, batch_size, seq_len);
137        
138        // Output projection
139        let output = self.out_proj.forward(&attn_output);
140        
141        // Return output, weights, and current k/v for caching
142        let k_cache = self.reshape_for_attention(&self.k_proj.forward(key), batch_size, key.dims()[1]);
143        let v_cache = self.reshape_for_attention(&self.v_proj.forward(value), batch_size, value.dims()[1]);
144        
145        (output, attn_weights, k_cache, v_cache)
146    }
147
148    fn reshape_for_attention(&self, x: &Tensor, batch_size: usize, seq_len: usize) -> Tensor {
149        // [batch, seq, embed_dim] -> [batch, seq, heads, head_dim] -> [batch, heads, seq, head_dim]
150        let reshaped = x.reshape(&[batch_size, seq_len, self.num_heads, self.head_dim]).unwrap();
151        reshaped.transpose(1, 2).unwrap()
152    }
153
154    fn reshape_from_attention(&self, x: &Tensor, batch_size: usize, seq_len: usize) -> Tensor {
155        // [batch, heads, seq, head_dim] -> [batch, seq, heads, head_dim] -> [batch, seq, embed_dim]
156        let transposed = x.transpose(1, 2).unwrap();
157        transposed.reshape(&[batch_size, seq_len, self.embed_dim]).unwrap()
158    }
159}
160
161impl Module for MultiHeadAttention {
162    fn forward(&self, input: &Tensor) -> Tensor {
163        // Self-attention: Q = K = V = input
164        let (output, _, _, _) = self.forward_with_cache(input, input, input, None, None, None);
165        output
166    }
167
168    fn parameters(&self) -> Vec<Tensor> {
169        let mut params = self.q_proj.parameters();
170        params.extend(self.k_proj.parameters());
171        params.extend(self.v_proj.parameters());
172        params.extend(self.out_proj.parameters());
173        params
174    }
175
176    fn train(&mut self) { self.training = true; }
177    fn eval(&mut self) { self.training = false; }
178    fn is_training(&self) -> bool { self.training }
179}
180
181/// Cross-Attention (for encoder-decoder models)
182pub struct CrossAttention {
183    mha: MultiHeadAttention,
184}
185
186impl CrossAttention {
187    pub fn new(embed_dim: usize, num_heads: usize, dropout: f32) -> Self {
188        CrossAttention {
189            mha: MultiHeadAttention::new(embed_dim, num_heads, dropout),
190        }
191    }
192
193    pub fn forward_cross(&self, query: &Tensor, key: &Tensor, value: &Tensor, mask: Option<&Tensor>) -> Tensor {
194        let (output, _, _, _) = self.mha.forward_with_cache(query, key, value, mask, None, None);
195        output
196    }
197}
198
199impl Module for CrossAttention {
200    fn forward(&self, input: &Tensor) -> Tensor {
201        self.mha.forward(input)
202    }
203
204    fn parameters(&self) -> Vec<Tensor> {
205        self.mha.parameters()
206    }
207
208    fn train(&mut self) { self.mha.train(); }
209    fn eval(&mut self) { self.mha.eval(); }
210    fn is_training(&self) -> bool { self.mha.is_training() }
211}
212
213/// Helper function to concatenate tensors along a dimension
214fn concat_tensors(a: &Tensor, b: &Tensor, dim: usize) -> Tensor {
215    let a_dims = a.dims();
216    let b_dims = b.dims();
217    
218    let mut new_dims = a_dims.to_vec();
219    new_dims[dim] = a_dims[dim] + b_dims[dim];
220    
221    let a_data = a.data_f32();
222    let b_data = b.data_f32();
223    
224    // Simple concatenation for dim=1 (sequence dimension)
225    if dim == 1 {
226        let batch = a_dims[0];
227        let a_seq = a_dims[1];
228        let b_seq = b_dims[1];
229        let rest: usize = a_dims[2..].iter().product();
230        
231        let mut result = Vec::with_capacity(batch * (a_seq + b_seq) * rest);
232        
233        for b_idx in 0..batch {
234            // Copy from a
235            let a_start = b_idx * a_seq * rest;
236            result.extend_from_slice(&a_data[a_start..a_start + a_seq * rest]);
237            // Copy from b
238            let b_start = b_idx * b_seq * rest;
239            result.extend_from_slice(&b_data[b_start..b_start + b_seq * rest]);
240        }
241        
242        Tensor::from_slice(&result, &new_dims).unwrap()
243    } else {
244        // Fallback: just return b (simplified)
245        b.clone()
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use super::*;
252
253    #[test]
254    fn test_scaled_dot_product_attention() {
255        let q = Tensor::randn(&[2, 4, 8]); // [batch, seq, dim]
256        let k = Tensor::randn(&[2, 4, 8]);
257        let v = Tensor::randn(&[2, 4, 8]);
258        
259        let (output, weights) = scaled_dot_product_attention(&q, &k, &v, None, 0.0, false);
260        
261        assert_eq!(output.dims(), &[2, 4, 8]);
262        assert_eq!(weights.dims(), &[2, 4, 4]);
263    }
264
265    #[test]
266    fn test_multi_head_attention() {
267        let mha = MultiHeadAttention::new(64, 8, 0.1);
268        let input = Tensor::randn(&[2, 10, 64]); // [batch, seq, embed_dim]
269        
270        let output = mha.forward(&input);
271        
272        assert_eq!(output.dims(), &[2, 10, 64]);
273    }
274}