1use cubecl_core::{
2    prelude::*, server::ComputeServer, tensor_line_size_parallel, tensor_line_size_perpendicular,
3};
4use cubecl_std::tensor::is_contiguous;
5
6use crate::ReduceStrategy;
7
8const DEFAULT_PLANE_COUNT: u32 = 8;
10
11#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
12pub enum LineMode {
13    Parallel,
14    Perpendicular,
15}
16
17#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
18pub enum BoundChecksInner {
20    None,
22    Mask,
26    Branch,
30}
31
32#[derive(Debug, Clone)]
33pub struct ReduceConfig {
34    pub cube_count: CubeCount,
35    pub cube_dim: CubeDim,
36    pub line_mode: LineMode,
37    pub line_size_input: u32,
38    pub line_size_output: u32,
39    pub bound_checks: bool,
40    pub bound_checks_inner: BoundChecksInner,
41}
42
43impl ReduceConfig {
44    pub(crate) fn generate<R: Runtime, In: CubePrimitive>(
45        client: &ComputeClient<R::Server>,
46        input: &TensorHandleRef<R>,
47        output: &TensorHandleRef<R>,
48        axis: usize,
49        strategy: &ReduceStrategy,
50    ) -> ReduceConfig {
51        let reduce_count = output.size() as u32;
52        ReduceConfig::new()
53            .generate_line_mode(input, axis)
54            .generate_line_size::<R, In>(input, output, axis)
55            .generate_cube_dim(client, strategy.use_planes)
56            .generate_cube_count::<R>(reduce_count, strategy)
57    }
58
59    fn new() -> Self {
60        Self {
62            cube_count: CubeCount::new_single(),
63            cube_dim: CubeDim::new_single(),
64            line_mode: LineMode::Parallel,
65            line_size_input: 1,
66            line_size_output: 1,
67            bound_checks: true,
68            bound_checks_inner: BoundChecksInner::Mask,
69        }
70    }
71
72    fn generate_line_mode<R: Runtime>(mut self, input: &TensorHandleRef<R>, axis: usize) -> Self {
73        let stride = input.strides[axis];
74        self.line_mode = if stride == 1 {
75            LineMode::Parallel
76        } else {
77            LineMode::Perpendicular
78        };
79        self
80    }
81
82    fn generate_line_size<R: Runtime, In: CubePrimitive>(
83        mut self,
84        input: &TensorHandleRef<R>,
85        output: &TensorHandleRef<R>,
86        axis: usize,
87    ) -> Self {
88        let supported_line_sizes = R::io_optimized_line_sizes_unchecked(size_of::<In>());
89        self.line_size_input = match self.line_mode {
90            LineMode::Parallel => {
91                tensor_line_size_parallel(supported_line_sizes, input.shape, input.strides, axis)
92                    as u32
93            }
94            LineMode::Perpendicular => {
95                let mut input_axis_and_strides =
123                    input.strides.iter().enumerate().collect::<Vec<_>>();
124                input_axis_and_strides.sort_by_key(|(_, stride)| *stride);
125                let input_sorted_axis = input_axis_and_strides
126                    .into_iter()
127                    .map(|(a, _)| a)
128                    .take_while(|a| *a != axis);
129
130                let mut output_axis_and_strides =
131                    output.strides.iter().enumerate().collect::<Vec<_>>();
132                output_axis_and_strides.sort_by_key(|(_, stride)| *stride);
133                let output_sorted_axis = output_axis_and_strides
134                    .into_iter()
135                    .filter_map(|(a, _)| (a != axis).then_some(a));
136
137                let max_line_size = input_sorted_axis
138                    .zip(output_sorted_axis)
139                    .filter_map(|(i, o)| (i == o).then_some(output.shape[i]))
140                    .product();
141
142                tensor_line_size_perpendicular(
143                    supported_line_sizes.filter(|size| {
144                        *size as usize <= max_line_size && max_line_size % *size as usize == 0
145                    }),
146                    input.shape,
147                    input.strides,
148                    axis,
149                ) as u32
150            }
151        };
152
153        if self.line_size_input > 1 && self.line_mode == LineMode::Perpendicular {
154            let rank = output.strides.len();
156            let is_contiguous =
157                is_contiguous(&output.shape[axis..rank], &output.strides[axis..rank])
158                    && output.strides[rank - 1] == 1;
159            let shape = output.shape.get(axis + 1).cloned().unwrap_or(1) as u32;
160
161            if is_contiguous && shape.is_multiple_of(self.line_size_input) {
162                self.line_size_output = self.line_size_input;
163            }
164        }
165        self
166    }
167
168    pub fn generate_cube_dim<S: ComputeServer>(
169        mut self,
170        client: &ComputeClient<S>,
171        use_planes: bool,
172    ) -> Self {
173        self.cube_dim = if use_planes {
174            let plane_dim = client.properties().hardware.plane_size_min;
175            CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT)
176        } else {
177            let plane_dim = client.properties().hardware.plane_size_max;
178            CubeDim::new_2d(plane_dim, DEFAULT_PLANE_COUNT)
179        };
180        self
181    }
182
183    pub fn generate_cube_count<R: Runtime>(
184        mut self,
185        reduce_count: u32,
186        strategy: &ReduceStrategy,
187    ) -> Self {
188        let agent_count_per_cube =  match strategy {
190                ReduceStrategy { shared: true, .. } => 1,
191                ReduceStrategy { use_planes: true, .. } => self.cube_dim.y,
192                ReduceStrategy { use_planes: false, .. } => self.cube_dim.num_elems(),
193            };
194        let reduce_count_per_cube = match self.line_mode {
195            LineMode::Parallel => agent_count_per_cube,
196            LineMode::Perpendicular => agent_count_per_cube * self.line_size_input,
197        };
198
199        let cube_count = reduce_count.div_ceil(reduce_count_per_cube);
200
201        self.do_bound_checks_if(reduce_count_per_cube * cube_count > reduce_count);
202
203        let (max_x, max_y, _) = R::max_cube_count();
205        let mut cube_count_x = cube_count;
206        let mut cube_count_y = 1;
207        let mut cube_count_z = 1;
208        while cube_count_x > max_x {
209            cube_count_x /= 2;
210            cube_count_y *= 2;
211        }
212        while cube_count_y > max_y {
213            cube_count_y /= 2;
214            cube_count_z *= 2;
215        }
216        self.cube_count = CubeCount::new_3d(cube_count_x, cube_count_y, cube_count_z);
217        self.do_bound_checks_if(cube_count_x * cube_count_y != cube_count);
218
219        self
220    }
221
222    fn do_bound_checks_if(&mut self, condition: bool) {
223        self.bound_checks = self.bound_checks || condition;
224    }
225}