use crate::error::Result;
use candle_core::Tensor;
#[must_use]
pub fn has_cubecl_support() -> bool {
#[cfg(feature = "cuda")]
{
use candle_core::Device;
matches!(Device::cuda_if_available(0), Ok(Device::Cuda(_)))
}
#[cfg(not(feature = "cuda"))]
{
false
}
}
pub fn flash_attention_cubecl(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let q_shape = q.dims();
let k_shape = k.dims();
let v_shape = v.dims();
if q_shape.len() != 4 || k_shape.len() != 4 || v_shape.len() != 4 {
return Err(crate::error::UnslothError::InvalidConfig(format!(
"Expected 4D tensors, got Q: {q_shape:?}, K: {k_shape:?}, V: {v_shape:?}"
)));
}
let (batch, num_heads, seq_len, head_dim) = (q_shape[0], q_shape[1], q_shape[2], q_shape[3]);
let num_kv_heads = k_shape[1];
tracing::debug!(
"Flash Attention CubeCL: batch={}, heads={}/{}, seq={}, dim={}",
batch,
num_heads,
num_kv_heads,
seq_len,
head_dim
);
#[cfg(feature = "cuda")]
{
if q.device().is_cuda() && has_cubecl_support() {
use super::cubecl::{flash_attention_kernel, FlashAttentionConfig};
let config = FlashAttentionConfig::default().with_head_dim(head_dim as u32);
return flash_attention_kernel(q, k, v, scale, mask, &config);
}
}
flash_attention_fallback(q, k, v, scale, mask)
}
fn flash_attention_fallback(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
mask: Option<&Tensor>,
) -> Result<Tensor> {
let scores = q.matmul(&k.transpose(2, 3)?.contiguous()?)?;
let scores = (scores * scale)?;
let scores = match mask {
Some(m) => scores.broadcast_add(m)?,
None => scores,
};
let attn_weights = candle_nn::ops::softmax(&scores, 3)?;
let output = attn_weights.matmul(v)?;
Ok(output)
}
#[must_use]
pub fn estimate_flash_attention_vram(
batch_size: usize,
num_heads: usize,
seq_len: usize,
head_dim: usize,
tile_size: usize,
) -> usize {
let bytes_per_elem = 4;
let qkv_size = 3 * batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
let output_size = batch_size * num_heads * seq_len * head_dim * bytes_per_elem;
let workspace = batch_size * num_heads * (2 * tile_size) * bytes_per_elem;
qkv_size + output_size + workspace
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_has_cubecl_support() {
let _ = has_cubecl_support();
}
#[test]
fn test_flash_attention_shape() {
let device = Device::Cpu;
let (batch, num_heads, seq_len, head_dim) = (2, 4, 8, 64);
let q = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (batch, num_heads, seq_len, head_dim), &device).unwrap();
let scale = 1.0 / (head_dim as f64).sqrt();
let output = flash_attention_cubecl(&q, &k, &v, scale, None).unwrap();
assert_eq!(output.dims(), &[batch, num_heads, seq_len, head_dim]);
}
#[test]
fn test_flash_attention_numerical_stability() {
let device = Device::Cpu;
let q = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();
let k = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();
let v = Tensor::randn(0.0f32, 10.0, (1, 2, 4, 64), &device).unwrap();
let scale = 1.0 / 8.0;
let output = flash_attention_cubecl(&q, &k, &v, scale, None).unwrap();
let values: Vec<f32> = output.flatten_all().unwrap().to_vec1().unwrap();
for v in values {
assert!(!v.is_nan() && !v.is_infinite());
}
}
#[test]
fn test_flash_attention_invalid_shape() {
let device = Device::Cpu;
let q = Tensor::randn(0.0f32, 1.0, (2, 8, 64), &device).unwrap();
let k = Tensor::randn(0.0f32, 1.0, (2, 8, 64), &device).unwrap();
let v = Tensor::randn(0.0f32, 1.0, (2, 8, 64), &device).unwrap();
let result = flash_attention_cubecl(&q, &k, &v, 1.0, None);
assert!(result.is_err());
}
#[test]
fn test_estimate_flash_attention_vram() {
let vram = estimate_flash_attention_vram(4, 12, 2048, 64, 128);
assert!(vram > 1_000_000);
assert!(vram < 10_000_000_000);
let vram_2x_batch = estimate_flash_attention_vram(8, 12, 2048, 64, 128);
assert!(vram_2x_batch > vram);
}
}