1use ghostflow_core::Tensor;
10use std::cmp;
11
12#[derive(Debug, Clone)]
14pub struct FlashAttentionConfig {
15 pub block_size_m: usize,
17 pub block_size_n: usize,
19 pub causal: bool,
21 pub dropout: f32,
23 pub scale: f32,
25}
26
27impl Default for FlashAttentionConfig {
28 fn default() -> Self {
29 FlashAttentionConfig {
30 block_size_m: 64,
31 block_size_n: 64,
32 causal: false,
33 dropout: 0.0,
34 scale: 1.0,
35 }
36 }
37}
38
39impl FlashAttentionConfig {
40 pub fn causal(scale: f32) -> Self {
42 FlashAttentionConfig {
43 causal: true,
44 scale,
45 ..Default::default()
46 }
47 }
48
49 pub fn bidirectional(scale: f32) -> Self {
51 FlashAttentionConfig {
52 causal: false,
53 scale,
54 ..Default::default()
55 }
56 }
57
58 pub fn long_sequence(scale: f32) -> Self {
60 FlashAttentionConfig {
61 block_size_m: 128,
62 block_size_n: 128,
63 scale,
64 ..Default::default()
65 }
66 }
67}
68
69pub struct FlashAttention {
71 config: FlashAttentionConfig,
72}
73
74impl FlashAttention {
75 pub fn new(config: FlashAttentionConfig) -> Self {
77 FlashAttention { config }
78 }
79
80 pub fn forward(
82 &self,
83 query: &Tensor,
84 key: &Tensor,
85 value: &Tensor,
86 ) -> Result<Tensor, String> {
87 let q_dims = query.dims();
88 let k_dims = key.dims();
89 let v_dims = value.dims();
90
91 if q_dims.len() != 3 || k_dims.len() != 3 || v_dims.len() != 3 {
93 return Err("Expected 3D tensors [batch, seq_len, d_model]".to_string());
94 }
95
96 let batch_size = q_dims[0];
97 let seq_len_q = q_dims[1];
98 let seq_len_k = k_dims[1];
99 let d_model = q_dims[2];
100
101 if k_dims[2] != d_model || v_dims[2] != d_model {
102 return Err("Key and Value must have same d_model as Query".to_string());
103 }
104
105 let mut batch_outputs = Vec::new();
107
108 for b in 0..batch_size {
109 let q_batch = self.extract_batch(query, b)?;
110 let k_batch = self.extract_batch(key, b)?;
111 let v_batch = self.extract_batch(value, b)?;
112
113 let output = self.flash_attention_single_batch(&q_batch, &k_batch, &v_batch)?;
114 batch_outputs.push(output);
115 }
116
117 self.concatenate_batches(&batch_outputs, batch_size, seq_len_q, d_model)
119 }
120
121 fn flash_attention_single_batch(
123 &self,
124 query: &Tensor,
125 key: &Tensor,
126 value: &Tensor,
127 ) -> Result<Tensor, String> {
128 let q_data = query.data_f32();
129 let k_data = key.data_f32();
130 let v_data = value.data_f32();
131
132 let seq_len_q = query.dims()[0];
133 let seq_len_k = key.dims()[0];
134 let d_model = query.dims()[1];
135
136 let mut output = vec![0.0f32; seq_len_q * d_model];
137 let mut row_max = vec![f32::NEG_INFINITY; seq_len_q];
138 let mut row_sum = vec![0.0f32; seq_len_q];
139
140 let block_m = self.config.block_size_m;
142 let block_n = self.config.block_size_n;
143
144 for i in (0..seq_len_q).step_by(block_m) {
145 let end_i = cmp::min(i + block_m, seq_len_q);
146
147 for j in (0..seq_len_k).step_by(block_n) {
148 let end_j = cmp::min(j + block_n, seq_len_k);
149
150 if self.config.causal && j >= end_i {
152 continue;
153 }
154
155 self.process_block(
156 &q_data, &k_data, &v_data,
157 &mut output, &mut row_max, &mut row_sum,
158 i, end_i, j, end_j,
159 seq_len_q, seq_len_k, d_model,
160 )?;
161 }
162 }
163
164 for i in 0..seq_len_q {
166 if row_sum[i] > 0.0 {
167 for d in 0..d_model {
168 output[i * d_model + d] /= row_sum[i];
169 }
170 }
171 }
172
173 Tensor::from_slice(&output, &[seq_len_q, d_model])
174 .map_err(|e| format!("Failed to create output tensor: {:?}", e))
175 }
176
177 fn process_block(
179 &self,
180 q_data: &[f32],
181 k_data: &[f32],
182 v_data: &[f32],
183 output: &mut [f32],
184 row_max: &mut [f32],
185 row_sum: &mut [f32],
186 i_start: usize,
187 i_end: usize,
188 j_start: usize,
189 j_end: usize,
190 _seq_len_q: usize,
191 seq_len_k: usize,
192 d_model: usize,
193 ) -> Result<(), String> {
194 for i in i_start..i_end {
196 let mut block_max = f32::NEG_INFINITY;
197 let mut scores = Vec::new();
198
199 for j in j_start..j_end {
201 if self.config.causal && j > i {
203 scores.push(f32::NEG_INFINITY);
204 continue;
205 }
206
207 let mut score = 0.0;
208 for d in 0..d_model {
209 score += q_data[i * d_model + d] * k_data[j * d_model + d];
210 }
211 score *= self.config.scale;
212
213 scores.push(score);
214 block_max = block_max.max(score);
215 }
216
217 let old_max = row_max[i];
219 let new_max = old_max.max(block_max);
220 row_max[i] = new_max;
221
222 let mut block_sum = 0.0;
224 for score in &mut scores {
225 if *score != f32::NEG_INFINITY {
226 *score = (*score - new_max).exp();
227 block_sum += *score;
228 } else {
229 *score = 0.0;
230 }
231 }
232
233 let correction = (old_max - new_max).exp();
235 row_sum[i] = row_sum[i] * correction + block_sum;
236
237 for (idx, &score) in scores.iter().enumerate() {
239 let j = j_start + idx;
240 if j < seq_len_k {
241 for d in 0..d_model {
242 output[i * d_model + d] = output[i * d_model + d] * correction
243 + score * v_data[j * d_model + d];
244 }
245 }
246 }
247 }
248
249 Ok(())
250 }
251
252 fn extract_batch(&self, tensor: &Tensor, batch_idx: usize) -> Result<Tensor, String> {
254 let data = tensor.data_f32();
255 let dims = tensor.dims();
256 let seq_len = dims[1];
257 let d_model = dims[2];
258
259 let start = batch_idx * seq_len * d_model;
260 let end = start + seq_len * d_model;
261
262 Tensor::from_slice(&data[start..end], &[seq_len, d_model])
263 .map_err(|e| format!("Failed to extract batch: {:?}", e))
264 }
265
266 fn concatenate_batches(
268 &self,
269 batches: &[Tensor],
270 batch_size: usize,
271 seq_len: usize,
272 d_model: usize,
273 ) -> Result<Tensor, String> {
274 let mut result = Vec::with_capacity(batch_size * seq_len * d_model);
275
276 for batch in batches {
277 result.extend_from_slice(&batch.data_f32());
278 }
279
280 Tensor::from_slice(&result, &[batch_size, seq_len, d_model])
281 .map_err(|e| format!("Failed to concatenate batches: {:?}", e))
282 }
283
284 pub fn memory_usage_ratio(&self, seq_len: usize, _d_model: usize) -> f32 {
286 let standard_memory = seq_len * seq_len;
288
289 let flash_memory = self.config.block_size_m * self.config.block_size_n;
291
292 flash_memory as f32 / standard_memory as f32
293 }
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_flash_attention_config() {
302 let config = FlashAttentionConfig::default();
303 assert_eq!(config.block_size_m, 64);
304 assert!(!config.causal);
305
306 let causal = FlashAttentionConfig::causal(0.125);
307 assert!(causal.causal);
308 assert_eq!(causal.scale, 0.125);
309 }
310
311 #[test]
312 fn test_flash_attention_forward() {
313 let config = FlashAttentionConfig::default();
314 let flash_attn = FlashAttention::new(config);
315
316 let batch_size = 2;
317 let seq_len = 8;
318 let d_model = 16;
319
320 let query = Tensor::randn(&[batch_size, seq_len, d_model]);
321 let key = Tensor::randn(&[batch_size, seq_len, d_model]);
322 let value = Tensor::randn(&[batch_size, seq_len, d_model]);
323
324 let output = flash_attn.forward(&query, &key, &value).unwrap();
325 assert_eq!(output.dims(), &[batch_size, seq_len, d_model]);
326 }
327
328 #[test]
329 fn test_causal_attention() {
330 let config = FlashAttentionConfig::causal(1.0);
331 let flash_attn = FlashAttention::new(config);
332
333 let query = Tensor::randn(&[1, 4, 8]);
334 let key = Tensor::randn(&[1, 4, 8]);
335 let value = Tensor::randn(&[1, 4, 8]);
336
337 let output = flash_attn.forward(&query, &key, &value).unwrap();
338 assert_eq!(output.dims(), &[1, 4, 8]);
339 }
340
341 #[test]
342 fn test_memory_usage_ratio() {
343 let config = FlashAttentionConfig::default();
344 let flash_attn = FlashAttention::new(config);
345
346 let ratio = flash_attn.memory_usage_ratio(1024, 512);
347 assert!(ratio < 1.0); assert!(ratio > 0.0);
349 }
350}