1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use std::ops::Add;
4
5use hanzo_ml::{DType, Device, Result, Tensor, WithDType};
6
7use crate::pipeline::KvCache;
8
9pub struct CausalMasker;
11
12#[derive(Default)]
14pub struct CausalMaskConfig {
15 pub sliding_window: Option<usize>,
17 pub force_custom: bool,
21}
22
23pub fn masked_fill<D: WithDType>(xs: &Tensor, mask: &Tensor, value: D) -> Result<Tensor> {
26 let on_true = Tensor::full(value, xs.shape(), xs.device())?.to_dtype(xs.dtype())?;
27 let on_false = xs;
28 let res = mask
29 .broadcast_as(xs.shape())?
30 .where_cond(&on_true, on_false)?;
31 Ok(res)
32}
33
34pub struct NotACache;
35
36pub trait PastKvLenCache {
37 fn get_past_kv_len(&self) -> Result<usize>;
38}
39
40impl PastKvLenCache for NotACache {
41 fn get_past_kv_len(&self) -> Result<usize> {
42 Ok(0)
43 }
44}
45
46impl PastKvLenCache for Vec<KvCache> {
47 fn get_past_kv_len(&self) -> Result<usize> {
48 Ok(self.iter().map(KvCache::current_seq_len).max().unwrap_or(0))
49 }
50}
51
52impl PastKvLenCache for &[usize] {
53 fn get_past_kv_len(&self) -> Result<usize> {
54 if self.windows(2).all(|w| w[0] == w[1]) {
55 Ok(self[0])
56 } else {
57 Ok(0)
58 }
59 }
60}
61
62impl PastKvLenCache for Vec<Option<(Tensor, Tensor)>> {
63 fn get_past_kv_len(&self) -> Result<usize> {
64 let kv_cache_1 = &self[0];
65 if kv_cache_1.is_none() {
66 return Ok(0);
67 }
68 let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
69 Ok(k_cache_1.dims()[2])
70 }
71}
72
73impl CausalMasker {
74 fn make_mask(&self, tgt_len: usize, past_kv_len: usize, device: &Device) -> Result<Tensor> {
75 let offset = tgt_len + past_kv_len;
76 let mask: Vec<_> = (0..tgt_len)
77 .flat_map(|i| (0..offset).map(move |j| u8::from(j + tgt_len > i + offset)))
78 .collect();
79 Tensor::from_slice(&mask, (tgt_len, offset), device)
80 }
81
82 fn make_mask_chunked(
83 &self,
84 tgt_len: usize,
85 past_kv_len: usize,
86 chunk_size: usize,
87 device: &Device,
88 ) -> Result<Tensor> {
89 let offset = tgt_len + past_kv_len;
90 let mask: Vec<_> = (0..tgt_len)
91 .flat_map(|i| {
92 (0..offset).map(move |j| {
93 if j < past_kv_len {
95 return 0;
96 }
97
98 let j_adj = j - past_kv_len;
100
101 let i_block = i / chunk_size;
103 let j_block = j_adj / chunk_size;
104 let block_pos = (i_block as isize - j_block as isize).abs();
105
106 let token_pos = j_adj as isize - i as isize;
108
109 1 - u8::from((block_pos == 0) && (token_pos <= 0))
111 })
112 })
113 .collect();
114
115 Tensor::from_slice(&mask, (tgt_len, offset), device)
116 }
117
118 fn make_swa_mask(
119 &self,
120 tgt_len: usize,
121 past_kv_len: usize,
122 sliding_window: usize,
123 device: &Device,
124 dtype: DType,
125 ) -> Result<Tensor> {
126 let total_kv_len = tgt_len + past_kv_len;
127 let mask: Vec<_> = (0..tgt_len)
128 .flat_map(|i| {
129 let q_pos = past_kv_len + i;
130 (0..total_kv_len).map(move |j| {
131 let too_old = q_pos >= sliding_window && j <= q_pos - sliding_window;
136 if j > q_pos || too_old {
137 f32::NEG_INFINITY
138 } else {
139 0.
140 }
141 })
142 })
143 .collect();
144 Tensor::from_slice(&mask, (tgt_len, total_kv_len), device)?.to_dtype(dtype)
145 }
146
147 pub fn expand_mask(
150 &self,
151 mask: &Tensor,
152 dtype: DType,
153 tgt_len: Option<usize>,
154 ) -> Result<Tensor> {
155 let (bs, src_len) = mask.dims2()?;
156
157 let expanded_mask = mask.unsqueeze(1)?.unsqueeze(1)?;
158 let expanded_mask = expanded_mask
159 .expand((bs, 1, tgt_len.unwrap_or(src_len), src_len))?
160 .to_dtype(dtype)?;
161
162 let inverted_mask = expanded_mask.neg()?.add(1.0f64)?;
163 masked_fill(
164 &inverted_mask,
165 &inverted_mask.to_dtype(DType::U8)?,
166 f32::MIN,
167 )
168 }
169
170 pub fn calculate_past_kv_len(
171 &self,
172 cache: &[Option<(Tensor, Tensor)>],
173 ) -> hanzo_ml::Result<usize> {
174 let kv_cache_1 = &cache[0];
175 if kv_cache_1.is_none() {
176 return Ok(0);
177 }
178 let k_cache_1 = &kv_cache_1.as_ref().unwrap().0;
179 Ok(k_cache_1.dims()[2])
180 }
181
182 pub fn make_causal_mask(
189 &self,
190 input_ids: &Tensor,
191 cache: &dyn PastKvLenCache,
192 dtype: DType,
193 cfg: &CausalMaskConfig,
194 ) -> Result<crate::attention::AttentionMask> {
195 let past_kv_len = cache.get_past_kv_len()?;
196 let (_b_sz, tgt_len) = input_ids.dims2()?;
197 if tgt_len == 1 {
198 return Ok(crate::attention::AttentionMask::None);
199 }
200
201 if !cfg.force_custom && crate::using_flash_attn() && input_ids.device().is_cuda() {
202 return Ok(crate::attention::AttentionMask::CausalFlash);
203 }
204
205 let mask = if let Some(sw) = cfg.sliding_window {
206 self.make_swa_mask(tgt_len, past_kv_len, sw, input_ids.device(), dtype)?
207 } else {
208 let causal = self
209 .make_mask(tgt_len, past_kv_len, input_ids.device())?
210 .to_dtype(DType::U8)?;
211 let zero = Tensor::new(0.0f32, input_ids.device())?;
212 masked_fill(
213 &zero
214 .to_dtype(dtype)?
215 .broadcast_as((causal.dims()[0], causal.dims()[1]))?,
216 &causal,
217 f32::NEG_INFINITY,
218 )?
219 };
220
221 Ok(crate::attention::AttentionMask::Custom(mask))
222 }
223
224 pub fn make_chunked_mask_matrix(
225 &self,
226 input_ids: &Tensor,
227 chunk_size: usize,
228 cache: &dyn PastKvLenCache,
229 dtype: DType,
230 _n_attn_heads: usize,
231 ) -> Result<Option<Tensor>> {
232 let past_kv_len = cache.get_past_kv_len()?;
233 let (_b_sz, tgt_len) = input_ids.dims2()?;
234 if tgt_len == 1 {
235 return Ok(None);
236 }
237
238 let mut causal_mask = self
239 .make_mask_chunked(tgt_len, past_kv_len, chunk_size, input_ids.device())?
240 .to_dtype(DType::U8)?;
241
242 let zero = Tensor::new(0.0f32, input_ids.device())?;
243 causal_mask = {
244 let mut mask =
245 causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
246 mask = masked_fill(
248 &zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
249 &mask,
250 f32::NEG_INFINITY,
251 )?;
252 mask
253 };
254
255 Ok(Some(causal_mask))
256 }
257
258 pub fn apply_mask_one_and_zero(
259 &self,
260 mask: &Option<Tensor>,
261 att: Tensor,
262 neg_inf: &Tensor,
263 ) -> Result<Tensor> {
264 match mask {
265 None => Ok(att),
266 Some(mask) => {
267 let mask = mask.broadcast_as(att.shape())?;
268 mask.where_cond(
269 &neg_inf
270 .to_device(att.device())?
271 .to_dtype(att.dtype())?
272 .broadcast_as(att.dims())?,
273 &att,
274 )
275 }
276 }
277 }
278}
279
280pub struct BidirectionalMasker;
281
282impl BidirectionalMasker {
283 fn make_swa_mask(
284 &self,
285 tgt_len: usize,
286 sliding_window: usize,
287 device: &Device,
288 dtype: DType,
289 ) -> Result<Tensor> {
290 let mask: Vec<_> = (0..tgt_len)
291 .flat_map(|i| {
292 (0..tgt_len).map(move |j| {
293 if (i as isize - j as isize).unsigned_abs() >= sliding_window {
296 f32::NEG_INFINITY
297 } else {
298 0.
299 }
300 })
301 })
302 .collect();
303 let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), device)?;
304 mask.to_dtype(dtype)
305 }
306
307 pub fn make_mask(&self, input_ids: &Tensor, dtype: DType) -> Result<Tensor> {
308 let (_b_sz, tgt_len) = input_ids.dims2()?;
309
310 let mask = Tensor::zeros((tgt_len, tgt_len), dtype, input_ids.device())?;
312
313 Ok(mask)
314 }
315 pub fn make_sliding_mask(
316 &self,
317 input_ids: &Tensor,
318 dtype: DType,
319 sliding_window: usize,
320 ) -> Result<Tensor> {
321 let (_b_sz, tgt_len) = input_ids.dims2()?;
322
323 let mask = self.make_swa_mask(tgt_len, sliding_window, input_ids.device(), dtype)?;
324
325 Ok(mask)
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332
333 fn finite_rows(mask: &Tensor) -> Result<Vec<Vec<bool>>> {
334 Ok(mask
335 .to_vec2::<f32>()?
336 .into_iter()
337 .map(|row| row.into_iter().map(f32::is_finite).collect())
338 .collect())
339 }
340
341 #[test]
342 fn causal_sliding_mask_keeps_exact_window_width() -> Result<()> {
343 let mask = CausalMasker.make_swa_mask(2, 3, 2, &Device::Cpu, DType::F32)?;
344
345 assert_eq!(
346 finite_rows(&mask)?,
347 vec![
348 vec![false, false, true, true, false],
349 vec![false, false, false, true, true],
350 ]
351 );
352 Ok(())
353 }
354}