1use alloc::vec::Vec;
2use burn_std::{DType, Shape, Slice};
3
4use crate::{
5 AutodiffBackend, Backend, Distribution, ExecutionError, TensorData,
6 element::ElementConversion,
7 ops::TransactionPrimitive,
8 tensor::{
9 BasicAutodiffOps, BasicOps, BoolTensor, Device, IndexingUpdateOp, Int, IntTensor, Numeric,
10 TensorKind,
11 },
12};
13
14impl<B: Backend> BasicOps<B> for Int {
15 type Elem = B::IntElem;
16
17 fn empty(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
18 B::int_empty(shape, device, dtype.into())
19 }
20
21 fn zeros(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
22 B::int_zeros(shape, device, dtype.into())
23 }
24 fn ones(shape: Shape, device: &Device<B>, dtype: DType) -> Self::Primitive {
25 B::int_ones(shape, device, dtype.into())
26 }
27
28 fn full<E: ElementConversion>(
29 shape: Shape,
30 fill_value: E,
31 device: &Device<B>,
32 dtype: DType,
33 ) -> Self::Primitive {
34 B::int_full(shape, fill_value.elem(), device, dtype.into())
35 }
36
37 fn register_transaction(tr: &mut TransactionPrimitive<B>, tensor: Self::Primitive) {
38 tr.register_int(tensor);
39 }
40
41 fn reshape(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
42 B::int_reshape(tensor, shape)
43 }
44
45 fn transpose(tensor: Self::Primitive) -> Self::Primitive {
46 B::int_transpose(tensor)
47 }
48
49 fn swap_dims(tensor: Self::Primitive, dim1: usize, dim2: usize) -> Self::Primitive {
50 B::int_swap_dims(tensor, dim1, dim2)
51 }
52
53 fn slice(tensor: Self::Primitive, slices: &[Slice]) -> Self::Primitive {
54 B::int_slice(tensor, slices)
55 }
56
57 fn slice_assign(
58 tensor: Self::Primitive,
59 slices: &[Slice],
60 value: Self::Primitive,
61 ) -> Self::Primitive {
62 B::int_slice_assign(tensor, slices, value)
63 }
64
65 fn select(tensor: Self::Primitive, dim: usize, indices: IntTensor<B>) -> Self::Primitive {
66 B::int_select(tensor, dim, indices)
67 }
68
69 fn select_assign(
70 tensor: Self::Primitive,
71 dim: usize,
72 indices: IntTensor<B>,
73 values: Self::Primitive,
74 update: IndexingUpdateOp,
75 ) -> Self::Primitive {
76 match update {
77 IndexingUpdateOp::Add => B::int_select_add(tensor, dim, indices, values),
78 }
79 }
80
81 fn mask_where(
82 tensor: Self::Primitive,
83 mask: B::BoolTensorPrimitive,
84 source: Self::Primitive,
85 ) -> Self::Primitive {
86 B::int_mask_where(tensor, mask, source)
87 }
88
89 fn mask_fill(
90 tensor: Self::Primitive,
91 mask: B::BoolTensorPrimitive,
92 value: Self::Elem,
93 ) -> Self::Primitive {
94 B::int_mask_fill(tensor, mask, value)
95 }
96
97 fn gather(
98 dim: usize,
99 tensor: Self::Primitive,
100 indices: B::IntTensorPrimitive,
101 ) -> Self::Primitive {
102 B::int_gather(dim, tensor, indices)
103 }
104
105 fn scatter(
106 dim: usize,
107 tensor: Self::Primitive,
108 indices: B::IntTensorPrimitive,
109 values: Self::Primitive,
110 update: IndexingUpdateOp,
111 ) -> Self::Primitive {
112 match update {
113 IndexingUpdateOp::Add => B::int_scatter_add(dim, tensor, indices, values),
114 }
115 }
116
117 fn device(tensor: &Self::Primitive) -> Device<B> {
118 B::int_device(tensor)
119 }
120
121 fn to_device(tensor: Self::Primitive, device: &Device<B>) -> Self::Primitive {
122 B::int_to_device(tensor, device)
123 }
124
125 async fn into_data_async(tensor: Self::Primitive) -> Result<TensorData, ExecutionError> {
126 B::int_into_data(tensor).await
127 }
128
129 fn from_data(data: TensorData, device: &Device<B>) -> Self::Primitive {
130 B::int_from_data(data.convert::<B::IntElem>(), device)
131 }
132
133 fn from_data_dtype(data: TensorData, device: &Device<B>, dtype: DType) -> Self::Primitive {
134 if !dtype.is_int() {
135 panic!("Expected int dtype, got {dtype:?}")
136 }
137
138 B::int_from_data(data.convert_dtype(dtype), device)
139 }
140
141 fn repeat_dim(tensor: Self::Primitive, dim: usize, times: usize) -> Self::Primitive {
142 B::int_repeat_dim(tensor, dim, times)
143 }
144
145 fn equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
146 B::int_equal(lhs, rhs)
147 }
148
149 fn not_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> BoolTensor<B> {
150 B::int_not_equal(lhs, rhs)
151 }
152
153 fn equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
154 B::int_equal_elem(lhs, rhs)
155 }
156
157 fn not_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
158 B::int_not_equal_elem(lhs, rhs)
159 }
160
161 fn cat(vectors: Vec<Self::Primitive>, dim: usize) -> Self::Primitive {
162 B::int_cat(vectors, dim)
163 }
164
165 fn any(tensor: Self::Primitive) -> BoolTensor<B> {
166 B::int_any(tensor)
167 }
168
169 fn any_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
170 B::int_any_dim(tensor, dim)
171 }
172
173 fn all(tensor: Self::Primitive) -> BoolTensor<B> {
174 B::int_all(tensor)
175 }
176
177 fn all_dim(tensor: Self::Primitive, dim: usize) -> BoolTensor<B> {
178 B::int_all_dim(tensor, dim)
179 }
180
181 fn permute(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
182 B::int_permute(tensor, axes)
183 }
184
185 fn expand(tensor: Self::Primitive, shape: Shape) -> Self::Primitive {
186 B::int_expand(tensor, shape)
187 }
188
189 fn flip(tensor: Self::Primitive, axes: &[usize]) -> Self::Primitive {
190 B::int_flip(tensor, axes)
191 }
192
193 fn unfold(tensor: Self::Primitive, dim: usize, size: usize, step: usize) -> Self::Primitive {
194 B::int_unfold(tensor, dim, size, step)
195 }
196}
197
198impl<B: Backend> Numeric<B> for Int {
199 fn add(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
200 B::int_add(lhs, rhs)
201 }
202 fn add_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
203 B::int_add_scalar(lhs, rhs.elem())
204 }
205 fn sub(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
206 B::int_sub(lhs, rhs)
207 }
208 fn sub_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
209 B::int_sub_scalar(lhs, rhs.elem())
210 }
211 fn div(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
212 B::int_div(lhs, rhs)
213 }
214 fn div_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
215 B::int_div_scalar(lhs, rhs.elem())
216 }
217 fn remainder(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
218 B::int_remainder(lhs, rhs)
219 }
220 fn remainder_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
221 B::int_remainder_scalar(lhs, rhs.elem())
222 }
223 fn mul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
224 B::int_mul(lhs, rhs)
225 }
226 fn mul_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
227 B::int_mul_scalar(lhs, rhs.elem())
228 }
229 fn neg(tensor: Self::Primitive) -> Self::Primitive {
230 B::int_neg(tensor)
231 }
232
233 fn sum(tensor: Self::Primitive) -> Self::Primitive {
234 B::int_sum(tensor)
235 }
236
237 fn sum_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
238 B::int_sum_dim(tensor, dim)
239 }
240
241 fn prod(tensor: Self::Primitive) -> Self::Primitive {
242 B::int_prod(tensor)
243 }
244
245 fn prod_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
246 B::int_prod_dim(tensor, dim)
247 }
248
249 fn mean(tensor: Self::Primitive) -> Self::Primitive {
250 B::int_mean(tensor)
251 }
252 fn mean_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
253 B::int_mean_dim(tensor, dim)
254 }
255 fn cumsum(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
256 B::int_cumsum(tensor, dim)
257 }
258 fn cumprod(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
259 B::int_cumprod(tensor, dim)
260 }
261
262 fn cummin(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
263 B::int_cummin(tensor, dim)
264 }
265
266 fn cummax(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
267 B::int_cummax(tensor, dim)
268 }
269
270 fn greater(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
271 B::int_greater(lhs, rhs)
272 }
273
274 fn greater_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
275 B::int_greater_elem(lhs, rhs)
276 }
277
278 fn greater_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
279 B::int_greater_equal(lhs, rhs)
280 }
281
282 fn greater_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
283 B::int_greater_equal_elem(lhs, rhs)
284 }
285
286 fn lower(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
287 B::int_lower(lhs, rhs)
288 }
289
290 fn lower_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
291 B::int_lower_elem(lhs, rhs)
292 }
293
294 fn lower_equal(lhs: Self::Primitive, rhs: Self::Primitive) -> B::BoolTensorPrimitive {
295 B::int_lower_equal(lhs, rhs)
296 }
297
298 fn lower_equal_elem(lhs: Self::Primitive, rhs: Self::Elem) -> B::BoolTensorPrimitive {
299 B::int_lower_equal_elem(lhs, rhs)
300 }
301
302 fn argmax(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
303 B::int_argmax(tensor, dim)
304 }
305
306 fn argmin(tensor: Self::Primitive, dim: usize) -> IntTensor<B> {
307 B::int_argmin(tensor, dim)
308 }
309
310 fn max(tensor: Self::Primitive) -> Self::Primitive {
311 B::int_max(tensor)
312 }
313
314 fn max_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
315 B::int_max_dim(tensor, dim)
316 }
317
318 fn max_dim_with_indices(
319 tensor: Self::Primitive,
320 dim: usize,
321 ) -> (Self::Primitive, IntTensor<B>) {
322 B::int_max_dim_with_indices(tensor, dim)
323 }
324
325 fn max_abs(tensor: Self::Primitive) -> Self::Primitive {
326 B::int_max_abs(tensor)
327 }
328
329 fn max_abs_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
330 B::int_max_abs_dim(tensor, dim)
331 }
332
333 fn min(tensor: Self::Primitive) -> Self::Primitive {
334 B::int_min(tensor)
335 }
336
337 fn min_dim(tensor: Self::Primitive, dim: usize) -> Self::Primitive {
338 B::int_min_dim(tensor, dim)
339 }
340
341 fn min_dim_with_indices(
342 tensor: Self::Primitive,
343 dim: usize,
344 ) -> (Self::Primitive, IntTensor<B>) {
345 B::int_min_dim_with_indices(tensor, dim)
346 }
347
348 fn clamp(tensor: Self::Primitive, min: B::IntElem, max: B::IntElem) -> Self::Primitive {
349 B::int_clamp(tensor, min, max)
350 }
351
352 fn clamp_min(tensor: Self::Primitive, min: B::IntElem) -> Self::Primitive {
353 B::int_clamp_min(tensor, min)
354 }
355
356 fn clamp_max(tensor: Self::Primitive, max: B::IntElem) -> Self::Primitive {
357 B::int_clamp_max(tensor, max)
358 }
359
360 fn abs(tensor: Self::Primitive) -> Self::Primitive {
361 B::int_abs(tensor)
362 }
363
364 fn powf(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
365 B::int_powf(lhs, B::int_into_float(rhs))
366 }
367
368 fn powf_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
369 B::int_powf_scalar(lhs, rhs.elem())
370 }
371
372 fn powi(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
373 B::int_powi(lhs, rhs)
374 }
375
376 fn powi_scalar<E: ElementConversion>(lhs: Self::Primitive, rhs: E) -> Self::Primitive {
377 B::int_powi_scalar(lhs, rhs.elem())
378 }
379
380 fn random(shape: Shape, distribution: Distribution, device: &Device<B>) -> Self::Primitive {
381 B::int_random(shape, distribution, device)
382 }
383
384 fn sign(tensor: Self::Primitive) -> Self::Primitive {
385 B::int_sign(tensor)
386 }
387
388 fn sort(tensor: Self::Primitive, dim: usize, descending: bool) -> Self::Primitive {
389 B::int_sort(tensor, dim, descending)
390 }
391
392 fn sort_with_indices(
393 tensor: Self::Primitive,
394 dim: usize,
395 descending: bool,
396 ) -> (Self::Primitive, IntTensor<B>) {
397 B::int_sort_with_indices(tensor, dim, descending)
398 }
399
400 fn argsort(tensor: Self::Primitive, dim: usize, descending: bool) -> IntTensor<B> {
401 B::int_argsort(tensor, dim, descending)
402 }
403
404 fn matmul(lhs: Self::Primitive, rhs: Self::Primitive) -> Self::Primitive {
412 B::int_matmul(lhs, rhs)
413 }
414}
415
416impl<B: AutodiffBackend> BasicAutodiffOps<B> for Int {
417 type InnerKind = Int;
418
419 fn inner(
420 tensor: <Self as TensorKind<B>>::Primitive,
421 ) -> <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive {
422 B::int_inner(tensor)
423 }
424
425 fn from_inner(
426 inner: <Self::InnerKind as TensorKind<<B as AutodiffBackend>::InnerBackend>>::Primitive,
427 ) -> <Self as TensorKind<B>>::Primitive {
428 B::int_from_inner(inner)
429 }
430}