Expand description
Flash Attention: Memory-efficient attention with block-wise computation.
This module implements the Flash Attention algorithm, which reduces memory complexity from O(N^2) to O(N) by using tiled computation with online softmax.
§Algorithm Overview
Instead of computing the full N x N attention matrix, Flash Attention:
- Splits Q, K, V into blocks of size
block_size - For each query block, iterates over all key/value blocks
- Computes local attention scores for each block pair
- Uses online softmax with running max/sum for numerical stability
- Accumulates weighted values with proper rescaling
§References
- Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, NeurIPS 2022
Structs§
- Flash
Attention - Flash Attention: Memory-efficient scaled dot-product attention.
- Flash
Attention Config - Configuration for Flash Attention computation.
Functions§
- flash_
attention - Compute flash attention with default settings.
- flash_
attention_ with_ config - Compute flash attention with custom configuration.