cubecl_reduce/
config.rs

1use cubecl_core::{
2    channel::ComputeChannel, prelude::*, server::ComputeServer, tensor_line_size_parallel,
3    tensor_line_size_perpendicular,
4};
5use cubecl_std::tensor::is_contiguous;
6
7use crate::ReduceStrategy;
8
9// TODO: Should we allows the user to change that?
10const DEFAULT_CUBE_DIM: CubeDim = CubeDim::new_2d(32, 8);
11const DEFAULT_PLANE_COUNT: u32 = 8;
12
13#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
14pub enum LineMode {
15    Parallel,
16    Perpendicular,
17}
18
19#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
20/// How bound checks is handled for inner reductions.
21pub enum BoundChecksInner {
22    /// No bound check is necessary.
23    None,
24    /// Using a mask is enough for bound checks.
25    /// This will still read the memory in an out-of-bound location,
26    /// but will replace the value by the null value.
27    Mask,
28    /// Branching is necessary for bound checks.
29    ///
30    /// Probably the right setting when performing fuse on read.
31    Branch,
32}
33
34#[derive(Debug, Clone)]
35pub struct ReduceConfig {
36    pub cube_count: CubeCount,
37    pub cube_dim: CubeDim,
38    pub line_mode: LineMode,
39    pub line_size_input: u32,
40    pub line_size_output: u32,
41    pub bound_checks: bool,
42    pub bound_checks_inner: BoundChecksInner,
43}
44
45impl ReduceConfig {
46    pub(crate) fn generate<R: Runtime, In: CubePrimitive>(
47        client: &ComputeClient<R::Server, R::Channel>,
48        input: &TensorHandleRef<R>,
49        output: &TensorHandleRef<R>,
50        axis: usize,
51        strategy: &ReduceStrategy,
52    ) -> ReduceConfig {
53        let reduce_count = output.size() as u32;
54        ReduceConfig::new()
55            .generate_line_mode(input, axis)
56            .generate_line_size::<R, In>(input, output, axis)
57            .generate_cube_dim(client, strategy.use_planes)
58            .generate_cube_count::<R>(reduce_count, strategy)
59    }
60
61    fn new() -> Self {
62        // This is only a dummy configuration to use as a starting point.
63        Self {
64            cube_count: CubeCount::new_single(),
65            cube_dim: CubeDim::new_single(),
66            line_mode: LineMode::Parallel,
67            line_size_input: 1,
68            line_size_output: 1,
69            bound_checks: true,
70            bound_checks_inner: BoundChecksInner::Mask,
71        }
72    }
73
74    fn generate_line_mode<R: Runtime>(mut self, input: &TensorHandleRef<R>, axis: usize) -> Self {
75        let stride = input.strides[axis];
76        self.line_mode = if stride == 1 {
77            LineMode::Parallel
78        } else {
79            LineMode::Perpendicular
80        };
81        self
82    }
83
84    fn generate_line_size<R: Runtime, In: CubePrimitive>(
85        mut self,
86        input: &TensorHandleRef<R>,
87        output: &TensorHandleRef<R>,
88        axis: usize,
89    ) -> Self {
90        let elem = In::as_elem_native_unchecked();
91        let supported_line_sizes = R::line_size_elem(&elem);
92        self.line_size_input = match self.line_mode {
93            LineMode::Parallel => {
94                tensor_line_size_parallel(supported_line_sizes, input.shape, input.strides, axis)
95                    as u32
96            }
97            LineMode::Perpendicular => {
98                // To compute the maximum line size we can used,
99                // we first sort both the input and output axes by increasing strides.
100                // As example, consider
101                //    input shape = [2, 4, 6, 8]
102                //    input stride = [1, 16, 64, 2]
103                //    output shape = [2, 1, 6, 8]
104                //    output stride = [1, 1, 2, 12]
105                //    axis = 1
106                //
107                // then we have
108                //    input sorted axis = [0, 3, 1, 2]
109                //    output sorted axis = [0, 1, 2, 3]
110                //
111                // From that point, we look at all the axes before the target axis in the sorted input.
112                // That is [0, 3] in the example.
113                // In the output, we remove the target axis leading to [0, 2, 3] in the example.
114                //
115                // In order to use perpendicular line, we are limited by the number of entries that are both
116                // contiguous in the input and output. This is obtained by taking the head of each list until they are different.
117                // In the above example, only the 0 axis is contiguous in both tensor, but it output sorted axis were [0, 1, 3, 2] instead,
118                // both the 0 and 3 axes would be contiguous in the two tensors.
119                // The corresponding number of entries is the product of the shape for the contiguous axes.
120                // In the example, it is simply 2.
121                //
122                // This gives us an upper bound on the line size we can used.
123                // Then, we use the regular method to find the best line size that match the device capacities.
124
125                let mut input_axis_and_strides =
126                    input.strides.iter().enumerate().collect::<Vec<_>>();
127                input_axis_and_strides.sort_by_key(|(_, stride)| *stride);
128                let input_sorted_axis = input_axis_and_strides
129                    .into_iter()
130                    .map(|(a, _)| a)
131                    .take_while(|a| *a != axis);
132
133                let mut output_axis_and_strides =
134                    output.strides.iter().enumerate().collect::<Vec<_>>();
135                output_axis_and_strides.sort_by_key(|(_, stride)| *stride);
136                let output_sorted_axis = output_axis_and_strides
137                    .into_iter()
138                    .filter_map(|(a, _)| (a != axis).then_some(a));
139
140                let max_line_size = input_sorted_axis
141                    .zip(output_sorted_axis)
142                    .filter_map(|(i, o)| (i == o).then_some(output.shape[i]))
143                    .product();
144
145                tensor_line_size_perpendicular(
146                    supported_line_sizes.filter(|size| {
147                        *size as usize <= max_line_size && max_line_size % *size as usize == 0
148                    }),
149                    input.shape,
150                    input.strides,
151                    axis,
152                ) as u32
153            }
154        };
155
156        if self.line_size_input > 1 && self.line_mode == LineMode::Perpendicular {
157            // TODO that this can be improved
158            let rank = output.strides.len();
159            let is_contiguous =
160                is_contiguous(&output.shape[axis..rank], &output.strides[axis..rank])
161                    && output.strides[rank - 1] == 1;
162            let shape = output.shape.get(axis + 1).cloned().unwrap_or(1) as u32;
163
164            if is_contiguous && shape % self.line_size_input == 0 {
165                self.line_size_output = self.line_size_input;
166            }
167        }
168        self
169    }
170
171    pub fn generate_cube_dim<S: ComputeServer, C: ComputeChannel<S>>(
172        mut self,
173        client: &ComputeClient<S, C>,
174        use_planes: bool,
175    ) -> Self {
176        self.cube_dim = if use_planes {
177            let plane_dim = client.properties().hardware.plane_size_min;
178            CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT)
179        } else {
180            DEFAULT_CUBE_DIM
181        };
182        self
183    }
184
185    pub fn generate_cube_count<R: Runtime>(
186        mut self,
187        reduce_count: u32,
188        strategy: &ReduceStrategy,
189    ) -> Self {
190        let agent_count_per_cube =  // An agent is either a unit, a plane or a whole cube depending on the strategy.
191            match strategy {
192                ReduceStrategy { shared: true, .. } => 1,
193                ReduceStrategy { use_planes: true, .. } => self.cube_dim.y,
194                ReduceStrategy { use_planes: false, .. } => self.cube_dim.num_elems(),
195            };
196        let reduce_count_per_cube = match self.line_mode {
197            LineMode::Parallel => agent_count_per_cube,
198            LineMode::Perpendicular => agent_count_per_cube * self.line_size_input,
199        };
200
201        let cube_count = reduce_count.div_ceil(reduce_count_per_cube);
202
203        self.do_bound_checks_if(reduce_count_per_cube * cube_count > reduce_count);
204
205        // If needed, we decompose the cube count to be within runtime limitation.
206        let (max_x, max_y, _) = R::max_cube_count();
207        let mut cube_count_x = cube_count;
208        let mut cube_count_y = 1;
209        let mut cube_count_z = 1;
210        while cube_count_x > max_x {
211            cube_count_x /= 2;
212            cube_count_y *= 2;
213        }
214        while cube_count_y > max_y {
215            cube_count_y /= 2;
216            cube_count_z *= 2;
217        }
218        self.cube_count = CubeCount::new_3d(cube_count_x, cube_count_y, cube_count_z);
219        self.do_bound_checks_if(cube_count_x * cube_count_y != cube_count);
220
221        self
222    }
223
224    fn do_bound_checks_if(&mut self, condition: bool) {
225        self.bound_checks = self.bound_checks || condition;
226    }
227}