Skip to main content

hanzo_engine/attention/
mod.rs

1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{attention::backends::cpu, pipeline::text_models_inputs_processor::FlashParams};
4
5use hanzo_ml::{DType, Device, Result, Tensor};
6
7/// Attention mask passed to [`Sdpa::run_attention`].
8///
9/// Encodes both the mask data and the *intent*, whether the attention layer
10/// should use flash attention (causal handled by the kernel), eager attention
11/// with an explicit mask tensor, or no masking at all.
12#[derive(Clone, Debug)]
13pub enum AttentionMask {
14    /// No masking. Used for single-token decode or truly unmasked attention.
15    None,
16    /// Flash attention with `is_causal = true`. No mask tensor is needed;
17    /// the flash kernel applies causal masking internally. Also signals
18    /// "this is a prefill" to the paged attention layer.
19    CausalFlash,
20    /// An explicit mask tensor (causal, sliding window, bidirectional, etc).
21    /// Dispatches to the eager (non-flash) attention path.
22    Custom(Tensor),
23}
24
25impl AttentionMask {
26    /// Extract the inner tensor as `Option<&Tensor>`.
27    ///
28    /// Returns `Some(&tensor)` for [`Custom`](Self::Custom), `None` otherwise.
29    /// Useful for interfacing with paged-attention and MLA helpers that still
30    /// accept `Option<&Tensor>`.
31    pub fn as_option_tensor(&self) -> Option<&Tensor> {
32        match self {
33            Self::Custom(t) => Some(t),
34            _ => None,
35        }
36    }
37
38    /// Returns `true` when the mask carries an explicit tensor
39    /// ([`Custom`](Self::Custom) variant), mirroring the old
40    /// `Option<Tensor>::is_some()` semantics.
41    pub fn is_custom(&self) -> bool {
42        matches!(self, Self::Custom(_))
43    }
44}
45
46mod backends;
47
48#[allow(unused)]
49pub(crate) use backends::{flash_attn, maybe_synchronize, naive_sdpa, sinks_attn};
50
51/// Chunk size for attention computation to avoid OOM on long sequences
52pub(crate) const ATTENTION_CHUNK_SIZE: usize = 1024;
53
54/// Generic chunked attention computation that can be used by different backends
55pub(crate) fn chunked_attention<F>(
56    q: &Tensor,
57    k: &Tensor,
58    v: &Tensor,
59    mask: Option<&Tensor>,
60    attention_fn: F,
61) -> Result<Tensor>
62where
63    F: Fn(&Tensor, &Tensor, &Tensor, Option<&Tensor>) -> Result<Tensor>,
64{
65    let seq_len = q.dim(2)?;
66
67    if seq_len <= ATTENTION_CHUNK_SIZE {
68        // For short sequences, use the regular path
69        return attention_fn(q, k, v, mask);
70    }
71
72    // Chunk the query to avoid OOM on long sequences
73    let num_chunks = seq_len.div_ceil(ATTENTION_CHUNK_SIZE);
74    let mut attn_chunks = Vec::with_capacity(num_chunks);
75
76    for chunk_idx in 0..num_chunks {
77        let offset = chunk_idx * ATTENTION_CHUNK_SIZE;
78        let chunk_len = ATTENTION_CHUNK_SIZE.min(seq_len - offset);
79
80        // Extract query chunk
81        let q_chunk = q.narrow(2, offset, chunk_len)?;
82
83        // Extract mask chunk if present
84        let mask_chunk = mask
85            .map(|m| {
86                match m.rank() {
87                    2 => {
88                        // For 2D masks (seq_len, seq_len), narrow along dimension 0
89                        m.narrow(0, offset, chunk_len)
90                    }
91                    3 => {
92                        // For 3D masks (batch, seq_len, seq_len), narrow along dimension 1
93                        m.narrow(1, offset, chunk_len)
94                    }
95                    4 => {
96                        // For 4D masks (batch, heads, seq_len, seq_len), narrow along dimension 2
97                        m.narrow(2, offset, chunk_len)
98                    }
99                    _ => m.narrow(2, offset, chunk_len), // Default to dimension 2
100                }
101            })
102            .transpose()?;
103
104        // Compute attention for this chunk
105        let att_chunk = attention_fn(&q_chunk, k, v, mask_chunk.as_ref())?;
106
107        attn_chunks.push(att_chunk);
108    }
109
110    // Concatenate all chunks along the sequence dimension
111    Tensor::cat(&attn_chunks, 2)
112}
113
114fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
115    if n_rep == 1 {
116        Ok(x)
117    } else {
118        let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
119        Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
120    }
121}
122
123pub struct SdpaParams {
124    pub n_kv_groups: usize,
125    pub softcap: Option<f32>,
126    pub softmax_scale: f32,
127    pub sliding_window: Option<usize>,
128    pub sinks: Option<Tensor>,
129}
130
131pub struct Sdpa;
132
133impl Sdpa {
134    /// Computes softmax(QK^T*sqrt(d_k))V
135    ///
136    /// Inputs:
137    /// - q: (b_sz, n_attn_heads, q_len, head_dim)
138    /// - k: (b_sz, n_kv_heads, q_len, head_dim)
139    /// - v: (b_sz, n_kv_heads, q_len, head_dim)
140    ///
141    /// Dispatch attention based on the `AttentionMask` variant:
142    ///
143    /// - `AttentionMask::CausalFlash`: flash attention with `is_causal = true`
144    /// - `AttentionMask::None`: flash if available (decode), else eager without mask
145    /// - `AttentionMask::Custom`: eager attention with the explicit mask tensor
146    #[allow(unused_variables, clippy::too_many_arguments)]
147    pub fn run_attention(
148        &self,
149        q: &Tensor,
150        k: &Tensor,
151        v: &Tensor,
152        mask: &AttentionMask,
153        flash_params: Option<&FlashParams>,
154        sdpa_params: &SdpaParams,
155    ) -> Result<Tensor> {
156        // If sinks are present, dispatch to the sinks backend
157        if let Some(sinks) = &sdpa_params.sinks {
158            let mask_tensor = match mask {
159                AttentionMask::Custom(t) => Some(t),
160                _ => None,
161            };
162            return sinks_attn(q, k, v, sinks, mask_tensor, flash_params, sdpa_params);
163        }
164
165        // The mask carries causality already; the kernel-level do_causal
166        // early-exit is safe to enable only when the request is known causal.
167        let do_causal = flash_params.is_some_and(|p| p.causal);
168
169        // Custom mask, eager attention (flash can't use arbitrary mask tensors)
170        if let AttentionMask::Custom(mask_tensor) = mask {
171            return self.run_attention_noflash(q, k, v, Some(mask_tensor), sdpa_params, do_causal);
172        }
173
174        // CausalFlash or None: try flash attention, fall back to eager
175        let can_use_flash = q.device().is_cpu()
176            || q.device().is_cuda() && crate::using_flash_attn() && q.dtype() != DType::F32;
177
178        if can_use_flash {
179            // flash-attn expects (b_sz, seq_len, nheads, head_dim)
180            let q = q.transpose(1, 2)?;
181            let k = k.transpose(1, 2)?;
182            let v = v.transpose(1, 2)?;
183
184            if q.device().is_cpu() {
185                match q.dtype() {
186                    DType::F32 => {
187                        return cpu::run_flash_attn_cpu::<f32>(&q, &k, &v, None, sdpa_params);
188                    }
189                    DType::F16 => {
190                        return cpu::run_flash_attn_cpu::<half::f16>(&q, &k, &v, None, sdpa_params)
191                    }
192                    DType::BF16 => {
193                        return cpu::run_flash_attn_cpu::<half::bf16>(
194                            &q,
195                            &k,
196                            &v,
197                            None,
198                            sdpa_params,
199                        );
200                    }
201                    _ => {
202                        return Err(hanzo_ml::Error::Msg("Unsupported data type".into()));
203                    }
204                }
205            } else {
206                return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
207            }
208        }
209
210        self.run_attention_noflash(q, k, v, None, sdpa_params, do_causal)
211    }
212
213    /// Same as `run_attention`, but skips the flash-attention dispatch.
214    ///
215    /// `causal` tells the Metal SDPA-full kernel to enable its upper-triangle skip (`do_causal=true`).
216    /// Pass `true` only when the caller's mask is causal-or-stricter.
217    /// Pass false` for bidirectional masks (e.g. vision attention).
218    #[allow(unused_variables, clippy::too_many_arguments)]
219    pub fn run_attention_noflash(
220        &self,
221        q: &Tensor,
222        k: &Tensor,
223        v: &Tensor,
224        mask: Option<&Tensor>,
225        sdpa_params: &SdpaParams,
226        causal: bool,
227    ) -> Result<Tensor> {
228        let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
229        let (_, _, _, k_head_dim) = k.dims4()?;
230        let (_, _, _, v_head_dim) = v.dims4()?;
231
232        // We can use Metal SDPA (vector/full) if the mask is the correct size and head dims match.
233        // If the mask is provided, then softcapping isn't allowed - default back to naive SDPA
234        // Softcapping is implemented for vector SDPA.
235        let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
236        let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
237        let can_use_mask = mask.is_none_or(|mask| {
238            mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
239                && sdpa_params.softcap.is_none_or(|x| x == 1.0)
240        });
241        let valid_head_dims: &[usize] = &[32, 64, 72, 80, 96, 128, 256, 512];
242        // Metal SDPA full kernel requires q_seq <= k_seq when a mask is present.
243        let metal_supports_mask = mask.is_none() || seq_len <= k.dim(2)?;
244
245        // Metal FA path for DK=512 BF16 with a mask. Two specializations:
246        // prefill (seq_len > 8) goes through the BlockMMA kernel; decode
247        // (seq_len == 1) uses a vector FA kernel ported from llama.cpp.
248        if [q, k, v].into_iter().all(|x| x.device().is_metal())
249            && head_dim == 512
250            && k_head_dim == 512
251            && v_head_dim == 512
252            && q.dtype() == DType::BF16
253            && k.dtype() == DType::BF16
254            && v.dtype() == DType::BF16
255            && seq_len == 1
256            && mask.is_some()
257            && sdpa_params.softcap.is_none_or(|x| x == 1.0)
258        {
259            if let Some(out) =
260                crate::attention::backends::metal_flash_attn::try_flash_attn_ext_vec_bf16_dk512(
261                    q,
262                    k,
263                    v,
264                    mask,
265                    sdpa_params.softmax_scale,
266                )?
267            {
268                return Ok(out);
269            }
270        }
271        if [q, k, v].into_iter().all(|x| x.device().is_metal())
272            && head_dim == 512
273            && k_head_dim == 512
274            && v_head_dim == 512
275            && q.dtype() == DType::BF16
276            && k.dtype() == DType::BF16
277            && v.dtype() == DType::BF16
278            && seq_len > 8
279            && sdpa_params.softcap.is_none_or(|x| x == 1.0)
280        {
281            if let Some(mask) = mask {
282                if let Some(out) =
283                    crate::attention::backends::metal_flash_attn::try_flash_attn_ext_bf16_dk512(
284                        q,
285                        k,
286                        v,
287                        mask,
288                        sdpa_params.softmax_scale,
289                    )?
290                {
291                    return Ok(out);
292                }
293            }
294        }
295
296        if [q, k, v].into_iter().all(|x| x.device().is_metal())
297            && all_head_dims_match
298            && valid_head_dims.contains(&head_dim)
299            && can_use_mask
300            && metal_supports_mask
301            && !(head_dim == 512 && seq_len > 8)
302        {
303            let mask = match mask {
304                Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
305                None => None,
306            };
307            // do_causal lets the steel_attention kernel bound its kb-loop to
308            // the per-query position, skipping the upper triangle of Q*K^T
309            // entirely (roughly halves matmul cost for prefill).
310            let do_causal = seq_len > 1 && causal;
311            return hanzo_nn::ops::sdpa(
312                q,
313                k,
314                v,
315                mask.as_ref(),
316                do_causal,
317                sdpa_params.softmax_scale,
318                sdpa_params.softcap.unwrap_or(1.0),
319            );
320        }
321
322        let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
323        let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
324
325        if mask.is_some_and(|x| x.rank() == 2) || hanzo_quant::distributed::use_nccl() {
326            return naive_sdpa(
327                &q.contiguous()?,
328                &k.contiguous()?,
329                &v.contiguous()?,
330                mask,
331                sdpa_params,
332            );
333        }
334
335        // TODO: bench?
336        #[allow(unused)]
337        if let (Device::Cuda(_), Some(cublaslt)) = (
338            q.device(),
339            hanzo_quant::cublaslt::CUBLASLT_CONTROLLER.get_for_device(q.device()),
340        ) {
341            #[cfg(feature = "cuda")]
342            {
343                maybe_synchronize(q.device())?;
344
345                // Use chunked attention for cuBLASLt path
346                let k_flat = k.flatten(0, 1)?;
347                let v_flat = v.flatten(0, 1)?;
348
349                chunked_attention(q, &k, &v, mask, |q_chunk, _k, _v, mask_chunk| {
350                    // cuBLASLt batch matmul implementation requires inputs to be dims3
351                    let (chunk_b_sz, chunk_n_heads, chunk_seq_len, chunk_head_dim) =
352                        q_chunk.dims4()?;
353                    let q_flat = q_chunk.flatten(0, 1)?;
354
355                    let attention_bias = match mask_chunk {
356                        Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
357                            Some(mask.repeat((chunk_n_heads, 1, 1))?)
358                        }
359                        Some(mask) if mask.rank() == 3 => Some(mask.clone()),
360                        Some(mask) if mask.rank() == 4 => {
361                            let tgt_shape =
362                                vec![chunk_b_sz, chunk_n_heads, chunk_seq_len, k.dim(2)?];
363                            Some(mask.broadcast_as(tgt_shape)?.flatten(0, 1)?)
364                        }
365                        Some(mask) => {
366                            hanzo_ml::bail!("cublaslt attn mask: rank must be 3 or 4")
367                        }
368                        None => None,
369                    };
370
371                    // If attention_bias is set, we fuse the add by giving it as the output matrix
372                    // and setting beta to 1.0
373                    let beta = match attention_bias.is_some() {
374                        true => Some(1.0),
375                        false => None,
376                    };
377
378                    // Batch matrix multiplication
379                    // Fuse softmax scale and attention_bias add
380                    let mut attention_scores = cublaslt.batch_matmul(
381                        &k_flat,
382                        &q_flat,
383                        attention_bias.as_ref(),
384                        Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
385                        beta,
386                        None,
387                        None,
388                    )?;
389                    if let Some(softcap) = sdpa_params.softcap {
390                        attention_scores = (attention_scores.tanh()? * softcap as f64)?;
391                    }
392                    // Compute softmax in F32 for precision. BF16's 7 mantissa
393                    // bits cause exp() to lose information on long sequences.
394                    // Flash attention already computes softmax in F32; this
395                    // matches that behaviour for the eager path.
396                    let scores_dtype = attention_scores.dtype();
397                    if scores_dtype == DType::BF16 || scores_dtype == DType::F16 {
398                        attention_scores = attention_scores.to_dtype(DType::F32)?;
399                    }
400                    attention_scores = hanzo_nn::ops::softmax_last_dim(&attention_scores)?;
401                    if attention_scores.dtype() != scores_dtype {
402                        attention_scores = attention_scores.to_dtype(scores_dtype)?;
403                    }
404
405                    let context_layer = cublaslt.batch_matmul(
406                        &v_flat.t()?.contiguous()?,
407                        &attention_scores,
408                        // We save one allocation
409                        Some(&q_flat),
410                        None,
411                        None,
412                        None,
413                        None,
414                    )?;
415
416                    // Reshape to dims4
417                    context_layer.reshape((chunk_b_sz, chunk_n_heads, chunk_seq_len, v_head_dim))
418                })
419            }
420            #[cfg(not(feature = "cuda"))]
421            {
422                hanzo_ml::bail!("`cuda` feature is not enabled")
423            }
424        } else {
425            naive_sdpa(q, &k, &v, mask, sdpa_params)
426        }
427    }
428}