1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData,
6 get_device_settings,
7 ops::TransactionPrimitive,
8 tensor::{
9 BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
10 Ordered, TensorKind, TransactionOp,
11 },
12};
13
14impl<B: Backend> TransactionOp<B> for Int {
15 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
16 tr.register_int(tensor);
17 }
18}
19impl<B: Backend> BasicOps<B> for Int {
20 type Elem = B::IntElem;
21
22 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
23 B::int_empty(shape, device, dtype.into())
24 }
25
26 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
27 B::int_zeros(shape, device, dtype.into())
28 }
29 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
30 B::int_ones(shape, device, dtype.into())
31 }
32
33 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
34 B::int_full(shape, fill_value, device, dtype.into())
35 }
36
37 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
38 B::int_reshape(tensor, shape)
39 }
40
41 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
42 B::int_transpose(tensor)
43 }
44
45 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
46 B::int_swap_dims(tensor, dim1, dim2)
47 }
48
49 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
50 B::int_slice(tensor, slices)
51 }
52
53 fn slice_assign(
54 tensor: Self::Primitive,
55 slices: &[Slice],
56 value: Self::Primitive,
57 ) -> Self::Primitive {
58 B::int_slice_assign(tensor, slices, value)
59 }
60
61 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
62 B::int_select(tensor, dim, indices)
63 }
64
65 fn select_assign(
66 tensor: Self::Primitive,
67 dim: usize,
68 indices: IntTensor<B>,
69 values: Self::Primitive,
70 update: IndexingUpdateOp,
71 ) -> Self::Primitive {
72 match update {
73 IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
74 _ => unimplemented!(),
75 }
76 }
77
78 fn mask_where(
79 tensor: Self::Primitive,
80 mask: B::BoolTensorPrimitive,
81 source: Self::Primitive,
82 ) -> Self::Primitive {
83 B::int_mask_where(tensor, mask, source)
84 }
85
86 fn mask_fill(
87 tensor: Self::Primitive,
88 mask: B::BoolTensorPrimitive,
89 value: Scalar,
90 ) -> Self::Primitive {
91 B::int_mask_fill(tensor, mask, value)
92 }
93
94 fn gather(
95 dim: usize,
96 tensor: Self::Primitive,
97 indices: B::IntTensorPrimitive,
98 ) -> Self::Primitive {
99 B::int_gather(dim, tensor, indices)
100 }
101
102 fn scatter(
103 dim: usize,
104 tensor: Self::Primitive,
105 indices: B::IntTensorPrimitive,
106 values: Self::Primitive,
107 update: IndexingUpdateOp,
108 ) -> Self::Primitive {
109 match update {
110 IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
111 _ => unimplemented!(),
112 }
113 }
114
115 fn scatter_nd(
116 data: Self::Primitive,
117 indices: IntTensor<B>,
118 values: Self::Primitive,
119 reduction: IndexingUpdateOp,
120 ) -> Self::Primitive {
121 B::int_scatter_nd(data, indices, values, reduction)
122 }
123
124 fn gather_nd(data: Self::Primitive, indices: IntTensor<B>) -> Self::Primitive {
125 B::int_gather_nd(data, indices)
126 }
127
128 fn device(tensor: &Self::Primitive) -> Device<B> {
129 B::int_device(tensor)
130 }
131
132 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
133 B::int_to_device(tensor, device)
134 }
135
136 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
137 B::int_into_data(tensor).await
138 }
139
140 fn from_data(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
141 B::int_from_data(data.convert_dtype(dtype), device)
142 }
143
144 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
145 B::int_repeat_dim(tensor, dim, times)
146 }
147
148 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
149 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
150 B::int_equal(lhs, rhs, out_dtype)
151 }
152
153 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
154 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
155 B::int_not_equal(lhs, rhs, out_dtype)
156 }
157
158 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
159 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
160 B::int_equal_elem(lhs, rhs, out_dtype)
161 }
162
163 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
164 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
165 B::int_not_equal_elem(lhs, rhs, out_dtype)
166 }
167
168 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
169 B::int_cat(vectors, dim)
170 }
171
172 fn any(tensor: Self::Primitive) -> BoolTensor<B> {
173 let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
174 B::int_any(tensor, out_dtype)
175 }
176
177 fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
178 let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
179 B::int_any_dim(tensor, dim, out_dtype)
180 }
181
182 fn all(tensor: Self::Primitive) -> BoolTensor<B> {
183 let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
184 B::int_all(tensor, out_dtype)
185 }
186
187 fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
188 let out_dtype = get_device_settings::<B>(&B::int_device(&tensor)).bool_dtype;
189 B::int_all_dim(tensor, dim, out_dtype)
190 }
191
192 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
193 B::int_permute(tensor, axes)
194 }
195
196 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
197 B::int_expand(tensor, shape)
198 }
199
200 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
201 B::int_flip(tensor, axes)
202 }
203
204 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
205 B::int_unfold(tensor, dim, size, step)
206 }
207}
208
209impl<B: Backend> Numeric<B> for Int {
210 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
211 B::int_add(lhs, rhs)
212 }
213 fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
214 B::int_add_scalar(lhs, rhs)
215 }
216 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
217 B::int_sub(lhs, rhs)
218 }
219 fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
220 B::int_sub_scalar(lhs, rhs)
221 }
222 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
223 B::int_div(lhs, rhs)
224 }
225 fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
226 B::int_div_scalar(lhs, rhs)
227 }
228 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
229 B::int_remainder(lhs, rhs)
230 }
231 fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
232 B::int_remainder_scalar(lhs, rhs)
233 }
234 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
235 B::int_mul(lhs, rhs)
236 }
237 fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
238 B::int_mul_scalar(lhs, rhs)
239 }
240 fn neg(tensor: Self::Primitive) -> Self::Primitive {
241 B::int_neg(tensor)
242 }
243
244 fn sum(tensor: Self::Primitive) -> Self::Primitive {
245 B::int_sum(tensor)
246 }
247
248 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
249 B::int_sum_dim(tensor, dim)
250 }
251
252 fn prod(tensor: Self::Primitive) -> Self::Primitive {
253 B::int_prod(tensor)
254 }
255
256 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
257 B::int_prod_dim(tensor, dim)
258 }
259
260 fn mean(tensor: Self::Primitive) -> Self::Primitive {
261 B::int_mean(tensor)
262 }
263 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
264 B::int_mean_dim(tensor, dim)
265 }
266 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
267 B::int_cumsum(tensor, dim)
268 }
269 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
270 B::int_cumprod(tensor, dim)
271 }
272
273 fn abs(tensor: Self::Primitive) -> Self::Primitive {
274 B::int_abs(tensor)
275 }
276
277 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
278 B::int_powi(lhs, rhs)
279 }
280
281 fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
282 B::int_powi_scalar(lhs, rhs)
283 }
284
285 fn random(
286 shape: Shape,
287 distribution: Distribution,
288 device: &Device<B>,
289 dtype: DType,
290 ) -> Self::Primitive {
291 B::int_random(shape, distribution, device, dtype.into())
292 }
293
294 fn sign(tensor: Self::Primitive) -> Self::Primitive {
295 B::int_sign(tensor)
296 }
297
298 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
306 B::int_matmul(lhs, rhs)
307 }
308}
309
310impl<B: Backend> Ordered<B> for Int {
311 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
312 B::int_sort(tensor, dim, descending)
313 }
314
315 fn sort_with_indices(
316 tensor: Self::Primitive,
317 dim: usize,
318 descending: bool,
319 ) -> (Self::Primitive, IntTensor<B>) {
320 B::int_sort_with_indices(tensor, dim, descending)
321 }
322
323 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
324 B::int_argsort(tensor, dim, descending)
325 }
326
327 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
328 B::int_cummin(tensor, dim)
329 }
330
331 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
332 B::int_cummax(tensor, dim)
333 }
334
335 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
336 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
337 B::int_greater(lhs, rhs, out_dtype)
338 }
339
340 fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
341 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
342 B::int_greater_elem(lhs, rhs, out_dtype)
343 }
344
345 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
346 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
347 B::int_greater_equal(lhs, rhs, out_dtype)
348 }
349
350 fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
351 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
352 B::int_greater_equal_elem(lhs, rhs, out_dtype)
353 }
354
355 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
356 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
357 B::int_lower(lhs, rhs, out_dtype)
358 }
359
360 fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
361 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
362 B::int_lower_elem(lhs, rhs, out_dtype)
363 }
364
365 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
366 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
367 B::int_lower_equal(lhs, rhs, out_dtype)
368 }
369
370 fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
371 let out_dtype = get_device_settings::<B>(&B::int_device(&lhs)).bool_dtype;
372 B::int_lower_equal_elem(lhs, rhs, out_dtype)
373 }
374
375 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
376 B::int_argmax(tensor, dim)
377 }
378
379 fn argtopk(tensor: Self::Primitive, dim: usize, k: usize) -> IntTensor<B> {
380 B::int_argtopk(tensor, dim, k)
381 }
382
383 fn topk(tensor: Self::Primitive, dim: usize, k: usize) -> IntTensor<B> {
384 B::int_topk(tensor, dim, k)
385 }
386
387 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
388 B::int_argmin(tensor, dim)
389 }
390
391 fn max(tensor: Self::Primitive) -> Self::Primitive {
392 B::int_max(tensor)
393 }
394
395 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
396 B::int_max_dim(tensor, dim)
397 }
398
399 fn max_dim_with_indices(
400 tensor: Self::Primitive,
401 dim: usize,
402 ) -> (Self::Primitive, IntTensor<B>) {
403 B::int_max_dim_with_indices(tensor, dim)
404 }
405
406 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
407 B::int_max_abs(tensor)
408 }
409
410 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
411 B::int_max_abs_dim(tensor, dim)
412 }
413
414 fn min(tensor: Self::Primitive) -> Self::Primitive {
415 B::int_min(tensor)
416 }
417
418 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
419 B::int_min_dim(tensor, dim)
420 }
421
422 fn min_dim_with_indices(
423 tensor: Self::Primitive,
424 dim: usize,
425 ) -> (Self::Primitive, IntTensor<B>) {
426 B::int_min_dim_with_indices(tensor, dim)
427 }
428
429 fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
430 B::int_clamp(tensor, min, max)
431 }
432
433 fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
434 B::int_clamp_min(tensor, min)
435 }
436
437 fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
438 B::int_clamp_max(tensor, max)
439 }
440}
441
442impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
443 type InnerKind = Int;
444
445 fn inner(
446 tensor: <Self as TensorKind<B>>::Primitive,
447 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
448 B::int_inner(tensor)
449 }
450
451 fn from_inner(
452 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
453 ) -> <Self as TensorKind<B>>::Primitive {
454 B::int_from_inner(inner)
455 }
456}