1use crate::{Autodiff, checkpoint::strategy::CheckpointStrategy, tensor::AutodiffTensor};
2use alloc::vec::Vec;
3
4use burn_tensor::{
5 Device, Distribution, IntDType, Shape, TensorData,
6 backend::Backend,
7 ops::{BoolTensor, IntTensor, IntTensorOps},
8};
9
10impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
11 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<B> {
12 B::int_from_data(data, device)
13 }
14
15 async fn int_into_data(tensor: IntTensor<B>) -> TensorData {
16 B::int_into_data(tensor).await
17 }
18
19 fn int_to_device(tensor: IntTensor<B>, device: &Device<Self>) -> IntTensor<B> {
20 B::int_to_device(tensor, device)
21 }
22
23 fn int_device(tensor: &IntTensor<B>) -> Device<Self> {
24 B::int_device(tensor)
25 }
26
27 fn int_reshape(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
28 B::int_reshape(tensor, shape)
29 }
30
31 fn int_slice(tensor: IntTensor<B>, slices: &[burn_tensor::Slice]) -> IntTensor<B> {
32 B::int_slice(tensor, slices)
33 }
34
35 fn int_empty(
36 shape: Shape,
37 device: &<Autodiff<B> as Backend>::Device,
38 dtype: IntDType,
39 ) -> IntTensor<B> {
40 B::int_empty(shape, device, dtype)
41 }
42
43 fn int_slice_assign(
44 tensor: IntTensor<B>,
45 slices: &[burn_tensor::Slice],
46 value: IntTensor<B>,
47 ) -> IntTensor<B> {
48 B::int_slice_assign(tensor, slices, value)
49 }
50
51 fn int_cat(tensors: Vec<IntTensor<B>>, dim: usize) -> IntTensor<B> {
52 B::int_cat(tensors, dim)
53 }
54
55 fn int_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
56 B::int_equal(lhs, rhs)
57 }
58
59 fn int_equal_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
60 B::int_equal_elem(lhs, rhs)
61 }
62
63 fn int_add(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
64 B::int_add(lhs, rhs)
65 }
66
67 fn int_add_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
68 B::int_add_scalar(lhs, rhs)
69 }
70
71 fn int_clamp_min(tensor: IntTensor<B>, min: B::IntElem) -> IntTensor<B> {
72 B::int_clamp_min(tensor, min)
73 }
74
75 fn int_clamp_max(tensor: IntTensor<B>, max: B::IntElem) -> IntTensor<B> {
76 B::int_clamp_max(tensor, max)
77 }
78
79 fn int_clamp(tensor: IntTensor<B>, min: B::IntElem, max: B::IntElem) -> IntTensor<B> {
80 B::int_clamp(tensor, min, max)
81 }
82
83 fn int_sub(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
84 B::int_sub(lhs, rhs)
85 }
86
87 fn int_sub_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
88 B::int_sub_scalar(lhs, rhs)
89 }
90
91 fn int_mul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
92 B::int_mul(lhs, rhs)
93 }
94
95 fn int_mul_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
96 B::int_mul_scalar(lhs, rhs)
97 }
98
99 fn int_div(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
100 B::int_div(lhs, rhs)
101 }
102
103 fn int_div_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
104 B::int_div_scalar(lhs, rhs)
105 }
106
107 fn int_remainder(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
108 B::int_remainder(lhs, rhs)
109 }
110
111 fn int_remainder_scalar(lhs: IntTensor<B>, rhs: B::IntElem) -> IntTensor<B> {
112 B::int_remainder_scalar(lhs, rhs)
113 }
114
115 fn int_matmul(lhs: IntTensor<B>, rhs: IntTensor<B>) -> IntTensor<B> {
116 B::int_matmul(lhs, rhs)
117 }
118
119 fn int_neg(tensor: IntTensor<B>) -> IntTensor<B> {
120 B::int_neg(tensor)
121 }
122
123 fn int_zeros(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
124 B::int_zeros(shape, device, dtype)
125 }
126
127 fn int_ones(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<B> {
128 B::int_ones(shape, device, dtype)
129 }
130
131 fn int_full(
132 shape: Shape,
133 fill_value: B::IntElem,
134 device: &Device<Self>,
135 dtype: IntDType,
136 ) -> IntTensor<B> {
137 B::int_full(shape, fill_value, device, dtype)
138 }
139
140 fn int_sum(tensor: IntTensor<B>) -> IntTensor<B> {
141 B::int_sum(tensor)
142 }
143
144 fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
145 B::int_sum_dim(tensor, dim)
146 }
147
148 fn int_mean(tensor: IntTensor<B>) -> IntTensor<B> {
149 B::int_mean(tensor)
150 }
151
152 fn int_mean_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
153 B::int_mean_dim(tensor, dim)
154 }
155
156 fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
157 B::int_cumsum(tensor, dim)
158 }
159
160 fn int_cumprod(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
161 B::int_cumprod(tensor, dim)
162 }
163
164 fn int_cummin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
165 B::int_cummin(tensor, dim)
166 }
167
168 fn int_cummax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
169 B::int_cummax(tensor, dim)
170 }
171
172 fn int_repeat_dim(tensor: IntTensor<B>, dim: usize, times: usize) -> IntTensor<B> {
173 B::int_repeat_dim(tensor, dim, times)
174 }
175
176 fn int_greater(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
177 B::int_greater(lhs, rhs)
178 }
179
180 fn int_greater_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
181 B::int_greater_elem(lhs, rhs)
182 }
183
184 fn int_greater_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
185 B::int_greater_equal(lhs, rhs)
186 }
187
188 fn int_greater_equal_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
189 B::int_greater_equal_elem(lhs, rhs)
190 }
191
192 fn int_lower(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
193 B::int_lower(lhs, rhs)
194 }
195
196 fn int_lower_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
197 B::int_lower_elem(lhs, rhs)
198 }
199
200 fn int_lower_equal(lhs: IntTensor<B>, rhs: IntTensor<B>) -> BoolTensor<B> {
201 B::int_lower_equal(lhs, rhs)
202 }
203
204 fn int_lower_equal_elem(lhs: IntTensor<B>, rhs: B::IntElem) -> BoolTensor<B> {
205 B::int_lower_equal_elem(lhs, rhs)
206 }
207
208 fn int_gather(dim: usize, tensor: IntTensor<B>, indices: IntTensor<B>) -> IntTensor<B> {
209 B::int_gather(dim, tensor, indices)
210 }
211
212 fn int_scatter(
213 dim: usize,
214 tensor: IntTensor<B>,
215 indices: IntTensor<B>,
216 value: IntTensor<B>,
217 ) -> IntTensor<B> {
218 B::int_scatter(dim, tensor, indices, value)
219 }
220
221 fn int_select(tensor: IntTensor<B>, dim: usize, indices: IntTensor<B>) -> IntTensor<B> {
222 B::int_select(tensor, dim, indices)
223 }
224
225 fn int_select_assign(
226 tensor: IntTensor<B>,
227 dim: usize,
228 indices: IntTensor<B>,
229 value: IntTensor<B>,
230 ) -> IntTensor<B> {
231 B::int_select_assign(tensor, dim, indices, value)
232 }
233
234 fn int_mask_where(
235 tensor: IntTensor<B>,
236 mask: BoolTensor<B>,
237 value: IntTensor<B>,
238 ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
239 B::int_mask_where(tensor, mask, value)
240 }
241
242 fn int_mask_fill(
243 tensor: IntTensor<B>,
244 mask: BoolTensor<B>,
245 value: B::IntElem,
246 ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
247 B::int_mask_fill(tensor, mask, value)
248 }
249
250 fn int_argmax(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
251 B::int_argmax(tensor, dim)
252 }
253 fn int_argmin(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
254 B::int_argmin(tensor, dim)
255 }
256 fn int_max(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
257 B::int_max(tensor)
258 }
259 fn int_max_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
260 B::int_max_dim(tensor, dim)
261 }
262 fn int_max_dim_with_indices(
263 tensor: B::IntTensorPrimitive,
264 dim: usize,
265 ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
266 B::int_max_dim_with_indices(tensor, dim)
267 }
268 fn int_min(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
269 B::int_min(tensor)
270 }
271 fn int_min_dim(tensor: B::IntTensorPrimitive, dim: usize) -> B::IntTensorPrimitive {
272 B::int_min_dim(tensor, dim)
273 }
274 fn int_min_dim_with_indices(
275 tensor: B::IntTensorPrimitive,
276 dim: usize,
277 ) -> (B::IntTensorPrimitive, B::IntTensorPrimitive) {
278 B::int_min_dim_with_indices(tensor, dim)
279 }
280 fn int_abs(tensor: B::IntTensorPrimitive) -> B::IntTensorPrimitive {
281 B::int_abs(tensor)
282 }
283 fn int_into_float(
284 tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
285 ) -> <Autodiff<B> as Backend>::FloatTensorPrimitive {
286 AutodiffTensor::new(B::int_into_float(tensor))
287 }
288
289 fn int_swap_dims(
290 tensor: <Autodiff<B> as Backend>::IntTensorPrimitive,
291 dim1: usize,
292 dim2: usize,
293 ) -> <Autodiff<B> as Backend>::IntTensorPrimitive {
294 B::int_swap_dims(tensor, dim1, dim2)
295 }
296
297 fn int_random(
298 shape: Shape,
299 distribution: Distribution,
300 device: &Device<Self>,
301 ) -> IntTensor<Self> {
302 B::int_random(shape, distribution, device)
303 }
304
305 fn int_arange(range: core::ops::Range<i64>, device: &Device<Self>) -> IntTensor<Self> {
306 B::int_arange(range, device)
307 }
308
309 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
310 B::int_permute(tensor, axes)
311 }
312
313 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
314 B::int_flip(tensor, axes)
315 }
316
317 fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
318 B::int_sign(tensor)
319 }
320
321 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
322 B::int_prod(tensor)
323 }
324
325 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
326 B::int_prod_dim(tensor, dim)
327 }
328
329 fn int_expand(tensor: IntTensor<B>, shape: Shape) -> IntTensor<B> {
330 B::int_expand(tensor, shape)
331 }
332
333 fn int_sort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
334 B::int_sort(tensor, dim, descending)
335 }
336
337 fn int_sort_with_indices(
338 tensor: IntTensor<Self>,
339 dim: usize,
340 descending: bool,
341 ) -> (IntTensor<Self>, IntTensor<Self>) {
342 B::int_sort_with_indices(tensor, dim, descending)
343 }
344
345 fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
346 B::int_argsort(tensor, dim, descending)
347 }
348
349 fn bitwise_and(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
350 B::bitwise_and(lhs, rhs)
351 }
352
353 fn bitwise_and_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
354 B::bitwise_and_scalar(lhs, rhs)
355 }
356
357 fn bitwise_or(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
358 B::bitwise_or(lhs, rhs)
359 }
360
361 fn bitwise_or_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
362 B::bitwise_or_scalar(lhs, rhs)
363 }
364
365 fn bitwise_xor(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
366 B::bitwise_xor(lhs, rhs)
367 }
368
369 fn bitwise_xor_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
370 B::bitwise_xor_scalar(lhs, rhs)
371 }
372
373 fn bitwise_not(tensor: IntTensor<Self>) -> IntTensor<Self> {
374 B::bitwise_not(tensor)
375 }
376
377 fn bitwise_left_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
378 B::bitwise_left_shift(lhs, rhs)
379 }
380
381 fn bitwise_left_shift_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
382 B::bitwise_left_shift_scalar(lhs, rhs)
383 }
384
385 fn bitwise_right_shift(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
386 B::bitwise_right_shift(lhs, rhs)
387 }
388
389 fn bitwise_right_shift_scalar(lhs: IntTensor<Self>, rhs: B::IntElem) -> IntTensor<Self> {
390 B::bitwise_right_shift_scalar(lhs, rhs)
391 }
392
393 fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
394 B::int_cast(tensor, dtype)
395 }
396
397 fn int_unfold(
398 tensor: IntTensor<Self>,
399 dim: usize,
400 size: usize,
401 step: usize,
402 ) -> IntTensor<Self> {
403 B::int_unfold(tensor, dim, size, step)
404 }
405}