use metal::{Buffer, MTLResourceOptions, NSUInteger};
use crate::riir::backend::gpu::gpu_matvec::{MatvecPipelines, MatvecSpec, encode_matvec};
use crate::riir::backend::gpu::metal::{MetalContext, MetalError};
use crate::riir::io::mtl_weight_buf::{MtlWeightBuf, MtlWeightBufError};
use crate::riir::variants::VARIANT;
use crate::riir::io::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum GpuLmHeadError {
#[error("Metal: {0}")]
Metal(#[from] MetalError),
#[error("weight buffer: {0}")]
WeightBuf(#[from] MtlWeightBufError),
#[error("missing lm_head tensor: {0}")]
MissingTensor(&'static str),
#[error("input length {got} != HIDDEN_DIM {expected}")]
InputLen { got: usize, expected: usize },
#[error("output length {got} != VOCAB_SIZE {expected}")]
OutputLen { got: usize, expected: usize },
}
pub struct GpuLmHead {
pipelines: MatvecPipelines,
input_buf: Buffer,
logits_buf: Buffer,
w_off: u64,
s_off: u64,
b_off: u64,
}
unsafe impl Send for GpuLmHead {}
impl GpuLmHead {
pub fn new(
metal: &mut MetalContext,
wf: &WeightFile,
wf_buf: &MtlWeightBuf,
) -> Result<Self, GpuLmHeadError> {
let pipelines = MatvecPipelines::fetch(metal)?;
let device = metal.device();
let v = VARIANT;
let input_buf = device.new_buffer(
(v.hidden_dim * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
let logits_buf = device.new_buffer(
(v.vocab_size * std::mem::size_of::<f32>()) as NSUInteger,
MTLResourceOptions::StorageModeShared,
);
let w_off = wf_buf
.tensor_offset(wf, "lm_head.weight")?
.ok_or(GpuLmHeadError::MissingTensor("lm_head.weight"))?;
let s_off = wf_buf
.tensor_offset(wf, "lm_head.scales")?
.ok_or(GpuLmHeadError::MissingTensor("lm_head.scales"))?;
let b_off = wf_buf
.tensor_offset(wf, "lm_head.biases")?
.ok_or(GpuLmHeadError::MissingTensor("lm_head.biases"))?;
Ok(Self {
pipelines,
input_buf,
logits_buf,
w_off,
s_off,
b_off,
})
}
pub fn forward(
&self,
metal: &MetalContext,
wf_buf: &MtlWeightBuf,
hidden: &[f32],
logits: &mut [f32],
) -> Result<(), GpuLmHeadError> {
let v = VARIANT;
if hidden.len() != v.hidden_dim {
return Err(GpuLmHeadError::InputLen {
got: hidden.len(),
expected: v.hidden_dim,
});
}
if logits.len() != v.vocab_size {
return Err(GpuLmHeadError::OutputLen {
got: logits.len(),
expected: v.vocab_size,
});
}
unsafe {
std::ptr::copy_nonoverlapping(
hidden.as_ptr(),
self.input_buf.contents() as *mut f32,
v.hidden_dim,
);
}
let cmdbuf = metal.queue().new_command_buffer();
let spec = MatvecSpec {
w_off: self.w_off,
s_off: self.s_off,
b_off: self.b_off,
input: &self.input_buf,
output: &self.logits_buf,
out_dim: v.vocab_size as u32,
in_dim: v.hidden_dim as u32,
bits: 4,
};
encode_matvec(cmdbuf, &self.pipelines, wf_buf, &spec);
cmdbuf.commit();
cmdbuf.wait_until_completed();
unsafe {
std::ptr::copy_nonoverlapping(
self.logits_buf.contents() as *const f32,
logits.as_mut_ptr(),
v.vocab_size,
);
}
Ok(())
}
}
impl std::fmt::Debug for GpuLmHead {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("GpuLmHead")
.field("w_off", &self.w_off)
.field("s_off", &self.s_off)
.field("b_off", &self.b_off)
.finish()
}
}