# burn_attention
Flash Attention v3 implementation for the [Burn](https://github.com/tracel-ai/burn) deep learning framework.
## Overview
This crate provides an efficient implementation of Flash Attention v3, a memory-efficient attention algorithm that reduces memory usage from quadratic to linear in sequence length. The implementation supports multiple backends including:
- **WGPU** (default): Cross-platform GPU support via WebGPU
- **CubeCL**: High-performance compute kernels
- **CUDA**: Direct CUDA support for NVIDIA GPUs
## Features
- ✅ Standard scaled dot-product attention
- ✅ Causal masking for autoregressive models
- ✅ Custom attention masks
- ✅ Configurable softmax scaling
- ✅ Multiple backend support (WGPU, CubeCL, CUDA)
- ✅ Comprehensive test suite
- ✅ Criterion benchmarks for performance testing
## Installation
Add this to your `Cargo.toml`:
```toml
[dependencies]
burn_attention = "0.1"
```
### Feature Flags
- `wgpu` (default): Enable WGPU backend
- `cubecl`: Enable CubeCL backend
- `cuda`: Enable CUDA backend
Example with CUDA support:
```toml
[dependencies]
burn_attention = { version = "0.1", features = ["cuda"] }
```
## Usage
### Basic Example
```rust
use burn::backend::NdArray;
use burn::tensor::{Distribution, Tensor};
use burn_attention::FlashAttentionV3;
type Backend = NdArray;
fn main() {
let device = Default::default();
// Create input tensors
let batch_size = 2;
let num_heads = 8;
let seq_len = 128;
let head_dim = 64;
let query = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
let key = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
let value = Tensor::<Backend, 4>::random(
[batch_size, num_heads, seq_len, head_dim],
Distribution::Normal(0.0, 1.0),
&device,
);
// Compute attention
let output = FlashAttentionV3::forward(query, key, value, None, false);
println!("Output shape: {:?}", output.dims());
}
```
### Causal Attention
For autoregressive models, use causal masking:
```rust
let output = FlashAttentionV3::forward(query, key, value, None, true);
```
### Custom Configuration
```rust
use burn_attention::FlashAttentionV3Config;
let config = FlashAttentionV3Config {
causal: true,
dropout_p: 0.1,
softmax_scale: Some(0.125),
block_size_q: 128,
block_size_k: 128,
};
let output = FlashAttentionV3::forward_with_config(
query,
key,
value,
None,
config,
);
```
## Benchmarks
Run benchmarks with:
```bash
cargo bench
```
This will run throughput benchmarks for various sequence lengths and batch sizes.
## Testing
Run the test suite:
```bash
cargo test
```
The test suite includes:
- Unit tests for basic functionality
- Numerical correctness tests comparing against reference implementation
- Property-based tests for attention output
## Implementation Details
This implementation follows the Flash Attention v3 algorithm with optimizations for:
1. **Memory Efficiency**: Tiled computation to reduce memory usage
2. **Numerical Stability**: Online softmax computation
3. **Performance**: Kernel fusion and optimized memory access patterns
### Tensor Shapes
- Query: `[batch_size, num_heads, seq_len_q, head_dim]`
- Key: `[batch_size, num_heads, seq_len_k, head_dim]`
- Value: `[batch_size, num_heads, seq_len_k, head_dim]`
- Output: `[batch_size, num_heads, seq_len_q, head_dim]`
## References
- [Flash Attention v3 Paper](https://github.com/togethercomputer/flash-attention-3)
- [Candle Flash Attention v3](https://github.com/huggingface/candle/tree/main/candle-flash-attn-v3)
- [Burn Framework](https://github.com/tracel-ai/burn)
## License
This project is licensed under either of:
- Apache License, Version 2.0 ([LICENSE-APACHE](LICENSE-APACHE) or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license ([LICENSE-MIT](LICENSE-MIT) or http://opensource.org/licenses/MIT)
at your option.
## Contributing
Contributions are welcome! Please feel free to submit a Pull Request.