Skip to main content

mlx_native/ops/
divide_elementwise.rs

1//! ADR-020 iter-11h-misc-1 — elementwise division forward + backward.
2
3use metal::MTLSize;
4
5use crate::buffer::MlxBuffer;
6use crate::dtypes::DType;
7use crate::encoder::CommandEncoder;
8use crate::error::{MlxError, Result};
9use crate::kernel_registry::KernelRegistry;
10
11pub static DIVIDE_ELEMENTWISE_SHADER_SOURCE: &str =
12    include_str!("../shaders/divide_elementwise.metal");
13
14pub fn register(registry: &mut KernelRegistry) {
15    registry.register_source("divide_f32", DIVIDE_ELEMENTWISE_SHADER_SOURCE);
16    registry.register_source(
17        "divide_backward_f32",
18        DIVIDE_ELEMENTWISE_SHADER_SOURCE,
19    );
20}
21
22pub fn dispatch_divide_f32(
23    encoder: &mut CommandEncoder,
24    registry: &mut KernelRegistry,
25    device: &metal::DeviceRef,
26    a: &MlxBuffer,
27    b: &MlxBuffer,
28    y: &MlxBuffer,
29    params: &MlxBuffer,
30) -> Result<()> {
31    const OP: &str = "divide_f32";
32    let n = a.element_count();
33    if n == 0 {
34        return Err(MlxError::InvalidArgument(format!("{OP}: empty input")));
35    }
36    if b.element_count() != n || y.element_count() != n {
37        return Err(MlxError::InvalidArgument(format!(
38            "{OP}: shape mismatch (a={}, b={}, y={})",
39            n, b.element_count(), y.element_count()
40        )));
41    }
42    if a.dtype() != DType::F32 || b.dtype() != DType::F32 || y.dtype() != DType::F32 {
43        return Err(MlxError::InvalidArgument(format!("{OP}: must be f32")));
44    }
45    if params.byte_len() < 4 {
46        return Err(MlxError::InvalidArgument(format!(
47            "{OP}: params < 4 bytes"
48        )));
49    }
50    let pipeline = registry.get_pipeline(OP, device)?;
51    let n_u64 = n as u64;
52    encoder.encode(
53        pipeline,
54        &[(0, a), (1, b), (2, y), (3, params)],
55        MTLSize::new(n_u64, 1, 1),
56        MTLSize::new(std::cmp::min(256, n_u64), 1, 1),
57    );
58    Ok(())
59}
60
61#[allow(clippy::too_many_arguments)]
62pub fn dispatch_divide_backward_f32(
63    encoder: &mut CommandEncoder,
64    registry: &mut KernelRegistry,
65    device: &metal::DeviceRef,
66    b: &MlxBuffer,
67    y: &MlxBuffer,
68    dy: &MlxBuffer,
69    da: &MlxBuffer,
70    db: &MlxBuffer,
71    params: &MlxBuffer,
72) -> Result<()> {
73    const OP: &str = "divide_backward_f32";
74    let n = b.element_count();
75    if y.element_count() != n
76        || dy.element_count() != n
77        || da.element_count() != n
78        || db.element_count() != n
79    {
80        return Err(MlxError::InvalidArgument(format!(
81            "{OP}: shape mismatch n={n}, b/y/dy/da/db must match"
82        )));
83    }
84    if b.dtype() != DType::F32 || y.dtype() != DType::F32 || dy.dtype() != DType::F32
85        || da.dtype() != DType::F32 || db.dtype() != DType::F32
86    {
87        return Err(MlxError::InvalidArgument(format!("{OP}: must be f32")));
88    }
89    let pipeline = registry.get_pipeline(OP, device)?;
90    let n_u64 = n as u64;
91    encoder.encode(
92        pipeline,
93        &[(0, b), (1, y), (2, dy), (3, da), (4, db), (5, params)],
94        MTLSize::new(n_u64, 1, 1),
95        MTLSize::new(std::cmp::min(256, n_u64), 1, 1),
96    );
97    Ok(())
98}
99
100#[cfg(test)]
101mod tests {
102    use super::*;
103    use crate::device::MlxDevice;
104
105    fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
106        let mut bx = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
107        bx.as_mut_slice::<f32>().unwrap().fill(0.0);
108        bx
109    }
110    fn make_params(d: &MlxDevice, n: u32) -> MlxBuffer {
111        let mut p = d.alloc_buffer(4, DType::U32, vec![1]).unwrap();
112        p.as_mut_slice::<u32>().unwrap()[0] = n;
113        p
114    }
115
116    #[test]
117    fn forward_matches_cpu_oracle() {
118        let device = MlxDevice::new().unwrap();
119        let mut registry = KernelRegistry::new();
120        let n = 32usize;
121        let a: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.1).collect();
122        let b: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.07).collect();
123
124        let mut a_buf = alloc_f32(&device, n);
125        a_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&a);
126        let mut b_buf = alloc_f32(&device, n);
127        b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&b);
128        let y_buf = alloc_f32(&device, n);
129        let p = make_params(&device, n as u32);
130
131        let mut encoder = device.command_encoder().unwrap();
132        dispatch_divide_f32(
133            &mut encoder, &mut registry, device.metal_device(),
134            &a_buf, &b_buf, &y_buf, &p,
135        ).unwrap();
136        encoder.commit_and_wait().unwrap();
137
138        let gpu = y_buf.as_slice::<f32>().unwrap();
139        for i in 0..n {
140            let cpu = a[i] / b[i];
141            assert!(
142                (gpu[i] - cpu).abs() < 1e-6 * cpu.abs().max(1.0),
143                "y[{i}]: gpu={} cpu={}",
144                gpu[i], cpu
145            );
146        }
147    }
148
149    #[test]
150    fn backward_finite_difference_falsifier() {
151        let device = MlxDevice::new().unwrap();
152        let mut registry = KernelRegistry::new();
153        let n = 16usize;
154        let a: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.05).collect();
155        let b: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.07).collect();
156        let dy: Vec<f32> = vec![1.0; n];
157
158        let mut a_buf = alloc_f32(&device, n);
159        a_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&a);
160        let mut b_buf = alloc_f32(&device, n);
161        b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&b);
162        let y_buf = alloc_f32(&device, n);
163        let mut dy_buf = alloc_f32(&device, n);
164        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
165        let da_buf = alloc_f32(&device, n);
166        let db_buf = alloc_f32(&device, n);
167        let p = make_params(&device, n as u32);
168
169        let mut encoder = device.command_encoder().unwrap();
170        dispatch_divide_f32(
171            &mut encoder, &mut registry, device.metal_device(),
172            &a_buf, &b_buf, &y_buf, &p,
173        ).unwrap();
174        // RAW barrier — backward reads y written by forward.  Apple Metal
175        // compute encoders run threadgroups in parallel by default;
176        // without barriers the backward kernel reads zero-init y values
177        // (per `feedback_metal_raw_barrier_per_dispatch` memory).
178        encoder.memory_barrier();
179        dispatch_divide_backward_f32(
180            &mut encoder, &mut registry, device.metal_device(),
181            &b_buf, &y_buf, &dy_buf, &da_buf, &db_buf, &p,
182        ).unwrap();
183        encoder.commit_and_wait().unwrap();
184
185        let da = da_buf.as_slice::<f32>().unwrap().to_vec();
186        let db = db_buf.as_slice::<f32>().unwrap().to_vec();
187
188        // FD on a and b. Loss = sum(a/b).
189        let h = 1e-3f64;
190        let loss = |aa: &[f32], bb: &[f32]| -> f64 {
191            (0..n).map(|i| aa[i] as f64 / bb[i] as f64).sum::<f64>()
192        };
193        for i in 0..n {
194            let mut ap = a.clone(); ap[i] += h as f32;
195            let mut am = a.clone(); am[i] -= h as f32;
196            let fd = (loss(&ap, &b) - loss(&am, &b)) / (2.0 * h);
197            let tol = 1e-3 * fd.abs().max(1.0);
198            assert!(
199                (da[i] as f64 - fd).abs() < tol,
200                "FD a[{i}]: analytic={} fd={}", da[i], fd
201            );
202        }
203        for i in 0..n {
204            let mut bp = b.clone(); bp[i] += h as f32;
205            let mut bm = b.clone(); bm[i] -= h as f32;
206            let fd = (loss(&a, &bp) - loss(&a, &bm)) / (2.0 * h);
207            let tol = 1e-3 * fd.abs().max(1.0);
208            assert!(
209                (db[i] as f64 - fd).abs() < tol,
210                "FD b[{i}]: analytic={} fd={}", db[i], fd
211            );
212        }
213    }
214
215    #[test]
216    fn rejects_size_mismatch() {
217        let device = MlxDevice::new().unwrap();
218        let mut registry = KernelRegistry::new();
219        let a = alloc_f32(&device, 16);
220        let b = alloc_f32(&device, 8); // wrong
221        let y = alloc_f32(&device, 16);
222        let p = make_params(&device, 16);
223        let mut encoder = device.command_encoder().unwrap();
224        let res = dispatch_divide_f32(
225            &mut encoder, &mut registry, device.metal_device(),
226            &a, &b, &y, &p,
227        );
228        assert!(res.is_err());
229    }
230}