1use crate::{
2 CubeRuntime,
3 kernel::utils::{address_type, shape_divmod},
4};
5use crate::{element::CubeElement, tensor::CubeTensor};
6use crate::{
7 kernel::{
8 AddOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp,
9 launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int,
10 },
11 ops::max_vector_size,
12};
13use burn_backend::{DType, Shape, TensorMetadata};
14use burn_std::Metadata;
15use cubecl::{calculate_cube_count_elemwise, prelude::*};
16use cubecl::{client::ComputeClient, server::MemoryLayout};
17use cubecl::{
18 server::MemoryLayoutDescriptor,
19 std::{FastDivmod, tensor::layout::linear::LinearView},
20};
21
22pub fn full<R: CubeRuntime, E: CubeElement>(
24 shape: Shape,
25 device: &R::Device,
26 value: E,
27) -> CubeTensor<R> {
28 let client = R::client(device);
29
30 full_client::<R, E>(client, shape, device.clone(), value)
31}
32
33pub fn full_client<R: CubeRuntime, E: CubeElement>(
35 client: ComputeClient<R>,
36 shape: Shape,
37 device: R::Device,
38 value: E,
39) -> CubeTensor<R> {
40 let dtype = E::dtype();
41 full_device_dtype(client, shape, device, InputScalar::new(value, dtype), dtype)
42}
43
44pub fn full_device_dtype<R: CubeRuntime>(
46 client: ComputeClient<R>,
47 shape: Shape,
48 device: R::Device,
49 value: InputScalar,
50 dtype: DType,
51) -> CubeTensor<R> {
52 let empty = empty_device_dtype(client, device, shape, dtype);
53
54 #[cube(launch_unchecked, address_type = "dynamic")]
55 pub fn full_kernel<C: Numeric, N: Size>(
56 tensor: &mut LinearView<Vector<C, N>, ReadWrite>,
57 value: InputScalar,
58 #[define(C)] _dtype: StorageType,
59 ) {
60 if !tensor.is_in_bounds(ABSOLUTE_POS) {
61 terminate!();
62 }
63
64 tensor[ABSOLUTE_POS] = Vector::new(value.get::<C>());
65 }
66
67 let num_elems = empty.meta.num_elements();
68 let vector_size = max_vector_size(&empty);
69
70 let working_units = num_elems / vector_size as usize;
71 let cube_dim = CubeDim::new(&empty.client, working_units);
72 let cube_count = calculate_cube_count_elemwise(&empty.client, working_units, cube_dim);
73
74 unsafe {
75 full_kernel::launch_unchecked(
76 &empty.client,
77 cube_count,
78 cube_dim,
79 address_type!(empty),
80 vector_size,
81 empty.clone().into_linear_view(),
82 value,
83 empty.dtype.into(),
84 );
85 }
86
87 empty
88}
89
90pub fn zeros<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
92 let client = R::client(&device);
93 full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
94}
95
96pub fn ones<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
98 let client = R::client(&device);
99 full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
100}
101
102pub fn zeros_client<R: CubeRuntime>(
104 client: ComputeClient<R>,
105 device: R::Device,
106 shape: Shape,
107 dtype: DType,
108) -> CubeTensor<R> {
109 full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
110}
111
112pub fn ones_client<R: CubeRuntime>(
114 client: ComputeClient<R>,
115 device: R::Device,
116 shape: Shape,
117 dtype: DType,
118) -> CubeTensor<R> {
119 full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
120}
121
122pub fn empty_device<R: CubeRuntime, E: CubeElement>(
124 client: ComputeClient<R>,
125 device: R::Device,
126 shape: Shape,
127) -> CubeTensor<R> {
128 let MemoryLayout { memory, strides } = client.empty_tensor(shape.clone(), size_of::<E>());
129
130 CubeTensor::new(
131 client,
132 memory,
133 Metadata::new(shape, strides),
134 device,
135 E::dtype(),
136 )
137}
138
139pub fn empty_device_dtype<R: CubeRuntime>(
141 client: ComputeClient<R>,
142 device: R::Device,
143 shape: Shape,
144 dtype: DType,
145) -> CubeTensor<R> {
146 let MemoryLayout { memory, strides } = client.empty_tensor(shape.clone(), dtype.size());
147
148 CubeTensor::new(client, memory, Metadata::new(shape, strides), device, dtype)
149}
150
151pub fn empty_device_contiguous_dtype<R: CubeRuntime>(
153 client: ComputeClient<R>,
154 device: R::Device,
155 shape: Shape,
156 dtype: DType,
157) -> CubeTensor<R> {
158 let descriptor = MemoryLayoutDescriptor::contiguous(shape.clone(), dtype.size());
159 let MemoryLayout { memory, strides } = client.empty_tensors(vec![descriptor]).remove(0);
160
161 CubeTensor::new(client, memory, Metadata::new(shape, strides), device, dtype)
162}
163
164pub fn add<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
166 launch_binop::<R, AddOp>(lhs, rhs)
167}
168
169pub fn add_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
171 launch_scalar_binop::<R, AddOp>(lhs, rhs)
172}
173
174pub fn sub<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
176 launch_binop::<R, SubOp>(lhs, rhs)
177}
178
179pub fn sub_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
181 launch_scalar_binop::<R, SubOp>(lhs, rhs)
182}
183
184pub fn mul<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
186 launch_binop::<R, MulOp>(lhs, rhs)
187}
188
189pub fn mul_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
191 launch_scalar_binop::<R, MulOp>(lhs, rhs)
192}
193
194pub fn div<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
196 launch_binop::<R, DivOp>(lhs, rhs)
197}
198
199pub fn div_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
201 launch_scalar_binop::<R, DivOp>(lhs, rhs)
202}
203
204pub fn remainder<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
206 launch_binop::<R, RemainderOp>(lhs, rhs)
207}
208
209pub fn remainder_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
211 launch_scalar_binop::<R, RemainderOp>(lhs, rhs)
212}
213
214pub fn pow<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
216 launch_binop::<R, PowOp>(lhs, rhs)
217}
218
219pub fn bitwise_and<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
221 launch_binop_int::<R, BitwiseAndOp>(lhs, rhs)
222}
223
224pub fn bitwise_and_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
226 launch_scalar_binop_int::<R, BitwiseAndOp>(lhs, rhs)
227}
228
229pub fn bitwise_or<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
231 launch_binop_int::<R, BitwiseOrOp>(lhs, rhs)
232}
233
234pub fn bitwise_or_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
236 launch_scalar_binop_int::<R, BitwiseOrOp>(lhs, rhs)
237}
238
239pub fn bitwise_xor<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
241 launch_binop_int::<R, BitwiseXorOp>(lhs, rhs)
242}
243
244pub fn bitwise_xor_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
246 launch_scalar_binop_int::<R, BitwiseXorOp>(lhs, rhs)
247}
248
249pub(crate) trait CumulativeOpFamily: Send + Sync + 'static {
251 type CumulativeOp<C: Numeric>: CumulativeOp<C>;
252}
253
254#[cube]
256pub(crate) trait CumulativeOp<C: Numeric>: 'static + Send + Sync {
257 fn execute(lhs: C, rhs: C) -> C;
259
260 fn init_value(first_element: C) -> C;
262}
263
264struct SumOp;
266struct ProdOp;
267struct MaxOp;
268struct MinOp;
269
270impl CumulativeOpFamily for SumOp {
272 type CumulativeOp<C: Numeric> = Self;
273}
274
275impl CumulativeOpFamily for ProdOp {
276 type CumulativeOp<C: Numeric> = Self;
277}
278
279impl CumulativeOpFamily for MaxOp {
280 type CumulativeOp<C: Numeric> = Self;
281}
282
283impl CumulativeOpFamily for MinOp {
284 type CumulativeOp<C: Numeric> = Self;
285}
286
287#[cube]
289impl<N: Numeric> CumulativeOp<N> for SumOp {
290 fn execute(lhs: N, rhs: N) -> N {
291 lhs + rhs
292 }
293
294 fn init_value(_first_element: N) -> N {
295 N::zero()
296 }
297}
298
299#[cube]
300impl<N: Numeric> CumulativeOp<N> for ProdOp {
301 fn execute(lhs: N, rhs: N) -> N {
302 lhs * rhs
303 }
304
305 fn init_value(_first_element: N) -> N {
306 N::from_int(1)
307 }
308}
309
310#[cube]
311impl<N: Numeric> CumulativeOp<N> for MaxOp {
312 fn execute(lhs: N, rhs: N) -> N {
313 max(lhs, rhs)
314 }
315
316 fn init_value(first_element: N) -> N {
317 first_element
318 }
319}
320
321#[cube]
322impl<N: Numeric> CumulativeOp<N> for MinOp {
323 fn execute(lhs: N, rhs: N) -> N {
324 min(lhs, rhs)
325 }
326
327 fn init_value(first_element: N) -> N {
328 first_element
329 }
330}
331
332#[cube(launch_unchecked, address_type = "dynamic")]
346fn cumulative_kernel<C: Numeric, O: CumulativeOpFamily>(
347 input: &Tensor<C>,
348 output: &mut LinearView<C, ReadWrite>,
349 shape: Sequence<FastDivmod<usize>>,
350 #[comptime] dim: usize,
351 #[define(C)] _dtype: StorageType,
352) {
353 if !output.is_in_bounds(ABSOLUTE_POS) {
354 terminate!();
355 }
356
357 let rank = comptime![shape.len()];
358 let dim_stride = input.stride(dim);
359
360 let mut remainder = ABSOLUTE_POS;
361 let mut offset = 0;
362 let mut dim_idx = 0;
363
364 #[unroll]
365 for i in 0..shape.len() {
366 let i = comptime![rank - i - 1];
367 let (rem, local_idx) = shape.index(i).div_mod(remainder);
368 remainder = rem;
369 if i == dim {
370 dim_idx = local_idx;
371 } else {
372 offset += local_idx * input.stride(i);
373 }
374 }
375
376 let first_read_idx = offset + dim_idx * dim_stride;
378 let first_elem = input[first_read_idx];
379
380 let mut result = O::CumulativeOp::<C>::init_value(first_elem);
382
383 for i in 0..=dim_idx {
385 let read_idx = offset + i * dim_stride;
386 result = O::CumulativeOp::<C>::execute(result, input[read_idx]);
387 }
388 output[ABSOLUTE_POS] = result;
389}
390
391pub fn cumsum<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
393 cumulative_op::<R, SumOp>(input, dim)
394}
395
396pub fn cumprod<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
398 cumulative_op::<R, ProdOp>(input, dim)
399}
400
401pub fn cummin<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
403 cumulative_op::<R, MinOp>(input, dim)
404}
405
406pub fn cummax<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
408 cumulative_op::<R, MaxOp>(input, dim)
409}
410
411fn cumulative_op<R: CubeRuntime, O: CumulativeOpFamily>(
413 input: CubeTensor<R>,
414 dim: usize,
415) -> CubeTensor<R> {
416 let client = input.client.clone();
417 let device = input.device.clone();
418
419 let output = empty_device_dtype(client.clone(), device, input.shape(), input.dtype);
420
421 let num_elems = output.meta.num_elements();
422 let working_units = num_elems;
423 let cube_dim = CubeDim::new(&client, working_units);
424 let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
425 let shape = shape_divmod(&input);
426
427 unsafe {
428 cumulative_kernel::launch_unchecked::<O, R>(
429 &client,
430 cube_count,
431 cube_dim,
432 address_type!(input, output),
433 input.into_tensor_arg(),
434 output.clone().into_linear_view(),
435 shape,
436 dim,
437 output.dtype.into(),
438 );
439 }
440
441 output
442}