use candle_core::{DType, Device, Result, Tensor, D};
use std::sync::OnceLock;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AttentionBackend {
Math,
Flash,
}
static FLASH_FALLBACK_WARNED: OnceLock<()> = OnceLock::new();
impl AttentionBackend {
pub fn resolve() -> AttentionBackend {
static CACHED: OnceLock<AttentionBackend> = OnceLock::new();
*CACHED.get_or_init(|| {
let backend = parse_backend_env(std::env::var("MOLD_ATTN").ok().as_deref());
tracing::info!(backend = ?backend, "attention backend selected");
backend
})
}
}
fn parse_backend_env(raw: Option<&str>) -> AttentionBackend {
if let Some(value) = raw {
match value.trim().to_ascii_lowercase().as_str() {
"flash" => return AttentionBackend::Flash,
"math" => return AttentionBackend::Math,
"sdpa" => {
tracing::warn!(
"MOLD_ATTN=sdpa was removed (it was a no-op alias for math); using math"
);
return AttentionBackend::Math;
}
other if !other.is_empty() => {
tracing::warn!(
"MOLD_ATTN={other} is not one of flash/math; falling back to default"
);
}
_ => {}
}
}
default_backend()
}
pub(crate) fn warn_flash_fallback_once() -> bool {
let mut fired = false;
FLASH_FALLBACK_WARNED.get_or_init(|| {
tracing::warn!(
"attention backend 'flash' requested but FlashAttention FFI is gated off \
(build with --features cuda,flash-attn AND RUSTFLAGS='--cfg mold_flash_attn_real'); \
falling back to math"
);
fired = true;
});
fired
}
#[cfg(test)]
pub(crate) fn flash_fallback_warned() -> bool {
FLASH_FALLBACK_WARNED.get().is_some()
}
#[cfg(feature = "flash-attn")]
fn default_backend() -> AttentionBackend {
AttentionBackend::Flash
}
#[cfg(not(feature = "flash-attn"))]
fn default_backend() -> AttentionBackend {
AttentionBackend::Math
}
pub fn attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32) -> Result<Tensor> {
match AttentionBackend::resolve() {
AttentionBackend::Flash => flash_attention(q, k, v, scale),
AttentionBackend::Math => math_attention(q, k, v, scale),
}
}
pub fn attention_default_scale(q: &Tensor, k: &Tensor, v: &Tensor) -> Result<Tensor> {
let head_dim = q.dim(D::Minus1)?;
let scale = 1.0 / (head_dim as f64).sqrt();
attention(q, k, v, scale as f32)
}
static CHUNKED_MATH_LOGGED: OnceLock<()> = OnceLock::new();
pub fn math_attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32) -> Result<Tensor> {
math_attention_impl(q, k, v, scale, math_attention_chunk_size(q))
}
fn math_attention_impl(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f32,
chunk_size: Option<usize>,
) -> Result<Tensor> {
let mut batch_dims = q.dims().to_vec();
batch_dims.pop();
batch_dims.pop();
let q3 = q.flatten_to(batch_dims.len() - 1)?;
let k3 = k.flatten_to(batch_dims.len() - 1)?;
let v3 = v.flatten_to(batch_dims.len() - 1)?;
let attn = if let Some(chunk_size) = chunk_size {
math_attention_chunked_flat(&q3, &k3, &v3, scale, chunk_size)?
} else {
let attn_weights = (q3.matmul(&k3.t()?)? * f64::from(scale))?;
candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(&v3)?
};
batch_dims.push(attn.dim(D::Minus2)?);
batch_dims.push(attn.dim(D::Minus1)?);
attn.reshape(batch_dims)
}
fn math_attention_chunk_size(q: &Tensor) -> Option<usize> {
let q_len = q.dim(D::Minus2).ok()?;
if let Ok(raw) = std::env::var("MOLD_ATTN_CHUNK") {
let trimmed = raw.trim();
if trimmed == "0" || trimmed.eq_ignore_ascii_case("off") {
return None;
}
match trimmed.parse::<usize>() {
Ok(size) if size > 0 && size < q_len => return Some(size),
Ok(_) => return None,
Err(_) => tracing::warn!(
value = trimmed,
"MOLD_ATTN_CHUNK must be a positive integer, 0, or off; using default"
),
}
}
if matches!(q.device(), Device::Cuda(_)) && q_len > 1024 {
Some(512)
} else {
None
}
}
fn math_attention_chunked_flat(
q3: &Tensor,
k3: &Tensor,
v3: &Tensor,
scale: f32,
chunk_size: usize,
) -> Result<Tensor> {
let q_len = q3.dim(1)?;
let k_t = k3.t()?;
let mut chunks = Vec::with_capacity(q_len.div_ceil(chunk_size));
let mut start = 0;
while start < q_len {
let len = (q_len - start).min(chunk_size);
let q_chunk = q3.narrow(1, start, len)?;
let attn_weights = (q_chunk.matmul(&k_t)? * f64::from(scale))?;
let attn = candle_nn::ops::softmax_last_dim(&attn_weights)?.matmul(v3)?;
chunks.push(attn);
start += len;
}
CHUNKED_MATH_LOGGED.get_or_init(|| {
tracing::info!(
chunk_size,
q_len,
"using chunked math attention to reduce peak VRAM"
);
});
let refs: Vec<&Tensor> = chunks.iter().collect();
Tensor::cat(&refs, 1)
}
pub fn flash_attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32) -> Result<Tensor> {
if !flash_is_eligible(q) {
return math_attention(q, k, v, scale);
}
#[cfg(all(feature = "flash-attn", mold_flash_attn_real))]
{
let q_t = q.transpose(1, 2)?.contiguous()?;
let k_t = k.transpose(1, 2)?.contiguous()?;
let v_t = v.transpose(1, 2)?.contiguous()?;
let out = candle_flash_attn::flash_attn(&q_t, &k_t, &v_t, scale, false)?;
return out.transpose(1, 2)?.contiguous();
}
#[cfg(not(all(feature = "flash-attn", mold_flash_attn_real)))]
{
warn_flash_fallback_once();
}
math_attention(q, k, v, scale)
}
fn flash_is_eligible(q: &Tensor) -> bool {
matches!(q.device(), Device::Cuda(_)) && matches!(q.dtype(), DType::F16 | DType::BF16)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
fn cpu() -> Device {
Device::Cpu
}
fn reference_attention(q: &Tensor, k: &Tensor, v: &Tensor, scale: f32) -> Tensor {
let q = q.to_dtype(DType::F32).unwrap();
let k = k.to_dtype(DType::F32).unwrap();
let v = v.to_dtype(DType::F32).unwrap();
let weights = q.matmul(&k.t().unwrap()).unwrap();
let weights = (weights * scale as f64).unwrap();
let weights = candle_nn::ops::softmax_last_dim(&weights).unwrap();
weights.matmul(&v).unwrap()
}
fn rand_qkv(shape: (usize, usize, usize, usize)) -> (Tensor, Tensor, Tensor) {
let dev = cpu();
let q = Tensor::randn(0.0_f32, 1.0_f32, shape, &dev).unwrap();
let k = Tensor::randn(0.0_f32, 1.0_f32, shape, &dev).unwrap();
let v = Tensor::randn(0.0_f32, 1.0_f32, shape, &dev).unwrap();
(q, k, v)
}
fn max_abs_diff(a: &Tensor, b: &Tensor) -> f32 {
let diff = (a - b).unwrap().abs().unwrap();
diff.flatten_all()
.unwrap()
.max(0)
.unwrap()
.to_scalar::<f32>()
.unwrap()
}
#[test]
fn test_math_attention_matches_reference() {
let (q, k, v) = rand_qkv((2, 4, 16, 32));
let scale = 1.0 / (32f32).sqrt();
let got = math_attention(&q, &k, &v, scale).unwrap();
let want = reference_attention(&q, &k, &v, scale);
assert_eq!(got.dims(), &[2, 4, 16, 32]);
assert!(
max_abs_diff(&got, &want) < 1e-5,
"math attention diverged from reference"
);
}
#[test]
fn test_chunked_math_attention_matches_full_math() {
let (q, k, v) = rand_qkv((1, 3, 17, 16));
let scale = 1.0 / (16f32).sqrt();
let full = math_attention_impl(&q, &k, &v, scale, None).unwrap();
let chunked = math_attention_impl(&q, &k, &v, scale, Some(5)).unwrap();
assert_eq!(chunked.dims(), full.dims());
assert!(
max_abs_diff(&chunked, &full) < 1e-5,
"chunked math attention diverged from full math"
);
}
#[test]
fn test_flash_falls_back_on_cpu() {
let (q, k, v) = rand_qkv((1, 2, 8, 16));
let scale = 1.0 / (16f32).sqrt();
let math = math_attention(&q, &k, &v, scale).unwrap();
let flash = flash_attention(&q, &k, &v, scale).unwrap();
assert!(max_abs_diff(&math, &flash) < 1e-5);
}
#[test]
fn test_attention_default_scale() {
let (q, k, v) = rand_qkv((1, 2, 4, 8));
let scale = 1.0 / (8f32).sqrt();
let explicit = math_attention(&q, &k, &v, scale).unwrap();
let implicit = attention_default_scale(&q, &k, &v).unwrap();
assert!(max_abs_diff(&explicit, &implicit) < 1e-5);
}
#[test]
fn test_resolve_backend_from_env() {
assert_eq!(parse_backend_env(Some("flash")), AttentionBackend::Flash);
assert_eq!(parse_backend_env(Some("FLASH")), AttentionBackend::Flash);
assert_eq!(parse_backend_env(Some("math")), AttentionBackend::Math);
assert_eq!(parse_backend_env(Some("xformers")), default_backend());
assert_eq!(parse_backend_env(Some("")), default_backend());
assert_eq!(parse_backend_env(None), default_backend());
}
#[test]
fn resolve_returns_only_known_backends() {
assert_eq!(parse_backend_env(Some("sdpa")), AttentionBackend::Math);
assert_eq!(parse_backend_env(Some("SDPA")), AttentionBackend::Math);
assert_eq!(parse_backend_env(Some(" sdpa ")), AttentionBackend::Math);
for value in ["flash", "math"] {
let backend = parse_backend_env(Some(value));
assert!(matches!(
backend,
AttentionBackend::Flash | AttentionBackend::Math
));
}
}
#[test]
fn flash_fallback_warns_once() {
let first = warn_flash_fallback_once();
let second = warn_flash_fallback_once();
let third = warn_flash_fallback_once();
assert!(
!(second || third),
"warn_flash_fallback_once must not re-fire after the first call"
);
if first {
assert!(
flash_fallback_warned(),
"OnceLock state must reflect that the warning fired"
);
}
assert!(
flash_fallback_warned(),
"warn_flash_fallback_once must always leave the latch set"
);
}
#[test]
#[cfg(not(feature = "flash-attn"))]
fn test_resolve_default_without_feature() {
assert_eq!(default_backend(), AttentionBackend::Math);
assert_eq!(parse_backend_env(None), AttentionBackend::Math);
}
#[test]
#[cfg(feature = "flash-attn")]
fn test_resolve_default_with_feature() {
assert_eq!(default_backend(), AttentionBackend::Flash);
}
}