kalosm_common/
mask.rs

1use candle_core::*;
2use std::collections::HashMap;
3use std::sync::{OnceLock, RwLock};
4
5#[derive(Default, Debug)]
6pub struct MaskCache {
7    masks: RwLock<HashMap<usize, AttentionMask>>,
8}
9
10impl MaskCache {
11    pub fn get_mask(
12        &self,
13        seq_len: usize,
14        seqlen_offset: usize,
15        device: &Device,
16    ) -> Result<AttentionMask> {
17        let mask = if let Some(mask) = {
18            let masks = self.masks.read().unwrap();
19            masks.get(&seq_len).cloned()
20        } {
21            mask
22        } else {
23            let mask: Vec<_> = (0..seq_len)
24                .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
25                .collect();
26            let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
27            let mut masks = self.masks.write().unwrap();
28            let mask = AttentionMask {
29                mask,
30                on_true: OnceLock::new(),
31            };
32            masks.insert(seq_len, mask.clone());
33            mask
34        };
35
36        let mask_tensor = if seqlen_offset > 0 {
37            // If this isn't the first token, we need to pad the mask with zeros for the previous tokens.
38            let mask0 = Tensor::zeros((seq_len, seqlen_offset), DType::U8, device)?;
39            Tensor::cat(&[&mask0, &mask.mask], D::Minus1)?
40        } else {
41            mask.mask
42        };
43
44        let mask_tensor = mask_tensor.unsqueeze(0)?.unsqueeze(0)?;
45
46        Ok(AttentionMask {
47            mask: mask_tensor,
48            on_true: mask.on_true,
49        })
50    }
51}
52
53#[derive(Clone, Debug)]
54pub struct AttentionMask {
55    pub mask: Tensor,
56    pub on_true: OnceLock<Tensor>,
57}
58
59impl AttentionMask {
60    pub fn forward(&self, attn_weights: &mut Tensor) -> candle_core::Result<()> {
61        let shape = attn_weights.shape();
62        let attention_mask = self.mask.broadcast_as(shape)?;
63        let on_true = match self.on_true.get() {
64            Some(on_true) => on_true.clone(),
65            None => {
66                let on_true =
67                    Tensor::new(f32::NEG_INFINITY, attn_weights.device())?.broadcast_as(shape)?;
68                self.on_true.set(on_true.clone()).unwrap();
69                on_true
70            }
71        };
72        *attn_weights = attention_mask.where_cond(&on_true, attn_weights)?;
73        Ok(())
74    }
75}