burn_cubecl/ops/
numeric.rs

1use crate::{
2    CubeRuntime,
3    kernel::utils::{linear_view, shape_divmod},
4};
5use crate::{element::CubeElement, tensor::CubeTensor};
6use crate::{
7    kernel::{
8        AddOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp,
9        launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int,
10    },
11    ops::max_line_size,
12};
13use burn_backend::{DType, Shape};
14use cubecl::std::{FastDivmod, tensor::layout::linear::LinearView};
15use cubecl::{calculate_cube_count_elemwise, prelude::*};
16use cubecl::{client::ComputeClient, server::Allocation};
17
18/// Creates a tensor filled with `value`
19pub fn full<R: CubeRuntime, E: CubeElement>(
20    shape: Shape,
21    device: &R::Device,
22    value: E,
23) -> CubeTensor<R> {
24    let client = R::client(device);
25
26    full_client::<R, E>(client, shape, device.clone(), value)
27}
28
29/// Creates a tensor filled with `value`
30pub fn full_client<R: CubeRuntime, E: CubeElement>(
31    client: ComputeClient<R>,
32    shape: Shape,
33    device: R::Device,
34    value: E,
35) -> CubeTensor<R> {
36    let dtype = E::dtype();
37    full_device_dtype(client, shape, device, InputScalar::new(value, dtype), dtype)
38}
39
40/// Creates a tensor filled with `value`
41pub fn full_device_dtype<R: CubeRuntime>(
42    client: ComputeClient<R>,
43    shape: Shape,
44    device: R::Device,
45    value: InputScalar,
46    dtype: DType,
47) -> CubeTensor<R> {
48    let empty = empty_device_dtype(client, device, shape, dtype);
49
50    #[cube(launch_unchecked)]
51    pub fn full_kernel<C: Numeric>(
52        tensor: &mut LinearView<C, ReadWrite>,
53        value: InputScalar,
54        #[define(C)] _dtype: StorageType,
55    ) {
56        if !tensor.is_in_bounds(ABSOLUTE_POS) {
57            terminate!();
58        }
59
60        tensor[ABSOLUTE_POS] = value.get::<C>();
61    }
62
63    let num_elems = empty.shape.num_elements();
64    let line_size = max_line_size(&empty);
65
66    let working_units = num_elems / line_size as usize;
67    let cube_dim = CubeDim::new(&empty.client, working_units);
68    let cube_count = calculate_cube_count_elemwise(&empty.client, working_units, cube_dim);
69
70    unsafe {
71        full_kernel::launch_unchecked(
72            &empty.client,
73            cube_count,
74            cube_dim,
75            linear_view(&empty, line_size),
76            value,
77            dtype.into(),
78        )
79        .expect("Kernel to never fail");
80    }
81
82    empty
83}
84
85/// Creates a tensor filled with zeros
86pub fn zeros<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
87    let client = R::client(&device);
88    full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
89}
90
91/// Creates a tensor filled with ones
92pub fn ones<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
93    let client = R::client(&device);
94    full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
95}
96
97/// Creates a tensor filled with zeros
98pub fn zeros_client<R: CubeRuntime>(
99    client: ComputeClient<R>,
100    device: R::Device,
101    shape: Shape,
102    dtype: DType,
103) -> CubeTensor<R> {
104    full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
105}
106
107/// Creates a tensor filled with ones
108pub fn ones_client<R: CubeRuntime>(
109    client: ComputeClient<R>,
110    device: R::Device,
111    shape: Shape,
112    dtype: DType,
113) -> CubeTensor<R> {
114    full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
115}
116
117/// Creates a tensor with uninitialized memory
118pub fn empty_device<R: CubeRuntime, E: CubeElement>(
119    client: ComputeClient<R>,
120    device: R::Device,
121    shape: Shape,
122) -> CubeTensor<R> {
123    empty_device_dtype(client, device, shape, E::dtype())
124}
125
126/// Creates a tensor with uninitialized memory with the specific dtype.
127pub fn empty_device_dtype<R: CubeRuntime>(
128    client: ComputeClient<R>,
129    device: R::Device,
130    shape: Shape,
131    dtype: DType,
132) -> CubeTensor<R> {
133    let buffer = client.empty(shape.num_elements() * dtype.size());
134
135    CubeTensor::new_contiguous(client, device, shape, buffer, dtype)
136}
137
138/// Create a tensor with uninitialized memory
139pub fn empty_device_optimized<R: CubeRuntime, E: CubeElement>(
140    client: ComputeClient<R>,
141    device: R::Device,
142    shape: Shape,
143) -> CubeTensor<R> {
144    let Allocation { handle, strides } = client.empty_tensor(&shape.dims, size_of::<E>());
145
146    CubeTensor::new(client, handle, shape, device, strides, E::dtype())
147}
148
149/// Create a tensor with uninitialized memory
150pub fn empty_device_optimized_dtype<R: CubeRuntime>(
151    client: ComputeClient<R>,
152    device: R::Device,
153    shape: Shape,
154    dtype: DType,
155) -> CubeTensor<R> {
156    let Allocation { handle, strides } = client.empty_tensor(&shape.dims, dtype.size());
157
158    CubeTensor::new(client, handle, shape, device, strides, dtype)
159}
160
161/// Add two tensors
162pub fn add<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
163    launch_binop::<R, AddOp>(lhs, rhs)
164}
165
166/// Add a tensor and a scalar
167pub fn add_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
168    launch_scalar_binop::<R, AddOp>(lhs, rhs)
169}
170
171/// Subtract two tensors
172pub fn sub<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
173    launch_binop::<R, SubOp>(lhs, rhs)
174}
175
176/// Subtract a tensor and a scalar
177pub fn sub_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
178    launch_scalar_binop::<R, SubOp>(lhs, rhs)
179}
180
181/// Multiply two tensors
182pub fn mul<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
183    launch_binop::<R, MulOp>(lhs, rhs)
184}
185
186/// Multiply a tensor and a scalar
187pub fn mul_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
188    launch_scalar_binop::<R, MulOp>(lhs, rhs)
189}
190
191/// Divide two tensors
192pub fn div<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
193    launch_binop::<R, DivOp>(lhs, rhs)
194}
195
196/// Divide a tensor by a scalar
197pub fn div_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
198    launch_scalar_binop::<R, DivOp>(lhs, rhs)
199}
200
201/// Calculate remainder of two tensors
202pub fn remainder<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
203    launch_binop::<R, RemainderOp>(lhs, rhs)
204}
205
206/// Calculate the remainder of a tensor with a scalar
207pub fn remainder_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
208    launch_scalar_binop::<R, RemainderOp>(lhs, rhs)
209}
210
211/// Calculate the power of two tensors
212pub fn pow<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
213    launch_binop::<R, PowOp>(lhs, rhs)
214}
215
216/// Bitwise and two tensors
217pub fn bitwise_and<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
218    launch_binop_int::<R, BitwiseAndOp>(lhs, rhs)
219}
220
221/// Bitwise and with a scalar
222pub fn bitwise_and_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
223    launch_scalar_binop_int::<R, BitwiseAndOp>(lhs, rhs)
224}
225
226/// Bitwise or two tensors
227pub fn bitwise_or<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
228    launch_binop_int::<R, BitwiseOrOp>(lhs, rhs)
229}
230
231/// Bitwise or with a scalar
232pub fn bitwise_or_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
233    launch_scalar_binop_int::<R, BitwiseOrOp>(lhs, rhs)
234}
235
236/// Bitwise xor two tensors
237pub fn bitwise_xor<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
238    launch_binop_int::<R, BitwiseXorOp>(lhs, rhs)
239}
240
241/// Bitwise xor with a scalar
242pub fn bitwise_xor_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
243    launch_scalar_binop_int::<R, BitwiseXorOp>(lhs, rhs)
244}
245
246/// Operation family trait for cumulative operations
247pub(crate) trait CumulativeOpFamily: Send + Sync + 'static {
248    type CumulativeOp<C: Numeric>: CumulativeOp<C>;
249}
250
251/// Trait for cumulative operations
252#[cube]
253pub(crate) trait CumulativeOp<C: Numeric>: 'static + Send + Sync {
254    /// Execute a cumulative operation
255    fn execute(lhs: C, rhs: C) -> C;
256
257    /// Get the initial value for the accumulator
258    fn init_value(first_element: C) -> C;
259}
260
261// Operation types
262struct SumOp;
263struct ProdOp;
264struct MaxOp;
265struct MinOp;
266
267// Implement CumulativeOpFamily for each operation
268impl CumulativeOpFamily for SumOp {
269    type CumulativeOp<C: Numeric> = Self;
270}
271
272impl CumulativeOpFamily for ProdOp {
273    type CumulativeOp<C: Numeric> = Self;
274}
275
276impl CumulativeOpFamily for MaxOp {
277    type CumulativeOp<C: Numeric> = Self;
278}
279
280impl CumulativeOpFamily for MinOp {
281    type CumulativeOp<C: Numeric> = Self;
282}
283
284// Implement CumulativeOp for each operation type
285#[cube]
286impl<N: Numeric> CumulativeOp<N> for SumOp {
287    fn execute(lhs: N, rhs: N) -> N {
288        lhs + rhs
289    }
290
291    fn init_value(_first_element: N) -> N {
292        N::from_int(0)
293    }
294}
295
296#[cube]
297impl<N: Numeric> CumulativeOp<N> for ProdOp {
298    fn execute(lhs: N, rhs: N) -> N {
299        lhs * rhs
300    }
301
302    fn init_value(_first_element: N) -> N {
303        N::from_int(1)
304    }
305}
306
307#[cube]
308impl<N: Numeric> CumulativeOp<N> for MaxOp {
309    fn execute(lhs: N, rhs: N) -> N {
310        max(lhs, rhs)
311    }
312
313    fn init_value(first_element: N) -> N {
314        first_element
315    }
316}
317
318#[cube]
319impl<N: Numeric> CumulativeOp<N> for MinOp {
320    fn execute(lhs: N, rhs: N) -> N {
321        min(lhs, rhs)
322    }
323
324    fn init_value(first_element: N) -> N {
325        first_element
326    }
327}
328
329/// Generic cumulative operation kernel
330///
331/// # Limitations
332///
333/// This is a **naive sequential implementation** along the cumulative dimension:
334/// - Each output element sequentially reads all previous elements along the dimension
335/// - Computational complexity: O(n^2) memory reads where n is the size of the cumulative dimension
336/// - **Performance:** Suitable for small tensors or small dimensions. For large tensors,
337///   performance will degrade significantly compared to an optimized parallel scan algorithm.
338///
339/// # TODO
340///
341/// Implement an efficient GPU-optimized parallel scan algorithm.
342#[cube(launch_unchecked)]
343fn cumulative_kernel<C: Numeric, O: CumulativeOpFamily>(
344    input: &Tensor<C>,
345    output: &mut LinearView<C, ReadWrite>,
346    shape: Sequence<FastDivmod<usize>>,
347    #[comptime] dim: usize,
348    #[define(C)] _dtype: StorageType,
349) {
350    if !output.is_in_bounds(ABSOLUTE_POS) {
351        terminate!();
352    }
353
354    let rank = comptime![shape.len()];
355    let dim_stride = input.stride(dim);
356
357    let mut remainder = ABSOLUTE_POS;
358    let mut offset = 0;
359    let mut dim_idx = 0;
360
361    #[unroll]
362    for i in 0..shape.len() {
363        let i = comptime![rank - i - 1];
364        let (rem, local_idx) = shape.index(i).div_mod(remainder);
365        remainder = rem;
366        if i == dim {
367            dim_idx = local_idx;
368        } else {
369            offset += local_idx * input.stride(i);
370        }
371    }
372
373    // Read first element
374    let first_read_idx = offset + dim_idx * dim_stride;
375    let first_elem = input[first_read_idx];
376
377    // Initialize accumulator
378    let mut result = O::CumulativeOp::<C>::init_value(first_elem);
379
380    // Accumulate values
381    for i in 0..=dim_idx {
382        let read_idx = offset + i * dim_stride;
383        result = O::CumulativeOp::<C>::execute(result, input[read_idx]);
384    }
385    output[ABSOLUTE_POS] = result;
386}
387
388/// Compute the cumulative sum along a dimension
389pub fn cumsum<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
390    cumulative_op::<R, SumOp>(input, dim)
391}
392
393/// Compute the cumulative product along a dimension
394pub fn cumprod<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
395    cumulative_op::<R, ProdOp>(input, dim)
396}
397
398/// Compute the cumulative minimum along a dimension
399pub fn cummin<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
400    cumulative_op::<R, MinOp>(input, dim)
401}
402
403/// Compute the cumulative maximum along a dimension
404pub fn cummax<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
405    cumulative_op::<R, MaxOp>(input, dim)
406}
407
408/// Generic cumulative operation function
409fn cumulative_op<R: CubeRuntime, O: CumulativeOpFamily>(
410    input: CubeTensor<R>,
411    dim: usize,
412) -> CubeTensor<R> {
413    let client = input.client.clone();
414    let device = input.device.clone();
415
416    let output = empty_device_dtype(client.clone(), device, input.shape.clone(), input.dtype);
417
418    let num_elems = output.shape.num_elements();
419    let working_units = num_elems;
420    let cube_dim = CubeDim::new(&client, working_units);
421    let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
422
423    unsafe {
424        cumulative_kernel::launch_unchecked::<O, R>(
425            &client,
426            cube_count,
427            cube_dim,
428            input.as_tensor_arg(1),
429            linear_view(&output, 1),
430            shape_divmod(&input),
431            dim,
432            output.dtype.into(),
433        )
434        .expect("Kernel to never fail");
435    }
436
437    output
438}