use metal::{
Buffer, CommandBufferRef, ComputePipelineState, MTLSize, NSUInteger,
};
use super::metal::{MetalBackend, MetalError, MtlBuffer};
use super::variants::{RMS_NORM_EPS, VARIANT};
#[derive(Debug, thiserror::Error)]
pub enum GpuNormError {
#[error("x must be HIDDEN_DIM={expected} floats, got {actual}")]
BadXLen { expected: usize, actual: usize },
#[error("out must be HIDDEN_DIM={expected} floats, got {actual}")]
BadOutLen { expected: usize, actual: usize },
#[error(
"weight_bf16 must be HIDDEN_DIM*2={expected} bytes, got {actual}"
)]
BadWeightLen { expected: usize, actual: usize },
#[error("Metal backend: {0}")]
Metal(#[from] MetalError),
}
pub fn gpu_rms_norm_fused(
metal: &mut MetalBackend,
x: &[f32],
weight_bf16: &[u8],
out: &mut [f32],
) -> Result<(), GpuNormError> {
let v = VARIANT;
if x.len() != v.hidden_dim {
return Err(GpuNormError::BadXLen {
expected: v.hidden_dim,
actual: x.len(),
});
}
if out.len() != v.hidden_dim {
return Err(GpuNormError::BadOutLen {
expected: v.hidden_dim,
actual: out.len(),
});
}
let expected_w = v.hidden_dim * 2;
if weight_bf16.len() != expected_w {
return Err(GpuNormError::BadWeightLen {
expected: expected_w,
actual: weight_bf16.len(),
});
}
let sum_pipe = metal.pipeline("rms_norm_sum_sq")?.clone();
let apply_pipe = metal.pipeline("rms_norm_apply_bf16")?.clone();
let device = metal.device();
let buf_x = MtlBuffer::<f32>::with_data(device, x);
let buf_w = MtlBuffer::<u8>::with_data(device, weight_bf16);
let buf_sum_sq = MtlBuffer::<f32>::with_len(device, 1);
let buf_out = MtlBuffer::<f32>::with_len(device, v.hidden_dim);
let cmdbuf = metal.queue().new_command_buffer();
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&sum_pipe);
enc.set_buffer(0, Some(buf_x.raw()), 0);
enc.set_buffer(1, Some(buf_sum_sq.raw()), 0);
let dim = v.hidden_dim as u32;
enc.set_bytes(2, 4, (&dim as *const u32).cast());
enc.dispatch_thread_groups(
MTLSize::new(1, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&apply_pipe);
enc.set_buffer(0, Some(buf_x.raw()), 0);
enc.set_buffer(1, Some(buf_w.raw()), 0);
enc.set_buffer(2, Some(buf_sum_sq.raw()), 0);
enc.set_buffer(3, Some(buf_out.raw()), 0);
let dim = v.hidden_dim as u32;
let eps = RMS_NORM_EPS;
enc.set_bytes(4, 4, (&dim as *const u32).cast());
enc.set_bytes(5, 4, (&eps as *const f32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
cmdbuf.commit();
cmdbuf.wait_until_completed();
out.copy_from_slice(&buf_out.to_vec());
Ok(())
}
pub struct RmsNormBf16Pipelines {
pub sum: ComputePipelineState,
pub apply: ComputePipelineState,
}
impl RmsNormBf16Pipelines {
pub fn fetch(metal: &mut MetalBackend) -> Result<Self, MetalError> {
Ok(Self {
sum: metal.pipeline("rms_norm_sum_sq")?.clone(),
apply: metal.pipeline("rms_norm_apply_bf16")?.clone(),
})
}
}
#[allow(clippy::too_many_arguments)]
pub fn encode_rms_norm_bf16_into(
cmdbuf: &CommandBufferRef,
pipes: &RmsNormBf16Pipelines,
input: &Buffer,
weight_buf: &Buffer,
weight_off: u64,
sum_sq: &Buffer,
out: &Buffer,
dim: u32,
eps: f32,
) {
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&pipes.sum);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(sum_sq), 0);
enc.set_bytes(2, 4, (&dim as *const u32).cast());
enc.dispatch_thread_groups(
MTLSize::new(1, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
{
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(&pipes.apply);
enc.set_buffer(0, Some(input), 0);
enc.set_buffer(1, Some(weight_buf), weight_off as NSUInteger);
enc.set_buffer(2, Some(sum_sq), 0);
enc.set_buffer(3, Some(out), 0);
enc.set_bytes(4, 4, (&dim as *const u32).cast());
enc.set_bytes(5, 4, (&eps as *const f32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}
}