Skip to main content

Module flash_attention

Module flash_attention 

Source
Expand description

Flash Attention: Memory-efficient attention with block-wise computation.

This module implements the Flash Attention algorithm, which reduces memory complexity from O(N^2) to O(N) by using tiled computation with online softmax.

§Algorithm Overview

Instead of computing the full N x N attention matrix, Flash Attention:

  1. Splits Q, K, V into blocks of size block_size
  2. For each query block, iterates over all key/value blocks
  3. Computes local attention scores for each block pair
  4. Uses online softmax with running max/sum for numerical stability
  5. Accumulates weighted values with proper rescaling

§References

  • Dao et al., “FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness”, NeurIPS 2022

Structs§

FlashAttention
Flash Attention: Memory-efficient scaled dot-product attention.
FlashAttentionConfig
Configuration for Flash Attention computation.

Functions§

flash_attention
Compute flash attention with default settings.
flash_attention_with_config
Compute flash attention with custom configuration.