use crate::error::{Error, Result};
use numr::autograd::{Var, var_add, var_matmul, var_mul_scalar, var_softmax, var_transpose};
use numr::dtype::DType;
use numr::ops::ScalarOps;
use numr::runtime::{Runtime, RuntimeClient};
pub fn scaled_dot_product_attention_impl<R, C>(
client: &C,
q: &Var<R>,
k: &Var<R>,
v: &Var<R>,
scale: f64,
causal: bool,
) -> Result<Var<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R> + ScalarOps<R>,
R::Client: ScalarOps<R>,
{
let q_shape = q.tensor().shape().to_vec();
let k_shape = k.tensor().shape().to_vec();
let v_shape = v.tensor().shape().to_vec();
if q_shape.len() != 4 || k_shape.len() != 4 || v_shape.len() != 4 {
return Err(Error::InvalidArgument {
arg: "q/k/v",
reason: "expected 4D tensors [B, H, S, D]".into(),
});
}
if q_shape[0] != k_shape[0] || q_shape[1] != k_shape[1] {
return Err(Error::InvalidArgument {
arg: "k",
reason: format!("B,H mismatch: q={:?}, k={:?}", q_shape, k_shape),
});
}
if k_shape[0] != v_shape[0] || k_shape[1] != v_shape[1] || k_shape[2] != v_shape[2] {
return Err(Error::InvalidArgument {
arg: "v",
reason: format!("B,H,S_kv mismatch: k={:?}, v={:?}", k_shape, v_shape),
});
}
if q_shape[3] != k_shape[3] {
return Err(Error::InvalidArgument {
arg: "k",
reason: format!("D_k mismatch: q D={}, k D={}", q_shape[3], k_shape[3]),
});
}
let k_t = var_transpose(k).map_err(Error::Numr)?;
let scores = var_matmul(q, &k_t, client).map_err(Error::Numr)?;
let scores = var_mul_scalar(&scores, scale, client).map_err(Error::Numr)?;
let scores = if causal {
let s_q = q_shape[2];
let s_kv = k_shape[2];
let mut mask_data = vec![0.0f32; s_q * s_kv];
for i in 0..s_q {
for j in (i + 1)..s_kv {
mask_data[i * s_kv + j] = f32::NEG_INFINITY;
}
}
let mask_tensor = numr::tensor::Tensor::<R>::from_slice(
&mask_data,
&[1, 1, s_q, s_kv],
q.tensor().device(),
);
let mask = Var::new(mask_tensor, false);
var_add(&scores, &mask, client).map_err(Error::Numr)?
} else {
scores
};
let weights = var_softmax(&scores, -1, client).map_err(Error::Numr)?;
var_matmul(&weights, v, client).map_err(Error::Numr)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::Tensor;
#[test]
fn test_sdpa_different_kv_dims() {
let (client, device) = cpu_setup();
let b = 1;
let h = 2;
let s = 4;
let d_k = 8; let d_v = 6;
let q = Var::new(
Tensor::<CpuRuntime>::from_slice(
&vec![0.1f32; b * h * s * d_k],
&[b, h, s, d_k],
&device,
),
false,
);
let k = Var::new(
Tensor::<CpuRuntime>::from_slice(
&vec![0.1f32; b * h * s * d_k],
&[b, h, s, d_k],
&device,
),
false,
);
let v = Var::new(
Tensor::<CpuRuntime>::from_slice(
&vec![0.1f32; b * h * s * d_v],
&[b, h, s, d_v],
&device,
),
false,
);
let scale = 1.0 / (d_k as f64).sqrt();
let out = scaled_dot_product_attention_impl(&client, &q, &k, &v, scale, true).unwrap();
assert_eq!(out.tensor().shape(), &[b, h, s, d_v]);
}
#[test]
fn test_sdpa_same_dims() {
let (client, device) = cpu_setup();
let b = 1;
let h = 1;
let s = 3;
let d = 4;
let q = Var::new(
Tensor::<CpuRuntime>::from_slice(&vec![0.1f32; b * h * s * d], &[b, h, s, d], &device),
false,
);
let k = q.clone();
let v = q.clone();
let scale = 1.0 / (d as f64).sqrt();
let out = scaled_dot_product_attention_impl(&client, &q, &k, &v, scale, false).unwrap();
assert_eq!(out.tensor().shape(), &[b, h, s, d]);
}
}