Skip to main content

mlx_native/ops/
silu_backward.rs

1//! Elementwise SiLU (swish) forward + reverse-mode backward.
2//!
3//! Used by hf2q's ADR-020 Track 1 SwiGLU FFN on GpuTape (iter-11b).
4//!
5//! Forward:  `silu(x) = x · sigmoid(x)`
6//! Backward: `dx[i] = dy[i] · silu'(x[i])`
7//!           where `silu'(x) = sigmoid(x) · (1 + x · (1 − sigmoid(x)))`
8//!
9//! Note: mlx-native already has `silu_mul_f32` (= `silu(gate) * up`)
10//! used by the inference forward.  This module adds the standalone
11//! `silu_f32` + matching backward kernel needed for autograd.
12
13use metal::MTLSize;
14
15use crate::buffer::MlxBuffer;
16use crate::dtypes::DType;
17use crate::encoder::CommandEncoder;
18use crate::error::{MlxError, Result};
19use crate::kernel_registry::KernelRegistry;
20
21pub static SILU_BACKWARD_SHADER_SOURCE: &str =
22    include_str!("../shaders/silu_backward.metal");
23
24pub fn register(registry: &mut KernelRegistry) {
25    registry.register_source("silu_f32", SILU_BACKWARD_SHADER_SOURCE);
26    registry.register_source("silu_backward_f32", SILU_BACKWARD_SHADER_SOURCE);
27}
28
29/// Encode `output[i] = silu(input[i]) = input[i] · sigmoid(input[i])`.
30///
31/// `params_buf` must be at least 4 bytes (1 × u32: n).
32pub fn dispatch_silu_f32(
33    encoder: &mut CommandEncoder,
34    registry: &mut KernelRegistry,
35    device: &metal::DeviceRef,
36    input: &MlxBuffer,
37    output: &MlxBuffer,
38    params_buf: &MlxBuffer,
39) -> Result<()> {
40    let n = input.element_count();
41    if n == 0 {
42        return Err(MlxError::InvalidArgument(
43            "silu_f32: input must have at least one element".into(),
44        ));
45    }
46    if output.element_count() != n {
47        return Err(MlxError::InvalidArgument(format!(
48            "silu_f32: output element count {} != input element count {n}",
49            output.element_count()
50        )));
51    }
52    for (label, buf) in [("input", input), ("output", output)] {
53        if buf.dtype() != DType::F32 {
54            return Err(MlxError::InvalidArgument(format!(
55                "silu_f32: {label} dtype {} not f32",
56                buf.dtype()
57            )));
58        }
59    }
60    if params_buf.byte_len() < 4 {
61        return Err(MlxError::InvalidArgument(format!(
62            "silu_f32: params_buf too small (need 4 bytes for u32, got {})",
63            params_buf.byte_len()
64        )));
65    }
66
67    let pipeline = registry.get_pipeline("silu_f32", device)?;
68    let thread_count = n as u64;
69    let tg_size = std::cmp::min(256, thread_count);
70    encoder.encode(
71        pipeline,
72        &[(0, input), (1, output), (2, params_buf)],
73        MTLSize::new(thread_count, 1, 1),
74        MTLSize::new(tg_size, 1, 1),
75    );
76    Ok(())
77}
78
79/// Encode `dx[i] = dy[i] · silu'(x[i])`.  `x` is the FORWARD INPUT.
80///
81/// `params_buf` must be at least 4 bytes (1 × u32: n).
82#[allow(clippy::too_many_arguments)]
83pub fn dispatch_silu_backward_f32(
84    encoder: &mut CommandEncoder,
85    registry: &mut KernelRegistry,
86    device: &metal::DeviceRef,
87    x: &MlxBuffer,
88    dy: &MlxBuffer,
89    dx: &MlxBuffer,
90    params_buf: &MlxBuffer,
91) -> Result<()> {
92    let n = x.element_count();
93    if n == 0 {
94        return Err(MlxError::InvalidArgument(
95            "silu_backward_f32: x must have at least one element".into(),
96        ));
97    }
98    for (label, buf) in [("x", x), ("dy", dy), ("dx", dx)] {
99        if buf.element_count() != n {
100            return Err(MlxError::InvalidArgument(format!(
101                "silu_backward_f32: {label} element count {} != x element count {n}",
102                buf.element_count(),
103            )));
104        }
105        if buf.dtype() != DType::F32 {
106            return Err(MlxError::InvalidArgument(format!(
107                "silu_backward_f32: {label} dtype {} not f32",
108                buf.dtype()
109            )));
110        }
111    }
112    if params_buf.byte_len() < 4 {
113        return Err(MlxError::InvalidArgument(format!(
114            "silu_backward_f32: params_buf too small (need 4 bytes for u32, got {})",
115            params_buf.byte_len()
116        )));
117    }
118
119    let pipeline = registry.get_pipeline("silu_backward_f32", device)?;
120    let thread_count = n as u64;
121    let tg_size = std::cmp::min(256, thread_count);
122    encoder.encode(
123        pipeline,
124        &[(0, x), (1, dy), (2, dx), (3, params_buf)],
125        MTLSize::new(thread_count, 1, 1),
126        MTLSize::new(tg_size, 1, 1),
127    );
128    Ok(())
129}
130
131#[cfg(test)]
132mod tests {
133    use super::*;
134    use crate::device::MlxDevice;
135
136    fn silu_cpu(x: &[f32]) -> Vec<f32> {
137        x.iter().map(|&xv| xv / (1.0 + (-xv).exp())).collect()
138    }
139
140    fn silu_backward_cpu(x: &[f32], dy: &[f32]) -> Vec<f32> {
141        x.iter()
142            .zip(dy.iter())
143            .map(|(&xv, &dyv)| {
144                let s = 1.0 / (1.0 + (-xv).exp());
145                let deriv = s * (1.0 + xv * (1.0 - s));
146                dyv * deriv
147            })
148            .collect()
149    }
150
151    fn run_silu_forward(input: &[f32]) -> Vec<f32> {
152        let device = MlxDevice::new().expect("device");
153        let n = input.len();
154        let mut in_buf = device
155            .alloc_buffer(n * 4, DType::F32, vec![n])
156            .expect("alloc in");
157        in_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(input);
158        let out_buf = device
159            .alloc_buffer(n * 4, DType::F32, vec![n])
160            .expect("alloc out");
161        let mut params = device.alloc_buffer(4, DType::F32, vec![1]).expect("params");
162        params.as_mut_slice::<u32>().unwrap()[0] = n as u32;
163        let mut registry = KernelRegistry::new();
164        register(&mut registry);
165        let mut encoder = device.command_encoder().expect("encoder");
166        dispatch_silu_f32(
167            &mut encoder,
168            &mut registry,
169            device.metal_device(),
170            &in_buf,
171            &out_buf,
172            &params,
173        )
174        .expect("dispatch silu");
175        encoder.commit_and_wait().expect("commit");
176        out_buf.as_slice::<f32>().unwrap().to_vec()
177    }
178
179    fn run_silu_backward(input: &[f32], dy: &[f32]) -> Vec<f32> {
180        let device = MlxDevice::new().expect("device");
181        let n = input.len();
182        let mut x_buf = device
183            .alloc_buffer(n * 4, DType::F32, vec![n])
184            .expect("alloc x");
185        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(input);
186        let mut dy_buf = device
187            .alloc_buffer(n * 4, DType::F32, vec![n])
188            .expect("alloc dy");
189        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(dy);
190        let dx_buf = device
191            .alloc_buffer(n * 4, DType::F32, vec![n])
192            .expect("alloc dx");
193        let mut params = device.alloc_buffer(4, DType::F32, vec![1]).expect("params");
194        params.as_mut_slice::<u32>().unwrap()[0] = n as u32;
195        let mut registry = KernelRegistry::new();
196        register(&mut registry);
197        let mut encoder = device.command_encoder().expect("encoder");
198        dispatch_silu_backward_f32(
199            &mut encoder,
200            &mut registry,
201            device.metal_device(),
202            &x_buf,
203            &dy_buf,
204            &dx_buf,
205            &params,
206        )
207        .expect("dispatch silu backward");
208        encoder.commit_and_wait().expect("commit");
209        dx_buf.as_slice::<f32>().unwrap().to_vec()
210    }
211
212    fn assert_close(label: &str, gpu: &[f32], cpu: &[f32], rel_tol: f32, abs_tol: f32) {
213        assert_eq!(gpu.len(), cpu.len(), "{label}: length mismatch");
214        for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
215            let diff = (g - c).abs();
216            let scale = g.abs().max(c.abs()).max(1.0);
217            assert!(
218                diff <= abs_tol || diff / scale <= rel_tol,
219                "{label}: i={i}: gpu={g} cpu={c} diff={diff}"
220            );
221        }
222    }
223
224    #[test]
225    fn silu_forward_parity_with_cpu() {
226        let input: Vec<f32> = (0..256)
227            .map(|i| (i as f32 - 128.0) * 0.05)
228            .collect();
229        let gpu = run_silu_forward(&input);
230        let cpu = silu_cpu(&input);
231        assert_close("silu forward", &gpu, &cpu, 1e-6, 1e-7);
232    }
233
234    #[test]
235    fn silu_forward_handles_extremes() {
236        // Values near the saturation regions: very negative (sigmoid → 0)
237        // and very positive (sigmoid → 1).
238        let input = vec![-20.0_f32, -10.0, -5.0, -0.5, 0.0, 0.5, 5.0, 10.0, 20.0];
239        let gpu = run_silu_forward(&input);
240        let cpu = silu_cpu(&input);
241        assert_close("silu extremes", &gpu, &cpu, 1e-5, 1e-6);
242        // x=0 → silu(0) = 0 · sigmoid(0) = 0 · 0.5 = 0.
243        assert_eq!(gpu[4], 0.0);
244    }
245
246    #[test]
247    fn silu_backward_parity_with_cpu() {
248        let input: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
249        let dy: Vec<f32> = (0..256).map(|i| ((i as f32) * 0.013).sin()).collect();
250        let gpu = run_silu_backward(&input, &dy);
251        let cpu = silu_backward_cpu(&input, &dy);
252        assert_close("silu backward", &gpu, &cpu, 1e-5, 1e-6);
253    }
254
255    #[test]
256    fn silu_backward_finite_diff_falsifier() {
257        // Acid test: backward analytical gradient must match
258        // central finite-difference of the forward.
259        let input: Vec<f32> = (0..32).map(|i| (i as f32 - 15.5) * 0.07).collect();
260        let h = 1e-3_f32;
261        // Pick a few probe indices spanning the saturation regimes.
262        for &probe in &[0usize, 7, 15, 16, 24, 31] {
263            let mut x_plus = input.clone();
264            let mut x_minus = input.clone();
265            x_plus[probe] += h;
266            x_minus[probe] -= h;
267            let f_plus = silu_cpu(&x_plus)[probe];
268            let f_minus = silu_cpu(&x_minus)[probe];
269            let fd = (f_plus - f_minus) / (2.0 * h);
270            // Use dy = e_probe (one-hot) so dx[probe] equals exactly
271            // silu'(x[probe]).
272            let mut dy = vec![0f32; input.len()];
273            dy[probe] = 1.0;
274            let dx_gpu = run_silu_backward(&input, &dy)[probe];
275            let diff = (dx_gpu - fd).abs();
276            let scale = dx_gpu.abs().max(fd.abs()).max(1.0);
277            assert!(
278                diff <= 1e-3 || diff / scale <= 5e-3,
279                "silu finite-diff falsifier failed at probe {probe}: \
280                 fd={fd} analytical={dx_gpu} diff={diff}"
281            );
282        }
283    }
284}