Skip to main content

burn_backend/tensor/ops/
int.rs

1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5    AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData,
6    ops::TransactionPrimitive,
7    tensor::{
8        BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
9        Ordered, TensorKind,
10    },
11};
12
13impl<B: Backend> BasicOps<B> for Int {
14    type Elem = B::IntElem;
15
16    fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
17        B::int_empty(shape, device, dtype.into())
18    }
19
20    fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
21        B::int_zeros(shape, device, dtype.into())
22    }
23    fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
24        B::int_ones(shape, device, dtype.into())
25    }
26
27    fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
28        B::int_full(shape, fill_value, device, dtype.into())
29    }
30
31    fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
32        tr.register_int(tensor);
33    }
34
35    fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
36        B::int_reshape(tensor, shape)
37    }
38
39    fn transpose(tensor: Self::Primitive) -> Self::Primitive {
40        B::int_transpose(tensor)
41    }
42
43    fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
44        B::int_swap_dims(tensor, dim1, dim2)
45    }
46
47    fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
48        B::int_slice(tensor, slices)
49    }
50
51    fn slice_assign(
52        tensor: Self::Primitive,
53        slices: &[Slice],
54        value: Self::Primitive,
55    ) -> Self::Primitive {
56        B::int_slice_assign(tensor, slices, value)
57    }
58
59    fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
60        B::int_select(tensor, dim, indices)
61    }
62
63    fn select_assign(
64        tensor: Self::Primitive,
65        dim: usize,
66        indices: IntTensor<B>,
67        values: Self::Primitive,
68        update: IndexingUpdateOp,
69    ) -> Self::Primitive {
70        match update {
71            IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
72        }
73    }
74
75    fn mask_where(
76        tensor: Self::Primitive,
77        mask: B::BoolTensorPrimitive,
78        source: Self::Primitive,
79    ) -> Self::Primitive {
80        B::int_mask_where(tensor, mask, source)
81    }
82
83    fn mask_fill(
84        tensor: Self::Primitive,
85        mask: B::BoolTensorPrimitive,
86        value: Scalar,
87    ) -> Self::Primitive {
88        B::int_mask_fill(tensor, mask, value)
89    }
90
91    fn gather(
92        dim: usize,
93        tensor: Self::Primitive,
94        indices: B::IntTensorPrimitive,
95    ) -> Self::Primitive {
96        B::int_gather(dim, tensor, indices)
97    }
98
99    fn scatter(
100        dim: usize,
101        tensor: Self::Primitive,
102        indices: B::IntTensorPrimitive,
103        values: Self::Primitive,
104        update: IndexingUpdateOp,
105    ) -> Self::Primitive {
106        match update {
107            IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
108        }
109    }
110
111    fn device(tensor: &Self::Primitive) -> Device<B> {
112        B::int_device(tensor)
113    }
114
115    fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
116        B::int_to_device(tensor, device)
117    }
118
119    async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
120        B::int_into_data(tensor).await
121    }
122
123    fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
124        B::int_from_data(data.convert::<B::IntElem>(), device)
125    }
126
127    fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
128        if !dtype.is_int() {
129            panic!("Expected int dtype, got {dtype:?}")
130        }
131
132        B::int_from_data(data.convert_dtype(dtype), device)
133    }
134
135    fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
136        B::int_repeat_dim(tensor, dim, times)
137    }
138
139    fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
140        B::int_equal(lhs, rhs)
141    }
142
143    fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
144        B::int_not_equal(lhs, rhs)
145    }
146
147    fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
148        B::int_equal_elem(lhs, rhs)
149    }
150
151    fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
152        B::int_not_equal_elem(lhs, rhs)
153    }
154
155    fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
156        B::int_cat(vectors, dim)
157    }
158
159    fn any(tensor: Self::Primitive) -> BoolTensor<B> {
160        B::int_any(tensor)
161    }
162
163    fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
164        B::int_any_dim(tensor, dim)
165    }
166
167    fn all(tensor: Self::Primitive) -> BoolTensor<B> {
168        B::int_all(tensor)
169    }
170
171    fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
172        B::int_all_dim(tensor, dim)
173    }
174
175    fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
176        B::int_permute(tensor, axes)
177    }
178
179    fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
180        B::int_expand(tensor, shape)
181    }
182
183    fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
184        B::int_flip(tensor, axes)
185    }
186
187    fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
188        B::int_unfold(tensor, dim, size, step)
189    }
190}
191
192impl<B: Backend> Numeric<B> for Int {
193    fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
194        B::int_add(lhs, rhs)
195    }
196    fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
197        B::int_add_scalar(lhs, rhs)
198    }
199    fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
200        B::int_sub(lhs, rhs)
201    }
202    fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
203        B::int_sub_scalar(lhs, rhs)
204    }
205    fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
206        B::int_div(lhs, rhs)
207    }
208    fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
209        B::int_div_scalar(lhs, rhs)
210    }
211    fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
212        B::int_remainder(lhs, rhs)
213    }
214    fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
215        B::int_remainder_scalar(lhs, rhs)
216    }
217    fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
218        B::int_mul(lhs, rhs)
219    }
220    fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
221        B::int_mul_scalar(lhs, rhs)
222    }
223    fn neg(tensor: Self::Primitive) -> Self::Primitive {
224        B::int_neg(tensor)
225    }
226
227    fn sum(tensor: Self::Primitive) -> Self::Primitive {
228        B::int_sum(tensor)
229    }
230
231    fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
232        B::int_sum_dim(tensor, dim)
233    }
234
235    fn prod(tensor: Self::Primitive) -> Self::Primitive {
236        B::int_prod(tensor)
237    }
238
239    fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
240        B::int_prod_dim(tensor, dim)
241    }
242
243    fn mean(tensor: Self::Primitive) -> Self::Primitive {
244        B::int_mean(tensor)
245    }
246    fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
247        B::int_mean_dim(tensor, dim)
248    }
249    fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
250        B::int_cumsum(tensor, dim)
251    }
252    fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
253        B::int_cumprod(tensor, dim)
254    }
255
256    fn abs(tensor: Self::Primitive) -> Self::Primitive {
257        B::int_abs(tensor)
258    }
259
260    fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
261        B::int_powf(lhs, B::int_into_float(rhs))
262    }
263
264    fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
265        B::int_powf_scalar(lhs, rhs)
266    }
267
268    fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
269        B::int_powi(lhs, rhs)
270    }
271
272    fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
273        B::int_powi_scalar(lhs, rhs)
274    }
275
276    fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
277        B::int_random(shape, distribution, device)
278    }
279
280    fn sign(tensor: Self::Primitive) -> Self::Primitive {
281        B::int_sign(tensor)
282    }
283
284    /// Applies the matrix multiplication operation.
285    ///
286    /// `C = AB`
287    ///
288    /// # Panics
289    ///
290    /// If the two tensors don't have a compatible shape.
291    fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
292        B::int_matmul(lhs, rhs)
293    }
294}
295
296impl<B: Backend> Ordered<B> for Int {
297    fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
298        B::int_sort(tensor, dim, descending)
299    }
300
301    fn sort_with_indices(
302        tensor: Self::Primitive,
303        dim: usize,
304        descending: bool,
305    ) -> (Self::Primitive, IntTensor<B>) {
306        B::int_sort_with_indices(tensor, dim, descending)
307    }
308
309    fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
310        B::int_argsort(tensor, dim, descending)
311    }
312
313    fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
314        B::int_cummin(tensor, dim)
315    }
316
317    fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
318        B::int_cummax(tensor, dim)
319    }
320
321    fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
322        B::int_greater(lhs, rhs)
323    }
324
325    fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
326        B::int_greater_elem(lhs, rhs)
327    }
328
329    fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
330        B::int_greater_equal(lhs, rhs)
331    }
332
333    fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
334        B::int_greater_equal_elem(lhs, rhs)
335    }
336
337    fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
338        B::int_lower(lhs, rhs)
339    }
340
341    fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
342        B::int_lower_elem(lhs, rhs)
343    }
344
345    fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
346        B::int_lower_equal(lhs, rhs)
347    }
348
349    fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
350        B::int_lower_equal_elem(lhs, rhs)
351    }
352
353    fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
354        B::int_argmax(tensor, dim)
355    }
356
357    fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
358        B::int_argmin(tensor, dim)
359    }
360
361    fn max(tensor: Self::Primitive) -> Self::Primitive {
362        B::int_max(tensor)
363    }
364
365    fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
366        B::int_max_dim(tensor, dim)
367    }
368
369    fn max_dim_with_indices(
370        tensor: Self::Primitive,
371        dim: usize,
372    ) -> (Self::Primitive, IntTensor<B>) {
373        B::int_max_dim_with_indices(tensor, dim)
374    }
375
376    fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
377        B::int_max_abs(tensor)
378    }
379
380    fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
381        B::int_max_abs_dim(tensor, dim)
382    }
383
384    fn min(tensor: Self::Primitive) -> Self::Primitive {
385        B::int_min(tensor)
386    }
387
388    fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
389        B::int_min_dim(tensor, dim)
390    }
391
392    fn min_dim_with_indices(
393        tensor: Self::Primitive,
394        dim: usize,
395    ) -> (Self::Primitive, IntTensor<B>) {
396        B::int_min_dim_with_indices(tensor, dim)
397    }
398
399    fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
400        B::int_clamp(tensor, min, max)
401    }
402
403    fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
404        B::int_clamp_min(tensor, min)
405    }
406
407    fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
408        B::int_clamp_max(tensor, max)
409    }
410}
411
412impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
413    type InnerKind = Int;
414
415    fn inner(
416        tensor: <Self as TensorKind<B>>::Primitive,
417    ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
418        B::int_inner(tensor)
419    }
420
421    fn from_inner(
422        inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
423    ) -> <Self as TensorKind<B>>::Primitive {
424        B::int_from_inner(inner)
425    }
426}