cubecl_reduce/
tune_key.rs

1use cubecl_core as cubecl;
2
3use cubecl_core::{AutotuneKey, ir::Elem};
4use serde::{Deserialize, Serialize};
5
6#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
7/// Autotune key representative of reduce versions
8pub struct ReduceAutotuneKey {
9    elem_input: Elem,
10    elem_output: Elem,
11    elem_acc: Elem,
12    potential_line_size: u8,
13    axis_is_contiguous: bool,
14    #[autotune(anchor(exp(min = 16, max = 4096)))]
15    reduce_axis_shape: usize,
16    #[autotune(anchor(exp(max = 16384, base = 4)))]
17    reduce_count: usize,
18}
19
20impl ReduceAutotuneKey {
21    pub fn generate(
22        elem_input: Elem,
23        elem_output: Elem,
24        elem_acc: Elem,
25        input_shape: &[usize],
26        axis_is_contiguous: bool,
27        axis: usize,
28    ) -> Self {
29        let rank = input_shape.len();
30
31        if axis > rank {
32            panic!("axis {axis} is out-of-bound for a rank of {rank}");
33        }
34
35        let reduce_axis_shape = input_shape[axis];
36
37        let reduce_count = input_shape
38            .iter()
39            .enumerate()
40            .filter_map(|(i, shape)| (i != axis).then_some(shape))
41            .product();
42
43        let potential_line_size = Self::potential_line_size(elem_input.size(), reduce_axis_shape);
44
45        ReduceAutotuneKey::new(
46            elem_input,
47            elem_output,
48            elem_acc,
49            potential_line_size,
50            axis_is_contiguous,
51            reduce_axis_shape,
52            reduce_count,
53        )
54    }
55
56    fn potential_line_size(elem_size: usize, mut shape: usize) -> u8 {
57        let mut potential_line_size = 1;
58        let max_bytes_in_line = 16; // 128 bits
59        //
60        while shape % 2 == 0 && potential_line_size as usize * elem_size < max_bytes_in_line {
61            potential_line_size *= 2;
62            shape /= 2;
63        }
64        potential_line_size
65    }
66}