Skip to main content

mlx_native/ops/
conv1d_depthwise_causal.rs

1//! ADR-020 iter-11h-b — depthwise causal 1D convolution forward +
2//! backward kernels for the GpuTape autograd pipeline.
3//!
4//! Distinct from `ssm_conv` (which fuses SiLU + handles autoregressive
5//! decode state).  This module is for TRAINING-MODE backward pass:
6//!
7//! Forward shape contract:
8//!   x : `[n_tokens, channels]` row-major (f32)
9//!   kernel_w : `[channels, K]` row-major (f32)
10//!   y : `[n_tokens, channels]` row-major (f32)
11//!
12//! Math (per output element `(t, c)`):
13//!   y[t, c] = Σ_{k=0..K-1, t+k-(K-1)>=0} kernel_w[c, k] · x[t+k-(K-1), c]
14//!
15//! Zero-pad on the past: outputs at `t < K-1` see fewer than K input
16//! taps (the missing taps default to 0 — equivalent to "no prior
17//! decode state", which is the training-time invariant).
18
19use metal::MTLSize;
20
21use crate::buffer::MlxBuffer;
22use crate::dtypes::DType;
23use crate::encoder::CommandEncoder;
24use crate::error::{MlxError, Result};
25use crate::kernel_registry::KernelRegistry;
26
27pub static CONV1D_DEPTHWISE_CAUSAL_SHADER_SOURCE: &str =
28    include_str!("../shaders/conv1d_depthwise_causal.metal");
29
30pub fn register(registry: &mut KernelRegistry) {
31    registry.register_source(
32        "conv1d_depthwise_causal_forward_f32",
33        CONV1D_DEPTHWISE_CAUSAL_SHADER_SOURCE,
34    );
35    registry.register_source(
36        "conv1d_depthwise_causal_backward_dx_f32",
37        CONV1D_DEPTHWISE_CAUSAL_SHADER_SOURCE,
38    );
39    registry.register_source(
40        "conv1d_depthwise_causal_backward_dw_f32",
41        CONV1D_DEPTHWISE_CAUSAL_SHADER_SOURCE,
42    );
43}
44
45fn validate_shapes(
46    op: &str,
47    n_tokens: u32,
48    channels: u32,
49    k: u32,
50    x_or_dx: &MlxBuffer,
51    w_or_dy_or_dw: &MlxBuffer,
52    out: &MlxBuffer,
53    params: &MlxBuffer,
54    expected_first_count: usize,
55    expected_second_count: usize,
56    expected_out_count: usize,
57) -> Result<()> {
58    if n_tokens == 0 || channels == 0 || k == 0 {
59        return Err(MlxError::InvalidArgument(format!(
60            "{op}: n_tokens, channels, K must all be > 0 (got {n_tokens}, {channels}, {k})"
61        )));
62    }
63    if x_or_dx.dtype() != DType::F32
64        || w_or_dy_or_dw.dtype() != DType::F32
65        || out.dtype() != DType::F32
66    {
67        return Err(MlxError::InvalidArgument(format!(
68            "{op}: all I/O buffers must be f32"
69        )));
70    }
71    if x_or_dx.element_count() != expected_first_count {
72        return Err(MlxError::InvalidArgument(format!(
73            "{op}: first-buffer element_count {} != expected {expected_first_count}",
74            x_or_dx.element_count()
75        )));
76    }
77    if w_or_dy_or_dw.element_count() != expected_second_count {
78        return Err(MlxError::InvalidArgument(format!(
79            "{op}: second-buffer element_count {} != expected {expected_second_count}",
80            w_or_dy_or_dw.element_count()
81        )));
82    }
83    if out.element_count() != expected_out_count {
84        return Err(MlxError::InvalidArgument(format!(
85            "{op}: out element_count {} != expected {expected_out_count}",
86            out.element_count()
87        )));
88    }
89    if params.byte_len() < 12 {
90        return Err(MlxError::InvalidArgument(format!(
91            "{op}: params < 12 bytes (need 3 × u32 = [n_tokens, channels, K])"
92        )));
93    }
94    Ok(())
95}
96
97#[allow(clippy::too_many_arguments)]
98pub fn dispatch_conv1d_depthwise_causal_forward_f32(
99    encoder: &mut CommandEncoder,
100    registry: &mut KernelRegistry,
101    device: &metal::DeviceRef,
102    x: &MlxBuffer,
103    kernel_w: &MlxBuffer,
104    y: &MlxBuffer,
105    params: &MlxBuffer,
106    n_tokens: u32,
107    channels: u32,
108    k: u32,
109) -> Result<()> {
110    const OP: &str = "conv1d_depthwise_causal_forward_f32";
111    let n = n_tokens as usize;
112    let c = channels as usize;
113    let k_us = k as usize;
114    validate_shapes(
115        OP, n_tokens, channels, k, x, kernel_w, y, params,
116        n * c, c * k_us, n * c,
117    )?;
118
119    let pipeline = registry.get_pipeline(OP, device)?;
120    encoder.encode(
121        pipeline,
122        &[(0, x), (1, kernel_w), (2, y), (3, params)],
123        MTLSize::new(n_tokens as u64, channels as u64, 1),
124        MTLSize::new(
125            std::cmp::min(32, n_tokens as u64),
126            std::cmp::min(8, channels as u64),
127            1,
128        ),
129    );
130    Ok(())
131}
132
133#[allow(clippy::too_many_arguments)]
134pub fn dispatch_conv1d_depthwise_causal_backward_dx_f32(
135    encoder: &mut CommandEncoder,
136    registry: &mut KernelRegistry,
137    device: &metal::DeviceRef,
138    dy: &MlxBuffer,
139    kernel_w: &MlxBuffer,
140    dx: &MlxBuffer,
141    params: &MlxBuffer,
142    n_tokens: u32,
143    channels: u32,
144    k: u32,
145) -> Result<()> {
146    const OP: &str = "conv1d_depthwise_causal_backward_dx_f32";
147    let n = n_tokens as usize;
148    let c = channels as usize;
149    let k_us = k as usize;
150    validate_shapes(
151        OP, n_tokens, channels, k, dy, kernel_w, dx, params,
152        n * c, c * k_us, n * c,
153    )?;
154
155    let pipeline = registry.get_pipeline(OP, device)?;
156    encoder.encode(
157        pipeline,
158        &[(0, dy), (1, kernel_w), (2, dx), (3, params)],
159        MTLSize::new(n_tokens as u64, channels as u64, 1),
160        MTLSize::new(
161            std::cmp::min(32, n_tokens as u64),
162            std::cmp::min(8, channels as u64),
163            1,
164        ),
165    );
166    Ok(())
167}
168
169#[allow(clippy::too_many_arguments)]
170pub fn dispatch_conv1d_depthwise_causal_backward_dw_f32(
171    encoder: &mut CommandEncoder,
172    registry: &mut KernelRegistry,
173    device: &metal::DeviceRef,
174    x: &MlxBuffer,
175    dy: &MlxBuffer,
176    dw: &MlxBuffer,
177    params: &MlxBuffer,
178    n_tokens: u32,
179    channels: u32,
180    k: u32,
181) -> Result<()> {
182    const OP: &str = "conv1d_depthwise_causal_backward_dw_f32";
183    let n = n_tokens as usize;
184    let c = channels as usize;
185    let k_us = k as usize;
186    validate_shapes(
187        OP, n_tokens, channels, k, x, dy, dw, params,
188        n * c, n * c, c * k_us,
189    )?;
190
191    let pipeline = registry.get_pipeline(OP, device)?;
192    encoder.encode(
193        pipeline,
194        &[(0, x), (1, dy), (2, dw), (3, params)],
195        MTLSize::new(channels as u64, k as u64, 1),
196        MTLSize::new(
197            std::cmp::min(32, channels as u64),
198            std::cmp::min(8, k as u64),
199            1,
200        ),
201    );
202    Ok(())
203}
204
205#[cfg(test)]
206mod tests {
207    use super::*;
208    use crate::device::MlxDevice;
209
210    fn alloc_f32(device: &MlxDevice, n: usize, shape: Vec<usize>) -> MlxBuffer {
211        let mut b = device
212            .alloc_buffer(n * 4, DType::F32, shape)
213            .expect("alloc f32");
214        b.as_mut_slice::<f32>().unwrap().fill(0.0);
215        b
216    }
217
218    fn make_params(device: &MlxDevice, n_tokens: u32, channels: u32, k: u32) -> MlxBuffer {
219        let mut p = device
220            .alloc_buffer(12, DType::U32, vec![3])
221            .expect("alloc params");
222        p.as_mut_slice::<u32>()
223            .unwrap()
224            .copy_from_slice(&[n_tokens, channels, k]);
225        p
226    }
227
228    /// CPU oracle: forward causal depthwise conv with zero-pad.
229    fn forward_cpu(
230        x: &[f32], kernel_w: &[f32], n: usize, c: usize, k: usize,
231    ) -> Vec<f32> {
232        let mut y = vec![0.0f32; n * c];
233        for t in 0..n {
234            for ch in 0..c {
235                let mut sum = 0.0f64;
236                for kk in 0..k {
237                    let i_signed = (t as isize) + (kk as isize) - (k as isize - 1);
238                    if i_signed < 0 {
239                        continue;
240                    }
241                    let i = i_signed as usize;
242                    sum += kernel_w[ch * k + kk] as f64 * x[i * c + ch] as f64;
243                }
244                y[t * c + ch] = sum as f32;
245            }
246        }
247        y
248    }
249
250    /// CPU oracle: backward dx.
251    fn backward_dx_cpu(
252        dy: &[f32], kernel_w: &[f32], n: usize, c: usize, k: usize,
253    ) -> Vec<f32> {
254        let mut dx = vec![0.0f32; n * c];
255        for i in 0..n {
256            for ch in 0..c {
257                let mut sum = 0.0f64;
258                for kk in 0..k {
259                    let t_signed = (i as isize) + (k as isize - 1) - (kk as isize);
260                    if t_signed < 0 || t_signed >= n as isize {
261                        continue;
262                    }
263                    let t = t_signed as usize;
264                    sum += kernel_w[ch * k + kk] as f64 * dy[t * c + ch] as f64;
265                }
266                dx[i * c + ch] = sum as f32;
267            }
268        }
269        dx
270    }
271
272    /// CPU oracle: backward dw.
273    fn backward_dw_cpu(
274        x: &[f32], dy: &[f32], n: usize, c: usize, k: usize,
275    ) -> Vec<f32> {
276        let mut dw = vec![0.0f32; c * k];
277        for ch in 0..c {
278            for kk in 0..k {
279                let mut sum = 0.0f64;
280                for t in (k - 1 - kk)..n {
281                    let i = t + kk - (k - 1);
282                    sum += x[i * c + ch] as f64 * dy[t * c + ch] as f64;
283                }
284                dw[ch * k + kk] = sum as f32;
285            }
286        }
287        dw
288    }
289
290    #[test]
291    fn forward_matches_cpu_oracle() {
292        let device = MlxDevice::new().expect("device");
293        let mut registry = KernelRegistry::new();
294        let n = 16usize;
295        let c = 8usize;
296        let k = 4usize;
297
298        let x: Vec<f32> = (0..(n * c))
299            .map(|i| ((i as f32) * 0.137 - 0.4).sin() * 0.7)
300            .collect();
301        let w: Vec<f32> = (0..(c * k))
302            .map(|i| ((i as f32) * 0.231 + 0.1).cos() * 0.5)
303            .collect();
304
305        let mut x_buf = alloc_f32(&device, n * c, vec![n, c]);
306        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
307        let mut w_buf = alloc_f32(&device, c * k, vec![c, k]);
308        w_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&w);
309        let y_buf = alloc_f32(&device, n * c, vec![n, c]);
310        let params = make_params(&device, n as u32, c as u32, k as u32);
311
312        let mut encoder = device.command_encoder().unwrap();
313        dispatch_conv1d_depthwise_causal_forward_f32(
314            &mut encoder, &mut registry, device.metal_device(),
315            &x_buf, &w_buf, &y_buf, &params,
316            n as u32, c as u32, k as u32,
317        ).unwrap();
318        encoder.commit_and_wait().unwrap();
319
320        let gpu = y_buf.as_slice::<f32>().unwrap();
321        let cpu = forward_cpu(&x, &w, n, c, k);
322        for i in 0..(n * c) {
323            assert!(
324                (gpu[i] - cpu[i]).abs() < 1e-5 * cpu[i].abs().max(1.0),
325                "forward y[{i}]: gpu={} cpu={}",
326                gpu[i], cpu[i]
327            );
328        }
329    }
330
331    #[test]
332    fn backward_dx_matches_cpu_oracle() {
333        let device = MlxDevice::new().expect("device");
334        let mut registry = KernelRegistry::new();
335        let n = 16usize;
336        let c = 8usize;
337        let k = 4usize;
338
339        let dy: Vec<f32> = (0..(n * c)).map(|i| ((i as f32) * 0.073 - 0.3).sin() * 0.6).collect();
340        let w: Vec<f32> = (0..(c * k)).map(|i| 0.1 + (i as f32) * 0.013).collect();
341
342        let mut dy_buf = alloc_f32(&device, n * c, vec![n, c]);
343        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
344        let mut w_buf = alloc_f32(&device, c * k, vec![c, k]);
345        w_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&w);
346        let dx_buf = alloc_f32(&device, n * c, vec![n, c]);
347        let params = make_params(&device, n as u32, c as u32, k as u32);
348
349        let mut encoder = device.command_encoder().unwrap();
350        dispatch_conv1d_depthwise_causal_backward_dx_f32(
351            &mut encoder, &mut registry, device.metal_device(),
352            &dy_buf, &w_buf, &dx_buf, &params,
353            n as u32, c as u32, k as u32,
354        ).unwrap();
355        encoder.commit_and_wait().unwrap();
356
357        let gpu = dx_buf.as_slice::<f32>().unwrap();
358        let cpu = backward_dx_cpu(&dy, &w, n, c, k);
359        for i in 0..(n * c) {
360            assert!(
361                (gpu[i] - cpu[i]).abs() < 1e-5 * cpu[i].abs().max(1.0),
362                "dx[{i}]: gpu={} cpu={}",
363                gpu[i], cpu[i]
364            );
365        }
366    }
367
368    #[test]
369    fn backward_dw_matches_cpu_oracle() {
370        let device = MlxDevice::new().expect("device");
371        let mut registry = KernelRegistry::new();
372        let n = 32usize;
373        let c = 8usize;
374        let k = 4usize;
375
376        let x: Vec<f32> = (0..(n * c)).map(|i| ((i as f32) * 0.041 - 0.5).cos() * 0.7).collect();
377        let dy: Vec<f32> = (0..(n * c)).map(|i| ((i as f32) * 0.073 - 0.3).sin() * 0.6).collect();
378
379        let mut x_buf = alloc_f32(&device, n * c, vec![n, c]);
380        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
381        let mut dy_buf = alloc_f32(&device, n * c, vec![n, c]);
382        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
383        let dw_buf = alloc_f32(&device, c * k, vec![c, k]);
384        let params = make_params(&device, n as u32, c as u32, k as u32);
385
386        let mut encoder = device.command_encoder().unwrap();
387        dispatch_conv1d_depthwise_causal_backward_dw_f32(
388            &mut encoder, &mut registry, device.metal_device(),
389            &x_buf, &dy_buf, &dw_buf, &params,
390            n as u32, c as u32, k as u32,
391        ).unwrap();
392        encoder.commit_and_wait().unwrap();
393
394        let gpu = dw_buf.as_slice::<f32>().unwrap();
395        let cpu = backward_dw_cpu(&x, &dy, n, c, k);
396        for i in 0..(c * k) {
397            assert!(
398                (gpu[i] - cpu[i]).abs() < 1e-4 * cpu[i].abs().max(1.0),
399                "dw[{i}]: gpu={} cpu={}",
400                gpu[i], cpu[i]
401            );
402        }
403    }
404
405    /// Finite-difference falsifier: verify analytic dw and dx match
406    /// numerical gradient of `loss = sum(forward(x, w))` to within 1%
407    /// relative tolerance.  This is THE load-bearing correctness gate.
408    #[test]
409    fn backward_finite_difference_falsifier() {
410        let device = MlxDevice::new().expect("device");
411        let mut registry = KernelRegistry::new();
412        let n = 8usize;
413        let c = 4usize;
414        let k = 3usize;
415
416        let x: Vec<f32> = (0..(n * c)).map(|i| ((i as f32) * 0.137).sin() * 0.6).collect();
417        let w: Vec<f32> = (0..(c * k)).map(|i| 0.2 + (i as f32) * 0.05).collect();
418
419        let forward_loss = |x: &[f32], w: &[f32]| -> f64 {
420            let y = forward_cpu(x, w, n, c, k);
421            y.iter().map(|v| *v as f64).sum::<f64>()
422        };
423
424        // Analytic gradients via dy = ones.
425        let dy_ones = vec![1.0f32; n * c];
426        let mut dy_buf = alloc_f32(&device, n * c, vec![n, c]);
427        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
428
429        let mut x_buf = alloc_f32(&device, n * c, vec![n, c]);
430        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
431        let mut w_buf = alloc_f32(&device, c * k, vec![c, k]);
432        w_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&w);
433        let dx_buf = alloc_f32(&device, n * c, vec![n, c]);
434        let dw_buf = alloc_f32(&device, c * k, vec![c, k]);
435        let params = make_params(&device, n as u32, c as u32, k as u32);
436
437        let mut encoder = device.command_encoder().unwrap();
438        dispatch_conv1d_depthwise_causal_backward_dx_f32(
439            &mut encoder, &mut registry, device.metal_device(),
440            &dy_buf, &w_buf, &dx_buf, &params,
441            n as u32, c as u32, k as u32,
442        ).unwrap();
443        dispatch_conv1d_depthwise_causal_backward_dw_f32(
444            &mut encoder, &mut registry, device.metal_device(),
445            &x_buf, &dy_buf, &dw_buf, &params,
446            n as u32, c as u32, k as u32,
447        ).unwrap();
448        encoder.commit_and_wait().unwrap();
449        let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
450        let dw = dw_buf.as_slice::<f32>().unwrap().to_vec();
451
452        // FD on x.
453        let h = 1e-3f64;
454        for i in 0..(n * c) {
455            let mut xp = x.clone();
456            xp[i] += h as f32;
457            let mut xm = x.clone();
458            xm[i] -= h as f32;
459            let fd = (forward_loss(&xp, &w) - forward_loss(&xm, &w)) / (2.0 * h);
460            let tol = 1e-2 * fd.abs().max(1.0);
461            assert!(
462                (dx[i] as f64 - fd).abs() < tol,
463                "FD x[{i}]: analytic={} fd={}", dx[i], fd
464            );
465        }
466        // FD on w.
467        for i in 0..(c * k) {
468            let mut wp = w.clone();
469            wp[i] += h as f32;
470            let mut wm = w.clone();
471            wm[i] -= h as f32;
472            let fd = (forward_loss(&x, &wp) - forward_loss(&x, &wm)) / (2.0 * h);
473            let tol = 1e-2 * fd.abs().max(1.0);
474            assert!(
475                (dw[i] as f64 - fd).abs() < tol,
476                "FD w[{i}]: analytic={} fd={}", dw[i], fd
477            );
478        }
479    }
480
481    #[test]
482    fn rejects_zero_dimensions() {
483        let device = MlxDevice::new().expect("device");
484        let mut registry = KernelRegistry::new();
485        let x_buf = alloc_f32(&device, 1, vec![1, 1]);
486        let w_buf = alloc_f32(&device, 1, vec![1, 1]);
487        let y_buf = alloc_f32(&device, 1, vec![1, 1]);
488        let params = make_params(&device, 0, 1, 1);
489        let mut encoder = device.command_encoder().unwrap();
490        let res = dispatch_conv1d_depthwise_causal_forward_f32(
491            &mut encoder, &mut registry, device.metal_device(),
492            &x_buf, &w_buf, &y_buf, &params, 0, 1, 1,
493        );
494        assert!(res.is_err());
495    }
496}