use crate::utils::EncoderProvider;
use crate::{
set_params, Buffer, ComputeCommandEncoder, ConstantValues, Device, EncoderParam, Kernels,
MetalKernelError, Source, Value,
};
use objc2_metal::{MTLResourceUsage, MTLSize};
#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
pub enum SdpaDType {
BF16,
F16,
F32,
}
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_full(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_strides: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_strides: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_buffer: &Buffer,
v_strides: &[usize],
mask_type: Option<SdpaDType>,
mask_buffer: Option<&Buffer>,
m_strides: Option<&[usize]>,
output: &Buffer,
o_strides: &[usize],
scale: f32,
do_causal: bool,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
#[derive(Debug)]
#[repr(C)]
struct AttnParams {
b: i32,
h: i32,
d: i32,
ql: i32,
kl: i32,
gqa_factor: i32,
scale: f32,
softcapping: f32, nq: i32,
nk: i32,
nq_aligned: i32,
nk_aligned: i32,
ql_rem: i32,
kl_rem: i32,
ql_off: i32,
q_strides: [i64; 3],
k_strides: [i64; 3],
v_strides: [i64; 3],
o_strides: [i64; 3],
}
#[derive(Debug)]
#[repr(C)]
struct AttnMaskParams {
m_strides: [i64; 3],
}
let bd = q_shape[q_shape.len() - 1];
if ![32, 64, 72, 80, 96, 128, 256, 512].contains(&bd) {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "full",
got: bd,
expected: vec![32, 64, 72, 80, 96, 128, 256, 512],
});
};
let (bq, bk, wm, wn): (usize, usize, usize, usize) = if bd == 512 {
if itype == SdpaDType::F32 {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "full (f32 unsupported at head_dim=512)",
got: bd,
expected: vec![32, 64, 72, 80, 96, 128, 256],
});
}
(8, 8, 1, 1)
} else {
let bk = if bd < 128 { 32 } else { 16 };
(32, bk, 4, 1)
};
let b = q_shape[0];
let h = q_shape[1];
let d = q_shape[3];
let gqa_factor = q_shape[1] / k_shape[1];
let ql = q_shape[2];
let kl = k_shape[2];
let align_q = (ql % bq) == 0;
let align_k = (kl % bk) == 0;
let has_mask = mask_buffer.is_some();
let itype_repr = match itype {
SdpaDType::BF16 => "bfloat16",
SdpaDType::F16 => "float16",
SdpaDType::F32 => "float32",
};
let mask_repr = match mask_type {
Some(SdpaDType::BF16) => "bfloat16",
Some(SdpaDType::F16) => "float16",
Some(SdpaDType::F32) => "float32",
None => itype_repr,
};
let name =
format!("steel_attention_{itype_repr}_bq{bq}_bk{bk}_bd{bd}_wm{wm}_wn{wn}_mask{mask_repr}");
let constants = Some(ConstantValues::new(vec![
(200, Value::Bool( align_q)),
(201, Value::Bool( align_k)),
(300, Value::Bool( has_mask)),
(301, Value::Bool( do_causal)),
]));
let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
let nq = (ql + bq - 1) / bq;
let nk = (kl + bk - 1) / bk;
let nq_aligned = ql / bq;
let nk_aligned = kl / bk;
let params = AttnParams {
b: b as i32,
h: h as i32,
d: d as i32,
ql: ql as i32,
kl: kl as i32,
gqa_factor: gqa_factor as i32,
scale,
softcapping: 1.0, nq: nq as i32,
nk: nk as i32,
nq_aligned: nq_aligned as i32,
nk_aligned: nk_aligned as i32,
ql_rem: ql.wrapping_sub(nq_aligned * bq) as i32,
kl_rem: kl.wrapping_sub(nk_aligned * bk) as i32,
ql_off: kl.wrapping_sub(ql) as i32,
q_strides: [
q_strides[0] as i64,
q_strides[1] as i64,
q_strides[2] as i64,
],
k_strides: [
k_strides[0] as i64,
k_strides[1] as i64,
k_strides[2] as i64,
],
v_strides: [
v_strides[0] as i64,
v_strides[1] as i64,
v_strides[2] as i64,
],
o_strides: [
o_strides[0] as i64,
o_strides[1] as i64,
o_strides[2] as i64,
],
};
impl EncoderParam for AttnParams {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_bytes(position, &data);
}
}
impl EncoderParam for AttnMaskParams {
fn set_param(encoder: &ComputeCommandEncoder, position: usize, data: Self) {
encoder.set_bytes(position, &data);
}
}
if let Some(mask) = mask_buffer {
let mask_strides = m_strides.unwrap();
let mask_params = AttnMaskParams {
m_strides: [
mask_strides[0] as i64,
mask_strides[1] as i64,
mask_strides[2] as i64,
],
};
encoder.use_resource(mask, MTLResourceUsage::Read);
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params,
mask_params,
mask
)
);
} else {
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
params
)
);
}
let grid_dims = MTLSize {
width: nq,
height: h,
depth: b,
};
let group_dims = MTLSize {
width: 32,
height: wm,
depth: wn,
};
encoder.use_resource(q_buffer, MTLResourceUsage::Read);
encoder.use_resource(k_buffer, MTLResourceUsage::Read);
encoder.use_resource(v_buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_vector(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_stride: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_stride: &[usize],
v_buffer: &Buffer,
output: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
let bk = q_shape.last().unwrap();
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
let n = k_shape[2] as i32;
let b = (q_shape[0] * q_shape[1]) as i32;
let kstride = k_stride[1];
let vstride = v_stride[1];
let name = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_float16_t_256",
(512, SdpaDType::F16) => "sdpa_vector_float16_t_512",
(32, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_256",
(512, SdpaDType::BF16) => "sdpa_vector_bfloat16_t_512",
(32, SdpaDType::F32) => "sdpa_vector_float_32",
(64, SdpaDType::F32) => "sdpa_vector_float_64",
(96, SdpaDType::F32) => "sdpa_vector_float_96",
(128, SdpaDType::F32) => "sdpa_vector_float_128",
(256, SdpaDType::F32) => "sdpa_vector_float_256",
(512, SdpaDType::F32) => "sdpa_vector_float_512",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector",
got: *other,
expected: vec![32, 64, 96, 128, 256, 512],
})
}
};
let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};
let constants = Some(ConstantValues::new(vec![(
20,
Value::Bool( false),
)]));
let pipeline = kernels.load_pipeline_with_constants(device, Source::Sdpa, name, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
output,
gqa_factor,
n,
kstride,
vstride,
alpha,
softcapping
)
);
let grid_dims = MTLSize {
width: 1,
height: b as usize,
depth: 1,
};
let group_dims = MTLSize {
width: 1024,
height: 1,
depth: 1,
};
encoder.use_resource(q_buffer, MTLResourceUsage::Read);
encoder.use_resource(k_buffer, MTLResourceUsage::Read);
encoder.use_resource(v_buffer, MTLResourceUsage::Read);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
Ok(())
}
pub const SDPA_2PASS_BLOCKS: usize = 32;
#[allow(clippy::too_many_arguments)]
pub fn call_sdpa_vector_2pass(
device: &Device,
ep: impl EncoderProvider,
kernels: &Kernels,
q_offset: usize,
q_shape: &[usize],
q_buffer: &Buffer,
k_offset: usize,
k_shape: &[usize],
k_stride: &[usize],
k_buffer: &Buffer,
v_offset: usize,
v_stride: &[usize],
v_buffer: &Buffer,
output: &Buffer,
intermediate: &Buffer,
sums: &Buffer,
maxs: &Buffer,
alpha: f32,
softcapping: f32,
itype: SdpaDType,
) -> Result<(), MetalKernelError> {
let bk = q_shape.last().unwrap();
{
let name_pass1 = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_256",
(512, SdpaDType::F16) => "sdpa_vector_2pass_1_float16_t_512",
(32, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_256",
(512, SdpaDType::BF16) => "sdpa_vector_2pass_1_bfloat16_t_512",
(32, SdpaDType::F32) => "sdpa_vector_2pass_1_float_32",
(64, SdpaDType::F32) => "sdpa_vector_2pass_1_float_64",
(96, SdpaDType::F32) => "sdpa_vector_2pass_1_float_96",
(128, SdpaDType::F32) => "sdpa_vector_2pass_1_float_128",
(256, SdpaDType::F32) => "sdpa_vector_2pass_1_float_256",
(512, SdpaDType::F32) => "sdpa_vector_2pass_1_float_512",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector_2pass_1",
got: *other,
expected: vec![32, 64, 96, 128, 256, 512],
})
}
};
let gqa_factor = (q_shape[1] / k_shape[1]) as i32;
let n = k_shape[2] as i32;
let b = (q_shape[0] * q_shape[1]) as i32;
let kstride = k_stride[1];
let vstride = v_stride[1];
let alpha = if softcapping != 1. {
alpha / softcapping
} else {
alpha
};
let constants = Some(ConstantValues::new(vec![(
20,
Value::Bool( false),
)]));
let pipeline =
kernels.load_pipeline_with_constants(device, Source::Sdpa, name_pass1, constants)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(
encoder,
(
(q_buffer, q_offset),
(k_buffer, k_offset),
(v_buffer, v_offset),
intermediate,
sums,
maxs,
gqa_factor,
n,
kstride,
vstride,
alpha,
softcapping
)
);
let grid_dims = MTLSize {
width: 1,
height: b as usize,
depth: SDPA_2PASS_BLOCKS,
};
let group_dims = MTLSize {
width: 8 * 32,
height: 1,
depth: 1,
};
encoder.use_resource(q_buffer, MTLResourceUsage::Read);
encoder.use_resource(k_buffer, MTLResourceUsage::Read);
encoder.use_resource(v_buffer, MTLResourceUsage::Read);
encoder.use_resource(intermediate, MTLResourceUsage::Write);
encoder.use_resource(sums, MTLResourceUsage::Write);
encoder.use_resource(maxs, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
}
{
let name_pass2 = match (bk, itype) {
(32, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_32",
(64, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_64",
(96, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_96",
(128, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_128",
(256, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_256",
(512, SdpaDType::F16) => "sdpa_vector_2pass_2_float16_t_512",
(32, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_32",
(64, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_64",
(96, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_96",
(128, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_128",
(256, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_256",
(512, SdpaDType::BF16) => "sdpa_vector_2pass_2_bfloat16_t_512",
(32, SdpaDType::F32) => "sdpa_vector_2pass_2_float_32",
(64, SdpaDType::F32) => "sdpa_vector_2pass_2_float_64",
(96, SdpaDType::F32) => "sdpa_vector_2pass_2_float_96",
(128, SdpaDType::F32) => "sdpa_vector_2pass_2_float_128",
(256, SdpaDType::F32) => "sdpa_vector_2pass_2_float_256",
(512, SdpaDType::F32) => "sdpa_vector_2pass_2_float_512",
(other, _) => {
return Err(MetalKernelError::SdpaHeadSizeMismatch {
variation: "vector_2pass_2",
got: *other,
expected: vec![32, 64, 96, 128, 256, 512],
})
}
};
let b = q_shape[0] * q_shape[1];
let pipeline = kernels.load_pipeline(device, Source::Sdpa, name_pass2)?;
let encoder = ep.encoder();
let encoder: &ComputeCommandEncoder = encoder.as_ref();
encoder.set_compute_pipeline_state(&pipeline);
set_params!(encoder, (intermediate, sums, maxs, output));
let grid_dims = MTLSize {
width: 1,
height: b,
depth: 1,
};
let group_dims = MTLSize {
width: 1024,
height: 1,
depth: 1,
};
encoder.use_resource(intermediate, MTLResourceUsage::Write);
encoder.use_resource(sums, MTLResourceUsage::Write);
encoder.use_resource(maxs, MTLResourceUsage::Write);
encoder.use_resource(output, MTLResourceUsage::Write);
encoder.dispatch_thread_groups(grid_dims, group_dims);
}
Ok(())
}