use metal::{
Buffer, CommandBufferRef, ComputePipelineState, Device, MTLResourceOptions,
MTLSize, NSUInteger,
};
use super::encoder::pipeline_bundle;
use super::metal::{MetalContext, MetalError, MtlBuffer};
use crate::riir::variants::{RMS_NORM_EPS, VARIANT};
#[derive(Debug, thiserror::Error)]
pub enum GpuNormError {
#[error("x must be HIDDEN_DIM={expected} floats, got {actual}")]
BadXLen { expected: usize, actual: usize },
#[error("out must be HIDDEN_DIM={expected} floats, got {actual}")]
BadOutLen { expected: usize, actual: usize },
#[error(
"weight_bf16 must be HIDDEN_DIM*2={expected} bytes, got {actual}"
)]
BadWeightLen { expected: usize, actual: usize },
#[error("Metal backend: {0}")]
Metal(#[from] MetalError),
}
pub fn gpu_rms_norm_fused(
metal: &mut MetalContext,
x: &[f32],
weight_bf16: &[u8],
out: &mut [f32],
) -> Result<(), GpuNormError> {
let v = VARIANT;
if x.len() != v.hidden_dim {
return Err(GpuNormError::BadXLen {
expected: v.hidden_dim,
actual: x.len(),
});
}
if out.len() != v.hidden_dim {
return Err(GpuNormError::BadOutLen {
expected: v.hidden_dim,
actual: out.len(),
});
}
let expected_w = v.hidden_dim * 2;
if weight_bf16.len() != expected_w {
return Err(GpuNormError::BadWeightLen {
expected: expected_w,
actual: weight_bf16.len(),
});
}
let sum_pipe = metal.pipeline("rms_norm_sum_sq")?.clone();
let apply_pipe = metal.pipeline("rms_norm_apply_bf16")?.clone();
let device = metal.device();
let buf_x = MtlBuffer::<f32>::with_data(device, x);
let buf_w = MtlBuffer::<u8>::with_data(device, weight_bf16);
let buf_sum_sq = MtlBuffer::<f32>::with_len(device, 1);
let buf_out = MtlBuffer::<f32>::with_len(device, v.hidden_dim);
let cmdbuf = metal.queue().new_command_buffer();
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&sum_pipe);
enc.set_buffer(0, Some(buf_x.raw()), 0);
enc.set_buffer(1, Some(buf_sum_sq.raw()), 0);
let dim = v.hidden_dim as u32;
enc.set_bytes(2, 4, (&dim as *const u32).cast());
enc.dispatch_thread_groups(
MTLSize::new(1, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&apply_pipe);
enc.set_buffer(0, Some(buf_x.raw()), 0);
enc.set_buffer(1, Some(buf_w.raw()), 0);
enc.set_buffer(2, Some(buf_sum_sq.raw()), 0);
enc.set_buffer(3, Some(buf_out.raw()), 0);
let dim = v.hidden_dim as u32;
let eps = RMS_NORM_EPS;
enc.set_bytes(4, 4, (&dim as *const u32).cast());
enc.set_bytes(5, 4, (&eps as *const f32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
cmdbuf.commit();
cmdbuf.wait_until_completed();
out.copy_from_slice(&buf_out.to_vec());
Ok(())
}
pipeline_bundle! {
pub struct RmsNormBf16Pipelines {
sum => "rms_norm_sum_sq",
apply => "rms_norm_apply_bf16",
}
}
#[allow(clippy::too_many_arguments)]
pub fn encode_rms_norm_bf16_into(
cmdbuf: &CommandBufferRef,
pipes: &RmsNormBf16Pipelines,
input: &Buffer,
weight_buf: &Buffer,
weight_off: u64,
sum_sq: &Buffer,
out: &Buffer,
dim: u32,
eps: f32,
) {
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&pipes.sum);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(sum_sq), 0);
enc.set_bytes(2, 4, (&dim as *const u32).cast());
enc.dispatch_thread_groups(
MTLSize::new(1, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&pipes.apply);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(weight_buf), weight_off as NSUInteger);
enc.set_buffer(2, Some(sum_sq), 0);
enc.set_buffer(3, Some(out), 0);
enc.set_bytes(4, 4, (&dim as *const u32).cast());
enc.set_bytes(5, 4, (&eps as *const f32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
}
pipeline_bundle! {
pub struct RmsNormBf16FusedNTokensPipeline {
pso => "rms_norm_bf16_fused_n_tokens",
}
}
#[allow(clippy::too_many_arguments)]
pub fn encode_rms_norm_bf16_fused_n_tokens(
cmdbuf: &CommandBufferRef,
pipe: &RmsNormBf16FusedNTokensPipeline,
input: &Buffer,
weight_buf: &Buffer,
weight_off: u64,
out: &Buffer,
dim: u32,
n_tokens: u32,
eps: f32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&pipe.pso);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(weight_buf), weight_off as NSUInteger);
enc.set_buffer(2, Some(out), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
enc.set_bytes(4, 4, (&eps as *const f32).cast());
enc.dispatch_thread_groups(
MTLSize::new(n_tokens as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
pub fn encode_residual_add_n_tokens_into(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
a: &Buffer,
b: &Buffer,
out: &Buffer,
n_tokens: u32,
dim: u32,
) {
let total = n_tokens * dim;
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(out), 0);
enc.set_bytes(3, 4, (&total as *const u32).cast());
let num_tgs = (total + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
pub fn encode_rope_n_tokens_into(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
x: &Buffer,
inv_freq: &Buffer,
n_tokens: u32,
num_heads: u32,
head_dim: u32,
rotary_dim: u32,
start_pos: i32,
) {
let total = n_tokens * num_heads * (rotary_dim / 2);
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(x), 0);
enc.set_buffer(1, Some(inv_freq), 0);
enc.set_bytes(2, 4, (&n_tokens as *const u32).cast());
enc.set_bytes(3, 4, (&num_heads as *const u32).cast());
enc.set_bytes(4, 4, (&head_dim as *const u32).cast());
enc.set_bytes(5, 4, (&rotary_dim as *const u32).cast());
enc.set_bytes(6, 4, (&start_pos as *const i32).cast());
let num_tgs = (total + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
#[allow(clippy::too_many_arguments)]
pub fn encode_embed_gather_4bit_into(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
w_buf: &Buffer,
w_off: u64,
s_off: u64,
b_off: u64,
token_ids: &Buffer,
out: &Buffer,
n_tokens: u32,
hidden_dim: u32,
group_size: u32,
) {
let total = n_tokens * hidden_dim;
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(w_buf), w_off as NSUInteger);
enc.set_buffer(1, Some(w_buf), s_off as NSUInteger);
enc.set_buffer(2, Some(w_buf), b_off as NSUInteger);
enc.set_buffer(3, Some(token_ids), 0);
enc.set_buffer(4, Some(out), 0);
enc.set_bytes(5, 4, (&n_tokens as *const u32).cast());
enc.set_bytes(6, 4, (&hidden_dim as *const u32).cast());
enc.set_bytes(7, 4, (&group_size as *const u32).cast());
let num_tgs = (total + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
pub fn encode_buffer_copy_f32(
cmdbuf: &CommandBufferRef,
src: &Buffer,
dst: &Buffer,
dim: u32,
) {
let bytes = (dim as NSUInteger) * std::mem::size_of::<f32>() as NSUInteger;
let blit = cmdbuf.new_blit_command_encoder();
blit.copy_from_buffer(src, 0, dst, 0, bytes);
blit.end_encoding();
}
pub fn encode_residual_add_into(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
a: &Buffer,
b: &Buffer,
out: &Buffer,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(a), 0);
enc.set_buffer(1, Some(b), 0);
enc.set_buffer(2, Some(out), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
pub struct MlaForwardScratch {
pub hidden: Buffer,
pub residual: Buffer,
pub normed: Buffer,
pub block_out: Buffer,
pub sum_sq: Buffer,
}
impl MlaForwardScratch {
pub fn new(device: &Device) -> Self {
let v = VARIANT;
let f32_buf = |n: usize| {
device.new_buffer(
(n * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
)
};
Self {
hidden: f32_buf(v.hidden_dim),
residual: f32_buf(v.hidden_dim),
normed: f32_buf(v.hidden_dim),
block_out: f32_buf(v.hidden_dim),
sum_sq: f32_buf(1),
}
}
}