use super::config::FlashAttentionConfig;
use super::interop::has_cubecl_cuda_support;
use crate::error::{Result as UnslothResult, UnslothError};
use candle_core::Tensor;
#[cfg(feature = "cuda")]
use cubecl::prelude::*;
#[cfg(feature = "cuda")]
use cubecl_cuda::CudaRuntime;
pub const MAX_BLOCK_SIZE: u32 = 1024;
pub const WARP_SIZE: u32 = 32;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg(feature = "cuda")]
pub struct TileConfig {
pub tile_size: u32,
pub head_dim: u32,
pub seq_len: u32,
pub num_kv_tiles: u32,
pub causal: bool,
}
#[inline]
#[allow(dead_code)]
fn next_power_of_two(n: u32) -> u32 {
if n == 0 {
return 1;
}
let mut v = n - 1;
v |= v >> 1;
v |= v >> 2;
v |= v >> 4;
v |= v >> 8;
v |= v >> 16;
v + 1
}
#[cfg(feature = "cuda")]
#[cube(launch)]
fn flash_attention_tile<F: Float + CubeElement>(
q: &Array<F>, k: &Array<F>, v: &Array<F>, out: &mut Array<F>, scale: F, seq_len_val: u32,
head_dim_val: u32,
block_size_val: u32, ) {
let batch_head_idx = CUBE_POS_X; let q_row_idx = CUBE_POS_Y; let tid = UNIT_POS_X; let tid_usize = tid as usize;
let head_stride = (seq_len_val as usize) * (head_dim_val as usize);
let base_offset = (batch_head_idx as usize) * head_stride;
let is_active = tid_usize < (head_dim_val as usize);
let mut running_max = F::new(-1e30); let mut running_sum = F::new(0.0); let mut running_out = F::new(0.0);
let q_val = if is_active {
let q_offset = base_offset + ((q_row_idx as usize) * (head_dim_val as usize) + tid_usize);
q[q_offset]
} else {
F::new(0.0)
};
let mut score_tile = SharedMemory::<F>::new(1024usize);
for kv_idx in 0u32..(seq_len_val) {
let kv_idx_usize = kv_idx as usize;
let score_contrib = if is_active {
let k_offset = base_offset + (kv_idx_usize * (head_dim_val as usize) + tid_usize);
let k_val = k[k_offset];
q_val * k_val
} else {
F::new(0.0)
};
score_tile[tid_usize] = score_contrib;
sync_cube();
let mut stride = (block_size_val / 2) as usize;
while stride > 0 {
if tid_usize < stride {
let partner_idx = tid_usize + stride;
if partner_idx < (block_size_val as usize) {
score_tile[tid_usize] = score_tile[tid_usize] + score_tile[partner_idx];
}
}
sync_cube();
stride = stride / 2;
}
let score = score_tile[0] * scale;
if tid == 0 {
score_tile[0] = score;
}
sync_cube();
let attn_score = score_tile[0];
let new_max = F::max(running_max, attn_score);
let exp_old = F::exp(running_max - new_max);
let exp_new = F::exp(attn_score - new_max);
let new_sum = exp_old * running_sum + exp_new;
if is_active {
let v_offset = base_offset + (kv_idx_usize * (head_dim_val as usize) + tid_usize);
let v_val = v[v_offset];
running_out = (exp_old * running_sum * running_out + exp_new * v_val) / new_sum;
}
running_max = new_max;
running_sum = new_sum;
}
if is_active {
let out_offset = base_offset + ((q_row_idx as usize) * (head_dim_val as usize) + tid_usize);
out[out_offset] = running_out;
}
}
#[cfg(feature = "cuda")]
#[cube(launch)]
fn flash_attention_causal<F: Float + CubeElement>(
q: &Array<F>,
k: &Array<F>,
v: &Array<F>,
out: &mut Array<F>,
scale: F,
seq_len_val: u32,
head_dim_val: u32,
block_size_val: u32,
) {
let batch_head_idx = CUBE_POS_X;
let q_row_idx = CUBE_POS_Y;
let tid = UNIT_POS_X;
let tid_usize = tid as usize;
let head_stride = (seq_len_val as usize) * (head_dim_val as usize);
let base_offset = (batch_head_idx as usize) * head_stride;
let is_active = tid_usize < (head_dim_val as usize);
let mut running_max = F::new(-1e30);
let mut running_sum = F::new(0.0);
let mut running_out = F::new(0.0);
let q_val = if is_active {
let q_offset = base_offset + ((q_row_idx as usize) * (head_dim_val as usize) + tid_usize);
q[q_offset]
} else {
F::new(0.0)
};
let mut score_tile = SharedMemory::<F>::new(1024usize);
let max_kv_idx = q_row_idx + 1;
for kv_idx in 0u32..(max_kv_idx) {
let kv_idx_usize = kv_idx as usize;
let score_contrib = if is_active {
let k_offset = base_offset + (kv_idx_usize * (head_dim_val as usize) + tid_usize);
let k_val = k[k_offset];
q_val * k_val
} else {
F::new(0.0)
};
score_tile[tid_usize] = score_contrib;
sync_cube();
let mut stride = (block_size_val / 2) as usize;
while stride > 0 {
if tid_usize < stride {
let partner_idx = tid_usize + stride;
if partner_idx < (block_size_val as usize) {
score_tile[tid_usize] = score_tile[tid_usize] + score_tile[partner_idx];
}
}
sync_cube();
stride = stride / 2;
}
let score = score_tile[0] * scale;
if tid == 0 {
score_tile[0] = score;
}
sync_cube();
let attn_score = score_tile[0];
let new_max = F::max(running_max, attn_score);
let exp_old = F::exp(running_max - new_max);
let exp_new = F::exp(attn_score - new_max);
let new_sum = exp_old * running_sum + exp_new;
if is_active {
let v_offset = base_offset + (kv_idx_usize * (head_dim_val as usize) + tid_usize);
let v_val = v[v_offset];
running_out = (exp_old * running_sum * running_out + exp_new * v_val) / new_sum;
}
running_max = new_max;
running_sum = new_sum;
}
if is_active {
let out_offset = base_offset + ((q_row_idx as usize) * (head_dim_val as usize) + tid_usize);
out[out_offset] = running_out;
}
}
#[cfg(all(feature = "cuda", feature = "_phase2_tiled_kernel"))]
#[cube(launch)]
#[allow(dead_code)]
fn flash_attention_tiled<F: Float + CubeElement>(
q: &Array<F>, k: &Array<F>, v: &Array<F>, out: &mut Array<F>, scale: F, #[comptime] tile_size: u32,
#[comptime] head_dim: u32,
#[comptime] seq_len: u32,
#[comptime] num_kv_tiles: u32,
#[comptime] causal: bool,
) {
let batch_head_idx = CUBE_POS_X;
let q_tile_idx = CUBE_POS_Y;
let thread_in_tile = UNIT_POS_X;
let thread_in_tile_usize = thread_in_tile as usize;
let q_row_global = ((q_tile_idx as usize) * (tile_size as usize) + thread_in_tile_usize);
if q_row_global >= (seq_len as usize) {
terminate!();
}
let head_stride = (seq_len as usize) * (head_dim as usize);
let base_offset = (batch_head_idx as usize) * head_stride;
let mut q_tile = SharedMemory::<F>::new((tile_size as usize) * (head_dim as usize));
for dim_idx in 0u32..(head_dim) {
let dim_idx_usize = dim_idx as usize;
let q_offset = base_offset + (q_row_global * (head_dim as usize) + dim_idx_usize);
let tile_offset = thread_in_tile_usize * (head_dim as usize) + dim_idx_usize;
q_tile[tile_offset] = q[q_offset];
}
sync_cube();
let mut running_max = F::new(-1e30);
let mut running_sum = F::new(0.0);
let mut out_acc = SharedMemory::<F>::new((tile_size as usize) * (head_dim as usize));
for dim_idx in 0u32..(head_dim) {
let dim_idx_usize = dim_idx as usize;
out_acc[thread_in_tile_usize * (head_dim as usize) + dim_idx_usize] = F::new(0.0);
}
for kv_tile_idx in 0u32..(num_kv_tiles) {
let kv_start = ((kv_tile_idx as usize) * (tile_size as usize));
let kv_end = if kv_start + (tile_size as usize) < (seq_len as usize) {
kv_start + (tile_size as usize)
} else {
seq_len as usize
};
let kv_tile_actual_size = kv_end - kv_start;
let mut k_tile = SharedMemory::<F>::new((tile_size as usize) * (head_dim as usize));
let mut v_tile = SharedMemory::<F>::new((tile_size as usize) * (head_dim as usize));
for local_kv_idx in 0usize..(kv_tile_actual_size) {
if local_kv_idx % (tile_size as usize) == thread_in_tile_usize {
let kv_row_global = kv_start + local_kv_idx;
for dim_idx in 0u32..(head_dim) {
let dim_idx_usize = dim_idx as usize;
let k_offset =
base_offset + (kv_row_global * (head_dim as usize) + dim_idx_usize);
let v_offset =
base_offset + (kv_row_global * (head_dim as usize) + dim_idx_usize);
let tile_offset = local_kv_idx * (head_dim as usize) + dim_idx_usize;
k_tile[tile_offset] = k[k_offset];
v_tile[tile_offset] = v[v_offset];
}
}
}
sync_cube();
for local_kv_idx in 0usize..(kv_tile_actual_size) {
let kv_row_global = kv_start + local_kv_idx;
let should_process = if causal {
q_row_global >= kv_row_global
} else {
true
};
if should_process {
let mut score = F::new(0.0);
for dim_idx in 0u32..(head_dim) {
let dim_idx_usize = dim_idx as usize;
let q_val = q_tile[thread_in_tile_usize * (head_dim as usize) + dim_idx_usize];
let k_val = k_tile[local_kv_idx * (head_dim as usize) + dim_idx_usize];
score = score + q_val * k_val;
}
score = score * scale;
let new_max = F::max(running_max, score);
let exp_old = F::exp(running_max - new_max);
let exp_new = F::exp(score - new_max);
let new_sum = exp_old * running_sum + exp_new;
for dim_idx in 0u32..(head_dim) {
let dim_idx_usize = dim_idx as usize;
let out_offset = thread_in_tile_usize * (head_dim as usize) + dim_idx_usize;
let old_out = out_acc[out_offset];
let v_val = v_tile[local_kv_idx * (head_dim as usize) + dim_idx_usize];
let corrected_old = exp_old * running_sum * old_out;
let new_contrib = exp_new * v_val;
out_acc[out_offset] = (corrected_old + new_contrib) / new_sum;
}
running_max = new_max;
running_sum = new_sum;
}
}
sync_cube();
}
for dim_idx in 0u32..(head_dim) {
let dim_idx_usize = dim_idx as usize;
let out_offset = base_offset + (q_row_global * (head_dim as usize) + dim_idx_usize);
let tile_offset = thread_in_tile_usize * (head_dim as usize) + dim_idx_usize;
out[out_offset] = out_acc[tile_offset];
}
}
pub fn flash_attention_kernel(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
mask: Option<&Tensor>,
config: &FlashAttentionConfig,
) -> UnslothResult<Tensor> {
validate_attention_inputs(q, k, v)?;
if !has_cubecl_cuda_support() {
tracing::debug!("CubeCL CUDA not available, using fallback implementation");
return fallback_attention(q, k, v, scale, mask, config);
}
#[cfg(feature = "cuda")]
{
match launch_cubecl_attention(q, k, v, scale, config) {
Ok(output) => return Ok(output),
Err(e) => {
tracing::warn!("CubeCL kernel launch failed: {}, using fallback", e);
}
}
}
fallback_attention(q, k, v, scale, mask, config)
}
#[cfg(feature = "cuda")]
fn launch_cubecl_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
config: &FlashAttentionConfig,
) -> UnslothResult<Tensor> {
use super::interop::{candle_to_cubecl_handle, cubecl_to_candle_tensor};
let dims = q.dims();
let batch = dims[0];
let num_heads = dims[1];
let seq_len = dims[2];
let head_dim = dims[3];
tracing::debug!(
"Launching CubeCL Flash Attention: batch={}, heads={}, seq={}, dim={}",
batch,
num_heads,
seq_len,
head_dim
);
if head_dim > MAX_BLOCK_SIZE as usize {
return Err(UnslothError::InvalidConfig(format!(
"head_dim={} exceeds maximum supported size of {}. \
Consider using a model with smaller head dimensions.",
head_dim, MAX_BLOCK_SIZE
)));
}
let (q_bytes, _, _) = candle_to_cubecl_handle(q)?;
let (k_bytes, _, _) = candle_to_cubecl_handle(k)?;
let (v_bytes, _, _) = candle_to_cubecl_handle(v)?;
let num_elements = batch * num_heads * seq_len * head_dim;
let device = cubecl_cuda::CudaDevice::new(0);
let client = CudaRuntime::client(&device);
let q_handle = client.create(cubecl::bytes::Bytes::from_bytes_vec(q_bytes));
let k_handle = client.create(cubecl::bytes::Bytes::from_bytes_vec(k_bytes));
let v_handle = client.create(cubecl::bytes::Bytes::from_bytes_vec(v_bytes));
let out_handle = client.empty(num_elements * std::mem::size_of::<f32>());
let cube_count = CubeCount::Static((batch * num_heads) as u32, seq_len as u32, 1);
let block_size = next_power_of_two(head_dim as u32).min(MAX_BLOCK_SIZE);
let cube_dim = CubeDim::new(&client, block_size as usize);
let scale_f32 = scale as f32;
unsafe {
if config.causal_mask {
flash_attention_causal::launch::<f32, CudaRuntime>(
&client,
cube_count,
cube_dim,
ArrayArg::from_raw_parts::<f32>(&q_handle, num_elements, 1),
ArrayArg::from_raw_parts::<f32>(&k_handle, num_elements, 1),
ArrayArg::from_raw_parts::<f32>(&v_handle, num_elements, 1),
ArrayArg::from_raw_parts::<f32>(&out_handle, num_elements, 1),
ScalarArg::new(scale_f32),
ScalarArg::new(seq_len as u32),
ScalarArg::new(head_dim as u32),
ScalarArg::new(block_size),
)
.map_err(|e| {
UnslothError::Kernel(format!("flash_attention_causal kernel launch failed: {e}"))
})?;
} else {
flash_attention_tile::launch::<f32, CudaRuntime>(
&client,
cube_count,
cube_dim,
ArrayArg::from_raw_parts::<f32>(&q_handle, num_elements, 1),
ArrayArg::from_raw_parts::<f32>(&k_handle, num_elements, 1),
ArrayArg::from_raw_parts::<f32>(&v_handle, num_elements, 1),
ArrayArg::from_raw_parts::<f32>(&out_handle, num_elements, 1),
ScalarArg::new(scale_f32),
ScalarArg::new(seq_len as u32),
ScalarArg::new(head_dim as u32),
ScalarArg::new(block_size),
)
.map_err(|e| {
UnslothError::Kernel(format!("flash_attention_tile kernel launch failed: {e}"))
})?;
}
}
let output_bytes = client.read_one(out_handle);
cubecl_to_candle_tensor(
&output_bytes,
&[batch, num_heads, seq_len, head_dim],
q.device(),
)
}
fn validate_attention_inputs(q: &Tensor, k: &Tensor, v: &Tensor) -> UnslothResult<()> {
let q_dims = q.dims();
let k_dims = k.dims();
let v_dims = v.dims();
if q_dims.len() != 4 || k_dims.len() != 4 || v_dims.len() != 4 {
return Err(UnslothError::InvalidConfig(format!(
"Expected 4D tensors [batch, heads, seq, dim], got Q: {q_dims:?}, K: {k_dims:?}, V: {v_dims:?}"
)));
}
if q_dims[0] != k_dims[0] || q_dims[0] != v_dims[0] {
return Err(UnslothError::InvalidConfig(format!(
"Batch size mismatch: Q={}, K={}, V={}",
q_dims[0], k_dims[0], v_dims[0]
)));
}
if q_dims[3] != k_dims[3] || q_dims[3] != v_dims[3] {
return Err(UnslothError::InvalidConfig(format!(
"Head dimension mismatch: Q={}, K={}, V={}",
q_dims[3], k_dims[3], v_dims[3]
)));
}
if k_dims != v_dims {
return Err(UnslothError::InvalidConfig(format!(
"K and V shape mismatch: K={k_dims:?}, V={v_dims:?}"
)));
}
Ok(())
}
fn fallback_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
mask: Option<&Tensor>,
config: &FlashAttentionConfig,
) -> UnslothResult<Tensor> {
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
let scores = (scores * scale)?;
let scores = if config.causal_mask {
let seq_len = q.dims()[2];
let causal_mask = create_causal_mask_tensor(seq_len, q.device())?;
scores.broadcast_add(&causal_mask)?
} else {
scores
};
let scores = match mask {
Some(m) => scores.broadcast_add(m)?,
None => scores,
};
let attn_weights = candle_nn::ops::softmax(&scores, 3)?;
let output = attn_weights.matmul(v)?;
Ok(output)
}
fn create_causal_mask_tensor(
seq_len: usize,
device: &candle_core::Device,
) -> UnslothResult<Tensor> {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
if j > i {
mask_data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
}
let mask = Tensor::from_vec(mask_data, (1, 1, seq_len, seq_len), device)?;
Ok(mask)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_validate_attention_inputs_valid() {
let device = Device::Cpu;
let q = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
let k = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
let v = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
assert!(validate_attention_inputs(&q, &k, &v).is_ok());
}
#[test]
fn test_validate_attention_inputs_wrong_dims() {
let device = Device::Cpu;
let q = Tensor::zeros((2, 8, 64), candle_core::DType::F32, &device).unwrap(); let k = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
let v = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
assert!(validate_attention_inputs(&q, &k, &v).is_err());
}
#[test]
fn test_validate_attention_inputs_batch_mismatch() {
let device = Device::Cpu;
let q = Tensor::zeros((2, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
let k = Tensor::zeros((3, 4, 8, 64), candle_core::DType::F32, &device).unwrap(); let v = Tensor::zeros((3, 4, 8, 64), candle_core::DType::F32, &device).unwrap();
assert!(validate_attention_inputs(&q, &k, &v).is_err());
}
#[test]
fn test_flash_attention_kernel_shape() {
let device = Device::Cpu;
let batch = 2;
let num_heads = 4;
let seq_len = 8;
let head_dim = 64;
let q = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();
let scale = 1.0 / (head_dim as f64).sqrt();
let config = FlashAttentionConfig::default();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
assert_eq!(output.dims(), &[batch, num_heads, seq_len, head_dim]);
}
#[test]
fn test_flash_attention_kernel_numerical_stability() {
let device = Device::Cpu;
let q = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();
let k = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();
let v = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();
let scale = 1.0 / 8.0; let config = FlashAttentionConfig::default();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let values: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
for val in values {
assert!(!val.is_nan(), "Output contains NaN");
assert!(!val.is_infinite(), "Output contains Inf");
}
}
#[test]
fn test_flash_attention_with_config() {
let device = Device::Cpu;
let q = Tensor::randn(0.0f32, 1.0, (1, 2, 16, 64), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (1, 2, 16, 64), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (1, 2, 16, 64), &device).unwrap();
let scale = 1.0 / 8.0;
let config = FlashAttentionConfig::for_rtx_5080().with_causal_mask();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
assert_eq!(output.dims(), &[1, 2, 16, 64]);
}
fn reference_attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f64) -> Tensor {
let scores = q
.matmul(&k.transpose(2, 3).unwrap().contiguous().unwrap())
.unwrap();
let scores = (scores * scale).unwrap();
let attn_weights = candle_nn::ops::softmax(&scores, 3).unwrap();
attn_weights.matmul(v).unwrap()
}
fn mean_absolute_error(a: &Tensor, b: &Tensor) -> f32 {
let diff = (a - b).unwrap().abs().unwrap();
let mean = diff.mean_all().unwrap();
mean.to_scalar::<f32>().unwrap()
}
#[test]
fn test_numerical_equivalence_batch2_heads4_seq8_dim64() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (2, 4, 8, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let config = FlashAttentionConfig::default();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let reference = reference_attention(&q, &k, &v, scale);
let mae = mean_absolute_error(&output, &reference);
assert!(
mae < 1e-5,
"MAE {} exceeds tolerance 1e-5 for batch={}, heads={}, seq={}, dim={}",
mae,
batch,
heads,
seq,
dim
);
}
#[test]
fn test_numerical_equivalence_batch2_heads4_seq16_dim64() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (2, 4, 16, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let config = FlashAttentionConfig::default();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let reference = reference_attention(&q, &k, &v, scale);
let mae = mean_absolute_error(&output, &reference);
assert!(mae < 1e-5, "MAE {} exceeds tolerance 1e-5", mae);
}
#[test]
fn test_numerical_equivalence_batch2_heads4_seq32_dim64() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (2, 4, 32, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let config = FlashAttentionConfig::default();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let reference = reference_attention(&q, &k, &v, scale);
let mae = mean_absolute_error(&output, &reference);
assert!(mae < 1e-5, "MAE {} exceeds tolerance 1e-5", mae);
}
#[test]
fn test_numerical_equivalence_batch2_heads4_seq64_dim64() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (2, 4, 64, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let config = FlashAttentionConfig::default();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let reference = reference_attention(&q, &k, &v, scale);
let mae = mean_absolute_error(&output, &reference);
assert!(mae < 1e-5, "MAE {} exceeds tolerance 1e-5", mae);
}
#[test]
fn test_numerical_equivalence_batch1_heads1_seq8_dim64() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (1, 1, 8, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let config = FlashAttentionConfig::default();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let reference = reference_attention(&q, &k, &v, scale);
let mae = mean_absolute_error(&output, &reference);
assert!(mae < 1e-5, "MAE {} exceeds tolerance 1e-5", mae);
}
#[test]
fn test_numerical_equivalence_with_different_configs() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (2, 4, 16, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let configs = [
FlashAttentionConfig::default(),
FlashAttentionConfig::for_rtx_5080(),
FlashAttentionConfig::for_rtx_3090_ti(),
];
let reference = reference_attention(&q, &k, &v, scale);
for config in configs {
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let mae = mean_absolute_error(&output, &reference);
assert!(
mae < 1e-5,
"MAE {} exceeds tolerance for config {:?}",
mae,
config
);
}
}
#[test]
fn test_determinism_multiple_runs() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (2, 4, 8, 64);
let scale = 1.0 / (dim as f64).sqrt();
let data: Vec<f32> = (0..(batch * heads * seq * dim))
.map(|i| (i as f32 * 0.001).sin())
.collect();
let q = Tensor::from_vec(data.clone(), (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::from_vec(data.clone(), (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::from_vec(data, (batch, heads, seq, dim), &device).unwrap();
let config = FlashAttentionConfig::default();
let output1 = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let output2 = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let mae = mean_absolute_error(&output1, &output2);
assert!(mae < 1e-10, "Non-deterministic output: MAE = {}", mae);
}
#[test]
fn test_identity_attention_pattern() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (1, 1, 4, 64);
let scale = 1.0 / (dim as f64).sqrt();
let mut data = vec![0.0f32; seq * dim];
for i in 0..seq {
for j in 0..dim {
data[i * dim + j] = if j == i { 1.0 } else { 0.0 };
}
}
let q = Tensor::from_vec(data.clone(), (batch, heads, seq, dim), &device).unwrap();
let k = q.clone();
let v = q.clone();
let config = FlashAttentionConfig::default();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let mae = mean_absolute_error(&output, &v);
assert!(mae < 0.5, "Identity pattern MAE {} too high", mae);
}
#[test]
fn test_causal_masking_basic() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (1, 1, 4, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let config = FlashAttentionConfig::default().with_causal_mask();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
assert_eq!(output.dims(), &[batch, heads, seq, dim]);
let values: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
for val in values {
assert!(!val.is_nan() && !val.is_infinite());
}
}
#[test]
fn test_causal_vs_non_causal_difference() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (1, 2, 8, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let config_normal = FlashAttentionConfig::default();
let config_causal = FlashAttentionConfig::default().with_causal_mask();
let output_normal =
flash_attention_kernel(&q, &k, &v, scale, None, &config_normal).unwrap();
let output_causal =
flash_attention_kernel(&q, &k, &v, scale, None, &config_causal).unwrap();
let mae = mean_absolute_error(&output_normal, &output_causal);
assert!(mae > 1e-4, "Causal and non-causal outputs are too similar");
}
#[test]
fn test_causal_masking_numerical_equivalence() {
let device = Device::Cpu;
let (batch, heads, seq, dim) = (2, 4, 16, 64);
let scale = 1.0 / (dim as f64).sqrt();
let q = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, heads, seq, dim), &device).unwrap();
let causal_mask = create_causal_mask(seq, &device);
let config = FlashAttentionConfig::default().with_causal_mask();
let output = flash_attention_kernel(&q, &k, &v, scale, None, &config).unwrap();
let reference = reference_attention_with_mask(&q, &k, &v, scale, Some(&causal_mask));
let mae = mean_absolute_error(&output, &reference);
assert!(mae < 1e-5, "Causal MAE {} exceeds tolerance 1e-5", mae);
}
fn create_causal_mask(seq_len: usize, device: &Device) -> Tensor {
let mut mask_data = vec![0.0f32; seq_len * seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
if j > i {
mask_data[i * seq_len + j] = f32::NEG_INFINITY;
}
}
}
Tensor::from_vec(mask_data, (1, 1, seq_len, seq_len), device)
.unwrap()
.broadcast_as((1, 1, seq_len, seq_len))
.unwrap()
}
fn reference_attention_with_mask(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
mask: Option<&Tensor>,
) -> Tensor {
let scores = q
.matmul(&k.transpose(2, 3).unwrap().contiguous().unwrap())
.unwrap();
let scores = (scores * scale).unwrap();
let scores = if let Some(m) = mask {
scores.broadcast_add(m).unwrap()
} else {
scores
};
let attn_weights = candle_nn::ops::softmax(&scores, 3).unwrap();
attn_weights.matmul(v).unwrap()
}
}