use crate::Tensor;
use bon::bon;
use snafu::ensure;
use svod_dtype::DType;
use svod_ir::ConstValue;
use crate::error::FloatDTypeRequiredSnafu;
type Result<T> = crate::Result<T>;
impl Tensor {
pub fn embedding(&self, indices: &Tensor) -> Result<Tensor> {
let weight_shape = self.shape()?;
let embed_dim = weight_shape[1].as_const().expect("embedding weight dim 1 must be concrete") as isize;
let idx_shape = indices.shape()?;
let flat = indices.try_reshape([-1])?;
let expanded = flat.try_unsqueeze(-1)?.try_expand([-1, embed_dim])?;
let gathered = self.gather(0, &expanded)?;
let mut out_shape: Vec<isize> =
idx_shape.iter().map(|d| d.as_const().expect("embedding index dims must be concrete") as isize).collect();
out_shape.push(embed_dim);
gathered.try_reshape(&out_shape)
}
pub fn apply_rotary_emb(&self, cos: &Tensor, sin: &Tensor, interleaved: bool) -> Result<Tensor> {
let shape = self.shape()?;
let last_dim = shape
.last()
.expect("apply_rotary_emb requires non-scalar input")
.as_const()
.expect("last dim must be concrete");
let half = last_dim / 2;
let (x1, x2) = if interleaved {
let mut rs: Vec<isize> = shape
.iter()
.take(shape.len() - 1)
.map(|d| d.as_const().expect("dims must be concrete") as isize)
.collect();
rs.push(half as isize);
rs.push(2);
let r = self.try_reshape(&rs)?;
let p = r.split(&[1, 1], -1)?;
(p[0].try_squeeze(Some(-1))?, p[1].try_squeeze(Some(-1))?)
} else {
let p = self.split(&[half, half], -1)?;
(p[0].clone(), p[1].clone())
};
let real = x1.try_mul(cos)?.try_sub(&x2.try_mul(sin)?)?;
let imag = x1.try_mul(sin)?.try_add(&x2.try_mul(cos)?)?;
if interleaved {
let stacked = Tensor::stack(&[&real, &imag], -1)?;
let mut fs: Vec<isize> = shape.iter().map(|d| d.as_const().unwrap() as isize).collect();
let _ = fs.last_mut().map(|d| *d = last_dim as isize);
stacked.try_reshape(&fs)
} else {
Tensor::cat(&[&real, &imag], -1)
}
}
}
#[bon]
impl Tensor {
#[builder]
pub fn scaled_dot_product_attention(
&self,
key: &Tensor,
value: &Tensor,
attn_mask: Option<&Tensor>,
scale: Option<f64>,
#[builder(default)] is_causal: bool,
softcap: Option<f64>,
) -> Result<Tensor> {
let q_dtype = self.uop().dtype();
ensure!(
q_dtype.is_float(),
FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "query", dtype: q_dtype.clone() }
);
let k_dtype = key.uop().dtype();
ensure!(
k_dtype.is_float(),
FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "key", dtype: k_dtype.clone() }
);
let v_dtype = value.uop().dtype();
ensure!(
v_dtype.is_float(),
FloatDTypeRequiredSnafu { op: "scaled_dot_product_attention", arg: "value", dtype: v_dtype.clone() }
);
let q_shape = self.shape()?;
let k_shape = key.shape()?;
let head_dim = q_shape[q_shape.len() - 1].as_const().expect("Q head_dim must be concrete");
let scale_val = scale.unwrap_or(1.0 / (head_dim as f64).sqrt());
let scores_dtype = self.uop().dtype();
let kt = key.try_transpose(-1, -2)?;
let mut scores = self.matmul(&kt)?;
let scale_t = Tensor::const_(scale_val, scores_dtype.clone());
scores = scores.try_mul(&scale_t)?;
if is_causal {
let q_len = q_shape[q_shape.len() - 2].as_const().expect("Q seq_len must be concrete");
let k_len = k_shape[k_shape.len() - 2].as_const().expect("K seq_len must be concrete");
let causal = Tensor::full(&[q_len, k_len], true, DType::Bool)?.tril(0)?;
let neg_large = Tensor::const_(ConstValue::min(scores_dtype.base()), scores_dtype.clone());
scores = scores.where_(&causal, &neg_large)?;
}
let mut bool_mask: Option<Tensor> = None;
if let Some(mask) = attn_mask {
let mask_dtype = mask.uop().dtype();
if mask_dtype == DType::Bool {
let neg_large = Tensor::const_(ConstValue::min(scores_dtype.base()), scores_dtype.clone());
let zero = Tensor::const_(ConstValue::zero(scores_dtype.base()), scores_dtype.clone());
let additive = neg_large.where_(mask, &zero)?;
scores = scores.try_add(&additive)?;
bool_mask = Some(mask.clone());
} else {
scores = scores.try_add(mask)?;
}
}
if let Some(cap) = softcap
&& cap > 0.0
{
let cap_t = Tensor::const_(cap, scores_dtype.clone());
scores = scores.try_div(&cap_t)?.tanh()?.try_mul(&cap_t)?;
}
let mut attn_weights = scores.softmax(-1isize)?;
if let Some(mask) = bool_mask.as_ref() {
let zero = Tensor::const_(ConstValue::zero(scores_dtype.base()), scores_dtype);
attn_weights = zero.where_(mask, &attn_weights)?;
}
attn_weights.matmul(value)
}
}