cubecl_matmul/components/
ident.rs

1#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
2/// Identifier for all three tensors in a matmul
3///
4/// Useful to specialize some functions depending on the tensor
5pub enum Ident {
6    Lhs,
7    Rhs,
8    Out,
9}
10
11impl Ident {
12    pub fn as_input_ident(&self) -> InputIdent {
13        match self {
14            Ident::Lhs => InputIdent::Lhs,
15            Ident::Rhs => InputIdent::Rhs,
16            Ident::Out => panic!("Out is not an input."),
17        }
18    }
19}
20
21#[derive(Copy, Clone, PartialEq, Eq, Hash, Debug)]
22/// Identifier for the two input tensors in a matmul.
23///
24/// Useful to specialize some functions depending on the tensor
25pub enum InputIdent {
26    Lhs,
27    Rhs,
28}
29
30impl InputIdent {
31    pub fn as_ident(&self) -> Ident {
32        match self {
33            InputIdent::Lhs => Ident::Lhs,
34            InputIdent::Rhs => Ident::Rhs,
35        }
36    }
37}
38
39impl From<InputIdent> for Ident {
40    fn from(value: InputIdent) -> Self {
41        value.as_ident()
42    }
43}