Skip to main content

mlx_native/ops/
sqrt_elementwise.rs

1//! ADR-020 iter-11h-misc-3 — elementwise sqrt forward + backward.
2
3use metal::MTLSize;
4
5use crate::buffer::MlxBuffer;
6use crate::dtypes::DType;
7use crate::encoder::CommandEncoder;
8use crate::error::{MlxError, Result};
9use crate::kernel_registry::KernelRegistry;
10
11pub static SQRT_ELEMENTWISE_SHADER_SOURCE: &str =
12    include_str!("../shaders/sqrt_elementwise.metal");
13
14pub fn register(registry: &mut KernelRegistry) {
15    registry.register_source("sqrt_f32", SQRT_ELEMENTWISE_SHADER_SOURCE);
16    registry.register_source("sqrt_backward_f32", SQRT_ELEMENTWISE_SHADER_SOURCE);
17}
18
19pub fn dispatch_sqrt_f32(
20    encoder: &mut CommandEncoder,
21    registry: &mut KernelRegistry,
22    device: &metal::DeviceRef,
23    input: &MlxBuffer,
24    output: &MlxBuffer,
25    params: &MlxBuffer,
26) -> Result<()> {
27    const OP: &str = "sqrt_f32";
28    let n = input.element_count();
29    if n == 0 {
30        return Err(MlxError::InvalidArgument(format!("{OP}: empty input")));
31    }
32    if output.element_count() != n {
33        return Err(MlxError::InvalidArgument(format!(
34            "{OP}: output element_count {} != input {n}",
35            output.element_count()
36        )));
37    }
38    if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
39        return Err(MlxError::InvalidArgument(format!("{OP}: must be f32")));
40    }
41    if params.byte_len() < 4 {
42        return Err(MlxError::InvalidArgument(format!("{OP}: params < 4 bytes")));
43    }
44    let pipeline = registry.get_pipeline(OP, device)?;
45    let n_u64 = n as u64;
46    encoder.encode(
47        pipeline,
48        &[(0, input), (1, output), (2, params)],
49        MTLSize::new(n_u64, 1, 1),
50        MTLSize::new(std::cmp::min(256, n_u64), 1, 1),
51    );
52    Ok(())
53}
54
55pub fn dispatch_sqrt_backward_f32(
56    encoder: &mut CommandEncoder,
57    registry: &mut KernelRegistry,
58    device: &metal::DeviceRef,
59    y: &MlxBuffer,
60    dy: &MlxBuffer,
61    dx: &MlxBuffer,
62    params: &MlxBuffer,
63) -> Result<()> {
64    const OP: &str = "sqrt_backward_f32";
65    let n = y.element_count();
66    if dy.element_count() != n || dx.element_count() != n {
67        return Err(MlxError::InvalidArgument(format!(
68            "{OP}: shape mismatch (y={n}, dy={}, dx={})",
69            dy.element_count(), dx.element_count()
70        )));
71    }
72    if y.dtype() != DType::F32 || dy.dtype() != DType::F32 || dx.dtype() != DType::F32 {
73        return Err(MlxError::InvalidArgument(format!("{OP}: must be f32")));
74    }
75    let pipeline = registry.get_pipeline(OP, device)?;
76    let n_u64 = n as u64;
77    encoder.encode(
78        pipeline,
79        &[(0, y), (1, dy), (2, dx), (3, params)],
80        MTLSize::new(n_u64, 1, 1),
81        MTLSize::new(std::cmp::min(256, n_u64), 1, 1),
82    );
83    Ok(())
84}
85
86#[cfg(test)]
87mod tests {
88    use super::*;
89    use crate::device::MlxDevice;
90
91    fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
92        let mut b = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
93        b.as_mut_slice::<f32>().unwrap().fill(0.0);
94        b
95    }
96    fn make_params(d: &MlxDevice, n: u32) -> MlxBuffer {
97        let mut p = d.alloc_buffer(4, DType::U32, vec![1]).unwrap();
98        p.as_mut_slice::<u32>().unwrap()[0] = n;
99        p
100    }
101
102    #[test]
103    fn forward_matches_cpu_oracle() {
104        let device = MlxDevice::new().unwrap();
105        let mut registry = KernelRegistry::new();
106        let n = 32usize;
107        let x: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.3).collect();
108
109        let mut x_buf = alloc_f32(&device, n);
110        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
111        let y_buf = alloc_f32(&device, n);
112        let p = make_params(&device, n as u32);
113
114        let mut encoder = device.command_encoder().unwrap();
115        dispatch_sqrt_f32(
116            &mut encoder, &mut registry, device.metal_device(),
117            &x_buf, &y_buf, &p,
118        ).unwrap();
119        encoder.commit_and_wait().unwrap();
120
121        let gpu = y_buf.as_slice::<f32>().unwrap();
122        for i in 0..n {
123            let cpu = (x[i] as f64).sqrt() as f32;
124            assert!(
125                (gpu[i] - cpu).abs() < 1e-6 * cpu.abs().max(1.0),
126                "y[{i}]: gpu={} cpu={} (x={})",
127                gpu[i], cpu, x[i]
128            );
129        }
130    }
131
132    /// FD falsifier: loss = sum(sqrt(x)).  dx[i] = 1/(2·sqrt(x[i])).
133    #[test]
134    fn backward_finite_difference_falsifier() {
135        let device = MlxDevice::new().unwrap();
136        let mut registry = KernelRegistry::new();
137        let n = 16usize;
138        let x: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.1).collect();
139
140        let mut x_buf = alloc_f32(&device, n);
141        x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
142        let y_buf = alloc_f32(&device, n);
143        let p = make_params(&device, n as u32);
144        let mut encoder = device.command_encoder().unwrap();
145        dispatch_sqrt_f32(
146            &mut encoder, &mut registry, device.metal_device(),
147            &x_buf, &y_buf, &p,
148        ).unwrap();
149        // RAW barrier (per feedback_metal_raw_barrier_per_dispatch).
150        encoder.memory_barrier();
151
152        let dy_ones = vec![1.0f32; n];
153        let mut dy_buf = alloc_f32(&device, n);
154        dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
155        let dx_buf = alloc_f32(&device, n);
156        dispatch_sqrt_backward_f32(
157            &mut encoder, &mut registry, device.metal_device(),
158            &y_buf, &dy_buf, &dx_buf, &p,
159        ).unwrap();
160        encoder.commit_and_wait().unwrap();
161        let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
162
163        let h = 1e-3f64;
164        for i in 0..n {
165            let mut xp = x.clone(); xp[i] += h as f32;
166            let mut xm = x.clone(); xm[i] -= h as f32;
167            let lp: f64 = xp.iter().map(|v| (*v as f64).sqrt()).sum();
168            let lm: f64 = xm.iter().map(|v| (*v as f64).sqrt()).sum();
169            let fd = (lp - lm) / (2.0 * h);
170            let tol = 1e-2 * fd.abs().max(1.0);
171            assert!(
172                (dx[i] as f64 - fd).abs() < tol,
173                "FD x[{i}]: analytic={} fd={}", dx[i], fd
174            );
175        }
176    }
177
178    #[test]
179    fn rejects_size_mismatch() {
180        let device = MlxDevice::new().unwrap();
181        let mut registry = KernelRegistry::new();
182        let x = alloc_f32(&device, 16);
183        let y = alloc_f32(&device, 8);
184        let p = make_params(&device, 16);
185        let mut encoder = device.command_encoder().unwrap();
186        let res = dispatch_sqrt_f32(
187            &mut encoder, &mut registry, device.metal_device(),
188            &x, &y, &p,
189        );
190        assert!(res.is_err());
191    }
192}