cubecl_ir/
operation.rs

1use core::fmt::Display;
2
3use super::{Branch, CoopMma, NonSemantic, Plane, Synchronization, Type, Variable};
4use crate::{
5    Arithmetic, AtomicOp, Bitwise, InstructionModes, Metadata, OperationArgs, OperationReflect,
6    Operator, TmaOps, comparison::Comparison, marker::Marker,
7};
8use crate::{BarrierOps, SourceLoc, TypeHash};
9use alloc::{
10    format,
11    string::{String, ToString},
12    vec::Vec,
13};
14use derive_more::derive::From;
15
16/// All operations that can be used in a GPU compute shader.
17///
18/// Notes:
19///
20/// [Operator] can be vectorized, but other operations can't.
21/// Therefore, during tracing, only operators can be registered.
22///
23#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
24#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, From, OperationReflect)]
25#[operation(opcode_name = OpCode)]
26#[allow(dead_code, missing_docs, clippy::large_enum_variant)] // Some variants might not be used with different flags
27pub enum Operation {
28    #[operation(pure)]
29    #[from(ignore)]
30    Copy(Variable),
31    #[operation(nested)]
32    Arithmetic(Arithmetic),
33    #[operation(nested)]
34    Comparison(Comparison),
35    #[operation(nested)]
36    Bitwise(Bitwise),
37    #[operation(nested)]
38    Operator(Operator),
39    #[operation(nested)]
40    Atomic(AtomicOp),
41    #[operation(nested)]
42    Metadata(Metadata),
43    #[operation(nested)]
44    Branch(Branch),
45    #[operation(nested)]
46    Synchronization(Synchronization),
47    #[operation(nested)]
48    Plane(Plane),
49    #[operation(nested)]
50    CoopMma(CoopMma),
51    #[operation(nested)]
52    Barrier(BarrierOps),
53    #[operation(nested)]
54    Tma(TmaOps),
55    /// Non-semantic instructions (i.e. comments, debug info)
56    #[operation(nested)]
57    NonSemantic(NonSemantic),
58    // Markers used by compilers to update state or modes, but don't emit instructions
59    #[operation(nested)]
60    Marker(Marker),
61}
62
63/// An instruction that contains a right hand side [`Operation`] and an optional out variable.
64#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
65#[derive(Debug, Clone, PartialEq, Eq, Hash, TypeHash)]
66pub struct Instruction {
67    pub out: Option<Variable>,
68    pub source_loc: Option<SourceLoc>,
69    pub modes: InstructionModes,
70    pub operation: Operation,
71}
72
73impl Instruction {
74    pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
75        Instruction {
76            out: Some(out),
77            operation: operation.into(),
78            source_loc: None,
79            modes: Default::default(),
80        }
81    }
82
83    pub fn no_out(operation: impl Into<Operation>) -> Self {
84        Instruction {
85            out: None,
86            operation: operation.into(),
87            source_loc: None,
88            modes: Default::default(),
89        }
90    }
91
92    pub fn out(&self) -> Variable {
93        self.out.unwrap()
94    }
95
96    pub fn ty(&self) -> Type {
97        self.out().ty
98    }
99}
100
101impl Display for Instruction {
102    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
103        match &self.operation {
104            Operation::Operator(Operator::CopyMemory(op)) => write!(
105                f,
106                "copy_mem({}[{}], {}[{}])",
107                self.out(),
108                op.out_index,
109                op.input,
110                op.in_index
111            ),
112            Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
113                f,
114                "copy_mem_bulk({}[{}], {}[{}], {})",
115                self.out(),
116                op.out_index,
117                op.input,
118                op.in_index,
119                op.len
120            ),
121            Operation::Operator(Operator::IndexAssign(op)) => {
122                write!(
123                    f,
124                    "{}[{}] = {}  : ({}, {}) -> ({})",
125                    self.out(),
126                    op.index,
127                    op.value,
128                    op.index.ty,
129                    op.value.ty,
130                    self.out().ty,
131                )
132            }
133            Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
134                write!(
135                    f,
136                    "unchecked {}[{}] = {} : ({}, {}) -> ({})",
137                    self.out(),
138                    op.index,
139                    op.value,
140                    op.index.ty,
141                    op.value.ty,
142                    self.out().ty,
143                )
144            }
145            Operation::Operator(Operator::Cast(op)) => {
146                write!(
147                    f,
148                    "{} = cast<{}>({}) : ({}) -> ({})",
149                    self.out(),
150                    self.ty(),
151                    op.input,
152                    op.input.ty,
153                    self.out().ty,
154                )
155            }
156            Operation::Operator(Operator::Reinterpret(op)) => {
157                write!(f, "{} = bitcast<{}>({})", self.out(), self.ty(), op.input)
158            }
159            _ => {
160                if let Some(out) = self.out {
161                    let mut vars_str = String::new();
162                    for (i, var) in self.operation.args().unwrap_or_default().iter().enumerate() {
163                        if i != 0 {
164                            vars_str.push_str(", ");
165                        }
166                        vars_str.push_str(&var.ty.to_string());
167                    }
168                    write!(
169                        f,
170                        "{out} = {} : ({}) -> ({})",
171                        self.operation, vars_str, out.ty
172                    )
173                } else {
174                    write!(f, "{}", self.operation)
175                }
176            }
177        }
178    }
179}
180
181impl Display for Operation {
182    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
183        match self {
184            Operation::Arithmetic(arithmetic) => write!(f, "{arithmetic}"),
185            Operation::Comparison(comparison) => write!(f, "{comparison}"),
186            Operation::Bitwise(bitwise) => write!(f, "{bitwise}"),
187            Operation::Operator(operator) => write!(f, "{operator}"),
188            Operation::Atomic(atomic) => write!(f, "{atomic}"),
189            Operation::Metadata(metadata) => write!(f, "{metadata}"),
190            Operation::Branch(branch) => write!(f, "{branch}"),
191            Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
192            Operation::Plane(plane) => write!(f, "{plane}"),
193            Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
194            Operation::Copy(variable) => write!(f, "{variable}"),
195            Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
196            Operation::Barrier(barrier_ops) => write!(f, "{barrier_ops}"),
197            Operation::Tma(tma_ops) => write!(f, "{tma_ops}"),
198            Operation::Marker(marker) => write!(f, "{marker}"),
199        }
200    }
201}
202
203pub fn fmt_vararg(args: &[impl Display]) -> String {
204    if args.is_empty() {
205        "".to_string()
206    } else {
207        let str = args
208            .iter()
209            .map(|it| it.to_string())
210            .collect::<Vec<_>>()
211            .join(", ");
212        format!(", {str}")
213    }
214}
215
216#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
217#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
218#[allow(missing_docs)]
219pub struct IndexOperator {
220    pub list: Variable,
221    pub index: Variable,
222    pub line_size: u32,     // 0 == same as list.
223    pub unroll_factor: u32, // Adjustment factor for bounds check
224}
225
226#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
227#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
228#[allow(missing_docs)]
229pub struct IndexAssignOperator {
230    // list is out.
231    pub index: Variable,
232    pub value: Variable,
233    pub line_size: u32,     // 0 == same as list.
234    pub unroll_factor: u32, // Adjustment factor for bounds check
235}
236
237#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
238#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
239#[allow(missing_docs)]
240pub struct BinaryOperator {
241    pub lhs: Variable,
242    pub rhs: Variable,
243}
244
245#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
246#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
247#[allow(missing_docs)]
248pub struct UnaryOperator {
249    pub input: Variable,
250}
251
252impl From<Branch> for Instruction {
253    fn from(value: Branch) -> Self {
254        Instruction::no_out(value)
255    }
256}
257
258impl From<Synchronization> for Instruction {
259    fn from(value: Synchronization) -> Self {
260        Instruction::no_out(value)
261    }
262}
263
264impl From<NonSemantic> for Instruction {
265    fn from(value: NonSemantic) -> Self {
266        Instruction::no_out(value)
267    }
268}