cubecl_core/ir/
operation.rs

1use std::fmt::Display;
2
3use super::{Branch, CoopMma, Item, NonSemantic, Plane, Scope, Select, Synchronization, Variable};
4use crate::{
5    cpa,
6    ir::{Elem, UIntKind},
7    prelude::AtomicOp,
8};
9use serde::{Deserialize, Serialize};
10
11/// All operations that can be used in a GPU compute shader.
12///
13/// Notes:
14///
15/// [Operator] can be vectorized, but other operations can't.
16/// Therefore, during tracing, only operators can be registered.
17///
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
19#[allow(dead_code, missing_docs, clippy::large_enum_variant)] // Some variants might not be used with different flags
20pub enum Operation {
21    Copy(Variable),
22    Operator(Operator),
23    Atomic(AtomicOp),
24    Metadata(Metadata),
25    Branch(Branch),
26    Synchronization(Synchronization),
27    Plane(Plane),
28    CoopMma(CoopMma),
29    /// Non-semantic instructions (i.e. comments, debug info)
30    NonSemantic(NonSemantic),
31}
32
33/// An instruction that contains a right hand side [`Operation`] and an optional out variable.
34#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
35pub struct Instruction {
36    pub out: Option<Variable>,
37    pub operation: Operation,
38}
39
40impl Instruction {
41    pub fn new(operation: impl Into<Operation>, out: Variable) -> Self {
42        Instruction {
43            out: Some(out),
44            operation: operation.into(),
45        }
46    }
47
48    pub fn out(&self) -> Variable {
49        self.out.unwrap()
50    }
51
52    pub fn item(&self) -> Item {
53        self.out().item
54    }
55}
56
57impl Operation {
58    /// Whether this operation is pure, aka has no side effects. Pure operations can be removed
59    /// if their output is not needed, impure operations must be kept since their execution can
60    /// affect things down the line. e.g. atomics.
61    ///
62    /// Operations that operate across multiple units are always considered impure.
63    pub fn is_pure(&self) -> bool {
64        match self {
65            Operation::Copy(_) => true,
66            Operation::Operator(_) => true,
67            Operation::Atomic(_) => false,
68            Operation::Metadata(_) => true,
69            Operation::Branch(_) => false,
70            Operation::Synchronization(_) => false,
71            Operation::Plane(_) => false,
72            Operation::CoopMma(_) => false,
73            Operation::NonSemantic(_) => false,
74        }
75    }
76}
77
78impl Display for Instruction {
79    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80        match &self.operation {
81            Operation::Operator(Operator::CopyMemory(op)) => write!(
82                f,
83                "copy_mem({}[{}], {}[{}])",
84                self.out(),
85                op.out_index,
86                op.input,
87                op.in_index
88            ),
89            Operation::Operator(Operator::CopyMemoryBulk(op)) => write!(
90                f,
91                "copy_mem_bulk({}[{}], {}[{}], {})",
92                self.out(),
93                op.out_index,
94                op.input,
95                op.in_index,
96                op.len
97            ),
98            Operation::Operator(Operator::IndexAssign(op)) => {
99                write!(f, "{}[{}] = {}", self.out(), op.lhs, op.rhs)
100            }
101            Operation::Operator(Operator::UncheckedIndexAssign(op)) => {
102                write!(f, "unchecked {}[{}] = {}", self.out(), op.lhs, op.rhs)
103            }
104            Operation::Operator(Operator::Cast(op)) => {
105                write!(f, "{} = cast<{}>({})", self.out(), self.item(), op.input)
106            }
107            Operation::Operator(Operator::Bitcast(op)) => {
108                write!(f, "{} = bitcast<{}>({})", self.out(), self.item(), op.input)
109            }
110            _ => {
111                if let Some(out) = self.out {
112                    write!(f, "{out} = {}", self.operation)
113                } else {
114                    write!(f, "{}", self.operation)
115                }
116            }
117        }
118    }
119}
120
121impl Display for Operation {
122    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
123        match self {
124            Operation::Operator(operator) => write!(f, "{operator}"),
125            Operation::Atomic(atomic) => write!(f, "{atomic}"),
126            Operation::Metadata(metadata) => write!(f, "{metadata}"),
127            Operation::Branch(branch) => write!(f, "{branch}"),
128            Operation::Synchronization(synchronization) => write!(f, "{synchronization}"),
129            Operation::Plane(plane) => write!(f, "{plane}"),
130            Operation::CoopMma(coop_mma) => write!(f, "{coop_mma}"),
131            Operation::Copy(variable) => write!(f, "{variable}"),
132            Operation::NonSemantic(non_semantic) => write!(f, "{non_semantic}"),
133        }
134    }
135}
136
137pub fn fmt_vararg(args: &[impl Display]) -> String {
138    if args.is_empty() {
139        "".to_string()
140    } else {
141        let str = args
142            .iter()
143            .map(|it| it.to_string())
144            .collect::<Vec<_>>()
145            .join(", ");
146        format!(", {str}")
147    }
148}
149
150/// All operators that can be used in a GPU compute shader.
151#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
152#[allow(dead_code, missing_docs)] // Some variants might not be used with different flags
153pub enum Operator {
154    Add(BinaryOperator),
155    Fma(FmaOperator),
156    Sub(BinaryOperator),
157    Mul(BinaryOperator),
158    Div(BinaryOperator),
159    Abs(UnaryOperator),
160    Exp(UnaryOperator),
161    Log(UnaryOperator),
162    Log1p(UnaryOperator),
163    Cos(UnaryOperator),
164    Sin(UnaryOperator),
165    Tanh(UnaryOperator),
166    Powf(BinaryOperator),
167    Sqrt(UnaryOperator),
168    Round(UnaryOperator),
169    Floor(UnaryOperator),
170    Ceil(UnaryOperator),
171    Erf(UnaryOperator),
172    Recip(UnaryOperator),
173    Equal(BinaryOperator),
174    NotEqual(BinaryOperator),
175    Lower(BinaryOperator),
176    Clamp(ClampOperator),
177    Greater(BinaryOperator),
178    LowerEqual(BinaryOperator),
179    GreaterEqual(BinaryOperator),
180    Cast(UnaryOperator),
181    Modulo(BinaryOperator),
182    Index(BinaryOperator),
183    CopyMemory(CopyMemoryOperator),
184    CopyMemoryBulk(CopyMemoryBulkOperator),
185    Slice(SliceOperator),
186    UncheckedIndex(BinaryOperator),
187    IndexAssign(BinaryOperator),
188    InitLine(LineInitOperator),
189    UncheckedIndexAssign(BinaryOperator),
190    And(BinaryOperator),
191    Or(BinaryOperator),
192    Not(UnaryOperator),
193    Neg(UnaryOperator),
194    Max(BinaryOperator),
195    Min(BinaryOperator),
196    BitwiseAnd(BinaryOperator),
197    BitwiseOr(BinaryOperator),
198    BitwiseXor(BinaryOperator),
199    ShiftLeft(BinaryOperator),
200    ShiftRight(BinaryOperator),
201    CountOnes(UnaryOperator),
202    ReverseBits(UnaryOperator),
203    Remainder(BinaryOperator),
204    Bitcast(UnaryOperator),
205    Magnitude(UnaryOperator),
206    Normalize(UnaryOperator),
207    Dot(BinaryOperator),
208    // A select statement/ternary
209    Select(Select),
210}
211
212impl Display for Operator {
213    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
214        match self {
215            Operator::Add(op) => write!(f, "{} + {}", op.lhs, op.rhs),
216            Operator::Fma(op) => write!(f, "{} * {} + {}", op.a, op.b, op.c),
217            Operator::Sub(op) => write!(f, "{} - {}", op.lhs, op.rhs),
218            Operator::Mul(op) => write!(f, "{} * {}", op.lhs, op.rhs),
219            Operator::Div(op) => write!(f, "{} / {}", op.lhs, op.rhs),
220            Operator::Abs(op) => write!(f, "{}.abs()", op.input),
221            Operator::Exp(op) => write!(f, "{}.exp()", op.input),
222            Operator::Log(op) => write!(f, "{}.log()", op.input),
223            Operator::Log1p(op) => write!(f, "{}.log_1p()", op.input),
224            Operator::Cos(op) => write!(f, "{}.cos()", op.input),
225            Operator::Sin(op) => write!(f, "{}.sin()", op.input),
226            Operator::Tanh(op) => write!(f, "{}.tanh()", op.input),
227            Operator::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs),
228            Operator::Sqrt(op) => write!(f, "{}.sqrt()", op.input),
229            Operator::Round(op) => write!(f, "{}.round()", op.input),
230            Operator::Floor(op) => write!(f, "{}.floor()", op.input),
231            Operator::Ceil(op) => write!(f, "{}.ceil()", op.input),
232            Operator::Erf(op) => write!(f, "{}.erf()", op.input),
233            Operator::Recip(op) => write!(f, "{}.recip()", op.input),
234            Operator::Equal(op) => write!(f, "{} == {}", op.lhs, op.rhs),
235            Operator::NotEqual(op) => write!(f, "{} != {}", op.lhs, op.rhs),
236            Operator::Lower(op) => write!(f, "{} < {}", op.lhs, op.rhs),
237            Operator::Clamp(op) => {
238                write!(f, "{}.clamp({}, {})", op.input, op.min_value, op.max_value)
239            }
240            Operator::Greater(op) => write!(f, "{} > {}", op.lhs, op.rhs),
241            Operator::LowerEqual(op) => write!(f, "{} <= {}", op.lhs, op.rhs),
242            Operator::GreaterEqual(op) => write!(f, "{} >= {}", op.lhs, op.rhs),
243            Operator::Modulo(op) => write!(f, "{} % {}", op.lhs, op.rhs),
244            Operator::Index(op) => write!(f, "{}[{}]", op.lhs, op.rhs),
245            Operator::CopyMemory(op) => {
246                write!(f, "[{}] = {}[{}]", op.out_index, op.input, op.in_index)
247            }
248            Operator::CopyMemoryBulk(op) => write!(
249                f,
250                "memcpy([{}], {}[{}], {})",
251                op.input, op.in_index, op.out_index, op.len
252            ),
253            Operator::Slice(op) => write!(f, "{}[{}..{}]", op.input, op.start, op.end),
254            Operator::UncheckedIndex(op) => {
255                write!(f, "unchecked {}[{}]", op.lhs, op.rhs)
256            }
257            Operator::IndexAssign(op) => write!(f, "[{}] = {}", op.lhs, op.rhs),
258            Operator::UncheckedIndexAssign(op) => {
259                write!(f, "unchecked [{}] = {}", op.lhs, op.rhs)
260            }
261            Operator::And(op) => write!(f, "{} && {}", op.lhs, op.rhs),
262            Operator::Or(op) => write!(f, "{} || {}", op.lhs, op.rhs),
263            Operator::Not(op) => write!(f, "!{}", op.input),
264            Operator::Neg(op) => write!(f, "-{}", op.input),
265            Operator::Max(op) => write!(f, "{}.max({})", op.lhs, op.rhs),
266            Operator::Min(op) => write!(f, "{}.min({})", op.lhs, op.rhs),
267            Operator::BitwiseAnd(op) => write!(f, "{} & {}", op.lhs, op.rhs),
268            Operator::BitwiseOr(op) => write!(f, "{} | {}", op.lhs, op.rhs),
269            Operator::BitwiseXor(op) => write!(f, "{} ^ {}", op.lhs, op.rhs),
270            Operator::CountOnes(op) => write!(f, "{}.count_bits()", op.input),
271            Operator::ReverseBits(op) => write!(f, "{}.reverse_bits()", op.input),
272            Operator::ShiftLeft(op) => write!(f, "{} << {}", op.lhs, op.rhs),
273            Operator::ShiftRight(op) => write!(f, "{} >> {}", op.lhs, op.rhs),
274            Operator::Remainder(op) => write!(f, "{} rem {}", op.lhs, op.rhs),
275            Operator::Magnitude(op) => write!(f, "{}.length()", op.input),
276            Operator::Normalize(op) => write!(f, "{}.normalize()", op.input),
277            Operator::Dot(op) => write!(f, "{}.dot({})", op.lhs, op.rhs),
278            Operator::InitLine(init) => {
279                let inits = init
280                    .inputs
281                    .iter()
282                    .map(|input| format!("{input}"))
283                    .collect::<Vec<_>>();
284                write!(f, "vec({})", inits.join(", "))
285            }
286            Operator::Select(op) => {
287                write!(f, "{} ? {} : {}", op.cond, op.then, op.or_else)
288            }
289            Operator::Cast(op) => write!(f, "cast({})", op.input),
290            Operator::Bitcast(op) => write!(f, "bitcast({})", op.input),
291        }
292    }
293}
294
295/// All metadata that can be accessed in a shader.
296#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
297#[allow(missing_docs)]
298pub enum Metadata {
299    /// The rank of an array.
300    Rank { var: Variable },
301    /// The stride of an array at the given dimension.
302    Stride { dim: Variable, var: Variable },
303    /// The shape of an array at the given dimension.
304    Shape { dim: Variable, var: Variable },
305    /// The length of an array.
306    Length { var: Variable },
307    /// The length of an array's underlying buffer.
308    BufferLength { var: Variable },
309}
310
311impl Display for Metadata {
312    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
313        match self {
314            Metadata::Rank { var } => write!(f, "rank({})", var),
315            Metadata::Stride { dim, var } => write!(f, "{}.strides[{}]", var, dim),
316            Metadata::Shape { dim, var } => write!(f, "{}.shape[{}]", var, dim),
317            Metadata::Length { var } => write!(f, "{}.len()", var),
318            Metadata::BufferLength { var } => write!(f, "buffer_len({})", var),
319        }
320    }
321}
322
323#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
324#[allow(missing_docs)]
325pub struct BinaryOperator {
326    pub lhs: Variable,
327    pub rhs: Variable,
328}
329
330#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
331#[allow(missing_docs)]
332pub struct UnaryOperator {
333    pub input: Variable,
334}
335
336#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
337#[allow(missing_docs)]
338pub struct LineInitOperator {
339    pub inputs: Vec<Variable>,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
343#[allow(missing_docs)]
344pub struct CopyMemoryOperator {
345    pub out_index: Variable,
346    pub input: Variable,
347    pub in_index: Variable,
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
351#[allow(missing_docs)]
352pub struct CopyMemoryBulkOperator {
353    pub out_index: Variable,
354    pub input: Variable,
355    pub in_index: Variable,
356    pub len: u32,
357}
358
359#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
360#[allow(missing_docs)]
361pub struct ClampOperator {
362    pub input: Variable,
363    pub min_value: Variable,
364    pub max_value: Variable,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
368#[allow(missing_docs)]
369pub struct SliceOperator {
370    pub input: Variable,
371    pub start: Variable,
372    pub end: Variable,
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
376#[allow(missing_docs)]
377pub struct CompareAndSwapOperator {
378    pub input: Variable,
379    pub cmp: Variable,
380    pub val: Variable,
381}
382
383#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
384#[allow(missing_docs)]
385pub struct ReadGlobalOperator {
386    pub variable: Variable,
387}
388
389#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
390#[allow(missing_docs)]
391pub struct ReadGlobalWithLayoutOperator {
392    pub variable: Variable,
393    pub tensor_read_pos: usize,
394    pub tensor_layout_pos: usize,
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
398#[allow(missing_docs)]
399pub struct FmaOperator {
400    pub a: Variable,
401    pub b: Variable,
402    pub c: Variable,
403}
404
405#[allow(missing_docs)]
406pub fn expand_checked_index_assign(scope: &mut Scope, lhs: Variable, rhs: Variable, out: Variable) {
407    let array_len = scope.create_local(Item::new(Elem::UInt(UIntKind::U32)));
408    let inside_bound = scope.create_local(Item::new(Elem::Bool));
409
410    if out.has_buffer_length() {
411        cpa!(scope, array_len = buffer_len(out));
412    } else {
413        cpa!(scope, array_len = len(out));
414    }
415
416    cpa!(scope, inside_bound = lhs < array_len);
417    cpa!(scope, if(inside_bound).then(|scope| {
418        cpa!(scope, unchecked(out[lhs]) = rhs);
419    }));
420}
421
422impl From<Operator> for Operation {
423    fn from(val: Operator) -> Self {
424        Operation::Operator(val)
425    }
426}
427
428impl From<Branch> for Operation {
429    fn from(value: Branch) -> Self {
430        Self::Branch(value)
431    }
432}
433
434impl From<Branch> for Instruction {
435    fn from(value: Branch) -> Self {
436        Instruction {
437            out: None,
438            operation: value.into(),
439        }
440    }
441}
442
443impl From<Synchronization> for Operation {
444    fn from(value: Synchronization) -> Self {
445        Self::Synchronization(value)
446    }
447}
448
449impl From<Synchronization> for Instruction {
450    fn from(value: Synchronization) -> Self {
451        Instruction {
452            out: None,
453            operation: value.into(),
454        }
455    }
456}
457
458impl From<Metadata> for Operation {
459    fn from(val: Metadata) -> Self {
460        Operation::Metadata(val)
461    }
462}
463
464impl From<NonSemantic> for Operation {
465    fn from(val: NonSemantic) -> Self {
466        Operation::NonSemantic(val)
467    }
468}
469
470impl From<NonSemantic> for Instruction {
471    fn from(value: NonSemantic) -> Self {
472        Instruction {
473            out: None,
474            operation: value.into(),
475        }
476    }
477}