Skip to main content

burn_backend/tensor/ops/
int.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData,
6    get_device_settings,
7    ops::TransactionPrimitive,
8    tensor::{
9        BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
10        Ordered, TensorKind, TransactionOp,
11    },
12};
13
14impl<B: Backend> TransactionOp<B> for Int {
15    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
16        tr.register_int(tensor);
17    }
18}
19impl<B: Backend> BasicOps<B> for Int {
20    type Elem = B::IntElem;
21
22    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
23        B::int_empty(shape, device, dtype.into())
24    }
25
26    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
27        B::int_zeros(shape, device, dtype.into())
28    }
29    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
30        B::int_ones(shape, device, dtype.into())
31    }
32
33    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
34        B::int_full(shape, fill_value, device, dtype.into())
35    }
36
37    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
38        B::int_reshape(tensor, shape)
39    }
40
41    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
42        B::int_transpose(tensor)
43    }
44
45    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
46        B::int_swap_dims(tensor, dim1, dim2)
47    }
48
49    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
50        B::int_slice(tensor, slices)
51    }
52
53    fn slice_assign(
54        tensor: Self::Primitive,
55        slices: &[Slice],
56        value: Self::Primitive,
57    ) -> Self::Primitive {
58        B::int_slice_assign(tensor, slices, value)
59    }
60
61    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
62        B::int_select(tensor, dim, indices)
63    }
64
65    fn select_assign(
66        tensor: Self::Primitive,
67        dim: usize,
68        indices: IntTensor<B>,
69        values: Self::Primitive,
70        update: IndexingUpdateOp,
71    ) -> Self::Primitive {
72        match update {
73            IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
74            _ => unimplemented!(),
75        }
76    }
77
78    fn mask_where(
79        tensor: Self::Primitive,
80        mask: B::BoolTensorPrimitive,
81        source: Self::Primitive,
82    ) -> Self::Primitive {
83        B::int_mask_where(tensor, mask, source)
84    }
85
86    fn mask_fill(
87        tensor: Self::Primitive,
88        mask: B::BoolTensorPrimitive,
89        value: Scalar,
90    ) -> Self::Primitive {
91        B::int_mask_fill(tensor, mask, value)
92    }
93
94    fn gather(
95        dim: usize,
96        tensor: Self::Primitive,
97        indices: B::IntTensorPrimitive,
98    ) -> Self::Primitive {
99        B::int_gather(dim, tensor, indices)
100    }
101
102    fn scatter(
103        dim: usize,
104        tensor: Self::Primitive,
105        indices: B::IntTensorPrimitive,
106        values: Self::Primitive,
107        update: IndexingUpdateOp,
108    ) -> Self::Primitive {
109        match update {
110            IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
111            _ => unimplemented!(),
112        }
113    }
114
115    fn scatter_nd(
116        data: Self::Primitive,
117        indices: IntTensor<B>,
118        values: Self::Primitive,
119        reduction: IndexingUpdateOp,
120    ) -> Self::Primitive {
121        B::int_scatter_nd(data, indices, values, reduction)
122    }
123
124    fn gather_nd(data: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
125        B::int_gather_nd(data, indices)
126    }
127
128    fn device(tensor: &Self::Primitive) -> Device<B> {
129        B::int_device(tensor)
130    }
131
132    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
133        B::int_to_device(tensor, device)
134    }
135
136    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
137        B::int_into_data(tensor).await
138    }
139
140    fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
141        B::int_from_data(data.convert_dtype(dtype), device)
142    }
143
144    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
145        B::int_repeat_dim(tensor, dim, times)
146    }
147
148    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
149        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
150        B::int_equal(lhs, rhs, out_dtype)
151    }
152
153    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
154        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
155        B::int_not_equal(lhs, rhs, out_dtype)
156    }
157
158    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
159        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
160        B::int_equal_elem(lhs, rhs, out_dtype)
161    }
162
163    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
164        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
165        B::int_not_equal_elem(lhs, rhs, out_dtype)
166    }
167
168    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
169        B::int_cat(vectors, dim)
170    }
171
172    fn any(tensor: Self::Primitive) -> BoolTensor<B> {
173        let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
174        B::int_any(tensor, out_dtype)
175    }
176
177    fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
178        let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
179        B::int_any_dim(tensor, dim, out_dtype)
180    }
181
182    fn all(tensor: Self::Primitive) -> BoolTensor<B> {
183        let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
184        B::int_all(tensor, out_dtype)
185    }
186
187    fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
188        let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
189        B::int_all_dim(tensor, dim, out_dtype)
190    }
191
192    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
193        B::int_permute(tensor, axes)
194    }
195
196    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
197        B::int_expand(tensor, shape)
198    }
199
200    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
201        B::int_flip(tensor, axes)
202    }
203
204    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
205        B::int_unfold(tensor, dim, size, step)
206    }
207}
208
209impl<B: Backend> Numeric<B> for Int {
210    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
211        B::int_add(lhs, rhs)
212    }
213    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
214        B::int_add_scalar(lhs, rhs)
215    }
216    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
217        B::int_sub(lhs, rhs)
218    }
219    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
220        B::int_sub_scalar(lhs, rhs)
221    }
222    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
223        B::int_div(lhs, rhs)
224    }
225    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
226        B::int_div_scalar(lhs, rhs)
227    }
228    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
229        B::int_remainder(lhs, rhs)
230    }
231    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
232        B::int_remainder_scalar(lhs, rhs)
233    }
234    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
235        B::int_mul(lhs, rhs)
236    }
237    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
238        B::int_mul_scalar(lhs, rhs)
239    }
240    fn neg(tensor: Self::Primitive) -> Self::Primitive {
241        B::int_neg(tensor)
242    }
243
244    fn sum(tensor: Self::Primitive) -> Self::Primitive {
245        B::int_sum(tensor)
246    }
247
248    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
249        B::int_sum_dim(tensor, dim)
250    }
251
252    fn prod(tensor: Self::Primitive) -> Self::Primitive {
253        B::int_prod(tensor)
254    }
255
256    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
257        B::int_prod_dim(tensor, dim)
258    }
259
260    fn mean(tensor: Self::Primitive) -> Self::Primitive {
261        B::int_mean(tensor)
262    }
263    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
264        B::int_mean_dim(tensor, dim)
265    }
266    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
267        B::int_cumsum(tensor, dim)
268    }
269    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
270        B::int_cumprod(tensor, dim)
271    }
272
273    fn abs(tensor: Self::Primitive) -> Self::Primitive {
274        B::int_abs(tensor)
275    }
276
277    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
278        B::int_powi(lhs, rhs)
279    }
280
281    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
282        B::int_powi_scalar(lhs, rhs)
283    }
284
285    fn random(
286        shape: Shape,
287        distribution: Distribution,
288        device: &Device<B>,
289        dtype: DType,
290    ) -> Self::Primitive {
291        B::int_random(shape, distribution, device, dtype.into())
292    }
293
294    fn sign(tensor: Self::Primitive) -> Self::Primitive {
295        B::int_sign(tensor)
296    }
297
298    /// Applies the matrix multiplication operation.
299    ///
300    /// `C = AB`
301    ///
302    /// # Panics
303    ///
304    /// If the two tensors don't have a compatible shape.
305    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
306        B::int_matmul(lhs, rhs)
307    }
308}
309
310impl<B: Backend> Ordered<B> for Int {
311    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
312        B::int_sort(tensor, dim, descending)
313    }
314
315    fn sort_with_indices(
316        tensor: Self::Primitive,
317        dim: usize,
318        descending: bool,
319    ) -> (Self::Primitive, IntTensor<B>) {
320        B::int_sort_with_indices(tensor, dim, descending)
321    }
322
323    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
324        B::int_argsort(tensor, dim, descending)
325    }
326
327    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
328        B::int_cummin(tensor, dim)
329    }
330
331    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
332        B::int_cummax(tensor, dim)
333    }
334
335    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
336        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
337        B::int_greater(lhs, rhs, out_dtype)
338    }
339
340    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
341        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
342        B::int_greater_elem(lhs, rhs, out_dtype)
343    }
344
345    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
346        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
347        B::int_greater_equal(lhs, rhs, out_dtype)
348    }
349
350    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
351        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
352        B::int_greater_equal_elem(lhs, rhs, out_dtype)
353    }
354
355    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
356        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
357        B::int_lower(lhs, rhs, out_dtype)
358    }
359
360    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
361        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
362        B::int_lower_elem(lhs, rhs, out_dtype)
363    }
364
365    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
366        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
367        B::int_lower_equal(lhs, rhs, out_dtype)
368    }
369
370    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
371        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
372        B::int_lower_equal_elem(lhs, rhs, out_dtype)
373    }
374
375    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
376        B::int_argmax(tensor, dim)
377    }
378
379    fn argtopk(tensor: Self::Primitive, dim: usize, k: usize) -> IntTensor<B> {
380        B::int_argtopk(tensor, dim, k)
381    }
382
383    fn topk(tensor: Self::Primitive, dim: usize, k: usize) -> IntTensor<B> {
384        B::int_topk(tensor, dim, k)
385    }
386
387    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
388        B::int_argmin(tensor, dim)
389    }
390
391    fn max(tensor: Self::Primitive) -> Self::Primitive {
392        B::int_max(tensor)
393    }
394
395    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
396        B::int_max_dim(tensor, dim)
397    }
398
399    fn max_dim_with_indices(
400        tensor: Self::Primitive,
401        dim: usize,
402    ) -> (Self::Primitive, IntTensor<B>) {
403        B::int_max_dim_with_indices(tensor, dim)
404    }
405
406    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
407        B::int_max_abs(tensor)
408    }
409
410    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
411        B::int_max_abs_dim(tensor, dim)
412    }
413
414    fn min(tensor: Self::Primitive) -> Self::Primitive {
415        B::int_min(tensor)
416    }
417
418    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
419        B::int_min_dim(tensor, dim)
420    }
421
422    fn min_dim_with_indices(
423        tensor: Self::Primitive,
424        dim: usize,
425    ) -> (Self::Primitive, IntTensor<B>) {
426        B::int_min_dim_with_indices(tensor, dim)
427    }
428
429    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
430        B::int_clamp(tensor, min, max)
431    }
432
433    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
434        B::int_clamp_min(tensor, min)
435    }
436
437    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
438        B::int_clamp_max(tensor, max)
439    }
440}
441
442impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
443    type InnerKind = Int;
444
445    fn inner(
446        tensor: <Self as TensorKind<B>>::Primitive,
447    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
448        B::int_inner(tensor)
449    }
450
451    fn from_inner(
452        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
453    ) -> <Self as TensorKind<B>>::Primitive {
454        B::int_from_inner(inner)
455    }
456}