use candle_core::{D, Result, Tensor};
#[inline]
pub fn sdpa(
q: &Tensor,
k: &Tensor,
v: &Tensor,
scale: f64,
is_causal: bool,
context_window: Option<usize>,
) -> Result<Tensor> {
let (_b, _h, q_len, _dim) = q.dims4()?;
let kv_len = k.dims()[2];
const TILING_THRESHOLD: usize = 512;
let k_t = k.transpose(2, 3)?;
if q_len < TILING_THRESHOLD {
let scores = (q.matmul(&k_t)? * scale)?;
let scores = if is_causal || context_window.is_some() {
let mask = generate_mask_chunk(
0,
q_len,
kv_len,
q_len,
is_causal,
context_window,
q.device(),
)?;
scores.broadcast_add(&mask)?
} else {
scores
};
let probs = candle_nn::ops::softmax(&scores, D::Minus1)?;
return probs.matmul(v);
}
let block_size = 128;
let mut outputs = Vec::new();
for start in (0..q_len).step_by(block_size) {
let end = std::cmp::min(start + block_size, q_len);
let len = end - start;
let q_chunk = q.narrow(2, start, len)?;
let scores = (q_chunk.matmul(&k_t)? * scale)?;
let scores = if is_causal || context_window.is_some() {
let mask_chunk = generate_mask_chunk(
start,
len,
kv_len,
q_len,
is_causal,
context_window,
q.device(),
)?;
scores.broadcast_add(&mask_chunk)?
} else {
scores
};
let probs = candle_nn::ops::softmax(&scores, D::Minus1)?;
let out_chunk = probs.matmul(v)?;
outputs.push(out_chunk);
}
Tensor::cat(&outputs, 2)
}
fn generate_mask_chunk(
start_q: usize,
num_q: usize,
k_len: usize,
total_q_len: usize,
is_causal: bool,
context_window: Option<usize>,
device: &candle_core::Device,
) -> Result<Tensor> {
let mask: Vec<f32> = (0..num_q)
.flat_map(|i_rel| {
let i_abs = start_q + i_rel;
(0..k_len).map(move |j| {
let shift = k_len.saturating_sub(total_q_len);
let pos_q = i_abs + shift;
let is_future = is_causal && (j > pos_q);
let is_out_of_context = if let Some(ctx) = context_window {
if pos_q >= ctx {
j <= pos_q - ctx
} else {
false
}
} else {
false
};
if is_future || is_out_of_context {
f32::NEG_INFINITY
} else {
0.0
}
})
})
.collect();
Tensor::from_vec(mask, (1, 1, num_q, k_len), device)
}