cubecl_ir/
operation.rs

1use core::fmt::Display;
2
3use super::{Branch, CoopMma, Item, NonSemantic, Plane, Synchronization, Variable};
4use crate::{
5    Arithmetic, AtomicOp, Bitwise, Metadata, OperationArgs, OperationReflect, Operator, TmaOps,
6    comparison::Comparison,
7};
8use crate::{BarrierOps, SourceLoc, TypeHash};
9use alloc::{
10    format,
11    string::{String, ToString},
12    vec::Vec,
13};
14use derive_more::derive::From;
15
16/// All operations that can be used in a GPU compute shader.
17///
18/// Notes:
19///
20/// [Operator] can be vectorized, but other operations can't.
21/// Therefore, during tracing, only operators can be registered.
22///
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, From, OperationReflect)]
25#[operation(opcode_name = OpCode)]
26#[allow(dead_code, missing_docs, clippy::large_enum_variant)] // Some variants might not be used with different flags
27pub enum Operation {
28    #[operation(pure)]
29    #[from(ignore)]
30    Copy(Variable),
31    #[operation(nested)]
32    Arithmetic(Arithmetic),
33    #[operation(nested)]
34    Comparison(Comparison),
35    #[operation(nested)]
36    Bitwise(Bitwise),
37    #[operation(nested)]
38    Operator(Operator),
39    #[operation(nested)]
40    Atomic(AtomicOp),
41    #[operation(nested)]
42    Metadata(Metadata),
43    #[operation(nested)]
44    Branch(Branch),
45    #[operation(nested)]
46    Synchronization(Synchronization),
47    #[operation(nested)]
48    Plane(Plane),
49    #[operation(nested)]
50    CoopMma(CoopMma),
51    #[operation(nested)]
52    Barrier(BarrierOps),
53    #[operation(nested)]
54    Tma(TmaOps),
55    /// Non-semantic instructions (i.e. comments, debug info)
56    #[operation(nested)]
57    NonSemantic(NonSemantic),
58}
59
60/// An instruction that contains a right hand side [`Operation`] and an optional out variable.
61#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
62#[derive(Debug, Clone, PartialEq, Eq, Hash, TypeHash)]
63pub struct Instruction {
64    pub out: Option<Variable>,
65    pub source_loc: Option<SourceLoc>,
66    pub operation: Operation,
67}
68
69impl Instruction {
70    pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
71        Instruction {
72            out: Some(out),
73            operation: operation.into(),
74            source_loc: None,
75        }
76    }
77
78    pub fn no_out(operation: impl Into<Operation>) -> Self {
79        Instruction {
80            out: None,
81            operation: operation.into(),
82            source_loc: None,
83        }
84    }
85
86    pub fn out(&self) -> Variable {
87        self.out.unwrap()
88    }
89
90    pub fn item(&self) -> Item {
91        self.out().item
92    }
93}
94
95impl Display for Instruction {
96    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
97        match &self.operation {
98            Operation::Operator(Operator::CopyMemory(op)) => write!(
99                f,
100                "copy_mem({}[{}], {}[{}])",
101                self.out(),
102                op.out_index,
103                op.input,
104                op.in_index
105            ),
106            Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
107                f,
108                "copy_mem_bulk({}[{}], {}[{}], {})",
109                self.out(),
110                op.out_index,
111                op.input,
112                op.in_index,
113                op.len
114            ),
115            Operation::Operator(Operator::IndexAssign(op)) => {
116                write!(f, "{}[{}] = {}", self.out(), op.index, op.value)
117            }
118            Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
119                write!(f, "unchecked {}[{}] = {}", self.out(), op.index, op.value)
120            }
121            Operation::Operator(Operator::Cast(op)) => {
122                write!(f, "{} = cast<{}>({})", self.out(), self.item(), op.input)
123            }
124            Operation::Operator(Operator::Reinterpret(op)) => {
125                write!(f, "{} = bitcast<{}>({})", self.out(), self.item(), op.input)
126            }
127            _ => {
128                if let Some(out) = self.out {
129                    write!(f, "{out} = {}", self.operation)
130                } else {
131                    write!(f, "{}", self.operation)
132                }
133            }
134        }
135    }
136}
137
138impl Display for Operation {
139    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
140        match self {
141            Operation::Arithmetic(arithmetic) => write!(f, "{arithmetic}"),
142            Operation::Comparison(comparison) => write!(f, "{comparison}"),
143            Operation::Bitwise(bitwise) => write!(f, "{bitwise}"),
144            Operation::Operator(operator) => write!(f, "{operator}"),
145            Operation::Atomic(atomic) => write!(f, "{atomic}"),
146            Operation::Metadata(metadata) => write!(f, "{metadata}"),
147            Operation::Branch(branch) => write!(f, "{branch}"),
148            Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
149            Operation::Plane(plane) => write!(f, "{plane}"),
150            Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
151            Operation::Copy(variable) => write!(f, "{variable}"),
152            Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
153            Operation::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
154            Operation::Tma(tma_ops) => write!(f, "{tma_ops}"),
155        }
156    }
157}
158
159pub fn fmt_vararg(args: &[impl Display]) -> String {
160    if args.is_empty() {
161        "".to_string()
162    } else {
163        let str = args
164            .iter()
165            .map(|it| it.to_string())
166            .collect::<Vec<_>>()
167            .join(", ");
168        format!(", {str}")
169    }
170}
171
172#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
173#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
174#[allow(missing_docs)]
175pub struct IndexOperator {
176    pub list: Variable,
177    pub index: Variable,
178    pub line_size: u32, // 0 == same as list.
179}
180
181#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
182#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
183#[allow(missing_docs)]
184pub struct IndexAssignOperator {
185    // list is out.
186    pub index: Variable,
187    pub value: Variable,
188    pub line_size: u32, // 0 == same as list.
189}
190
191#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
192#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
193#[allow(missing_docs)]
194pub struct BinaryOperator {
195    pub lhs: Variable,
196    pub rhs: Variable,
197}
198
199#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
200#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
201#[allow(missing_docs)]
202pub struct UnaryOperator {
203    pub input: Variable,
204}
205
206impl From<Branch> for Instruction {
207    fn from(value: Branch) -> Self {
208        Instruction::no_out(value)
209    }
210}
211
212impl From<Synchronization> for Instruction {
213    fn from(value: Synchronization) -> Self {
214        Instruction::no_out(value)
215    }
216}
217
218impl From<NonSemantic> for Instruction {
219    fn from(value: NonSemantic) -> Self {
220        Instruction::no_out(value)
221    }
222}