use burn_backend::ops::{BoolTensorOps, IntTensorOps};
use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
use burn_backend::{DType, ExecutionError, Scalar, TensorData};
use burn_std::{Shape, Slice};
use std::future::Future;
use crate::bridge::{self, burn_to_mps_dtype};
use crate::ffi::{self};
use crate::{MpsGraph, MpsGraphTensor};
impl BoolTensorOps<MpsGraph> for MpsGraph {
fn bool_empty(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { bridge::tensor_zeros(shape, DType::Bool, *device) }
fn bool_zeros(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { bridge::tensor_zeros(shape, DType::Bool, *device) }
fn bool_ones(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> {
let n = shape.num_elements();
bridge::tensor_from_bytes(&vec![1u8; n], shape, DType::Bool, *device)
}
fn bool_into_data(t: BoolTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
async move { Ok(TensorData::from_bytes_vec(bridge::tensor_to_bytes(&t), t.shape.clone(), t.dtype)) }
}
fn bool_from_data(data: TensorData, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> {
bridge::tensor_from_bytes(data.as_bytes(), Shape::from(data.shape.clone()), data.dtype, *device)
}
fn bool_into_int(t: BoolTensor<MpsGraph>) -> IntTensor<MpsGraph> {
let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,ffi::MPSDataType::INT32) });
r.dtype = DType::I32; r
}
fn bool_into_float(t: BoolTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,ffi::MPSDataType::FLOAT32) });
r.dtype = DType::F32; r
}
fn bool_device(t: &BoolTensor<MpsGraph>) -> Device<MpsGraph> { t.device }
fn bool_to_device(t: BoolTensor<MpsGraph>, d: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { { let buf = unsafe { crate::ffi::retain(t.buffer) }; MpsGraphTensor { buffer: buf, shape: t.shape.clone(), dtype: t.dtype, device: *d } } }
fn bool_reshape(t: BoolTensor<MpsGraph>, shape: Shape) -> BoolTensor<MpsGraph> {
let ns = bridge::shape_to_ns(&shape);
bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reshape(g,ph,ns) })
}
fn bool_slice(t: BoolTensor<MpsGraph>, slices: &[Slice]) -> BoolTensor<MpsGraph> {
let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_slice(g,ph,sa,ea,st) })
}
fn bool_slice_assign(t: BoolTensor<MpsGraph>, slices: &[Slice], v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
bridge::run_binary_ctx(&t,&v, |g,pd,pu| unsafe { ffi::graph_slice_update(g,pd,pu,sa,ea,st) })
}
fn bool_mask_where(t: BoolTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
bridge::run_ternary(&t,&m,&v, |g,pt,pm,pv| unsafe { ffi::graph_select(g,pm,pv,pt) })
}
fn bool_mask_fill(t: BoolTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: Scalar) -> BoolTensor<MpsGraph> {
bridge::run_binary_ctx(&t,&m, |g,pt,pm| unsafe {
let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
ffi::graph_select(g,pm,s,pt)
})
}
fn bool_gather(dim: usize, t: BoolTensor<MpsGraph>, idx: IntTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
bridge::run_binary_ctx(&t, &idx, |g,a,b| unsafe { ffi::graph_gather(g,a,b,dim,0) })
}
fn bool_scatter_or(dim: usize, t: BoolTensor<MpsGraph>, idx: IntTensor<MpsGraph>, v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
let ti = Self::bool_into_int(t); let vi = Self::bool_into_int(v);
let ri = MpsGraph::int_scatter_add(dim, ti, idx, vi);
bridge::run_unary_ctx(&ri, |g,ph| unsafe {
let z = ffi::graph_constant_scalar(g, 0.0, ffi::MPSDataType::INT32);
ffi::graph_binary(g, "notEqualWithPrimaryTensor:secondaryTensor:name:", ph, z)
})
}
fn bool_equal(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "equalWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
}
fn bool_equal_elem(t: BoolTensor<MpsGraph>, v: Scalar) -> BoolTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g,ph| unsafe {
let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
ffi::graph_binary(g, "equalWithPrimaryTensor:secondaryTensor:name:", ph, s)
})
}
fn bool_not(t: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_unary(g, "notWithTensor:name:", ph) })
}
fn bool_and(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "logicalANDWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
}
fn bool_or(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "logicalORWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
}
fn bool_swap_dims(t: BoolTensor<MpsGraph>, d1: usize, d2: usize) -> BoolTensor<MpsGraph> {
bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_transpose(g,ph,d1,d2) })
}
fn bool_permute(t: BoolTensor<MpsGraph>, axes: &[usize]) -> BoolTensor<MpsGraph> {
let p = unsafe { ffi::ns_usize_array(axes) };
bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_permute(g,ph,p) })
}
fn bool_flip(t: BoolTensor<MpsGraph>, axes: &[usize]) -> BoolTensor<MpsGraph> {
let nd = t.shape.num_dims(); let shape = &t.shape;
let starts: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { shape[d] as isize-1 } else { 0 }).collect();
let ends: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -(shape[d] as isize)-1 } else { shape[d] as isize }).collect();
let strides: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -1 } else { 1 }).collect();
bridge::run_unary_ctx(&t, |g,ph| unsafe {
ffi::graph_slice_masked(g,ph, ffi::ns_isize_array(&starts), ffi::ns_isize_array(&ends), ffi::ns_isize_array(&strides), 0,0,0)
})
}
fn bool_expand(t: BoolTensor<MpsGraph>, shape: Shape) -> BoolTensor<MpsGraph> {
let ns = bridge::shape_to_ns(&shape);
bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_broadcast(g,ph,ns) })
}
fn bool_cat(tensors: Vec<BoolTensor<MpsGraph>>, dim: usize) -> BoolTensor<MpsGraph> {
if tensors.len()==1 { return tensors.into_iter().next().unwrap(); }
let refs: Vec<&MpsGraphTensor> = tensors.iter().collect();
bridge::run_multi_ctx(&refs, tensors[0].device, |g,phs| unsafe { ffi::graph_concat(g, ffi::ns_array(phs), dim as isize) })
}
fn bool_unfold(t: BoolTensor<MpsGraph>, dim: usize, size: usize, step: usize) -> BoolTensor<MpsGraph> {
let ti = Self::bool_into_int(t);
let ui = MpsGraph::int_unfold(ti, dim, size, step);
bridge::run_unary_ctx(&ui, |g,ph| unsafe {
let z = ffi::graph_constant_scalar(g, 0.0, ffi::MPSDataType::INT32);
ffi::graph_binary(g, "notEqualWithPrimaryTensor:secondaryTensor:name:", ph, z)
})
}
}