cubecl_ir/
reflect.rs

1use alloc::collections::VecDeque;
2
3use alloc::vec;
4use alloc::vec::Vec;
5
6use crate::Variable;
7
8/// An operation that can be reflected on
9pub trait OperationReflect: Sized {
10    /// Type of the op codes for this operation
11    type OpCode;
12
13    /// Get the opcode for this operation
14    fn op_code(&self) -> Self::OpCode;
15    /// Get the list of arguments for this operation. If not all arguments are [`Variable`], returns
16    /// `None` instead.
17    fn args(&self) -> Option<Vec<Variable>> {
18        None
19    }
20    /// Create typed operation from an opcode and a list of arguments. Returns `None` if not all
21    /// arguments are [`Variable`].
22    #[allow(unused)]
23    fn from_code_and_args(op_code: Self::OpCode, args: &[Variable]) -> Option<Self> {
24        None
25    }
26    /// Whether this operation is commutative (arguments can be freely reordered). Ignored for
27    /// single argument operations.
28    fn is_commutative(&self) -> bool {
29        false
30    }
31    /// Whether this operation is pure (has no side effects). Things like uniform/plane operations
32    /// are considered impure, because they affect other units.
33    fn is_pure(&self) -> bool {
34        false
35    }
36}
37
38/// A type that represents an operation's arguments
39pub trait OperationArgs: Sized {
40    /// Construct this type from a list of arguments. If not all arguments are [`Variable`], returns
41    /// `None`
42    #[allow(unused)]
43    fn from_args(args: &[Variable]) -> Option<Self> {
44        None
45    }
46
47    /// Turns this type into a flat list of arguments. If not all arguments are [`Variable`],
48    /// returns `None`
49    fn as_args(&self) -> Option<Vec<Variable>> {
50        None
51    }
52}
53
54impl OperationArgs for Variable {
55    fn from_args(args: &[Variable]) -> Option<Self> {
56        Some(args[0])
57    }
58
59    fn as_args(&self) -> Option<Vec<Variable>> {
60        Some(vec![*self])
61    }
62}
63
64/// Types that can be destructured into and created from a list of [`Variable`]s.
65pub trait FromArgList: Sized {
66    /// Creates this type from a list of variables. This works like a parse stream, where consumed
67    /// variables are popped from the front.
68    fn from_arg_list(args: &mut VecDeque<Variable>) -> Self;
69    /// Turns this type into a list of [`Variable`]s.
70    fn as_arg_list(&self) -> impl IntoIterator<Item = Variable>;
71}
72
73impl FromArgList for Variable {
74    fn from_arg_list(args: &mut VecDeque<Variable>) -> Self {
75        args.pop_front().expect("Missing variable from arg list")
76    }
77
78    fn as_arg_list(&self) -> impl IntoIterator<Item = Variable> {
79        [*self]
80    }
81}
82
83impl FromArgList for Vec<Variable> {
84    fn from_arg_list(args: &mut VecDeque<Variable>) -> Self {
85        core::mem::take(args).into_iter().collect()
86    }
87
88    fn as_arg_list(&self) -> impl IntoIterator<Item = Variable> {
89        self.iter().cloned()
90    }
91}
92
93impl FromArgList for bool {
94    fn from_arg_list(args: &mut VecDeque<Variable>) -> Self {
95        args.pop_front()
96            .expect("Missing variable from arg list")
97            .as_const()
98            .unwrap()
99            .as_bool()
100    }
101
102    fn as_arg_list(&self) -> impl IntoIterator<Item = Variable> {
103        [(*self).into()]
104    }
105}
106
107impl FromArgList for u32 {
108    fn from_arg_list(args: &mut VecDeque<Variable>) -> Self {
109        args.pop_front()
110            .expect("Missing variable from arg list")
111            .as_const()
112            .unwrap()
113            .as_u32()
114    }
115
116    fn as_arg_list(&self) -> impl IntoIterator<Item = Variable> {
117        [(*self).into()]
118    }
119}