#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use std::ops::Add;
use candle_core::{DType, Device, Result, Tensor, WithDType};
use crate::pipeline::KvCache;
pub struct CausalMasker;
pub fn masked_fill<D: WithDType>(xs: &Tensor, mask: &Tensor, value: D) -> Result<Tensor> {
let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?;
let on_false = xs;
let res = mask
.broadcast_as(xs.shape())?
.where_cond(&on_true, on_false)?;
Ok(res)
}
pub struct NotACache;
pub trait PastKvLenCache {
fn get_past_kv_len(&self) -> Result<usize>;
}
impl PastKvLenCache for NotACache {
fn get_past_kv_len(&self) -> Result<usize> {
Ok(0)
}
}
impl PastKvLenCache for Vec<KvCache> {
fn get_past_kv_len(&self) -> Result<usize> {
Ok(self.iter().map(KvCache::current_seq_len).max().unwrap_or(0))
}
}
impl PastKvLenCache for &[usize] {
fn get_past_kv_len(&self) -> Result<usize> {
if self.windows(2).all(|w| w[0] == w[1]) {
Ok(self[0])
} else {
Ok(0)
}
}
}
impl PastKvLenCache for Vec<Option<(Tensor, Tensor)>> {
fn get_past_kv_len(&self) -> Result<usize> {
let kv_cache_1 = &self[0];
if kv_cache_1.is_none() {
return Ok(0);
}
let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
Ok(k_cache_1.dims()[2])
}
}
impl CausalMasker {
fn make_mask(&self, tgt_len: usize, past_kv_len: usize, device: &Device) -> Result<Tensor> {
let offset = tgt_len + past_kv_len;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| (0..offset).map(move |j| u8::from(j + tgt_len > i + offset)))
.collect();
Tensor::from_slice(&mask, (tgt_len, offset), device)
}
fn make_mask_chunked(
&self,
tgt_len: usize,
past_kv_len: usize,
chunk_size: usize,
device: &Device,
) -> Result<Tensor> {
let offset = tgt_len + past_kv_len;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..offset).map(move |j| {
if j < past_kv_len {
return 0;
}
let j_adj = j - past_kv_len;
let i_block = i / chunk_size;
let j_block = j_adj / chunk_size;
let block_pos = (i_block as isize - j_block as isize).abs();
let token_pos = j_adj as isize - i as isize;
1 - u8::from((block_pos == 0) && (token_pos <= 0))
})
})
.collect();
Tensor::from_slice(&mask, (tgt_len, offset), device)
}
fn make_swa_mask(
&self,
tgt_len: usize,
past_kv_len: usize,
sliding_window: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let total_kv_len = tgt_len + past_kv_len;
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
let q_pos = past_kv_len + i;
(0..total_kv_len).map(move |j| {
if j > q_pos || j + sliding_window < q_pos {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
Tensor::from_slice(&mask, (tgt_len, total_kv_len), device)?.to_dtype(dtype)
}
pub fn expand_mask(
&self,
mask: &Tensor,
dtype: DType,
tgt_len: Option<usize>,
) -> Result<Tensor> {
let (bs, src_len) = mask.dims2()?;
let expanded_mask = mask.unsqueeze(1)?.unsqueeze(1)?;
let expanded_mask = expanded_mask
.expand((bs, 1, tgt_len.unwrap_or(src_len), src_len))?
.to_dtype(dtype)?;
let inverted_mask = expanded_mask.neg()?.add(1.0f64)?;
masked_fill(
&inverted_mask,
&inverted_mask.to_dtype(DType::U8)?,
f32::MIN,
)
}
pub fn calculate_past_kv_len(
&self,
cache: &[Option<(Tensor, Tensor)>],
) -> candle_core::Result<usize> {
let kv_cache_1 = &cache[0];
if kv_cache_1.is_none() {
return Ok(0);
}
let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
Ok(k_cache_1.dims()[2])
}
pub fn make_causal_mask_matrix(
&self,
input_ids: &Tensor,
cache: &dyn PastKvLenCache,
dtype: DType,
_n_attn_heads: usize,
) -> Result<Option<Tensor>> {
let past_kv_len = cache.get_past_kv_len()?;
let (_b_sz, tgt_len) = input_ids.dims2()?;
if tgt_len == 1 {
return Ok(None);
}
if crate::using_flash_attn() && input_ids.device().is_cuda() {
return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
}
let mut causal_mask = self
.make_mask(tgt_len, past_kv_len, input_ids.device())?
.to_dtype(DType::U8)?;
let zero = Tensor::new(0.0f32, input_ids.device())?;
causal_mask = {
let mut mask =
causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;
mask
};
Ok(Some(causal_mask))
}
pub fn make_causal_mask_as_attn_bias(
&self,
input_ids: &Tensor,
cache: &dyn PastKvLenCache,
dtype: DType,
) -> Result<Option<Tensor>> {
let past_kv_len = cache.get_past_kv_len()?;
let (_b_sz, tgt_len) = input_ids.dims2()?;
if tgt_len == 1 {
return Ok(None);
}
let mut causal_mask = self
.make_mask(tgt_len, past_kv_len, input_ids.device())?
.to_dtype(DType::U8)?;
let zero = Tensor::new(0.0f32, input_ids.device())?;
causal_mask = {
let mask = causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?
};
Ok(Some(causal_mask))
}
pub fn make_sliding_window_causal_mask_as_attn_bias(
&self,
input_ids: &Tensor,
cache: &dyn PastKvLenCache,
sliding_window: Option<usize>,
dtype: DType,
) -> Result<Option<Tensor>> {
if sliding_window.is_none() {
return self.make_causal_mask_as_attn_bias(input_ids, cache, dtype);
}
let (_b_sz, tgt_len) = input_ids.dims2()?;
let sliding_window = sliding_window.unwrap();
let past_kv_len = cache.get_past_kv_len()?;
if tgt_len == 1 {
return Ok(None);
}
Ok(Some(self.make_swa_mask(
tgt_len,
past_kv_len,
sliding_window,
input_ids.device(),
dtype,
)?))
}
pub fn make_chunked_mask_matrix(
&self,
input_ids: &Tensor,
chunk_size: usize,
cache: &dyn PastKvLenCache,
dtype: DType,
_n_attn_heads: usize,
) -> Result<Option<Tensor>> {
let past_kv_len = cache.get_past_kv_len()?;
let (_b_sz, tgt_len) = input_ids.dims2()?;
if tgt_len == 1 {
return Ok(None);
}
let mut causal_mask = self
.make_mask_chunked(tgt_len, past_kv_len, chunk_size, input_ids.device())?
.to_dtype(DType::U8)?;
let zero = Tensor::new(0.0f32, input_ids.device())?;
causal_mask = {
let mut mask =
causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;
mask
};
Ok(Some(causal_mask))
}
pub fn make_sliding_window_causal_mask_matrix(
&self,
input_ids: &Tensor,
cache: &dyn PastKvLenCache,
sliding_window: Option<usize>,
dtype: DType,
n_attn_heads: usize,
) -> Result<Option<Tensor>> {
if sliding_window.is_none() {
return self.make_causal_mask_matrix(input_ids, cache, dtype, n_attn_heads);
}
let (_b_sz, tgt_len) = input_ids.dims2()?;
let sliding_window = sliding_window.unwrap();
if tgt_len > 1 && crate::using_flash_attn() && input_ids.device().is_cuda() {
return Ok(Some(Tensor::zeros((1, 1), dtype, input_ids.device())?));
}
let past_kv_len = cache.get_past_kv_len()?;
if tgt_len == 1 {
return Ok(None);
}
Ok(Some(self.make_swa_mask(
tgt_len,
past_kv_len,
sliding_window,
input_ids.device(),
dtype,
)?))
}
pub fn apply_mask_one_and_zero(
&self,
mask: &Option<Tensor>,
att: Tensor,
neg_inf: &Tensor,
) -> Result<Tensor> {
match mask {
None => Ok(att),
Some(mask) => {
let mask = mask.broadcast_as(att.shape())?;
mask.where_cond(
&neg_inf
.to_device(att.device())?
.to_dtype(att.dtype())?
.broadcast_as(att.dims())?,
&att,
)
}
}
}
}
pub struct BidirectionalMasker;
impl BidirectionalMasker {
fn make_swa_mask(
&self,
tgt_len: usize,
sliding_window: usize,
device: &Device,
dtype: DType,
) -> Result<Tensor> {
let mask: Vec<_> = (0..tgt_len)
.flat_map(|i| {
(0..tgt_len).map(move |j| {
if (i as isize - j as isize).unsigned_abs() >= sliding_window {
f32::NEG_INFINITY
} else {
0.
}
})
})
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
mask.to_dtype(dtype)
}
pub fn make_mask(&self, input_ids: &Tensor, dtype: DType) -> Result<Tensor> {
let (_b_sz, tgt_len) = input_ids.dims2()?;
if crate::using_flash_attn() && input_ids.device().is_cuda() {
return Tensor::zeros((1, 1), dtype, input_ids.device());
}
let mask = Tensor::zeros((tgt_len, tgt_len), dtype, input_ids.device())?;
Ok(mask)
}
pub fn make_sliding_mask(
&self,
input_ids: &Tensor,
dtype: DType,
sliding_window: usize,
) -> Result<Tensor> {
let (_b_sz, tgt_len) = input_ids.dims2()?;
if crate::using_flash_attn() && input_ids.device().is_cuda() {
return Tensor::zeros((1, 1), dtype, input_ids.device());
}
let mask = self.make_swa_mask(tgt_len, sliding_window, input_ids.device(), dtype)?;
Ok(mask)
}
}