Skip to main content

cubecl_ir/
arithmetic.rs

1use core::fmt::Display;
2
3use crate::TypeHash;
4
5use crate::{BinaryOperator, OperationArgs, OperationReflect, UnaryOperator, Variable};
6
7/// Arithmetic operations
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
10#[operation(opcode_name = ArithmeticOpCode, pure)]
11pub enum Arithmetic {
12    #[operation(commutative)]
13    Add(BinaryOperator),
14    #[operation(commutative)]
15    SaturatingAdd(BinaryOperator),
16    Fma(FmaOperator),
17    Sub(BinaryOperator),
18    SaturatingSub(BinaryOperator),
19    #[operation(commutative)]
20    Mul(BinaryOperator),
21    Div(BinaryOperator),
22    Abs(UnaryOperator),
23    Exp(UnaryOperator),
24    Log(UnaryOperator),
25    Log1p(UnaryOperator),
26    Cos(UnaryOperator),
27    Sin(UnaryOperator),
28    Tan(UnaryOperator),
29    Tanh(UnaryOperator),
30    Sinh(UnaryOperator),
31    Cosh(UnaryOperator),
32    ArcCos(UnaryOperator),
33    ArcSin(UnaryOperator),
34    ArcTan(UnaryOperator),
35    ArcSinh(UnaryOperator),
36    ArcCosh(UnaryOperator),
37    ArcTanh(UnaryOperator),
38    Degrees(UnaryOperator),
39    Radians(UnaryOperator),
40    ArcTan2(BinaryOperator),
41    Powf(BinaryOperator),
42    Powi(BinaryOperator),
43    Hypot(BinaryOperator),
44    Rhypot(BinaryOperator),
45    Sqrt(UnaryOperator),
46    InverseSqrt(UnaryOperator),
47    Round(UnaryOperator),
48    Floor(UnaryOperator),
49    Ceil(UnaryOperator),
50    Trunc(UnaryOperator),
51    Erf(UnaryOperator),
52    Recip(UnaryOperator),
53    Clamp(ClampOperator),
54    Modulo(BinaryOperator),
55    Neg(UnaryOperator),
56    #[operation(commutative)]
57    Max(BinaryOperator),
58    #[operation(commutative)]
59    Min(BinaryOperator),
60    Remainder(BinaryOperator),
61    Magnitude(UnaryOperator),
62    Normalize(UnaryOperator),
63    #[operation(commutative)]
64    Dot(BinaryOperator),
65    #[operation(commutative)]
66    MulHi(BinaryOperator),
67    VectorSum(UnaryOperator),
68}
69
70impl Display for Arithmetic {
71    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
72        match self {
73            Arithmetic::Add(op) => write!(f, "{} + {}", op.lhs, op.rhs),
74            Arithmetic::SaturatingAdd(op) => write!(f, "saturating_add({}, {})", op.lhs, op.rhs),
75            Arithmetic::Fma(op) => write!(f, "{} * {} + {}", op.a, op.b, op.c),
76            Arithmetic::Sub(op) => write!(f, "{} - {}", op.lhs, op.rhs),
77            Arithmetic::SaturatingSub(op) => write!(f, "saturating_sub({}, {})", op.lhs, op.rhs),
78            Arithmetic::Mul(op) => write!(f, "{} * {}", op.lhs, op.rhs),
79            Arithmetic::Div(op) => write!(f, "{} / {}", op.lhs, op.rhs),
80            Arithmetic::Abs(op) => write!(f, "{}.abs()", op.input),
81            Arithmetic::Exp(op) => write!(f, "{}.exp()", op.input),
82            Arithmetic::Log(op) => write!(f, "{}.log()", op.input),
83            Arithmetic::Log1p(op) => write!(f, "{}.log_1p()", op.input),
84            Arithmetic::Cos(op) => write!(f, "{}.cos()", op.input),
85            Arithmetic::Sin(op) => write!(f, "{}.sin()", op.input),
86            Arithmetic::Tan(op) => write!(f, "{}.tan()", op.input),
87            Arithmetic::Tanh(op) => write!(f, "{}.tanh()", op.input),
88            Arithmetic::Sinh(op) => write!(f, "{}.sinh()", op.input),
89            Arithmetic::Cosh(op) => write!(f, "{}.cosh()", op.input),
90            Arithmetic::ArcCos(op) => write!(f, "{}.acos()", op.input),
91            Arithmetic::ArcSin(op) => write!(f, "{}.asin()", op.input),
92            Arithmetic::ArcTan(op) => write!(f, "{}.atan()", op.input),
93            Arithmetic::ArcSinh(op) => write!(f, "{}.asinh()", op.input),
94            Arithmetic::ArcCosh(op) => write!(f, "{}.acosh()", op.input),
95            Arithmetic::ArcTanh(op) => write!(f, "{}.atanh()", op.input),
96            Arithmetic::Degrees(op) => write!(f, "{}.degrees()", op.input),
97            Arithmetic::Radians(op) => write!(f, "{}.radians()", op.input),
98            Arithmetic::ArcTan2(op) => write!(f, "{}.atan2({})", op.lhs, op.rhs),
99            Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs),
100            Arithmetic::Powi(op) => write!(f, "{}.powi({})", op.lhs, op.rhs),
101            Arithmetic::Hypot(op) => write!(f, "{}.hypot({})", op.lhs, op.rhs),
102            Arithmetic::Rhypot(op) => write!(f, "{}.rhypot({})", op.lhs, op.rhs),
103            Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input),
104            Arithmetic::InverseSqrt(op) => write!(f, "{}.inverse_sqrt()", op.input),
105            Arithmetic::Round(op) => write!(f, "{}.round()", op.input),
106            Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input),
107            Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input),
108            Arithmetic::Trunc(op) => write!(f, "{}.trunc()", op.input),
109            Arithmetic::Erf(op) => write!(f, "{}.erf()", op.input),
110            Arithmetic::Recip(op) => write!(f, "{}.recip()", op.input),
111            Arithmetic::Clamp(op) => {
112                write!(f, "{}.clamp({}, {})", op.input, op.min_value, op.max_value)
113            }
114            Arithmetic::Modulo(op) => write!(f, "{} % {}", op.lhs, op.rhs),
115            Arithmetic::Neg(op) => write!(f, "-{}", op.input),
116            Arithmetic::Max(op) => write!(f, "{}.max({})", op.lhs, op.rhs),
117            Arithmetic::Min(op) => write!(f, "{}.min({})", op.lhs, op.rhs),
118            Arithmetic::Remainder(op) => write!(f, "{} rem {}", op.lhs, op.rhs),
119            Arithmetic::Magnitude(op) => write!(f, "{}.length()", op.input),
120            Arithmetic::Normalize(op) => write!(f, "{}.normalize()", op.input),
121            Arithmetic::Dot(op) => write!(f, "{}.dot({})", op.lhs, op.rhs),
122            Arithmetic::MulHi(op) => write!(f, "mul_hi({}, {})", op.lhs, op.rhs),
123            Arithmetic::VectorSum(op) => write!(f, "{}.vector_sum()", op.input),
124        }
125    }
126}
127
128#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
129#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
130#[allow(missing_docs)]
131pub struct ClampOperator {
132    pub input: Variable,
133    pub min_value: Variable,
134    pub max_value: Variable,
135}
136
137#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
138#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
139#[allow(missing_docs)]
140pub struct ReadGlobalOperator {
141    pub variable: Variable,
142}
143
144#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
145#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
146#[allow(missing_docs)]
147pub struct ReadGlobalWithLayoutOperator {
148    pub variable: Variable,
149    pub tensor_read_pos: usize,
150    pub tensor_layout_pos: usize,
151}
152
153#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
154#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
155#[allow(missing_docs)]
156pub struct FmaOperator {
157    pub a: Variable,
158    pub b: Variable,
159    pub c: Variable,
160}