cubecl_reduce/
shared_sum.rs

1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3use cubecl_runtime::TypeUsage;
4
5use crate::ReduceError;
6
7/// Sum all the elements of the input tensor distributed over `cube_count` cubes.
8///
9/// This is an optimized version for summing large tensors using multiple cubes.
10/// For summing a single axis, the regular [reduce] entry point is preferred.
11///
12/// Return an error if atomic addition is not supported for the type `N`.
13///
14/// # Important
15///
16/// This doesn't set the value of output to 0 before computing the sums.
17/// It is the responsibility of the caller to ensure that output is set to
18/// the proper value. Basically, the behavior of this kernel is akin to the AddAssign operator
19/// as it update the output instead of overwriting it.
20///
21/// # Example
22///
23/// This examples show how to sum all the elements of a small `2 x 2` matrix.
24/// For more details, see the CubeCL documentation.
25///
26/// ```ignore
27/// let client = /* ... */;
28/// let size_f32 = std::mem::size_of::<f32>();
29///
30/// // Create input and output handles.
31/// let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
32/// let output_handle = client.empty(size_of::<F>());
33/// let input = unsafe {
34///     TensorHandleRef::<R>::from_raw_parts(
35///         &input_handle,
36///         &[2, 1],
37///         &[2, 2],
38///         size_f32,
39///     )
40/// };
41/// let output = unsafe {
42///     TensorHandleRef::<R>::from_raw_parts(&output_handle, &[1], &[1], size_of::<F>())
43/// };
44///
45/// // Here `R` is a `cubecl::Runtime`.
46/// let result = shared_sum::<R, f32>(&client, input, output, cube_count);
47///
48/// if result.is_ok() {
49///        let binding = output_handle.binding();
50///        let bytes = client.read_one(binding);
51///        let output_values = f32::from_bytes(&bytes);
52///        println!("Output = {:?}", output_values); // Should print [6].
53/// }
54/// ```
55pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
56    client: &ComputeClient<R::Server>,
57    input: TensorHandleRef<R>,
58    output: TensorHandleRef<R>,
59    cube_count: u32,
60) -> Result<(), ReduceError> {
61    // Check that the client supports atomic addition.
62    let atomic_elem = Atomic::<N>::as_type_native_unchecked();
63    if !client
64        .properties()
65        .type_usage(atomic_elem)
66        .contains(TypeUsage::AtomicAdd)
67    {
68        return Err(ReduceError::MissingAtomicAdd(N::as_type_native_unchecked()));
69    }
70
71    let input_len = input.shape.iter().map(|s| *s as u32).product::<u32>();
72
73    // Compute the optimal line size.
74    let line_size = R::io_optimized_line_sizes_unchecked(size_of::<N>())
75        .filter(|line_size| input_len % *line_size as u32 == 0)
76        .max()
77        .unwrap_or(1) as u32;
78
79    // Compute extra parameters.
80    let cube_dim = CubeDim::new_2d(32, 8); // NOTE: If you change that, keep the unit count a power of 2.
81    let num_units = cube_count * cube_dim.num_elems();
82    let num_lines_per_unit = input_len.div_ceil(num_units * line_size);
83    let cube_count = CubeCount::new_1d(cube_count);
84
85    // Launch kernel
86    unsafe {
87        shared_sum_kernel::launch_unchecked::<N, R>(
88            client,
89            cube_count,
90            cube_dim,
91            input.as_tensor_arg(line_size as u8),
92            output.as_tensor_arg(1),
93            cube_dim.num_elems(),
94            line_size,
95            num_lines_per_unit,
96        );
97    }
98
99    Ok(())
100}
101
102#[cube(launch_unchecked)]
103fn shared_sum_kernel<N: Numeric>(
104    input: &Tensor<Line<N>>,
105    output: &mut Tensor<Atomic<N>>,
106    #[comptime] shared_memory_size: u32,
107    #[comptime] line_size: u32,
108    #[comptime] num_lines_per_unit: u32,
109) {
110    let mut shared_memory = SharedMemory::new_lined(shared_memory_size, line_size);
111    shared_memory[UNIT_POS] = Line::empty(line_size).fill(N::from_int(0));
112
113    // Each unit reduce `num_lines_per_unit` lines.
114    let start = ABSOLUTE_POS * num_lines_per_unit;
115    let end = start + num_lines_per_unit;
116
117    // Prevent out-of-bound access
118    let start = select(start < input.len(), start, input.len());
119    let end = select(end < input.len(), end, input.len());
120
121    // Each unit sum its lines.
122    for k in start..end {
123        shared_memory[UNIT_POS] += input[k];
124    }
125
126    // Sum all lines within the shared_memory to a single line.
127    let line = sum_shared_memory(&mut shared_memory);
128
129    // Sum all the elements within the line.
130    let sum = RuntimeCell::<N>::new(N::from_int(0));
131    #[unroll]
132    for k in 0..line_size {
133        let update = line[k] + sum.read();
134        sum.store(update);
135    }
136
137    // Add the sum for the current cube to the output.
138    if UNIT_POS == 0 {
139        Atomic::add(&output[0], sum.consume());
140    }
141}
142
143// This is a simplified version of [tree_reduce].
144// See the documentation there for details.
145// Here we assume that `CUBE_DIM` is always a power of two.
146#[cube]
147fn sum_shared_memory<N: Numeric>(accumulator: &mut SharedMemory<Line<N>>) -> Line<N> {
148    sync_cube();
149    let mut num_active_units = CUBE_DIM;
150    let mut jump = 1;
151    while num_active_units > 1 {
152        num_active_units /= 2;
153        let destination = jump * 2 * UNIT_POS;
154        let origin = jump * (2 * UNIT_POS + 1);
155        if UNIT_POS < num_active_units {
156            let element = accumulator[origin];
157            accumulator[destination] += element;
158        }
159        jump *= 2;
160        sync_cube();
161    }
162    accumulator[0]
163}