burn_backend/tensor/ops/
int.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, Distribution, ExecutionError, TensorData,
6    element::ElementConversion,
7    ops::TransactionPrimitive,
8    tensor::{
9        BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
10        TensorKind,
11    },
12};
13
14impl<B: Backend> BasicOps<B> for Int {
15    type Elem = B::IntElem;
16
17    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
18        B::int_empty(shape, device, dtype.into())
19    }
20
21    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
22        B::int_zeros(shape, device, dtype.into())
23    }
24    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
25        B::int_ones(shape, device, dtype.into())
26    }
27
28    fn full<E: ElementConversion>(
29        shape: Shape,
30        fill_value: E,
31        device: &Device<B>,
32        dtype: DType,
33    ) -> Self::Primitive {
34        B::int_full(shape, fill_value.elem(), device, dtype.into())
35    }
36
37    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
38        tr.register_int(tensor);
39    }
40
41    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
42        B::int_reshape(tensor, shape)
43    }
44
45    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
46        B::int_transpose(tensor)
47    }
48
49    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
50        B::int_swap_dims(tensor, dim1, dim2)
51    }
52
53    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
54        B::int_slice(tensor, slices)
55    }
56
57    fn slice_assign(
58        tensor: Self::Primitive,
59        slices: &[Slice],
60        value: Self::Primitive,
61    ) -> Self::Primitive {
62        B::int_slice_assign(tensor, slices, value)
63    }
64
65    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
66        B::int_select(tensor, dim, indices)
67    }
68
69    fn select_assign(
70        tensor: Self::Primitive,
71        dim: usize,
72        indices: IntTensor<B>,
73        values: Self::Primitive,
74        update: IndexingUpdateOp,
75    ) -> Self::Primitive {
76        match update {
77            IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
78        }
79    }
80
81    fn mask_where(
82        tensor: Self::Primitive,
83        mask: B::BoolTensorPrimitive,
84        source: Self::Primitive,
85    ) -> Self::Primitive {
86        B::int_mask_where(tensor, mask, source)
87    }
88
89    fn mask_fill(
90        tensor: Self::Primitive,
91        mask: B::BoolTensorPrimitive,
92        value: Self::Elem,
93    ) -> Self::Primitive {
94        B::int_mask_fill(tensor, mask, value)
95    }
96
97    fn gather(
98        dim: usize,
99        tensor: Self::Primitive,
100        indices: B::IntTensorPrimitive,
101    ) -> Self::Primitive {
102        B::int_gather(dim, tensor, indices)
103    }
104
105    fn scatter(
106        dim: usize,
107        tensor: Self::Primitive,
108        indices: B::IntTensorPrimitive,
109        values: Self::Primitive,
110        update: IndexingUpdateOp,
111    ) -> Self::Primitive {
112        match update {
113            IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
114        }
115    }
116
117    fn device(tensor: &Self::Primitive) -> Device<B> {
118        B::int_device(tensor)
119    }
120
121    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
122        B::int_to_device(tensor, device)
123    }
124
125    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
126        B::int_into_data(tensor).await
127    }
128
129    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
130        B::int_from_data(data.convert::<B::IntElem>(), device)
131    }
132
133    fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
134        if !dtype.is_int() {
135            panic!("Expected int dtype, got {dtype:?}")
136        }
137
138        B::int_from_data(data.convert_dtype(dtype), device)
139    }
140
141    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
142        B::int_repeat_dim(tensor, dim, times)
143    }
144
145    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
146        B::int_equal(lhs, rhs)
147    }
148
149    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
150        B::int_not_equal(lhs, rhs)
151    }
152
153    fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
154        B::int_equal_elem(lhs, rhs)
155    }
156
157    fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
158        B::int_not_equal_elem(lhs, rhs)
159    }
160
161    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
162        B::int_cat(vectors, dim)
163    }
164
165    fn any(tensor: Self::Primitive) -> BoolTensor<B> {
166        B::int_any(tensor)
167    }
168
169    fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
170        B::int_any_dim(tensor, dim)
171    }
172
173    fn all(tensor: Self::Primitive) -> BoolTensor<B> {
174        B::int_all(tensor)
175    }
176
177    fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
178        B::int_all_dim(tensor, dim)
179    }
180
181    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
182        B::int_permute(tensor, axes)
183    }
184
185    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
186        B::int_expand(tensor, shape)
187    }
188
189    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
190        B::int_flip(tensor, axes)
191    }
192
193    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
194        B::int_unfold(tensor, dim, size, step)
195    }
196}
197
198impl<B: Backend> Numeric<B> for Int {
199    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
200        B::int_add(lhs, rhs)
201    }
202    fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
203        B::int_add_scalar(lhs, rhs.elem())
204    }
205    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
206        B::int_sub(lhs, rhs)
207    }
208    fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
209        B::int_sub_scalar(lhs, rhs.elem())
210    }
211    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
212        B::int_div(lhs, rhs)
213    }
214    fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
215        B::int_div_scalar(lhs, rhs.elem())
216    }
217    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
218        B::int_remainder(lhs, rhs)
219    }
220    fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
221        B::int_remainder_scalar(lhs, rhs.elem())
222    }
223    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
224        B::int_mul(lhs, rhs)
225    }
226    fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
227        B::int_mul_scalar(lhs, rhs.elem())
228    }
229    fn neg(tensor: Self::Primitive) -> Self::Primitive {
230        B::int_neg(tensor)
231    }
232
233    fn sum(tensor: Self::Primitive) -> Self::Primitive {
234        B::int_sum(tensor)
235    }
236
237    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
238        B::int_sum_dim(tensor, dim)
239    }
240
241    fn prod(tensor: Self::Primitive) -> Self::Primitive {
242        B::int_prod(tensor)
243    }
244
245    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
246        B::int_prod_dim(tensor, dim)
247    }
248
249    fn mean(tensor: Self::Primitive) -> Self::Primitive {
250        B::int_mean(tensor)
251    }
252    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
253        B::int_mean_dim(tensor, dim)
254    }
255    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
256        B::int_cumsum(tensor, dim)
257    }
258    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
259        B::int_cumprod(tensor, dim)
260    }
261
262    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
263        B::int_cummin(tensor, dim)
264    }
265
266    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
267        B::int_cummax(tensor, dim)
268    }
269
270    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
271        B::int_greater(lhs, rhs)
272    }
273
274    fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
275        B::int_greater_elem(lhs, rhs)
276    }
277
278    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
279        B::int_greater_equal(lhs, rhs)
280    }
281
282    fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
283        B::int_greater_equal_elem(lhs, rhs)
284    }
285
286    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
287        B::int_lower(lhs, rhs)
288    }
289
290    fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
291        B::int_lower_elem(lhs, rhs)
292    }
293
294    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
295        B::int_lower_equal(lhs, rhs)
296    }
297
298    fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
299        B::int_lower_equal_elem(lhs, rhs)
300    }
301
302    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
303        B::int_argmax(tensor, dim)
304    }
305
306    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
307        B::int_argmin(tensor, dim)
308    }
309
310    fn max(tensor: Self::Primitive) -> Self::Primitive {
311        B::int_max(tensor)
312    }
313
314    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
315        B::int_max_dim(tensor, dim)
316    }
317
318    fn max_dim_with_indices(
319        tensor: Self::Primitive,
320        dim: usize,
321    ) -> (Self::Primitive, IntTensor<B>) {
322        B::int_max_dim_with_indices(tensor, dim)
323    }
324
325    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
326        B::int_max_abs(tensor)
327    }
328
329    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
330        B::int_max_abs_dim(tensor, dim)
331    }
332
333    fn min(tensor: Self::Primitive) -> Self::Primitive {
334        B::int_min(tensor)
335    }
336
337    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
338        B::int_min_dim(tensor, dim)
339    }
340
341    fn min_dim_with_indices(
342        tensor: Self::Primitive,
343        dim: usize,
344    ) -> (Self::Primitive, IntTensor<B>) {
345        B::int_min_dim_with_indices(tensor, dim)
346    }
347
348    fn clamp(tensor: Self::Primitive, min: B::IntElem, max: B::IntElem) -> Self::Primitive {
349        B::int_clamp(tensor, min, max)
350    }
351
352    fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive {
353        B::int_clamp_min(tensor, min)
354    }
355
356    fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive {
357        B::int_clamp_max(tensor, max)
358    }
359
360    fn abs(tensor: Self::Primitive) -> Self::Primitive {
361        B::int_abs(tensor)
362    }
363
364    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
365        B::int_powf(lhs, B::int_into_float(rhs))
366    }
367
368    fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
369        B::int_powf_scalar(lhs, rhs.elem())
370    }
371
372    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
373        B::int_powi(lhs, rhs)
374    }
375
376    fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
377        B::int_powi_scalar(lhs, rhs.elem())
378    }
379
380    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
381        B::int_random(shape, distribution, device)
382    }
383
384    fn sign(tensor: Self::Primitive) -> Self::Primitive {
385        B::int_sign(tensor)
386    }
387
388    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
389        B::int_sort(tensor, dim, descending)
390    }
391
392    fn sort_with_indices(
393        tensor: Self::Primitive,
394        dim: usize,
395        descending: bool,
396    ) -> (Self::Primitive, IntTensor<B>) {
397        B::int_sort_with_indices(tensor, dim, descending)
398    }
399
400    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
401        B::int_argsort(tensor, dim, descending)
402    }
403
404    /// Applies the matrix multiplication operation.
405    ///
406    /// `C = AB`
407    ///
408    /// # Panics
409    ///
410    /// If the two tensors don't have a compatible shape.
411    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
412        B::int_matmul(lhs, rhs)
413    }
414}
415
416impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
417    type InnerKind = Int;
418
419    fn inner(
420        tensor: <Self as TensorKind<B>>::Primitive,
421    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
422        B::int_inner(tensor)
423    }
424
425    fn from_inner(
426        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
427    ) -> <Self as TensorKind<B>>::Primitive {
428        B::int_from_inner(inner)
429    }
430}