use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static SILU_MUL_SHADER_SOURCE: &str = include_str!("../shaders/silu_mul.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("silu_mul_f32", SILU_MUL_SHADER_SOURCE);
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_silu_mul(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
gate: &MlxBuffer,
up: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
n: u32,
) -> Result<()> {
if n == 0 {
return Err(MlxError::InvalidArgument(
"silu_mul: n must be > 0".into(),
));
}
let expected = n as usize;
let elem_bytes = expected * DType::F32.size_of();
for (name, buf) in [("gate", gate), ("up", up), ("output", output)] {
if buf.byte_len() < elem_bytes {
return Err(MlxError::InvalidArgument(format!(
"silu_mul: {name} buffer too small: need {elem_bytes} bytes, have {}",
buf.byte_len()
)));
}
}
if params_buf.byte_len() < 4 {
return Err(MlxError::InvalidArgument(format!(
"silu_mul: params_buf too small: need 4 bytes, have {}",
params_buf.byte_len()
)));
}
let pipeline = registry.get_pipeline("silu_mul_f32", device)?;
let tg = MTLSize::new(std::cmp::min(n as u64, 256), 1, 1);
let grid = MTLSize::new(n as u64, 1, 1);
encoder.encode(
pipeline,
&[(0, gate), (1, up), (2, output), (3, params_buf)],
grid,
tg,
);
Ok(())
}
pub fn silu_mul_gpu(
registry: &mut KernelRegistry,
device: &MlxDevice,
gate: &MlxBuffer,
up: &MlxBuffer,
n: u32,
) -> Result<MlxBuffer> {
let n_usize = n as usize;
let out_bytes = n_usize * DType::F32.size_of();
let output = device
.alloc_buffer(out_bytes, DType::F32, vec![n_usize])
.map_err(|e| MlxError::InvalidArgument(format!("silu_mul_gpu: alloc output: {e}")))?;
let mut params_buf = device
.alloc_buffer(4, DType::U32, vec![1])
.map_err(|e| MlxError::InvalidArgument(format!("silu_mul_gpu: alloc params: {e}")))?;
params_buf
.as_mut_slice::<u32>()
.map_err(|e| MlxError::InvalidArgument(format!("silu_mul_gpu: write params: {e}")))?[0] = n;
let mut enc = device
.command_encoder()
.map_err(|e| MlxError::InvalidArgument(format!("silu_mul_gpu: command_encoder: {e}")))?;
dispatch_silu_mul(&mut enc, registry, device.metal_device(), gate, up, &output, ¶ms_buf, n)?;
enc.commit_and_wait()
.map_err(|e| MlxError::InvalidArgument(format!("silu_mul_gpu: commit: {e}")))?;
Ok(output)
}