Skip to main content

hanzo_engine/
layers_masker.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::ops::Add;
4
5use hanzo_ml::{DType, Device, Result, Tensor, WithDType};
6
7use crate::pipeline::KvCache;
8
9// https://github.com/huggingface/transformers/blob/main/src/transformers/modeling_attn_mask_utils.py
10pub struct CausalMasker;
11
12/// Configuration for [`CausalMasker::make_causal_mask`].
13#[derive(Default)]
14pub struct CausalMaskConfig {
15    /// Sliding window size. `None` = full causal attention.
16    pub sliding_window: Option<usize>,
17    /// Force `AttentionMask::Custom` even when flash attention is available.
18    /// Set to `true` when you need the real mask tensor (e.g. bidirectional
19    /// vision overrides, head_dim > 256 eager fallback).
20    pub force_custom: bool,
21}
22
23// an external reference implementation
24/// xs are on false (0), value is on true (1)
25pub fn masked_fill<D: WithDType>(xs: &Tensor, mask: &Tensor, value: D) -> Result<Tensor> {
26    let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?;
27    let on_false = xs;
28    let res = mask
29        .broadcast_as(xs.shape())?
30        .where_cond(&on_true, on_false)?;
31    Ok(res)
32}
33
34pub struct NotACache;
35
36pub trait PastKvLenCache {
37    fn get_past_kv_len(&self) -> Result<usize>;
38}
39
40impl PastKvLenCache for NotACache {
41    fn get_past_kv_len(&self) -> Result<usize> {
42        Ok(0)
43    }
44}
45
46impl PastKvLenCache for Vec<KvCache> {
47    fn get_past_kv_len(&self) -> Result<usize> {
48        Ok(self.iter().map(KvCache::current_seq_len).max().unwrap_or(0))
49    }
50}
51
52impl PastKvLenCache for &[usize] {
53    fn get_past_kv_len(&self) -> Result<usize> {
54        if self.windows(2).all(|w| w[0] == w[1]) {
55            Ok(self[0])
56        } else {
57            Ok(0)
58        }
59    }
60}
61
62impl PastKvLenCache for Vec<Option<(Tensor, Tensor)>> {
63    fn get_past_kv_len(&self) -> Result<usize> {
64        let kv_cache_1 = &self[0];
65        if kv_cache_1.is_none() {
66            return Ok(0);
67        }
68        let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
69        Ok(k_cache_1.dims()[2])
70    }
71}
72
73impl CausalMasker {
74    fn make_mask(&self, tgt_len: usize, past_kv_len: usize, device: &Device) -> Result<Tensor> {
75        let offset = tgt_len + past_kv_len;
76        let mask: Vec<_> = (0..tgt_len)
77            .flat_map(|i| (0..offset).map(move |j| u8::from(j + tgt_len > i + offset)))
78            .collect();
79        Tensor::from_slice(&mask, (tgt_len, offset), device)
80    }
81
82    fn make_mask_chunked(
83        &self,
84        tgt_len: usize,
85        past_kv_len: usize,
86        chunk_size: usize,
87        device: &Device,
88    ) -> Result<Tensor> {
89        let offset = tgt_len + past_kv_len;
90        let mask: Vec<_> = (0..tgt_len)
91            .flat_map(|i| {
92                (0..offset).map(move |j| {
93                    // For past key-value positions
94                    if j < past_kv_len {
95                        return 0;
96                    }
97
98                    // Adjust j to account for past_kv_len
99                    let j_adj = j - past_kv_len;
100
101                    // Calculate block position (equivalent to block_pos)
102                    let i_block = i / chunk_size;
103                    let j_block = j_adj / chunk_size;
104                    let block_pos = (i_block as isize - j_block as isize).abs();
105
106                    // Calculate token position (equivalent to token_pos)
107                    let token_pos = j_adj as isize - i as isize;
108
109                    // Apply mask conditions: same block and causal
110                    1 - u8::from((block_pos == 0) && (token_pos <= 0))
111                })
112            })
113            .collect();
114
115        Tensor::from_slice(&mask, (tgt_len, offset), device)
116    }
117
118    fn make_swa_mask(
119        &self,
120        tgt_len: usize,
121        past_kv_len: usize,
122        sliding_window: usize,
123        device: &Device,
124        dtype: DType,
125    ) -> Result<Tensor> {
126        let total_kv_len = tgt_len + past_kv_len;
127        let mask: Vec<_> = (0..tgt_len)
128            .flat_map(|i| {
129                let q_pos = past_kv_len + i;
130                (0..total_kv_len).map(move |j| {
131                    // HF's sliding causal mask uses an exclusive lower bound
132                    // (`kv_idx > q_idx - sliding_window`), so the current
133                    // token plus its visible history total exactly
134                    // `sliding_window` positions.
135                    let too_old = q_pos >= sliding_window && j <= q_pos - sliding_window;
136                    if j > q_pos || too_old {
137                        f32::NEG_INFINITY
138                    } else {
139                        0.
140                    }
141                })
142            })
143            .collect();
144        Tensor::from_slice(&mask, (tgt_len, total_kv_len), device)?.to_dtype(dtype)
145    }
146
147    /// Expands a mask from (bs, seq_len) to (bs, 1, tgt_len, seq_len)
148    /// If tgt_len is None, use seq_len
149    pub fn expand_mask(
150        &self,
151        mask: &Tensor,
152        dtype: DType,
153        tgt_len: Option<usize>,
154    ) -> Result<Tensor> {
155        let (bs, src_len) = mask.dims2()?;
156
157        let expanded_mask = mask.unsqueeze(1)?.unsqueeze(1)?;
158        let expanded_mask = expanded_mask
159            .expand((bs, 1, tgt_len.unwrap_or(src_len), src_len))?
160            .to_dtype(dtype)?;
161
162        let inverted_mask = expanded_mask.neg()?.add(1.0f64)?;
163        masked_fill(
164            &inverted_mask,
165            &inverted_mask.to_dtype(DType::U8)?,
166            f32::MIN,
167        )
168    }
169
170    pub fn calculate_past_kv_len(
171        &self,
172        cache: &[Option<(Tensor, Tensor)>],
173    ) -> hanzo_ml::Result<usize> {
174        let kv_cache_1 = &cache[0];
175        if kv_cache_1.is_none() {
176            return Ok(0);
177        }
178        let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
179        Ok(k_cache_1.dims()[2])
180    }
181
182    /// Build a causal attention mask.
183    ///
184    /// - Returns `AttentionMask::None` for single-token decode.
185    /// - Returns `AttentionMask::CausalFlash` on CUDA+flash when
186    ///   `force_custom` is false (flash handles causality internally).
187    /// - Returns `AttentionMask::Custom` with a real mask tensor otherwise.
188    pub fn make_causal_mask(
189        &self,
190        input_ids: &Tensor,
191        cache: &dyn PastKvLenCache,
192        dtype: DType,
193        cfg: &CausalMaskConfig,
194    ) -> Result<crate::attention::AttentionMask> {
195        let past_kv_len = cache.get_past_kv_len()?;
196        let (_b_sz, tgt_len) = input_ids.dims2()?;
197        if tgt_len == 1 {
198            return Ok(crate::attention::AttentionMask::None);
199        }
200
201        if !cfg.force_custom && crate::using_flash_attn() && input_ids.device().is_cuda() {
202            return Ok(crate::attention::AttentionMask::CausalFlash);
203        }
204
205        let mask = if let Some(sw) = cfg.sliding_window {
206            self.make_swa_mask(tgt_len, past_kv_len, sw, input_ids.device(), dtype)?
207        } else {
208            let causal = self
209                .make_mask(tgt_len, past_kv_len, input_ids.device())?
210                .to_dtype(DType::U8)?;
211            let zero = Tensor::new(0.0f32, input_ids.device())?;
212            masked_fill(
213                &zero
214                    .to_dtype(dtype)?
215                    .broadcast_as((causal.dims()[0], causal.dims()[1]))?,
216                &causal,
217                f32::NEG_INFINITY,
218            )?
219        };
220
221        Ok(crate::attention::AttentionMask::Custom(mask))
222    }
223
224    pub fn make_chunked_mask_matrix(
225        &self,
226        input_ids: &Tensor,
227        chunk_size: usize,
228        cache: &dyn PastKvLenCache,
229        dtype: DType,
230        _n_attn_heads: usize,
231    ) -> Result<Option<Tensor>> {
232        let past_kv_len = cache.get_past_kv_len()?;
233        let (_b_sz, tgt_len) = input_ids.dims2()?;
234        if tgt_len == 1 {
235            return Ok(None);
236        }
237
238        let mut causal_mask = self
239            .make_mask_chunked(tgt_len, past_kv_len, chunk_size, input_ids.device())?
240            .to_dtype(DType::U8)?;
241
242        let zero = Tensor::new(0.0f32, input_ids.device())?;
243        causal_mask = {
244            let mut mask =
245                causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
246            // Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
247            mask = masked_fill(
248                &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
249                &mask,
250                f32::NEG_INFINITY,
251            )?;
252            mask
253        };
254
255        Ok(Some(causal_mask))
256    }
257
258    pub fn apply_mask_one_and_zero(
259        &self,
260        mask: &Option<Tensor>,
261        att: Tensor,
262        neg_inf: &Tensor,
263    ) -> Result<Tensor> {
264        match mask {
265            None => Ok(att),
266            Some(mask) => {
267                let mask = mask.broadcast_as(att.shape())?;
268                mask.where_cond(
269                    &neg_inf
270                        .to_device(att.device())?
271                        .to_dtype(att.dtype())?
272                        .broadcast_as(att.dims())?,
273                    &att,
274                )
275            }
276        }
277    }
278}
279
280pub struct BidirectionalMasker;
281
282impl BidirectionalMasker {
283    fn make_swa_mask(
284        &self,
285        tgt_len: usize,
286        sliding_window: usize,
287        device: &Device,
288        dtype: DType,
289    ) -> Result<Tensor> {
290        let mask: Vec<_> = (0..tgt_len)
291            .flat_map(|i| {
292                (0..tgt_len).map(move |j| {
293                    // https://github.com/huggingface/transformers/blob/a0bf5a82eebf88ee9f52145be427f6f1541329f6/src/transformers/models/gemma3/modeling_gemma3.py#L478
294                    // A token can attend to any other token if their absolute distance is within the (exclusive) sliding window size (distance < sliding_window)."
295                    if (i as isize - j as isize).unsigned_abs() >= sliding_window {
296                        f32::NEG_INFINITY
297                    } else {
298                        0.
299                    }
300                })
301            })
302            .collect();
303        let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
304        mask.to_dtype(dtype)
305    }
306
307    pub fn make_mask(&self, input_ids: &Tensor, dtype: DType) -> Result<Tensor> {
308        let (_b_sz, tgt_len) = input_ids.dims2()?;
309
310        // Do not make any -inf, bidirectional (all-zeros) mask
311        let mask = Tensor::zeros((tgt_len, tgt_len), dtype, input_ids.device())?;
312
313        Ok(mask)
314    }
315    pub fn make_sliding_mask(
316        &self,
317        input_ids: &Tensor,
318        dtype: DType,
319        sliding_window: usize,
320    ) -> Result<Tensor> {
321        let (_b_sz, tgt_len) = input_ids.dims2()?;
322
323        let mask = self.make_swa_mask(tgt_len, sliding_window, input_ids.device(), dtype)?;
324
325        Ok(mask)
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332
333    fn finite_rows(mask: &Tensor) -> Result<Vec<Vec<bool>>> {
334        Ok(mask
335            .to_vec2::<f32>()?
336            .into_iter()
337            .map(|row| row.into_iter().map(f32::is_finite).collect())
338            .collect())
339    }
340
341    #[test]
342    fn causal_sliding_mask_keeps_exact_window_width() -> Result<()> {
343        let mask = CausalMasker.make_swa_mask(2, 3, 2, &Device::Cpu, DType::F32)?;
344
345        assert_eq!(
346            finite_rows(&mask)?,
347            vec![
348                vec![false, false, true, true, false],
349                vec![false, false, false, true, true],
350            ]
351        );
352        Ok(())
353    }
354}