1use metal::MTLSize;
20
21use crate::buffer::MlxBuffer;
22use crate::dtypes::DType;
23use crate::encoder::CommandEncoder;
24use crate::error::{MlxError, Result};
25use crate::kernel_registry::KernelRegistry;
26
27pub static CONV1D_DEPTHWISE_CAUSAL_SHADER_SOURCE: &str =
28 include_str!("../shaders/conv1d_depthwise_causal.metal");
29
30pub fn register(registry: &mut KernelRegistry) {
31 registry.register_source(
32 "conv1d_depthwise_causal_forward_f32",
33 CONV1D_DEPTHWISE_CAUSAL_SHADER_SOURCE,
34 );
35 registry.register_source(
36 "conv1d_depthwise_causal_backward_dx_f32",
37 CONV1D_DEPTHWISE_CAUSAL_SHADER_SOURCE,
38 );
39 registry.register_source(
40 "conv1d_depthwise_causal_backward_dw_f32",
41 CONV1D_DEPTHWISE_CAUSAL_SHADER_SOURCE,
42 );
43}
44
45fn validate_shapes(
46 op: &str,
47 n_tokens: u32,
48 channels: u32,
49 k: u32,
50 x_or_dx: &MlxBuffer,
51 w_or_dy_or_dw: &MlxBuffer,
52 out: &MlxBuffer,
53 params: &MlxBuffer,
54 expected_first_count: usize,
55 expected_second_count: usize,
56 expected_out_count: usize,
57) -> Result<()> {
58 if n_tokens == 0 || channels == 0 || k == 0 {
59 return Err(MlxError::InvalidArgument(format!(
60 "{op}: n_tokens, channels, K must all be > 0 (got {n_tokens}, {channels}, {k})"
61 )));
62 }
63 if x_or_dx.dtype() != DType::F32
64 || w_or_dy_or_dw.dtype() != DType::F32
65 || out.dtype() != DType::F32
66 {
67 return Err(MlxError::InvalidArgument(format!(
68 "{op}: all I/O buffers must be f32"
69 )));
70 }
71 if x_or_dx.element_count() != expected_first_count {
72 return Err(MlxError::InvalidArgument(format!(
73 "{op}: first-buffer element_count {} != expected {expected_first_count}",
74 x_or_dx.element_count()
75 )));
76 }
77 if w_or_dy_or_dw.element_count() != expected_second_count {
78 return Err(MlxError::InvalidArgument(format!(
79 "{op}: second-buffer element_count {} != expected {expected_second_count}",
80 w_or_dy_or_dw.element_count()
81 )));
82 }
83 if out.element_count() != expected_out_count {
84 return Err(MlxError::InvalidArgument(format!(
85 "{op}: out element_count {} != expected {expected_out_count}",
86 out.element_count()
87 )));
88 }
89 if params.byte_len() < 12 {
90 return Err(MlxError::InvalidArgument(format!(
91 "{op}: params < 12 bytes (need 3 × u32 = [n_tokens, channels, K])"
92 )));
93 }
94 Ok(())
95}
96
97#[allow(clippy::too_many_arguments)]
98pub fn dispatch_conv1d_depthwise_causal_forward_f32(
99 encoder: &mut CommandEncoder,
100 registry: &mut KernelRegistry,
101 device: &metal::DeviceRef,
102 x: &MlxBuffer,
103 kernel_w: &MlxBuffer,
104 y: &MlxBuffer,
105 params: &MlxBuffer,
106 n_tokens: u32,
107 channels: u32,
108 k: u32,
109) -> Result<()> {
110 const OP: &str = "conv1d_depthwise_causal_forward_f32";
111 let n = n_tokens as usize;
112 let c = channels as usize;
113 let k_us = k as usize;
114 validate_shapes(
115 OP, n_tokens, channels, k, x, kernel_w, y, params,
116 n * c, c * k_us, n * c,
117 )?;
118
119 let pipeline = registry.get_pipeline(OP, device)?;
120 encoder.encode(
121 pipeline,
122 &[(0, x), (1, kernel_w), (2, y), (3, params)],
123 MTLSize::new(n_tokens as u64, channels as u64, 1),
124 MTLSize::new(
125 std::cmp::min(32, n_tokens as u64),
126 std::cmp::min(8, channels as u64),
127 1,
128 ),
129 );
130 Ok(())
131}
132
133#[allow(clippy::too_many_arguments)]
134pub fn dispatch_conv1d_depthwise_causal_backward_dx_f32(
135 encoder: &mut CommandEncoder,
136 registry: &mut KernelRegistry,
137 device: &metal::DeviceRef,
138 dy: &MlxBuffer,
139 kernel_w: &MlxBuffer,
140 dx: &MlxBuffer,
141 params: &MlxBuffer,
142 n_tokens: u32,
143 channels: u32,
144 k: u32,
145) -> Result<()> {
146 const OP: &str = "conv1d_depthwise_causal_backward_dx_f32";
147 let n = n_tokens as usize;
148 let c = channels as usize;
149 let k_us = k as usize;
150 validate_shapes(
151 OP, n_tokens, channels, k, dy, kernel_w, dx, params,
152 n * c, c * k_us, n * c,
153 )?;
154
155 let pipeline = registry.get_pipeline(OP, device)?;
156 encoder.encode(
157 pipeline,
158 &[(0, dy), (1, kernel_w), (2, dx), (3, params)],
159 MTLSize::new(n_tokens as u64, channels as u64, 1),
160 MTLSize::new(
161 std::cmp::min(32, n_tokens as u64),
162 std::cmp::min(8, channels as u64),
163 1,
164 ),
165 );
166 Ok(())
167}
168
169#[allow(clippy::too_many_arguments)]
170pub fn dispatch_conv1d_depthwise_causal_backward_dw_f32(
171 encoder: &mut CommandEncoder,
172 registry: &mut KernelRegistry,
173 device: &metal::DeviceRef,
174 x: &MlxBuffer,
175 dy: &MlxBuffer,
176 dw: &MlxBuffer,
177 params: &MlxBuffer,
178 n_tokens: u32,
179 channels: u32,
180 k: u32,
181) -> Result<()> {
182 const OP: &str = "conv1d_depthwise_causal_backward_dw_f32";
183 let n = n_tokens as usize;
184 let c = channels as usize;
185 let k_us = k as usize;
186 validate_shapes(
187 OP, n_tokens, channels, k, x, dy, dw, params,
188 n * c, n * c, c * k_us,
189 )?;
190
191 let pipeline = registry.get_pipeline(OP, device)?;
192 encoder.encode(
193 pipeline,
194 &[(0, x), (1, dy), (2, dw), (3, params)],
195 MTLSize::new(channels as u64, k as u64, 1),
196 MTLSize::new(
197 std::cmp::min(32, channels as u64),
198 std::cmp::min(8, k as u64),
199 1,
200 ),
201 );
202 Ok(())
203}
204
205#[cfg(test)]
206mod tests {
207 use super::*;
208 use crate::device::MlxDevice;
209
210 fn alloc_f32(device: &MlxDevice, n: usize, shape: Vec<usize>) -> MlxBuffer {
211 let mut b = device
212 .alloc_buffer(n * 4, DType::F32, shape)
213 .expect("alloc f32");
214 b.as_mut_slice::<f32>().unwrap().fill(0.0);
215 b
216 }
217
218 fn make_params(device: &MlxDevice, n_tokens: u32, channels: u32, k: u32) -> MlxBuffer {
219 let mut p = device
220 .alloc_buffer(12, DType::U32, vec![3])
221 .expect("alloc params");
222 p.as_mut_slice::<u32>()
223 .unwrap()
224 .copy_from_slice(&[n_tokens, channels, k]);
225 p
226 }
227
228 fn forward_cpu(
230 x: &[f32], kernel_w: &[f32], n: usize, c: usize, k: usize,
231 ) -> Vec<f32> {
232 let mut y = vec![0.0f32; n * c];
233 for t in 0..n {
234 for ch in 0..c {
235 let mut sum = 0.0f64;
236 for kk in 0..k {
237 let i_signed = (t as isize) + (kk as isize) - (k as isize - 1);
238 if i_signed < 0 {
239 continue;
240 }
241 let i = i_signed as usize;
242 sum += kernel_w[ch * k + kk] as f64 * x[i * c + ch] as f64;
243 }
244 y[t * c + ch] = sum as f32;
245 }
246 }
247 y
248 }
249
250 fn backward_dx_cpu(
252 dy: &[f32], kernel_w: &[f32], n: usize, c: usize, k: usize,
253 ) -> Vec<f32> {
254 let mut dx = vec![0.0f32; n * c];
255 for i in 0..n {
256 for ch in 0..c {
257 let mut sum = 0.0f64;
258 for kk in 0..k {
259 let t_signed = (i as isize) + (k as isize - 1) - (kk as isize);
260 if t_signed < 0 || t_signed >= n as isize {
261 continue;
262 }
263 let t = t_signed as usize;
264 sum += kernel_w[ch * k + kk] as f64 * dy[t * c + ch] as f64;
265 }
266 dx[i * c + ch] = sum as f32;
267 }
268 }
269 dx
270 }
271
272 fn backward_dw_cpu(
274 x: &[f32], dy: &[f32], n: usize, c: usize, k: usize,
275 ) -> Vec<f32> {
276 let mut dw = vec![0.0f32; c * k];
277 for ch in 0..c {
278 for kk in 0..k {
279 let mut sum = 0.0f64;
280 for t in (k - 1 - kk)..n {
281 let i = t + kk - (k - 1);
282 sum += x[i * c + ch] as f64 * dy[t * c + ch] as f64;
283 }
284 dw[ch * k + kk] = sum as f32;
285 }
286 }
287 dw
288 }
289
290 #[test]
291 fn forward_matches_cpu_oracle() {
292 let device = MlxDevice::new().expect("device");
293 let mut registry = KernelRegistry::new();
294 let n = 16usize;
295 let c = 8usize;
296 let k = 4usize;
297
298 let x: Vec<f32> = (0..(n * c))
299 .map(|i| ((i as f32) * 0.137 - 0.4).sin() * 0.7)
300 .collect();
301 let w: Vec<f32> = (0..(c * k))
302 .map(|i| ((i as f32) * 0.231 + 0.1).cos() * 0.5)
303 .collect();
304
305 let mut x_buf = alloc_f32(&device, n * c, vec![n, c]);
306 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
307 let mut w_buf = alloc_f32(&device, c * k, vec![c, k]);
308 w_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&w);
309 let y_buf = alloc_f32(&device, n * c, vec![n, c]);
310 let params = make_params(&device, n as u32, c as u32, k as u32);
311
312 let mut encoder = device.command_encoder().unwrap();
313 dispatch_conv1d_depthwise_causal_forward_f32(
314 &mut encoder, &mut registry, device.metal_device(),
315 &x_buf, &w_buf, &y_buf, ¶ms,
316 n as u32, c as u32, k as u32,
317 ).unwrap();
318 encoder.commit_and_wait().unwrap();
319
320 let gpu = y_buf.as_slice::<f32>().unwrap();
321 let cpu = forward_cpu(&x, &w, n, c, k);
322 for i in 0..(n * c) {
323 assert!(
324 (gpu[i] - cpu[i]).abs() < 1e-5 * cpu[i].abs().max(1.0),
325 "forward y[{i}]: gpu={} cpu={}",
326 gpu[i], cpu[i]
327 );
328 }
329 }
330
331 #[test]
332 fn backward_dx_matches_cpu_oracle() {
333 let device = MlxDevice::new().expect("device");
334 let mut registry = KernelRegistry::new();
335 let n = 16usize;
336 let c = 8usize;
337 let k = 4usize;
338
339 let dy: Vec<f32> = (0..(n * c)).map(|i| ((i as f32) * 0.073 - 0.3).sin() * 0.6).collect();
340 let w: Vec<f32> = (0..(c * k)).map(|i| 0.1 + (i as f32) * 0.013).collect();
341
342 let mut dy_buf = alloc_f32(&device, n * c, vec![n, c]);
343 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
344 let mut w_buf = alloc_f32(&device, c * k, vec![c, k]);
345 w_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&w);
346 let dx_buf = alloc_f32(&device, n * c, vec![n, c]);
347 let params = make_params(&device, n as u32, c as u32, k as u32);
348
349 let mut encoder = device.command_encoder().unwrap();
350 dispatch_conv1d_depthwise_causal_backward_dx_f32(
351 &mut encoder, &mut registry, device.metal_device(),
352 &dy_buf, &w_buf, &dx_buf, ¶ms,
353 n as u32, c as u32, k as u32,
354 ).unwrap();
355 encoder.commit_and_wait().unwrap();
356
357 let gpu = dx_buf.as_slice::<f32>().unwrap();
358 let cpu = backward_dx_cpu(&dy, &w, n, c, k);
359 for i in 0..(n * c) {
360 assert!(
361 (gpu[i] - cpu[i]).abs() < 1e-5 * cpu[i].abs().max(1.0),
362 "dx[{i}]: gpu={} cpu={}",
363 gpu[i], cpu[i]
364 );
365 }
366 }
367
368 #[test]
369 fn backward_dw_matches_cpu_oracle() {
370 let device = MlxDevice::new().expect("device");
371 let mut registry = KernelRegistry::new();
372 let n = 32usize;
373 let c = 8usize;
374 let k = 4usize;
375
376 let x: Vec<f32> = (0..(n * c)).map(|i| ((i as f32) * 0.041 - 0.5).cos() * 0.7).collect();
377 let dy: Vec<f32> = (0..(n * c)).map(|i| ((i as f32) * 0.073 - 0.3).sin() * 0.6).collect();
378
379 let mut x_buf = alloc_f32(&device, n * c, vec![n, c]);
380 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
381 let mut dy_buf = alloc_f32(&device, n * c, vec![n, c]);
382 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
383 let dw_buf = alloc_f32(&device, c * k, vec![c, k]);
384 let params = make_params(&device, n as u32, c as u32, k as u32);
385
386 let mut encoder = device.command_encoder().unwrap();
387 dispatch_conv1d_depthwise_causal_backward_dw_f32(
388 &mut encoder, &mut registry, device.metal_device(),
389 &x_buf, &dy_buf, &dw_buf, ¶ms,
390 n as u32, c as u32, k as u32,
391 ).unwrap();
392 encoder.commit_and_wait().unwrap();
393
394 let gpu = dw_buf.as_slice::<f32>().unwrap();
395 let cpu = backward_dw_cpu(&x, &dy, n, c, k);
396 for i in 0..(c * k) {
397 assert!(
398 (gpu[i] - cpu[i]).abs() < 1e-4 * cpu[i].abs().max(1.0),
399 "dw[{i}]: gpu={} cpu={}",
400 gpu[i], cpu[i]
401 );
402 }
403 }
404
405 #[test]
409 fn backward_finite_difference_falsifier() {
410 let device = MlxDevice::new().expect("device");
411 let mut registry = KernelRegistry::new();
412 let n = 8usize;
413 let c = 4usize;
414 let k = 3usize;
415
416 let x: Vec<f32> = (0..(n * c)).map(|i| ((i as f32) * 0.137).sin() * 0.6).collect();
417 let w: Vec<f32> = (0..(c * k)).map(|i| 0.2 + (i as f32) * 0.05).collect();
418
419 let forward_loss = |x: &[f32], w: &[f32]| -> f64 {
420 let y = forward_cpu(x, w, n, c, k);
421 y.iter().map(|v| *v as f64).sum::<f64>()
422 };
423
424 let dy_ones = vec![1.0f32; n * c];
426 let mut dy_buf = alloc_f32(&device, n * c, vec![n, c]);
427 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
428
429 let mut x_buf = alloc_f32(&device, n * c, vec![n, c]);
430 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
431 let mut w_buf = alloc_f32(&device, c * k, vec![c, k]);
432 w_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&w);
433 let dx_buf = alloc_f32(&device, n * c, vec![n, c]);
434 let dw_buf = alloc_f32(&device, c * k, vec![c, k]);
435 let params = make_params(&device, n as u32, c as u32, k as u32);
436
437 let mut encoder = device.command_encoder().unwrap();
438 dispatch_conv1d_depthwise_causal_backward_dx_f32(
439 &mut encoder, &mut registry, device.metal_device(),
440 &dy_buf, &w_buf, &dx_buf, ¶ms,
441 n as u32, c as u32, k as u32,
442 ).unwrap();
443 dispatch_conv1d_depthwise_causal_backward_dw_f32(
444 &mut encoder, &mut registry, device.metal_device(),
445 &x_buf, &dy_buf, &dw_buf, ¶ms,
446 n as u32, c as u32, k as u32,
447 ).unwrap();
448 encoder.commit_and_wait().unwrap();
449 let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
450 let dw = dw_buf.as_slice::<f32>().unwrap().to_vec();
451
452 let h = 1e-3f64;
454 for i in 0..(n * c) {
455 let mut xp = x.clone();
456 xp[i] += h as f32;
457 let mut xm = x.clone();
458 xm[i] -= h as f32;
459 let fd = (forward_loss(&xp, &w) - forward_loss(&xm, &w)) / (2.0 * h);
460 let tol = 1e-2 * fd.abs().max(1.0);
461 assert!(
462 (dx[i] as f64 - fd).abs() < tol,
463 "FD x[{i}]: analytic={} fd={}", dx[i], fd
464 );
465 }
466 for i in 0..(c * k) {
468 let mut wp = w.clone();
469 wp[i] += h as f32;
470 let mut wm = w.clone();
471 wm[i] -= h as f32;
472 let fd = (forward_loss(&x, &wp) - forward_loss(&x, &wm)) / (2.0 * h);
473 let tol = 1e-2 * fd.abs().max(1.0);
474 assert!(
475 (dw[i] as f64 - fd).abs() < tol,
476 "FD w[{i}]: analytic={} fd={}", dw[i], fd
477 );
478 }
479 }
480
481 #[test]
482 fn rejects_zero_dimensions() {
483 let device = MlxDevice::new().expect("device");
484 let mut registry = KernelRegistry::new();
485 let x_buf = alloc_f32(&device, 1, vec![1, 1]);
486 let w_buf = alloc_f32(&device, 1, vec![1, 1]);
487 let y_buf = alloc_f32(&device, 1, vec![1, 1]);
488 let params = make_params(&device, 0, 1, 1);
489 let mut encoder = device.command_encoder().unwrap();
490 let res = dispatch_conv1d_depthwise_causal_forward_f32(
491 &mut encoder, &mut registry, device.metal_device(),
492 &x_buf, &w_buf, &y_buf, ¶ms, 0, 1, 1,
493 );
494 assert!(res.is_err());
495 }
496}