use std::ffi::c_void;
use std::ptr::NonNull;
use metal::{Buffer, Device, MTLResourceOptions, NSUInteger};
use super::weight_file::WeightFile;
#[derive(Debug, thiserror::Error)]
pub enum MtlWeightBufError {
#[error("tensor '{name}' bytes outside mmap region")]
TensorOutOfBounds { name: String },
#[error("required tensor '{name}' missing from weight manifest")]
MissingTensor { name: String },
}
pub struct MtlWeightBuf {
buf: Buffer,
base_ptr: NonNull<u8>,
aligned_len: usize,
}
unsafe impl Send for MtlWeightBuf {}
impl MtlWeightBuf {
pub fn wrap(wf: &WeightFile, device: &Device) -> Self {
let base_ptr = wf
.iter()
.next()
.and_then(|(name, _)| {
let bytes = wf.tensor_bytes(name)?;
let info = wf.tensor_info(name)?;
let off = info.offset as usize;
let mmap_base = unsafe { bytes.as_ptr().sub(off) };
NonNull::new(mmap_base as *mut u8)
})
.expect("WeightFile is non-empty");
let raw_len = wf.file_size();
let page = 16384;
let aligned_len = (raw_len + page - 1) & !(page - 1);
let buf = device.new_buffer_with_bytes_no_copy(
base_ptr.as_ptr() as *const c_void,
aligned_len as NSUInteger,
MTLResourceOptions::StorageModeShared,
None,
);
Self {
buf,
base_ptr,
aligned_len,
}
}
pub fn buffer(&self) -> &Buffer {
&self.buf
}
pub fn tensor_offset(
&self,
wf: &WeightFile,
name: &str,
) -> Result<Option<u64>, MtlWeightBufError> {
let Some(info) = wf.tensor_info(name) else {
return Ok(None);
};
let off = info.offset;
if (off as usize) + (info.size as usize) > self.aligned_len {
return Err(MtlWeightBufError::TensorOutOfBounds {
name: name.to_string(),
});
}
Ok(Some(off))
}
pub fn aligned_len(&self) -> usize {
self.aligned_len
}
pub fn base_ptr(&self) -> *const u8 {
self.base_ptr.as_ptr()
}
}
impl std::fmt::Debug for MtlWeightBuf {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("MtlWeightBuf")
.field("aligned_len", &self.aligned_len)
.field(
"size_gb",
&(self.aligned_len as f64 / 1e9),
)
.finish()
}
}