1use burn_backend::ops::{BoolTensorOps, IntTensorOps};
2use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
3use burn_backend::{DType, ExecutionError, Scalar, TensorData};
4use burn_std::{Shape, Slice};
5use std::future::Future;
6
7use crate::bridge::{self, burn_to_mps_dtype};
8use crate::ffi::{self};
9use crate::{MpsGraph, MpsGraphTensor};
10
11impl BoolTensorOps<MpsGraph> for MpsGraph {
12 fn bool_empty(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { bridge::tensor_zeros(shape, DType::Bool, *device) }
13 fn bool_zeros(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { bridge::tensor_zeros(shape, DType::Bool, *device) }
14 fn bool_ones(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> {
15 let n = shape.num_elements();
16 bridge::tensor_from_bytes(&vec![1u8; n], shape, DType::Bool, *device)
17 }
18
19 fn bool_into_data(t: BoolTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
20 async move { Ok(TensorData::from_bytes_vec(bridge::tensor_to_bytes(&t), t.shape.clone(), t.dtype)) }
21 }
22
23 fn bool_from_data(data: TensorData, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> {
24 bridge::tensor_from_bytes(data.as_bytes(), Shape::from(data.shape.clone()), data.dtype, *device)
25 }
26
27 fn bool_into_int(t: BoolTensor<MpsGraph>) -> IntTensor<MpsGraph> {
28 let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,ffi::MPSDataType::INT32) });
29 r.dtype = DType::I32; r
30 }
31
32 fn bool_into_float(t: BoolTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
33 let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,ffi::MPSDataType::FLOAT32) });
34 r.dtype = DType::F32; r
35 }
36
37 fn bool_device(t: &BoolTensor<MpsGraph>) -> Device<MpsGraph> { t.device }
38 fn bool_to_device(t: BoolTensor<MpsGraph>, d: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { { let buf = unsafe { crate::ffi::retain(t.buffer) }; MpsGraphTensor { buffer: buf, shape: t.shape.clone(), dtype: t.dtype, device: *d } } }
39
40 fn bool_reshape(t: BoolTensor<MpsGraph>, shape: Shape) -> BoolTensor<MpsGraph> {
41 let ns = bridge::shape_to_ns(&shape);
42 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reshape(g,ph,ns) })
43 }
44
45 fn bool_slice(t: BoolTensor<MpsGraph>, slices: &[Slice]) -> BoolTensor<MpsGraph> {
46 let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
47 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_slice(g,ph,sa,ea,st) })
48 }
49
50 fn bool_slice_assign(t: BoolTensor<MpsGraph>, slices: &[Slice], v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
51 let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
52 bridge::run_binary_ctx(&t,&v, |g,pd,pu| unsafe { ffi::graph_slice_update(g,pd,pu,sa,ea,st) })
53 }
54
55 fn bool_mask_where(t: BoolTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
56 bridge::run_ternary(&t,&m,&v, |g,pt,pm,pv| unsafe { ffi::graph_select(g,pm,pv,pt) })
57 }
58
59 fn bool_mask_fill(t: BoolTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: Scalar) -> BoolTensor<MpsGraph> {
60 bridge::run_binary_ctx(&t,&m, |g,pt,pm| unsafe {
61 let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
62 ffi::graph_select(g,pm,s,pt)
63 })
64 }
65
66 fn bool_gather(dim: usize, t: BoolTensor<MpsGraph>, idx: IntTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
67 bridge::run_binary_ctx(&t, &idx, |g,a,b| unsafe { ffi::graph_gather(g,a,b,dim,0) })
68 }
69
70 fn bool_scatter_or(dim: usize, t: BoolTensor<MpsGraph>, idx: IntTensor<MpsGraph>, v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
71 let ti = Self::bool_into_int(t); let vi = Self::bool_into_int(v);
72 let ri = MpsGraph::int_scatter_add(dim, ti, idx, vi);
73 bridge::run_unary_ctx(&ri, |g,ph| unsafe {
75 let z = ffi::graph_constant_scalar(g, 0.0, ffi::MPSDataType::INT32);
76 ffi::graph_binary(g, "notEqualWithPrimaryTensor:secondaryTensor:name:", ph, z)
77 })
78 }
79
80 fn bool_equal(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
81 bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "equalWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
82 }
83
84 fn bool_equal_elem(t: BoolTensor<MpsGraph>, v: Scalar) -> BoolTensor<MpsGraph> {
85 bridge::run_unary_ctx(&t, |g,ph| unsafe {
86 let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
87 ffi::graph_binary(g, "equalWithPrimaryTensor:secondaryTensor:name:", ph, s)
88 })
89 }
90
91 fn bool_not(t: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
92 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_unary(g, "notWithTensor:name:", ph) })
93 }
94
95 fn bool_and(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
96 bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "logicalANDWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
97 }
98
99 fn bool_or(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
100 bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "logicalORWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
101 }
102
103 fn bool_swap_dims(t: BoolTensor<MpsGraph>, d1: usize, d2: usize) -> BoolTensor<MpsGraph> {
104 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_transpose(g,ph,d1,d2) })
105 }
106
107 fn bool_permute(t: BoolTensor<MpsGraph>, axes: &[usize]) -> BoolTensor<MpsGraph> {
108 let p = unsafe { ffi::ns_usize_array(axes) };
109 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_permute(g,ph,p) })
110 }
111
112 fn bool_flip(t: BoolTensor<MpsGraph>, axes: &[usize]) -> BoolTensor<MpsGraph> {
113 let nd = t.shape.num_dims(); let shape = &t.shape;
114 let starts: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { shape[d] as isize-1 } else { 0 }).collect();
115 let ends: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -(shape[d] as isize)-1 } else { shape[d] as isize }).collect();
116 let strides: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -1 } else { 1 }).collect();
117 bridge::run_unary_ctx(&t, |g,ph| unsafe {
118 ffi::graph_slice_masked(g,ph, ffi::ns_isize_array(&starts), ffi::ns_isize_array(&ends), ffi::ns_isize_array(&strides), 0,0,0)
119 })
120 }
121
122 fn bool_expand(t: BoolTensor<MpsGraph>, shape: Shape) -> BoolTensor<MpsGraph> {
123 let ns = bridge::shape_to_ns(&shape);
124 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_broadcast(g,ph,ns) })
125 }
126
127 fn bool_cat(tensors: Vec<BoolTensor<MpsGraph>>, dim: usize) -> BoolTensor<MpsGraph> {
128 if tensors.len()==1 { return tensors.into_iter().next().unwrap(); }
129 let refs: Vec<&MpsGraphTensor> = tensors.iter().collect();
130 bridge::run_multi_ctx(&refs, tensors[0].device, |g,phs| unsafe { ffi::graph_concat(g, ffi::ns_array(phs), dim as isize) })
131 }
132
133 fn bool_unfold(t: BoolTensor<MpsGraph>, dim: usize, size: usize, step: usize) -> BoolTensor<MpsGraph> {
134 let ti = Self::bool_into_int(t);
135 let ui = MpsGraph::int_unfold(ti, dim, size, step);
136 bridge::run_unary_ctx(&ui, |g,ph| unsafe {
137 let z = ffi::graph_constant_scalar(g, 0.0, ffi::MPSDataType::INT32);
138 ffi::graph_binary(g, "notEqualWithPrimaryTensor:secondaryTensor:name:", ph, z)
139 })
140 }
141}