use metal::{
Buffer, CommandBufferRef, ComputePipelineState, Device, MTLResourceOptions,
MTLSize, NSUInteger,
};
use super::gpu_matvec::{encode_matvec, MatvecPipelines, MatvecSpec};
use super::metal::{MetalContext, MetalError};
use crate::riir::io::mtl_weight_buf::{MtlWeightBuf, MtlWeightBufError};
use crate::riir::variants::VARIANT;
use crate::riir::io::weight_file::WeightFile;
pub struct DenseMlpBuffers {
pub hidden_in: Buffer,
pub gate_out: Buffer,
pub up_out: Buffer,
pub act: Buffer,
pub out: Buffer,
}
impl DenseMlpBuffers {
pub fn new(device: &Device) -> Self {
let v = VARIANT;
let f32_buf = |n: usize| {
let b = device.new_buffer(
(n * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
unsafe {
std::ptr::write_bytes(
b.contents() as *mut u8,
0,
n * std::mem::size_of::<f32>(),
);
}
b
};
Self {
hidden_in: f32_buf(v.hidden_dim),
gate_out: f32_buf(v.dense_intermediate),
up_out: f32_buf(v.dense_intermediate),
act: f32_buf(v.dense_intermediate),
out: f32_buf(v.hidden_dim),
}
}
}
pub fn dense_mlp_layer_forward_gpu(
metal: &mut MetalContext,
pipes: &DenseMlpPipelines,
bufs: &mut DenseMlpBuffers,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
layer_idx: usize,
hidden: &[f32],
out: &mut [f32],
) -> Result<(), DenseMlpGpuError> {
let v = VARIANT;
if hidden.len() != v.hidden_dim {
return Err(DenseMlpGpuError::HiddenLen {
got: hidden.len(),
expected: v.hidden_dim,
});
}
if out.len() != v.hidden_dim {
return Err(DenseMlpGpuError::OutLen {
got: out.len(),
expected: v.hidden_dim,
});
}
unsafe {
std::ptr::copy_nonoverlapping(
hidden.as_ptr(),
bufs.hidden_in.contents() as *mut f32,
v.hidden_dim,
);
}
let cmdbuf = metal.queue().new_command_buffer();
encode_dense_mlp_layer_forward_gpu(
cmdbuf,
pipes,
wf,
wf_buf,
layer_idx,
&bufs.hidden_in,
&bufs.gate_out,
&bufs.up_out,
&bufs.act,
&bufs.out,
)?;
cmdbuf.commit();
cmdbuf.wait_until_completed();
unsafe {
std::ptr::copy_nonoverlapping(
bufs.out.contents() as *const f32,
out.as_mut_ptr(),
v.hidden_dim,
);
}
Ok(())
}
#[derive(Debug, thiserror::Error)]
pub enum DenseMlpGpuError {
#[error("hidden length {got} != hidden_dim ({expected})")]
HiddenLen { got: usize, expected: usize },
#[error("out length {got} != hidden_dim ({expected})")]
OutLen { got: usize, expected: usize },
#[error("missing tensor '{name}'")]
MissingTensor { name: String },
#[error("Metal: {0}")]
Metal(#[from] MetalError),
#[error("weight-buffer offset: {0}")]
Offset(#[from] MtlWeightBufError),
#[error(
"variant has dense_intermediate={got}; this build's variant has \
no dense MLP (first_k_dense_replace=0). Don't dispatch dense \
layers."
)]
NoDenseMlp { got: usize },
}
pub struct DenseMlpPipelines {
pub matvec: MatvecPipelines,
pub swiglu: ComputePipelineState,
}
impl DenseMlpPipelines {
pub fn fetch(metal: &mut MetalContext) -> Result<Self, MetalError> {
Ok(Self {
matvec: MatvecPipelines::fetch(metal)?,
swiglu: metal.pipeline("swiglu_fused")?.clone(),
})
}
}
pub fn encode_dense_mlp_layer_forward_gpu(
cmdbuf: &CommandBufferRef,
pipes: &DenseMlpPipelines,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
layer_idx: usize,
hidden: &Buffer,
gate_out: &Buffer,
up_out: &Buffer,
act: &Buffer,
out: &Buffer,
) -> Result<(), DenseMlpGpuError> {
let v = VARIANT;
if v.dense_intermediate == 0 {
return Err(DenseMlpGpuError::NoDenseMlp {
got: v.dense_intermediate,
});
}
let prefix = format!("model.layers.{layer_idx}.mlp");
encode_swiglu_ffn_layer_forward_gpu(
cmdbuf,
pipes,
wf,
wf_buf,
&prefix,
v.dense_intermediate as u32,
hidden,
gate_out,
up_out,
act,
out,
)
}
#[allow(clippy::too_many_arguments)]
pub fn encode_swiglu_ffn_layer_forward_gpu(
cmdbuf: &CommandBufferRef,
pipes: &DenseMlpPipelines,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
tensor_prefix: &str,
intermediate: u32,
hidden: &Buffer,
gate_out: &Buffer,
up_out: &Buffer,
act: &Buffer,
out: &Buffer,
) -> Result<(), DenseMlpGpuError> {
let v = VARIANT;
let hidden_dim = v.hidden_dim as u32;
let resolve_proj =
|name: &str| -> Result<(u64, u64, u64), DenseMlpGpuError> {
let w = format!("{name}.weight");
let s = format!("{name}.scales");
let b = format!("{name}.biases");
let w_off = wf_buf
.tensor_offset(wf, &w)?
.ok_or(DenseMlpGpuError::MissingTensor { name: w })?;
let s_off = wf_buf
.tensor_offset(wf, &s)?
.ok_or(DenseMlpGpuError::MissingTensor { name: s })?;
let b_off = wf_buf
.tensor_offset(wf, &b)?
.ok_or(DenseMlpGpuError::MissingTensor { name: b })?;
Ok((w_off, s_off, b_off))
};
let gate_off = resolve_proj(&format!("{tensor_prefix}.gate_proj"))?;
let up_off = resolve_proj(&format!("{tensor_prefix}.up_proj"))?;
let down_off = resolve_proj(&format!("{tensor_prefix}.down_proj"))?;
encode_matvec(
cmdbuf,
&pipes.matvec,
wf_buf,
&MatvecSpec {
w_off: gate_off.0,
s_off: gate_off.1,
b_off: gate_off.2,
input: hidden,
output: gate_out,
out_dim: intermediate,
in_dim: hidden_dim,
bits: 4,
},
);
encode_matvec(
cmdbuf,
&pipes.matvec,
wf_buf,
&MatvecSpec {
w_off: up_off.0,
s_off: up_off.1,
b_off: up_off.2,
input: hidden,
output: up_out,
out_dim: intermediate,
in_dim: hidden_dim,
bits: 4,
},
);
encode_swiglu(cmdbuf, &pipes.swiglu, gate_out, up_out, act, intermediate);
encode_matvec(
cmdbuf,
&pipes.matvec,
wf_buf,
&MatvecSpec {
w_off: down_off.0,
s_off: down_off.1,
b_off: down_off.2,
input: act,
output: out,
out_dim: hidden_dim,
in_dim: intermediate,
bits: 4,
},
);
Ok(())
}
fn encode_swiglu(
cmdbuf: &CommandBufferRef,
pipeline: &ComputePipelineState,
gate: &Buffer,
up: &Buffer,
act: &Buffer,
dim: u32,
) {
let enc = cmdbuf.new_compute_command_encoder();
enc.set_compute_pipeline_state(pipeline);
enc.set_buffer(0, Some(gate), 0);
enc.set_buffer(1, Some(up), 0);
enc.set_buffer(2, Some(act), 0);
enc.set_bytes(3, 4, (&dim as *const u32).cast());
let num_tgs = (dim + 255) / 256;
enc.dispatch_thread_groups(
MTLSize::new(num_tgs as NSUInteger, 1, 1),
MTLSize::new(256, 1, 1),
);
enc.end_encoding();
}