use crate::error::{AttentionError, AttentionResult};
#[derive(Clone, Debug)]
pub struct FlashConfig {
pub block_size_q: usize,
pub block_size_kv: usize,
pub causal: bool,
pub dropout_p: f32,
}
impl Default for FlashConfig {
fn default() -> Self {
Self {
block_size_q: 64,
block_size_kv: 64,
causal: false,
dropout_p: 0.0,
}
}
}
impl FlashConfig {
pub fn new(block_size_q: usize, block_size_kv: usize) -> AttentionResult<Self> {
if block_size_q == 0 || block_size_kv == 0 {
return Err(AttentionError::InvalidConfig(
"Block sizes must be > 0".into(),
));
}
Ok(Self {
block_size_q,
block_size_kv,
..Default::default()
})
}
pub fn with_causal(mut self) -> Self {
self.causal = true;
self
}
pub fn with_dropout(mut self, p: f32) -> AttentionResult<Self> {
if !(0.0..=1.0).contains(&p) {
return Err(AttentionError::InvalidConfig(
"Dropout must be in [0, 1]".into(),
));
}
self.dropout_p = p;
Ok(self)
}
}
#[derive(Clone, Debug, Default)]
pub struct IOStats {
pub total_flops: u64,
pub memory_reads: u64,
pub memory_writes: u64,
seq_len: usize,
head_dim: usize,
#[allow(dead_code)]
block_size_q: usize,
#[allow(dead_code)]
block_size_kv: usize,
}
impl IOStats {
pub fn flop_ratio(&self) -> f32 {
if self.total_flops == 0 {
return 1.0;
}
let n = self.seq_len as f64;
let d = self.head_dim as f64;
let naive_io = n * n + n * d; let tiled_io = self.memory_reads as f64 + self.memory_writes as f64;
if tiled_io < 1.0 {
return 1.0;
}
(naive_io / tiled_io) as f32
}
pub fn memory_complexity(&self) -> &'static str {
"O(N)"
}
pub fn naive_memory_complexity(&self) -> &'static str {
"O(N^2)"
}
}
pub struct FlashAttention3;
#[derive(Clone, Debug)]
pub struct FlashOutput {
pub output: Vec<Vec<f32>>,
pub lse: Vec<f32>,
pub stats: IOStats,
}
impl FlashAttention3 {
pub fn forward(
q: &[Vec<f32>],
k: &[Vec<f32>],
v: &[Vec<f32>],
config: &FlashConfig,
) -> AttentionResult<FlashOutput> {
if q.is_empty() {
return Err(AttentionError::EmptyInput("queries".into()));
}
if k.is_empty() || v.is_empty() {
return Err(AttentionError::EmptyInput("keys or values".into()));
}
if k.len() != v.len() {
return Err(AttentionError::DimensionMismatch {
expected: k.len(),
actual: v.len(),
});
}
let d = q[0].len();
if d == 0 {
return Err(AttentionError::InvalidConfig("Dimension must be > 0".into()));
}
let scale = 1.0 / (d as f32).sqrt();
let n_q = q.len();
let n_kv = k.len();
let br = config.block_size_q;
let bc = config.block_size_kv;
let mut output = vec![vec![0.0f32; d]; n_q];
let mut lse = vec![f32::NEG_INFINITY; n_q];
let mut row_max = vec![f32::NEG_INFINITY; n_q];
let mut row_sum = vec![0.0f32; n_q];
let mut stats = IOStats {
seq_len: n_q.max(n_kv),
head_dim: d,
block_size_q: br,
block_size_kv: bc,
..Default::default()
};
for qi_start in (0..n_q).step_by(br) {
let qi_end = (qi_start + br).min(n_q);
for kj_start in (0..n_kv).step_by(bc) {
let kj_end = (kj_start + bc).min(n_kv);
stats.memory_reads += ((qi_end - qi_start) * d
+ (kj_end - kj_start) * d * 2) as u64;
for qi in qi_start..qi_end {
let mut block_scores = Vec::with_capacity(kj_end - kj_start);
for kj in kj_start..kj_end {
let mut dot = 0.0f32;
for dd in 0..d {
dot += q[qi][dd] * k[kj][dd];
}
let mut score = dot * scale;
if config.causal && kj > qi {
score = f32::NEG_INFINITY;
}
block_scores.push(score);
stats.total_flops += (2 * d) as u64; }
let m_ij = block_scores
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
if !m_ij.is_finite() {
continue; }
let exp_scores: Vec<f32> =
block_scores.iter().map(|&s| (s - m_ij).exp()).collect();
let l_ij: f32 = exp_scores
.iter()
.filter(|x| x.is_finite())
.sum();
let m_old = row_max[qi];
let m_new = m_old.max(m_ij);
let exp_old = if m_old.is_finite() {
(m_old - m_new).exp()
} else {
0.0
};
let exp_new = (m_ij - m_new).exp();
let l_new = exp_old * row_sum[qi] + exp_new * l_ij;
if l_new > 0.0 {
let inv_l_new = 1.0 / l_new;
let scale_old = exp_old * row_sum[qi] * inv_l_new;
let scale_new = exp_new * inv_l_new;
for dd in 0..d {
let mut pv = 0.0f32;
for (local_j, kj) in (kj_start..kj_end).enumerate() {
if exp_scores[local_j].is_finite() {
pv += exp_scores[local_j] * v[kj][dd];
}
}
output[qi][dd] =
scale_old * output[qi][dd] + scale_new * pv;
stats.total_flops += (2 * (kj_end - kj_start)) as u64;
}
}
row_max[qi] = m_new;
row_sum[qi] = l_new;
}
}
stats.memory_writes += ((qi_end - qi_start) * d) as u64;
}
for i in 0..n_q {
if row_sum[i] > 0.0 && row_max[i].is_finite() {
lse[i] = row_max[i] + row_sum[i].ln();
}
}
Ok(FlashOutput {
output,
lse,
stats,
})
}
}
pub fn causal_block_mask(
qi_start: usize,
qi_end: usize,
kj_start: usize,
kj_end: usize,
) -> Vec<Vec<bool>> {
let mut mask = Vec::with_capacity(qi_end - qi_start);
for qi in qi_start..qi_end {
let mut row = Vec::with_capacity(kj_end - kj_start);
for kj in kj_start..kj_end {
row.push(kj <= qi);
}
mask.push(row);
}
mask
}
pub struct RingAttention;
#[derive(Clone, Debug)]
pub struct RingDeviceOutput {
pub output: Vec<Vec<f32>>,
pub lse: Vec<f32>,
pub transfers: usize,
}
impl RingAttention {
pub fn ring_forward(
q_shards: &[Vec<Vec<f32>>],
k_shards: &[Vec<Vec<f32>>],
v_shards: &[Vec<Vec<f32>>],
) -> AttentionResult<Vec<RingDeviceOutput>> {
let num_devices = q_shards.len();
if num_devices == 0 {
return Err(AttentionError::EmptyInput("shards".into()));
}
if k_shards.len() != num_devices || v_shards.len() != num_devices {
return Err(AttentionError::DimensionMismatch {
expected: num_devices,
actual: k_shards.len().min(v_shards.len()),
});
}
let config = FlashConfig {
block_size_q: 32,
block_size_kv: 32,
causal: false,
dropout_p: 0.0,
};
let mut results = Vec::with_capacity(num_devices);
for device_id in 0..num_devices {
let local_q = &q_shards[device_id];
if local_q.is_empty() {
return Err(AttentionError::EmptyInput(
format!("Q shard on device {device_id}"),
));
}
let d = local_q[0].len();
let n_q = local_q.len();
let mut output = vec![vec![0.0f32; d]; n_q];
let mut row_max = vec![f32::NEG_INFINITY; n_q];
let mut row_sum = vec![0.0f32; n_q];
let mut lse = vec![f32::NEG_INFINITY; n_q];
let mut transfers = 0usize;
for step in 0..num_devices {
let kv_idx = (device_id + step) % num_devices;
if step > 0 {
transfers += 1; }
let partial = FlashAttention3::forward(
local_q,
&k_shards[kv_idx],
&v_shards[kv_idx],
&config,
)?;
for qi in 0..n_q {
let m_partial = if partial.lse[qi].is_finite() {
partial.lse[qi]
} else {
continue;
};
let m_old = row_max[qi];
let m_new = m_old.max(m_partial);
let exp_old = if m_old.is_finite() {
(m_old - m_new).exp()
} else {
0.0
};
let exp_partial = (m_partial - m_new).exp();
let l_partial = if partial.lse[qi].is_finite() {
partial.lse[qi].exp()
} else {
0.0
};
let l_old = row_sum[qi];
let l_new = exp_old * l_old + exp_partial * l_partial;
if l_new > 0.0 {
let inv_l = 1.0 / l_new;
for dd in 0..d {
output[qi][dd] = (exp_old * l_old * output[qi][dd]
+ exp_partial * l_partial * partial.output[qi][dd])
* inv_l;
}
}
row_max[qi] = m_new;
row_sum[qi] = l_new;
}
}
for qi in 0..n_q {
if row_sum[qi] > 0.0 && row_max[qi].is_finite() {
lse[qi] = row_max[qi] + row_sum[qi].ln();
}
}
results.push(RingDeviceOutput {
output,
lse,
transfers,
});
}
Ok(results)
}
}
fn naive_attention(
q: &[Vec<f32>],
k: &[Vec<f32>],
v: &[Vec<f32>],
causal: bool,
) -> Vec<Vec<f32>> {
let n_q = q.len();
let n_kv = k.len();
let d = q[0].len();
let scale = 1.0 / (d as f32).sqrt();
let mut output = vec![vec![0.0f32; d]; n_q];
for qi in 0..n_q {
let mut scores = Vec::with_capacity(n_kv);
for kj in 0..n_kv {
let mut dot = 0.0f32;
for dd in 0..d {
dot += q[qi][dd] * k[kj][dd];
}
let mut s = dot * scale;
if causal && kj > qi {
s = f32::NEG_INFINITY;
}
scores.push(s);
}
let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let exp_s: Vec<f32> = scores.iter().map(|&s| (s - max_s).exp()).collect();
let sum_s: f32 = exp_s.iter().sum();
for dd in 0..d {
let mut val = 0.0f32;
for kj in 0..n_kv {
val += (exp_s[kj] / sum_s) * v[kj][dd];
}
output[qi][dd] = val;
}
}
output
}
#[cfg(test)]
mod tests {
use super::*;
fn make_seq(n: usize, d: usize, seed: f32) -> Vec<Vec<f32>> {
(0..n)
.map(|i| {
(0..d)
.map(|j| ((i as f32 + 1.0) * (j as f32 + 1.0) * seed).sin() * 0.5)
.collect()
})
.collect()
}
#[test]
fn test_forward_matches_naive() {
let d = 16;
let n = 12;
let q = make_seq(n, d, 0.1);
let k = make_seq(n, d, 0.2);
let v = make_seq(n, d, 0.3);
let config = FlashConfig::new(4, 4).unwrap();
let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
let naive = naive_attention(&q, &k, &v, false);
for qi in 0..n {
for dd in 0..d {
let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
assert!(diff < 1e-4, "row={qi} col={dd} flash={} naive={} diff={diff}",
flash.output[qi][dd], naive[qi][dd]);
}
}
}
#[test]
fn test_causal_masking() {
let d = 8;
let n = 6;
let q = make_seq(n, d, 0.4);
let k = make_seq(n, d, 0.5);
let v = make_seq(n, d, 0.6);
let config = FlashConfig::new(2, 2).unwrap().with_causal();
let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
let naive = naive_attention(&q, &k, &v, true);
for qi in 0..n {
for dd in 0..d {
let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
assert!(diff < 1e-4, "causal row={qi} col={dd} diff={diff}");
}
}
}
#[test]
fn test_numerical_stability_large_values() {
let d = 8;
let n = 4;
let q: Vec<Vec<f32>> = (0..n)
.map(|i| vec![100.0 * (i as f32 + 1.0); d])
.collect();
let k = q.clone();
let v: Vec<Vec<f32>> = (0..n).map(|i| vec![i as f32; d]).collect();
let config = FlashConfig::new(2, 2).unwrap();
let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
for row in &result.output {
for &val in row {
assert!(val.is_finite(), "Non-finite output: {val}");
}
}
for &l in &result.lse {
assert!(l.is_finite(), "Non-finite LSE: {l}");
}
}
#[test]
fn test_block_size_variations() {
let d = 8;
let n = 10;
let q = make_seq(n, d, 0.7);
let k = make_seq(n, d, 0.8);
let v = make_seq(n, d, 0.9);
let block_sizes = [(2, 2), (3, 5), (1, 1), (10, 10), (7, 3)];
let naive = naive_attention(&q, &k, &v, false);
for (bq, bk) in block_sizes {
let config = FlashConfig::new(bq, bk).unwrap();
let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
for qi in 0..n {
for dd in 0..d {
let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
assert!(
diff < 1e-4,
"blocks=({bq},{bk}) row={qi} col={dd} diff={diff}"
);
}
}
}
}
#[test]
fn test_io_stats_tracking() {
let d = 8;
let n = 16;
let q = make_seq(n, d, 1.0);
let k = make_seq(n, d, 1.1);
let v = make_seq(n, d, 1.2);
let config = FlashConfig::new(4, 4).unwrap();
let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
assert!(result.stats.total_flops > 0, "FLOPs should be tracked");
assert!(result.stats.memory_reads > 0, "Reads should be tracked");
assert!(result.stats.memory_writes > 0, "Writes should be tracked");
assert_eq!(result.stats.memory_complexity(), "O(N)");
assert_eq!(result.stats.naive_memory_complexity(), "O(N^2)");
let ratio = result.stats.flop_ratio();
assert!(ratio > 0.0, "IO ratio should be positive");
}
#[test]
fn test_ring_attention() {
let d = 8;
let shard_size = 4;
let num_devices = 3;
let q_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
.map(|dev| make_seq(shard_size, d, 0.1 * (dev as f32 + 1.0)))
.collect();
let k_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
.map(|dev| make_seq(shard_size, d, 0.2 * (dev as f32 + 1.0)))
.collect();
let v_shards: Vec<Vec<Vec<f32>>> = (0..num_devices)
.map(|dev| make_seq(shard_size, d, 0.3 * (dev as f32 + 1.0)))
.collect();
let results =
RingAttention::ring_forward(&q_shards, &k_shards, &v_shards).unwrap();
assert_eq!(results.len(), num_devices);
for (dev_id, res) in results.iter().enumerate() {
assert_eq!(res.output.len(), shard_size);
assert_eq!(res.output[0].len(), d);
assert_eq!(res.transfers, num_devices - 1,
"Device {dev_id} should have {} transfers", num_devices - 1);
for row in &res.output {
for &val in row {
assert!(val.is_finite(), "Device {dev_id} has non-finite output");
}
}
}
}
#[test]
fn test_single_block() {
let d = 4;
let n = 3;
let q = make_seq(n, d, 1.5);
let k = make_seq(n, d, 1.6);
let v = make_seq(n, d, 1.7);
let config = FlashConfig::new(n, n).unwrap();
let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
let naive = naive_attention(&q, &k, &v, false);
for qi in 0..n {
for dd in 0..d {
let diff = (flash.output[qi][dd] - naive[qi][dd]).abs();
assert!(diff < 1e-5, "single block row={qi} col={dd} diff={diff}");
}
}
}
#[test]
fn test_large_sequence() {
let d = 16;
let n = 128;
let q = make_seq(n, d, 2.0);
let k = make_seq(n, d, 2.1);
let v = make_seq(n, d, 2.2);
let config = FlashConfig::new(16, 16).unwrap();
let flash = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
let naive = naive_attention(&q, &k, &v, false);
let mut max_diff = 0.0f32;
for qi in 0..n {
for dd in 0..d {
max_diff = max_diff.max((flash.output[qi][dd] - naive[qi][dd]).abs());
}
}
assert!(max_diff < 1e-3, "Large seq max diff: {max_diff}");
}
#[test]
fn test_lse_correctness() {
let d = 8;
let n = 6;
let q = make_seq(n, d, 3.0);
let k = make_seq(n, d, 3.1);
let v = make_seq(n, d, 3.2);
let scale = 1.0 / (d as f32).sqrt();
let config = FlashConfig::new(2, 3).unwrap();
let result = FlashAttention3::forward(&q, &k, &v, &config).unwrap();
for qi in 0..n {
let mut scores = Vec::with_capacity(n);
for kj in 0..n {
let dot: f32 = (0..d).map(|dd| q[qi][dd] * k[kj][dd]).sum();
scores.push(dot * scale);
}
let max_s = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let sum_exp: f32 = scores.iter().map(|&s| (s - max_s).exp()).sum();
let expected_lse = max_s + sum_exp.ln();
let diff = (result.lse[qi] - expected_lse).abs();
assert!(diff < 1e-3, "LSE row={qi} flash={} expected={expected_lse} diff={diff}",
result.lse[qi]);
}
}
#[test]
fn test_causal_block_mask_utility() {
let mask = causal_block_mask(2, 5, 0, 4);
assert_eq!(mask[0], vec![true, true, true, false]);
assert_eq!(mask[1], vec![true, true, true, true]);
assert_eq!(mask[2], vec![true, true, true, true]);
}
#[test]
fn test_empty_input_errors() {
let config = FlashConfig::default();
let empty: Vec<Vec<f32>> = vec![];
let q = vec![vec![1.0; 4]];
assert!(FlashAttention3::forward(&empty, &q, &q, &config).is_err());
assert!(FlashAttention3::forward(&q, &empty, &q, &config).is_err());
assert!(FlashAttention3::forward(&q, &q, &empty, &config).is_err());
}
#[test]
fn test_config_validation() {
assert!(FlashConfig::new(0, 4).is_err());
assert!(FlashConfig::new(4, 0).is_err());
assert!(FlashConfig::new(4, 4).is_ok());
assert!(FlashConfig::default().with_dropout(1.5).is_err());
assert!(FlashConfig::default().with_dropout(-0.1).is_err());
assert!(FlashConfig::default().with_dropout(0.5).is_ok());
}
}