1use candle_core::*;
2use std::collections::HashMap;
3use std::sync::{OnceLock, RwLock};
4
5#[derive(Default, Debug)]
6pub struct MaskCache {
7 masks: RwLock<HashMap<usize, AttentionMask>>,
8}
9
10impl MaskCache {
11 pub fn get_mask(
12 &self,
13 seq_len: usize,
14 seqlen_offset: usize,
15 device: &Device,
16 ) -> Result<AttentionMask> {
17 let mask = if let Some(mask) = {
18 let masks = self.masks.read().unwrap();
19 masks.get(&seq_len).cloned()
20 } {
21 mask
22 } else {
23 let mask: Vec<_> = (0..seq_len)
24 .flat_map(|i| (0..seq_len).map(move |j| u8::from(j > i)))
25 .collect();
26 let mask = Tensor::from_slice(&mask, (seq_len, seq_len), device)?;
27 let mut masks = self.masks.write().unwrap();
28 let mask = AttentionMask {
29 mask,
30 on_true: OnceLock::new(),
31 };
32 masks.insert(seq_len, mask.clone());
33 mask
34 };
35
36 let mask_tensor = if seqlen_offset > 0 {
37 let mask0 = Tensor::zeros((seq_len, seqlen_offset), DType::U8, device)?;
39 Tensor::cat(&[&mask0, &mask.mask], D::Minus1)?
40 } else {
41 mask.mask
42 };
43
44 let mask_tensor = mask_tensor.unsqueeze(0)?.unsqueeze(0)?;
45
46 Ok(AttentionMask {
47 mask: mask_tensor,
48 on_true: mask.on_true,
49 })
50 }
51}
52
53#[derive(Clone, Debug)]
54pub struct AttentionMask {
55 pub mask: Tensor,
56 pub on_true: OnceLock<Tensor>,
57}
58
59impl AttentionMask {
60 pub fn forward(&self, attn_weights: &mut Tensor) -> candle_core::Result<()> {
61 let shape = attn_weights.shape();
62 let attention_mask = self.mask.broadcast_as(shape)?;
63 let on_true = match self.on_true.get() {
64 Some(on_true) => on_true.clone(),
65 None => {
66 let on_true =
67 Tensor::new(f32::NEG_INFINITY, attn_weights.device())?.broadcast_as(shape)?;
68 self.on_true.set(on_true.clone()).unwrap();
69 on_true
70 }
71 };
72 *attn_weights = attention_mask.where_cond(&on_true, attn_weights)?;
73 Ok(())
74 }
75}