use hanzo_ml::{backend::BackendStorage, DType, MetalStorage, Result, Shape, Storage, Tensor};
use hanzo_quant::metal_kernels::{
call_flash_attn_ext_bf16_dk512, call_flash_attn_ext_vec_bf16_dk512,
flash_attn_ext_blk_scratch_size, Kernels, FA_NCPSG,
};
const HEAD_DIM: usize = 512;
pub(crate) fn try_flash_attn_ext_bf16_dk512(
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: &Tensor,
scale: f32,
) -> Result<Option<Tensor>> {
if q.dtype() != DType::BF16 || k.dtype() != DType::BF16 || v.dtype() != DType::BF16 {
return Ok(None);
}
if mask.dtype() != DType::F16 && mask.dtype() != DType::BF16 {
return Ok(None);
}
let q_dims = q.dims4()?;
let k_dims = k.dims4()?;
let v_dims = v.dims4()?;
if q_dims.3 != HEAD_DIM || k_dims.3 != HEAD_DIM || v_dims.3 != HEAD_DIM {
return Ok(None);
}
let mask = match mask.rank() {
2 => mask.unsqueeze(0)?.unsqueeze(0)?,
3 => mask.unsqueeze(0)?,
4 => mask.clone(),
_ => return Ok(None),
};
let (b, n_heads_q, q_seq, _) = q_dims;
let (b_kv, _n_heads_kv, k_seq, _) = k_dims;
if b != b_kv {
return Ok(None);
}
let q = q.contiguous()?;
let k = k.contiguous()?;
let v = v.contiguous()?;
let mask = if mask.dtype() == DType::F16 {
mask.contiguous()?
} else {
mask.to_dtype(DType::F16)?.contiguous()?
};
let q_s = q.storage_and_layout().0;
let Storage::Metal(q_s) = &*q_s else {
return Ok(None);
};
let k_s = k.storage_and_layout().0;
let Storage::Metal(k_s) = &*k_s else {
return Ok(None);
};
let v_s = v.storage_and_layout().0;
let Storage::Metal(v_s) = &*v_s else {
return Ok(None);
};
let m_s = mask.storage_and_layout().0;
let Storage::Metal(m_s) = &*m_s else {
return Ok(None);
};
let device = q_s.device().clone();
let out_shape = vec![b, n_heads_q, q_seq, HEAD_DIM];
let out_buf = device.new_buffer(out_shape.iter().product(), DType::BF16, "fa-ext-out")?;
let mask_shape = mask.dims();
let mask_stride = mask.stride();
let blk_bytes =
flash_attn_ext_blk_scratch_size(q_seq, k_seq, mask_shape[1].max(1), mask_shape[0].max(1));
let blk_scratch = device.new_buffer(blk_bytes, DType::U8, "fa-ext-blk")?;
let n_heads_kv = k_dims.1;
let pad_scratch = if k_seq % FA_NCPSG != 0 {
let head_bytes = HEAD_DIM * 2;
let kv_pad_bytes = head_bytes * FA_NCPSG * n_heads_kv.max(1) * b.max(1);
let mask_pad_bytes =
2 * FA_NCPSG * q_seq.max(1) * mask_shape[1].max(1) * mask_shape[0].max(1);
Some(device.new_buffer(2 * kv_pad_bytes + mask_pad_bytes, DType::U8, "fa-ext-pad")?)
} else {
None
};
let encoder = device.command_encoder()?;
encoder.set_label("flash-attn-ext-bf16-dk512");
call_flash_attn_ext_bf16_dk512(
device.device(),
&encoder,
&Kernels::new(),
(
q_s.buffer(),
q.layout().start_offset() * q.dtype().size_in_bytes(),
),
(
k_s.buffer(),
k.layout().start_offset() * k.dtype().size_in_bytes(),
),
(
v_s.buffer(),
v.layout().start_offset() * v.dtype().size_in_bytes(),
),
(
m_s.buffer(),
mask.layout().start_offset() * mask.dtype().size_in_bytes(),
),
&out_buf,
&blk_scratch,
pad_scratch.as_deref(),
q.dims(),
q.stride(),
k.dims(),
k.stride(),
v.stride(),
mask_shape,
mask_stride,
scale,
)
.map_err(hanzo_ml::Error::wrap)?;
let out = Tensor::from((
Storage::Metal(MetalStorage::new(
out_buf,
device.clone(),
out_shape.iter().product(),
DType::BF16,
)),
Shape::from(out_shape),
));
Ok(Some(out))
}
pub(crate) fn try_flash_attn_ext_vec_bf16_dk512(
q: &Tensor,
k: &Tensor,
v: &Tensor,
mask: Option<&Tensor>,
scale: f32,
) -> Result<Option<Tensor>> {
if q.dtype() != DType::BF16 || k.dtype() != DType::BF16 || v.dtype() != DType::BF16 {
return Ok(None);
}
if let Some(m) = mask {
if m.dtype() != DType::F16 && m.dtype() != DType::BF16 {
return Ok(None);
}
}
let q_dims = q.dims4()?;
let k_dims = k.dims4()?;
let v_dims = v.dims4()?;
if q_dims.3 != HEAD_DIM || k_dims.3 != HEAD_DIM || v_dims.3 != HEAD_DIM {
return Ok(None);
}
let mask = if let Some(m) = mask {
Some(match m.rank() {
2 => m.unsqueeze(0)?.unsqueeze(0)?,
3 => m.unsqueeze(0)?,
4 => m.clone(),
_ => return Ok(None),
})
} else {
None
};
let (b, n_heads_q, q_seq, _) = q_dims;
let (b_kv, _n_heads_kv, _k_seq, _) = k_dims;
if b != b_kv {
return Ok(None);
}
let q = q.contiguous()?;
let k = k.contiguous()?;
let v = v.contiguous()?;
let mask = if let Some(m) = mask {
Some(if m.dtype() == DType::F16 {
m.contiguous()?
} else {
m.to_dtype(DType::F16)?.contiguous()?
})
} else {
None
};
let q_s = q.storage_and_layout().0;
let Storage::Metal(q_s) = &*q_s else {
return Ok(None);
};
let k_s = k.storage_and_layout().0;
let Storage::Metal(k_s) = &*k_s else {
return Ok(None);
};
let v_s = v.storage_and_layout().0;
let Storage::Metal(v_s) = &*v_s else {
return Ok(None);
};
let mask_storage_and_layout = mask.as_ref().map(|m| m.storage_and_layout());
let mask_metal = match mask_storage_and_layout.as_ref() {
Some((s, _)) => match &**s {
Storage::Metal(ms) => Some(ms),
_ => return Ok(None),
},
None => None,
};
let device = q_s.device().clone();
let out_shape = vec![b, n_heads_q, q_seq, HEAD_DIM];
let out_buf = device.new_buffer(out_shape.iter().product(), DType::BF16, "fa-vec-out")?;
let encoder = device.command_encoder()?;
encoder.set_label("flash-attn-ext-vec-bf16-dk512");
let (mask_buf, mask_offset, mask_dims, mask_stride) = match (mask_metal, mask.as_ref()) {
(Some(ms), Some(m)) => (
ms.buffer(),
m.layout().start_offset() * m.dtype().size_in_bytes(),
m.dims().to_vec(),
m.stride().to_vec(),
),
_ => (
q_s.buffer(),
0,
vec![1, 1, 1, k_dims.2],
vec![k_dims.2, k_dims.2, k_dims.2, 1],
),
};
call_flash_attn_ext_vec_bf16_dk512(
device.device(),
&encoder,
&Kernels::new(),
(
q_s.buffer(),
q.layout().start_offset() * q.dtype().size_in_bytes(),
),
(
k_s.buffer(),
k.layout().start_offset() * k.dtype().size_in_bytes(),
),
(
v_s.buffer(),
v.layout().start_offset() * v.dtype().size_in_bytes(),
),
(mask_buf, mask_offset),
&out_buf,
q.dims(),
q.stride(),
k.dims(),
k.stride(),
v.stride(),
&mask_dims,
&mask_stride,
scale,
)
.map_err(hanzo_ml::Error::wrap)?;
let out = Tensor::from((
Storage::Metal(MetalStorage::new(
out_buf,
device.clone(),
out_shape.iter().product(),
DType::BF16,
)),
Shape::from(out_shape),
));
Ok(Some(out))
}