Skip to main content

burn_mpsgraph/ops/
int_tensor.rs

1use burn_backend::ops::{FloatTensorOps, IntTensorOps};
2use burn_backend::tensor::{BoolTensor, Device, FloatTensor, IntTensor};
3use burn_backend::{DType, Distribution, ExecutionError, Scalar, TensorData};
4use burn_std::{IntDType, Shape, Slice};
5use std::future::Future;
6
7use crate::bridge::{self, burn_to_mps_dtype};
8use crate::ffi::{self};
9use crate::{MpsGraph, MpsGraphTensor};
10
11use super::float_tensor::{ADD, SUB, MUL, DIV, MOD, EQ, GT, GTE, LT, LTE};
12
13macro_rules! int_binary  { ($name:ident, $s:expr) => { fn $name(a: IntTensor<MpsGraph>, b: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> { bridge::run_binary(&a,&b,|g,pa,pb| unsafe{ffi::graph_binary(g,$s,pa,pb)}) } }; }
14macro_rules! int_scalar  { ($name:ident, $s:expr) => { fn $name(a: IntTensor<MpsGraph>, b: Scalar) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&a,|g,ph| unsafe{ let s=ffi::graph_constant_scalar(g,b.elem::<f64>(),burn_to_mps_dtype(a.dtype)); ffi::graph_binary(g,$s,ph,s) }) } }; }
15macro_rules! int_cmp     { ($name:ident, $s:expr) => { fn $name(a: IntTensor<MpsGraph>, b: IntTensor<MpsGraph>) -> BoolTensor<MpsGraph> { bridge::run_binary(&a,&b,|g,pa,pb| unsafe{ffi::graph_binary(g,$s,pa,pb)}) } }; }
16macro_rules! int_cmp_s   { ($name:ident, $s:expr) => { fn $name(a: IntTensor<MpsGraph>, b: Scalar) -> BoolTensor<MpsGraph> { bridge::run_unary_ctx(&a,|g,ph| unsafe{ let s=ffi::graph_constant_scalar(g,b.elem::<f64>(),burn_to_mps_dtype(a.dtype)); ffi::graph_binary(g,$s,ph,s) }) } }; }
17
18impl IntTensorOps<MpsGraph> for MpsGraph {
19    fn int_empty(shape: Shape, device: &Device<MpsGraph>, dtype: IntDType) -> IntTensor<MpsGraph> {
20        bridge::tensor_zeros(shape, dtype.into(), *device)
21    }
22
23    fn int_into_data(t: IntTensor<MpsGraph>) -> impl Future<Output = Result<TensorData, ExecutionError>> + Send {
24        async move { Ok(TensorData::from_bytes_vec(bridge::tensor_to_bytes(&t), t.shape.clone(), t.dtype)) }
25    }
26
27    fn int_from_data(data: TensorData, device: &Device<MpsGraph>) -> IntTensor<MpsGraph> {
28        bridge::tensor_from_bytes(data.as_bytes(), Shape::from(data.shape.clone()), data.dtype, *device)
29    }
30
31    fn int_device(t: &IntTensor<MpsGraph>) -> Device<MpsGraph> { t.device }
32    fn int_to_device(t: IntTensor<MpsGraph>, d: &Device<MpsGraph>) -> IntTensor<MpsGraph> { { let buf = unsafe { crate::ffi::retain(t.buffer) }; MpsGraphTensor { buffer: buf, shape: t.shape.clone(), dtype: t.dtype, device: *d } } }
33
34    fn int_reshape(t: IntTensor<MpsGraph>, shape: Shape) -> IntTensor<MpsGraph> {
35        let ns = bridge::shape_to_ns(&shape);
36        bridge::run_unary_ctx(&t, |g, ph| unsafe { ffi::graph_reshape(g, ph, ns) })
37    }
38
39    fn int_slice(t: IntTensor<MpsGraph>, slices: &[Slice]) -> IntTensor<MpsGraph> {
40        let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
41        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_slice(g,ph,sa,ea,st) })
42    }
43
44    fn int_slice_assign(t: IntTensor<MpsGraph>, slices: &[Slice], v: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
45        let (sa,ea,st) = bridge::slices_to_ns(slices, &t.shape);
46        bridge::run_binary_ctx(&t, &v, |g,pd,pu| unsafe { ffi::graph_slice_update(g,pd,pu,sa,ea,st) })
47    }
48
49    fn int_into_float(t: IntTensor<MpsGraph>) -> FloatTensor<MpsGraph> {
50        let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,ffi::MPSDataType::FLOAT32) });
51        r.dtype = DType::F32; r
52    }
53
54    fn int_mask_where(t: IntTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
55        bridge::run_ternary(&t,&m,&v, |g,pt,pm,pv| unsafe { ffi::graph_select(g,pm,pv,pt) })
56    }
57
58    fn int_mask_fill(t: IntTensor<MpsGraph>, m: BoolTensor<MpsGraph>, v: Scalar) -> IntTensor<MpsGraph> {
59        bridge::run_binary_ctx(&t,&m, |g,pt,pm| unsafe {
60            let s = ffi::graph_constant_scalar(g, v.elem::<f64>(), burn_to_mps_dtype(t.dtype));
61            ffi::graph_select(g, pm, s, pt)
62        })
63    }
64
65    fn int_gather(dim: usize, t: IntTensor<MpsGraph>, idx: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
66        bridge::run_binary_ctx(&t, &idx, |g,a,b| unsafe { ffi::graph_gather(g,a,b,dim,0) })
67    }
68
69    fn int_scatter_add(dim: usize, t: IntTensor<MpsGraph>, idx: IntTensor<MpsGraph>, v: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
70        bridge::run_multi_ctx(&[&t,&idx,&v], t.device, |g,phs| unsafe {
71            ffi::graph_scatter_along(g, dim as isize, phs[0], phs[2], phs[1], ffi::MPSGraphScatterMode::ADD)
72        })
73    }
74
75    fn int_select(t: IntTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> { Self::int_gather(dim,t,idx) }
76    fn int_select_add(t: IntTensor<MpsGraph>, dim: usize, idx: IntTensor<MpsGraph>, v: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> { Self::int_scatter_add(dim,t,idx,v) }
77
78    int_cmp!(int_equal, EQ); int_cmp_s!(int_equal_elem, EQ);
79    int_cmp!(int_greater, GT); int_cmp_s!(int_greater_elem, GT);
80    int_cmp!(int_greater_equal, GTE); int_cmp_s!(int_greater_equal_elem, GTE);
81    int_cmp!(int_lower, LT); int_cmp_s!(int_lower_elem, LT);
82    int_cmp!(int_lower_equal, LTE); int_cmp_s!(int_lower_equal_elem, LTE);
83
84    int_binary!(int_add, ADD); int_scalar!(int_add_scalar, ADD);
85    int_binary!(int_sub, SUB); int_scalar!(int_sub_scalar, SUB);
86    int_binary!(int_mul, MUL); int_scalar!(int_mul_scalar, MUL);
87    int_binary!(int_div, DIV); int_scalar!(int_div_scalar, DIV);
88    int_binary!(int_remainder, MOD); int_scalar!(int_remainder_scalar, MOD);
89
90    fn int_matmul(a: IntTensor<MpsGraph>, b: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
91        let af = Self::int_into_float(a); let bf = Self::int_into_float(b);
92        MpsGraph::float_into_int(MpsGraph::float_matmul(af, bf))
93    }
94
95    fn int_neg(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
96        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_unary(g, "negativeWithTensor:name:", ph) })
97    }
98
99    fn int_abs(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
100        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_unary(g, "absoluteWithTensor:name:", ph) })
101    }
102
103    fn int_swap_dims(t: IntTensor<MpsGraph>, d1: usize, d2: usize) -> IntTensor<MpsGraph> {
104        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_transpose(g,ph,d1,d2) })
105    }
106
107    fn int_permute(t: IntTensor<MpsGraph>, axes: &[usize]) -> IntTensor<MpsGraph> {
108        let p = unsafe { ffi::ns_usize_array(axes) };
109        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_permute(g,ph,p) })
110    }
111
112    fn int_flip(t: IntTensor<MpsGraph>, axes: &[usize]) -> IntTensor<MpsGraph> {
113        let nd = t.shape.num_dims(); let shape = &t.shape;
114        let starts: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { shape[d] as isize-1 } else { 0 }).collect();
115        let ends:   Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -(shape[d] as isize)-1 } else { shape[d] as isize }).collect();
116        let strides: Vec<isize> = (0..nd).map(|d| if axes.contains(&d) { -1 } else { 1 }).collect();
117        bridge::run_unary_ctx(&t, |g,ph| unsafe {
118            ffi::graph_slice_masked(g,ph, ffi::ns_isize_array(&starts), ffi::ns_isize_array(&ends), ffi::ns_isize_array(&strides), 0,0,0)
119        })
120    }
121
122    fn int_expand(t: IntTensor<MpsGraph>, shape: Shape) -> IntTensor<MpsGraph> {
123        let ns = bridge::shape_to_ns(&shape);
124        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_broadcast(g,ph,ns) })
125    }
126
127    fn int_sum(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
128        let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
129        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reduction_sum(g,ph, ffi::ns_isize_array(&axes)) })
130    }
131
132    fn int_sum_dim(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
133        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reduction_sum_axis(g,ph,dim as isize) })
134    }
135
136    fn int_prod(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
137        let axes: Vec<isize> = (0..t.shape.num_dims() as isize).collect();
138        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reduction_prod(g,ph, ffi::ns_isize_array(&axes)) })
139    }
140
141    fn int_prod_dim(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
142        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_reduction_prod_axis(g,ph,dim as isize) })
143    }
144
145    fn int_mean_dim(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
146        let n = t.shape[dim]; let sum = Self::int_sum_dim(t, dim);
147        Self::int_div_scalar(sum, Scalar::from(n as i64))
148    }
149
150    fn int_cumsum(t: IntTensor<MpsGraph>, d: usize) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumsum(g,ph,d as isize) }) }
151    fn int_cumprod(t: IntTensor<MpsGraph>, d: usize) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cumprod(g,ph,d as isize) }) }
152    fn int_cummin(t: IntTensor<MpsGraph>, d: usize) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummin(g,ph,d as isize) }) }
153    fn int_cummax(t: IntTensor<MpsGraph>, d: usize) -> IntTensor<MpsGraph> { bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cummax(g,ph,d as isize) }) }
154
155    fn int_argmax(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
156        let f = Self::int_into_float(t); MpsGraph::float_argmax(f, dim)
157    }
158    fn int_argmin(t: IntTensor<MpsGraph>, dim: usize) -> IntTensor<MpsGraph> {
159        let f = Self::int_into_float(t); MpsGraph::float_argmin(f, dim)
160    }
161
162    fn int_random(shape: Shape, dist: Distribution, device: &Device<MpsGraph>) -> IntTensor<MpsGraph> {
163        let n = shape.num_elements(); let mut buf = vec![0i32; n];
164        let mut rng = crate::ops::get_seeded_rng(); use rand::Rng;
165        match dist {
166            Distribution::Uniform(lo,hi) => buf.iter_mut().for_each(|v| *v = rng.gen_range(lo as i32..hi as i32)),
167            _ => buf.iter_mut().for_each(|v| *v = rng.gen_range(-1000..1000)),
168        }
169        let bytes = unsafe { std::slice::from_raw_parts(buf.as_ptr() as *const u8, n*4) };
170        bridge::tensor_from_bytes(bytes, shape, DType::I32, *device)
171    }
172
173    fn int_sort(t: IntTensor<MpsGraph>, dim: usize, desc: bool) -> IntTensor<MpsGraph> {
174        bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_sort(g,ph,dim as isize,desc) })
175    }
176
177    fn int_argsort(t: IntTensor<MpsGraph>, dim: usize, desc: bool) -> IntTensor<MpsGraph> {
178        let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_argsort(g,ph,dim as isize,desc) });
179        r.dtype = DType::I32; r
180    }
181
182    // ── Bitwise ─────────────────────────────────────────────────────────
183    int_binary!(bitwise_and, "bitwiseANDWithPrimaryTensor:secondaryTensor:name:");
184    int_scalar!(bitwise_and_scalar, "bitwiseANDWithPrimaryTensor:secondaryTensor:name:");
185    int_binary!(bitwise_or, "bitwiseORWithPrimaryTensor:secondaryTensor:name:");
186    int_scalar!(bitwise_or_scalar, "bitwiseORWithPrimaryTensor:secondaryTensor:name:");
187    int_binary!(bitwise_xor, "bitwiseXORWithPrimaryTensor:secondaryTensor:name:");
188    int_scalar!(bitwise_xor_scalar, "bitwiseXORWithPrimaryTensor:secondaryTensor:name:");
189    fn bitwise_not(t: IntTensor<MpsGraph>) -> IntTensor<MpsGraph> {
190        bridge::run_unary_ctx(&t, |g,ph| unsafe {
191            let ones = ffi::graph_constant_scalar(g, -1.0, burn_to_mps_dtype(t.dtype));
192            ffi::graph_binary(g, "bitwiseXORWithPrimaryTensor:secondaryTensor:name:", ph, ones)
193        })
194    }
195    int_binary!(bitwise_left_shift, "bitwiseLeftShiftWithPrimaryTensor:secondaryTensor:name:");
196    int_scalar!(bitwise_left_shift_scalar, "bitwiseLeftShiftWithPrimaryTensor:secondaryTensor:name:");
197    int_binary!(bitwise_right_shift, "bitwiseRightShiftWithPrimaryTensor:secondaryTensor:name:");
198    int_scalar!(bitwise_right_shift_scalar, "bitwiseRightShiftWithPrimaryTensor:secondaryTensor:name:");
199
200    fn int_unfold(t: IntTensor<MpsGraph>, dim: usize, size: usize, step: usize) -> IntTensor<MpsGraph> {
201        let n = t.shape[dim]; let nw = (n.saturating_sub(size))/step+1;
202        let mut wins = Vec::with_capacity(nw);
203        for i in 0..nw {
204            let st = i*step;
205            let slices: Vec<Slice> = (0..t.shape.num_dims()).map(|d| {
206                if d==dim { Slice::new(st as isize, Some((st+size) as isize), 1) }
207                else      { Slice::new(0, Some(t.shape[d] as isize), 1) }
208            }).collect();
209            let w = Self::int_slice(t.clone(), &slices);
210            let mut dims: Vec<usize> = (0..w.shape.num_dims()).map(|d| w.shape[d]).collect();
211            dims[dim]=1; dims.push(size);
212            wins.push(Self::int_reshape(w, Shape::from(dims)));
213        }
214        Self::int_cat(wins, dim)
215    }
216
217    fn int_cat(tensors: Vec<IntTensor<MpsGraph>>, dim: usize) -> IntTensor<MpsGraph> {
218        if tensors.len()==1 { return tensors.into_iter().next().unwrap(); }
219        let refs: Vec<&MpsGraphTensor> = tensors.iter().collect();
220        bridge::run_multi_ctx(&refs, tensors[0].device, |g,phs| unsafe { ffi::graph_concat(g, ffi::ns_array(phs), dim as isize) })
221    }
222
223    fn int_cast(t: IntTensor<MpsGraph>, dtype: IntDType) -> IntTensor<MpsGraph> {
224        let dt: DType = dtype.into(); if t.dtype == dt { return t; }
225        let mps = burn_to_mps_dtype(dt);
226        let mut r = bridge::run_unary_ctx(&t, |g,ph| unsafe { ffi::graph_cast(g,ph,mps) });
227        r.dtype = dt; r
228    }
229}