burn_attention
Flash Attention v3 implementation for the Burn deep learning framework.
Overview
This crate provides an efficient implementation of Flash Attention v3, a memory-efficient attention algorithm that reduces memory usage from quadratic to linear in sequence length. The implementation supports multiple backends including:
- WGPU (default): Cross-platform GPU support via WebGPU
- CubeCL: High-performance compute kernels
- CUDA: Direct CUDA support for NVIDIA GPUs
Features
- ✅ Standard scaled dot-product attention
- ✅ Causal masking for autoregressive models
- ✅ Custom attention masks
- ✅ Configurable softmax scaling
- ✅ Multiple backend support (WGPU, CubeCL, CUDA)
- ✅ Comprehensive test suite
- ✅ Criterion benchmarks for performance testing
Installation
Add this to your Cargo.toml:
[]
= "0.1"
Feature Flags
wgpu(default): Enable WGPU backendcubecl: Enable CubeCL backendcuda: Enable CUDA backend
Example with CUDA support:
[]
= { = "0.1", = ["cuda"] }
Usage
Basic Example
use NdArray;
use ;
use FlashAttentionV3;
type Backend = NdArray;
Causal Attention
For autoregressive models, use causal masking:
let output = forward;
Custom Configuration
use FlashAttentionV3Config;
let config = FlashAttentionV3Config ;
let output = forward_with_config;
Benchmarks
Run benchmarks with:
This will run throughput benchmarks for various sequence lengths and batch sizes.
Testing
Run the test suite:
The test suite includes:
- Unit tests for basic functionality
- Numerical correctness tests comparing against reference implementation
- Property-based tests for attention output
Implementation Details
This implementation follows the Flash Attention v3 algorithm with optimizations for:
- Memory Efficiency: Tiled computation to reduce memory usage
- Numerical Stability: Online softmax computation
- Performance: Kernel fusion and optimized memory access patterns
Tensor Shapes
- Query:
[batch_size, num_heads, seq_len_q, head_dim] - Key:
[batch_size, num_heads, seq_len_k, head_dim] - Value:
[batch_size, num_heads, seq_len_k, head_dim] - Output:
[batch_size, num_heads, seq_len_q, head_dim]
References
License
This project is licensed under either of:
- Apache License, Version 2.0 (LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license (LICENSE-MIT or http://opensource.org/licenses/MIT)
at your option.
Contributing
Contributions are welcome! Please feel free to submit a Pull Request.