use metal::{Buffer, Device, MTLResourceOptions};
use std::ffi::c_void;
use super::error::MetalGraphError;
pub(super) struct MetalBuffers {
pub(super) hidden_buf: Buffer,
pub(super) attn_out_buf: Buffer,
pub(super) norm_weight_buf: Buffer,
pub(super) proj_buf: Buffer,
pub(super) normed_buf: Buffer,
pub(super) swiglu_buf: Buffer,
pub(super) down_buf: Buffer,
pub(super) hidden_size: usize,
pub(super) intermediate_size: usize,
}
impl MetalBuffers {
pub(super) fn allocate(
device: &Device,
hidden_size: usize,
intermediate_size: usize,
) -> Result<Self, MetalGraphError> {
let h_bytes = (hidden_size * std::mem::size_of::<f32>()) as u64;
let inter_bytes = (intermediate_size * std::mem::size_of::<f32>()) as u64;
let shared = MTLResourceOptions::StorageModeShared;
let private = MTLResourceOptions::StorageModePrivate;
Ok(Self {
hidden_buf: alloc_buf(device, h_bytes, shared)?, attn_out_buf: alloc_buf(device, h_bytes, shared)?, norm_weight_buf: alloc_buf(device, h_bytes, shared)?, proj_buf: alloc_buf(device, h_bytes, private)?, normed_buf: alloc_buf(device, h_bytes, private)?, swiglu_buf: alloc_buf(device, inter_bytes, private)?,
down_buf: alloc_buf(device, h_bytes, private)?, hidden_size,
intermediate_size,
})
}
pub(super) fn matches(&self, hidden_size: usize, intermediate_size: usize) -> bool {
self.hidden_size == hidden_size && self.intermediate_size == intermediate_size
}
}
pub(crate) fn alloc_buf(
device: &Device,
byte_len: u64,
opts: MTLResourceOptions,
) -> Result<Buffer, MetalGraphError> {
if byte_len == 0 {
return Err(MetalGraphError::BufferCreationFailed);
}
let buf = device.new_buffer(byte_len, opts);
if opts.contains(MTLResourceOptions::StorageModePrivate) {
if buf.length() < byte_len {
return Err(MetalGraphError::BufferCreationFailed);
}
} else if buf.contents().is_null() {
return Err(MetalGraphError::BufferCreationFailed);
}
Ok(buf)
}
pub(crate) unsafe fn upload_f32(buf: &Buffer, data: &[f32]) {
std::ptr::copy_nonoverlapping(data.as_ptr(), buf.contents() as *mut f32, data.len());
}
pub(crate) unsafe fn download_f32(buf: &Buffer, out: &mut [f32]) {
std::ptr::copy_nonoverlapping(buf.contents() as *const f32, out.as_mut_ptr(), out.len());
}
pub(super) fn upload_bytes(device: &Device, data: &[u8]) -> Result<Buffer, MetalGraphError> {
if data.is_empty() {
return Err(MetalGraphError::BufferCreationFailed);
}
let opts = MTLResourceOptions::StorageModeShared;
let buf = device.new_buffer(data.len() as u64, opts);
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), buf.contents() as *mut u8, data.len());
}
Ok(buf)
}
#[inline]
pub(crate) fn div_ceil(n: usize, divisor: usize) -> usize {
n.div_ceil(divisor)
}
pub(crate) unsafe fn set_scalar<T: Copy>(
encoder: &metal::ComputeCommandEncoderRef,
index: u64,
value: &T,
) {
encoder.set_bytes(
index,
std::mem::size_of::<T>() as u64,
value as *const T as *const c_void,
);
}