burn_autodiff/ops/
int_tensor.rs

1use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
2use alloc::vec::Vec;
3
4use burn_tensor::{
5    Device, Distribution, IntDType, Shape, TensorData,
6    backend::Backend,
7    ops::{BoolTensor, IntTensor, IntTensorOps},
8};
9
10impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
11    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<B> {
12        B::int_from_data(data, device)
13    }
14
15    async fn int_into_data(tensor: IntTensor<B>) -> TensorData {
16        B::int_into_data(tensor).await
17    }
18
19    fn int_to_device(tensor: IntTensor<B>, device: &Device<Self>) -> IntTensor<B> {
20        B::int_to_device(tensor, device)
21    }
22
23    fn int_device(tensor: &IntTensor<B>) -> Device<Self> {
24        B::int_device(tensor)
25    }
26
27    fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
28        B::int_reshape(tensor, shape)
29    }
30
31    fn int_slice(tensor: IntTensor<B>, slices: &[burn_tensor::Slice]) -> IntTensor<B> {
32        B::int_slice(tensor, slices)
33    }
34
35    fn int_empty(
36        shape: Shape,
37        device: &<Autodiff<B> as Backend>::Device,
38        dtype: IntDType,
39    ) -> IntTensor<B> {
40        B::int_empty(shape, device, dtype)
41    }
42
43    fn int_slice_assign(
44        tensor: IntTensor<B>,
45        slices: &[burn_tensor::Slice],
46        value: IntTensor<B>,
47    ) -> IntTensor<B> {
48        B::int_slice_assign(tensor, slices, value)
49    }
50
51    fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
52        B::int_cat(tensors, dim)
53    }
54
55    fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
56        B::int_equal(lhs, rhs)
57    }
58
59    fn int_equal_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
60        B::int_equal_elem(lhs, rhs)
61    }
62
63    fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
64        B::int_add(lhs, rhs)
65    }
66
67    fn int_add_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
68        B::int_add_scalar(lhs, rhs)
69    }
70
71    fn int_clamp_min(tensor: IntTensor<B>, min: B::IntElem) -> IntTensor<B> {
72        B::int_clamp_min(tensor, min)
73    }
74
75    fn int_clamp_max(tensor: IntTensor<B>, max: B::IntElem) -> IntTensor<B> {
76        B::int_clamp_max(tensor, max)
77    }
78
79    fn int_clamp(tensor: IntTensor<B>, min: B::IntElem, max: B::IntElem) -> IntTensor<B> {
80        B::int_clamp(tensor, min, max)
81    }
82
83    fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
84        B::int_sub(lhs, rhs)
85    }
86
87    fn int_sub_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
88        B::int_sub_scalar(lhs, rhs)
89    }
90
91    fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
92        B::int_mul(lhs, rhs)
93    }
94
95    fn int_mul_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
96        B::int_mul_scalar(lhs, rhs)
97    }
98
99    fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
100        B::int_div(lhs, rhs)
101    }
102
103    fn int_div_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
104        B::int_div_scalar(lhs, rhs)
105    }
106
107    fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
108        B::int_remainder(lhs, rhs)
109    }
110
111    fn int_remainder_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
112        B::int_remainder_scalar(lhs, rhs)
113    }
114
115    fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
116        B::int_matmul(lhs, rhs)
117    }
118
119    fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
120        B::int_neg(tensor)
121    }
122
123    fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
124        B::int_zeros(shape, device, dtype)
125    }
126
127    fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
128        B::int_ones(shape, device, dtype)
129    }
130
131    fn int_full(
132        shape: Shape,
133        fill_value: B::IntElem,
134        device: &Device<Self>,
135        dtype: IntDType,
136    ) -> IntTensor<B> {
137        B::int_full(shape, fill_value, device, dtype)
138    }
139
140    fn int_sum(tensor: IntTensor<B>) -> IntTensor<B> {
141        B::int_sum(tensor)
142    }
143
144    fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
145        B::int_sum_dim(tensor, dim)
146    }
147
148    fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
149        B::int_mean(tensor)
150    }
151
152    fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
153        B::int_mean_dim(tensor, dim)
154    }
155
156    fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
157        B::int_cumsum(tensor, dim)
158    }
159
160    fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
161        B::int_cumprod(tensor, dim)
162    }
163
164    fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
165        B::int_cummin(tensor, dim)
166    }
167
168    fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
169        B::int_cummax(tensor, dim)
170    }
171
172    fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
173        B::int_repeat_dim(tensor, dim, times)
174    }
175
176    fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
177        B::int_greater(lhs, rhs)
178    }
179
180    fn int_greater_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
181        B::int_greater_elem(lhs, rhs)
182    }
183
184    fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
185        B::int_greater_equal(lhs, rhs)
186    }
187
188    fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
189        B::int_greater_equal_elem(lhs, rhs)
190    }
191
192    fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
193        B::int_lower(lhs, rhs)
194    }
195
196    fn int_lower_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
197        B::int_lower_elem(lhs, rhs)
198    }
199
200    fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
201        B::int_lower_equal(lhs, rhs)
202    }
203
204    fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
205        B::int_lower_equal_elem(lhs, rhs)
206    }
207
208    fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B> {
209        B::int_gather(dim, tensor, indices)
210    }
211
212    fn int_scatter(
213        dim: usize,
214        tensor: IntTensor<B>,
215        indices: IntTensor<B>,
216        value: IntTensor<B>,
217    ) -> IntTensor<B> {
218        B::int_scatter(dim, tensor, indices, value)
219    }
220
221    fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B> {
222        B::int_select(tensor, dim, indices)
223    }
224
225    fn int_select_assign(
226        tensor: IntTensor<B>,
227        dim: usize,
228        indices: IntTensor<B>,
229        value: IntTensor<B>,
230    ) -> IntTensor<B> {
231        B::int_select_assign(tensor, dim, indices, value)
232    }
233
234    fn int_mask_where(
235        tensor: IntTensor<B>,
236        mask: BoolTensor<B>,
237        value: IntTensor<B>,
238    ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
239        B::int_mask_where(tensor, mask, value)
240    }
241
242    fn int_mask_fill(
243        tensor: IntTensor<B>,
244        mask: BoolTensor<B>,
245        value: B::IntElem,
246    ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
247        B::int_mask_fill(tensor, mask, value)
248    }
249
250    fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
251        B::int_argmax(tensor, dim)
252    }
253    fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
254        B::int_argmin(tensor, dim)
255    }
256    fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
257        B::int_max(tensor)
258    }
259    fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
260        B::int_max_dim(tensor, dim)
261    }
262    fn int_max_dim_with_indices(
263        tensor: B::IntTensorPrimitive,
264        dim: usize,
265    ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
266        B::int_max_dim_with_indices(tensor, dim)
267    }
268    fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
269        B::int_min(tensor)
270    }
271    fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
272        B::int_min_dim(tensor, dim)
273    }
274    fn int_min_dim_with_indices(
275        tensor: B::IntTensorPrimitive,
276        dim: usize,
277    ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
278        B::int_min_dim_with_indices(tensor, dim)
279    }
280    fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
281        B::int_abs(tensor)
282    }
283    fn int_into_float(
284        tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
285    ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
286        AutodiffTensor::new(B::int_into_float(tensor))
287    }
288
289    fn int_swap_dims(
290        tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
291        dim1: usize,
292        dim2: usize,
293    ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
294        B::int_swap_dims(tensor, dim1, dim2)
295    }
296
297    fn int_random(
298        shape: Shape,
299        distribution: Distribution,
300        device: &Device<Self>,
301    ) -> IntTensor<Self> {
302        B::int_random(shape, distribution, device)
303    }
304
305    fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
306        B::int_arange(range, device)
307    }
308
309    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
310        B::int_permute(tensor, axes)
311    }
312
313    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
314        B::int_flip(tensor, axes)
315    }
316
317    fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
318        B::int_sign(tensor)
319    }
320
321    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
322        B::int_prod(tensor)
323    }
324
325    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
326        B::int_prod_dim(tensor, dim)
327    }
328
329    fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
330        B::int_expand(tensor, shape)
331    }
332
333    fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
334        B::int_sort(tensor, dim, descending)
335    }
336
337    fn int_sort_with_indices(
338        tensor: IntTensor<Self>,
339        dim: usize,
340        descending: bool,
341    ) -> (IntTensor<Self>, IntTensor<Self>) {
342        B::int_sort_with_indices(tensor, dim, descending)
343    }
344
345    fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
346        B::int_argsort(tensor, dim, descending)
347    }
348
349    fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
350        B::bitwise_and(lhs, rhs)
351    }
352
353    fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
354        B::bitwise_and_scalar(lhs, rhs)
355    }
356
357    fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
358        B::bitwise_or(lhs, rhs)
359    }
360
361    fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
362        B::bitwise_or_scalar(lhs, rhs)
363    }
364
365    fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
366        B::bitwise_xor(lhs, rhs)
367    }
368
369    fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
370        B::bitwise_xor_scalar(lhs, rhs)
371    }
372
373    fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
374        B::bitwise_not(tensor)
375    }
376
377    fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
378        B::bitwise_left_shift(lhs, rhs)
379    }
380
381    fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
382        B::bitwise_left_shift_scalar(lhs, rhs)
383    }
384
385    fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
386        B::bitwise_right_shift(lhs, rhs)
387    }
388
389    fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
390        B::bitwise_right_shift_scalar(lhs, rhs)
391    }
392
393    fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
394        B::int_cast(tensor, dtype)
395    }
396
397    fn int_unfold(
398        tensor: IntTensor<Self>,
399        dim: usize,
400        size: usize,
401        step: usize,
402    ) -> IntTensor<Self> {
403        B::int_unfold(tensor, dim, size, step)
404    }
405}