Skip to main content

burn_backend/tensor/ops/
int.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData,
6    get_device_settings,
7    ops::TransactionPrimitive,
8    tensor::{
9        BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
10        Ordered, TensorKind, TransactionOp,
11    },
12};
13
14impl<B: Backend> TransactionOp<B> for Int {
15    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
16        tr.register_int(tensor);
17    }
18}
19impl<B: Backend> BasicOps<B> for Int {
20    type Elem = B::IntElem;
21
22    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
23        B::int_empty(shape, device, dtype.into())
24    }
25
26    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
27        B::int_zeros(shape, device, dtype.into())
28    }
29    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
30        B::int_ones(shape, device, dtype.into())
31    }
32
33    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
34        B::int_full(shape, fill_value, device, dtype.into())
35    }
36
37    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
38        B::int_reshape(tensor, shape)
39    }
40
41    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
42        B::int_transpose(tensor)
43    }
44
45    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
46        B::int_swap_dims(tensor, dim1, dim2)
47    }
48
49    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
50        B::int_slice(tensor, slices)
51    }
52
53    fn slice_assign(
54        tensor: Self::Primitive,
55        slices: &[Slice],
56        value: Self::Primitive,
57    ) -> Self::Primitive {
58        B::int_slice_assign(tensor, slices, value)
59    }
60
61    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
62        B::int_select(tensor, dim, indices)
63    }
64
65    fn select_assign(
66        tensor: Self::Primitive,
67        dim: usize,
68        indices: IntTensor<B>,
69        values: Self::Primitive,
70        update: IndexingUpdateOp,
71    ) -> Self::Primitive {
72        match update {
73            IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
74        }
75    }
76
77    fn mask_where(
78        tensor: Self::Primitive,
79        mask: B::BoolTensorPrimitive,
80        source: Self::Primitive,
81    ) -> Self::Primitive {
82        B::int_mask_where(tensor, mask, source)
83    }
84
85    fn mask_fill(
86        tensor: Self::Primitive,
87        mask: B::BoolTensorPrimitive,
88        value: Scalar,
89    ) -> Self::Primitive {
90        B::int_mask_fill(tensor, mask, value)
91    }
92
93    fn gather(
94        dim: usize,
95        tensor: Self::Primitive,
96        indices: B::IntTensorPrimitive,
97    ) -> Self::Primitive {
98        B::int_gather(dim, tensor, indices)
99    }
100
101    fn scatter(
102        dim: usize,
103        tensor: Self::Primitive,
104        indices: B::IntTensorPrimitive,
105        values: Self::Primitive,
106        update: IndexingUpdateOp,
107    ) -> Self::Primitive {
108        match update {
109            IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
110        }
111    }
112
113    fn device(tensor: &Self::Primitive) -> Device<B> {
114        B::int_device(tensor)
115    }
116
117    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
118        B::int_to_device(tensor, device)
119    }
120
121    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
122        B::int_into_data(tensor).await
123    }
124
125    fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
126        B::int_from_data(data.convert_dtype(dtype), device)
127    }
128
129    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
130        B::int_repeat_dim(tensor, dim, times)
131    }
132
133    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
134        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
135        B::int_equal(lhs, rhs, out_dtype)
136    }
137
138    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
139        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
140        B::int_not_equal(lhs, rhs, out_dtype)
141    }
142
143    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
144        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
145        B::int_equal_elem(lhs, rhs, out_dtype)
146    }
147
148    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
149        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
150        B::int_not_equal_elem(lhs, rhs, out_dtype)
151    }
152
153    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
154        B::int_cat(vectors, dim)
155    }
156
157    fn any(tensor: Self::Primitive) -> BoolTensor<B> {
158        let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
159        B::int_any(tensor, out_dtype)
160    }
161
162    fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
163        let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
164        B::int_any_dim(tensor, dim, out_dtype)
165    }
166
167    fn all(tensor: Self::Primitive) -> BoolTensor<B> {
168        let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
169        B::int_all(tensor, out_dtype)
170    }
171
172    fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
173        let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
174        B::int_all_dim(tensor, dim, out_dtype)
175    }
176
177    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
178        B::int_permute(tensor, axes)
179    }
180
181    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
182        B::int_expand(tensor, shape)
183    }
184
185    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
186        B::int_flip(tensor, axes)
187    }
188
189    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
190        B::int_unfold(tensor, dim, size, step)
191    }
192}
193
194impl<B: Backend> Numeric<B> for Int {
195    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
196        B::int_add(lhs, rhs)
197    }
198    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
199        B::int_add_scalar(lhs, rhs)
200    }
201    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
202        B::int_sub(lhs, rhs)
203    }
204    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
205        B::int_sub_scalar(lhs, rhs)
206    }
207    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
208        B::int_div(lhs, rhs)
209    }
210    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
211        B::int_div_scalar(lhs, rhs)
212    }
213    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
214        B::int_remainder(lhs, rhs)
215    }
216    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
217        B::int_remainder_scalar(lhs, rhs)
218    }
219    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
220        B::int_mul(lhs, rhs)
221    }
222    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
223        B::int_mul_scalar(lhs, rhs)
224    }
225    fn neg(tensor: Self::Primitive) -> Self::Primitive {
226        B::int_neg(tensor)
227    }
228
229    fn sum(tensor: Self::Primitive) -> Self::Primitive {
230        B::int_sum(tensor)
231    }
232
233    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
234        B::int_sum_dim(tensor, dim)
235    }
236
237    fn prod(tensor: Self::Primitive) -> Self::Primitive {
238        B::int_prod(tensor)
239    }
240
241    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
242        B::int_prod_dim(tensor, dim)
243    }
244
245    fn mean(tensor: Self::Primitive) -> Self::Primitive {
246        B::int_mean(tensor)
247    }
248    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
249        B::int_mean_dim(tensor, dim)
250    }
251    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
252        B::int_cumsum(tensor, dim)
253    }
254    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
255        B::int_cumprod(tensor, dim)
256    }
257
258    fn abs(tensor: Self::Primitive) -> Self::Primitive {
259        B::int_abs(tensor)
260    }
261
262    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
263        B::int_powi(lhs, rhs)
264    }
265
266    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
267        B::int_powi_scalar(lhs, rhs)
268    }
269
270    fn random(
271        shape: Shape,
272        distribution: Distribution,
273        device: &Device<B>,
274        dtype: DType,
275    ) -> Self::Primitive {
276        B::int_random(shape, distribution, device, dtype.into())
277    }
278
279    fn sign(tensor: Self::Primitive) -> Self::Primitive {
280        B::int_sign(tensor)
281    }
282
283    /// Applies the matrix multiplication operation.
284    ///
285    /// `C = AB`
286    ///
287    /// # Panics
288    ///
289    /// If the two tensors don't have a compatible shape.
290    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
291        B::int_matmul(lhs, rhs)
292    }
293}
294
295impl<B: Backend> Ordered<B> for Int {
296    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
297        B::int_sort(tensor, dim, descending)
298    }
299
300    fn sort_with_indices(
301        tensor: Self::Primitive,
302        dim: usize,
303        descending: bool,
304    ) -> (Self::Primitive, IntTensor<B>) {
305        B::int_sort_with_indices(tensor, dim, descending)
306    }
307
308    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
309        B::int_argsort(tensor, dim, descending)
310    }
311
312    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
313        B::int_cummin(tensor, dim)
314    }
315
316    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
317        B::int_cummax(tensor, dim)
318    }
319
320    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
321        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
322        B::int_greater(lhs, rhs, out_dtype)
323    }
324
325    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
326        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
327        B::int_greater_elem(lhs, rhs, out_dtype)
328    }
329
330    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
331        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
332        B::int_greater_equal(lhs, rhs, out_dtype)
333    }
334
335    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
336        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
337        B::int_greater_equal_elem(lhs, rhs, out_dtype)
338    }
339
340    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
341        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
342        B::int_lower(lhs, rhs, out_dtype)
343    }
344
345    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
346        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
347        B::int_lower_elem(lhs, rhs, out_dtype)
348    }
349
350    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
351        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
352        B::int_lower_equal(lhs, rhs, out_dtype)
353    }
354
355    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
356        let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
357        B::int_lower_equal_elem(lhs, rhs, out_dtype)
358    }
359
360    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
361        B::int_argmax(tensor, dim)
362    }
363
364    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
365        B::int_argmin(tensor, dim)
366    }
367
368    fn max(tensor: Self::Primitive) -> Self::Primitive {
369        B::int_max(tensor)
370    }
371
372    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
373        B::int_max_dim(tensor, dim)
374    }
375
376    fn max_dim_with_indices(
377        tensor: Self::Primitive,
378        dim: usize,
379    ) -> (Self::Primitive, IntTensor<B>) {
380        B::int_max_dim_with_indices(tensor, dim)
381    }
382
383    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
384        B::int_max_abs(tensor)
385    }
386
387    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
388        B::int_max_abs_dim(tensor, dim)
389    }
390
391    fn min(tensor: Self::Primitive) -> Self::Primitive {
392        B::int_min(tensor)
393    }
394
395    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
396        B::int_min_dim(tensor, dim)
397    }
398
399    fn min_dim_with_indices(
400        tensor: Self::Primitive,
401        dim: usize,
402    ) -> (Self::Primitive, IntTensor<B>) {
403        B::int_min_dim_with_indices(tensor, dim)
404    }
405
406    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
407        B::int_clamp(tensor, min, max)
408    }
409
410    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
411        B::int_clamp_min(tensor, min)
412    }
413
414    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
415        B::int_clamp_max(tensor, max)
416    }
417}
418
419impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
420    type InnerKind = Int;
421
422    fn inner(
423        tensor: <Self as TensorKind<B>>::Primitive,
424    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
425        B::int_inner(tensor)
426    }
427
428    fn from_inner(
429        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
430    ) -> <Self as TensorKind<B>>::Primitive {
431        B::int_from_inner(inner)
432    }
433}