use candle_core::{Result, Tensor};
use mistralrs_quant::MatMul;
use crate::{
attention::{repeat_kv, SdpaParams},
pipeline::text_models_inputs_processor::FlashParams,
};
#[allow(unused_variables, clippy::too_many_arguments)]
pub(crate) fn sinks_attn(
q: &Tensor,
k: &Tensor,
v: &Tensor,
sinks: &Tensor,
mask: Option<&Tensor>,
flash_params: Option<&FlashParams>,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
let (b_sz, _n_heads, _q_len, _head_dim) = q.dims4()?;
let window_size = sdpa_params.sliding_window.unwrap_or(0);
let is_varlen = b_sz > 1
&& flash_params.is_some_and(|fp| {
fp.k_meta(sdpa_params.sliding_window)
.cumulative_seqlens
.contains_key(&q.device().location())
});
if is_varlen {
return sinks_attn_varlen(
q,
k,
v,
sinks,
flash_params.unwrap(),
sdpa_params,
window_size,
);
}
sinks_attn_regular(q, k, v, sinks, mask, sdpa_params, window_size)
}
#[allow(unused_variables)]
fn sinks_attn_regular(
q: &Tensor,
k: &Tensor,
v: &Tensor,
sinks: &Tensor,
mask: Option<&Tensor>,
sdpa_params: &SdpaParams,
window_size: usize,
) -> Result<Tensor> {
#[cfg(all(feature = "cuda", target_family = "unix"))]
if q.device().is_cuda() {
return mistralrs_paged_attn::flash_attn_sinks(
q,
k,
v,
Some(sinks),
sdpa_params.softmax_scale,
window_size,
);
}
#[cfg(feature = "metal")]
if q.device().is_metal() {
return mistralrs_quant::flash_attn_sinks_metal(
q,
k,
v,
Some(sinks),
sdpa_params.softmax_scale,
window_size,
);
}
sinks_attn_cpu(q, k, v, sinks, mask, sdpa_params)
}
#[allow(unused_variables)]
fn sinks_attn_varlen(
q: &Tensor,
k: &Tensor,
v: &Tensor,
sinks: &Tensor,
flash_params: &FlashParams,
sdpa_params: &SdpaParams,
window_size: usize,
) -> Result<Tensor> {
let device = q.device();
let (k_packed, v_packed) = if k.dims().len() == 4 {
(
k.squeeze(0)?.transpose(0, 1)?,
v.squeeze(0)?.transpose(0, 1)?,
)
} else {
(k.clone(), v.clone())
};
let cu_seqlens_q = &flash_params.cumulative_seqlens_q[&device.location()];
let cu_seqlens_k = &flash_params
.k_meta(sdpa_params.sliding_window)
.cumulative_seqlens[&device.location()];
#[cfg(all(feature = "cuda", target_family = "unix"))]
if device.is_cuda() {
return mistralrs_paged_attn::flash_attn_sinks_varlen(
q,
&k_packed,
&v_packed,
Some(sinks),
cu_seqlens_q,
cu_seqlens_k,
sdpa_params.softmax_scale,
window_size,
);
}
#[cfg(feature = "metal")]
if device.is_metal() {
return mistralrs_quant::flash_attn_sinks_varlen_metal(
q,
&k_packed,
&v_packed,
Some(sinks),
cu_seqlens_q,
cu_seqlens_k,
sdpa_params.softmax_scale,
window_size,
);
}
let cu_q_vec: Vec<u32> = cu_seqlens_q.to_vec1()?;
let cu_k_vec: Vec<u32> = cu_seqlens_k.to_vec1()?;
sinks_attn_cpu_varlen(
q,
&k_packed,
&v_packed,
sinks,
sdpa_params,
&cu_q_vec,
&cu_k_vec,
)
}
fn sinks_attn_cpu(
q: &Tensor,
k: &Tensor,
v: &Tensor,
sinks: &Tensor,
mask: Option<&Tensor>,
sdpa_params: &SdpaParams,
) -> Result<Tensor> {
let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
let att = MatMul.matmul_affine_mul(q, &k.t()?, sdpa_params.softmax_scale.into())?;
let att = mistralrs_quant::softmax_with_sinks(&att, sinks, mask)?;
MatMul.matmul(&att, &v)
}
fn sinks_attn_cpu_varlen(
q: &Tensor,
k_packed: &Tensor,
v_packed: &Tensor,
sinks: &Tensor,
sdpa_params: &SdpaParams,
cu_q: &[u32],
cu_k: &[u32],
) -> Result<Tensor> {
let (b_sz, num_heads, max_q, head_dim) = q.dims4()?;
let device = q.device();
let mut outputs = Vec::with_capacity(b_sz);
for i in 0..b_sz {
let q_len = (cu_q[i + 1] - cu_q[i]) as usize;
let kv_start = cu_k[i] as usize;
let kv_len = (cu_k[i + 1] - cu_k[i]) as usize;
let qi = q.narrow(0, i, 1)?;
let qi = qi.narrow(2, 0, q_len)?;
let ki = k_packed
.narrow(0, kv_start, kv_len)?
.transpose(0, 1)?
.unsqueeze(0)?;
let vi = v_packed
.narrow(0, kv_start, kv_len)?
.transpose(0, 1)?
.unsqueeze(0)?;
let oi = sinks_attn_cpu(&qi, &ki, &vi, sinks, None, sdpa_params)?;
if q_len < max_q {
let pad = Tensor::zeros((1, num_heads, max_q - q_len, head_dim), qi.dtype(), device)?;
outputs.push(Tensor::cat(&[&oi, &pad], 2)?);
} else {
outputs.push(oi);
}
}
Tensor::cat(&outputs, 0)
}