cubecl_reduce/
tune_key.rs1use cubecl_core as cubecl;
2
3use cubecl_core::{AutotuneKey, ir::Elem};
4use serde::{Deserialize, Serialize};
5
6#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
7pub struct ReduceAutotuneKey {
9 elem_input: Elem,
10 elem_output: Elem,
11 elem_acc: Elem,
12 potential_line_size: u8,
13 axis_is_contiguous: bool,
14 #[autotune(anchor(exp(min = 16, max = 4096)))]
15 reduce_axis_shape: usize,
16 #[autotune(anchor(exp(max = 16384, base = 4)))]
17 reduce_count: usize,
18}
19
20impl ReduceAutotuneKey {
21 pub fn generate(
22 elem_input: Elem,
23 elem_output: Elem,
24 elem_acc: Elem,
25 input_shape: &[usize],
26 axis_is_contiguous: bool,
27 axis: usize,
28 ) -> Self {
29 let rank = input_shape.len();
30
31 if axis > rank {
32 panic!("axis {axis} is out-of-bound for a rank of {rank}");
33 }
34
35 let reduce_axis_shape = input_shape[axis];
36
37 let reduce_count = input_shape
38 .iter()
39 .enumerate()
40 .filter_map(|(i, shape)| (i != axis).then_some(shape))
41 .product();
42
43 let potential_line_size = Self::potential_line_size(elem_input.size(), reduce_axis_shape);
44
45 ReduceAutotuneKey::new(
46 elem_input,
47 elem_output,
48 elem_acc,
49 potential_line_size,
50 axis_is_contiguous,
51 reduce_axis_shape,
52 reduce_count,
53 )
54 }
55
56 fn potential_line_size(elem_size: usize, mut shape: usize) -> u8 {
57 let mut potential_line_size = 1;
58 let max_bytes_in_line = 16; while shape % 2 == 0 && potential_line_size as usize * elem_size < max_bytes_in_line {
61 potential_line_size *= 2;
62 shape /= 2;
63 }
64 potential_line_size
65 }
66}