Skip to main content

burn_mpsgraph/ops/
bool_tensor.rs

1use burn_backend::ops::{BoolTensorOps, IntTensorOps};
2use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
3use burn_backend::{DType, ExecutionError, Scalar, TensorData};
4use burn_std::{Shape, Slice};
5use std::future::Future;
6
7use crate::bridge::{self, burn_to_mps_dtype};
8use crate::ffi::{self};
9use crate::{MpsGraph, MpsGraphTensor};
10
11impl BoolTensorOps<MpsGraph> for MpsGraph {
12    fn bool_empty(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { bridge::tensor_zeros(shape, DType::Bool, *device) }
13    fn bool_zeros(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { bridge::tensor_zeros(shape, DType::Bool, *device) }
14    fn bool_ones(shape: Shape, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> {
15        let n = shape.num_elements();
16        bridge::tensor_from_bytes(&vec![1u8; n], shape, DType::Bool, *device)
17    }
18
19    fn bool_into_data(t: BoolTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
20        async move { Ok(TensorData::from_bytes_vec(bridge::tensor_to_bytes(&t), t.shape.clone(), t.dtype)) }
21    }
22
23    fn bool_from_data(data: TensorData, device: &Device<MpsGraph>) -> BoolTensor<MpsGraph> {
24        bridge::tensor_from_bytes(data.as_bytes(), Shape::from(data.shape.clone()), data.dtype, *device)
25    }
26
27    fn bool_into_int(t: BoolTensor<MpsGraph>) -> IntTensor<MpsGraph> {
28        let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,ffi::MPSDataType::INT32) });
29        r.dtype = DType::I32; r
30    }
31
32    fn bool_into_float(t: BoolTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
33        let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,ffi::MPSDataType::FLOAT32) });
34        r.dtype = DType::F32; r
35    }
36
37    fn bool_device(t: &BoolTensor<MpsGraph>) -> Device<MpsGraph> { t.device }
38    fn bool_to_device(t: BoolTensor<MpsGraph>, d: &Device<MpsGraph>) -> BoolTensor<MpsGraph> { { let buf = unsafe { crate::ffi::retain(t.buffer) }; MpsGraphTensor { buffer: buf, shape: t.shape.clone(), dtype: t.dtype, device: *d } } }
39
40    fn bool_reshape(t: BoolTensor<MpsGraph>, shape: Shape) -> BoolTensor<MpsGraph> {
41        let ns = bridge::shape_to_ns(&shape);
42        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reshape(g,ph,ns) })
43    }
44
45    fn bool_slice(t: BoolTensor<MpsGraph>, slices: &[Slice]) -> BoolTensor<MpsGraph> {
46        let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
47        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_slice(g,ph,sa,ea,st) })
48    }
49
50    fn bool_slice_assign(t: BoolTensor<MpsGraph>, slices: &[Slice], v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
51        let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
52        bridge::run_binary_ctx(&t,&v, |g,pd,pu| unsafe { ffi::graph_slice_update(g,pd,pu,sa,ea,st) })
53    }
54
55    fn bool_mask_where(t: BoolTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
56        bridge::run_ternary(&t,&m,&v, |g,pt,pm,pv| unsafe { ffi::graph_select(g,pm,pv,pt) })
57    }
58
59    fn bool_mask_fill(t: BoolTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: Scalar) -> BoolTensor<MpsGraph> {
60        bridge::run_binary_ctx(&t,&m, |g,pt,pm| unsafe {
61            let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
62            ffi::graph_select(g,pm,s,pt)
63        })
64    }
65
66    fn bool_gather(dim: usize, t: BoolTensor<MpsGraph>, idx: IntTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
67        bridge::run_binary_ctx(&t, &idx, |g,a,b| unsafe { ffi::graph_gather(g,a,b,dim,0) })
68    }
69
70    fn bool_scatter_or(dim: usize, t: BoolTensor<MpsGraph>, idx: IntTensor<MpsGraph>, v: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
71        let ti = Self::bool_into_int(t); let vi = Self::bool_into_int(v);
72        let ri = MpsGraph::int_scatter_add(dim, ti, idx, vi);
73        // Non-zero -> true
74        bridge::run_unary_ctx(&ri, |g,ph| unsafe {
75            let z = ffi::graph_constant_scalar(g, 0.0, ffi::MPSDataType::INT32);
76            ffi::graph_binary(g, "notEqualWithPrimaryTensor:secondaryTensor:name:", ph, z)
77        })
78    }
79
80    fn bool_equal(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
81        bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "equalWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
82    }
83
84    fn bool_equal_elem(t: BoolTensor<MpsGraph>, v: Scalar) -> BoolTensor<MpsGraph> {
85        bridge::run_unary_ctx(&t, |g,ph| unsafe {
86            let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
87            ffi::graph_binary(g, "equalWithPrimaryTensor:secondaryTensor:name:", ph, s)
88        })
89    }
90
91    fn bool_not(t: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
92        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_unary(g, "notWithTensor:name:", ph) })
93    }
94
95    fn bool_and(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
96        bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "logicalANDWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
97    }
98
99    fn bool_or(a: BoolTensor<MpsGraph>, b: BoolTensor<MpsGraph>) -> BoolTensor<MpsGraph> {
100        bridge::run_binary(&a,&b, |g,pa,pb| unsafe { ffi::graph_binary(g, "logicalORWithPrimaryTensor:secondaryTensor:name:", pa, pb) })
101    }
102
103    fn bool_swap_dims(t: BoolTensor<MpsGraph>, d1: usize, d2: usize) -> BoolTensor<MpsGraph> {
104        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_transpose(g,ph,d1,d2) })
105    }
106
107    fn bool_permute(t: BoolTensor<MpsGraph>, axes: &[usize]) -> BoolTensor<MpsGraph> {
108        let p = unsafe { ffi::ns_usize_array(axes) };
109        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_permute(g,ph,p) })
110    }
111
112    fn bool_flip(t: BoolTensor<MpsGraph>, axes: &[usize]) -> BoolTensor<MpsGraph> {
113        let nd = t.shape.num_dims(); let shape = &t.shape;
114        let starts: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { shape[d] as isize-1 } else { 0 }).collect();
115        let ends:   Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -(shape[d] as isize)-1 } else { shape[d] as isize }).collect();
116        let strides: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -1 } else { 1 }).collect();
117        bridge::run_unary_ctx(&t, |g,ph| unsafe {
118            ffi::graph_slice_masked(g,ph, ffi::ns_isize_array(&starts), ffi::ns_isize_array(&ends), ffi::ns_isize_array(&strides), 0,0,0)
119        })
120    }
121
122    fn bool_expand(t: BoolTensor<MpsGraph>, shape: Shape) -> BoolTensor<MpsGraph> {
123        let ns = bridge::shape_to_ns(&shape);
124        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_broadcast(g,ph,ns) })
125    }
126
127    fn bool_cat(tensors: Vec<BoolTensor<MpsGraph>>, dim: usize) -> BoolTensor<MpsGraph> {
128        if tensors.len()==1 { return tensors.into_iter().next().unwrap(); }
129        let refs: Vec<&MpsGraphTensor> = tensors.iter().collect();
130        bridge::run_multi_ctx(&refs, tensors[0].device, |g,phs| unsafe { ffi::graph_concat(g, ffi::ns_array(phs), dim as isize) })
131    }
132
133    fn bool_unfold(t: BoolTensor<MpsGraph>, dim: usize, size: usize, step: usize) -> BoolTensor<MpsGraph> {
134        let ti = Self::bool_into_int(t);
135        let ui = MpsGraph::int_unfold(ti, dim, size, step);
136        bridge::run_unary_ctx(&ui, |g,ph| unsafe {
137            let z = ffi::graph_constant_scalar(g, 0.0, ffi::MPSDataType::INT32);
138            ffi::graph_binary(g, "notEqualWithPrimaryTensor:secondaryTensor:name:", ph, z)
139        })
140    }
141}