use metal::{
BufferRef, CommandBufferRef, ComputePipelineState, MTLSize, NSUInteger,
};
use crate::riir::backend::gpu::encoder::pipeline_bundle;
use crate::riir::backend::gpu::metal::{MetalContext, MetalError, MtlBuffer};
#[derive(Debug, thiserror::Error)]
pub enum GpuAttnError {
#[error("buffer too short: {what} expected {expected} floats, got {actual}")]
BadLen {
what: &'static str,
expected: usize,
actual: usize,
},
#[error("non-positive shape: {what} = {value}")]
BadShape {
what: &'static str,
value: i64,
},
#[error("num_heads ({num_heads}) must be a multiple of num_kv_heads ({num_kv_heads})")]
BadGqa { num_heads: u32, num_kv_heads: u32 },
#[error("Metal backend: {0}")]
Metal(#[from] MetalError),
}
fn check_pos(what: &'static str, value: i64) -> Result<(), GpuAttnError> {
if value <= 0 {
return Err(GpuAttnError::BadShape { what, value });
}
Ok(())
}
fn check_len(
what: &'static str,
expected: usize,
actual: usize,
) -> Result<(), GpuAttnError> {
if actual != expected {
return Err(GpuAttnError::BadLen {
what,
expected,
actual,
});
}
Ok(())
}
pipeline_bundle! {
pub struct GpuAttnPipelines {
scores => "attn_scores_batched",
softmax => "attn_softmax_batched",
values => "attn_values_batched",
gate => "sigmoid_gate",
}
}
#[allow(clippy::too_many_arguments)]
pub fn encode_attn_scores_batched_into(
cmdbuf: &CommandBufferRef,
pipe: &ComputePipelineState,
q: &BufferRef,
k_cache: &BufferRef,
scores: &BufferRef,
num_heads: u32,
head_dim: u32,
kv_dim: u32,
seq_len: u32,
seq_stride: u32,
heads_per_kv: u32,
scale: f32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipe);
enc.set_buffer(0, Some(q), 0);
enc.set_buffer(1, Some(k_cache), 0);
enc.set_buffer(2, Some(scores), 0);
enc.set_bytes(3, 4, (&head_dim as *const u32).cast());
enc.set_bytes(4, 4, (&kv_dim as *const u32).cast());
enc.set_bytes(5, 4, (&seq_len as *const u32).cast());
enc.set_bytes(6, 4, (&seq_stride as *const u32).cast());
enc.set_bytes(7, 4, (&scale as *const f32).cast());
enc.set_bytes(8, 4, (&heads_per_kv as *const u32).cast());
enc.set_bytes(9, 4, (&seq_len as *const u32).cast());
let total_tgs = seq_len * num_heads;
enc.dispatch_thread_groups(
MTLSize::new(total_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
pub fn encode_attn_softmax_batched_into(
cmdbuf: &CommandBufferRef,
pipe: &ComputePipelineState,
scores: &BufferRef,
num_heads: u32,
seq_len: u32,
seq_stride: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipe);
enc.set_buffer(0, Some(scores), 0);
enc.set_bytes(1, 4, (&seq_len as *const u32).cast());
enc.set_bytes(2, 4, (&seq_stride as *const u32).cast());
enc.dispatch_thread_groups(
MTLSize::new(num_heads as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
pub fn encode_attn_values_batched_into(
cmdbuf: &CommandBufferRef,
pipe: &ComputePipelineState,
scores: &BufferRef,
v_cache: &BufferRef,
out: &BufferRef,
num_heads: u32,
head_dim: u32,
kv_dim: u32,
seq_len: u32,
seq_stride: u32,
heads_per_kv: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipe);
enc.set_buffer(0, Some(scores), 0);
enc.set_buffer(1, Some(v_cache), 0);
enc.set_buffer(2, Some(out), 0);
enc.set_bytes(3, 4, (&head_dim as *const u32).cast());
enc.set_bytes(4, 4, (&kv_dim as *const u32).cast());
enc.set_bytes(5, 4, (&seq_len as *const u32).cast());
enc.set_bytes(6, 4, (&seq_stride as *const u32).cast());
enc.set_bytes(7, 4, (&heads_per_kv as *const u32).cast());
let total_threads = head_dim * num_heads;
let tgs = (total_threads + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
pub fn encode_sigmoid_gate_into(
cmdbuf: &CommandBufferRef,
pipe: &ComputePipelineState,
x_inout: &BufferRef,
gate: &BufferRef,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipe);
enc.set_buffer(0, Some(x_inout), 0);
enc.set_buffer(1, Some(gate), 0);
enc.set_bytes(2, 4, (&dim as *const u32).cast());
let tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_attn_scores_batched(
metal: &mut MetalContext,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
seq_len: u32,
q: &[f32],
k_cache: &[f32],
scale: f32,
scores_out: &mut [f32],
) -> Result<(), GpuAttnError> {
check_pos("num_heads", num_heads as i64)?;
check_pos("num_kv_heads", num_kv_heads as i64)?;
check_pos("head_dim", head_dim as i64)?;
check_pos("seq_len", seq_len as i64)?;
if num_heads % num_kv_heads != 0 {
return Err(GpuAttnError::BadGqa {
num_heads,
num_kv_heads,
});
}
let kv_dim = num_kv_heads * head_dim;
let heads_per_kv = num_heads / num_kv_heads;
check_len("q", (num_heads * head_dim) as usize, q.len())?;
check_len("k_cache", (seq_len * kv_dim) as usize, k_cache.len())?;
check_len("scores_out", (num_heads * seq_len) as usize, scores_out.len())?;
let pipe = metal.pipeline("attn_scores_batched")?.clone();
let device = metal.device();
let buf_q = MtlBuffer::<f32>::with_data(device, q);
let buf_k = MtlBuffer::<f32>::with_data(device, k_cache);
let buf_scores = MtlBuffer::<f32>::with_len(device, scores_out.len());
let cmdbuf = metal.queue().new_command_buffer();
encode_attn_scores_batched_into(
cmdbuf,
&pipe,
buf_q.raw(),
buf_k.raw(),
buf_scores.raw(),
num_heads,
head_dim,
kv_dim,
seq_len,
seq_len,
heads_per_kv,
scale,
);
cmdbuf.commit();
cmdbuf.wait_until_completed();
scores_out.copy_from_slice(&buf_scores.to_vec());
Ok(())
}
pub fn gpu_attn_softmax_batched(
metal: &mut MetalContext,
num_heads: u32,
seq_len: u32,
scores_inout: &mut [f32],
) -> Result<(), GpuAttnError> {
check_pos("num_heads", num_heads as i64)?;
check_pos("seq_len", seq_len as i64)?;
check_len(
"scores_inout",
(num_heads * seq_len) as usize,
scores_inout.len(),
)?;
let pipe = metal.pipeline("attn_softmax_batched")?.clone();
let device = metal.device();
let buf_scores = MtlBuffer::<f32>::with_data(device, scores_inout);
let cmdbuf = metal.queue().new_command_buffer();
encode_attn_softmax_batched_into(
cmdbuf,
&pipe,
buf_scores.raw(),
num_heads,
seq_len,
seq_len,
);
cmdbuf.commit();
cmdbuf.wait_until_completed();
scores_inout.copy_from_slice(&buf_scores.to_vec());
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn gpu_attn_values_batched(
metal: &mut MetalContext,
num_heads: u32,
num_kv_heads: u32,
head_dim: u32,
seq_len: u32,
scores: &[f32],
v_cache: &[f32],
out: &mut [f32],
) -> Result<(), GpuAttnError> {
check_pos("num_heads", num_heads as i64)?;
check_pos("num_kv_heads", num_kv_heads as i64)?;
check_pos("head_dim", head_dim as i64)?;
check_pos("seq_len", seq_len as i64)?;
if num_heads % num_kv_heads != 0 {
return Err(GpuAttnError::BadGqa {
num_heads,
num_kv_heads,
});
}
let kv_dim = num_kv_heads * head_dim;
let heads_per_kv = num_heads / num_kv_heads;
check_len("scores", (num_heads * seq_len) as usize, scores.len())?;
check_len("v_cache", (seq_len * kv_dim) as usize, v_cache.len())?;
check_len("out", (num_heads * head_dim) as usize, out.len())?;
let pipe = metal.pipeline("attn_values_batched")?.clone();
let device = metal.device();
let buf_scores = MtlBuffer::<f32>::with_data(device, scores);
let buf_v = MtlBuffer::<f32>::with_data(device, v_cache);
let buf_out = MtlBuffer::<f32>::with_len(device, out.len());
let cmdbuf = metal.queue().new_command_buffer();
encode_attn_values_batched_into(
cmdbuf,
&pipe,
buf_scores.raw(),
buf_v.raw(),
buf_out.raw(),
num_heads,
head_dim,
kv_dim,
seq_len,
seq_len,
heads_per_kv,
);
cmdbuf.commit();
cmdbuf.wait_until_completed();
out.copy_from_slice(&buf_out.to_vec());
Ok(())
}
pub fn gpu_sigmoid_gate(
metal: &mut MetalContext,
dim: u32,
gate: &[f32],
x_inout: &mut [f32],
) -> Result<(), GpuAttnError> {
check_pos("dim", dim as i64)?;
check_len("gate", dim as usize, gate.len())?;
check_len("x_inout", dim as usize, x_inout.len())?;
let pipe = metal.pipeline("sigmoid_gate")?.clone();
let device = metal.device();
let buf_x = MtlBuffer::<f32>::with_data(device, x_inout);
let buf_g = MtlBuffer::<f32>::with_data(device, gate);
let cmdbuf = metal.queue().new_command_buffer();
encode_sigmoid_gate_into(cmdbuf, &pipe, buf_x.raw(), buf_g.raw(), dim);
cmdbuf.commit();
cmdbuf.wait_until_completed();
x_inout.copy_from_slice(&buf_x.to_vec());
Ok(())
}