Skip to main content

mlx_native/ops/
exp_elementwise.rs

1//! ADR-020 iter-11h-c1 — elementwise exp forward + backward.
2//!
3//! Forward: `y[i] = exp(x[i])`
4//! Backward: `dx[i] = dy[i] · y[i]` (caller passes y, the forward
5//! output, NOT x — autograd-canonical pattern that avoids recompute).
6
7use metal::MTLSize;
8
9use crate::buffer::MlxBuffer;
10use crate::dtypes::DType;
11use crate::encoder::CommandEncoder;
12use crate::error::{MlxError, Result};
13use crate::kernel_registry::KernelRegistry;
14
15pub static EXP_ELEMENTWISE_SHADER_SOURCE: &str =
16    include_str!("../shaders/exp_elementwise.metal");
17
18pub fn register(registry: &mut KernelRegistry) {
19    registry.register_source("exp_f32", EXP_ELEMENTWISE_SHADER_SOURCE);
20    registry.register_source("exp_backward_f32", EXP_ELEMENTWISE_SHADER_SOURCE);
21}
22
23pub fn dispatch_exp_f32(
24    encoder: &mut CommandEncoder,
25    registry: &mut KernelRegistry,
26    device: &metal::DeviceRef,
27    input: &MlxBuffer,
28    output: &MlxBuffer,
29    params: &MlxBuffer,
30) -> Result<()> {
31    const OP: &str = "exp_f32";
32    let n = input.element_count();
33    if n == 0 {
34        return Err(MlxError::InvalidArgument(format!(
35            "{OP}: input must have at least one element"
36        )));
37    }
38    if output.element_count() != n {
39        return Err(MlxError::InvalidArgument(format!(
40            "{OP}: output element_count {} != input element_count {n}",
41            output.element_count()
42        )));
43    }
44    if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
45        return Err(MlxError::InvalidArgument(format!(
46            "{OP}: input/output must be f32"
47        )));
48    }
49    if params.byte_len() < 4 {
50        return Err(MlxError::InvalidArgument(format!(
51            "{OP}: params < 4 bytes (need 1 × u32 = n)"
52        )));
53    }
54
55    let pipeline = registry.get_pipeline(OP, device)?;
56    let n_u64 = n as u64;
57    let tg = std::cmp::min(256, n_u64);
58    encoder.encode(
59        pipeline,
60        &[(0, input), (1, output), (2, params)],
61        MTLSize::new(n_u64, 1, 1),
62        MTLSize::new(tg, 1, 1),
63    );
64    Ok(())
65}
66
67pub fn dispatch_exp_backward_f32(
68    encoder: &mut CommandEncoder,
69    registry: &mut KernelRegistry,
70    device: &metal::DeviceRef,
71    y: &MlxBuffer,
72    dy: &MlxBuffer,
73    dx: &MlxBuffer,
74    params: &MlxBuffer,
75) -> Result<()> {
76    const OP: &str = "exp_backward_f32";
77    let n = y.element_count();
78    if n == 0 {
79        return Err(MlxError::InvalidArgument(format!(
80            "{OP}: y must have at least one element"
81        )));
82    }
83    if dy.element_count() != n || dx.element_count() != n {
84        return Err(MlxError::InvalidArgument(format!(
85            "{OP}: dy/dx element_count must match y ({n})"
86        )));
87    }
88    if y.dtype() != DType::F32 || dy.dtype() != DType::F32 || dx.dtype() != DType::F32 {
89        return Err(MlxError::InvalidArgument(format!(
90            "{OP}: y/dy/dx must be f32"
91        )));
92    }
93    if params.byte_len() < 4 {
94        return Err(MlxError::InvalidArgument(format!(
95            "{OP}: params < 4 bytes (need 1 × u32 = n)"
96        )));
97    }
98
99    let pipeline = registry.get_pipeline(OP, device)?;
100    let n_u64 = n as u64;
101    let tg = std::cmp::min(256, n_u64);
102    encoder.encode(
103        pipeline,
104        &[(0, y), (1, dy), (2, dx), (3, params)],
105        MTLSize::new(n_u64, 1, 1),
106        MTLSize::new(tg, 1, 1),
107    );
108    Ok(())
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::device::MlxDevice;
115
116    fn alloc_f32(device: &MlxDevice, n: usize) -> MlxBuffer {
117        let mut b = device.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
118        b.as_mut_slice::<f32>().unwrap().fill(0.0);
119        b
120    }
121
122    fn make_params(device: &MlxDevice, n: u32) -> MlxBuffer {
123        let mut p = device.alloc_buffer(4, DType::U32, vec![1]).unwrap();
124        p.as_mut_slice::<u32>().unwrap()[0] = n;
125        p
126    }
127
128    #[test]
129    fn forward_matches_cpu_oracle() {
130        let device = MlxDevice::new().unwrap();
131        let mut registry = KernelRegistry::new();
132        let n = 64usize;
133        let x: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.073 - 1.5)).collect();
134
135        let mut x_buf = alloc_f32(&device, n);
136        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
137        let y_buf = alloc_f32(&device, n);
138        let params = make_params(&device, n as u32);
139
140        let mut encoder = device.command_encoder().unwrap();
141        dispatch_exp_f32(
142            &mut encoder, &mut registry, device.metal_device(),
143            &x_buf, &y_buf, &params,
144        ).unwrap();
145        encoder.commit_and_wait().unwrap();
146
147        let gpu = y_buf.as_slice::<f32>().unwrap();
148        for i in 0..n {
149            let cpu = (x[i] as f64).exp() as f32;
150            assert!(
151                (gpu[i] - cpu).abs() < 1e-5 * cpu.abs().max(1.0),
152                "exp y[{i}]: gpu={} cpu={} (x={})",
153                gpu[i], cpu, x[i]
154            );
155        }
156    }
157
158    #[test]
159    fn backward_dx_equals_dy_times_y() {
160        let device = MlxDevice::new().unwrap();
161        let mut registry = KernelRegistry::new();
162        let n = 32usize;
163        let y: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.07).collect();
164        let dy: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.13 - 0.5).sin()).collect();
165
166        let mut y_buf = alloc_f32(&device, n);
167        y_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&y);
168        let mut dy_buf = alloc_f32(&device, n);
169        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy);
170        let dx_buf = alloc_f32(&device, n);
171        let params = make_params(&device, n as u32);
172
173        let mut encoder = device.command_encoder().unwrap();
174        dispatch_exp_backward_f32(
175            &mut encoder, &mut registry, device.metal_device(),
176            &y_buf, &dy_buf, &dx_buf, &params,
177        ).unwrap();
178        encoder.commit_and_wait().unwrap();
179
180        let gpu = dx_buf.as_slice::<f32>().unwrap();
181        for i in 0..n {
182            let expected = dy[i] * y[i];
183            assert!(
184                (gpu[i] - expected).abs() < 1e-6 * expected.abs().max(1.0),
185                "exp dx[{i}]: gpu={} expected={}",
186                gpu[i], expected
187            );
188        }
189    }
190
191    /// Finite-difference falsifier: `loss = sum(exp(x))`.  Analytic
192    /// gradient is `dx = exp(x)`.  FD must match within 1% rel tol.
193    #[test]
194    fn backward_finite_difference_falsifier() {
195        let device = MlxDevice::new().unwrap();
196        let mut registry = KernelRegistry::new();
197        let n = 16usize;
198        let x: Vec<f32> = (0..n).map(|i| ((i as f32) * 0.043 - 0.5)).collect();
199
200        // Forward via mlx-native.
201        let mut x_buf = alloc_f32(&device, n);
202        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
203        let y_buf = alloc_f32(&device, n);
204        let params = make_params(&device, n as u32);
205        let mut encoder = device.command_encoder().unwrap();
206        dispatch_exp_f32(
207            &mut encoder, &mut registry, device.metal_device(),
208            &x_buf, &y_buf, &params,
209        ).unwrap();
210        encoder.commit_and_wait().unwrap();
211        let y = y_buf.as_slice::<f32>().unwrap().to_vec();
212
213        // Analytic backward: dy = ones, dx = exp(x).
214        let dy_ones = vec![1.0f32; n];
215        let mut dy_buf = alloc_f32(&device, n);
216        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
217        let dx_buf = alloc_f32(&device, n);
218        let mut encoder = device.command_encoder().unwrap();
219        dispatch_exp_backward_f32(
220            &mut encoder, &mut registry, device.metal_device(),
221            &y_buf, &dy_buf, &dx_buf, &params,
222        ).unwrap();
223        encoder.commit_and_wait().unwrap();
224        let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
225
226        // FD: loss(xp) - loss(xm) / (2h) for each x[i].
227        let h = 1e-4f64;
228        for i in 0..n {
229            let mut xp = x.clone();
230            xp[i] += h as f32;
231            let mut xm = x.clone();
232            xm[i] -= h as f32;
233            let loss_p: f64 = xp.iter().map(|v| (*v as f64).exp()).sum();
234            let loss_m: f64 = xm.iter().map(|v| (*v as f64).exp()).sum();
235            let fd = (loss_p - loss_m) / (2.0 * h);
236            let tol = 1e-2 * fd.abs().max(1.0);
237            assert!(
238                (dx[i] as f64 - fd).abs() < tol,
239                "FD x[{i}]: analytic={} fd={} (y={})",
240                dx[i], fd, y[i]
241            );
242        }
243    }
244
245    #[test]
246    fn rejects_size_mismatch() {
247        let device = MlxDevice::new().unwrap();
248        let mut registry = KernelRegistry::new();
249        let x = alloc_f32(&device, 16);
250        let y = alloc_f32(&device, 8); // wrong size
251        let params = make_params(&device, 16);
252        let mut encoder = device.command_encoder().unwrap();
253        let res = dispatch_exp_f32(
254            &mut encoder, &mut registry, device.metal_device(),
255            &x, &y, &params,
256        );
257        assert!(res.is_err());
258    }
259}