cubecl_ir/
operator.rs

1use core::fmt::Display;
2
3use alloc::{format, vec::Vec};
4
5use crate::{IndexAssignOperator, IndexOperator, TypeHash};
6
7use crate::{BinaryOperator, OperationArgs, OperationReflect, UnaryOperator, Variable};
8
9/// Operators available on the GPU
10#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
11#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
12#[operation(opcode_name = OperatorOpCode)]
13pub enum Operator {
14    #[operation(pure)]
15    Index(IndexOperator),
16    CopyMemory(CopyMemoryOperator),
17    CopyMemoryBulk(CopyMemoryBulkOperator),
18    #[operation(pure)]
19    UncheckedIndex(IndexOperator),
20    IndexAssign(IndexAssignOperator),
21    UncheckedIndexAssign(IndexAssignOperator),
22    #[operation(pure)]
23    InitLine(LineInitOperator),
24    #[operation(commutative, pure)]
25    And(BinaryOperator),
26    #[operation(commutative, pure)]
27    Or(BinaryOperator),
28    #[operation(pure)]
29    Not(UnaryOperator),
30    #[operation(pure)]
31    Cast(UnaryOperator),
32    #[operation(pure)]
33    Reinterpret(UnaryOperator),
34    /// A select statement/ternary
35    #[operation(pure)]
36    Select(Select),
37}
38
39impl Display for Operator {
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        match self {
42            Operator::Index(op) => write!(f, "{}[{}]", op.list, op.index),
43            Operator::CopyMemory(op) => {
44                write!(f, "[{}] = {}[{}]", op.out_index, op.input, op.in_index)
45            }
46            Operator::CopyMemoryBulk(op) => write!(
47                f,
48                "memcpy([{}], {}[{}], {})",
49                op.input, op.in_index, op.out_index, op.len
50            ),
51            Operator::UncheckedIndex(op) => {
52                write!(f, "unchecked {}[{}]", op.list, op.index)
53            }
54            Operator::IndexAssign(op) => write!(f, "[{}] = {}", op.index, op.value),
55            Operator::UncheckedIndexAssign(op) => {
56                write!(f, "unchecked [{}] = {}", op.index, op.value)
57            }
58            Operator::And(op) => write!(f, "{} && {}", op.lhs, op.rhs),
59            Operator::Or(op) => write!(f, "{} || {}", op.lhs, op.rhs),
60            Operator::Not(op) => write!(f, "!{}", op.input),
61            Operator::InitLine(init) => {
62                let inits = init
63                    .inputs
64                    .iter()
65                    .map(|input| format!("{input}"))
66                    .collect::<Vec<_>>();
67                write!(f, "vec({})", inits.join(", "))
68            }
69            Operator::Select(op) => {
70                write!(f, "{} ? {} : {}", op.cond, op.then, op.or_else)
71            }
72            Operator::Cast(op) => write!(f, "cast({})", op.input),
73            Operator::Reinterpret(op) => write!(f, "reinterpret({})", op.input),
74        }
75    }
76}
77
78#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
79#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
80#[allow(missing_docs)]
81pub struct SliceOperator {
82    pub input: Variable,
83    pub start: Variable,
84    pub end: Variable,
85}
86
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
89#[allow(missing_docs)]
90pub struct ReinterpretSliceOperator {
91    pub input: Variable,
92    pub line_size: u32,
93}
94
95#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
96#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
97#[allow(missing_docs)]
98pub struct LineInitOperator {
99    pub inputs: Vec<Variable>,
100}
101
102#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
103#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
104#[allow(missing_docs)]
105pub struct CopyMemoryOperator {
106    pub out_index: Variable,
107    pub input: Variable,
108    pub in_index: Variable,
109}
110
111#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
112#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
113#[allow(missing_docs)]
114pub struct CopyMemoryBulkOperator {
115    pub out_index: Variable,
116    pub input: Variable,
117    pub in_index: Variable,
118    pub len: Variable,
119    pub offset_input: Variable,
120    pub offset_out: Variable,
121}
122
123#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
124#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
125#[allow(missing_docs)]
126pub struct Select {
127    pub cond: Variable,
128    pub then: Variable,
129    pub or_else: Variable,
130}