1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, Distribution, ExecutionError, Scalar, TensorData,
6 ops::TransactionPrimitive,
7 tensor::{
8 BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
9 Ordered, TensorKind,
10 },
11};
12
13impl<B: Backend> BasicOps<B> for Int {
14 type Elem = B::IntElem;
15
16 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
17 B::int_empty(shape, device, dtype.into())
18 }
19
20 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
21 B::int_zeros(shape, device, dtype.into())
22 }
23 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
24 B::int_ones(shape, device, dtype.into())
25 }
26
27 fn full(shape: Shape, fill_value: Scalar, device: &Device<B>, dtype: DType) -> Self::Primitive {
28 B::int_full(shape, fill_value, device, dtype.into())
29 }
30
31 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
32 tr.register_int(tensor);
33 }
34
35 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
36 B::int_reshape(tensor, shape)
37 }
38
39 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
40 B::int_transpose(tensor)
41 }
42
43 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
44 B::int_swap_dims(tensor, dim1, dim2)
45 }
46
47 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
48 B::int_slice(tensor, slices)
49 }
50
51 fn slice_assign(
52 tensor: Self::Primitive,
53 slices: &[Slice],
54 value: Self::Primitive,
55 ) -> Self::Primitive {
56 B::int_slice_assign(tensor, slices, value)
57 }
58
59 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
60 B::int_select(tensor, dim, indices)
61 }
62
63 fn select_assign(
64 tensor: Self::Primitive,
65 dim: usize,
66 indices: IntTensor<B>,
67 values: Self::Primitive,
68 update: IndexingUpdateOp,
69 ) -> Self::Primitive {
70 match update {
71 IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
72 }
73 }
74
75 fn mask_where(
76 tensor: Self::Primitive,
77 mask: B::BoolTensorPrimitive,
78 source: Self::Primitive,
79 ) -> Self::Primitive {
80 B::int_mask_where(tensor, mask, source)
81 }
82
83 fn mask_fill(
84 tensor: Self::Primitive,
85 mask: B::BoolTensorPrimitive,
86 value: Scalar,
87 ) -> Self::Primitive {
88 B::int_mask_fill(tensor, mask, value)
89 }
90
91 fn gather(
92 dim: usize,
93 tensor: Self::Primitive,
94 indices: B::IntTensorPrimitive,
95 ) -> Self::Primitive {
96 B::int_gather(dim, tensor, indices)
97 }
98
99 fn scatter(
100 dim: usize,
101 tensor: Self::Primitive,
102 indices: B::IntTensorPrimitive,
103 values: Self::Primitive,
104 update: IndexingUpdateOp,
105 ) -> Self::Primitive {
106 match update {
107 IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
108 }
109 }
110
111 fn device(tensor: &Self::Primitive) -> Device<B> {
112 B::int_device(tensor)
113 }
114
115 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
116 B::int_to_device(tensor, device)
117 }
118
119 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
120 B::int_into_data(tensor).await
121 }
122
123 fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
124 B::int_from_data(data.convert::<B::IntElem>(), device)
125 }
126
127 fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
128 if !dtype.is_int() {
129 panic!("Expected int dtype, got {dtype:?}")
130 }
131
132 B::int_from_data(data.convert_dtype(dtype), device)
133 }
134
135 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
136 B::int_repeat_dim(tensor, dim, times)
137 }
138
139 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
140 B::int_equal(lhs, rhs)
141 }
142
143 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
144 B::int_not_equal(lhs, rhs)
145 }
146
147 fn equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
148 B::int_equal_elem(lhs, rhs)
149 }
150
151 fn not_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
152 B::int_not_equal_elem(lhs, rhs)
153 }
154
155 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
156 B::int_cat(vectors, dim)
157 }
158
159 fn any(tensor: Self::Primitive) -> BoolTensor<B> {
160 B::int_any(tensor)
161 }
162
163 fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
164 B::int_any_dim(tensor, dim)
165 }
166
167 fn all(tensor: Self::Primitive) -> BoolTensor<B> {
168 B::int_all(tensor)
169 }
170
171 fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
172 B::int_all_dim(tensor, dim)
173 }
174
175 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
176 B::int_permute(tensor, axes)
177 }
178
179 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
180 B::int_expand(tensor, shape)
181 }
182
183 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
184 B::int_flip(tensor, axes)
185 }
186
187 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
188 B::int_unfold(tensor, dim, size, step)
189 }
190}
191
192impl<B: Backend> Numeric<B> for Int {
193 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
194 B::int_add(lhs, rhs)
195 }
196 fn add_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
197 B::int_add_scalar(lhs, rhs)
198 }
199 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
200 B::int_sub(lhs, rhs)
201 }
202 fn sub_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
203 B::int_sub_scalar(lhs, rhs)
204 }
205 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
206 B::int_div(lhs, rhs)
207 }
208 fn div_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
209 B::int_div_scalar(lhs, rhs)
210 }
211 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
212 B::int_remainder(lhs, rhs)
213 }
214 fn remainder_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
215 B::int_remainder_scalar(lhs, rhs)
216 }
217 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
218 B::int_mul(lhs, rhs)
219 }
220 fn mul_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
221 B::int_mul_scalar(lhs, rhs)
222 }
223 fn neg(tensor: Self::Primitive) -> Self::Primitive {
224 B::int_neg(tensor)
225 }
226
227 fn sum(tensor: Self::Primitive) -> Self::Primitive {
228 B::int_sum(tensor)
229 }
230
231 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
232 B::int_sum_dim(tensor, dim)
233 }
234
235 fn prod(tensor: Self::Primitive) -> Self::Primitive {
236 B::int_prod(tensor)
237 }
238
239 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
240 B::int_prod_dim(tensor, dim)
241 }
242
243 fn mean(tensor: Self::Primitive) -> Self::Primitive {
244 B::int_mean(tensor)
245 }
246 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
247 B::int_mean_dim(tensor, dim)
248 }
249 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
250 B::int_cumsum(tensor, dim)
251 }
252 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
253 B::int_cumprod(tensor, dim)
254 }
255
256 fn abs(tensor: Self::Primitive) -> Self::Primitive {
257 B::int_abs(tensor)
258 }
259
260 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
261 B::int_powf(lhs, B::int_into_float(rhs))
262 }
263
264 fn powf_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
265 B::int_powf_scalar(lhs, rhs)
266 }
267
268 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
269 B::int_powi(lhs, rhs)
270 }
271
272 fn powi_scalar(lhs: Self::Primitive, rhs: Scalar) -> Self::Primitive {
273 B::int_powi_scalar(lhs, rhs)
274 }
275
276 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
277 B::int_random(shape, distribution, device)
278 }
279
280 fn sign(tensor: Self::Primitive) -> Self::Primitive {
281 B::int_sign(tensor)
282 }
283
284 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
292 B::int_matmul(lhs, rhs)
293 }
294}
295
296impl<B: Backend> Ordered<B> for Int {
297 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
298 B::int_sort(tensor, dim, descending)
299 }
300
301 fn sort_with_indices(
302 tensor: Self::Primitive,
303 dim: usize,
304 descending: bool,
305 ) -> (Self::Primitive, IntTensor<B>) {
306 B::int_sort_with_indices(tensor, dim, descending)
307 }
308
309 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
310 B::int_argsort(tensor, dim, descending)
311 }
312
313 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
314 B::int_cummin(tensor, dim)
315 }
316
317 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
318 B::int_cummax(tensor, dim)
319 }
320
321 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
322 B::int_greater(lhs, rhs)
323 }
324
325 fn greater_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
326 B::int_greater_elem(lhs, rhs)
327 }
328
329 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
330 B::int_greater_equal(lhs, rhs)
331 }
332
333 fn greater_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
334 B::int_greater_equal_elem(lhs, rhs)
335 }
336
337 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
338 B::int_lower(lhs, rhs)
339 }
340
341 fn lower_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
342 B::int_lower_elem(lhs, rhs)
343 }
344
345 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
346 B::int_lower_equal(lhs, rhs)
347 }
348
349 fn lower_equal_elem(lhs: Self::Primitive, rhs: Scalar) -> B::BoolTensorPrimitive {
350 B::int_lower_equal_elem(lhs, rhs)
351 }
352
353 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
354 B::int_argmax(tensor, dim)
355 }
356
357 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
358 B::int_argmin(tensor, dim)
359 }
360
361 fn max(tensor: Self::Primitive) -> Self::Primitive {
362 B::int_max(tensor)
363 }
364
365 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
366 B::int_max_dim(tensor, dim)
367 }
368
369 fn max_dim_with_indices(
370 tensor: Self::Primitive,
371 dim: usize,
372 ) -> (Self::Primitive, IntTensor<B>) {
373 B::int_max_dim_with_indices(tensor, dim)
374 }
375
376 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
377 B::int_max_abs(tensor)
378 }
379
380 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
381 B::int_max_abs_dim(tensor, dim)
382 }
383
384 fn min(tensor: Self::Primitive) -> Self::Primitive {
385 B::int_min(tensor)
386 }
387
388 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
389 B::int_min_dim(tensor, dim)
390 }
391
392 fn min_dim_with_indices(
393 tensor: Self::Primitive,
394 dim: usize,
395 ) -> (Self::Primitive, IntTensor<B>) {
396 B::int_min_dim_with_indices(tensor, dim)
397 }
398
399 fn clamp(tensor: Self::Primitive, min: Scalar, max: Scalar) -> Self::Primitive {
400 B::int_clamp(tensor, min, max)
401 }
402
403 fn clamp_min(tensor: Self::Primitive, min: Scalar) -> Self::Primitive {
404 B::int_clamp_min(tensor, min)
405 }
406
407 fn clamp_max(tensor: Self::Primitive, max: Scalar) -> Self::Primitive {
408 B::int_clamp_max(tensor, max)
409 }
410}
411
412impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
413 type InnerKind = Int;
414
415 fn inner(
416 tensor: <Self as TensorKind<B>>::Primitive,
417 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
418 B::int_inner(tensor)
419 }
420
421 fn from_inner(
422 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
423 ) -> <Self as TensorKind<B>>::Primitive {
424 B::int_from_inner(inner)
425 }
426}