use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static SILU_BACKWARD_SHADER_SOURCE: &str =
include_str!("../shaders/silu_backward.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("silu_f32", SILU_BACKWARD_SHADER_SOURCE);
registry.register_source("silu_backward_f32", SILU_BACKWARD_SHADER_SOURCE);
}
pub fn dispatch_silu_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
input: &MlxBuffer,
output: &MlxBuffer,
params_buf: &MlxBuffer,
) -> Result<()> {
let n = input.element_count();
if n == 0 {
return Err(MlxError::InvalidArgument(
"silu_f32: input must have at least one element".into(),
));
}
if output.element_count() != n {
return Err(MlxError::InvalidArgument(format!(
"silu_f32: output element count {} != input element count {n}",
output.element_count()
)));
}
for (label, buf) in [("input", input), ("output", output)] {
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"silu_f32: {label} dtype {} not f32",
buf.dtype()
)));
}
}
if params_buf.byte_len() < 4 {
return Err(MlxError::InvalidArgument(format!(
"silu_f32: params_buf too small (need 4 bytes for u32, got {})",
params_buf.byte_len()
)));
}
let pipeline = registry.get_pipeline("silu_f32", device)?;
let thread_count = n as u64;
let tg_size = std::cmp::min(256, thread_count);
encoder.encode(
pipeline,
&[(0, input), (1, output), (2, params_buf)],
MTLSize::new(thread_count, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[allow(clippy::too_many_arguments)]
pub fn dispatch_silu_backward_f32(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
x: &MlxBuffer,
dy: &MlxBuffer,
dx: &MlxBuffer,
params_buf: &MlxBuffer,
) -> Result<()> {
let n = x.element_count();
if n == 0 {
return Err(MlxError::InvalidArgument(
"silu_backward_f32: x must have at least one element".into(),
));
}
for (label, buf) in [("x", x), ("dy", dy), ("dx", dx)] {
if buf.element_count() != n {
return Err(MlxError::InvalidArgument(format!(
"silu_backward_f32: {label} element count {} != x element count {n}",
buf.element_count(),
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"silu_backward_f32: {label} dtype {} not f32",
buf.dtype()
)));
}
}
if params_buf.byte_len() < 4 {
return Err(MlxError::InvalidArgument(format!(
"silu_backward_f32: params_buf too small (need 4 bytes for u32, got {})",
params_buf.byte_len()
)));
}
let pipeline = registry.get_pipeline("silu_backward_f32", device)?;
let thread_count = n as u64;
let tg_size = std::cmp::min(256, thread_count);
encoder.encode(
pipeline,
&[(0, x), (1, dy), (2, dx), (3, params_buf)],
MTLSize::new(thread_count, 1, 1),
MTLSize::new(tg_size, 1, 1),
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::device::MlxDevice;
fn silu_cpu(x: &[f32]) -> Vec<f32> {
x.iter().map(|&xv| xv / (1.0 + (-xv).exp())).collect()
}
fn silu_backward_cpu(x: &[f32], dy: &[f32]) -> Vec<f32> {
x.iter()
.zip(dy.iter())
.map(|(&xv, &dyv)| {
let s = 1.0 / (1.0 + (-xv).exp());
let deriv = s * (1.0 + xv * (1.0 - s));
dyv * deriv
})
.collect()
}
fn run_silu_forward(input: &[f32]) -> Vec<f32> {
let device = MlxDevice::new().expect("device");
let n = input.len();
let mut in_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc in");
in_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(input);
let out_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc out");
let mut params = device.alloc_buffer(4, DType::F32, vec![1]).expect("params");
params.as_mut_slice::<u32>().unwrap()[0] = n as u32;
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut encoder = device.command_encoder().expect("encoder");
dispatch_silu_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&in_buf,
&out_buf,
¶ms,
)
.expect("dispatch silu");
encoder.commit_and_wait().expect("commit");
out_buf.as_slice::<f32>().unwrap().to_vec()
}
fn run_silu_backward(input: &[f32], dy: &[f32]) -> Vec<f32> {
let device = MlxDevice::new().expect("device");
let n = input.len();
let mut x_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc x");
x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(input);
let mut dy_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc dy");
dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(dy);
let dx_buf = device
.alloc_buffer(n * 4, DType::F32, vec![n])
.expect("alloc dx");
let mut params = device.alloc_buffer(4, DType::F32, vec![1]).expect("params");
params.as_mut_slice::<u32>().unwrap()[0] = n as u32;
let mut registry = KernelRegistry::new();
register(&mut registry);
let mut encoder = device.command_encoder().expect("encoder");
dispatch_silu_backward_f32(
&mut encoder,
&mut registry,
device.metal_device(),
&x_buf,
&dy_buf,
&dx_buf,
¶ms,
)
.expect("dispatch silu backward");
encoder.commit_and_wait().expect("commit");
dx_buf.as_slice::<f32>().unwrap().to_vec()
}
fn assert_close(label: &str, gpu: &[f32], cpu: &[f32], rel_tol: f32, abs_tol: f32) {
assert_eq!(gpu.len(), cpu.len(), "{label}: length mismatch");
for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
let diff = (g - c).abs();
let scale = g.abs().max(c.abs()).max(1.0);
assert!(
diff <= abs_tol || diff / scale <= rel_tol,
"{label}: i={i}: gpu={g} cpu={c} diff={diff}"
);
}
}
#[test]
fn silu_forward_parity_with_cpu() {
let input: Vec<f32> = (0..256)
.map(|i| (i as f32 - 128.0) * 0.05)
.collect();
let gpu = run_silu_forward(&input);
let cpu = silu_cpu(&input);
assert_close("silu forward", &gpu, &cpu, 1e-6, 1e-7);
}
#[test]
fn silu_forward_handles_extremes() {
let input = vec![-20.0_f32, -10.0, -5.0, -0.5, 0.0, 0.5, 5.0, 10.0, 20.0];
let gpu = run_silu_forward(&input);
let cpu = silu_cpu(&input);
assert_close("silu extremes", &gpu, &cpu, 1e-5, 1e-6);
assert_eq!(gpu[4], 0.0);
}
#[test]
fn silu_backward_parity_with_cpu() {
let input: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
let dy: Vec<f32> = (0..256).map(|i| ((i as f32) * 0.013).sin()).collect();
let gpu = run_silu_backward(&input, &dy);
let cpu = silu_backward_cpu(&input, &dy);
assert_close("silu backward", &gpu, &cpu, 1e-5, 1e-6);
}
#[test]
fn silu_backward_finite_diff_falsifier() {
let input: Vec<f32> = (0..32).map(|i| (i as f32 - 15.5) * 0.07).collect();
let h = 1e-3_f32;
for &probe in &[0usize, 7, 15, 16, 24, 31] {
let mut x_plus = input.clone();
let mut x_minus = input.clone();
x_plus[probe] += h;
x_minus[probe] -= h;
let f_plus = silu_cpu(&x_plus)[probe];
let f_minus = silu_cpu(&x_minus)[probe];
let fd = (f_plus - f_minus) / (2.0 * h);
let mut dy = vec![0f32; input.len()];
dy[probe] = 1.0;
let dx_gpu = run_silu_backward(&input, &dy)[probe];
let diff = (dx_gpu - fd).abs();
let scale = dx_gpu.abs().max(fd.abs()).max(1.0);
assert!(
diff <= 1e-3 || diff / scale <= 5e-3,
"silu finite-diff falsifier failed at probe {probe}: \
fd={fd} analytical={dx_gpu} diff={diff}"
);
}
}
}