use burn_backend::ops::FloatTensorOps;
use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
use burn_backend::{DType, Distribution, ExecutionError, Scalar, TensorData};
use burn_std::{FloatDType, Shape, Slice};
use std::future::Future;
use crate::bridge::{self, burn_to_mps_dtype};
use crate::ffi::{self};
use crate::{MpsGraph, MpsGraphTensor};
macro_rules! unary_op {
($fn_name:ident, $sel:expr) => {
fn $fn_name(tensor: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::run_unary(&tensor, |g, ph| unsafe { ffi::graph_unary(g, $sel, ph) })
}
};
}
macro_rules! binary_op {
($fn_name:ident, $sel:expr) => {
fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::run_binary(&lhs, &rhs, |g, a, b| unsafe { ffi::graph_binary(g, $sel, a, b) })
}
};
}
macro_rules! scalar_op {
($fn_name:ident, $sel:expr) => {
fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: Scalar) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&lhs, |g, ph| unsafe {
let s = ffi::graph_constant_scalar(g, rhs.elem::<f64>(), burn_to_mps_dtype(lhs.dtype));
ffi::graph_binary(g, $sel, ph, s)
})
}
};
}
macro_rules! cmp_op {
($fn_name:ident, $sel:expr) => {
fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
bridge::run_binary(&lhs, &rhs, |g, a, b| unsafe { ffi::graph_binary(g, $sel, a, b) })
}
};
}
macro_rules! cmp_scalar_op {
($fn_name:ident, $sel:expr) => {
fn $fn_name(lhs: FloatTensor<MpsGraph>, rhs: Scalar) -> BoolTensor<MpsGraph> {
bridge::run_unary_ctx(&lhs, |g, ph| unsafe {
let s = ffi::graph_constant_scalar(g, rhs.elem::<f64>(), burn_to_mps_dtype(lhs.dtype));
ffi::graph_binary(g, $sel, ph, s)
})
}
};
}
pub(crate) const ADD: &str = "additionWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const SUB: &str = "subtractionWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const MUL: &str = "multiplicationWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const DIV: &str = "divisionWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const MOD: &str = "moduloWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const POW: &str = "powerWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const EQ: &str = "equalWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const GT: &str = "greaterThanWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const GTE: &str = "greaterThanOrEqualToWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const LT: &str = "lessThanWithPrimaryTensor:secondaryTensor:name:";
pub(crate) const LTE: &str = "lessThanOrEqualToWithPrimaryTensor:secondaryTensor:name:";
impl FloatTensorOps<MpsGraph> for MpsGraph {
fn float_from_data(data: TensorData, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::tensor_from_bytes(data.as_bytes(), Shape::from(data.shape.clone()), data.dtype, *device)
}
fn float_random(shape: Shape, distribution: Distribution, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
let n = shape.num_elements();
let mut buf = vec![0f32; n];
let mut rng = crate::ops::get_seeded_rng();
use rand::Rng;
match distribution {
Distribution::Default => buf.iter_mut().for_each(|v| *v = rng.gen_range(0.0..1.0)),
Distribution::Bernoulli(p) => buf.iter_mut().for_each(|v| *v = if rng.gen_range(0.0..1.0) < p { 1.0 } else { 0.0 }),
Distribution::Uniform(lo, hi) => buf.iter_mut().for_each(|v| *v = rng.gen_range(lo as f32..hi as f32)),
Distribution::Normal(mu, sigma) => buf.iter_mut().for_each(|v| {
let u1: f64 = rng.gen_range(1e-7..1.0);
let u2: f64 = rng.gen_range(0.0..1.0);
*v = (mu + sigma * (-2.0 * u1.ln()).sqrt() * (std::f64::consts::TAU * u2).cos()) as f32;
}),
}
let bytes = unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const u8, n * 4) };
bridge::tensor_from_bytes(bytes, shape, burn_backend::DType::F32, *device)
}
fn float_into_data(t: FloatTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
async move {
let bytes = bridge::tensor_to_bytes(&t);
Ok(TensorData::from_bytes_vec(bytes, t.shape.clone(), t.dtype))
}
}
fn float_device(t: &FloatTensor<MpsGraph>) -> Device<MpsGraph> { t.device }
fn float_to_device(t: FloatTensor<MpsGraph>, device: &Device<MpsGraph>) -> FloatTensor<MpsGraph> {
{ let buf = unsafe { crate::ffi::retain(t.buffer) }; MpsGraphTensor { buffer: buf, shape: t.shape.clone(), dtype: t.dtype, device: *device } }
}
fn float_into_int(t: FloatTensor<MpsGraph>) -> IntTensor<MpsGraph> {
let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe {
ffi::graph_cast(g, ph, ffi::MPSDataType::INT32)
});
r.dtype = burn_backend::DType::I32;
r
}
fn float_empty(shape: Shape, device: &Device<MpsGraph>, dtype: FloatDType) -> FloatTensor<MpsGraph> {
let dt: burn_backend::DType = dtype.into();
bridge::tensor_zeros(shape, dt, *device)
}
binary_op!(float_add, ADD);
scalar_op!(float_add_scalar, ADD);
binary_op!(float_sub, SUB);
scalar_op!(float_sub_scalar, SUB);
binary_op!(float_mul, MUL);
scalar_op!(float_mul_scalar, MUL);
binary_op!(float_div, DIV);
scalar_op!(float_div_scalar, DIV);
binary_op!(float_remainder, MOD);
scalar_op!(float_remainder_scalar, MOD);
binary_op!(float_powf, POW);
fn float_powf_scalar_impl(t: FloatTensor<MpsGraph>, v: Scalar) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g, ph| unsafe {
let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
ffi::graph_binary(g, POW, ph, s)
})
}
fn float_matmul(a: FloatTensor<MpsGraph>, b: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::run_binary(&a, &b, |g, pa, pb| unsafe { ffi::graph_matmul(g, pa, pb) })
}
fn float_recip(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
unary_op_impl(&t, "reciprocalWithTensor:name:")
}
fn float_cross(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
let sl = |shape: &Shape, idx: usize| -> Vec<Slice> {
(0..shape.num_dims()).map(|d| {
if d == dim { Slice::new(idx as isize, Some((idx+1) as isize), 1) }
else { Slice::new(0, Some(shape[d] as isize), 1) }
}).collect()
};
let (a0,a1,a2) = (Self::float_slice(lhs.clone(), &sl(&lhs.shape,0)),
Self::float_slice(lhs.clone(), &sl(&lhs.shape,1)),
Self::float_slice(lhs.clone(), &sl(&lhs.shape,2)));
let (b0,b1,b2) = (Self::float_slice(rhs.clone(), &sl(&rhs.shape,0)),
Self::float_slice(rhs.clone(), &sl(&rhs.shape,1)),
Self::float_slice(rhs.clone(), &sl(&rhs.shape,2)));
let c0 = Self::float_sub(Self::float_mul(a1.clone(),b2.clone()), Self::float_mul(a2.clone(),b1.clone()));
let c1 = Self::float_sub(Self::float_mul(a2,b0.clone()), Self::float_mul(a0.clone(),b2));
let c2 = Self::float_sub(Self::float_mul(a0,b1), Self::float_mul(a1,b0));
Self::float_cat(vec![c0,c1,c2], dim)
}
cmp_op!(float_equal, EQ);
cmp_scalar_op!(float_equal_elem, EQ);
cmp_op!(float_greater, GT);
cmp_scalar_op!(float_greater_elem, GT);
cmp_op!(float_greater_equal, GTE);
cmp_scalar_op!(float_greater_equal_elem, GTE);
cmp_op!(float_lower, LT);
cmp_scalar_op!(float_lower_elem, LT);
cmp_op!(float_lower_equal, LTE);
cmp_scalar_op!(float_lower_equal_elem, LTE);
unary_op!(float_exp, "exponentWithTensor:name:");
unary_op!(float_log, "logarithmWithTensor:name:");
unary_op!(float_sqrt, "squareRootWithTensor:name:");
unary_op!(float_abs, "absoluteWithTensor:name:");
unary_op!(float_cos, "cosWithTensor:name:");
unary_op!(float_sin, "sinWithTensor:name:");
unary_op!(float_tan, "tanWithTensor:name:");
unary_op!(float_cosh, "coshWithTensor:name:");
unary_op!(float_sinh, "sinhWithTensor:name:");
unary_op!(float_tanh, "tanhWithTensor:name:");
unary_op!(float_acos, "acosWithTensor:name:");
unary_op!(float_acosh, "acoshWithTensor:name:");
unary_op!(float_asin, "asinWithTensor:name:");
unary_op!(float_asinh, "asinhWithTensor:name:");
unary_op!(float_atan, "atanWithTensor:name:");
unary_op!(float_atanh, "atanhWithTensor:name:");
unary_op!(float_erf, "erfWithTensor:name:");
unary_op!(float_floor, "floorWithTensor:name:");
unary_op!(float_ceil, "ceilWithTensor:name:");
unary_op!(float_round, "rintWithTensor:name:");
fn float_atan2(lhs: FloatTensor<MpsGraph>, rhs: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::run_binary(&lhs, &rhs, |g,a,b| unsafe { ffi::graph_binary(g, "atan2WithPrimaryTensor:secondaryTensor:name:", a, b) })
}
fn float_log1p(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
Self::float_log(Self::float_add_scalar(t, 1.0f32.into()))
}
fn float_trunc(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g, ph| unsafe {
let abs = ffi::graph_unary(g, "absoluteWithTensor:name:", ph);
let fl = ffi::graph_unary(g, "floorWithTensor:name:", abs);
let sgn = ffi::graph_unary(g, "signWithTensor:name:", ph);
ffi::graph_binary(g, MUL, sgn, fl)
})
}
fn float_swap_dims(t: FloatTensor<MpsGraph>, d1: usize, d2: usize) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_transpose(g, ph, d1, d2) })
}
fn float_permute(t: FloatTensor<MpsGraph>, axes: &[usize]) -> FloatTensor<MpsGraph> {
let perm_ns = unsafe { ffi::ns_usize_array(axes) };
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_permute(g, ph, perm_ns) })
}
fn float_flip(t: FloatTensor<MpsGraph>, axes: &[usize]) -> FloatTensor<MpsGraph> {
let nd = t.shape.num_dims();
let shape = &t.shape;
let starts: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { shape[d] as isize - 1 } else { 0 }).collect();
let ends: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -(shape[d] as isize) - 1 } else { shape[d] as isize }).collect();
let strides: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -1 } else { 1 }).collect();
bridge::run_unary_ctx(&t, |g, ph| unsafe {
ffi::graph_slice_masked(g, ph,
ffi::ns_isize_array(&starts), ffi::ns_isize_array(&ends),
ffi::ns_isize_array(&strides), 0, 0, 0)
})
}
fn float_reshape(t: FloatTensor<MpsGraph>, shape: Shape) -> FloatTensor<MpsGraph> {
let ns = bridge::shape_to_ns(&shape);
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reshape(g, ph, ns) })
}
fn float_expand(t: FloatTensor<MpsGraph>, shape: Shape) -> FloatTensor<MpsGraph> {
let ns = bridge::shape_to_ns(&shape);
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_broadcast(g, ph, ns) })
}
fn float_slice(t: FloatTensor<MpsGraph>, slices: &[Slice]) -> FloatTensor<MpsGraph> {
let (sa, ea, st) = bridge::slices_to_ns(slices, &t.shape);
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_slice(g, ph, sa, ea, st) })
}
fn float_slice_assign(t: FloatTensor<MpsGraph>, slices: &[Slice], value: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
let (sa, ea, st) = bridge::slices_to_ns(slices, &t.shape);
bridge::run_binary_ctx(&t, &value, |g, pd, pu| unsafe {
ffi::graph_slice_update(g, pd, pu, sa, ea, st)
})
}
fn float_cat(tensors: Vec<FloatTensor<MpsGraph>>, dim: usize) -> FloatTensor<MpsGraph> {
if tensors.len() == 1 { return tensors.into_iter().next().unwrap(); }
let refs: Vec<&MpsGraphTensor> = tensors.iter().collect();
bridge::run_multi_ctx(&refs, tensors[0].device, |g, phs| unsafe {
let arr = ffi::ns_array(phs);
ffi::graph_concat(g, arr, dim as isize)
})
}
fn float_gather(dim: usize, t: FloatTensor<MpsGraph>, idx: IntTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::run_binary_ctx(&t, &idx, |g,a,b| unsafe { ffi::graph_gather(g, a, b, dim, 0) })
}
fn float_scatter_add(dim: usize, t: FloatTensor<MpsGraph>, idx: IntTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::run_multi_ctx(&[&t, &idx, &val], t.device, |g, phs| unsafe {
ffi::graph_scatter_along(g, dim as isize, phs[0], phs[2], phs[1], ffi::MPSGraphScatterMode::ADD)
})
}
fn float_select(t: FloatTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
Self::float_gather(dim, t, idx)
}
fn float_select_add(t: FloatTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
Self::float_scatter_add(dim, t, idx, val)
}
fn float_mask_where(t: FloatTensor<MpsGraph>, mask: BoolTensor<MpsGraph>, val: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
bridge::run_ternary(&t, &mask, &val, |g, pt, pm, pv| unsafe { ffi::graph_select(g, pm, pv, pt) })
}
fn float_mask_fill(t: FloatTensor<MpsGraph>, mask: BoolTensor<MpsGraph>, val: Scalar) -> FloatTensor<MpsGraph> {
bridge::run_binary_ctx(&t, &mask, |g, pt, pm| unsafe {
let s = ffi::graph_constant_scalar(g, val.elem::<f64>(), burn_to_mps_dtype(t.dtype));
ffi::graph_select(g, pm, s, pt)
})
}
fn float_sum(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_sum(g, ph, ffi::ns_isize_array(&axes)) })
}
fn float_sum_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_sum_axis(g, ph, dim as isize) })
}
fn float_mean_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
let n = t.shape[dim] as f64;
let sum = Self::float_sum_dim(t, dim);
Self::float_div_scalar(sum, Scalar::from(n as f32))
}
fn float_argmax(t: FloatTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argmax(g, ph, dim as isize) });
r.dtype = DType::I32; r
}
fn float_argmin(t: FloatTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argmin(g, ph, dim as isize) });
r.dtype = DType::I32; r
}
fn float_cumsum(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumsum(g,ph,dim as isize) }) }
fn float_cumprod(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumprod(g,ph,dim as isize) }) }
fn float_cummin(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummin(g,ph,dim as isize) }) }
fn float_cummax(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummax(g,ph,dim as isize) }) }
fn float_sort(t: FloatTensor<MpsGraph>, dim: usize, desc: bool) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_sort(g, ph, dim as isize, desc) })
}
fn float_argsort(t: FloatTensor<MpsGraph>, dim: usize, desc: bool) -> IntTensor<MpsGraph> {
let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_argsort(g, ph, dim as isize, desc) });
r.dtype = DType::I32; r
}
fn float_cast(t: FloatTensor<MpsGraph>, dtype: FloatDType) -> FloatTensor<MpsGraph> {
let dt: DType = dtype.into();
if t.dtype == dt { return t; }
let mps = burn_to_mps_dtype(dt);
let mut r = bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_cast(g, ph, mps) });
r.dtype = dt; r
}
fn float_prod(t: FloatTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_prod(g, ph, ffi::ns_isize_array(&axes)) })
}
fn float_prod_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_prod_axis(g, ph, dim as isize) })
}
fn float_max_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_max_axis(g, ph, dim as isize) })
}
fn float_min_dim(t: FloatTensor<MpsGraph>, dim: usize) -> FloatTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reduction_min_axis(g, ph, dim as isize) })
}
fn float_unfold(t: FloatTensor<MpsGraph>, dim: usize, size: usize, step: usize) -> FloatTensor<MpsGraph> {
let dim_size = t.shape[dim];
let n_win = (dim_size.saturating_sub(size)) / step + 1;
let mut windows = Vec::with_capacity(n_win);
for i in 0..n_win {
let start = i * step;
let slices: Vec<Slice> = (0..t.shape.num_dims()).map(|d| {
if d == dim { Slice::new(start as isize, Some((start+size) as isize), 1) }
else { Slice::new(0, Some(t.shape[d] as isize), 1) }
}).collect();
let w = Self::float_slice(t.clone(), &slices);
let mut dims: Vec<usize> = (0..w.shape.num_dims()).map(|d| w.shape[d]).collect();
dims[dim] = 1; dims.push(size);
windows.push(Self::float_reshape(w, Shape::from(dims)));
}
Self::float_cat(windows, dim)
}
}
fn unary_op_impl(t: &MpsGraphTensor, sel: &'static str) -> MpsGraphTensor {
bridge::run_unary_ctx(t, |g, ph| unsafe { ffi::graph_unary(g, sel, ph) })
}