Skip to main content

mlx_native/ops/
outer_product.rs

1//! ADR-020 iter-11h-c2 — vector outer product forward + backward.
2//!
3//! Forward:  `y[i, j] = lhs[i] · rhs[j]`
4//! Backward: `dlhs[i] = Σ_j dy[i, j] · rhs[j]`
5//!           `drhs[j] = Σ_i dy[i, j] · lhs[i]`
6//!
7//! Distinct from matmul: matmul kernel has a 32-element floor on each
8//! dim (`M, N, K ≥ 32` for dW backward); outer products have
9//! inner-dim = 1, falling below that floor.
10
11use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::dtypes::DType;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19pub static OUTER_PRODUCT_SHADER_SOURCE: &str =
20    include_str!("../shaders/outer_product.metal");
21
22pub fn register(registry: &mut KernelRegistry) {
23    registry.register_source("outer_product_f32", OUTER_PRODUCT_SHADER_SOURCE);
24    registry.register_source(
25        "outer_product_backward_lhs_f32",
26        OUTER_PRODUCT_SHADER_SOURCE,
27    );
28    registry.register_source(
29        "outer_product_backward_rhs_f32",
30        OUTER_PRODUCT_SHADER_SOURCE,
31    );
32}
33
34fn validate_dims(op: &str, n: u32, m: u32, params: &MlxBuffer) -> Result<()> {
35    if n == 0 || m == 0 {
36        return Err(MlxError::InvalidArgument(format!(
37            "{op}: N and M must both be > 0 (got {n}, {m})"
38        )));
39    }
40    if params.byte_len() < 8 {
41        return Err(MlxError::InvalidArgument(format!(
42            "{op}: params < 8 bytes (need 2 × u32)"
43        )));
44    }
45    Ok(())
46}
47
48pub fn dispatch_outer_product_f32(
49    encoder: &mut CommandEncoder,
50    registry: &mut KernelRegistry,
51    device: &metal::DeviceRef,
52    lhs: &MlxBuffer,
53    rhs: &MlxBuffer,
54    y: &MlxBuffer,
55    params: &MlxBuffer,
56    n: u32,
57    m: u32,
58) -> Result<()> {
59    const OP: &str = "outer_product_f32";
60    validate_dims(OP, n, m, params)?;
61    if lhs.dtype() != DType::F32 || rhs.dtype() != DType::F32 || y.dtype() != DType::F32 {
62        return Err(MlxError::InvalidArgument(format!(
63            "{OP}: all buffers must be f32"
64        )));
65    }
66    if lhs.element_count() != n as usize {
67        return Err(MlxError::InvalidArgument(format!(
68            "{OP}: lhs.element_count {} != N {n}",
69            lhs.element_count()
70        )));
71    }
72    if rhs.element_count() != m as usize {
73        return Err(MlxError::InvalidArgument(format!(
74            "{OP}: rhs.element_count {} != M {m}",
75            rhs.element_count()
76        )));
77    }
78    if y.element_count() != (n as usize) * (m as usize) {
79        return Err(MlxError::InvalidArgument(format!(
80            "{OP}: y.element_count {} != N*M = {}",
81            y.element_count(),
82            n as usize * m as usize
83        )));
84    }
85
86    let pipeline = registry.get_pipeline(OP, device)?;
87    encoder.encode(
88        pipeline,
89        &[(0, lhs), (1, rhs), (2, y), (3, params)],
90        MTLSize::new(n as u64, m as u64, 1),
91        MTLSize::new(
92            std::cmp::min(16, n as u64),
93            std::cmp::min(16, m as u64),
94            1,
95        ),
96    );
97    Ok(())
98}
99
100pub fn dispatch_outer_product_backward_lhs_f32(
101    encoder: &mut CommandEncoder,
102    registry: &mut KernelRegistry,
103    device: &metal::DeviceRef,
104    dy: &MlxBuffer,
105    rhs: &MlxBuffer,
106    dlhs: &MlxBuffer,
107    params: &MlxBuffer,
108    n: u32,
109    m: u32,
110) -> Result<()> {
111    const OP: &str = "outer_product_backward_lhs_f32";
112    validate_dims(OP, n, m, params)?;
113    if dy.element_count() != (n as usize) * (m as usize)
114        || rhs.element_count() != m as usize
115        || dlhs.element_count() != n as usize
116    {
117        return Err(MlxError::InvalidArgument(format!(
118            "{OP}: shape mismatch (dy={}, rhs={}, dlhs={})",
119            dy.element_count(),
120            rhs.element_count(),
121            dlhs.element_count()
122        )));
123    }
124
125    let pipeline = registry.get_pipeline(OP, device)?;
126    encoder.encode(
127        pipeline,
128        &[(0, dy), (1, rhs), (2, dlhs), (3, params)],
129        MTLSize::new(n as u64, 1, 1),
130        MTLSize::new(std::cmp::min(64, n as u64), 1, 1),
131    );
132    Ok(())
133}
134
135pub fn dispatch_outer_product_backward_rhs_f32(
136    encoder: &mut CommandEncoder,
137    registry: &mut KernelRegistry,
138    device: &metal::DeviceRef,
139    dy: &MlxBuffer,
140    lhs: &MlxBuffer,
141    drhs: &MlxBuffer,
142    params: &MlxBuffer,
143    n: u32,
144    m: u32,
145) -> Result<()> {
146    const OP: &str = "outer_product_backward_rhs_f32";
147    validate_dims(OP, n, m, params)?;
148    if dy.element_count() != (n as usize) * (m as usize)
149        || lhs.element_count() != n as usize
150        || drhs.element_count() != m as usize
151    {
152        return Err(MlxError::InvalidArgument(format!(
153            "{OP}: shape mismatch (dy={}, lhs={}, drhs={})",
154            dy.element_count(),
155            lhs.element_count(),
156            drhs.element_count()
157        )));
158    }
159
160    let pipeline = registry.get_pipeline(OP, device)?;
161    encoder.encode(
162        pipeline,
163        &[(0, dy), (1, lhs), (2, drhs), (3, params)],
164        MTLSize::new(m as u64, 1, 1),
165        MTLSize::new(std::cmp::min(64, m as u64), 1, 1),
166    );
167    Ok(())
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173    use crate::device::MlxDevice;
174
175    fn alloc_f32(device: &MlxDevice, n: usize, shape: Vec<usize>) -> MlxBuffer {
176        let mut b = device.alloc_buffer(n * 4, DType::F32, shape).unwrap();
177        b.as_mut_slice::<f32>().unwrap().fill(0.0);
178        b
179    }
180
181    fn make_params(device: &MlxDevice, n: u32, m: u32) -> MlxBuffer {
182        let mut p = device.alloc_buffer(8, DType::U32, vec![2]).unwrap();
183        p.as_mut_slice::<u32>().unwrap().copy_from_slice(&[n, m]);
184        p
185    }
186
187    #[test]
188    fn forward_matches_cpu_oracle() {
189        let device = MlxDevice::new().unwrap();
190        let mut registry = KernelRegistry::new();
191        let n = 8usize;
192        let m = 5usize;
193        let lhs: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.1).collect();
194        let rhs: Vec<f32> = (0..m).map(|i| ((i as f32) * 0.137 - 0.3)).collect();
195
196        let mut lhs_buf = alloc_f32(&device, n, vec![n]);
197        lhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&lhs);
198        let mut rhs_buf = alloc_f32(&device, m, vec![m]);
199        rhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&rhs);
200        let y_buf = alloc_f32(&device, n * m, vec![n, m]);
201        let params = make_params(&device, n as u32, m as u32);
202
203        let mut encoder = device.command_encoder().unwrap();
204        dispatch_outer_product_f32(
205            &mut encoder, &mut registry, device.metal_device(),
206            &lhs_buf, &rhs_buf, &y_buf, &params, n as u32, m as u32,
207        ).unwrap();
208        encoder.commit_and_wait().unwrap();
209
210        let gpu = y_buf.as_slice::<f32>().unwrap();
211        for i in 0..n {
212            for j in 0..m {
213                let expected = lhs[i] * rhs[j];
214                assert!(
215                    (gpu[i * m + j] - expected).abs() < 1e-6 * expected.abs().max(1.0),
216                    "y[{i},{j}]: gpu={} expected={}",
217                    gpu[i * m + j], expected
218                );
219            }
220        }
221    }
222
223    #[test]
224    fn backward_dlhs_drhs_match_cpu_oracle() {
225        let device = MlxDevice::new().unwrap();
226        let mut registry = KernelRegistry::new();
227        let n = 8usize;
228        let m = 5usize;
229        let lhs: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.1).collect();
230        let rhs: Vec<f32> = (0..m).map(|i| 0.2 + (i as f32) * 0.07).collect();
231        let dy: Vec<f32> = (0..(n * m)).map(|i| ((i as f32) * 0.131 - 0.4).sin()).collect();
232
233        let mut lhs_buf = alloc_f32(&device, n, vec![n]);
234        lhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&lhs);
235        let mut rhs_buf = alloc_f32(&device, m, vec![m]);
236        rhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&rhs);
237        let mut dy_buf = alloc_f32(&device, n * m, vec![n, m]);
238        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
239        let dlhs_buf = alloc_f32(&device, n, vec![n]);
240        let drhs_buf = alloc_f32(&device, m, vec![m]);
241        let params = make_params(&device, n as u32, m as u32);
242
243        let mut encoder = device.command_encoder().unwrap();
244        dispatch_outer_product_backward_lhs_f32(
245            &mut encoder, &mut registry, device.metal_device(),
246            &dy_buf, &rhs_buf, &dlhs_buf, &params, n as u32, m as u32,
247        ).unwrap();
248        dispatch_outer_product_backward_rhs_f32(
249            &mut encoder, &mut registry, device.metal_device(),
250            &dy_buf, &lhs_buf, &drhs_buf, &params, n as u32, m as u32,
251        ).unwrap();
252        encoder.commit_and_wait().unwrap();
253
254        let dlhs = dlhs_buf.as_slice::<f32>().unwrap();
255        let drhs = drhs_buf.as_slice::<f32>().unwrap();
256        for i in 0..n {
257            let expected: f64 = (0..m).map(|j| dy[i * m + j] as f64 * rhs[j] as f64).sum();
258            assert!(
259                (dlhs[i] as f64 - expected).abs() < 1e-5 * expected.abs().max(1.0),
260                "dlhs[{i}]: gpu={} expected={}",
261                dlhs[i], expected
262            );
263        }
264        for j in 0..m {
265            let expected: f64 = (0..n).map(|i| dy[i * m + j] as f64 * lhs[i] as f64).sum();
266            assert!(
267                (drhs[j] as f64 - expected).abs() < 1e-5 * expected.abs().max(1.0),
268                "drhs[{j}]: gpu={} expected={}",
269                drhs[j], expected
270            );
271        }
272    }
273
274    /// FD falsifier: loss = sum(outer(lhs, rhs)) = sum(lhs) * sum(rhs).
275    /// Analytic dlhs[i] = sum(rhs), drhs[j] = sum(lhs). Verify FD matches.
276    #[test]
277    fn backward_finite_difference_falsifier() {
278        let device = MlxDevice::new().unwrap();
279        let mut registry = KernelRegistry::new();
280        let n = 6usize;
281        let m = 4usize;
282        let lhs: Vec<f32> = (0..n).map(|i| 0.3 + (i as f32) * 0.07).collect();
283        let rhs: Vec<f32> = (0..m).map(|i| 0.5 + (i as f32) * 0.05).collect();
284
285        let mut lhs_buf = alloc_f32(&device, n, vec![n]);
286        lhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&lhs);
287        let mut rhs_buf = alloc_f32(&device, m, vec![m]);
288        rhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&rhs);
289        let dy_ones = vec![1.0f32; n * m];
290        let mut dy_buf = alloc_f32(&device, n * m, vec![n, m]);
291        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
292        let dlhs_buf = alloc_f32(&device, n, vec![n]);
293        let drhs_buf = alloc_f32(&device, m, vec![m]);
294        let params = make_params(&device, n as u32, m as u32);
295
296        let mut encoder = device.command_encoder().unwrap();
297        dispatch_outer_product_backward_lhs_f32(
298            &mut encoder, &mut registry, device.metal_device(),
299            &dy_buf, &rhs_buf, &dlhs_buf, &params, n as u32, m as u32,
300        ).unwrap();
301        dispatch_outer_product_backward_rhs_f32(
302            &mut encoder, &mut registry, device.metal_device(),
303            &dy_buf, &lhs_buf, &drhs_buf, &params, n as u32, m as u32,
304        ).unwrap();
305        encoder.commit_and_wait().unwrap();
306        let dlhs = dlhs_buf.as_slice::<f32>().unwrap().to_vec();
307        let drhs = drhs_buf.as_slice::<f32>().unwrap().to_vec();
308
309        let h = 1e-3f64;
310        let loss = |l: &[f32], r: &[f32]| -> f64 {
311            let mut s = 0.0f64;
312            for i in 0..n { for j in 0..m { s += l[i] as f64 * r[j] as f64; } }
313            s
314        };
315        for i in 0..n {
316            let mut lp = lhs.clone(); lp[i] += h as f32;
317            let mut lm = lhs.clone(); lm[i] -= h as f32;
318            let fd = (loss(&lp, &rhs) - loss(&lm, &rhs)) / (2.0 * h);
319            let tol = 1e-2 * fd.abs().max(1.0);
320            assert!(
321                (dlhs[i] as f64 - fd).abs() < tol,
322                "FD lhs[{i}]: analytic={} fd={}", dlhs[i], fd
323            );
324        }
325        for j in 0..m {
326            let mut rp = rhs.clone(); rp[j] += h as f32;
327            let mut rm = rhs.clone(); rm[j] -= h as f32;
328            let fd = (loss(&lhs, &rp) - loss(&lhs, &rm)) / (2.0 * h);
329            let tol = 1e-2 * fd.abs().max(1.0);
330            assert!(
331                (drhs[j] as f64 - fd).abs() < tol,
332                "FD rhs[{j}]: analytic={} fd={}", drhs[j], fd
333            );
334        }
335    }
336
337    #[test]
338    fn rejects_zero_dims() {
339        let device = MlxDevice::new().unwrap();
340        let mut registry = KernelRegistry::new();
341        let lhs = alloc_f32(&device, 1, vec![1]);
342        let rhs = alloc_f32(&device, 1, vec![1]);
343        let y = alloc_f32(&device, 1, vec![1, 1]);
344        let params = make_params(&device, 0, 1);
345        let mut encoder = device.command_encoder().unwrap();
346        let res = dispatch_outer_product_f32(
347            &mut encoder, &mut registry, device.metal_device(),
348            &lhs, &rhs, &y, &params, 0, 1,
349        );
350        assert!(res.is_err());
351    }
352}