cubecl_reduce/
config.rs

1use cubecl_core::{
2    prelude::*, server::ComputeServer, tensor_line_size_parallel, tensor_line_size_perpendicular,
3};
4use cubecl_std::tensor::is_contiguous;
5
6use crate::ReduceStrategy;
7
8// TODO: Should we allows the user to change that?
9const DEFAULT_PLANE_COUNT: u32 = 8;
10
11#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
12pub enum LineMode {
13    Parallel,
14    Perpendicular,
15}
16
17#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
18/// How bound checks is handled for inner reductions.
19pub enum BoundChecksInner {
20    /// No bound check is necessary.
21    None,
22    /// Using a mask is enough for bound checks.
23    /// This will still read the memory in an out-of-bound location,
24    /// but will replace the value by the null value.
25    Mask,
26    /// Branching is necessary for bound checks.
27    ///
28    /// Probably the right setting when performing fuse on read.
29    Branch,
30}
31
32#[derive(Debug, Clone)]
33pub struct ReduceConfig {
34    pub cube_count: CubeCount,
35    pub cube_dim: CubeDim,
36    pub line_mode: LineMode,
37    pub line_size_input: u32,
38    pub line_size_output: u32,
39    pub bound_checks: bool,
40    pub bound_checks_inner: BoundChecksInner,
41}
42
43impl ReduceConfig {
44    pub(crate) fn generate<R: Runtime>(
45        client: &ComputeClient<R::Server>,
46        input: &TensorHandleRef<R>,
47        output: &TensorHandleRef<R>,
48        axis: usize,
49        strategy: &ReduceStrategy,
50        dtype: StorageType,
51    ) -> ReduceConfig {
52        let reduce_count = output.size() as u32;
53        ReduceConfig::new()
54            .generate_line_mode(input, axis)
55            .generate_line_size::<R>(input, output, axis, dtype)
56            .generate_cube_dim(client, strategy.use_planes)
57            .generate_cube_count::<R>(reduce_count, strategy)
58    }
59
60    fn new() -> Self {
61        // This is only a dummy configuration to use as a starting point.
62        Self {
63            cube_count: CubeCount::new_single(),
64            cube_dim: CubeDim::new_single(),
65            line_mode: LineMode::Parallel,
66            line_size_input: 1,
67            line_size_output: 1,
68            bound_checks: true,
69            bound_checks_inner: BoundChecksInner::Mask,
70        }
71    }
72
73    fn generate_line_mode<R: Runtime>(mut self, input: &TensorHandleRef<R>, axis: usize) -> Self {
74        let stride = input.strides[axis];
75        self.line_mode = if stride == 1 {
76            LineMode::Parallel
77        } else {
78            LineMode::Perpendicular
79        };
80        self
81    }
82
83    fn generate_line_size<R: Runtime>(
84        mut self,
85        input: &TensorHandleRef<R>,
86        output: &TensorHandleRef<R>,
87        axis: usize,
88        dtype: StorageType,
89    ) -> Self {
90        let supported_line_sizes = R::io_optimized_line_sizes_unchecked(dtype.size());
91        self.line_size_input = match self.line_mode {
92            LineMode::Parallel => {
93                tensor_line_size_parallel(supported_line_sizes, input.shape, input.strides, axis)
94                    as u32
95            }
96            LineMode::Perpendicular => {
97                // To compute the maximum line size we can used,
98                // we first sort both the input and output axes by increasing strides.
99                // As example, consider
100                //    input shape = [2, 4, 6, 8]
101                //    input stride = [1, 16, 64, 2]
102                //    output shape = [2, 1, 6, 8]
103                //    output stride = [1, 1, 2, 12]
104                //    axis = 1
105                //
106                // then we have
107                //    input sorted axis = [0, 3, 1, 2]
108                //    output sorted axis = [0, 1, 2, 3]
109                //
110                // From that point, we look at all the axes before the target axis in the sorted input.
111                // That is [0, 3] in the example.
112                // In the output, we remove the target axis leading to [0, 2, 3] in the example.
113                //
114                // In order to use perpendicular line, we are limited by the number of entries that are both
115                // contiguous in the input and output. This is obtained by taking the head of each list until they are different.
116                // In the above example, only the 0 axis is contiguous in both tensor, but it output sorted axis were [0, 1, 3, 2] instead,
117                // both the 0 and 3 axes would be contiguous in the two tensors.
118                // The corresponding number of entries is the product of the shape for the contiguous axes.
119                // In the example, it is simply 2.
120                //
121                // This gives us an upper bound on the line size we can used.
122                // Then, we use the regular method to find the best line size that match the device capacities.
123
124                let mut input_axis_and_strides =
125                    input.strides.iter().enumerate().collect::<Vec<_>>();
126                input_axis_and_strides.sort_by_key(|(_, stride)| *stride);
127                let input_sorted_axis = input_axis_and_strides
128                    .into_iter()
129                    .map(|(a, _)| a)
130                    .take_while(|a| *a != axis);
131
132                let mut output_axis_and_strides =
133                    output.strides.iter().enumerate().collect::<Vec<_>>();
134                output_axis_and_strides.sort_by_key(|(_, stride)| *stride);
135                let output_sorted_axis = output_axis_and_strides
136                    .into_iter()
137                    .filter_map(|(a, _)| (a != axis).then_some(a));
138
139                let max_line_size = input_sorted_axis
140                    .zip(output_sorted_axis)
141                    .filter_map(|(i, o)| (i == o).then_some(output.shape[i]))
142                    .product();
143
144                tensor_line_size_perpendicular(
145                    supported_line_sizes.filter(|size| {
146                        *size as usize <= max_line_size && max_line_size % *size as usize == 0
147                    }),
148                    input.shape,
149                    input.strides,
150                    axis,
151                ) as u32
152            }
153        };
154
155        if self.line_size_input > 1 && self.line_mode == LineMode::Perpendicular {
156            // TODO that this can be improved
157            let rank = output.strides.len();
158            let is_contiguous =
159                is_contiguous(&output.shape[axis..rank], &output.strides[axis..rank])
160                    && output.strides[rank - 1] == 1;
161            let shape = output.shape.get(axis + 1).cloned().unwrap_or(1) as u32;
162
163            if is_contiguous && shape.is_multiple_of(self.line_size_input) {
164                self.line_size_output = self.line_size_input;
165            }
166        }
167        self
168    }
169
170    pub fn generate_cube_dim<S: ComputeServer>(
171        mut self,
172        client: &ComputeClient<S>,
173        use_planes: bool,
174    ) -> Self {
175        let hw_properties = &client.properties().hardware;
176
177        let plane_dim = if use_planes {
178            hw_properties.plane_size_min
179        } else {
180            hw_properties.plane_size_max
181        };
182
183        let plane_count = if plane_dim * DEFAULT_PLANE_COUNT > hw_properties.max_units_per_cube {
184            hw_properties.max_units_per_cube / plane_dim
185        } else {
186            DEFAULT_PLANE_COUNT
187        };
188
189        self.cube_dim = CubeDim::new_2d(plane_dim, plane_count);
190        self
191    }
192
193    pub fn generate_cube_count<R: Runtime>(
194        mut self,
195        reduce_count: u32,
196        strategy: &ReduceStrategy,
197    ) -> Self {
198        let agent_count_per_cube =  // An agent is either a unit, a plane or a whole cube depending on the strategy.
199            match strategy {
200                ReduceStrategy { shared: true, .. } => 1,
201                ReduceStrategy { use_planes: true, .. } => self.cube_dim.y,
202                ReduceStrategy { use_planes: false, .. } => self.cube_dim.num_elems(),
203            };
204        let reduce_count_per_cube = match self.line_mode {
205            LineMode::Parallel => agent_count_per_cube,
206            LineMode::Perpendicular => agent_count_per_cube * self.line_size_input,
207        };
208
209        let cube_count = reduce_count.div_ceil(reduce_count_per_cube);
210
211        self.do_bound_checks_if(reduce_count_per_cube * cube_count > reduce_count);
212
213        // If needed, we decompose the cube count to be within runtime limitation.
214        let (max_x, max_y, _) = R::max_cube_count();
215        let mut cube_count_x = cube_count;
216        let mut cube_count_y = 1;
217        let mut cube_count_z = 1;
218        while cube_count_x > max_x {
219            cube_count_x /= 2;
220            cube_count_y *= 2;
221        }
222        while cube_count_y > max_y {
223            cube_count_y /= 2;
224            cube_count_z *= 2;
225        }
226        self.cube_count = CubeCount::new_3d(cube_count_x, cube_count_y, cube_count_z);
227        self.do_bound_checks_if(cube_count_x * cube_count_y != cube_count);
228
229        self
230    }
231
232    fn do_bound_checks_if(&mut self, condition: bool) {
233        self.bound_checks = self.bound_checks || condition;
234    }
235}