use candle_core::{D, Result, Tensor};
#[inline]
fn can_skip_mask_for_single_query(
q_len: usize,
kv_len: usize,
is_causal: bool,
context_window: Option<usize>,
) -> bool {
if !is_causal || q_len != 1 {
return false;
}
match context_window {
None => true,
Some(ctx) => kv_len <= ctx,
}
}
#[inline]
pub fn sdpa(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
is_causal: bool,
context_window: Option<usize>,
) -> Result<Tensor> {
let q = q.contiguous()?;
let k = k.contiguous()?;
let v = v.contiguous()?;
let (_b, _h, q_len, _dim) = q.dims4()?;
let kv_len = k.dims()[2];
const TILING_THRESHOLD: usize = 512;
let k_t = k.transpose(2, 3)?.contiguous()?;
if q_len < TILING_THRESHOLD {
let scores = (q.matmul(&k_t)? * scale)?;
let scores = if can_skip_mask_for_single_query(q_len, kv_len, is_causal, context_window) {
scores
} else if is_causal || context_window.is_some() {
let mask = generate_mask_chunk(
0,
q_len,
kv_len,
q_len,
is_causal,
context_window,
q.device(),
)?;
scores.broadcast_add(&mask)?
} else {
scores
};
let probs = candle_nn::ops::softmax(&scores, D::Minus1)?;
return probs.matmul(&v);
}
let block_size = 128;
let mut outputs = Vec::new();
for start in (0..q_len).step_by(block_size) {
let end = std::cmp::min(start + block_size, q_len);
let len = end - start;
let q_chunk = q.narrow(2, start, len)?;
let scores = (q_chunk.matmul(&k_t)? * scale)?;
let scores = if is_causal || context_window.is_some() {
let mask_chunk = generate_mask_chunk(
start,
len,
kv_len,
q_len,
is_causal,
context_window,
q.device(),
)?;
scores.broadcast_add(&mask_chunk)?
} else {
scores
};
let probs = candle_nn::ops::softmax(&scores, D::Minus1)?;
let out_chunk = probs.matmul(&v)?;
outputs.push(out_chunk);
}
Tensor::cat(&outputs, 2)
}
fn generate_mask_chunk(
start_q: usize,
num_q: usize,
k_len: usize,
total_q_len: usize,
is_causal: bool,
context_window: Option<usize>,
device: &candle_core::Device,
) -> Result<Tensor> {
let shift = k_len.saturating_sub(total_q_len);
let pos_q = (Tensor::arange(0u32, num_q as u32, device)?
.to_dtype(candle_core::DType::F32)?
.affine(1.0, (start_q + shift) as f64)?
.reshape((num_q, 1)))?;
let pos_k = Tensor::arange(0u32, k_len as u32, device)?
.to_dtype(candle_core::DType::F32)?
.reshape((1, k_len))?;
let mut mask = Tensor::zeros((num_q, k_len), candle_core::DType::F32, device)?;
if is_causal {
let is_future = pos_k.broadcast_gt(&pos_q)?;
mask = is_future.where_cond(
&Tensor::full(f32::NEG_INFINITY, (num_q, k_len), device)?,
&mask,
)?;
}
if let Some(ctx) = context_window {
let limit = pos_q.broadcast_sub(&Tensor::full(ctx as f32, (num_q, 1), device)?)?;
let is_out = pos_k.broadcast_le(&limit)?;
mask = is_out.where_cond(
&Tensor::full(f32::NEG_INFINITY, (num_q, k_len), device)?,
&mask,
)?;
}
mask.reshape((1, 1, num_q, k_len))
}
pub fn sdpa_chunked(
q: &Tensor,
k_chunks: &[Tensor],
v_chunks: &[Tensor],
scale: f64,
is_causal: bool,
context_window: Option<usize>,
) -> Result<Tensor> {
if k_chunks.is_empty() {
let (_b, h, _q, d) = q.dims4()?;
return Tensor::zeros((_b, h, _q, d), q.dtype(), q.device());
}
let device = q.device();
let dtype = q.dtype();
let q = q.contiguous()?;
let (b, h, q_len, d) = q.dims4()?;
let k_chunks: Vec<Tensor> = k_chunks
.iter()
.map(|t| t.contiguous())
.collect::<Result<_>>()?;
let v_chunks: Vec<Tensor> = v_chunks
.iter()
.map(|t| t.contiguous())
.collect::<Result<_>>()?;
if k_chunks.len() == 1 {
let k_t = k_chunks[0].transpose(2, 3)?.contiguous()?;
let scores = (q.matmul(&k_t)? * scale)?;
let kv_len = k_chunks[0].dims()[2];
let masked_scores =
if can_skip_mask_for_single_query(q_len, kv_len, is_causal, context_window) {
scores
} else if is_causal || context_window.is_some() {
let mask = generate_mask_chunk(
0,
q_len,
kv_len,
q_len,
is_causal,
context_window,
device,
)?;
scores.broadcast_add(&mask)?
} else {
scores
};
let probs = candle_nn::ops::softmax(&masked_scores, D::Minus1)?;
return probs.matmul(&v_chunks[0]);
}
let mut score_chunks = Vec::with_capacity(k_chunks.len());
let mut total_kv_len = 0;
for k_chunk in k_chunks {
total_kv_len += k_chunk.dims()[2];
let k_t = k_chunk.transpose(2, 3)?.contiguous()?;
let score_chunk = (q.matmul(&k_t)? * scale)?;
score_chunks.push(score_chunk);
}
let all_scores = Tensor::cat(&score_chunks, 3)?;
let masked_scores =
if can_skip_mask_for_single_query(q_len, total_kv_len, is_causal, context_window) {
all_scores
} else if is_causal || context_window.is_some() {
let mask = generate_mask_chunk(
0,
q_len,
total_kv_len,
q_len,
is_causal,
context_window,
device,
)?;
all_scores.broadcast_add(&mask)?
} else {
all_scores
};
let probs = candle_nn::ops::softmax(&masked_scores, D::Minus1)?;
let mut output = Tensor::zeros((b, h, q_len, d), dtype, device)?;
let mut offset = 0;
for v_chunk in v_chunks {
let chunk_len = v_chunk.dims()[2];
let probs_chunk = probs.narrow(3, offset, chunk_len)?.contiguous()?;
let out_chunk = probs_chunk.matmul(&v_chunk)?;
output = (output + out_chunk)?;
offset += chunk_len;
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_generate_mask_chunk_causal() -> Result<()> {
let device = Device::Cpu;
let mask = generate_mask_chunk(0, 1, 5, 1, true, None, &device)?;
let mask_data = mask.flatten_all()?.to_vec1::<f32>()?;
assert_eq!(mask_data, vec![0.0, 0.0, 0.0, 0.0, 0.0]);
let mask = generate_mask_chunk(0, 3, 3, 3, true, None, &device)?;
let mask_data = mask.reshape((3, 3))?.to_vec2::<f32>()?;
assert_eq!(
mask_data[0],
vec![0.0, f32::NEG_INFINITY, f32::NEG_INFINITY]
);
assert_eq!(mask_data[1], vec![0.0, 0.0, f32::NEG_INFINITY]);
assert_eq!(mask_data[2], vec![0.0, 0.0, 0.0]);
Ok(())
}
#[test]
fn test_generate_mask_chunk_window() -> Result<()> {
let device = Device::Cpu;
let mask = generate_mask_chunk(0, 1, 6, 1, false, Some(2), &device)?;
let mask_data = mask.flatten_all()?.to_vec1::<f32>()?;
assert_eq!(
mask_data,
vec![
f32::NEG_INFINITY,
f32::NEG_INFINITY,
f32::NEG_INFINITY,
f32::NEG_INFINITY,
0.0,
0.0
]
);
Ok(())
}
#[test]
fn test_can_skip_mask_for_single_query() {
assert!(can_skip_mask_for_single_query(1, 64, true, None));
assert!(can_skip_mask_for_single_query(1, 64, true, Some(64)));
assert!(!can_skip_mask_for_single_query(1, 65, true, Some(64)));
assert!(!can_skip_mask_for_single_query(2, 64, true, None));
assert!(!can_skip_mask_for_single_query(1, 64, false, None));
}
#[test]
fn test_sdpa_handles_non_contiguous_inputs() -> Result<()> {
let device = Device::Cpu;
let scale = 1.0 / (64f64).sqrt();
let q_base = Tensor::zeros((1, 8, 64, 128), candle_core::DType::F32, &device)?;
let q = q_base.transpose(2, 3)?;
let k = Tensor::zeros((1, 8, 1600, 64), candle_core::DType::F32, &device)?;
let v_base = Tensor::zeros((1, 8, 64, 1601), candle_core::DType::F32, &device)?;
let v = v_base.transpose(2, 3)?.narrow(2, 1, 1600)?;
let out = sdpa(&q, &k, &v, scale, true, None)?;
assert_eq!(out.dims(), &[1, 8, 128, 64]);
Ok(())
}
#[test]
fn test_sdpa_chunked_handles_non_contiguous_value_chunks() -> Result<()> {
let device = Device::Cpu;
let scale = 1.0 / (32f64).sqrt();
let q = Tensor::zeros((1, 4, 64, 32), candle_core::DType::F32, &device)?;
let k_full = Tensor::zeros((1, 4, 320, 32), candle_core::DType::F32, &device)?;
let v_base = Tensor::zeros((1, 4, 32, 321), candle_core::DType::F32, &device)?;
let v_full = v_base.transpose(2, 3)?.narrow(2, 1, 320)?;
let k_chunks = vec![k_full.narrow(2, 0, 160)?, k_full.narrow(2, 160, 160)?];
let v_chunks = vec![v_full.narrow(2, 0, 160)?, v_full.narrow(2, 160, 160)?];
let out = sdpa_chunked(&q, &k_chunks, &v_chunks, scale, true, None)?;
assert_eq!(out.dims(), &[1, 4, 64, 32]);
Ok(())
}
}