cubecl_reduce/
tune_key.rs1use cubecl_core as cubecl;
2
3use cubecl_core::{AutotuneKey, ir::Elem};
4use serde::{Deserialize, Serialize};
5
6#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
7pub struct ReduceAutotuneKey {
9 elem_input: Elem,
10 elem_output: Elem,
11 potential_line_size: u8,
12 axis_is_contiguous: bool,
13 #[autotune(anchor(exp(min = 16, max = 4096)))]
14 reduce_axis_shape: usize,
15 #[autotune(anchor(exp(max = 16384, base = 4)))]
16 reduce_count: usize,
17}
18
19impl ReduceAutotuneKey {
20 pub fn generate(
21 elem_input: Elem,
22 elem_output: Elem,
23 input_shape: &[usize],
24 axis_is_contiguous: bool,
25 axis: usize,
26 ) -> Self {
27 let rank = input_shape.len();
28
29 if axis > rank {
30 panic!("axis {axis} is out-of-bound for a rank of {rank}");
31 }
32
33 let reduce_axis_shape = input_shape[axis];
34
35 let reduce_count = input_shape
36 .iter()
37 .enumerate()
38 .filter_map(|(i, shape)| (i != axis).then_some(shape))
39 .product();
40
41 let potential_line_size = Self::potential_line_size(elem_input.size(), reduce_axis_shape);
42
43 ReduceAutotuneKey::new(
44 elem_input,
45 elem_output,
46 potential_line_size,
47 axis_is_contiguous,
48 reduce_axis_shape,
49 reduce_count,
50 )
51 }
52
53 fn potential_line_size(elem_size: usize, mut shape: usize) -> u8 {
54 let mut potential_line_size = 1;
55 let max_bytes_in_line = 16; while shape % 2 == 0 && potential_line_size as usize * elem_size < max_bytes_in_line {
58 potential_line_size *= 2;
59 shape /= 2;
60 }
61 potential_line_size
62 }
63}