burn_jit/ops/
int_ops.rs

1use super::{expand, numeric, permute};
2use crate::kernel::{launch_unary_numeric, reduce, NumericUnaryOp, NumericUnaryOpFamily};
3use crate::{
4    element::BoolElement,
5    kernel::prng::{random_bernoulli, random_normal, random_uniform},
6};
7use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
8use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
9use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData};
10use cubecl::frontend::Numeric;
11use cubecl::prelude::*;
12use std::ops::Range;
13
14impl<R, F, I, BT> IntTensorOps<Self> for JitBackend<R, F, I, BT>
15where
16    R: JitRuntime,
17    F: FloatElement,
18    I: IntElement,
19    BT: BoolElement,
20{
21    fn int_empty(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
22        super::empty::<R, I>(shape, device)
23    }
24
25    async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
26        super::into_data::<R, I>(tensor).await
27    }
28
29    fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
30        super::from_data::<R, I>(data, device)
31    }
32
33    fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
34        tensor.device.clone()
35    }
36
37    fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
38        super::to_device(tensor, device)
39    }
40
41    fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
42        super::reshape(tensor, shape)
43    }
44
45    fn int_slice(tensor: IntTensor<Self>, ranges: &[Range<usize>]) -> IntTensor<Self> {
46        kernel::slice::<R, I>(tensor, ranges)
47    }
48
49    fn int_slice_assign(
50        tensor: IntTensor<Self>,
51        ranges: &[Range<usize>],
52        value: IntTensor<Self>,
53    ) -> IntTensor<Self> {
54        kernel::slice_assign::<R, I>(tensor, ranges, value)
55    }
56
57    fn int_mask_where(
58        tensor: IntTensor<Self>,
59        mask: BoolTensor<Self>,
60        value: IntTensor<Self>,
61    ) -> IntTensor<Self> {
62        kernel::mask_where_auto::<R, I, BT>(tensor, mask, value)
63    }
64
65    fn int_mask_fill(
66        tensor: IntTensor<Self>,
67        mask: BoolTensor<Self>,
68        value: IntElem<Self>,
69    ) -> IntTensor<Self> {
70        kernel::mask_fill_auto::<R, I, BT>(tensor, mask, value)
71    }
72
73    fn int_gather(
74        dim: usize,
75        tensor: IntTensor<Self>,
76        indices: IntTensor<Self>,
77    ) -> IntTensor<Self> {
78        kernel::gather::<R, I, I>(dim, tensor, indices)
79    }
80
81    fn int_scatter(
82        dim: usize,
83        tensor: IntTensor<Self>,
84        indices: IntTensor<Self>,
85        value: IntTensor<Self>,
86    ) -> IntTensor<Self> {
87        kernel::scatter::<R, I, I>(dim, tensor, indices, value)
88    }
89
90    fn int_select(
91        tensor: IntTensor<Self>,
92        dim: usize,
93        indices: IntTensor<Self>,
94    ) -> IntTensor<Self> {
95        kernel::select::<R, I, I>(tensor, dim, indices)
96    }
97
98    fn int_select_assign(
99        tensor: IntTensor<Self>,
100        dim: usize,
101        indices: IntTensor<Self>,
102        value: IntTensor<Self>,
103    ) -> IntTensor<Self> {
104        kernel::select_assign::<R, I, I>(tensor, dim, indices, value)
105    }
106
107    fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
108        kernel::equal::<R, I, BT>(lhs, rhs)
109    }
110
111    fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
112        kernel::equal_elem::<R, I, BT>(lhs, rhs)
113    }
114
115    fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
116        kernel::greater::<R, I, BT>(lhs, rhs)
117    }
118
119    fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
120        kernel::greater_elem::<R, I, BT>(lhs, rhs)
121    }
122
123    fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
124        kernel::greater_equal::<R, I, BT>(lhs, rhs)
125    }
126
127    fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
128        kernel::greater_equal_elem::<R, I, BT>(lhs, rhs)
129    }
130
131    fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
132        kernel::lower::<R, I, BT>(lhs, rhs)
133    }
134
135    fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
136        kernel::lower_elem::<R, I, BT>(lhs, rhs)
137    }
138
139    fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
140        kernel::lower_equal::<R, I, BT>(lhs, rhs)
141    }
142
143    fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
144        kernel::lower_equal_elem::<R, I, BT>(lhs, rhs)
145    }
146
147    fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
148        numeric::add::<R, I>(lhs, rhs)
149    }
150
151    fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
152        numeric::add_scalar::<R, I>(lhs, rhs)
153    }
154
155    fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
156        numeric::sub::<R, I>(lhs, rhs)
157    }
158
159    fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
160        numeric::sub_scalar::<R, I>(lhs, rhs)
161    }
162
163    fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
164        numeric::mul::<R, I>(lhs, rhs)
165    }
166
167    fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
168        numeric::mul_scalar::<R, I>(lhs, rhs)
169    }
170
171    fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
172        numeric::div::<R, I>(lhs, rhs)
173    }
174
175    fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
176        numeric::div_scalar::<R, I>(lhs, rhs)
177    }
178
179    fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
180        numeric::remainder::<R, I>(lhs, rhs)
181    }
182
183    fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
184        numeric::remainder_scalar::<R, I>(lhs, rhs)
185    }
186
187    fn int_zeros(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
188        numeric::zeros::<R, I>(shape, device)
189    }
190
191    fn int_ones(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
192        numeric::ones::<R, I>(shape, device)
193    }
194
195    fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
196        reduce::reduce::<R, I, I, reduce::Sum>(tensor, Default::default()).unwrap()
197    }
198
199    fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
200        reduce::reduce_dim::<R, I, I, reduce::Sum>(tensor, dim, Default::default()).unwrap()
201    }
202
203    fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
204        reduce::reduce::<R, I, I, reduce::Prod>(tensor, Default::default()).unwrap()
205    }
206
207    fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
208        reduce::reduce_dim::<R, I, I, reduce::Prod>(tensor, dim, Default::default()).unwrap()
209    }
210
211    fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
212        reduce::reduce_dim::<R, I, I, reduce::Mean>(tensor, dim, Default::default()).unwrap()
213    }
214
215    fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
216        reduce::reduce_dim::<R, I, I, reduce::ArgMax>(tensor, dim, Default::default()).unwrap()
217    }
218
219    fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
220        reduce::reduce_dim::<R, I, I, reduce::ArgMin>(tensor, dim, Default::default()).unwrap()
221    }
222
223    fn int_clamp(
224        tensor: IntTensor<Self>,
225        min: IntElem<Self>,
226        max: IntElem<Self>,
227    ) -> IntTensor<Self> {
228        kernel::clamp::<R, I>(tensor, min, max)
229    }
230
231    fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
232        struct Abs;
233
234        #[cube]
235        impl<N: Numeric> NumericUnaryOp<N> for Abs {
236            type Options = ();
237
238            fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
239                Line::abs(input)
240            }
241        }
242
243        impl NumericUnaryOpFamily for Abs {
244            type Options<N: Numeric> = ();
245            type Unary<N: Numeric> = Self;
246        }
247
248        launch_unary_numeric::<R, I, Abs, _>(tensor, |_| ())
249    }
250
251    fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
252        kernel::cast::<R, I, F>(tensor)
253    }
254
255    fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
256        tensor.strides.swap(dim1, dim2);
257        tensor.shape.dims.swap(dim1, dim2);
258
259        tensor
260    }
261
262    fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
263        kernel::repeat_dim::<R, I>(tensor, dim, times)
264    }
265
266    fn int_random(
267        shape: Shape,
268        distribution: Distribution,
269        device: &Device<Self>,
270    ) -> IntTensor<Self> {
271        let float_tensor = match distribution {
272            Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 255.elem()),
273            Distribution::Uniform(low, high) => {
274                random_uniform(shape, device, low.elem::<F>(), high.elem())
275            }
276            Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::<F>()),
277            Distribution::Normal(mean, std) => {
278                random_normal(shape, device, mean.elem::<F>(), std.elem())
279            }
280        };
281
282        kernel::cast::<R, F, I>(float_tensor)
283    }
284
285    fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
286        permute(tensor, axes)
287    }
288
289    fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
290        expand(tensor, shape)
291    }
292
293    fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
294        kernel::flip::<R, I, BT>(tensor, axes)
295    }
296}