pub struct FlashAttention { /* private fields */ }Expand description
Block-wise attention computation optimized for CPU cache locality.
Instead of materializing the full N×N attention matrix, processes the computation in blocks that fit in L1/L2 cache, achieving O(N) memory complexity instead of O(N²).
Implementations§
Source§impl FlashAttention
impl FlashAttention
Sourcepub fn new(config: FlashAttentionConfig) -> Self
pub fn new(config: FlashAttentionConfig) -> Self
Create a new FlashAttention with the given configuration.
Sourcepub fn with_dimensions(dimensions: usize) -> Self
pub fn with_dimensions(dimensions: usize) -> Self
Create with default configuration.
Sourcepub fn config(&self) -> &FlashAttentionConfig
pub fn config(&self) -> &FlashAttentionConfig
Returns a reference to the configuration.
Sourcepub fn attention(
&self,
queries: &[Vec<f32>],
keys: &[Vec<f32>],
values: &[Vec<f32>],
) -> Vec<Vec<f32>>
pub fn attention( &self, queries: &[Vec<f32>], keys: &[Vec<f32>], values: &[Vec<f32>], ) -> Vec<Vec<f32>>
Compute scaled dot-product attention using the block-wise algorithm.
For sequences of length N with dimension D:
- Naive: O(N²) memory (full attention matrix)
- Flash: O(N) memory (block-wise accumulation via online softmax)
§Arguments
queries- Query vectors [N_q × D]keys- Key vectors [N_k × D]values- Value vectors [N_k × D]
§Returns
Output vectors [N_q × D]
Sourcepub fn naive_attention(
&self,
queries: &[Vec<f32>],
keys: &[Vec<f32>],
values: &[Vec<f32>],
) -> Vec<Vec<f32>>
pub fn naive_attention( &self, queries: &[Vec<f32>], keys: &[Vec<f32>], values: &[Vec<f32>], ) -> Vec<Vec<f32>>
Naive attention implementation for benchmarking comparison.
Materializes the full N×N attention matrix: O(N²) memory.
Sourcepub fn benchmark(&self, num_vectors: usize) -> BenchmarkResult
pub fn benchmark(&self, num_vectors: usize) -> BenchmarkResult
Run a benchmark comparing naive vs flash attention.
Generates random vectors and measures wall-clock time for both methods. Also verifies that both implementations produce equivalent results.
Sourcepub fn self_attention(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>>
pub fn self_attention(&self, sequence: &[Vec<f32>]) -> Vec<Vec<f32>>
Compute self-attention: a sequence attends to itself.
Convenience wrapper around attention(q, q, q).
Sourcepub fn cross_attention(
&self,
queries: &[Vec<f32>],
kv_sequence: &[Vec<f32>],
) -> Vec<Vec<f32>>
pub fn cross_attention( &self, queries: &[Vec<f32>], kv_sequence: &[Vec<f32>], ) -> Vec<Vec<f32>>
Compute cross-attention between two sequences.
Queries from one sequence attend to keys/values from another.
Sourcepub fn memory_estimate(&self, seq_len: usize) -> MemoryEstimate
pub fn memory_estimate(&self, seq_len: usize) -> MemoryEstimate
Estimate peak memory usage in bytes for a given sequence length.