burn_attention 0.1.0

Flash Attention v3 implementation for Burn framework
Documentation
# burn_attention


Flash Attention v3 implementation for the [Burn](https://github.com/tracel-ai/burn) deep learning framework.

## Overview


This crate provides an efficient implementation of Flash Attention v3, a memory-efficient attention algorithm that reduces memory usage from quadratic to linear in sequence length. The implementation supports multiple backends including:

- **WGPU** (default): Cross-platform GPU support via WebGPU
- **CubeCL**: High-performance compute kernels
- **CUDA**: Direct CUDA support for NVIDIA GPUs

## Features


- ✅ Standard scaled dot-product attention
- ✅ Causal masking for autoregressive models
- ✅ Custom attention masks
- ✅ Configurable softmax scaling
- ✅ Multiple backend support (WGPU, CubeCL, CUDA)
- ✅ Comprehensive test suite
- ✅ Criterion benchmarks for performance testing

## Installation


Add this to your `Cargo.toml`:

```toml
[dependencies]
burn_attention = "0.1"
```

### Feature Flags


- `wgpu` (default): Enable WGPU backend
- `cubecl`: Enable CubeCL backend
- `cuda`: Enable CUDA backend

Example with CUDA support:

```toml
[dependencies]
burn_attention = { version = "0.1", features = ["cuda"] }
```

## Usage


### Basic Example


```rust
use burn::backend::NdArray;
use burn::tensor::{Distribution, Tensor};
use burn_attention::FlashAttentionV3;

type Backend = NdArray;

fn main() {
    let device = Default::default();

    // Create input tensors
    let batch_size = 2;
    let num_heads = 8;
    let seq_len = 128;
    let head_dim = 64;

    let query = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );
    let key = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );
    let value = Tensor::<Backend, 4>::random(
        [batch_size, num_heads, seq_len, head_dim],
        Distribution::Normal(0.0, 1.0),
        &device,
    );

    // Compute attention
    let output = FlashAttentionV3::forward(query, key, value, None, false);

    println!("Output shape: {:?}", output.dims());
}
```

### Causal Attention


For autoregressive models, use causal masking:

```rust
let output = FlashAttentionV3::forward(query, key, value, None, true);
```

### Custom Configuration


```rust
use burn_attention::FlashAttentionV3Config;

let config = FlashAttentionV3Config {
    causal: true,
    dropout_p: 0.1,
    softmax_scale: Some(0.125),
    block_size_q: 128,
    block_size_k: 128,
};

let output = FlashAttentionV3::forward_with_config(
    query,
    key,
    value,
    None,
    config,
);
```

## Benchmarks


Run benchmarks with:

```bash
cargo bench
```

This will run throughput benchmarks for various sequence lengths and batch sizes.

## Testing


Run the test suite:

```bash
cargo test
```

The test suite includes:
- Unit tests for basic functionality
- Numerical correctness tests comparing against reference implementation
- Property-based tests for attention output

## Implementation Details


This implementation follows the Flash Attention v3 algorithm with optimizations for:

1. **Memory Efficiency**: Tiled computation to reduce memory usage
2. **Numerical Stability**: Online softmax computation
3. **Performance**: Kernel fusion and optimized memory access patterns

### Tensor Shapes


- Query: `[batch_size, num_heads, seq_len_q, head_dim]`
- Key: `[batch_size, num_heads, seq_len_k, head_dim]`
- Value: `[batch_size, num_heads, seq_len_k, head_dim]`
- Output: `[batch_size, num_heads, seq_len_q, head_dim]`

## References


- [Flash Attention v3 Paper]https://github.com/togethercomputer/flash-attention-3
- [Candle Flash Attention v3]https://github.com/huggingface/candle/tree/main/candle-flash-attn-v3
- [Burn Framework]https://github.com/tracel-ai/burn

## License


This project is licensed under either of:

- Apache License, Version 2.0 ([LICENSE-APACHE]LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0)
- MIT license ([LICENSE-MIT]LICENSE-MIT or http://opensource.org/licenses/MIT)

at your option.

## Contributing


Contributions are welcome! Please feel free to submit a Pull Request.