1use metal::MTLSize;
14
15use crate::buffer::MlxBuffer;
16use crate::dtypes::DType;
17use crate::encoder::CommandEncoder;
18use crate::error::{MlxError, Result};
19use crate::kernel_registry::KernelRegistry;
20
21pub static SILU_BACKWARD_SHADER_SOURCE: &str =
22 include_str!("../shaders/silu_backward.metal");
23
24pub fn register(registry: &mut KernelRegistry) {
25 registry.register_source("silu_f32", SILU_BACKWARD_SHADER_SOURCE);
26 registry.register_source("silu_backward_f32", SILU_BACKWARD_SHADER_SOURCE);
27}
28
29pub fn dispatch_silu_f32(
33 encoder: &mut CommandEncoder,
34 registry: &mut KernelRegistry,
35 device: &metal::DeviceRef,
36 input: &MlxBuffer,
37 output: &MlxBuffer,
38 params_buf: &MlxBuffer,
39) -> Result<()> {
40 let n = input.element_count();
41 if n == 0 {
42 return Err(MlxError::InvalidArgument(
43 "silu_f32: input must have at least one element".into(),
44 ));
45 }
46 if output.element_count() != n {
47 return Err(MlxError::InvalidArgument(format!(
48 "silu_f32: output element count {} != input element count {n}",
49 output.element_count()
50 )));
51 }
52 for (label, buf) in [("input", input), ("output", output)] {
53 if buf.dtype() != DType::F32 {
54 return Err(MlxError::InvalidArgument(format!(
55 "silu_f32: {label} dtype {} not f32",
56 buf.dtype()
57 )));
58 }
59 }
60 if params_buf.byte_len() < 4 {
61 return Err(MlxError::InvalidArgument(format!(
62 "silu_f32: params_buf too small (need 4 bytes for u32, got {})",
63 params_buf.byte_len()
64 )));
65 }
66
67 let pipeline = registry.get_pipeline("silu_f32", device)?;
68 let thread_count = n as u64;
69 let tg_size = std::cmp::min(256, thread_count);
70 encoder.encode(
71 pipeline,
72 &[(0, input), (1, output), (2, params_buf)],
73 MTLSize::new(thread_count, 1, 1),
74 MTLSize::new(tg_size, 1, 1),
75 );
76 Ok(())
77}
78
79#[allow(clippy::too_many_arguments)]
83pub fn dispatch_silu_backward_f32(
84 encoder: &mut CommandEncoder,
85 registry: &mut KernelRegistry,
86 device: &metal::DeviceRef,
87 x: &MlxBuffer,
88 dy: &MlxBuffer,
89 dx: &MlxBuffer,
90 params_buf: &MlxBuffer,
91) -> Result<()> {
92 let n = x.element_count();
93 if n == 0 {
94 return Err(MlxError::InvalidArgument(
95 "silu_backward_f32: x must have at least one element".into(),
96 ));
97 }
98 for (label, buf) in [("x", x), ("dy", dy), ("dx", dx)] {
99 if buf.element_count() != n {
100 return Err(MlxError::InvalidArgument(format!(
101 "silu_backward_f32: {label} element count {} != x element count {n}",
102 buf.element_count(),
103 )));
104 }
105 if buf.dtype() != DType::F32 {
106 return Err(MlxError::InvalidArgument(format!(
107 "silu_backward_f32: {label} dtype {} not f32",
108 buf.dtype()
109 )));
110 }
111 }
112 if params_buf.byte_len() < 4 {
113 return Err(MlxError::InvalidArgument(format!(
114 "silu_backward_f32: params_buf too small (need 4 bytes for u32, got {})",
115 params_buf.byte_len()
116 )));
117 }
118
119 let pipeline = registry.get_pipeline("silu_backward_f32", device)?;
120 let thread_count = n as u64;
121 let tg_size = std::cmp::min(256, thread_count);
122 encoder.encode(
123 pipeline,
124 &[(0, x), (1, dy), (2, dx), (3, params_buf)],
125 MTLSize::new(thread_count, 1, 1),
126 MTLSize::new(tg_size, 1, 1),
127 );
128 Ok(())
129}
130
131#[cfg(test)]
132mod tests {
133 use super::*;
134 use crate::device::MlxDevice;
135
136 fn silu_cpu(x: &[f32]) -> Vec<f32> {
137 x.iter().map(|&xv| xv / (1.0 + (-xv).exp())).collect()
138 }
139
140 fn silu_backward_cpu(x: &[f32], dy: &[f32]) -> Vec<f32> {
141 x.iter()
142 .zip(dy.iter())
143 .map(|(&xv, &dyv)| {
144 let s = 1.0 / (1.0 + (-xv).exp());
145 let deriv = s * (1.0 + xv * (1.0 - s));
146 dyv * deriv
147 })
148 .collect()
149 }
150
151 fn run_silu_forward(input: &[f32]) -> Vec<f32> {
152 let device = MlxDevice::new().expect("device");
153 let n = input.len();
154 let mut in_buf = device
155 .alloc_buffer(n * 4, DType::F32, vec![n])
156 .expect("alloc in");
157 in_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(input);
158 let out_buf = device
159 .alloc_buffer(n * 4, DType::F32, vec![n])
160 .expect("alloc out");
161 let mut params = device.alloc_buffer(4, DType::F32, vec![1]).expect("params");
162 params.as_mut_slice::<u32>().unwrap()[0] = n as u32;
163 let mut registry = KernelRegistry::new();
164 register(&mut registry);
165 let mut encoder = device.command_encoder().expect("encoder");
166 dispatch_silu_f32(
167 &mut encoder,
168 &mut registry,
169 device.metal_device(),
170 &in_buf,
171 &out_buf,
172 ¶ms,
173 )
174 .expect("dispatch silu");
175 encoder.commit_and_wait().expect("commit");
176 out_buf.as_slice::<f32>().unwrap().to_vec()
177 }
178
179 fn run_silu_backward(input: &[f32], dy: &[f32]) -> Vec<f32> {
180 let device = MlxDevice::new().expect("device");
181 let n = input.len();
182 let mut x_buf = device
183 .alloc_buffer(n * 4, DType::F32, vec![n])
184 .expect("alloc x");
185 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(input);
186 let mut dy_buf = device
187 .alloc_buffer(n * 4, DType::F32, vec![n])
188 .expect("alloc dy");
189 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(dy);
190 let dx_buf = device
191 .alloc_buffer(n * 4, DType::F32, vec![n])
192 .expect("alloc dx");
193 let mut params = device.alloc_buffer(4, DType::F32, vec![1]).expect("params");
194 params.as_mut_slice::<u32>().unwrap()[0] = n as u32;
195 let mut registry = KernelRegistry::new();
196 register(&mut registry);
197 let mut encoder = device.command_encoder().expect("encoder");
198 dispatch_silu_backward_f32(
199 &mut encoder,
200 &mut registry,
201 device.metal_device(),
202 &x_buf,
203 &dy_buf,
204 &dx_buf,
205 ¶ms,
206 )
207 .expect("dispatch silu backward");
208 encoder.commit_and_wait().expect("commit");
209 dx_buf.as_slice::<f32>().unwrap().to_vec()
210 }
211
212 fn assert_close(label: &str, gpu: &[f32], cpu: &[f32], rel_tol: f32, abs_tol: f32) {
213 assert_eq!(gpu.len(), cpu.len(), "{label}: length mismatch");
214 for (i, (g, c)) in gpu.iter().zip(cpu.iter()).enumerate() {
215 let diff = (g - c).abs();
216 let scale = g.abs().max(c.abs()).max(1.0);
217 assert!(
218 diff <= abs_tol || diff / scale <= rel_tol,
219 "{label}: i={i}: gpu={g} cpu={c} diff={diff}"
220 );
221 }
222 }
223
224 #[test]
225 fn silu_forward_parity_with_cpu() {
226 let input: Vec<f32> = (0..256)
227 .map(|i| (i as f32 - 128.0) * 0.05)
228 .collect();
229 let gpu = run_silu_forward(&input);
230 let cpu = silu_cpu(&input);
231 assert_close("silu forward", &gpu, &cpu, 1e-6, 1e-7);
232 }
233
234 #[test]
235 fn silu_forward_handles_extremes() {
236 let input = vec![-20.0_f32, -10.0, -5.0, -0.5, 0.0, 0.5, 5.0, 10.0, 20.0];
239 let gpu = run_silu_forward(&input);
240 let cpu = silu_cpu(&input);
241 assert_close("silu extremes", &gpu, &cpu, 1e-5, 1e-6);
242 assert_eq!(gpu[4], 0.0);
244 }
245
246 #[test]
247 fn silu_backward_parity_with_cpu() {
248 let input: Vec<f32> = (0..256).map(|i| (i as f32 - 128.0) * 0.05).collect();
249 let dy: Vec<f32> = (0..256).map(|i| ((i as f32) * 0.013).sin()).collect();
250 let gpu = run_silu_backward(&input, &dy);
251 let cpu = silu_backward_cpu(&input, &dy);
252 assert_close("silu backward", &gpu, &cpu, 1e-5, 1e-6);
253 }
254
255 #[test]
256 fn silu_backward_finite_diff_falsifier() {
257 let input: Vec<f32> = (0..32).map(|i| (i as f32 - 15.5) * 0.07).collect();
260 let h = 1e-3_f32;
261 for &probe in &[0usize, 7, 15, 16, 24, 31] {
263 let mut x_plus = input.clone();
264 let mut x_minus = input.clone();
265 x_plus[probe] += h;
266 x_minus[probe] -= h;
267 let f_plus = silu_cpu(&x_plus)[probe];
268 let f_minus = silu_cpu(&x_minus)[probe];
269 let fd = (f_plus - f_minus) / (2.0 * h);
270 let mut dy = vec![0f32; input.len()];
273 dy[probe] = 1.0;
274 let dx_gpu = run_silu_backward(&input, &dy)[probe];
275 let diff = (dx_gpu - fd).abs();
276 let scale = dx_gpu.abs().max(fd.abs()).max(1.0);
277 assert!(
278 diff <= 1e-3 || diff / scale <= 5e-3,
279 "silu finite-diff falsifier failed at probe {probe}: \
280 fd={fd} analytical={dx_gpu} diff={diff}"
281 );
282 }
283 }
284}