use super::ops;
#[cfg(test)]
fn naive_attention(q: &[f32], k: &[f32], v: &[f32], n: usize, d: usize, output: &mut [f32]) {
let mut scores = vec![0.0f32; n * n];
ops::score_matrix(q, k, n, n, d, &mut scores);
ops::softmax_rows(&mut scores, n, n);
ops::matmul_sv(&scores, v, n, n, d, output);
}
pub fn flash_attention_scalar(
q: &[f32],
k: &[f32],
v: &[f32],
n: usize,
d: usize,
tile_size: usize,
output: &mut [f32],
) {
assert_eq!(
q.len(),
n * d,
"Q dimension mismatch: expected {} got {}",
n * d,
q.len()
);
assert_eq!(
k.len(),
n * d,
"K dimension mismatch: expected {} got {}",
n * d,
k.len()
);
assert_eq!(
v.len(),
n * d,
"V dimension mismatch: expected {} got {}",
n * d,
v.len()
);
assert_eq!(
output.len(),
n * d,
"output dimension mismatch: expected {} got {}",
n * d,
output.len()
);
assert!(tile_size > 0, "tile_size must be > 0");
let scale = 1.0 / (d as f32).sqrt();
for i in 0..n {
let mut running_max = f32::NEG_INFINITY;
let mut running_sum = 0.0f32;
let mut acc = vec![0.0f32; d];
let mut tile_start = 0;
while tile_start < n {
let tile_end = (tile_start + tile_size).min(n);
process_tile(
q,
k,
v,
i,
d,
scale,
tile_start,
tile_end,
&mut running_max,
&mut running_sum,
&mut acc,
);
tile_start = tile_end;
}
normalize_row(&acc, running_sum, &mut output[i * d..(i + 1) * d]);
}
}
#[allow(clippy::too_many_arguments)]
fn process_tile(
q: &[f32],
k: &[f32],
v: &[f32],
i: usize,
d: usize,
scale: f32,
tile_start: usize,
tile_end: usize,
running_max: &mut f32,
running_sum: &mut f32,
acc: &mut [f32],
) {
let tile_len = tile_end - tile_start;
let mut tile_scores = vec![0.0f32; tile_len];
let q_row = &q[i * d..(i + 1) * d];
for (tj, j) in (tile_start..tile_end).enumerate() {
tile_scores[tj] = ops::dot(q_row, &k[j * d..(j + 1) * d]) * scale;
}
let tile_max = tile_scores
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let new_max = running_max.max(tile_max);
let correction = (*running_max - new_max).exp();
for a in acc.iter_mut() {
*a *= correction;
}
*running_sum *= correction;
for (tj, j) in (tile_start..tile_end).enumerate() {
let w = (tile_scores[tj] - new_max).exp();
ops::weighted_accumulate(acc, w, &v[j * d..(j + 1) * d]);
*running_sum += w;
}
*running_max = new_max;
}
fn normalize_row(acc: &[f32], running_sum: f32, output: &mut [f32]) {
if running_sum > 0.0 {
for (o, a) in output.iter_mut().zip(acc.iter()) {
*o = a / running_sum;
}
} else {
for o in output.iter_mut() {
*o = 0.0;
}
}
}
#[cfg(target_arch = "x86_64")]
#[target_feature(enable = "avx2")]
pub unsafe fn flash_attention_avx2(
q: &[f32],
k: &[f32],
v: &[f32],
n: usize,
d: usize,
tile_size: usize,
output: &mut [f32],
) {
flash_attention_scalar(q, k, v, n, d, tile_size, output);
}
include!("flash_attention_ptx.rs");
#[cfg(test)]
mod tests {
include!("flash_attention_tests.rs");
}