1use candle_core::{DType, Result, Tensor, D};
21
22#[derive(Debug, Clone, Copy)]
24pub struct FlashAttentionConfig {
25 pub block_size: usize,
28 pub causal: bool,
30 pub softmax_eps: f64,
32}
33
34impl Default for FlashAttentionConfig {
35 fn default() -> Self {
36 Self {
37 block_size: 64,
38 causal: false,
39 softmax_eps: 1e-6,
40 }
41 }
42}
43
44impl FlashAttentionConfig {
45 pub fn with_block_size(block_size: usize) -> Self {
47 Self {
48 block_size,
49 ..Default::default()
50 }
51 }
52
53 #[allow(dead_code)]
55 pub fn with_causal(mut self, causal: bool) -> Self {
56 self.causal = causal;
57 self
58 }
59}
60
61#[derive(Debug, Clone)]
67pub struct FlashAttention {
68 config: FlashAttentionConfig,
69 scale: f64,
70}
71
72impl FlashAttention {
73 pub fn new(dim_head: usize, config: FlashAttentionConfig) -> Self {
80 let scale = 1.0 / (dim_head as f64).sqrt();
81 Self { config, scale }
82 }
83
84 pub fn with_dim_head(dim_head: usize) -> Self {
86 Self::new(dim_head, FlashAttentionConfig::default())
87 }
88
89 pub fn forward(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
101 let (batch, heads, seq_q, dim_head) = q.dims4()?;
102 let (_, _, seq_k, _) = k.dims4()?;
103
104 let use_standard = seq_q <= self.config.block_size && seq_k <= self.config.block_size;
106 if use_standard {
107 return self.standard_attention(q, k, v);
108 }
109
110 let in_dtype = q.dtype();
112 let q = q.to_dtype(DType::F32)?;
113 let k = k.to_dtype(DType::F32)?;
114 let v = v.to_dtype(DType::F32)?;
115
116 let block_size = self.config.block_size;
118 let num_q_blocks = seq_q.div_ceil(block_size);
119 let num_k_blocks = seq_k.div_ceil(block_size);
120
121 let device = q.device();
123 let neg_inf = f32::NEG_INFINITY;
124
125 let mut output_blocks: Vec<Tensor> = Vec::with_capacity(num_q_blocks);
127
128 for q_block_idx in 0..num_q_blocks {
129 let q_start = q_block_idx * block_size;
130 let q_end = (q_start + block_size).min(seq_q);
131 let q_len = q_end - q_start;
132
133 let q_block = q.narrow(2, q_start, q_len)?;
135
136 let mut m = Tensor::full(neg_inf, (batch, heads, q_len), device)?;
141 let mut l = Tensor::zeros((batch, heads, q_len), DType::F32, device)?;
142 let mut o = Tensor::zeros((batch, heads, q_len, dim_head), DType::F32, device)?;
143
144 for k_block_idx in 0..num_k_blocks {
146 let k_start = k_block_idx * block_size;
147 let k_end = (k_start + block_size).min(seq_k);
148 let k_len = k_end - k_start;
149
150 if self.config.causal && k_start >= q_end {
152 continue;
153 }
154
155 let k_block = k.narrow(2, k_start, k_len)?;
157 let v_block = v.narrow(2, k_start, k_len)?;
158
159 let k_t = k_block.transpose(D::Minus2, D::Minus1)?;
161 let scores = (q_block.matmul(&k_t)? * self.scale)?;
162
163 let scores = if self.config.causal {
165 self.apply_causal_mask(&scores, q_start, k_start)?
166 } else {
167 scores
168 };
169
170 let (m_new, l_new, o_new) =
172 self.online_softmax_update(&m, &l, &o, &scores, &v_block)?;
173
174 m = m_new;
175 l = l_new;
176 o = o_new;
177 }
178
179 let l_expanded = l.unsqueeze(D::Minus1)?;
181 let l_safe = (l_expanded + self.config.softmax_eps)?;
182 let block_output = o.broadcast_div(&l_safe)?;
183
184 output_blocks.push(block_output);
185 }
186
187 let output = Tensor::cat(&output_blocks, 2)?;
189 output.to_dtype(in_dtype)
190 }
191
192 fn standard_attention(&self, q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
194 let in_dtype = q.dtype();
195 let q = q.to_dtype(DType::F32)?;
196 let k = k.to_dtype(DType::F32)?;
197 let v = v.to_dtype(DType::F32)?;
198
199 let k_t = k.transpose(D::Minus2, D::Minus1)?.contiguous()?;
200 let attn = (q.matmul(&k_t)? * self.scale)?;
201 let attn = candle_nn::ops::softmax_last_dim(&attn)?;
202 let out = attn.matmul(&v)?;
203 out.to_dtype(in_dtype)
204 }
205
206 fn online_softmax_update(
219 &self,
220 m: &Tensor,
221 l: &Tensor,
222 o: &Tensor,
223 scores: &Tensor,
224 v_block: &Tensor,
225 ) -> Result<(Tensor, Tensor, Tensor)> {
226 let m_block = scores.max(D::Minus1)?;
229
230 let m_new = m.maximum(&m_block)?;
232
233 let m_diff_old = m.broadcast_sub(&m_new)?;
236 let rescale_old = m_diff_old.exp()?;
237
238 let m_diff_new = m_block.broadcast_sub(&m_new)?;
241 let _rescale_new = m_diff_new.exp()?;
242
243 let m_new_expanded = m_new.unsqueeze(D::Minus1)?;
246 let p_block = scores.broadcast_sub(&m_new_expanded)?.exp()?;
247
248 let l_block = p_block.sum(D::Minus1)?;
250
251 let l_new = (l.mul(&rescale_old)? + l_block)?;
253
254 let rescale_old_expanded = rescale_old.unsqueeze(D::Minus1)?;
256 let o_rescaled = o.broadcast_mul(&rescale_old_expanded)?;
257 let pv = p_block.matmul(v_block)?;
258 let o_new = (o_rescaled + pv)?;
259
260 Ok((m_new, l_new, o_new))
261 }
262
263 fn apply_causal_mask(&self, scores: &Tensor, q_start: usize, k_start: usize) -> Result<Tensor> {
267 let (batch, heads, q_len, k_len) = scores.dims4()?;
268 let device = scores.device();
269
270 let mut mask_data = vec![0.0f32; q_len * k_len];
272 let neg_inf = f32::NEG_INFINITY;
273
274 for i in 0..q_len {
275 let q_pos = q_start + i;
276 for j in 0..k_len {
277 let k_pos = k_start + j;
278 if k_pos > q_pos {
279 mask_data[i * k_len + j] = neg_inf;
280 }
281 }
282 }
283
284 let mask = Tensor::from_vec(mask_data, (1, 1, q_len, k_len), device)?;
285 let mask = mask.broadcast_as((batch, heads, q_len, k_len))?;
286 scores.add(&mask)
287 }
288}
289
290pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, dim_head: usize) -> Result<Tensor> {
305 let flash = FlashAttention::with_dim_head(dim_head);
306 flash.forward(q, k, v)
307}
308
309pub fn flash_attention_with_config(
311 q: &Tensor,
312 k: &Tensor,
313 v: &Tensor,
314 dim_head: usize,
315 config: FlashAttentionConfig,
316) -> Result<Tensor> {
317 let flash = FlashAttention::new(dim_head, config);
318 flash.forward(q, k, v)
319}
320
321#[cfg(test)]
322mod tests {
323 use super::*;
324 use approx::assert_relative_eq;
325 use candle_core::Device;
326
327 fn create_test_tensors(
328 batch: usize,
329 heads: usize,
330 seq_q: usize,
331 seq_k: usize,
332 dim_head: usize,
333 device: &Device,
334 ) -> Result<(Tensor, Tensor, Tensor)> {
335 let q_size = batch * heads * seq_q * dim_head;
337 let k_size = batch * heads * seq_k * dim_head;
338
339 let q_data: Vec<f32> = (0..q_size).map(|i| (i as f32 * 0.01).sin()).collect();
340 let k_data: Vec<f32> = (0..k_size).map(|i| (i as f32 * 0.02).cos()).collect();
341 let v_data: Vec<f32> = (0..k_size).map(|i| (i as f32 * 0.03).sin()).collect();
342
343 let q = Tensor::from_vec(q_data, (batch, heads, seq_q, dim_head), device)?;
344 let k = Tensor::from_vec(k_data, (batch, heads, seq_k, dim_head), device)?;
345 let v = Tensor::from_vec(v_data, (batch, heads, seq_k, dim_head), device)?;
346
347 Ok((q, k, v))
348 }
349
350 fn standard_attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f64) -> Result<Tensor> {
351 let q = q.to_dtype(DType::F32)?;
352 let k = k.to_dtype(DType::F32)?;
353 let v = v.to_dtype(DType::F32)?;
354
355 let k_t = k.transpose(D::Minus2, D::Minus1)?;
356 let attn = (q.matmul(&k_t)? * scale)?;
357 let attn = candle_nn::ops::softmax_last_dim(&attn)?;
358 attn.matmul(&v)
359 }
360
361 #[test]
362 fn test_flash_attention_small_sequence() -> Result<()> {
363 let device = Device::Cpu;
364 let batch = 2;
365 let heads = 4;
366 let seq_len = 32;
367 let dim_head = 64;
368
369 let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
370
371 let flash = FlashAttention::with_dim_head(dim_head);
372 let flash_out = flash.forward(&q, &k, &v)?;
373
374 let scale = 1.0 / (dim_head as f64).sqrt();
375 let std_out = standard_attention(&q, &k, &v, scale)?;
376
377 let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
379 let std_vec: Vec<f32> = std_out.flatten_all()?.to_vec1()?;
380
381 assert_eq!(flash_vec.len(), std_vec.len());
382 for (f, s) in flash_vec.iter().zip(std_vec.iter()) {
383 assert_relative_eq!(f, s, epsilon = 1e-4);
384 }
385
386 Ok(())
387 }
388
389 #[test]
390 fn test_flash_attention_large_sequence() -> Result<()> {
391 let device = Device::Cpu;
392 let batch = 1;
393 let heads = 2;
394 let seq_len = 128; let dim_head = 32;
396
397 let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
398
399 let flash = FlashAttention::with_dim_head(dim_head);
400 let flash_out = flash.forward(&q, &k, &v)?;
401
402 let scale = 1.0 / (dim_head as f64).sqrt();
403 let std_out = standard_attention(&q, &k, &v, scale)?;
404
405 let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
407 let std_vec: Vec<f32> = std_out.flatten_all()?.to_vec1()?;
408
409 assert_eq!(flash_vec.len(), std_vec.len());
410 for (f, s) in flash_vec.iter().zip(std_vec.iter()) {
411 assert_relative_eq!(f, s, epsilon = 1e-3);
412 }
413
414 Ok(())
415 }
416
417 #[test]
418 fn test_flash_attention_asymmetric_sequences() -> Result<()> {
419 let device = Device::Cpu;
420 let batch = 1;
421 let heads = 2;
422 let seq_q = 100;
423 let seq_k = 150;
424 let dim_head = 32;
425
426 let (q, k, v) = create_test_tensors(batch, heads, seq_q, seq_k, dim_head, &device)?;
427
428 let flash = FlashAttention::with_dim_head(dim_head);
429 let flash_out = flash.forward(&q, &k, &v)?;
430
431 let scale = 1.0 / (dim_head as f64).sqrt();
432 let std_out = standard_attention(&q, &k, &v, scale)?;
433
434 let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
436 let std_vec: Vec<f32> = std_out.flatten_all()?.to_vec1()?;
437
438 assert_eq!(flash_vec.len(), std_vec.len());
439 for (f, s) in flash_vec.iter().zip(std_vec.iter()) {
440 assert_relative_eq!(f, s, epsilon = 1e-3);
441 }
442
443 Ok(())
444 }
445
446 #[test]
447 fn test_flash_attention_output_shape() -> Result<()> {
448 let device = Device::Cpu;
449 let batch = 2;
450 let heads = 4;
451 let seq_q = 96;
452 let seq_k = 128;
453 let dim_head = 64;
454
455 let (q, k, v) = create_test_tensors(batch, heads, seq_q, seq_k, dim_head, &device)?;
456
457 let flash = FlashAttention::with_dim_head(dim_head);
458 let out = flash.forward(&q, &k, &v)?;
459
460 assert_eq!(out.dims(), &[batch, heads, seq_q, dim_head]);
461
462 Ok(())
463 }
464
465 #[test]
466 fn test_flash_attention_single_element() -> Result<()> {
467 let device = Device::Cpu;
468 let batch = 1;
469 let heads = 1;
470 let seq_len = 1;
471 let dim_head = 16;
472
473 let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
474
475 let flash = FlashAttention::with_dim_head(dim_head);
476 let flash_out = flash.forward(&q, &k, &v)?;
477
478 let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
480 let v_vec: Vec<f32> = v.flatten_all()?.to_vec1()?;
481
482 for (f, vv) in flash_vec.iter().zip(v_vec.iter()) {
483 assert_relative_eq!(f, vv, epsilon = 1e-5);
484 }
485
486 Ok(())
487 }
488
489 #[test]
490 fn test_flash_attention_config_block_size() -> Result<()> {
491 let device = Device::Cpu;
492 let batch = 1;
493 let heads = 2;
494 let seq_len = 200;
495 let dim_head = 32;
496
497 let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
498
499 for block_size in [32, 64, 128] {
501 let config = FlashAttentionConfig::with_block_size(block_size);
502 let flash = FlashAttention::new(dim_head, config);
503 let flash_out = flash.forward(&q, &k, &v)?;
504
505 let scale = 1.0 / (dim_head as f64).sqrt();
506 let std_out = standard_attention(&q, &k, &v, scale)?;
507
508 let flash_vec: Vec<f32> = flash_out.to_dtype(DType::F32)?.flatten_all()?.to_vec1()?;
509 let std_vec: Vec<f32> = std_out.flatten_all()?.to_vec1()?;
510
511 for (f, s) in flash_vec.iter().zip(std_vec.iter()) {
512 assert_relative_eq!(f, s, epsilon = 1e-3);
513 }
514 }
515
516 Ok(())
517 }
518
519 #[test]
520 fn test_flash_attention_convenience_function() -> Result<()> {
521 let device = Device::Cpu;
522 let batch = 1;
523 let heads = 2;
524 let seq_len = 64;
525 let dim_head = 32;
526
527 let (q, k, v) = create_test_tensors(batch, heads, seq_len, seq_len, dim_head, &device)?;
528
529 let out = flash_attention(&q, &k, &v, dim_head)?;
530 assert_eq!(out.dims(), &[batch, heads, seq_len, dim_head]);
531
532 Ok(())
533 }
534}