1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData,
6 get_device_settings,
7 ops::TransactionPrimitive,
8 tensor::{
9 BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
10 Ordered, TensorKind, TransactionOp,
11 },
12};
13
14impl<B: Backend> TransactionOp<B> for Int {
15 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
16 tr.register_int(tensor);
17 }
18}
19impl<B: Backend> BasicOps<B> for Int {
20 type Elem = B::IntElem;
21
22 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
23 B::int_empty(shape, device, dtype.into())
24 }
25
26 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
27 B::int_zeros(shape, device, dtype.into())
28 }
29 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
30 B::int_ones(shape, device, dtype.into())
31 }
32
33 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
34 B::int_full(shape, fill_value, device, dtype.into())
35 }
36
37 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
38 B::int_reshape(tensor, shape)
39 }
40
41 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
42 B::int_transpose(tensor)
43 }
44
45 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
46 B::int_swap_dims(tensor, dim1, dim2)
47 }
48
49 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
50 B::int_slice(tensor, slices)
51 }
52
53 fn slice_assign(
54 tensor: Self::Primitive,
55 slices: &[Slice],
56 value: Self::Primitive,
57 ) -> Self::Primitive {
58 B::int_slice_assign(tensor, slices, value)
59 }
60
61 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
62 B::int_select(tensor, dim, indices)
63 }
64
65 fn select_assign(
66 tensor: Self::Primitive,
67 dim: usize,
68 indices: IntTensor<B>,
69 values: Self::Primitive,
70 update: IndexingUpdateOp,
71 ) -> Self::Primitive {
72 match update {
73 IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
74 }
75 }
76
77 fn mask_where(
78 tensor: Self::Primitive,
79 mask: B::BoolTensorPrimitive,
80 source: Self::Primitive,
81 ) -> Self::Primitive {
82 B::int_mask_where(tensor, mask, source)
83 }
84
85 fn mask_fill(
86 tensor: Self::Primitive,
87 mask: B::BoolTensorPrimitive,
88 value: Scalar,
89 ) -> Self::Primitive {
90 B::int_mask_fill(tensor, mask, value)
91 }
92
93 fn gather(
94 dim: usize,
95 tensor: Self::Primitive,
96 indices: B::IntTensorPrimitive,
97 ) -> Self::Primitive {
98 B::int_gather(dim, tensor, indices)
99 }
100
101 fn scatter(
102 dim: usize,
103 tensor: Self::Primitive,
104 indices: B::IntTensorPrimitive,
105 values: Self::Primitive,
106 update: IndexingUpdateOp,
107 ) -> Self::Primitive {
108 match update {
109 IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
110 }
111 }
112
113 fn device(tensor: &Self::Primitive) -> Device<B> {
114 B::int_device(tensor)
115 }
116
117 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
118 B::int_to_device(tensor, device)
119 }
120
121 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
122 B::int_into_data(tensor).await
123 }
124
125 fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
126 B::int_from_data(data.convert_dtype(dtype), device)
127 }
128
129 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
130 B::int_repeat_dim(tensor, dim, times)
131 }
132
133 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
134 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
135 B::int_equal(lhs, rhs, out_dtype)
136 }
137
138 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
139 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
140 B::int_not_equal(lhs, rhs, out_dtype)
141 }
142
143 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
144 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
145 B::int_equal_elem(lhs, rhs, out_dtype)
146 }
147
148 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
149 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
150 B::int_not_equal_elem(lhs, rhs, out_dtype)
151 }
152
153 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
154 B::int_cat(vectors, dim)
155 }
156
157 fn any(tensor: Self::Primitive) -> BoolTensor<B> {
158 let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
159 B::int_any(tensor, out_dtype)
160 }
161
162 fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
163 let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
164 B::int_any_dim(tensor, dim, out_dtype)
165 }
166
167 fn all(tensor: Self::Primitive) -> BoolTensor<B> {
168 let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
169 B::int_all(tensor, out_dtype)
170 }
171
172 fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
173 let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
174 B::int_all_dim(tensor, dim, out_dtype)
175 }
176
177 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
178 B::int_permute(tensor, axes)
179 }
180
181 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
182 B::int_expand(tensor, shape)
183 }
184
185 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
186 B::int_flip(tensor, axes)
187 }
188
189 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
190 B::int_unfold(tensor, dim, size, step)
191 }
192}
193
194impl<B: Backend> Numeric<B> for Int {
195 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
196 B::int_add(lhs, rhs)
197 }
198 fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
199 B::int_add_scalar(lhs, rhs)
200 }
201 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
202 B::int_sub(lhs, rhs)
203 }
204 fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
205 B::int_sub_scalar(lhs, rhs)
206 }
207 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
208 B::int_div(lhs, rhs)
209 }
210 fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
211 B::int_div_scalar(lhs, rhs)
212 }
213 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
214 B::int_remainder(lhs, rhs)
215 }
216 fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
217 B::int_remainder_scalar(lhs, rhs)
218 }
219 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
220 B::int_mul(lhs, rhs)
221 }
222 fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
223 B::int_mul_scalar(lhs, rhs)
224 }
225 fn neg(tensor: Self::Primitive) -> Self::Primitive {
226 B::int_neg(tensor)
227 }
228
229 fn sum(tensor: Self::Primitive) -> Self::Primitive {
230 B::int_sum(tensor)
231 }
232
233 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
234 B::int_sum_dim(tensor, dim)
235 }
236
237 fn prod(tensor: Self::Primitive) -> Self::Primitive {
238 B::int_prod(tensor)
239 }
240
241 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
242 B::int_prod_dim(tensor, dim)
243 }
244
245 fn mean(tensor: Self::Primitive) -> Self::Primitive {
246 B::int_mean(tensor)
247 }
248 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
249 B::int_mean_dim(tensor, dim)
250 }
251 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
252 B::int_cumsum(tensor, dim)
253 }
254 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
255 B::int_cumprod(tensor, dim)
256 }
257
258 fn abs(tensor: Self::Primitive) -> Self::Primitive {
259 B::int_abs(tensor)
260 }
261
262 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
263 B::int_powi(lhs, rhs)
264 }
265
266 fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
267 B::int_powi_scalar(lhs, rhs)
268 }
269
270 fn random(
271 shape: Shape,
272 distribution: Distribution,
273 device: &Device<B>,
274 dtype: DType,
275 ) -> Self::Primitive {
276 B::int_random(shape, distribution, device, dtype.into())
277 }
278
279 fn sign(tensor: Self::Primitive) -> Self::Primitive {
280 B::int_sign(tensor)
281 }
282
283 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
291 B::int_matmul(lhs, rhs)
292 }
293}
294
295impl<B: Backend> Ordered<B> for Int {
296 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
297 B::int_sort(tensor, dim, descending)
298 }
299
300 fn sort_with_indices(
301 tensor: Self::Primitive,
302 dim: usize,
303 descending: bool,
304 ) -> (Self::Primitive, IntTensor<B>) {
305 B::int_sort_with_indices(tensor, dim, descending)
306 }
307
308 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
309 B::int_argsort(tensor, dim, descending)
310 }
311
312 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
313 B::int_cummin(tensor, dim)
314 }
315
316 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
317 B::int_cummax(tensor, dim)
318 }
319
320 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
321 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
322 B::int_greater(lhs, rhs, out_dtype)
323 }
324
325 fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
326 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
327 B::int_greater_elem(lhs, rhs, out_dtype)
328 }
329
330 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
331 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
332 B::int_greater_equal(lhs, rhs, out_dtype)
333 }
334
335 fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
336 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
337 B::int_greater_equal_elem(lhs, rhs, out_dtype)
338 }
339
340 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
341 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
342 B::int_lower(lhs, rhs, out_dtype)
343 }
344
345 fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
346 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
347 B::int_lower_elem(lhs, rhs, out_dtype)
348 }
349
350 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
351 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
352 B::int_lower_equal(lhs, rhs, out_dtype)
353 }
354
355 fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
356 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
357 B::int_lower_equal_elem(lhs, rhs, out_dtype)
358 }
359
360 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
361 B::int_argmax(tensor, dim)
362 }
363
364 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
365 B::int_argmin(tensor, dim)
366 }
367
368 fn max(tensor: Self::Primitive) -> Self::Primitive {
369 B::int_max(tensor)
370 }
371
372 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
373 B::int_max_dim(tensor, dim)
374 }
375
376 fn max_dim_with_indices(
377 tensor: Self::Primitive,
378 dim: usize,
379 ) -> (Self::Primitive, IntTensor<B>) {
380 B::int_max_dim_with_indices(tensor, dim)
381 }
382
383 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
384 B::int_max_abs(tensor)
385 }
386
387 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
388 B::int_max_abs_dim(tensor, dim)
389 }
390
391 fn min(tensor: Self::Primitive) -> Self::Primitive {
392 B::int_min(tensor)
393 }
394
395 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
396 B::int_min_dim(tensor, dim)
397 }
398
399 fn min_dim_with_indices(
400 tensor: Self::Primitive,
401 dim: usize,
402 ) -> (Self::Primitive, IntTensor<B>) {
403 B::int_min_dim_with_indices(tensor, dim)
404 }
405
406 fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
407 B::int_clamp(tensor, min, max)
408 }
409
410 fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
411 B::int_clamp_min(tensor, min)
412 }
413
414 fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
415 B::int_clamp_max(tensor, max)
416 }
417}
418
419impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
420 type InnerKind = Int;
421
422 fn inner(
423 tensor: <Self as TensorKind<B>>::Primitive,
424 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
425 B::int_inner(tensor)
426 }
427
428 fn from_inner(
429 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
430 ) -> <Self as TensorKind<B>>::Primitive {
431 B::int_from_inner(inner)
432 }
433}