1use metal::MTLSize;
4
5use crate::buffer::MlxBuffer;
6use crate::dtypes::DType;
7use crate::encoder::CommandEncoder;
8use crate::error::{MlxError, Result};
9use crate::kernel_registry::KernelRegistry;
10
11pub static DIVIDE_ELEMENTWISE_SHADER_SOURCE: &str =
12 include_str!("../shaders/divide_elementwise.metal");
13
14pub fn register(registry: &mut KernelRegistry) {
15 registry.register_source("divide_f32", DIVIDE_ELEMENTWISE_SHADER_SOURCE);
16 registry.register_source(
17 "divide_backward_f32",
18 DIVIDE_ELEMENTWISE_SHADER_SOURCE,
19 );
20}
21
22pub fn dispatch_divide_f32(
23 encoder: &mut CommandEncoder,
24 registry: &mut KernelRegistry,
25 device: &metal::DeviceRef,
26 a: &MlxBuffer,
27 b: &MlxBuffer,
28 y: &MlxBuffer,
29 params: &MlxBuffer,
30) -> Result<()> {
31 const OP: &str = "divide_f32";
32 let n = a.element_count();
33 if n == 0 {
34 return Err(MlxError::InvalidArgument(format!("{OP}: empty input")));
35 }
36 if b.element_count() != n || y.element_count() != n {
37 return Err(MlxError::InvalidArgument(format!(
38 "{OP}: shape mismatch (a={}, b={}, y={})",
39 n, b.element_count(), y.element_count()
40 )));
41 }
42 if a.dtype() != DType::F32 || b.dtype() != DType::F32 || y.dtype() != DType::F32 {
43 return Err(MlxError::InvalidArgument(format!("{OP}: must be f32")));
44 }
45 if params.byte_len() < 4 {
46 return Err(MlxError::InvalidArgument(format!(
47 "{OP}: params < 4 bytes"
48 )));
49 }
50 let pipeline = registry.get_pipeline(OP, device)?;
51 let n_u64 = n as u64;
52 encoder.encode(
53 pipeline,
54 &[(0, a), (1, b), (2, y), (3, params)],
55 MTLSize::new(n_u64, 1, 1),
56 MTLSize::new(std::cmp::min(256, n_u64), 1, 1),
57 );
58 Ok(())
59}
60
61#[allow(clippy::too_many_arguments)]
62pub fn dispatch_divide_backward_f32(
63 encoder: &mut CommandEncoder,
64 registry: &mut KernelRegistry,
65 device: &metal::DeviceRef,
66 b: &MlxBuffer,
67 y: &MlxBuffer,
68 dy: &MlxBuffer,
69 da: &MlxBuffer,
70 db: &MlxBuffer,
71 params: &MlxBuffer,
72) -> Result<()> {
73 const OP: &str = "divide_backward_f32";
74 let n = b.element_count();
75 if y.element_count() != n
76 || dy.element_count() != n
77 || da.element_count() != n
78 || db.element_count() != n
79 {
80 return Err(MlxError::InvalidArgument(format!(
81 "{OP}: shape mismatch n={n}, b/y/dy/da/db must match"
82 )));
83 }
84 if b.dtype() != DType::F32 || y.dtype() != DType::F32 || dy.dtype() != DType::F32
85 || da.dtype() != DType::F32 || db.dtype() != DType::F32
86 {
87 return Err(MlxError::InvalidArgument(format!("{OP}: must be f32")));
88 }
89 let pipeline = registry.get_pipeline(OP, device)?;
90 let n_u64 = n as u64;
91 encoder.encode(
92 pipeline,
93 &[(0, b), (1, y), (2, dy), (3, da), (4, db), (5, params)],
94 MTLSize::new(n_u64, 1, 1),
95 MTLSize::new(std::cmp::min(256, n_u64), 1, 1),
96 );
97 Ok(())
98}
99
100#[cfg(test)]
101mod tests {
102 use super::*;
103 use crate::device::MlxDevice;
104
105 fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
106 let mut bx = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
107 bx.as_mut_slice::<f32>().unwrap().fill(0.0);
108 bx
109 }
110 fn make_params(d: &MlxDevice, n: u32) -> MlxBuffer {
111 let mut p = d.alloc_buffer(4, DType::U32, vec![1]).unwrap();
112 p.as_mut_slice::<u32>().unwrap()[0] = n;
113 p
114 }
115
116 #[test]
117 fn forward_matches_cpu_oracle() {
118 let device = MlxDevice::new().unwrap();
119 let mut registry = KernelRegistry::new();
120 let n = 32usize;
121 let a: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.1).collect();
122 let b: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.07).collect();
123
124 let mut a_buf = alloc_f32(&device, n);
125 a_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&a);
126 let mut b_buf = alloc_f32(&device, n);
127 b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&b);
128 let y_buf = alloc_f32(&device, n);
129 let p = make_params(&device, n as u32);
130
131 let mut encoder = device.command_encoder().unwrap();
132 dispatch_divide_f32(
133 &mut encoder, &mut registry, device.metal_device(),
134 &a_buf, &b_buf, &y_buf, &p,
135 ).unwrap();
136 encoder.commit_and_wait().unwrap();
137
138 let gpu = y_buf.as_slice::<f32>().unwrap();
139 for i in 0..n {
140 let cpu = a[i] / b[i];
141 assert!(
142 (gpu[i] - cpu).abs() < 1e-6 * cpu.abs().max(1.0),
143 "y[{i}]: gpu={} cpu={}",
144 gpu[i], cpu
145 );
146 }
147 }
148
149 #[test]
150 fn backward_finite_difference_falsifier() {
151 let device = MlxDevice::new().unwrap();
152 let mut registry = KernelRegistry::new();
153 let n = 16usize;
154 let a: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.05).collect();
155 let b: Vec<f32> = (0..n).map(|i| 1.0 + (i as f32) * 0.07).collect();
156 let dy: Vec<f32> = vec![1.0; n];
157
158 let mut a_buf = alloc_f32(&device, n);
159 a_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&a);
160 let mut b_buf = alloc_f32(&device, n);
161 b_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&b);
162 let y_buf = alloc_f32(&device, n);
163 let mut dy_buf = alloc_f32(&device, n);
164 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
165 let da_buf = alloc_f32(&device, n);
166 let db_buf = alloc_f32(&device, n);
167 let p = make_params(&device, n as u32);
168
169 let mut encoder = device.command_encoder().unwrap();
170 dispatch_divide_f32(
171 &mut encoder, &mut registry, device.metal_device(),
172 &a_buf, &b_buf, &y_buf, &p,
173 ).unwrap();
174 encoder.memory_barrier();
179 dispatch_divide_backward_f32(
180 &mut encoder, &mut registry, device.metal_device(),
181 &b_buf, &y_buf, &dy_buf, &da_buf, &db_buf, &p,
182 ).unwrap();
183 encoder.commit_and_wait().unwrap();
184
185 let da = da_buf.as_slice::<f32>().unwrap().to_vec();
186 let db = db_buf.as_slice::<f32>().unwrap().to_vec();
187
188 let h = 1e-3f64;
190 let loss = |aa: &[f32], bb: &[f32]| -> f64 {
191 (0..n).map(|i| aa[i] as f64 / bb[i] as f64).sum::<f64>()
192 };
193 for i in 0..n {
194 let mut ap = a.clone(); ap[i] += h as f32;
195 let mut am = a.clone(); am[i] -= h as f32;
196 let fd = (loss(&ap, &b) - loss(&am, &b)) / (2.0 * h);
197 let tol = 1e-3 * fd.abs().max(1.0);
198 assert!(
199 (da[i] as f64 - fd).abs() < tol,
200 "FD a[{i}]: analytic={} fd={}", da[i], fd
201 );
202 }
203 for i in 0..n {
204 let mut bp = b.clone(); bp[i] += h as f32;
205 let mut bm = b.clone(); bm[i] -= h as f32;
206 let fd = (loss(&a, &bp) - loss(&a, &bm)) / (2.0 * h);
207 let tol = 1e-3 * fd.abs().max(1.0);
208 assert!(
209 (db[i] as f64 - fd).abs() < tol,
210 "FD b[{i}]: analytic={} fd={}", db[i], fd
211 );
212 }
213 }
214
215 #[test]
216 fn rejects_size_mismatch() {
217 let device = MlxDevice::new().unwrap();
218 let mut registry = KernelRegistry::new();
219 let a = alloc_f32(&device, 16);
220 let b = alloc_f32(&device, 8); let y = alloc_f32(&device, 16);
222 let p = make_params(&device, 16);
223 let mut encoder = device.command_encoder().unwrap();
224 let res = dispatch_divide_f32(
225 &mut encoder, &mut registry, device.metal_device(),
226 &a, &b, &y, &p,
227 );
228 assert!(res.is_err());
229 }
230}