1use metal::MTLSize;
4
5use crate::buffer::MlxBuffer;
6use crate::dtypes::DType;
7use crate::encoder::CommandEncoder;
8use crate::error::{MlxError, Result};
9use crate::kernel_registry::KernelRegistry;
10
11pub static SQRT_ELEMENTWISE_SHADER_SOURCE: &str =
12 include_str!("../shaders/sqrt_elementwise.metal");
13
14pub fn register(registry: &mut KernelRegistry) {
15 registry.register_source("sqrt_f32", SQRT_ELEMENTWISE_SHADER_SOURCE);
16 registry.register_source("sqrt_backward_f32", SQRT_ELEMENTWISE_SHADER_SOURCE);
17}
18
19pub fn dispatch_sqrt_f32(
20 encoder: &mut CommandEncoder,
21 registry: &mut KernelRegistry,
22 device: &metal::DeviceRef,
23 input: &MlxBuffer,
24 output: &MlxBuffer,
25 params: &MlxBuffer,
26) -> Result<()> {
27 const OP: &str = "sqrt_f32";
28 let n = input.element_count();
29 if n == 0 {
30 return Err(MlxError::InvalidArgument(format!("{OP}: empty input")));
31 }
32 if output.element_count() != n {
33 return Err(MlxError::InvalidArgument(format!(
34 "{OP}: output element_count {} != input {n}",
35 output.element_count()
36 )));
37 }
38 if input.dtype() != DType::F32 || output.dtype() != DType::F32 {
39 return Err(MlxError::InvalidArgument(format!("{OP}: must be f32")));
40 }
41 if params.byte_len() < 4 {
42 return Err(MlxError::InvalidArgument(format!("{OP}: params < 4 bytes")));
43 }
44 let pipeline = registry.get_pipeline(OP, device)?;
45 let n_u64 = n as u64;
46 encoder.encode(
47 pipeline,
48 &[(0, input), (1, output), (2, params)],
49 MTLSize::new(n_u64, 1, 1),
50 MTLSize::new(std::cmp::min(256, n_u64), 1, 1),
51 );
52 Ok(())
53}
54
55pub fn dispatch_sqrt_backward_f32(
56 encoder: &mut CommandEncoder,
57 registry: &mut KernelRegistry,
58 device: &metal::DeviceRef,
59 y: &MlxBuffer,
60 dy: &MlxBuffer,
61 dx: &MlxBuffer,
62 params: &MlxBuffer,
63) -> Result<()> {
64 const OP: &str = "sqrt_backward_f32";
65 let n = y.element_count();
66 if dy.element_count() != n || dx.element_count() != n {
67 return Err(MlxError::InvalidArgument(format!(
68 "{OP}: shape mismatch (y={n}, dy={}, dx={})",
69 dy.element_count(), dx.element_count()
70 )));
71 }
72 if y.dtype() != DType::F32 || dy.dtype() != DType::F32 || dx.dtype() != DType::F32 {
73 return Err(MlxError::InvalidArgument(format!("{OP}: must be f32")));
74 }
75 let pipeline = registry.get_pipeline(OP, device)?;
76 let n_u64 = n as u64;
77 encoder.encode(
78 pipeline,
79 &[(0, y), (1, dy), (2, dx), (3, params)],
80 MTLSize::new(n_u64, 1, 1),
81 MTLSize::new(std::cmp::min(256, n_u64), 1, 1),
82 );
83 Ok(())
84}
85
86#[cfg(test)]
87mod tests {
88 use super::*;
89 use crate::device::MlxDevice;
90
91 fn alloc_f32(d: &MlxDevice, n: usize) -> MlxBuffer {
92 let mut b = d.alloc_buffer(n * 4, DType::F32, vec![n]).unwrap();
93 b.as_mut_slice::<f32>().unwrap().fill(0.0);
94 b
95 }
96 fn make_params(d: &MlxDevice, n: u32) -> MlxBuffer {
97 let mut p = d.alloc_buffer(4, DType::U32, vec![1]).unwrap();
98 p.as_mut_slice::<u32>().unwrap()[0] = n;
99 p
100 }
101
102 #[test]
103 fn forward_matches_cpu_oracle() {
104 let device = MlxDevice::new().unwrap();
105 let mut registry = KernelRegistry::new();
106 let n = 32usize;
107 let x: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.3).collect();
108
109 let mut x_buf = alloc_f32(&device, n);
110 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
111 let y_buf = alloc_f32(&device, n);
112 let p = make_params(&device, n as u32);
113
114 let mut encoder = device.command_encoder().unwrap();
115 dispatch_sqrt_f32(
116 &mut encoder, &mut registry, device.metal_device(),
117 &x_buf, &y_buf, &p,
118 ).unwrap();
119 encoder.commit_and_wait().unwrap();
120
121 let gpu = y_buf.as_slice::<f32>().unwrap();
122 for i in 0..n {
123 let cpu = (x[i] as f64).sqrt() as f32;
124 assert!(
125 (gpu[i] - cpu).abs() < 1e-6 * cpu.abs().max(1.0),
126 "y[{i}]: gpu={} cpu={} (x={})",
127 gpu[i], cpu, x[i]
128 );
129 }
130 }
131
132 #[test]
134 fn backward_finite_difference_falsifier() {
135 let device = MlxDevice::new().unwrap();
136 let mut registry = KernelRegistry::new();
137 let n = 16usize;
138 let x: Vec<f32> = (0..n).map(|i| 0.5 + (i as f32) * 0.1).collect();
139
140 let mut x_buf = alloc_f32(&device, n);
141 x_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&x);
142 let y_buf = alloc_f32(&device, n);
143 let p = make_params(&device, n as u32);
144 let mut encoder = device.command_encoder().unwrap();
145 dispatch_sqrt_f32(
146 &mut encoder, &mut registry, device.metal_device(),
147 &x_buf, &y_buf, &p,
148 ).unwrap();
149 encoder.memory_barrier();
151
152 let dy_ones = vec![1.0f32; n];
153 let mut dy_buf = alloc_f32(&device, n);
154 dy_buf.as_mut_slice::<f32>().unwrap().copy_from_slice(&dy_ones);
155 let dx_buf = alloc_f32(&device, n);
156 dispatch_sqrt_backward_f32(
157 &mut encoder, &mut registry, device.metal_device(),
158 &y_buf, &dy_buf, &dx_buf, &p,
159 ).unwrap();
160 encoder.commit_and_wait().unwrap();
161 let dx = dx_buf.as_slice::<f32>().unwrap().to_vec();
162
163 let h = 1e-3f64;
164 for i in 0..n {
165 let mut xp = x.clone(); xp[i] += h as f32;
166 let mut xm = x.clone(); xm[i] -= h as f32;
167 let lp: f64 = xp.iter().map(|v| (*v as f64).sqrt()).sum();
168 let lm: f64 = xm.iter().map(|v| (*v as f64).sqrt()).sum();
169 let fd = (lp - lm) / (2.0 * h);
170 let tol = 1e-2 * fd.abs().max(1.0);
171 assert!(
172 (dx[i] as f64 - fd).abs() < tol,
173 "FD x[{i}]: analytic={} fd={}", dx[i], fd
174 );
175 }
176 }
177
178 #[test]
179 fn rejects_size_mismatch() {
180 let device = MlxDevice::new().unwrap();
181 let mut registry = KernelRegistry::new();
182 let x = alloc_f32(&device, 16);
183 let y = alloc_f32(&device, 8);
184 let p = make_params(&device, 16);
185 let mut encoder = device.command_encoder().unwrap();
186 let res = dispatch_sqrt_f32(
187 &mut encoder, &mut registry, device.metal_device(),
188 &x, &y, &p,
189 );
190 assert!(res.is_err());
191 }
192}