1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
use cubecl::zspace::Strides;
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum VectorizationMode {
Parallel,
Perpendicular,
}
pub fn output_vectorization_axis(
input_strides: &Strides,
reduce_axis: usize,
_vectorization_mode: VectorizationMode,
) -> usize {
if input_strides.len() < 2 {
// The axis of vectorization for input and output are both 0
return 0;
}
// Find the two smallest strides overall (tracking axis indices).
let mut min1 = (usize::MAX, 0); // (stride, axis)
let mut min2 = (usize::MAX, 0);
for (i, &s) in input_strides.iter().enumerate() {
if s < min1.0 {
min2 = min1;
min1 = (s, i);
} else if s < min2.0 {
min2 = (s, i);
}
}
// The vectorization axis is the smallest-stride *non-reduce* axis. For
// parallel reductions the reduce axis is itself the contiguous (stride 1)
// axis, so this falls through to the next-smallest; for perpendicular it's
// usually the smallest, except when the reduce axis happens to share the
// overall minimum (e.g. a broadcast stride of 0), which forces the fallback.
if min1.1 == reduce_axis {
min2.1
} else {
min1.1
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
/// How bound checks is handled for inner reductions.
pub enum BoundChecks {
/// No bound check is necessary.
None,
/// Using a mask is enough for bound checks.
/// This will still read the memory in an out-of-bound location,
/// but will replace the value by the null value.
Mask,
/// Branching is necessary for bound checks.
///
/// Probably the right setting when performing fuse on read.
Branch,
}
impl BoundChecks {
pub fn idle(self) -> Self {
Self::Mask
}
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
pub enum IdleMode {
None,
Mask,
Terminate,
}
impl IdleMode {
/// Whether idle is activated.
pub fn is_enabled(&self) -> bool {
!matches!(self, Self::None)
}
}