ghostflow_nn/
flash_attention.rs

1//! Flash Attention
2//!
3//! Implements memory-efficient attention computation:
4//! - Flash Attention algorithm for reduced memory usage
5//! - Tiling and recomputation strategies
6//! - Support for causal and non-causal attention
7//! - Optimized for long sequences
8
9use ghostflow_core::Tensor;
10use std::cmp;
11
12/// Flash Attention configuration
13#[derive(Debug, Clone)]
14pub struct FlashAttentionConfig {
15    /// Block size for tiling (M)
16    pub block_size_m: usize,
17    /// Block size for tiling (N)
18    pub block_size_n: usize,
19    /// Whether to use causal masking
20    pub causal: bool,
21    /// Dropout probability
22    pub dropout: f32,
23    /// Scale factor (usually 1/sqrt(d_k))
24    pub scale: f32,
25}
26
27impl Default for FlashAttentionConfig {
28    fn default() -> Self {
29        FlashAttentionConfig {
30            block_size_m: 64,
31            block_size_n: 64,
32            causal: false,
33            dropout: 0.0,
34            scale: 1.0,
35        }
36    }
37}
38
39impl FlashAttentionConfig {
40    /// Configuration for causal attention (GPT-style)
41    pub fn causal(scale: f32) -> Self {
42        FlashAttentionConfig {
43            causal: true,
44            scale,
45            ..Default::default()
46        }
47    }
48    
49    /// Configuration for bidirectional attention (BERT-style)
50    pub fn bidirectional(scale: f32) -> Self {
51        FlashAttentionConfig {
52            causal: false,
53            scale,
54            ..Default::default()
55        }
56    }
57    
58    /// Configuration for long sequences
59    pub fn long_sequence(scale: f32) -> Self {
60        FlashAttentionConfig {
61            block_size_m: 128,
62            block_size_n: 128,
63            scale,
64            ..Default::default()
65        }
66    }
67}
68
69/// Flash Attention implementation
70pub struct FlashAttention {
71    config: FlashAttentionConfig,
72}
73
74impl FlashAttention {
75    /// Create new Flash Attention
76    pub fn new(config: FlashAttentionConfig) -> Self {
77        FlashAttention { config }
78    }
79    
80    /// Forward pass with Flash Attention algorithm
81    pub fn forward(
82        &self,
83        query: &Tensor,
84        key: &Tensor,
85        value: &Tensor,
86    ) -> Result<Tensor, String> {
87        let q_dims = query.dims();
88        let k_dims = key.dims();
89        let v_dims = value.dims();
90        
91        // Validate dimensions
92        if q_dims.len() != 3 || k_dims.len() != 3 || v_dims.len() != 3 {
93            return Err("Expected 3D tensors [batch, seq_len, d_model]".to_string());
94        }
95        
96        let batch_size = q_dims[0];
97        let seq_len_q = q_dims[1];
98        let seq_len_k = k_dims[1];
99        let d_model = q_dims[2];
100        
101        if k_dims[2] != d_model || v_dims[2] != d_model {
102            return Err("Key and Value must have same d_model as Query".to_string());
103        }
104        
105        // Process each batch independently
106        let mut batch_outputs = Vec::new();
107        
108        for b in 0..batch_size {
109            let q_batch = self.extract_batch(query, b)?;
110            let k_batch = self.extract_batch(key, b)?;
111            let v_batch = self.extract_batch(value, b)?;
112            
113            let output = self.flash_attention_single_batch(&q_batch, &k_batch, &v_batch)?;
114            batch_outputs.push(output);
115        }
116        
117        // Concatenate batch outputs
118        self.concatenate_batches(&batch_outputs, batch_size, seq_len_q, d_model)
119    }
120    
121    /// Flash Attention for single batch
122    fn flash_attention_single_batch(
123        &self,
124        query: &Tensor,
125        key: &Tensor,
126        value: &Tensor,
127    ) -> Result<Tensor, String> {
128        let q_data = query.data_f32();
129        let k_data = key.data_f32();
130        let v_data = value.data_f32();
131        
132        let seq_len_q = query.dims()[0];
133        let seq_len_k = key.dims()[0];
134        let d_model = query.dims()[1];
135        
136        let mut output = vec![0.0f32; seq_len_q * d_model];
137        let mut row_max = vec![f32::NEG_INFINITY; seq_len_q];
138        let mut row_sum = vec![0.0f32; seq_len_q];
139        
140        // Tile over sequence length
141        let block_m = self.config.block_size_m;
142        let block_n = self.config.block_size_n;
143        
144        for i in (0..seq_len_q).step_by(block_m) {
145            let end_i = cmp::min(i + block_m, seq_len_q);
146            
147            for j in (0..seq_len_k).step_by(block_n) {
148                let end_j = cmp::min(j + block_n, seq_len_k);
149                
150                // Skip if causal and j > i
151                if self.config.causal && j >= end_i {
152                    continue;
153                }
154                
155                self.process_block(
156                    &q_data, &k_data, &v_data,
157                    &mut output, &mut row_max, &mut row_sum,
158                    i, end_i, j, end_j,
159                    seq_len_q, seq_len_k, d_model,
160                )?;
161            }
162        }
163        
164        // Final normalization
165        for i in 0..seq_len_q {
166            if row_sum[i] > 0.0 {
167                for d in 0..d_model {
168                    output[i * d_model + d] /= row_sum[i];
169                }
170            }
171        }
172        
173        Tensor::from_slice(&output, &[seq_len_q, d_model])
174            .map_err(|e| format!("Failed to create output tensor: {:?}", e))
175    }
176    
177    /// Process a single attention block
178    fn process_block(
179        &self,
180        q_data: &[f32],
181        k_data: &[f32],
182        v_data: &[f32],
183        output: &mut [f32],
184        row_max: &mut [f32],
185        row_sum: &mut [f32],
186        i_start: usize,
187        i_end: usize,
188        j_start: usize,
189        j_end: usize,
190        _seq_len_q: usize,
191        seq_len_k: usize,
192        d_model: usize,
193    ) -> Result<(), String> {
194        // Compute attention scores for this block
195        for i in i_start..i_end {
196            let mut block_max = f32::NEG_INFINITY;
197            let mut scores = Vec::new();
198            
199            // Compute scores Q_i @ K_j^T
200            for j in j_start..j_end {
201                // Apply causal mask
202                if self.config.causal && j > i {
203                    scores.push(f32::NEG_INFINITY);
204                    continue;
205                }
206                
207                let mut score = 0.0;
208                for d in 0..d_model {
209                    score += q_data[i * d_model + d] * k_data[j * d_model + d];
210                }
211                score *= self.config.scale;
212                
213                scores.push(score);
214                block_max = block_max.max(score);
215            }
216            
217            // Update global max and compute softmax
218            let old_max = row_max[i];
219            let new_max = old_max.max(block_max);
220            row_max[i] = new_max;
221            
222            // Compute exponentials and sum
223            let mut block_sum = 0.0;
224            for score in &mut scores {
225                if *score != f32::NEG_INFINITY {
226                    *score = (*score - new_max).exp();
227                    block_sum += *score;
228                } else {
229                    *score = 0.0;
230                }
231            }
232            
233            // Update running sum
234            let correction = (old_max - new_max).exp();
235            row_sum[i] = row_sum[i] * correction + block_sum;
236            
237            // Update output with weighted values
238            for (idx, &score) in scores.iter().enumerate() {
239                let j = j_start + idx;
240                if j < seq_len_k {
241                    for d in 0..d_model {
242                        output[i * d_model + d] = output[i * d_model + d] * correction
243                            + score * v_data[j * d_model + d];
244                    }
245                }
246            }
247        }
248        
249        Ok(())
250    }
251    
252    /// Extract single batch from tensor
253    fn extract_batch(&self, tensor: &Tensor, batch_idx: usize) -> Result<Tensor, String> {
254        let data = tensor.data_f32();
255        let dims = tensor.dims();
256        let seq_len = dims[1];
257        let d_model = dims[2];
258        
259        let start = batch_idx * seq_len * d_model;
260        let end = start + seq_len * d_model;
261        
262        Tensor::from_slice(&data[start..end], &[seq_len, d_model])
263            .map_err(|e| format!("Failed to extract batch: {:?}", e))
264    }
265    
266    /// Concatenate batch outputs
267    fn concatenate_batches(
268        &self,
269        batches: &[Tensor],
270        batch_size: usize,
271        seq_len: usize,
272        d_model: usize,
273    ) -> Result<Tensor, String> {
274        let mut result = Vec::with_capacity(batch_size * seq_len * d_model);
275        
276        for batch in batches {
277            result.extend_from_slice(&batch.data_f32());
278        }
279        
280        Tensor::from_slice(&result, &[batch_size, seq_len, d_model])
281            .map_err(|e| format!("Failed to concatenate batches: {:?}", e))
282    }
283    
284    /// Estimate memory usage compared to standard attention
285    pub fn memory_usage_ratio(&self, seq_len: usize, _d_model: usize) -> f32 {
286        // Standard attention: O(seq_len^2)
287        let standard_memory = seq_len * seq_len;
288        
289        // Flash attention: O(block_size^2)
290        let flash_memory = self.config.block_size_m * self.config.block_size_n;
291        
292        flash_memory as f32 / standard_memory as f32
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    
300    #[test]
301    fn test_flash_attention_config() {
302        let config = FlashAttentionConfig::default();
303        assert_eq!(config.block_size_m, 64);
304        assert!(!config.causal);
305        
306        let causal = FlashAttentionConfig::causal(0.125);
307        assert!(causal.causal);
308        assert_eq!(causal.scale, 0.125);
309    }
310    
311    #[test]
312    fn test_flash_attention_forward() {
313        let config = FlashAttentionConfig::default();
314        let flash_attn = FlashAttention::new(config);
315        
316        let batch_size = 2;
317        let seq_len = 8;
318        let d_model = 16;
319        
320        let query = Tensor::randn(&[batch_size, seq_len, d_model]);
321        let key = Tensor::randn(&[batch_size, seq_len, d_model]);
322        let value = Tensor::randn(&[batch_size, seq_len, d_model]);
323        
324        let output = flash_attn.forward(&query, &key, &value).unwrap();
325        assert_eq!(output.dims(), &[batch_size, seq_len, d_model]);
326    }
327    
328    #[test]
329    fn test_causal_attention() {
330        let config = FlashAttentionConfig::causal(1.0);
331        let flash_attn = FlashAttention::new(config);
332        
333        let query = Tensor::randn(&[1, 4, 8]);
334        let key = Tensor::randn(&[1, 4, 8]);
335        let value = Tensor::randn(&[1, 4, 8]);
336        
337        let output = flash_attn.forward(&query, &key, &value).unwrap();
338        assert_eq!(output.dims(), &[1, 4, 8]);
339    }
340    
341    #[test]
342    fn test_memory_usage_ratio() {
343        let config = FlashAttentionConfig::default();
344        let flash_attn = FlashAttention::new(config);
345        
346        let ratio = flash_attn.memory_usage_ratio(1024, 512);
347        assert!(ratio < 1.0); // Flash attention should use less memory
348        assert!(ratio > 0.0);
349    }
350}