1#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2
3use crate::{attention::backends::cpu, pipeline::text_models_inputs_processor::FlashParams};
4
5use hanzo_ml::{DType, Device, Result, Tensor};
6
7#[derive(Clone, Debug)]
13pub enum AttentionMask {
14 None,
16 CausalFlash,
20 Custom(Tensor),
23}
24
25impl AttentionMask {
26 pub fn as_option_tensor(&self) -> Option<&Tensor> {
32 match self {
33 Self::Custom(t) => Some(t),
34 _ => None,
35 }
36 }
37
38 pub fn is_custom(&self) -> bool {
42 matches!(self, Self::Custom(_))
43 }
44}
45
46mod backends;
47
48#[allow(unused)]
49pub(crate) use backends::{flash_attn, maybe_synchronize, naive_sdpa, sinks_attn};
50
51pub(crate) const ATTENTION_CHUNK_SIZE: usize = 1024;
53
54pub(crate) fn chunked_attention<F>(
56 q: &Tensor,
57 k: &Tensor,
58 v: &Tensor,
59 mask: Option<&Tensor>,
60 attention_fn: F,
61) -> Result<Tensor>
62where
63 F: Fn(&Tensor, &Tensor, &Tensor, Option<&Tensor>) -> Result<Tensor>,
64{
65 let seq_len = q.dim(2)?;
66
67 if seq_len <= ATTENTION_CHUNK_SIZE {
68 return attention_fn(q, k, v, mask);
70 }
71
72 let num_chunks = seq_len.div_ceil(ATTENTION_CHUNK_SIZE);
74 let mut attn_chunks = Vec::with_capacity(num_chunks);
75
76 for chunk_idx in 0..num_chunks {
77 let offset = chunk_idx * ATTENTION_CHUNK_SIZE;
78 let chunk_len = ATTENTION_CHUNK_SIZE.min(seq_len - offset);
79
80 let q_chunk = q.narrow(2, offset, chunk_len)?;
82
83 let mask_chunk = mask
85 .map(|m| {
86 match m.rank() {
87 2 => {
88 m.narrow(0, offset, chunk_len)
90 }
91 3 => {
92 m.narrow(1, offset, chunk_len)
94 }
95 4 => {
96 m.narrow(2, offset, chunk_len)
98 }
99 _ => m.narrow(2, offset, chunk_len), }
101 })
102 .transpose()?;
103
104 let att_chunk = attention_fn(&q_chunk, k, v, mask_chunk.as_ref())?;
106
107 attn_chunks.push(att_chunk);
108 }
109
110 Tensor::cat(&attn_chunks, 2)
112}
113
114fn repeat_kv(x: Tensor, n_rep: usize) -> Result<Tensor> {
115 if n_rep == 1 {
116 Ok(x)
117 } else {
118 let (b_sz, n_kv_head, seq_len, head_dim) = x.dims4()?;
119 Tensor::cat(&vec![&x; n_rep], 2)?.reshape((b_sz, n_kv_head * n_rep, seq_len, head_dim))
120 }
121}
122
123pub struct SdpaParams {
124 pub n_kv_groups: usize,
125 pub softcap: Option<f32>,
126 pub softmax_scale: f32,
127 pub sliding_window: Option<usize>,
128 pub sinks: Option<Tensor>,
129}
130
131pub struct Sdpa;
132
133impl Sdpa {
134 #[allow(unused_variables, clippy::too_many_arguments)]
147 pub fn run_attention(
148 &self,
149 q: &Tensor,
150 k: &Tensor,
151 v: &Tensor,
152 mask: &AttentionMask,
153 flash_params: Option<&FlashParams>,
154 sdpa_params: &SdpaParams,
155 ) -> Result<Tensor> {
156 if let Some(sinks) = &sdpa_params.sinks {
158 let mask_tensor = match mask {
159 AttentionMask::Custom(t) => Some(t),
160 _ => None,
161 };
162 return sinks_attn(q, k, v, sinks, mask_tensor, flash_params, sdpa_params);
163 }
164
165 let do_causal = flash_params.is_some_and(|p| p.causal);
168
169 if let AttentionMask::Custom(mask_tensor) = mask {
171 return self.run_attention_noflash(q, k, v, Some(mask_tensor), sdpa_params, do_causal);
172 }
173
174 let can_use_flash = q.device().is_cpu()
176 || q.device().is_cuda() && crate::using_flash_attn() && q.dtype() != DType::F32;
177
178 if can_use_flash {
179 let q = q.transpose(1, 2)?;
181 let k = k.transpose(1, 2)?;
182 let v = v.transpose(1, 2)?;
183
184 if q.device().is_cpu() {
185 match q.dtype() {
186 DType::F32 => {
187 return cpu::run_flash_attn_cpu::<f32>(&q, &k, &v, None, sdpa_params);
188 }
189 DType::F16 => {
190 return cpu::run_flash_attn_cpu::<half::f16>(&q, &k, &v, None, sdpa_params)
191 }
192 DType::BF16 => {
193 return cpu::run_flash_attn_cpu::<half::bf16>(
194 &q,
195 &k,
196 &v,
197 None,
198 sdpa_params,
199 );
200 }
201 _ => {
202 return Err(hanzo_ml::Error::Msg("Unsupported data type".into()));
203 }
204 }
205 } else {
206 return flash_attn(&q, &k, &v, flash_params, sdpa_params)?.transpose(1, 2);
207 }
208 }
209
210 self.run_attention_noflash(q, k, v, None, sdpa_params, do_causal)
211 }
212
213 #[allow(unused_variables, clippy::too_many_arguments)]
219 pub fn run_attention_noflash(
220 &self,
221 q: &Tensor,
222 k: &Tensor,
223 v: &Tensor,
224 mask: Option<&Tensor>,
225 sdpa_params: &SdpaParams,
226 causal: bool,
227 ) -> Result<Tensor> {
228 let (b_sz, n_attn_heads, seq_len, head_dim) = q.dims4()?;
229 let (_, _, _, k_head_dim) = k.dims4()?;
230 let (_, _, _, v_head_dim) = v.dims4()?;
231
232 let all_head_dims_match = head_dim == k_head_dim && k_head_dim == v_head_dim;
236 let tgt_mask_shape = vec![b_sz, n_attn_heads, seq_len, k.dim(2)?];
237 let can_use_mask = mask.is_none_or(|mask| {
238 mask.layout().broadcast_as(tgt_mask_shape.clone()).is_ok()
239 && sdpa_params.softcap.is_none_or(|x| x == 1.0)
240 });
241 let valid_head_dims: &[usize] = &[32, 64, 72, 80, 96, 128, 256, 512];
242 let metal_supports_mask = mask.is_none() || seq_len <= k.dim(2)?;
244
245 if [q, k, v].into_iter().all(|x| x.device().is_metal())
249 && head_dim == 512
250 && k_head_dim == 512
251 && v_head_dim == 512
252 && q.dtype() == DType::BF16
253 && k.dtype() == DType::BF16
254 && v.dtype() == DType::BF16
255 && seq_len == 1
256 && mask.is_some()
257 && sdpa_params.softcap.is_none_or(|x| x == 1.0)
258 {
259 if let Some(out) =
260 crate::attention::backends::metal_flash_attn::try_flash_attn_ext_vec_bf16_dk512(
261 q,
262 k,
263 v,
264 mask,
265 sdpa_params.softmax_scale,
266 )?
267 {
268 return Ok(out);
269 }
270 }
271 if [q, k, v].into_iter().all(|x| x.device().is_metal())
272 && head_dim == 512
273 && k_head_dim == 512
274 && v_head_dim == 512
275 && q.dtype() == DType::BF16
276 && k.dtype() == DType::BF16
277 && v.dtype() == DType::BF16
278 && seq_len > 8
279 && sdpa_params.softcap.is_none_or(|x| x == 1.0)
280 {
281 if let Some(mask) = mask {
282 if let Some(out) =
283 crate::attention::backends::metal_flash_attn::try_flash_attn_ext_bf16_dk512(
284 q,
285 k,
286 v,
287 mask,
288 sdpa_params.softmax_scale,
289 )?
290 {
291 return Ok(out);
292 }
293 }
294 }
295
296 if [q, k, v].into_iter().all(|x| x.device().is_metal())
297 && all_head_dims_match
298 && valid_head_dims.contains(&head_dim)
299 && can_use_mask
300 && metal_supports_mask
301 && !(head_dim == 512 && seq_len > 8)
302 {
303 let mask = match mask {
304 Some(mask) => Some(mask.broadcast_as(tgt_mask_shape)?),
305 None => None,
306 };
307 let do_causal = seq_len > 1 && causal;
311 return hanzo_nn::ops::sdpa(
312 q,
313 k,
314 v,
315 mask.as_ref(),
316 do_causal,
317 sdpa_params.softmax_scale,
318 sdpa_params.softcap.unwrap_or(1.0),
319 );
320 }
321
322 let k = repeat_kv(k.clone(), sdpa_params.n_kv_groups)?;
323 let v = repeat_kv(v.clone(), sdpa_params.n_kv_groups)?;
324
325 if mask.is_some_and(|x| x.rank() == 2) || hanzo_quant::distributed::use_nccl() {
326 return naive_sdpa(
327 &q.contiguous()?,
328 &k.contiguous()?,
329 &v.contiguous()?,
330 mask,
331 sdpa_params,
332 );
333 }
334
335 #[allow(unused)]
337 if let (Device::Cuda(_), Some(cublaslt)) = (
338 q.device(),
339 hanzo_quant::cublaslt::CUBLASLT_CONTROLLER.get_for_device(q.device()),
340 ) {
341 #[cfg(feature = "cuda")]
342 {
343 maybe_synchronize(q.device())?;
344
345 let k_flat = k.flatten(0, 1)?;
347 let v_flat = v.flatten(0, 1)?;
348
349 chunked_attention(q, &k, &v, mask, |q_chunk, _k, _v, mask_chunk| {
350 let (chunk_b_sz, chunk_n_heads, chunk_seq_len, chunk_head_dim) =
352 q_chunk.dims4()?;
353 let q_flat = q_chunk.flatten(0, 1)?;
354
355 let attention_bias = match mask_chunk {
356 Some(mask) if mask.rank() == 3 && mask.dims()[0] == 1 => {
357 Some(mask.repeat((chunk_n_heads, 1, 1))?)
358 }
359 Some(mask) if mask.rank() == 3 => Some(mask.clone()),
360 Some(mask) if mask.rank() == 4 => {
361 let tgt_shape =
362 vec![chunk_b_sz, chunk_n_heads, chunk_seq_len, k.dim(2)?];
363 Some(mask.broadcast_as(tgt_shape)?.flatten(0, 1)?)
364 }
365 Some(mask) => {
366 hanzo_ml::bail!("cublaslt attn mask: rank must be 3 or 4")
367 }
368 None => None,
369 };
370
371 let beta = match attention_bias.is_some() {
374 true => Some(1.0),
375 false => None,
376 };
377
378 let mut attention_scores = cublaslt.batch_matmul(
381 &k_flat,
382 &q_flat,
383 attention_bias.as_ref(),
384 Some(sdpa_params.softmax_scale / sdpa_params.softcap.unwrap_or(1.0)),
385 beta,
386 None,
387 None,
388 )?;
389 if let Some(softcap) = sdpa_params.softcap {
390 attention_scores = (attention_scores.tanh()? * softcap as f64)?;
391 }
392 let scores_dtype = attention_scores.dtype();
397 if scores_dtype == DType::BF16 || scores_dtype == DType::F16 {
398 attention_scores = attention_scores.to_dtype(DType::F32)?;
399 }
400 attention_scores = hanzo_nn::ops::softmax_last_dim(&attention_scores)?;
401 if attention_scores.dtype() != scores_dtype {
402 attention_scores = attention_scores.to_dtype(scores_dtype)?;
403 }
404
405 let context_layer = cublaslt.batch_matmul(
406 &v_flat.t()?.contiguous()?,
407 &attention_scores,
408 Some(&q_flat),
410 None,
411 None,
412 None,
413 None,
414 )?;
415
416 context_layer.reshape((chunk_b_sz, chunk_n_heads, chunk_seq_len, v_head_dim))
418 })
419 }
420 #[cfg(not(feature = "cuda"))]
421 {
422 hanzo_ml::bail!("`cuda` feature is not enabled")
423 }
424 } else {
425 naive_sdpa(q, &k, &v, mask, sdpa_params)
426 }
427 }
428}