1use burn_backend::ops::{FloatTensorOps, IntTensorOps};
2use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
3use burn_backend::{DType, Distribution, ExecutionError, Scalar, TensorData};
4use burn_std::{IntDType, Shape, Slice};
5use std::future::Future;
6
7use crate::bridge::{self, burn_to_mps_dtype};
8use crate::ffi::{self};
9use crate::{MpsGraph, MpsGraphTensor};
10
11use super::float_tensor::{ADD, SUB, MUL, DIV, MOD, EQ, GT, GTE, LT, LTE};
12
13macro_rules! int_binary { ($name:ident, $s:expr) => { fn $name(a: IntTensor<MpsGraph>, b: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> { bridge::run_binary(&a,&b,|g,pa,pb| unsafe{ffi::graph_binary(g,$s,pa,pb)}) } }; }
14macro_rules! int_scalar { ($name:ident, $s:expr) => { fn $name(a: IntTensor<MpsGraph>, b: Scalar) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&a,|g,ph| unsafe{ let s=ffi::graph_constant_scalar(g,b.elem::<f64>(),burn_to_mps_dtype(a.dtype)); ffi::graph_binary(g,$s,ph,s) }) } }; }
15macro_rules! int_cmp { ($name:ident, $s:expr) => { fn $name(a: IntTensor<MpsGraph>, b: IntTensor<MpsGraph>) -> BoolTensor<MpsGraph> { bridge::run_binary(&a,&b,|g,pa,pb| unsafe{ffi::graph_binary(g,$s,pa,pb)}) } }; }
16macro_rules! int_cmp_s { ($name:ident, $s:expr) => { fn $name(a: IntTensor<MpsGraph>, b: Scalar) -> BoolTensor<MpsGraph> { bridge::run_unary_ctx(&a,|g,ph| unsafe{ let s=ffi::graph_constant_scalar(g,b.elem::<f64>(),burn_to_mps_dtype(a.dtype)); ffi::graph_binary(g,$s,ph,s) }) } }; }
17
18impl IntTensorOps<MpsGraph> for MpsGraph {
19 fn int_empty(shape: Shape, device: &Device<MpsGraph>, dtype: IntDType) -> IntTensor<MpsGraph> {
20 bridge::tensor_zeros(shape, dtype.into(), *device)
21 }
22
23 fn int_into_data(t: IntTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
24 async move { Ok(TensorData::from_bytes_vec(bridge::tensor_to_bytes(&t), t.shape.clone(), t.dtype)) }
25 }
26
27 fn int_from_data(data: TensorData, device: &Device<MpsGraph>) -> IntTensor<MpsGraph> {
28 bridge::tensor_from_bytes(data.as_bytes(), Shape::from(data.shape.clone()), data.dtype, *device)
29 }
30
31 fn int_device(t: &IntTensor<MpsGraph>) -> Device<MpsGraph> { t.device }
32 fn int_to_device(t: IntTensor<MpsGraph>, d: &Device<MpsGraph>) -> IntTensor<MpsGraph> { { let buf = unsafe { crate::ffi::retain(t.buffer) }; MpsGraphTensor { buffer: buf, shape: t.shape.clone(), dtype: t.dtype, device: *d } } }
33
34 fn int_reshape(t: IntTensor<MpsGraph>, shape: Shape) -> IntTensor<MpsGraph> {
35 let ns = bridge::shape_to_ns(&shape);
36 bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reshape(g, ph, ns) })
37 }
38
39 fn int_slice(t: IntTensor<MpsGraph>, slices: &[Slice]) -> IntTensor<MpsGraph> {
40 let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
41 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_slice(g,ph,sa,ea,st) })
42 }
43
44 fn int_slice_assign(t: IntTensor<MpsGraph>, slices: &[Slice], v: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
45 let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
46 bridge::run_binary_ctx(&t, &v, |g,pd,pu| unsafe { ffi::graph_slice_update(g,pd,pu,sa,ea,st) })
47 }
48
49 fn int_into_float(t: IntTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
50 let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,ffi::MPSDataType::FLOAT32) });
51 r.dtype = DType::F32; r
52 }
53
54 fn int_mask_where(t: IntTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
55 bridge::run_ternary(&t,&m,&v, |g,pt,pm,pv| unsafe { ffi::graph_select(g,pm,pv,pt) })
56 }
57
58 fn int_mask_fill(t: IntTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: Scalar) -> IntTensor<MpsGraph> {
59 bridge::run_binary_ctx(&t,&m, |g,pt,pm| unsafe {
60 let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
61 ffi::graph_select(g, pm, s, pt)
62 })
63 }
64
65 fn int_gather(dim: usize, t: IntTensor<MpsGraph>, idx: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
66 bridge::run_binary_ctx(&t, &idx, |g,a,b| unsafe { ffi::graph_gather(g,a,b,dim,0) })
67 }
68
69 fn int_scatter_add(dim: usize, t: IntTensor<MpsGraph>, idx: IntTensor<MpsGraph>, v: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
70 bridge::run_multi_ctx(&[&t,&idx,&v], t.device, |g,phs| unsafe {
71 ffi::graph_scatter_along(g, dim as isize, phs[0], phs[2], phs[1], ffi::MPSGraphScatterMode::ADD)
72 })
73 }
74
75 fn int_select(t: IntTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> { Self::int_gather(dim,t,idx) }
76 fn int_select_add(t: IntTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>, v: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> { Self::int_scatter_add(dim,t,idx,v) }
77
78 int_cmp!(int_equal, EQ); int_cmp_s!(int_equal_elem, EQ);
79 int_cmp!(int_greater, GT); int_cmp_s!(int_greater_elem, GT);
80 int_cmp!(int_greater_equal, GTE); int_cmp_s!(int_greater_equal_elem, GTE);
81 int_cmp!(int_lower, LT); int_cmp_s!(int_lower_elem, LT);
82 int_cmp!(int_lower_equal, LTE); int_cmp_s!(int_lower_equal_elem, LTE);
83
84 int_binary!(int_add, ADD); int_scalar!(int_add_scalar, ADD);
85 int_binary!(int_sub, SUB); int_scalar!(int_sub_scalar, SUB);
86 int_binary!(int_mul, MUL); int_scalar!(int_mul_scalar, MUL);
87 int_binary!(int_div, DIV); int_scalar!(int_div_scalar, DIV);
88 int_binary!(int_remainder, MOD); int_scalar!(int_remainder_scalar, MOD);
89
90 fn int_matmul(a: IntTensor<MpsGraph>, b: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
91 let af = Self::int_into_float(a); let bf = Self::int_into_float(b);
92 MpsGraph::float_into_int(MpsGraph::float_matmul(af, bf))
93 }
94
95 fn int_neg(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
96 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_unary(g, "negativeWithTensor:name:", ph) })
97 }
98
99 fn int_abs(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
100 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_unary(g, "absoluteWithTensor:name:", ph) })
101 }
102
103 fn int_swap_dims(t: IntTensor<MpsGraph>, d1: usize, d2: usize) -> IntTensor<MpsGraph> {
104 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_transpose(g,ph,d1,d2) })
105 }
106
107 fn int_permute(t: IntTensor<MpsGraph>, axes: &[usize]) -> IntTensor<MpsGraph> {
108 let p = unsafe { ffi::ns_usize_array(axes) };
109 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_permute(g,ph,p) })
110 }
111
112 fn int_flip(t: IntTensor<MpsGraph>, axes: &[usize]) -> IntTensor<MpsGraph> {
113 let nd = t.shape.num_dims(); let shape = &t.shape;
114 let starts: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { shape[d] as isize-1 } else { 0 }).collect();
115 let ends: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -(shape[d] as isize)-1 } else { shape[d] as isize }).collect();
116 let strides: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -1 } else { 1 }).collect();
117 bridge::run_unary_ctx(&t, |g,ph| unsafe {
118 ffi::graph_slice_masked(g,ph, ffi::ns_isize_array(&starts), ffi::ns_isize_array(&ends), ffi::ns_isize_array(&strides), 0,0,0)
119 })
120 }
121
122 fn int_expand(t: IntTensor<MpsGraph>, shape: Shape) -> IntTensor<MpsGraph> {
123 let ns = bridge::shape_to_ns(&shape);
124 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_broadcast(g,ph,ns) })
125 }
126
127 fn int_sum(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
128 let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
129 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reduction_sum(g,ph, ffi::ns_isize_array(&axes)) })
130 }
131
132 fn int_sum_dim(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
133 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reduction_sum_axis(g,ph,dim as isize) })
134 }
135
136 fn int_prod(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
137 let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
138 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reduction_prod(g,ph, ffi::ns_isize_array(&axes)) })
139 }
140
141 fn int_prod_dim(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
142 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reduction_prod_axis(g,ph,dim as isize) })
143 }
144
145 fn int_mean_dim(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
146 let n = t.shape[dim]; let sum = Self::int_sum_dim(t, dim);
147 Self::int_div_scalar(sum, Scalar::from(n as i64))
148 }
149
150 fn int_cumsum(t: IntTensor<MpsGraph>, d: usize) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumsum(g,ph,d as isize) }) }
151 fn int_cumprod(t: IntTensor<MpsGraph>, d: usize) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumprod(g,ph,d as isize) }) }
152 fn int_cummin(t: IntTensor<MpsGraph>, d: usize) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummin(g,ph,d as isize) }) }
153 fn int_cummax(t: IntTensor<MpsGraph>, d: usize) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummax(g,ph,d as isize) }) }
154
155 fn int_argmax(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
156 let f = Self::int_into_float(t); MpsGraph::float_argmax(f, dim)
157 }
158 fn int_argmin(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
159 let f = Self::int_into_float(t); MpsGraph::float_argmin(f, dim)
160 }
161
162 fn int_random(shape: Shape, dist: Distribution, device: &Device<MpsGraph>) -> IntTensor<MpsGraph> {
163 let n = shape.num_elements(); let mut buf = vec![0i32; n];
164 let mut rng = crate::ops::get_seeded_rng(); use rand::Rng;
165 match dist {
166 Distribution::Uniform(lo,hi) => buf.iter_mut().for_each(|v| *v = rng.gen_range(lo as i32..hi as i32)),
167 _ => buf.iter_mut().for_each(|v| *v = rng.gen_range(-1000..1000)),
168 }
169 let bytes = unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const u8, n*4) };
170 bridge::tensor_from_bytes(bytes, shape, DType::I32, *device)
171 }
172
173 fn int_sort(t: IntTensor<MpsGraph>, dim: usize, desc: bool) -> IntTensor<MpsGraph> {
174 bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_sort(g,ph,dim as isize,desc) })
175 }
176
177 fn int_argsort(t: IntTensor<MpsGraph>, dim: usize, desc: bool) -> IntTensor<MpsGraph> {
178 let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_argsort(g,ph,dim as isize,desc) });
179 r.dtype = DType::I32; r
180 }
181
182 int_binary!(bitwise_and, "bitwiseANDWithPrimaryTensor:secondaryTensor:name:");
184 int_scalar!(bitwise_and_scalar, "bitwiseANDWithPrimaryTensor:secondaryTensor:name:");
185 int_binary!(bitwise_or, "bitwiseORWithPrimaryTensor:secondaryTensor:name:");
186 int_scalar!(bitwise_or_scalar, "bitwiseORWithPrimaryTensor:secondaryTensor:name:");
187 int_binary!(bitwise_xor, "bitwiseXORWithPrimaryTensor:secondaryTensor:name:");
188 int_scalar!(bitwise_xor_scalar, "bitwiseXORWithPrimaryTensor:secondaryTensor:name:");
189 fn bitwise_not(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
190 bridge::run_unary_ctx(&t, |g,ph| unsafe {
191 let ones = ffi::graph_constant_scalar(g, -1.0, burn_to_mps_dtype(t.dtype));
192 ffi::graph_binary(g, "bitwiseXORWithPrimaryTensor:secondaryTensor:name:", ph, ones)
193 })
194 }
195 int_binary!(bitwise_left_shift, "bitwiseLeftShiftWithPrimaryTensor:secondaryTensor:name:");
196 int_scalar!(bitwise_left_shift_scalar, "bitwiseLeftShiftWithPrimaryTensor:secondaryTensor:name:");
197 int_binary!(bitwise_right_shift, "bitwiseRightShiftWithPrimaryTensor:secondaryTensor:name:");
198 int_scalar!(bitwise_right_shift_scalar, "bitwiseRightShiftWithPrimaryTensor:secondaryTensor:name:");
199
200 fn int_unfold(t: IntTensor<MpsGraph>, dim: usize, size: usize, step: usize) -> IntTensor<MpsGraph> {
201 let n = t.shape[dim]; let nw = (n.saturating_sub(size))/step+1;
202 let mut wins = Vec::with_capacity(nw);
203 for i in 0..nw {
204 let st = i*step;
205 let slices: Vec<Slice> = (0..t.shape.num_dims()).map(|d| {
206 if d==dim { Slice::new(st as isize, Some((st+size) as isize), 1) }
207 else { Slice::new(0, Some(t.shape[d] as isize), 1) }
208 }).collect();
209 let w = Self::int_slice(t.clone(), &slices);
210 let mut dims: Vec<usize> = (0..w.shape.num_dims()).map(|d| w.shape[d]).collect();
211 dims[dim]=1; dims.push(size);
212 wins.push(Self::int_reshape(w, Shape::from(dims)));
213 }
214 Self::int_cat(wins, dim)
215 }
216
217 fn int_cat(tensors: Vec<IntTensor<MpsGraph>>, dim: usize) -> IntTensor<MpsGraph> {
218 if tensors.len()==1 { return tensors.into_iter().next().unwrap(); }
219 let refs: Vec<&MpsGraphTensor> = tensors.iter().collect();
220 bridge::run_multi_ctx(&refs, tensors[0].device, |g,phs| unsafe { ffi::graph_concat(g, ffi::ns_array(phs), dim as isize) })
221 }
222
223 fn int_cast(t: IntTensor<MpsGraph>, dtype: IntDType) -> IntTensor<MpsGraph> {
224 let dt: DType = dtype.into(); if t.dtype == dt { return t; }
225 let mps = burn_to_mps_dtype(dt);
226 let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,mps) });
227 r.dtype = dt; r
228 }
229}