cubecl_reduce/
shared_sum.rs1use cubecl_core::prelude::*;
2use cubecl_core::{self as cubecl};
3
4use crate::ReduceError;
5
6pub fn shared_sum<R: Runtime, N: Numeric + CubeElement>(
55 client: &ComputeClient<R::Server, R::Channel>,
56 input: TensorHandleRef<R>,
57 output: TensorHandleRef<R>,
58 cube_count: u32,
59) -> Result<(), ReduceError> {
60 let atomic_elem = Atomic::<N>::as_elem_native_unchecked();
62 if !client
63 .properties()
64 .feature_enabled(cubecl_core::Feature::Type(atomic_elem))
65 || !client
66 .properties()
67 .feature_enabled(cubecl_core::Feature::AtomicFloat(
68 cubecl_core::AtomicFeature::Add,
69 ))
70 {
71 return Err(ReduceError::MissingAtomicAdd(N::as_elem_native_unchecked()));
72 }
73
74 let input_len = input.shape.iter().map(|s| *s as u32).product::<u32>();
75
76 let elem = N::as_elem_native_unchecked();
78 let line_size = R::line_size_elem(&elem)
79 .filter(|line_size| input_len % *line_size as u32 == 0)
80 .max()
81 .unwrap_or(1) as u32;
82
83 let cube_dim = CubeDim::new_2d(32, 8); let num_units = cube_count * cube_dim.num_elems();
86 let num_lines_per_unit = input_len.div_ceil(num_units * line_size);
87 let cube_count = CubeCount::new_1d(cube_count);
88
89 unsafe {
91 shared_sum_kernel::launch_unchecked::<N, R>(
92 client,
93 cube_count,
94 cube_dim,
95 input.as_tensor_arg(line_size as u8),
96 output.as_tensor_arg(1),
97 cube_dim.num_elems(),
98 line_size,
99 num_lines_per_unit,
100 );
101 }
102
103 Ok(())
104}
105
106#[cube(launch_unchecked)]
107fn shared_sum_kernel<N: Numeric>(
108 input: &Tensor<Line<N>>,
109 output: &mut Tensor<Atomic<N>>,
110 #[comptime] shared_memory_size: u32,
111 #[comptime] line_size: u32,
112 #[comptime] num_lines_per_unit: u32,
113) {
114 let mut shared_memory = SharedMemory::new_lined(shared_memory_size, line_size);
115 shared_memory[UNIT_POS] = Line::empty(line_size).fill(N::from_int(0));
116
117 let start = ABSOLUTE_POS * num_lines_per_unit;
119 let end = start + num_lines_per_unit;
120
121 let start = select(start < input.len(), start, input.len());
123 let end = select(end < input.len(), end, input.len());
124
125 for k in start..end {
127 shared_memory[UNIT_POS] += input[k];
128 }
129
130 let line = sum_shared_memory(&mut shared_memory);
132
133 let sum = RuntimeCell::<N>::new(N::from_int(0));
135 #[unroll]
136 for k in 0..line_size {
137 let update = line[k] + sum.read();
138 sum.store(update);
139 }
140
141 if UNIT_POS == 0 {
143 Atomic::add(&output[0], sum.consume());
144 }
145}
146
147#[cube]
151fn sum_shared_memory<N: Numeric>(accumulator: &mut SharedMemory<Line<N>>) -> Line<N> {
152 sync_cube();
153 let mut num_active_units = CUBE_DIM;
154 let mut jump = 1;
155 while num_active_units > 1 {
156 num_active_units /= 2;
157 let destination = jump * 2 * UNIT_POS;
158 let origin = jump * (2 * UNIT_POS + 1);
159 if UNIT_POS < num_active_units {
160 let element = accumulator[origin];
161 accumulator[destination] += element;
162 }
163 jump *= 2;
164 sync_cube();
165 }
166 accumulator[0]
167}