1use burn_backend::ops::FloatTensorOps;
2use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
3use burn_backend::{DType, Distribution, ExecutionError, Scalar, TensorData};
4use burn_std::{FloatDType, Shape, Slice};
5use std::future::Future;
6
7use crate::bridge::{self, burn_to_mps_dtype};
8use crate::ffi::{self};
9use crate::{MpsGraph, MpsGraphTensor};
10
11macro_rules! unary_op {
15 ($fn_name:ident, $sel:expr) => {
16 fn $fn_name(tensor: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
17 bridge::run_unary(&tensor, |g, ph| unsafe { ffi::graph_unary(g, $sel, ph) })
18 }
19 };
20}
21
22macro_rules! binary_op {
24 ($fn_name:ident, $sel:expr) => {
25 fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
26 bridge::run_binary(&lhs, &rhs, |g, a, b| unsafe { ffi::graph_binary(g, $sel, a, b) })
27 }
28 };
29}
30
31macro_rules! scalar_op {
33 ($fn_name:ident, $sel:expr) => {
34 fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: Scalar) -> FloatTensor<MpsGraph> {
35 bridge::run_unary_ctx(&lhs, |g, ph| unsafe {
36 let s = ffi::graph_constant_scalar(g, rhs.elem::<f64>(), burn_to_mps_dtype(lhs.dtype));
37 ffi::graph_binary(g, $sel, ph, s)
38 })
39 }
40 };
41}
42
43macro_rules! cmp_op {
45 ($fn_name:ident, $sel:expr) => {
46 fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
47 bridge::run_binary(&lhs, &rhs, |g, a, b| unsafe { ffi::graph_binary(g, $sel, a, b) })
48 }
49 };
50}
51
52macro_rules! cmp_scalar_op {
53 ($fn_name:ident, $sel:expr) => {
54 fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: Scalar) -> BoolTensor<MpsGraph> {
55 bridge::run_unary_ctx(&lhs, |g, ph| unsafe {
56 let s = ffi::graph_constant_scalar(g, rhs.elem::<f64>(), burn_to_mps_dtype(lhs.dtype));
57 ffi::graph_binary(g, $sel, ph, s)
58 })
59 }
60 };
61}
62
63pub(crate) const ADD: &str = "additionWithPrimaryTensor:secondaryTensor:name:";
64pub(crate) const SUB: &str = "subtractionWithPrimaryTensor:secondaryTensor:name:";
65pub(crate) const MUL: &str = "multiplicationWithPrimaryTensor:secondaryTensor:name:";
66pub(crate) const DIV: &str = "divisionWithPrimaryTensor:secondaryTensor:name:";
67pub(crate) const MOD: &str = "moduloWithPrimaryTensor:secondaryTensor:name:";
68pub(crate) const POW: &str = "powerWithPrimaryTensor:secondaryTensor:name:";
69pub(crate) const EQ: &str = "equalWithPrimaryTensor:secondaryTensor:name:";
70pub(crate) const GT: &str = "greaterThanWithPrimaryTensor:secondaryTensor:name:";
71pub(crate) const GTE: &str = "greaterThanOrEqualToWithPrimaryTensor:secondaryTensor:name:";
72pub(crate) const LT: &str = "lessThanWithPrimaryTensor:secondaryTensor:name:";
73pub(crate) const LTE: &str = "lessThanOrEqualToWithPrimaryTensor:secondaryTensor:name:";
74
75impl FloatTensorOps<MpsGraph> for MpsGraph {
76 fn float_from_data(data: TensorData, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
77 bridge::tensor_from_bytes(data.as_bytes(), Shape::from(data.shape.clone()), data.dtype, *device)
78 }
79
80 fn float_random(shape: Shape, distribution: Distribution, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
81 let n = shape.num_elements();
82 let mut buf = vec![0f32; n];
83 let mut rng = crate::ops::get_seeded_rng();
84 use rand::Rng;
85 match distribution {
86 Distribution::Default => buf.iter_mut().for_each(|v| *v = rng.gen_range(0.0..1.0)),
87 Distribution::Bernoulli(p) => buf.iter_mut().for_each(|v| *v = if rng.gen_range(0.0..1.0) < p { 1.0 } else { 0.0 }),
88 Distribution::Uniform(lo, hi) => buf.iter_mut().for_each(|v| *v = rng.gen_range(lo as f32..hi as f32)),
89 Distribution::Normal(mu, sigma) => buf.iter_mut().for_each(|v| {
90 let u1: f64 = rng.gen_range(1e-7..1.0);
91 let u2: f64 = rng.gen_range(0.0..1.0);
92 *v = (mu + sigma * (-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()) as f32;
93 }),
94 }
95 let bytes = unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const u8, n * 4) };
96 bridge::tensor_from_bytes(bytes, shape, burn_backend::DType::F32, *device)
97 }
98
99 fn float_into_data(t: FloatTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
100 async move {
101 let bytes = bridge::tensor_to_bytes(&t);
102 Ok(TensorData::from_bytes_vec(bytes, t.shape.clone(), t.dtype))
103 }
104 }
105
106 fn float_device(t: &FloatTensor<MpsGraph>) -> Device<MpsGraph> { t.device }
107
108 fn float_to_device(t: FloatTensor<MpsGraph>, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
109 { let buf = unsafe { crate::ffi::retain(t.buffer) }; MpsGraphTensor { buffer: buf, shape: t.shape.clone(), dtype: t.dtype, device: *device } }
110 }
111
112 fn float_into_int(t: FloatTensor<MpsGraph>) -> IntTensor<MpsGraph> {
113 let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe {
114 ffi::graph_cast(g, ph, ffi::MPSDataType::INT32)
115 });
116 r.dtype = burn_backend::DType::I32;
117 r
118 }
119
120 fn float_empty(shape: Shape, device: &Device<MpsGraph>, dtype: FloatDType) -> FloatTensor<MpsGraph> {
121 let dt: burn_backend::DType = dtype.into();
122 bridge::tensor_zeros(shape, dt, *device)
123 }
124
125 binary_op!(float_add, ADD);
127 scalar_op!(float_add_scalar, ADD);
128 binary_op!(float_sub, SUB);
129 scalar_op!(float_sub_scalar, SUB);
130 binary_op!(float_mul, MUL);
131 scalar_op!(float_mul_scalar, MUL);
132 binary_op!(float_div, DIV);
133 scalar_op!(float_div_scalar, DIV);
134 binary_op!(float_remainder, MOD);
135 scalar_op!(float_remainder_scalar, MOD);
136 binary_op!(float_powf, POW);
137
138 fn float_powf_scalar_impl(t: FloatTensor<MpsGraph>, v: Scalar) -> FloatTensor<MpsGraph> {
139 bridge::run_unary_ctx(&t, |g, ph| unsafe {
140 let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
141 ffi::graph_binary(g, POW, ph, s)
142 })
143 }
144
145 fn float_matmul(a: FloatTensor<MpsGraph>, b: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
146 bridge::run_binary(&a, &b, |g, pa, pb| unsafe { ffi::graph_matmul(g, pa, pb) })
147 }
148
149 fn float_recip(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
150 unary_op_impl(&t, "reciprocalWithTensor:name:")
151 }
152
153 fn float_cross(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
154 let sl = |shape: &Shape, idx: usize| -> Vec<Slice> {
155 (0..shape.num_dims()).map(|d| {
156 if d == dim { Slice::new(idx as isize, Some((idx+1) as isize), 1) }
157 else { Slice::new(0, Some(shape[d] as isize), 1) }
158 }).collect()
159 };
160 let (a0,a1,a2) = (Self::float_slice(lhs.clone(), &sl(&lhs.shape,0)),
161 Self::float_slice(lhs.clone(), &sl(&lhs.shape,1)),
162 Self::float_slice(lhs.clone(), &sl(&lhs.shape,2)));
163 let (b0,b1,b2) = (Self::float_slice(rhs.clone(), &sl(&rhs.shape,0)),
164 Self::float_slice(rhs.clone(), &sl(&rhs.shape,1)),
165 Self::float_slice(rhs.clone(), &sl(&rhs.shape,2)));
166 let c0 = Self::float_sub(Self::float_mul(a1.clone(),b2.clone()), Self::float_mul(a2.clone(),b1.clone()));
167 let c1 = Self::float_sub(Self::float_mul(a2,b0.clone()), Self::float_mul(a0.clone(),b2));
168 let c2 = Self::float_sub(Self::float_mul(a0,b1), Self::float_mul(a1,b0));
169 Self::float_cat(vec![c0,c1,c2], dim)
170 }
171
172 cmp_op!(float_equal, EQ);
174 cmp_scalar_op!(float_equal_elem, EQ);
175 cmp_op!(float_greater, GT);
176 cmp_scalar_op!(float_greater_elem, GT);
177 cmp_op!(float_greater_equal, GTE);
178 cmp_scalar_op!(float_greater_equal_elem, GTE);
179 cmp_op!(float_lower, LT);
180 cmp_scalar_op!(float_lower_elem, LT);
181 cmp_op!(float_lower_equal, LTE);
182 cmp_scalar_op!(float_lower_equal_elem, LTE);
183
184 unary_op!(float_exp, "exponentWithTensor:name:");
186 unary_op!(float_log, "logarithmWithTensor:name:");
187 unary_op!(float_sqrt, "squareRootWithTensor:name:");
188 unary_op!(float_abs, "absoluteWithTensor:name:");
189 unary_op!(float_cos, "cosWithTensor:name:");
190 unary_op!(float_sin, "sinWithTensor:name:");
191 unary_op!(float_tan, "tanWithTensor:name:");
192 unary_op!(float_cosh, "coshWithTensor:name:");
193 unary_op!(float_sinh, "sinhWithTensor:name:");
194 unary_op!(float_tanh, "tanhWithTensor:name:");
195 unary_op!(float_acos, "acosWithTensor:name:");
196 unary_op!(float_acosh, "acoshWithTensor:name:");
197 unary_op!(float_asin, "asinWithTensor:name:");
198 unary_op!(float_asinh, "asinhWithTensor:name:");
199 unary_op!(float_atan, "atanWithTensor:name:");
200 unary_op!(float_atanh, "atanhWithTensor:name:");
201 unary_op!(float_erf, "erfWithTensor:name:");
202 unary_op!(float_floor, "floorWithTensor:name:");
203 unary_op!(float_ceil, "ceilWithTensor:name:");
204 unary_op!(float_round, "rintWithTensor:name:");
205
206 fn float_atan2(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
207 bridge::run_binary(&lhs, &rhs, |g,a,b| unsafe { ffi::graph_binary(g, "atan2WithPrimaryTensor:secondaryTensor:name:", a, b) })
208 }
209
210 fn float_log1p(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
211 Self::float_log(Self::float_add_scalar(t, 1.0f32.into()))
212 }
213
214 fn float_trunc(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
215 bridge::run_unary_ctx(&t, |g, ph| unsafe {
216 let abs = ffi::graph_unary(g, "absoluteWithTensor:name:", ph);
217 let fl = ffi::graph_unary(g, "floorWithTensor:name:", abs);
218 let sgn = ffi::graph_unary(g, "signWithTensor:name:", ph);
219 ffi::graph_binary(g, MUL, sgn, fl)
220 })
221 }
222
223 fn float_swap_dims(t: FloatTensor<MpsGraph>, d1: usize, d2: usize) -> FloatTensor<MpsGraph> {
226 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_transpose(g, ph, d1, d2) })
227 }
228
229 fn float_permute(t: FloatTensor<MpsGraph>, axes: &[usize]) -> FloatTensor<MpsGraph> {
230 let perm_ns = unsafe { ffi::ns_usize_array(axes) };
231 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_permute(g, ph, perm_ns) })
232 }
233
234 fn float_flip(t: FloatTensor<MpsGraph>, axes: &[usize]) -> FloatTensor<MpsGraph> {
235 let nd = t.shape.num_dims();
236 let shape = &t.shape;
237 let starts: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { shape[d] as isize - 1 } else { 0 }).collect();
238 let ends: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -(shape[d] as isize) - 1 } else { shape[d] as isize }).collect();
239 let strides: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -1 } else { 1 }).collect();
240 bridge::run_unary_ctx(&t, |g, ph| unsafe {
241 ffi::graph_slice_masked(g, ph,
242 ffi::ns_isize_array(&starts), ffi::ns_isize_array(&ends),
243 ffi::ns_isize_array(&strides), 0, 0, 0)
244 })
245 }
246
247 fn float_reshape(t: FloatTensor<MpsGraph>, shape: Shape) -> FloatTensor<MpsGraph> {
248 let ns = bridge::shape_to_ns(&shape);
249 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reshape(g, ph, ns) })
250 }
251
252 fn float_expand(t: FloatTensor<MpsGraph>, shape: Shape) -> FloatTensor<MpsGraph> {
253 let ns = bridge::shape_to_ns(&shape);
254 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_broadcast(g, ph, ns) })
255 }
256
257 fn float_slice(t: FloatTensor<MpsGraph>, slices: &[Slice]) -> FloatTensor<MpsGraph> {
258 let (sa, ea, st) = bridge::slices_to_ns(slices, &t.shape);
259 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_slice(g, ph, sa, ea, st) })
260 }
261
262 fn float_slice_assign(t: FloatTensor<MpsGraph>, slices: &[Slice], value: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
263 let (sa, ea, st) = bridge::slices_to_ns(slices, &t.shape);
264 bridge::run_binary_ctx(&t, &value, |g, pd, pu| unsafe {
265 ffi::graph_slice_update(g, pd, pu, sa, ea, st)
266 })
267 }
268
269 fn float_cat(tensors: Vec<FloatTensor<MpsGraph>>, dim: usize) -> FloatTensor<MpsGraph> {
270 if tensors.len() == 1 { return tensors.into_iter().next().unwrap(); }
271 let refs: Vec<&MpsGraphTensor> = tensors.iter().collect();
272 bridge::run_multi_ctx(&refs, tensors[0].device, |g, phs| unsafe {
273 let arr = ffi::ns_array(phs);
274 ffi::graph_concat(g, arr, dim as isize)
275 })
276 }
277
278 fn float_gather(dim: usize, t: FloatTensor<MpsGraph>, idx: IntTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
281 bridge::run_binary_ctx(&t, &idx, |g,a,b| unsafe { ffi::graph_gather(g, a, b, dim, 0) })
282 }
283
284 fn float_scatter_add(dim: usize, t: FloatTensor<MpsGraph>, idx: IntTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
285 bridge::run_multi_ctx(&[&t, &idx, &val], t.device, |g, phs| unsafe {
286 ffi::graph_scatter_along(g, dim as isize, phs[0], phs[2], phs[1], ffi::MPSGraphScatterMode::ADD)
287 })
288 }
289
290 fn float_select(t: FloatTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
291 Self::float_gather(dim, t, idx)
292 }
293
294 fn float_select_add(t: FloatTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
295 Self::float_scatter_add(dim, t, idx, val)
296 }
297
298 fn float_mask_where(t: FloatTensor<MpsGraph>, mask: BoolTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
301 bridge::run_ternary(&t, &mask, &val, |g, pt, pm, pv| unsafe { ffi::graph_select(g, pm, pv, pt) })
302 }
303
304 fn float_mask_fill(t: FloatTensor<MpsGraph>, mask: BoolTensor<MpsGraph>, val: Scalar) -> FloatTensor<MpsGraph> {
305 bridge::run_binary_ctx(&t, &mask, |g, pt, pm| unsafe {
306 let s = ffi::graph_constant_scalar(g, val.elem::<f64>(), burn_to_mps_dtype(t.dtype));
307 ffi::graph_select(g, pm, s, pt)
308 })
309 }
310
311 fn float_sum(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
314 let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
315 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_sum(g, ph, ffi::ns_isize_array(&axes)) })
316 }
317
318 fn float_sum_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
319 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_sum_axis(g, ph, dim as isize) })
320 }
321
322 fn float_mean_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
323 let n = t.shape[dim] as f64;
324 let sum = Self::float_sum_dim(t, dim);
325 Self::float_div_scalar(sum, Scalar::from(n as f32))
326 }
327
328 fn float_argmax(t: FloatTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
329 let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argmax(g, ph, dim as isize) });
330 r.dtype = DType::I32; r
331 }
332
333 fn float_argmin(t: FloatTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
334 let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argmin(g, ph, dim as isize) });
335 r.dtype = DType::I32; r
336 }
337
338 fn float_cumsum(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumsum(g,ph,dim as isize) }) }
341 fn float_cumprod(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumprod(g,ph,dim as isize) }) }
342 fn float_cummin(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummin(g,ph,dim as isize) }) }
343 fn float_cummax(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummax(g,ph,dim as isize) }) }
344
345 fn float_sort(t: FloatTensor<MpsGraph>, dim: usize, desc: bool) -> FloatTensor<MpsGraph> {
348 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_sort(g, ph, dim as isize, desc) })
349 }
350
351 fn float_argsort(t: FloatTensor<MpsGraph>, dim: usize, desc: bool) -> IntTensor<MpsGraph> {
352 let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argsort(g, ph, dim as isize, desc) });
353 r.dtype = DType::I32; r
354 }
355
356 fn float_cast(t: FloatTensor<MpsGraph>, dtype: FloatDType) -> FloatTensor<MpsGraph> {
359 let dt: DType = dtype.into();
360 if t.dtype == dt { return t; }
361 let mps = burn_to_mps_dtype(dt);
362 let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_cast(g, ph, mps) });
363 r.dtype = dt; r
364 }
365
366 fn float_prod(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
369 let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
370 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_prod(g, ph, ffi::ns_isize_array(&axes)) })
371 }
372
373 fn float_prod_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
374 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_prod_axis(g, ph, dim as isize) })
375 }
376
377 fn float_max_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
378 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_max_axis(g, ph, dim as isize) })
379 }
380
381 fn float_min_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
382 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_min_axis(g, ph, dim as isize) })
383 }
384
385 fn float_unfold(t: FloatTensor<MpsGraph>, dim: usize, size: usize, step: usize) -> FloatTensor<MpsGraph> {
388 let dim_size = t.shape[dim];
389 let n_win = (dim_size.saturating_sub(size)) / step + 1;
390 let mut windows = Vec::with_capacity(n_win);
391 for i in 0..n_win {
392 let start = i * step;
393 let slices: Vec<Slice> = (0..t.shape.num_dims()).map(|d| {
394 if d == dim { Slice::new(start as isize, Some((start+size) as isize), 1) }
395 else { Slice::new(0, Some(t.shape[d] as isize), 1) }
396 }).collect();
397 let w = Self::float_slice(t.clone(), &slices);
398 let mut dims: Vec<usize> = (0..w.shape.num_dims()).map(|d| w.shape[d]).collect();
399 dims[dim] = 1; dims.push(size);
400 windows.push(Self::float_reshape(w, Shape::from(dims)));
401 }
402 Self::float_cat(windows, dim)
403 }
404}
405
406fn unary_op_impl(t: &MpsGraphTensor, sel: &'static str) -> MpsGraphTensor {
408 bridge::run_unary_ctx(t, |g, ph| unsafe { ffi::graph_unary(g, sel, ph) })
409}