1use crate::{
2 CubeRuntime,
3 kernel::utils::{linear_view, shape_divmod},
4};
5use crate::{element::CubeElement, tensor::CubeTensor};
6use crate::{
7 kernel::{
8 AddOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp,
9 launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int,
10 },
11 ops::max_line_size,
12};
13use burn_backend::{DType, Shape};
14use cubecl::std::{FastDivmod, tensor::layout::linear::LinearView};
15use cubecl::{calculate_cube_count_elemwise, prelude::*};
16use cubecl::{client::ComputeClient, server::Allocation};
17
18pub fn full<R: CubeRuntime, E: CubeElement>(
20 shape: Shape,
21 device: &R::Device,
22 value: E,
23) -> CubeTensor<R> {
24 let client = R::client(device);
25
26 full_client::<R, E>(client, shape, device.clone(), value)
27}
28
29pub fn full_client<R: CubeRuntime, E: CubeElement>(
31 client: ComputeClient<R>,
32 shape: Shape,
33 device: R::Device,
34 value: E,
35) -> CubeTensor<R> {
36 let dtype = E::dtype();
37 full_device_dtype(client, shape, device, InputScalar::new(value, dtype), dtype)
38}
39
40pub fn full_device_dtype<R: CubeRuntime>(
42 client: ComputeClient<R>,
43 shape: Shape,
44 device: R::Device,
45 value: InputScalar,
46 dtype: DType,
47) -> CubeTensor<R> {
48 let empty = empty_device_dtype(client, device, shape, dtype);
49
50 #[cube(launch_unchecked)]
51 pub fn full_kernel<C: Numeric>(
52 tensor: &mut LinearView<C, ReadWrite>,
53 value: InputScalar,
54 #[define(C)] _dtype: StorageType,
55 ) {
56 if !tensor.is_in_bounds(ABSOLUTE_POS) {
57 terminate!();
58 }
59
60 tensor[ABSOLUTE_POS] = value.get::<C>();
61 }
62
63 let num_elems = empty.shape.num_elements();
64 let line_size = max_line_size(&empty);
65
66 let working_units = num_elems / line_size as usize;
67 let cube_dim = CubeDim::new(&empty.client, working_units);
68 let cube_count = calculate_cube_count_elemwise(&empty.client, working_units, cube_dim);
69
70 unsafe {
71 full_kernel::launch_unchecked(
72 &empty.client,
73 cube_count,
74 cube_dim,
75 linear_view(&empty, line_size),
76 value,
77 dtype.into(),
78 )
79 .expect("Kernel to never fail");
80 }
81
82 empty
83}
84
85pub fn zeros<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
87 let client = R::client(&device);
88 full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
89}
90
91pub fn ones<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
93 let client = R::client(&device);
94 full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
95}
96
97pub fn zeros_client<R: CubeRuntime>(
99 client: ComputeClient<R>,
100 device: R::Device,
101 shape: Shape,
102 dtype: DType,
103) -> CubeTensor<R> {
104 full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
105}
106
107pub fn ones_client<R: CubeRuntime>(
109 client: ComputeClient<R>,
110 device: R::Device,
111 shape: Shape,
112 dtype: DType,
113) -> CubeTensor<R> {
114 full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
115}
116
117pub fn empty_device<R: CubeRuntime, E: CubeElement>(
119 client: ComputeClient<R>,
120 device: R::Device,
121 shape: Shape,
122) -> CubeTensor<R> {
123 empty_device_dtype(client, device, shape, E::dtype())
124}
125
126pub fn empty_device_dtype<R: CubeRuntime>(
128 client: ComputeClient<R>,
129 device: R::Device,
130 shape: Shape,
131 dtype: DType,
132) -> CubeTensor<R> {
133 let buffer = client.empty(shape.num_elements() * dtype.size());
134
135 CubeTensor::new_contiguous(client, device, shape, buffer, dtype)
136}
137
138pub fn empty_device_optimized<R: CubeRuntime, E: CubeElement>(
140 client: ComputeClient<R>,
141 device: R::Device,
142 shape: Shape,
143) -> CubeTensor<R> {
144 let Allocation { handle, strides } = client.empty_tensor(&shape.dims, size_of::<E>());
145
146 CubeTensor::new(client, handle, shape, device, strides, E::dtype())
147}
148
149pub fn empty_device_optimized_dtype<R: CubeRuntime>(
151 client: ComputeClient<R>,
152 device: R::Device,
153 shape: Shape,
154 dtype: DType,
155) -> CubeTensor<R> {
156 let Allocation { handle, strides } = client.empty_tensor(&shape.dims, dtype.size());
157
158 CubeTensor::new(client, handle, shape, device, strides, dtype)
159}
160
161pub fn add<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
163 launch_binop::<R, AddOp>(lhs, rhs)
164}
165
166pub fn add_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
168 launch_scalar_binop::<R, AddOp>(lhs, rhs)
169}
170
171pub fn sub<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
173 launch_binop::<R, SubOp>(lhs, rhs)
174}
175
176pub fn sub_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
178 launch_scalar_binop::<R, SubOp>(lhs, rhs)
179}
180
181pub fn mul<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
183 launch_binop::<R, MulOp>(lhs, rhs)
184}
185
186pub fn mul_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
188 launch_scalar_binop::<R, MulOp>(lhs, rhs)
189}
190
191pub fn div<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
193 launch_binop::<R, DivOp>(lhs, rhs)
194}
195
196pub fn div_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
198 launch_scalar_binop::<R, DivOp>(lhs, rhs)
199}
200
201pub fn remainder<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
203 launch_binop::<R, RemainderOp>(lhs, rhs)
204}
205
206pub fn remainder_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
208 launch_scalar_binop::<R, RemainderOp>(lhs, rhs)
209}
210
211pub fn pow<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
213 launch_binop::<R, PowOp>(lhs, rhs)
214}
215
216pub fn bitwise_and<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
218 launch_binop_int::<R, BitwiseAndOp>(lhs, rhs)
219}
220
221pub fn bitwise_and_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
223 launch_scalar_binop_int::<R, BitwiseAndOp>(lhs, rhs)
224}
225
226pub fn bitwise_or<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
228 launch_binop_int::<R, BitwiseOrOp>(lhs, rhs)
229}
230
231pub fn bitwise_or_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
233 launch_scalar_binop_int::<R, BitwiseOrOp>(lhs, rhs)
234}
235
236pub fn bitwise_xor<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
238 launch_binop_int::<R, BitwiseXorOp>(lhs, rhs)
239}
240
241pub fn bitwise_xor_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
243 launch_scalar_binop_int::<R, BitwiseXorOp>(lhs, rhs)
244}
245
246pub(crate) trait CumulativeOpFamily: Send + Sync + 'static {
248 type CumulativeOp<C: Numeric>: CumulativeOp<C>;
249}
250
251#[cube]
253pub(crate) trait CumulativeOp<C: Numeric>: 'static + Send + Sync {
254 fn execute(lhs: C, rhs: C) -> C;
256
257 fn init_value(first_element: C) -> C;
259}
260
261struct SumOp;
263struct ProdOp;
264struct MaxOp;
265struct MinOp;
266
267impl CumulativeOpFamily for SumOp {
269 type CumulativeOp<C: Numeric> = Self;
270}
271
272impl CumulativeOpFamily for ProdOp {
273 type CumulativeOp<C: Numeric> = Self;
274}
275
276impl CumulativeOpFamily for MaxOp {
277 type CumulativeOp<C: Numeric> = Self;
278}
279
280impl CumulativeOpFamily for MinOp {
281 type CumulativeOp<C: Numeric> = Self;
282}
283
284#[cube]
286impl<N: Numeric> CumulativeOp<N> for SumOp {
287 fn execute(lhs: N, rhs: N) -> N {
288 lhs + rhs
289 }
290
291 fn init_value(_first_element: N) -> N {
292 N::from_int(0)
293 }
294}
295
296#[cube]
297impl<N: Numeric> CumulativeOp<N> for ProdOp {
298 fn execute(lhs: N, rhs: N) -> N {
299 lhs * rhs
300 }
301
302 fn init_value(_first_element: N) -> N {
303 N::from_int(1)
304 }
305}
306
307#[cube]
308impl<N: Numeric> CumulativeOp<N> for MaxOp {
309 fn execute(lhs: N, rhs: N) -> N {
310 max(lhs, rhs)
311 }
312
313 fn init_value(first_element: N) -> N {
314 first_element
315 }
316}
317
318#[cube]
319impl<N: Numeric> CumulativeOp<N> for MinOp {
320 fn execute(lhs: N, rhs: N) -> N {
321 min(lhs, rhs)
322 }
323
324 fn init_value(first_element: N) -> N {
325 first_element
326 }
327}
328
329#[cube(launch_unchecked)]
343fn cumulative_kernel<C: Numeric, O: CumulativeOpFamily>(
344 input: &Tensor<C>,
345 output: &mut LinearView<C, ReadWrite>,
346 shape: Sequence<FastDivmod<usize>>,
347 #[comptime] dim: usize,
348 #[define(C)] _dtype: StorageType,
349) {
350 if !output.is_in_bounds(ABSOLUTE_POS) {
351 terminate!();
352 }
353
354 let rank = comptime![shape.len()];
355 let dim_stride = input.stride(dim);
356
357 let mut remainder = ABSOLUTE_POS;
358 let mut offset = 0;
359 let mut dim_idx = 0;
360
361 #[unroll]
362 for i in 0..shape.len() {
363 let i = comptime![rank - i - 1];
364 let (rem, local_idx) = shape.index(i).div_mod(remainder);
365 remainder = rem;
366 if i == dim {
367 dim_idx = local_idx;
368 } else {
369 offset += local_idx * input.stride(i);
370 }
371 }
372
373 let first_read_idx = offset + dim_idx * dim_stride;
375 let first_elem = input[first_read_idx];
376
377 let mut result = O::CumulativeOp::<C>::init_value(first_elem);
379
380 for i in 0..=dim_idx {
382 let read_idx = offset + i * dim_stride;
383 result = O::CumulativeOp::<C>::execute(result, input[read_idx]);
384 }
385 output[ABSOLUTE_POS] = result;
386}
387
388pub fn cumsum<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
390 cumulative_op::<R, SumOp>(input, dim)
391}
392
393pub fn cumprod<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
395 cumulative_op::<R, ProdOp>(input, dim)
396}
397
398pub fn cummin<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
400 cumulative_op::<R, MinOp>(input, dim)
401}
402
403pub fn cummax<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
405 cumulative_op::<R, MaxOp>(input, dim)
406}
407
408fn cumulative_op<R: CubeRuntime, O: CumulativeOpFamily>(
410 input: CubeTensor<R>,
411 dim: usize,
412) -> CubeTensor<R> {
413 let client = input.client.clone();
414 let device = input.device.clone();
415
416 let output = empty_device_dtype(client.clone(), device, input.shape.clone(), input.dtype);
417
418 let num_elems = output.shape.num_elements();
419 let working_units = num_elems;
420 let cube_dim = CubeDim::new(&client, working_units);
421 let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
422
423 unsafe {
424 cumulative_kernel::launch_unchecked::<O, R>(
425 &client,
426 cube_count,
427 cube_dim,
428 input.as_tensor_arg(1),
429 linear_view(&output, 1),
430 shape_divmod(&input),
431 dim,
432 output.dtype.into(),
433 )
434 .expect("Kernel to never fail");
435 }
436
437 output
438}