cubecl_matmul/components/
ident.rs

1use crate::components::global::memory::ViewDirection;
2
3#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
4/// Identifier for all three tensors in a matmul
5///
6/// Useful to specialize some functions depending on the tensor
7pub enum MatmulIdent {
8    Lhs,
9    Rhs,
10    Out,
11}
12
13impl MatmulIdent {
14    /// Equivalent to into, but type inference works better within Cube functions
15    pub fn into_stage(self) -> StageIdent {
16        self.into()
17    }
18
19    pub fn view_direction(&self) -> ViewDirection {
20        match self {
21            MatmulIdent::Lhs => ViewDirection::Col,
22            MatmulIdent::Rhs => ViewDirection::Row,
23            MatmulIdent::Out => ViewDirection::None,
24        }
25    }
26}
27
28#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
29pub enum StageIdent {
30    Lhs,
31    Rhs,
32    Acc,
33    Out,
34}
35
36impl From<MatmulIdent> for StageIdent {
37    fn from(matmul_ident: MatmulIdent) -> Self {
38        match matmul_ident {
39            MatmulIdent::Lhs => StageIdent::Lhs,
40            MatmulIdent::Rhs => StageIdent::Rhs,
41            MatmulIdent::Out => StageIdent::Acc,
42        }
43    }
44}