1use metal::MTLSize;
12
13use crate::buffer::MlxBuffer;
14use crate::dtypes::DType;
15use crate::encoder::CommandEncoder;
16use crate::error::{MlxError, Result};
17use crate::kernel_registry::KernelRegistry;
18
19pub static OUTER_PRODUCT_SHADER_SOURCE: &str =
20 include_str!("../shaders/outer_product.metal");
21
22pub fn register(registry: &mut KernelRegistry) {
23 registry.register_source("outer_product_f32", OUTER_PRODUCT_SHADER_SOURCE);
24 registry.register_source(
25 "outer_product_backward_lhs_f32",
26 OUTER_PRODUCT_SHADER_SOURCE,
27 );
28 registry.register_source(
29 "outer_product_backward_rhs_f32",
30 OUTER_PRODUCT_SHADER_SOURCE,
31 );
32}
33
34fn validate_dims(op: &str, n: u32, m: u32, params: &MlxBuffer) -> Result<()> {
35 if n == 0 || m == 0 {
36 return Err(MlxError::InvalidArgument(format!(
37 "{op}: N and M must both be > 0 (got {n}, {m})"
38 )));
39 }
40 if params.byte_len() < 8 {
41 return Err(MlxError::InvalidArgument(format!(
42 "{op}: params < 8 bytes (need 2 × u32)"
43 )));
44 }
45 Ok(())
46}
47
48pub fn dispatch_outer_product_f32(
49 encoder: &mut CommandEncoder,
50 registry: &mut KernelRegistry,
51 device: &metal::DeviceRef,
52 lhs: &MlxBuffer,
53 rhs: &MlxBuffer,
54 y: &MlxBuffer,
55 params: &MlxBuffer,
56 n: u32,
57 m: u32,
58) -> Result<()> {
59 const OP: &str = "outer_product_f32";
60 validate_dims(OP, n, m, params)?;
61 if lhs.dtype() != DType::F32 || rhs.dtype() != DType::F32 || y.dtype() != DType::F32 {
62 return Err(MlxError::InvalidArgument(format!(
63 "{OP}: all buffers must be f32"
64 )));
65 }
66 if lhs.element_count() != n as usize {
67 return Err(MlxError::InvalidArgument(format!(
68 "{OP}: lhs.element_count {} != N {n}",
69 lhs.element_count()
70 )));
71 }
72 if rhs.element_count() != m as usize {
73 return Err(MlxError::InvalidArgument(format!(
74 "{OP}: rhs.element_count {} != M {m}",
75 rhs.element_count()
76 )));
77 }
78 if y.element_count() != (n as usize) * (m as usize) {
79 return Err(MlxError::InvalidArgument(format!(
80 "{OP}: y.element_count {} != N*M = {}",
81 y.element_count(),
82 n as usize * m as usize
83 )));
84 }
85
86 let pipeline = registry.get_pipeline(OP, device)?;
87 encoder.encode(
88 pipeline,
89 &[(0, lhs), (1, rhs), (2, y), (3, params)],
90 MTLSize::new(n as u64, m as u64, 1),
91 MTLSize::new(
92 std::cmp::min(16, n as u64),
93 std::cmp::min(16, m as u64),
94 1,
95 ),
96 );
97 Ok(())
98}
99
100pub fn dispatch_outer_product_backward_lhs_f32(
101 encoder: &mut CommandEncoder,
102 registry: &mut KernelRegistry,
103 device: &metal::DeviceRef,
104 dy: &MlxBuffer,
105 rhs: &MlxBuffer,
106 dlhs: &MlxBuffer,
107 params: &MlxBuffer,
108 n: u32,
109 m: u32,
110) -> Result<()> {
111 const OP: &str = "outer_product_backward_lhs_f32";
112 validate_dims(OP, n, m, params)?;
113 if dy.element_count() != (n as usize) * (m as usize)
114 || rhs.element_count() != m as usize
115 || dlhs.element_count() != n as usize
116 {
117 return Err(MlxError::InvalidArgument(format!(
118 "{OP}: shape mismatch (dy={}, rhs={}, dlhs={})",
119 dy.element_count(),
120 rhs.element_count(),
121 dlhs.element_count()
122 )));
123 }
124
125 let pipeline = registry.get_pipeline(OP, device)?;
126 encoder.encode(
127 pipeline,
128 &[(0, dy), (1, rhs), (2, dlhs), (3, params)],
129 MTLSize::new(n as u64, 1, 1),
130 MTLSize::new(std::cmp::min(64, n as u64), 1, 1),
131 );
132 Ok(())
133}
134
135pub fn dispatch_outer_product_backward_rhs_f32(
136 encoder: &mut CommandEncoder,
137 registry: &mut KernelRegistry,
138 device: &metal::DeviceRef,
139 dy: &MlxBuffer,
140 lhs: &MlxBuffer,
141 drhs: &MlxBuffer,
142 params: &MlxBuffer,
143 n: u32,
144 m: u32,
145) -> Result<()> {
146 const OP: &str = "outer_product_backward_rhs_f32";
147 validate_dims(OP, n, m, params)?;
148 if dy.element_count() != (n as usize) * (m as usize)
149 || lhs.element_count() != n as usize
150 || drhs.element_count() != m as usize
151 {
152 return Err(MlxError::InvalidArgument(format!(
153 "{OP}: shape mismatch (dy={}, lhs={}, drhs={})",
154 dy.element_count(),
155 lhs.element_count(),
156 drhs.element_count()
157 )));
158 }
159
160 let pipeline = registry.get_pipeline(OP, device)?;
161 encoder.encode(
162 pipeline,
163 &[(0, dy), (1, lhs), (2, drhs), (3, params)],
164 MTLSize::new(m as u64, 1, 1),
165 MTLSize::new(std::cmp::min(64, m as u64), 1, 1),
166 );
167 Ok(())
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::device::MlxDevice;
174
175 fn alloc_f32(device: &MlxDevice, n: usize, shape: Vec<usize>) -> MlxBuffer {
176 let mut b = device.alloc_buffer(n * 4, DType::F32, shape).unwrap();
177 b.as_mut_slice::<f32>().unwrap().fill(0.0);
178 b
179 }
180
181 fn make_params(device: &MlxDevice, n: u32, m: u32) -> MlxBuffer {
182 let mut p = device.alloc_buffer(8, DType::U32, vec![2]).unwrap();
183 p.as_mut_slice::<u32>().unwrap().copy_from_slice(&[n, m]);
184 p
185 }
186
187 #[test]
188 fn forward_matches_cpu_oracle() {
189 let device = MlxDevice::new().unwrap();
190 let mut registry = KernelRegistry::new();
191 let n = 8usize;
192 let m = 5usize;
193 let lhs: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.1).collect();
194 let rhs: Vec<f32> = (0..m).map(|i| ((i as f32) * 0.137 - 0.3)).collect();
195
196 let mut lhs_buf = alloc_f32(&device, n, vec![n]);
197 lhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&lhs);
198 let mut rhs_buf = alloc_f32(&device, m, vec![m]);
199 rhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&rhs);
200 let y_buf = alloc_f32(&device, n * m, vec![n, m]);
201 let params = make_params(&device, n as u32, m as u32);
202
203 let mut encoder = device.command_encoder().unwrap();
204 dispatch_outer_product_f32(
205 &mut encoder, &mut registry, device.metal_device(),
206 &lhs_buf, &rhs_buf, &y_buf, ¶ms, n as u32, m as u32,
207 ).unwrap();
208 encoder.commit_and_wait().unwrap();
209
210 let gpu = y_buf.as_slice::<f32>().unwrap();
211 for i in 0..n {
212 for j in 0..m {
213 let expected = lhs[i] * rhs[j];
214 assert!(
215 (gpu[i * m + j] - expected).abs() < 1e-6 * expected.abs().max(1.0),
216 "y[{i},{j}]: gpu={} expected={}",
217 gpu[i * m + j], expected
218 );
219 }
220 }
221 }
222
223 #[test]
224 fn backward_dlhs_drhs_match_cpu_oracle() {
225 let device = MlxDevice::new().unwrap();
226 let mut registry = KernelRegistry::new();
227 let n = 8usize;
228 let m = 5usize;
229 let lhs: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.1).collect();
230 let rhs: Vec<f32> = (0..m).map(|i| 0.2 + (i as f32) * 0.07).collect();
231 let dy: Vec<f32> = (0..(n * m)).map(|i| ((i as f32) * 0.131 - 0.4).sin()).collect();
232
233 let mut lhs_buf = alloc_f32(&device, n, vec![n]);
234 lhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&lhs);
235 let mut rhs_buf = alloc_f32(&device, m, vec![m]);
236 rhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&rhs);
237 let mut dy_buf = alloc_f32(&device, n * m, vec![n, m]);
238 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
239 let dlhs_buf = alloc_f32(&device, n, vec![n]);
240 let drhs_buf = alloc_f32(&device, m, vec![m]);
241 let params = make_params(&device, n as u32, m as u32);
242
243 let mut encoder = device.command_encoder().unwrap();
244 dispatch_outer_product_backward_lhs_f32(
245 &mut encoder, &mut registry, device.metal_device(),
246 &dy_buf, &rhs_buf, &dlhs_buf, ¶ms, n as u32, m as u32,
247 ).unwrap();
248 dispatch_outer_product_backward_rhs_f32(
249 &mut encoder, &mut registry, device.metal_device(),
250 &dy_buf, &lhs_buf, &drhs_buf, ¶ms, n as u32, m as u32,
251 ).unwrap();
252 encoder.commit_and_wait().unwrap();
253
254 let dlhs = dlhs_buf.as_slice::<f32>().unwrap();
255 let drhs = drhs_buf.as_slice::<f32>().unwrap();
256 for i in 0..n {
257 let expected: f64 = (0..m).map(|j| dy[i * m + j] as f64 * rhs[j] as f64).sum();
258 assert!(
259 (dlhs[i] as f64 - expected).abs() < 1e-5 * expected.abs().max(1.0),
260 "dlhs[{i}]: gpu={} expected={}",
261 dlhs[i], expected
262 );
263 }
264 for j in 0..m {
265 let expected: f64 = (0..n).map(|i| dy[i * m + j] as f64 * lhs[i] as f64).sum();
266 assert!(
267 (drhs[j] as f64 - expected).abs() < 1e-5 * expected.abs().max(1.0),
268 "drhs[{j}]: gpu={} expected={}",
269 drhs[j], expected
270 );
271 }
272 }
273
274 #[test]
277 fn backward_finite_difference_falsifier() {
278 let device = MlxDevice::new().unwrap();
279 let mut registry = KernelRegistry::new();
280 let n = 6usize;
281 let m = 4usize;
282 let lhs: Vec<f32> = (0..n).map(|i| 0.3 + (i as f32) * 0.07).collect();
283 let rhs: Vec<f32> = (0..m).map(|i| 0.5 + (i as f32) * 0.05).collect();
284
285 let mut lhs_buf = alloc_f32(&device, n, vec![n]);
286 lhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&lhs);
287 let mut rhs_buf = alloc_f32(&device, m, vec![m]);
288 rhs_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&rhs);
289 let dy_ones = vec![1.0f32; n * m];
290 let mut dy_buf = alloc_f32(&device, n * m, vec![n, m]);
291 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
292 let dlhs_buf = alloc_f32(&device, n, vec![n]);
293 let drhs_buf = alloc_f32(&device, m, vec![m]);
294 let params = make_params(&device, n as u32, m as u32);
295
296 let mut encoder = device.command_encoder().unwrap();
297 dispatch_outer_product_backward_lhs_f32(
298 &mut encoder, &mut registry, device.metal_device(),
299 &dy_buf, &rhs_buf, &dlhs_buf, ¶ms, n as u32, m as u32,
300 ).unwrap();
301 dispatch_outer_product_backward_rhs_f32(
302 &mut encoder, &mut registry, device.metal_device(),
303 &dy_buf, &lhs_buf, &drhs_buf, ¶ms, n as u32, m as u32,
304 ).unwrap();
305 encoder.commit_and_wait().unwrap();
306 let dlhs = dlhs_buf.as_slice::<f32>().unwrap().to_vec();
307 let drhs = drhs_buf.as_slice::<f32>().unwrap().to_vec();
308
309 let h = 1e-3f64;
310 let loss = |l: &[f32], r: &[f32]| -> f64 {
311 let mut s = 0.0f64;
312 for i in 0..n { for j in 0..m { s += l[i] as f64 * r[j] as f64; } }
313 s
314 };
315 for i in 0..n {
316 let mut lp = lhs.clone(); lp[i] += h as f32;
317 let mut lm = lhs.clone(); lm[i] -= h as f32;
318 let fd = (loss(&lp, &rhs) - loss(&lm, &rhs)) / (2.0 * h);
319 let tol = 1e-2 * fd.abs().max(1.0);
320 assert!(
321 (dlhs[i] as f64 - fd).abs() < tol,
322 "FD lhs[{i}]: analytic={} fd={}", dlhs[i], fd
323 );
324 }
325 for j in 0..m {
326 let mut rp = rhs.clone(); rp[j] += h as f32;
327 let mut rm = rhs.clone(); rm[j] -= h as f32;
328 let fd = (loss(&lhs, &rp) - loss(&lhs, &rm)) / (2.0 * h);
329 let tol = 1e-2 * fd.abs().max(1.0);
330 assert!(
331 (drhs[j] as f64 - fd).abs() < tol,
332 "FD rhs[{j}]: analytic={} fd={}", drhs[j], fd
333 );
334 }
335 }
336
337 #[test]
338 fn rejects_zero_dims() {
339 let device = MlxDevice::new().unwrap();
340 let mut registry = KernelRegistry::new();
341 let lhs = alloc_f32(&device, 1, vec![1]);
342 let rhs = alloc_f32(&device, 1, vec![1]);
343 let y = alloc_f32(&device, 1, vec![1, 1]);
344 let params = make_params(&device, 0, 1);
345 let mut encoder = device.command_encoder().unwrap();
346 let res = dispatch_outer_product_f32(
347 &mut encoder, &mut registry, device.metal_device(),
348 &lhs, &rhs, &y, ¶ms, 0, 1,
349 );
350 assert!(res.is_err());
351 }
352}