cubecl_reduce/
tune_key.rs

1use cubecl_core as cubecl;
2
3use cubecl_core::{AutotuneKey, ir::Elem};
4use serde::{Deserialize, Serialize};
5
6#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
7/// Autotune key representative of reduce versions
8pub struct ReduceAutotuneKey {
9    elem_input: Elem,
10    elem_output: Elem,
11    potential_line_size: u8,
12    axis_is_contiguous: bool,
13    #[autotune(anchor(exp(min = 16, max = 4096)))]
14    reduce_axis_shape: usize,
15    #[autotune(anchor(exp(max = 16384, base = 4)))]
16    reduce_count: usize,
17}
18
19impl ReduceAutotuneKey {
20    pub fn generate(
21        elem_input: Elem,
22        elem_output: Elem,
23        input_shape: &[usize],
24        axis_is_contiguous: bool,
25        axis: usize,
26    ) -> Self {
27        let rank = input_shape.len();
28
29        if axis > rank {
30            panic!("axis {axis} is out-of-bound for a rank of {rank}");
31        }
32
33        let reduce_axis_shape = input_shape[axis];
34
35        let reduce_count = input_shape
36            .iter()
37            .enumerate()
38            .filter_map(|(i, shape)| (i != axis).then_some(shape))
39            .product();
40
41        let potential_line_size = Self::potential_line_size(elem_input.size(), reduce_axis_shape);
42
43        ReduceAutotuneKey::new(
44            elem_input,
45            elem_output,
46            potential_line_size,
47            axis_is_contiguous,
48            reduce_axis_shape,
49            reduce_count,
50        )
51    }
52
53    fn potential_line_size(elem_size: usize, mut shape: usize) -> u8 {
54        let mut potential_line_size = 1;
55        let max_bytes_in_line = 16; // 128 bits
56        //
57        while shape % 2 == 0 && potential_line_size as usize * elem_size < max_bytes_in_line {
58            potential_line_size *= 2;
59            shape /= 2;
60        }
61        potential_line_size
62    }
63}