cubecl_ir/
arithmetic.rs

1use core::fmt::Display;
2
3use crate::TypeHash;
4
5use crate::{BinaryOperator, OperationArgs, OperationReflect, UnaryOperator, Variable};
6
7/// Arithmetic operations
8#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
10#[operation(opcode_name = ArithmeticOpCode, pure)]
11pub enum Arithmetic {
12    #[operation(commutative)]
13    Add(BinaryOperator),
14    Fma(FmaOperator),
15    Sub(BinaryOperator),
16    #[operation(commutative)]
17    Mul(BinaryOperator),
18    Div(BinaryOperator),
19    Abs(UnaryOperator),
20    Exp(UnaryOperator),
21    Log(UnaryOperator),
22    Log1p(UnaryOperator),
23    Cos(UnaryOperator),
24    Sin(UnaryOperator),
25    Tanh(UnaryOperator),
26    Powf(BinaryOperator),
27    Sqrt(UnaryOperator),
28    Round(UnaryOperator),
29    Floor(UnaryOperator),
30    Ceil(UnaryOperator),
31    Erf(UnaryOperator),
32    Recip(UnaryOperator),
33    Clamp(ClampOperator),
34    Modulo(BinaryOperator),
35    Neg(UnaryOperator),
36    #[operation(commutative)]
37    Max(BinaryOperator),
38    #[operation(commutative)]
39    Min(BinaryOperator),
40    Remainder(BinaryOperator),
41    Magnitude(UnaryOperator),
42    Normalize(UnaryOperator),
43    #[operation(commutative)]
44    Dot(BinaryOperator),
45    #[operation(commutative)]
46    MulHi(BinaryOperator),
47}
48
49impl Display for Arithmetic {
50    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
51        match self {
52            Arithmetic::Add(op) => write!(f, "{} + {}", op.lhs, op.rhs),
53            Arithmetic::Fma(op) => write!(f, "{} * {} + {}", op.a, op.b, op.c),
54            Arithmetic::Sub(op) => write!(f, "{} - {}", op.lhs, op.rhs),
55            Arithmetic::Mul(op) => write!(f, "{} * {}", op.lhs, op.rhs),
56            Arithmetic::Div(op) => write!(f, "{} / {}", op.lhs, op.rhs),
57            Arithmetic::Abs(op) => write!(f, "{}.abs()", op.input),
58            Arithmetic::Exp(op) => write!(f, "{}.exp()", op.input),
59            Arithmetic::Log(op) => write!(f, "{}.log()", op.input),
60            Arithmetic::Log1p(op) => write!(f, "{}.log_1p()", op.input),
61            Arithmetic::Cos(op) => write!(f, "{}.cos()", op.input),
62            Arithmetic::Sin(op) => write!(f, "{}.sin()", op.input),
63            Arithmetic::Tanh(op) => write!(f, "{}.tanh()", op.input),
64            Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs),
65            Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input),
66            Arithmetic::Round(op) => write!(f, "{}.round()", op.input),
67            Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input),
68            Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input),
69            Arithmetic::Erf(op) => write!(f, "{}.erf()", op.input),
70            Arithmetic::Recip(op) => write!(f, "{}.recip()", op.input),
71            Arithmetic::Clamp(op) => {
72                write!(f, "{}.clamp({}, {})", op.input, op.min_value, op.max_value)
73            }
74            Arithmetic::Modulo(op) => write!(f, "{} % {}", op.lhs, op.rhs),
75            Arithmetic::Neg(op) => write!(f, "-{}", op.input),
76            Arithmetic::Max(op) => write!(f, "{}.max({})", op.lhs, op.rhs),
77            Arithmetic::Min(op) => write!(f, "{}.min({})", op.lhs, op.rhs),
78            Arithmetic::Remainder(op) => write!(f, "{} rem {}", op.lhs, op.rhs),
79            Arithmetic::Magnitude(op) => write!(f, "{}.length()", op.input),
80            Arithmetic::Normalize(op) => write!(f, "{}.normalize()", op.input),
81            Arithmetic::Dot(op) => write!(f, "{}.dot({})", op.lhs, op.rhs),
82            Arithmetic::MulHi(op) => write!(f, "mul_hi({}, {})", op.lhs, op.rhs),
83        }
84    }
85}
86
87#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
88#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
89#[allow(missing_docs)]
90pub struct ClampOperator {
91    pub input: Variable,
92    pub min_value: Variable,
93    pub max_value: Variable,
94}
95
96#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
97#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
98#[allow(missing_docs)]
99pub struct ReadGlobalOperator {
100    pub variable: Variable,
101}
102
103#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
104#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
105#[allow(missing_docs)]
106pub struct ReadGlobalWithLayoutOperator {
107    pub variable: Variable,
108    pub tensor_read_pos: usize,
109    pub tensor_layout_pos: usize,
110}
111
112#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
113#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
114#[allow(missing_docs)]
115pub struct FmaOperator {
116    pub a: Variable,
117    pub b: Variable,
118    pub c: Variable,
119}