cubecl_reduce/shared_sum.rs
1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::ReduceError;
5
6/// Sum all the elements of the input tensor distributed over `cube_count` cubes.
7///
8/// This is an optimized version for summing large tensors using multiple cubes.
9/// For summing a single axis, the regular [reduce] entry point is preferred.
10///
11/// Return an error if atomic addition is not supported for the type `N`.
12///
13/// # Important
14///
15/// This doesn't set the value of output to 0 before computing the sums.
16/// It is the responsibility of the caller to ensure that output is set to
17/// the proper value. Basically, the behavior of this kernel is akin to the AddAssign operator
18/// as it update the output instead of overwriting it.
19///
20/// # Example
21///
22/// This examples show how to sum all the elements of a small `2 x 2` matrix.
23/// For more details, see the CubeCL documentation.
24///
25/// ```ignore
26/// let client = /* ... */;
27/// let size_f32 = std::mem::size_of::<f32>();
28///
29/// // Create input and output handles.
30/// let input_handle = client.create(f32::as_bytes(&[0, 1, 2, 3]));
31/// let output_handle = client.empty(size_of::<F>());
32/// let input = unsafe {
33/// TensorHandleRef::<R>::from_raw_parts(
34/// &input_handle,
35/// &[2, 1],
36/// &[2, 2],
37/// size_f32,
38/// )
39/// };
40/// let output = unsafe {
41/// TensorHandleRef::<R>::from_raw_parts(&output_handle, &[1], &[1], size_of::<F>())
42/// };
43///
44/// // Here `R` is a `cubecl::Runtime`.
45/// let result = shared_sum::<R, f32>(&client, input, output, cube_count);
46///
47/// if result.is_ok() {
48/// let binding = output_handle.binding();
49/// let bytes = client.read_one(binding);
50/// let output_values = f32::from_bytes(&bytes);
51/// println!("Output = {:?}", output_values); // Should print [6].
52/// }
53/// ```
54pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
55 client: &ComputeClient<R::Server, R::Channel>,
56 input: TensorHandleRef<R>,
57 output: TensorHandleRef<R>,
58 cube_count: u32,
59) -> Result<(), ReduceError> {
60 // Check that the client supports atomic addition.
61 let atomic_elem = Atomic::<N>::as_elem_native_unchecked();
62 if !client
63 .properties()
64 .feature_enabled(cubecl_core::Feature::Type(atomic_elem))
65 || !client
66 .properties()
67 .feature_enabled(cubecl_core::Feature::AtomicFloat(
68 cubecl_core::AtomicFeature::Add,
69 ))
70 {
71 return Err(ReduceError::MissingAtomicAdd(N::as_elem_native_unchecked()));
72 }
73
74 let input_len = input.shape.iter().map(|s| *s as u32).product::<u32>();
75
76 // Compute the optimal line size.
77 let elem = N::as_elem_native_unchecked();
78 let line_size = R::line_size_elem(&elem)
79 .filter(|line_size| input_len % *line_size as u32 == 0)
80 .max()
81 .unwrap_or(1) as u32;
82
83 // Compute extra parameters.
84 let cube_dim = CubeDim::new_2d(32, 8); // NOTE: If you change that, keep the unit count a power of 2.
85 let num_units = cube_count * cube_dim.num_elems();
86 let num_lines_per_unit = input_len.div_ceil(num_units * line_size);
87 let cube_count = CubeCount::new_1d(cube_count);
88
89 // Launch kernel
90 unsafe {
91 shared_sum_kernel::launch_unchecked::<N, R>(
92 client,
93 cube_count,
94 cube_dim,
95 input.as_tensor_arg(line_size as u8),
96 output.as_tensor_arg(1),
97 cube_dim.num_elems(),
98 line_size,
99 num_lines_per_unit,
100 );
101 }
102
103 Ok(())
104}
105
106#[cube(launch_unchecked)]
107fn shared_sum_kernel<N: Numeric>(
108 input: &Tensor<Line<N>>,
109 output: &mut Tensor<Atomic<N>>,
110 #[comptime] shared_memory_size: u32,
111 #[comptime] line_size: u32,
112 #[comptime] num_lines_per_unit: u32,
113) {
114 let mut shared_memory = SharedMemory::new_lined(shared_memory_size, line_size);
115 shared_memory[UNIT_POS] = Line::empty(line_size).fill(N::from_int(0));
116
117 // Each unit reduce `num_lines_per_unit` lines.
118 let start = ABSOLUTE_POS * num_lines_per_unit;
119 let end = start + num_lines_per_unit;
120
121 // Prevent out-of-bound access
122 let start = select(start < input.len(), start, input.len());
123 let end = select(end < input.len(), end, input.len());
124
125 // Each unit sum its lines.
126 for k in start..end {
127 shared_memory[UNIT_POS] += input[k];
128 }
129
130 // Sum all lines within the shared_memory to a single line.
131 let line = sum_shared_memory(&mut shared_memory);
132
133 // Sum all the elements within the line.
134 let mut sum = N::from_int(0);
135 #[unroll]
136 for k in 0..line_size {
137 sum += line[k];
138 }
139
140 // Add the sum for the current cube to the output.
141 if UNIT_POS == 0 {
142 Atomic::add(&output[0], sum);
143 }
144}
145
146// This is a simplified version of [tree_reduce].
147// See the documentation there for details.
148// Here we assume that `CUBE_DIM` is always a power of two.
149#[cube]
150fn sum_shared_memory<N: Numeric>(accumulator: &mut SharedMemory<Line<N>>) -> Line<N> {
151 sync_units();
152 let mut num_active_units = CUBE_DIM;
153 let mut jump = 1;
154 while num_active_units > 1 {
155 num_active_units /= 2;
156 let destination = jump * 2 * UNIT_POS;
157 let origin = jump * (2 * UNIT_POS + 1);
158 if UNIT_POS < num_active_units {
159 let element = accumulator[origin];
160 accumulator[destination] += element;
161 }
162 jump *= 2;
163 sync_units();
164 }
165 accumulator[0]
166}