Skip to main content

burn_cubecl/ops/
numeric.rs

1use crate::{
2    CubeRuntime,
3    kernel::utils::{address_type, shape_divmod},
4};
5use crate::{element::CubeElement, tensor::CubeTensor};
6use crate::{
7    kernel::{
8        AddOp, BitwiseAndOp, BitwiseOrOp, BitwiseXorOp, DivOp, MulOp, PowOp, RemainderOp, SubOp,
9        launch_binop, launch_binop_int, launch_scalar_binop, launch_scalar_binop_int,
10    },
11    ops::max_vector_size,
12};
13use burn_backend::{DType, Shape, TensorMetadata};
14use burn_std::Metadata;
15use cubecl::{calculate_cube_count_elemwise, prelude::*};
16use cubecl::{client::ComputeClient, server::MemoryLayout};
17use cubecl::{
18    server::MemoryLayoutDescriptor,
19    std::{FastDivmod, tensor::layout::linear::LinearView},
20};
21
22/// Creates a tensor filled with `value`
23pub fn full<R: CubeRuntime, E: CubeElement>(
24    shape: Shape,
25    device: &R::Device,
26    value: E,
27) -> CubeTensor<R> {
28    let client = R::client(device);
29
30    full_client::<R, E>(client, shape, device.clone(), value)
31}
32
33/// Creates a tensor filled with `value`
34pub fn full_client<R: CubeRuntime, E: CubeElement>(
35    client: ComputeClient<R>,
36    shape: Shape,
37    device: R::Device,
38    value: E,
39) -> CubeTensor<R> {
40    let dtype = E::dtype();
41    full_device_dtype(client, shape, device, InputScalar::new(value, dtype), dtype)
42}
43
44/// Creates a tensor filled with `value`
45pub fn full_device_dtype<R: CubeRuntime>(
46    client: ComputeClient<R>,
47    shape: Shape,
48    device: R::Device,
49    value: InputScalar,
50    dtype: DType,
51) -> CubeTensor<R> {
52    let empty = empty_device_dtype(client, device, shape, dtype);
53
54    #[cube(launch_unchecked, address_type = "dynamic")]
55    pub fn full_kernel<C: Numeric, N: Size>(
56        tensor: &mut LinearView<Vector<C, N>, ReadWrite>,
57        value: InputScalar,
58        #[define(C)] _dtype: StorageType,
59    ) {
60        if !tensor.is_in_bounds(ABSOLUTE_POS) {
61            terminate!();
62        }
63
64        tensor[ABSOLUTE_POS] = Vector::new(value.get::<C>());
65    }
66
67    let num_elems = empty.meta.num_elements();
68    let vector_size = max_vector_size(&empty);
69
70    let working_units = num_elems / vector_size as usize;
71    let cube_dim = CubeDim::new(&empty.client, working_units);
72    let cube_count = calculate_cube_count_elemwise(&empty.client, working_units, cube_dim);
73
74    unsafe {
75        full_kernel::launch_unchecked(
76            &empty.client,
77            cube_count,
78            cube_dim,
79            address_type!(empty),
80            vector_size,
81            empty.clone().into_linear_view(),
82            value,
83            empty.dtype.into(),
84        );
85    }
86
87    empty
88}
89
90/// Creates a tensor filled with zeros
91pub fn zeros<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
92    let client = R::client(&device);
93    full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
94}
95
96/// Creates a tensor filled with ones
97pub fn ones<R: CubeRuntime>(device: R::Device, shape: Shape, dtype: DType) -> CubeTensor<R> {
98    let client = R::client(&device);
99    full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
100}
101
102/// Creates a tensor filled with zeros
103pub fn zeros_client<R: CubeRuntime>(
104    client: ComputeClient<R>,
105    device: R::Device,
106    shape: Shape,
107    dtype: DType,
108) -> CubeTensor<R> {
109    full_device_dtype(client, shape, device, InputScalar::new(0u32, dtype), dtype)
110}
111
112/// Creates a tensor filled with ones
113pub fn ones_client<R: CubeRuntime>(
114    client: ComputeClient<R>,
115    device: R::Device,
116    shape: Shape,
117    dtype: DType,
118) -> CubeTensor<R> {
119    full_device_dtype(client, shape, device, InputScalar::new(1u32, dtype), dtype)
120}
121
122/// Create a tensor with uninitialized memory
123pub fn empty_device<R: CubeRuntime, E: CubeElement>(
124    client: ComputeClient<R>,
125    device: R::Device,
126    shape: Shape,
127) -> CubeTensor<R> {
128    let MemoryLayout { memory, strides } = client.empty_tensor(shape.clone(), size_of::<E>());
129
130    CubeTensor::new(
131        client,
132        memory,
133        Metadata::new(shape, strides),
134        device,
135        E::dtype(),
136    )
137}
138
139/// Create a tensor with uninitialized memory
140pub fn empty_device_dtype<R: CubeRuntime>(
141    client: ComputeClient<R>,
142    device: R::Device,
143    shape: Shape,
144    dtype: DType,
145) -> CubeTensor<R> {
146    let MemoryLayout { memory, strides } = client.empty_tensor(shape.clone(), dtype.size());
147
148    CubeTensor::new(client, memory, Metadata::new(shape, strides), device, dtype)
149}
150
151/// Create a contiguous tensor with uninitialized memory
152pub fn empty_device_contiguous_dtype<R: CubeRuntime>(
153    client: ComputeClient<R>,
154    device: R::Device,
155    shape: Shape,
156    dtype: DType,
157) -> CubeTensor<R> {
158    let descriptor = MemoryLayoutDescriptor::contiguous(shape.clone(), dtype.size());
159    let MemoryLayout { memory, strides } = client.empty_tensors(vec![descriptor]).remove(0);
160
161    CubeTensor::new(client, memory, Metadata::new(shape, strides), device, dtype)
162}
163
164/// Add two tensors
165pub fn add<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
166    launch_binop::<R, AddOp>(lhs, rhs)
167}
168
169/// Add a tensor and a scalar
170pub fn add_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
171    launch_scalar_binop::<R, AddOp>(lhs, rhs)
172}
173
174/// Subtract two tensors
175pub fn sub<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
176    launch_binop::<R, SubOp>(lhs, rhs)
177}
178
179/// Subtract a tensor and a scalar
180pub fn sub_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
181    launch_scalar_binop::<R, SubOp>(lhs, rhs)
182}
183
184/// Multiply two tensors
185pub fn mul<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
186    launch_binop::<R, MulOp>(lhs, rhs)
187}
188
189/// Multiply a tensor and a scalar
190pub fn mul_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
191    launch_scalar_binop::<R, MulOp>(lhs, rhs)
192}
193
194/// Divide two tensors
195pub fn div<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
196    launch_binop::<R, DivOp>(lhs, rhs)
197}
198
199/// Divide a tensor by a scalar
200pub fn div_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
201    launch_scalar_binop::<R, DivOp>(lhs, rhs)
202}
203
204/// Calculate remainder of two tensors
205pub fn remainder<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
206    launch_binop::<R, RemainderOp>(lhs, rhs)
207}
208
209/// Calculate the remainder of a tensor with a scalar
210pub fn remainder_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
211    launch_scalar_binop::<R, RemainderOp>(lhs, rhs)
212}
213
214/// Calculate the power of two tensors
215pub fn pow<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
216    launch_binop::<R, PowOp>(lhs, rhs)
217}
218
219/// Bitwise and two tensors
220pub fn bitwise_and<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
221    launch_binop_int::<R, BitwiseAndOp>(lhs, rhs)
222}
223
224/// Bitwise and with a scalar
225pub fn bitwise_and_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
226    launch_scalar_binop_int::<R, BitwiseAndOp>(lhs, rhs)
227}
228
229/// Bitwise or two tensors
230pub fn bitwise_or<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
231    launch_binop_int::<R, BitwiseOrOp>(lhs, rhs)
232}
233
234/// Bitwise or with a scalar
235pub fn bitwise_or_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
236    launch_scalar_binop_int::<R, BitwiseOrOp>(lhs, rhs)
237}
238
239/// Bitwise xor two tensors
240pub fn bitwise_xor<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: CubeTensor<R>) -> CubeTensor<R> {
241    launch_binop_int::<R, BitwiseXorOp>(lhs, rhs)
242}
243
244/// Bitwise xor with a scalar
245pub fn bitwise_xor_scalar<R: CubeRuntime>(lhs: CubeTensor<R>, rhs: InputScalar) -> CubeTensor<R> {
246    launch_scalar_binop_int::<R, BitwiseXorOp>(lhs, rhs)
247}
248
249/// Operation family trait for cumulative operations
250pub(crate) trait CumulativeOpFamily: Send + Sync + 'static {
251    type CumulativeOp<C: Numeric>: CumulativeOp<C>;
252}
253
254/// Trait for cumulative operations
255#[cube]
256pub(crate) trait CumulativeOp<C: Numeric>: 'static + Send + Sync {
257    /// Execute a cumulative operation
258    fn execute(lhs: C, rhs: C) -> C;
259
260    /// Get the initial value for the accumulator
261    fn init_value(first_element: C) -> C;
262}
263
264// Operation types
265struct SumOp;
266struct ProdOp;
267struct MaxOp;
268struct MinOp;
269
270// Implement CumulativeOpFamily for each operation
271impl CumulativeOpFamily for SumOp {
272    type CumulativeOp<C: Numeric> = Self;
273}
274
275impl CumulativeOpFamily for ProdOp {
276    type CumulativeOp<C: Numeric> = Self;
277}
278
279impl CumulativeOpFamily for MaxOp {
280    type CumulativeOp<C: Numeric> = Self;
281}
282
283impl CumulativeOpFamily for MinOp {
284    type CumulativeOp<C: Numeric> = Self;
285}
286
287// Implement CumulativeOp for each operation type
288#[cube]
289impl<N: Numeric> CumulativeOp<N> for SumOp {
290    fn execute(lhs: N, rhs: N) -> N {
291        lhs + rhs
292    }
293
294    fn init_value(_first_element: N) -> N {
295        N::zero()
296    }
297}
298
299#[cube]
300impl<N: Numeric> CumulativeOp<N> for ProdOp {
301    fn execute(lhs: N, rhs: N) -> N {
302        lhs * rhs
303    }
304
305    fn init_value(_first_element: N) -> N {
306        N::from_int(1)
307    }
308}
309
310#[cube]
311impl<N: Numeric> CumulativeOp<N> for MaxOp {
312    fn execute(lhs: N, rhs: N) -> N {
313        max(lhs, rhs)
314    }
315
316    fn init_value(first_element: N) -> N {
317        first_element
318    }
319}
320
321#[cube]
322impl<N: Numeric> CumulativeOp<N> for MinOp {
323    fn execute(lhs: N, rhs: N) -> N {
324        min(lhs, rhs)
325    }
326
327    fn init_value(first_element: N) -> N {
328        first_element
329    }
330}
331
332/// Generic cumulative operation kernel
333///
334/// # Limitations
335///
336/// This is a **naive sequential implementation** along the cumulative dimension:
337/// - Each output element sequentially reads all previous elements along the dimension
338/// - Computational complexity: O(n^2) memory reads where n is the size of the cumulative dimension
339/// - **Performance:** Suitable for small tensors or small dimensions. For large tensors,
340///   performance will degrade significantly compared to an optimized parallel scan algorithm.
341///
342/// # TODO
343///
344/// Implement an efficient GPU-optimized parallel scan algorithm.
345#[cube(launch_unchecked, address_type = "dynamic")]
346fn cumulative_kernel<C: Numeric, O: CumulativeOpFamily>(
347    input: &Tensor<C>,
348    output: &mut LinearView<C, ReadWrite>,
349    shape: Sequence<FastDivmod<usize>>,
350    #[comptime] dim: usize,
351    #[define(C)] _dtype: StorageType,
352) {
353    if !output.is_in_bounds(ABSOLUTE_POS) {
354        terminate!();
355    }
356
357    let rank = comptime![shape.len()];
358    let dim_stride = input.stride(dim);
359
360    let mut remainder = ABSOLUTE_POS;
361    let mut offset = 0;
362    let mut dim_idx = 0;
363
364    #[unroll]
365    for i in 0..shape.len() {
366        let i = comptime![rank - i - 1];
367        let (rem, local_idx) = shape.index(i).div_mod(remainder);
368        remainder = rem;
369        if i == dim {
370            dim_idx = local_idx;
371        } else {
372            offset += local_idx * input.stride(i);
373        }
374    }
375
376    // Read first element
377    let first_read_idx = offset + dim_idx * dim_stride;
378    let first_elem = input[first_read_idx];
379
380    // Initialize accumulator
381    let mut result = O::CumulativeOp::<C>::init_value(first_elem);
382
383    // Accumulate values
384    for i in 0..=dim_idx {
385        let read_idx = offset + i * dim_stride;
386        result = O::CumulativeOp::<C>::execute(result, input[read_idx]);
387    }
388    output[ABSOLUTE_POS] = result;
389}
390
391/// Compute the cumulative sum along a dimension
392pub fn cumsum<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
393    cumulative_op::<R, SumOp>(input, dim)
394}
395
396/// Compute the cumulative product along a dimension
397pub fn cumprod<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
398    cumulative_op::<R, ProdOp>(input, dim)
399}
400
401/// Compute the cumulative minimum along a dimension
402pub fn cummin<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
403    cumulative_op::<R, MinOp>(input, dim)
404}
405
406/// Compute the cumulative maximum along a dimension
407pub fn cummax<R: CubeRuntime>(input: CubeTensor<R>, dim: usize) -> CubeTensor<R> {
408    cumulative_op::<R, MaxOp>(input, dim)
409}
410
411/// Generic cumulative operation function
412fn cumulative_op<R: CubeRuntime, O: CumulativeOpFamily>(
413    input: CubeTensor<R>,
414    dim: usize,
415) -> CubeTensor<R> {
416    let client = input.client.clone();
417    let device = input.device.clone();
418
419    let output = empty_device_dtype(client.clone(), device, input.shape(), input.dtype);
420
421    let num_elems = output.meta.num_elements();
422    let working_units = num_elems;
423    let cube_dim = CubeDim::new(&client, working_units);
424    let cube_count = calculate_cube_count_elemwise(&client, working_units, cube_dim);
425    let shape = shape_divmod(&input);
426
427    unsafe {
428        cumulative_kernel::launch_unchecked::<O, R>(
429            &client,
430            cube_count,
431            cube_dim,
432            address_type!(input, output),
433            input.into_tensor_arg(),
434            output.clone().into_linear_view(),
435            shape,
436            dim,
437            output.dtype.into(),
438        );
439    }
440
441    output
442}