1use super::{expand, numeric, permute};
2use crate::kernel::{launch_unary_numeric, reduce, NumericUnaryOp, NumericUnaryOpFamily};
3use crate::{
4 element::BoolElement,
5 kernel::prng::{random_bernoulli, random_normal, random_uniform},
6};
7use crate::{kernel, FloatElement, IntElement, JitBackend, JitRuntime};
8use burn_tensor::ops::{BoolTensor, Device, FloatTensor, IntElem, IntTensor};
9use burn_tensor::{ops::IntTensorOps, Distribution, ElementConversion, Shape, TensorData};
10use cubecl::frontend::Numeric;
11use cubecl::prelude::*;
12use std::ops::Range;
13
14impl<R, F, I, BT> IntTensorOps<Self> for JitBackend<R, F, I, BT>
15where
16 R: JitRuntime,
17 F: FloatElement,
18 I: IntElement,
19 BT: BoolElement,
20{
21 fn int_empty(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
22 super::empty::<R, I>(shape, device)
23 }
24
25 async fn int_into_data(tensor: IntTensor<Self>) -> TensorData {
26 super::into_data::<R, I>(tensor).await
27 }
28
29 fn int_from_data(data: TensorData, device: &Device<Self>) -> IntTensor<Self> {
30 super::from_data::<R, I>(data, device)
31 }
32
33 fn int_device(tensor: &IntTensor<Self>) -> Device<Self> {
34 tensor.device.clone()
35 }
36
37 fn int_to_device(tensor: IntTensor<Self>, device: &Device<Self>) -> IntTensor<Self> {
38 super::to_device(tensor, device)
39 }
40
41 fn int_reshape(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
42 super::reshape(tensor, shape)
43 }
44
45 fn int_slice(tensor: IntTensor<Self>, ranges: &[Range<usize>]) -> IntTensor<Self> {
46 kernel::slice::<R, I>(tensor, ranges)
47 }
48
49 fn int_slice_assign(
50 tensor: IntTensor<Self>,
51 ranges: &[Range<usize>],
52 value: IntTensor<Self>,
53 ) -> IntTensor<Self> {
54 kernel::slice_assign::<R, I>(tensor, ranges, value)
55 }
56
57 fn int_mask_where(
58 tensor: IntTensor<Self>,
59 mask: BoolTensor<Self>,
60 value: IntTensor<Self>,
61 ) -> IntTensor<Self> {
62 kernel::mask_where_auto::<R, I, BT>(tensor, mask, value)
63 }
64
65 fn int_mask_fill(
66 tensor: IntTensor<Self>,
67 mask: BoolTensor<Self>,
68 value: IntElem<Self>,
69 ) -> IntTensor<Self> {
70 kernel::mask_fill_auto::<R, I, BT>(tensor, mask, value)
71 }
72
73 fn int_gather(
74 dim: usize,
75 tensor: IntTensor<Self>,
76 indices: IntTensor<Self>,
77 ) -> IntTensor<Self> {
78 kernel::gather::<R, I, I>(dim, tensor, indices)
79 }
80
81 fn int_scatter(
82 dim: usize,
83 tensor: IntTensor<Self>,
84 indices: IntTensor<Self>,
85 value: IntTensor<Self>,
86 ) -> IntTensor<Self> {
87 kernel::scatter::<R, I, I>(dim, tensor, indices, value)
88 }
89
90 fn int_select(
91 tensor: IntTensor<Self>,
92 dim: usize,
93 indices: IntTensor<Self>,
94 ) -> IntTensor<Self> {
95 kernel::select::<R, I, I>(tensor, dim, indices)
96 }
97
98 fn int_select_assign(
99 tensor: IntTensor<Self>,
100 dim: usize,
101 indices: IntTensor<Self>,
102 value: IntTensor<Self>,
103 ) -> IntTensor<Self> {
104 kernel::select_assign::<R, I, I>(tensor, dim, indices, value)
105 }
106
107 fn int_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
108 kernel::equal::<R, I, BT>(lhs, rhs)
109 }
110
111 fn int_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
112 kernel::equal_elem::<R, I, BT>(lhs, rhs)
113 }
114
115 fn int_greater(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
116 kernel::greater::<R, I, BT>(lhs, rhs)
117 }
118
119 fn int_greater_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
120 kernel::greater_elem::<R, I, BT>(lhs, rhs)
121 }
122
123 fn int_greater_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
124 kernel::greater_equal::<R, I, BT>(lhs, rhs)
125 }
126
127 fn int_greater_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
128 kernel::greater_equal_elem::<R, I, BT>(lhs, rhs)
129 }
130
131 fn int_lower(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
132 kernel::lower::<R, I, BT>(lhs, rhs)
133 }
134
135 fn int_lower_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
136 kernel::lower_elem::<R, I, BT>(lhs, rhs)
137 }
138
139 fn int_lower_equal(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> BoolTensor<Self> {
140 kernel::lower_equal::<R, I, BT>(lhs, rhs)
141 }
142
143 fn int_lower_equal_elem(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> BoolTensor<Self> {
144 kernel::lower_equal_elem::<R, I, BT>(lhs, rhs)
145 }
146
147 fn int_add(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
148 numeric::add::<R, I>(lhs, rhs)
149 }
150
151 fn int_add_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
152 numeric::add_scalar::<R, I>(lhs, rhs)
153 }
154
155 fn int_sub(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
156 numeric::sub::<R, I>(lhs, rhs)
157 }
158
159 fn int_sub_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
160 numeric::sub_scalar::<R, I>(lhs, rhs)
161 }
162
163 fn int_mul(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
164 numeric::mul::<R, I>(lhs, rhs)
165 }
166
167 fn int_mul_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
168 numeric::mul_scalar::<R, I>(lhs, rhs)
169 }
170
171 fn int_div(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
172 numeric::div::<R, I>(lhs, rhs)
173 }
174
175 fn int_div_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
176 numeric::div_scalar::<R, I>(lhs, rhs)
177 }
178
179 fn int_remainder(lhs: IntTensor<Self>, rhs: IntTensor<Self>) -> IntTensor<Self> {
180 numeric::remainder::<R, I>(lhs, rhs)
181 }
182
183 fn int_remainder_scalar(lhs: IntTensor<Self>, rhs: IntElem<Self>) -> IntTensor<Self> {
184 numeric::remainder_scalar::<R, I>(lhs, rhs)
185 }
186
187 fn int_zeros(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
188 numeric::zeros::<R, I>(shape, device)
189 }
190
191 fn int_ones(shape: Shape, device: &Device<Self>) -> IntTensor<Self> {
192 numeric::ones::<R, I>(shape, device)
193 }
194
195 fn int_sum(tensor: IntTensor<Self>) -> IntTensor<Self> {
196 reduce::reduce::<R, I, I, reduce::Sum>(tensor, Default::default()).unwrap()
197 }
198
199 fn int_sum_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
200 reduce::reduce_dim::<R, I, I, reduce::Sum>(tensor, dim, Default::default()).unwrap()
201 }
202
203 fn int_prod(tensor: IntTensor<Self>) -> IntTensor<Self> {
204 reduce::reduce::<R, I, I, reduce::Prod>(tensor, Default::default()).unwrap()
205 }
206
207 fn int_prod_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
208 reduce::reduce_dim::<R, I, I, reduce::Prod>(tensor, dim, Default::default()).unwrap()
209 }
210
211 fn int_mean_dim(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
212 reduce::reduce_dim::<R, I, I, reduce::Mean>(tensor, dim, Default::default()).unwrap()
213 }
214
215 fn int_argmax(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
216 reduce::reduce_dim::<R, I, I, reduce::ArgMax>(tensor, dim, Default::default()).unwrap()
217 }
218
219 fn int_argmin(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
220 reduce::reduce_dim::<R, I, I, reduce::ArgMin>(tensor, dim, Default::default()).unwrap()
221 }
222
223 fn int_clamp(
224 tensor: IntTensor<Self>,
225 min: IntElem<Self>,
226 max: IntElem<Self>,
227 ) -> IntTensor<Self> {
228 kernel::clamp::<R, I>(tensor, min, max)
229 }
230
231 fn int_abs(tensor: IntTensor<Self>) -> IntTensor<Self> {
232 struct Abs;
233
234 #[cube]
235 impl<N: Numeric> NumericUnaryOp<N> for Abs {
236 type Options = ();
237
238 fn execute(input: Line<N>, _options: &Self::Options) -> Line<N> {
239 Line::abs(input)
240 }
241 }
242
243 impl NumericUnaryOpFamily for Abs {
244 type Options<N: Numeric> = ();
245 type Unary<N: Numeric> = Self;
246 }
247
248 launch_unary_numeric::<R, I, Abs, _>(tensor, |_| ())
249 }
250
251 fn int_into_float(tensor: IntTensor<Self>) -> FloatTensor<Self> {
252 kernel::cast::<R, I, F>(tensor)
253 }
254
255 fn int_swap_dims(mut tensor: IntTensor<Self>, dim1: usize, dim2: usize) -> IntTensor<Self> {
256 tensor.strides.swap(dim1, dim2);
257 tensor.shape.dims.swap(dim1, dim2);
258
259 tensor
260 }
261
262 fn int_repeat_dim(tensor: IntTensor<Self>, dim: usize, times: usize) -> IntTensor<Self> {
263 kernel::repeat_dim::<R, I>(tensor, dim, times)
264 }
265
266 fn int_random(
267 shape: Shape,
268 distribution: Distribution,
269 device: &Device<Self>,
270 ) -> IntTensor<Self> {
271 let float_tensor = match distribution {
272 Distribution::Default => random_uniform(shape, device, 0.elem::<F>(), 255.elem()),
273 Distribution::Uniform(low, high) => {
274 random_uniform(shape, device, low.elem::<F>(), high.elem())
275 }
276 Distribution::Bernoulli(prob) => random_bernoulli(shape, device, prob.elem::<F>()),
277 Distribution::Normal(mean, std) => {
278 random_normal(shape, device, mean.elem::<F>(), std.elem())
279 }
280 };
281
282 kernel::cast::<R, F, I>(float_tensor)
283 }
284
285 fn int_permute(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
286 permute(tensor, axes)
287 }
288
289 fn int_expand(tensor: IntTensor<Self>, shape: Shape) -> IntTensor<Self> {
290 expand(tensor, shape)
291 }
292
293 fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
294 kernel::flip::<R, I, BT>(tensor, axes)
295 }
296}