1use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15pub static EXP_ELEMENTWISE_SHADER_SOURCE: &str =
16 include_str!("../shaders/exp_elementwise.metal");
17
18pub fn register(registry: &mut KernelRegistry) {
19 registry.register_source("exp_f32", EXP_ELEMENTWISE_SHADER_SOURCE);
20 registry.register_source("exp_backward_f32", EXP_ELEMENTWISE_SHADER_SOURCE);
21}
22
23pub fn dispatch_exp_f32(
24 encoder: &mut CommandEncoder,
25 registry: &mut KernelRegistry,
26 device: &metal::DeviceRef,
27 input: &MlxBuffer,
28 output: &MlxBuffer,
29 params: &MlxBuffer,
30) -> Result<()> {
31 const OP: &str = "exp_f32";
32 let n = input.element_count();
33 if n == 0 {
34 return Err(MlxError::InvalidArgument(format!(
35 "{OP}: input must have at least one element"
36 )));
37 }
38 if output.element_count() != n {
39 return Err(MlxError::InvalidArgument(format!(
40 "{OP}: output element_count {} != input element_count {n}",
41 output.element_count()
42 )));
43 }
44 if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
45 return Err(MlxError::InvalidArgument(format!(
46 "{OP}: input/output must be f32"
47 )));
48 }
49 if params.byte_len() < 4 {
50 return Err(MlxError::InvalidArgument(format!(
51 "{OP}: params < 4 bytes (need 1 × u32 = n)"
52 )));
53 }
54
55 let pipeline = registry.get_pipeline(OP, device)?;
56 let n_u64 = n as u64;
57 let tg = std::cmp::min(256, n_u64);
58 encoder.encode(
59 pipeline,
60 &[(0, input), (1, output), (2, params)],
61 MTLSize::new(n_u64, 1, 1),
62 MTLSize::new(tg, 1, 1),
63 );
64 Ok(())
65}
66
67pub fn dispatch_exp_backward_f32(
68 encoder: &mut CommandEncoder,
69 registry: &mut KernelRegistry,
70 device: &metal::DeviceRef,
71 y: &MlxBuffer,
72 dy: &MlxBuffer,
73 dx: &MlxBuffer,
74 params: &MlxBuffer,
75) -> Result<()> {
76 const OP: &str = "exp_backward_f32";
77 let n = y.element_count();
78 if n == 0 {
79 return Err(MlxError::InvalidArgument(format!(
80 "{OP}: y must have at least one element"
81 )));
82 }
83 if dy.element_count() != n || dx.element_count() != n {
84 return Err(MlxError::InvalidArgument(format!(
85 "{OP}: dy/dx element_count must match y ({n})"
86 )));
87 }
88 if y.dtype() != DType::F32 || dy.dtype() != DType::F32 || dx.dtype() != DType::F32 {
89 return Err(MlxError::InvalidArgument(format!(
90 "{OP}: y/dy/dx must be f32"
91 )));
92 }
93 if params.byte_len() < 4 {
94 return Err(MlxError::InvalidArgument(format!(
95 "{OP}: params < 4 bytes (need 1 × u32 = n)"
96 )));
97 }
98
99 let pipeline = registry.get_pipeline(OP, device)?;
100 let n_u64 = n as u64;
101 let tg = std::cmp::min(256, n_u64);
102 encoder.encode(
103 pipeline,
104 &[(0, y), (1, dy), (2, dx), (3, params)],
105 MTLSize::new(n_u64, 1, 1),
106 MTLSize::new(tg, 1, 1),
107 );
108 Ok(())
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::device::MlxDevice;
115
116 fn alloc_f32(device: &MlxDevice, n: usize) -> MlxBuffer {
117 let mut b = device.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
118 b.as_mut_slice::<f32>().unwrap().fill(0.0);
119 b
120 }
121
122 fn make_params(device: &MlxDevice, n: u32) -> MlxBuffer {
123 let mut p = device.alloc_buffer(4, DType::U32, vec![1]).unwrap();
124 p.as_mut_slice::<u32>().unwrap()[0] = n;
125 p
126 }
127
128 #[test]
129 fn forward_matches_cpu_oracle() {
130 let device = MlxDevice::new().unwrap();
131 let mut registry = KernelRegistry::new();
132 let n = 64usize;
133 let x: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.073 - 1.5)).collect();
134
135 let mut x_buf = alloc_f32(&device, n);
136 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
137 let y_buf = alloc_f32(&device, n);
138 let params = make_params(&device, n as u32);
139
140 let mut encoder = device.command_encoder().unwrap();
141 dispatch_exp_f32(
142 &mut encoder, &mut registry, device.metal_device(),
143 &x_buf, &y_buf, ¶ms,
144 ).unwrap();
145 encoder.commit_and_wait().unwrap();
146
147 let gpu = y_buf.as_slice::<f32>().unwrap();
148 for i in 0..n {
149 let cpu = (x[i] as f64).exp() as f32;
150 assert!(
151 (gpu[i] - cpu).abs() < 1e-5 * cpu.abs().max(1.0),
152 "exp y[{i}]: gpu={} cpu={} (x={})",
153 gpu[i], cpu, x[i]
154 );
155 }
156 }
157
158 #[test]
159 fn backward_dx_equals_dy_times_y() {
160 let device = MlxDevice::new().unwrap();
161 let mut registry = KernelRegistry::new();
162 let n = 32usize;
163 let y: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.07).collect();
164 let dy: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.13 - 0.5).sin()).collect();
165
166 let mut y_buf = alloc_f32(&device, n);
167 y_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&y);
168 let mut dy_buf = alloc_f32(&device, n);
169 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
170 let dx_buf = alloc_f32(&device, n);
171 let params = make_params(&device, n as u32);
172
173 let mut encoder = device.command_encoder().unwrap();
174 dispatch_exp_backward_f32(
175 &mut encoder, &mut registry, device.metal_device(),
176 &y_buf, &dy_buf, &dx_buf, ¶ms,
177 ).unwrap();
178 encoder.commit_and_wait().unwrap();
179
180 let gpu = dx_buf.as_slice::<f32>().unwrap();
181 for i in 0..n {
182 let expected = dy[i] * y[i];
183 assert!(
184 (gpu[i] - expected).abs() < 1e-6 * expected.abs().max(1.0),
185 "exp dx[{i}]: gpu={} expected={}",
186 gpu[i], expected
187 );
188 }
189 }
190
191 #[test]
194 fn backward_finite_difference_falsifier() {
195 let device = MlxDevice::new().unwrap();
196 let mut registry = KernelRegistry::new();
197 let n = 16usize;
198 let x: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.043 - 0.5)).collect();
199
200 let mut x_buf = alloc_f32(&device, n);
202 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
203 let y_buf = alloc_f32(&device, n);
204 let params = make_params(&device, n as u32);
205 let mut encoder = device.command_encoder().unwrap();
206 dispatch_exp_f32(
207 &mut encoder, &mut registry, device.metal_device(),
208 &x_buf, &y_buf, ¶ms,
209 ).unwrap();
210 encoder.commit_and_wait().unwrap();
211 let y = y_buf.as_slice::<f32>().unwrap().to_vec();
212
213 let dy_ones = vec![1.0f32; n];
215 let mut dy_buf = alloc_f32(&device, n);
216 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
217 let dx_buf = alloc_f32(&device, n);
218 let mut encoder = device.command_encoder().unwrap();
219 dispatch_exp_backward_f32(
220 &mut encoder, &mut registry, device.metal_device(),
221 &y_buf, &dy_buf, &dx_buf, ¶ms,
222 ).unwrap();
223 encoder.commit_and_wait().unwrap();
224 let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
225
226 let h = 1e-4f64;
228 for i in 0..n {
229 let mut xp = x.clone();
230 xp[i] += h as f32;
231 let mut xm = x.clone();
232 xm[i] -= h as f32;
233 let loss_p: f64 = xp.iter().map(|v| (*v as f64).exp()).sum();
234 let loss_m: f64 = xm.iter().map(|v| (*v as f64).exp()).sum();
235 let fd = (loss_p - loss_m) / (2.0 * h);
236 let tol = 1e-2 * fd.abs().max(1.0);
237 assert!(
238 (dx[i] as f64 - fd).abs() < tol,
239 "FD x[{i}]: analytic={} fd={} (y={})",
240 dx[i], fd, y[i]
241 );
242 }
243 }
244
245 #[test]
246 fn rejects_size_mismatch() {
247 let device = MlxDevice::new().unwrap();
248 let mut registry = KernelRegistry::new();
249 let x = alloc_f32(&device, 16);
250 let y = alloc_f32(&device, 8); let params = make_params(&device, 16);
252 let mut encoder = device.command_encoder().unwrap();
253 let res = dispatch_exp_f32(
254 &mut encoder, &mut registry, device.metal_device(),
255 &x, &y, ¶ms,
256 );
257 assert!(res.is_err());
258 }
259}