1use core::fmt::Display;
2
3use crate::TypeHash;
4
5use crate::{BinaryOperator, OperationArgs, OperationReflect, UnaryOperator, Variable};
6
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
10#[operation(opcode_name = ArithmeticOpCode, pure)]
11pub enum Arithmetic {
12 #[operation(commutative)]
13 Add(BinaryOperator),
14 #[operation(commutative)]
15 SaturatingAdd(BinaryOperator),
16 Fma(FmaOperator),
17 Sub(BinaryOperator),
18 SaturatingSub(BinaryOperator),
19 #[operation(commutative)]
20 Mul(BinaryOperator),
21 Div(BinaryOperator),
22 Abs(UnaryOperator),
23 Exp(UnaryOperator),
24 Log(UnaryOperator),
25 Log1p(UnaryOperator),
26 Expm1(UnaryOperator),
27 Cos(UnaryOperator),
28 Sin(UnaryOperator),
29 Tan(UnaryOperator),
30 Tanh(UnaryOperator),
31 Sinh(UnaryOperator),
32 Cosh(UnaryOperator),
33 ArcCos(UnaryOperator),
34 ArcSin(UnaryOperator),
35 ArcTan(UnaryOperator),
36 ArcSinh(UnaryOperator),
37 ArcCosh(UnaryOperator),
38 ArcTanh(UnaryOperator),
39 Degrees(UnaryOperator),
40 Radians(UnaryOperator),
41 ArcTan2(BinaryOperator),
42 Powf(BinaryOperator),
43 Powi(BinaryOperator),
44 Hypot(BinaryOperator),
45 Rhypot(BinaryOperator),
46 Sqrt(UnaryOperator),
47 InverseSqrt(UnaryOperator),
48 Round(UnaryOperator),
49 Floor(UnaryOperator),
50 Ceil(UnaryOperator),
51 Trunc(UnaryOperator),
52 Erf(UnaryOperator),
53 Recip(UnaryOperator),
54 Clamp(ClampOperator),
55 Modulo(BinaryOperator),
56 Neg(UnaryOperator),
57 #[operation(commutative)]
58 Max(BinaryOperator),
59 #[operation(commutative)]
60 Min(BinaryOperator),
61 Remainder(BinaryOperator),
62 Magnitude(UnaryOperator),
63 Normalize(UnaryOperator),
64 #[operation(commutative)]
65 Dot(BinaryOperator),
66 #[operation(commutative)]
67 MulHi(BinaryOperator),
68 Conj(UnaryOperator),
69 VectorSum(UnaryOperator),
70}
71
72impl Display for Arithmetic {
73 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
74 match self {
75 Arithmetic::Add(op) => write!(f, "{} + {}", op.lhs, op.rhs),
76 Arithmetic::SaturatingAdd(op) => write!(f, "saturating_add({}, {})", op.lhs, op.rhs),
77 Arithmetic::Fma(op) => write!(f, "{} * {} + {}", op.a, op.b, op.c),
78 Arithmetic::Sub(op) => write!(f, "{} - {}", op.lhs, op.rhs),
79 Arithmetic::SaturatingSub(op) => write!(f, "saturating_sub({}, {})", op.lhs, op.rhs),
80 Arithmetic::Mul(op) => write!(f, "{} * {}", op.lhs, op.rhs),
81 Arithmetic::Div(op) => write!(f, "{} / {}", op.lhs, op.rhs),
82 Arithmetic::Abs(op) => write!(f, "{}.abs()", op.input),
83 Arithmetic::Exp(op) => write!(f, "{}.exp()", op.input),
84 Arithmetic::Log(op) => write!(f, "{}.log()", op.input),
85 Arithmetic::Log1p(op) => write!(f, "{}.log_1p()", op.input),
86 Arithmetic::Expm1(op) => write!(f, "{}.exp_m1()", op.input),
87 Arithmetic::Cos(op) => write!(f, "{}.cos()", op.input),
88 Arithmetic::Sin(op) => write!(f, "{}.sin()", op.input),
89 Arithmetic::Tan(op) => write!(f, "{}.tan()", op.input),
90 Arithmetic::Tanh(op) => write!(f, "{}.tanh()", op.input),
91 Arithmetic::Sinh(op) => write!(f, "{}.sinh()", op.input),
92 Arithmetic::Cosh(op) => write!(f, "{}.cosh()", op.input),
93 Arithmetic::ArcCos(op) => write!(f, "{}.acos()", op.input),
94 Arithmetic::ArcSin(op) => write!(f, "{}.asin()", op.input),
95 Arithmetic::ArcTan(op) => write!(f, "{}.atan()", op.input),
96 Arithmetic::ArcSinh(op) => write!(f, "{}.asinh()", op.input),
97 Arithmetic::ArcCosh(op) => write!(f, "{}.acosh()", op.input),
98 Arithmetic::ArcTanh(op) => write!(f, "{}.atanh()", op.input),
99 Arithmetic::Degrees(op) => write!(f, "{}.degrees()", op.input),
100 Arithmetic::Radians(op) => write!(f, "{}.radians()", op.input),
101 Arithmetic::ArcTan2(op) => write!(f, "{}.atan2({})", op.lhs, op.rhs),
102 Arithmetic::Powf(op) => write!(f, "{}.pow({})", op.lhs, op.rhs),
103 Arithmetic::Powi(op) => write!(f, "{}.powi({})", op.lhs, op.rhs),
104 Arithmetic::Hypot(op) => write!(f, "{}.hypot({})", op.lhs, op.rhs),
105 Arithmetic::Rhypot(op) => write!(f, "{}.rhypot({})", op.lhs, op.rhs),
106 Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input),
107 Arithmetic::InverseSqrt(op) => write!(f, "{}.inverse_sqrt()", op.input),
108 Arithmetic::Round(op) => write!(f, "{}.round()", op.input),
109 Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input),
110 Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input),
111 Arithmetic::Trunc(op) => write!(f, "{}.trunc()", op.input),
112 Arithmetic::Erf(op) => write!(f, "{}.erf()", op.input),
113 Arithmetic::Recip(op) => write!(f, "{}.recip()", op.input),
114 Arithmetic::Clamp(op) => {
115 write!(f, "{}.clamp({}, {})", op.input, op.min_value, op.max_value)
116 }
117 Arithmetic::Modulo(op) => write!(f, "{} % {}", op.lhs, op.rhs),
118 Arithmetic::Neg(op) => write!(f, "-{}", op.input),
119 Arithmetic::Max(op) => write!(f, "{}.max({})", op.lhs, op.rhs),
120 Arithmetic::Min(op) => write!(f, "{}.min({})", op.lhs, op.rhs),
121 Arithmetic::Remainder(op) => write!(f, "{} rem {}", op.lhs, op.rhs),
122 Arithmetic::Magnitude(op) => write!(f, "{}.length()", op.input),
123 Arithmetic::Normalize(op) => write!(f, "{}.normalize()", op.input),
124 Arithmetic::Dot(op) => write!(f, "{}.dot({})", op.lhs, op.rhs),
125 Arithmetic::MulHi(op) => write!(f, "mul_hi({}, {})", op.lhs, op.rhs),
126 Arithmetic::Conj(op) => write!(f, "{}.conj()", op.input),
127 Arithmetic::VectorSum(op) => write!(f, "{}.vector_sum()", op.input),
128 }
129 }
130}
131
132#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
133#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
134#[allow(missing_docs)]
135pub struct ClampOperator {
136 pub input: Variable,
137 pub min_value: Variable,
138 pub max_value: Variable,
139}
140
141#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
142#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
143#[allow(missing_docs)]
144pub struct ReadGlobalOperator {
145 pub variable: Variable,
146}
147
148#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
149#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
150#[allow(missing_docs)]
151pub struct ReadGlobalWithLayoutOperator {
152 pub variable: Variable,
153 pub tensor_read_pos: usize,
154 pub tensor_layout_pos: usize,
155}
156
157#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
158#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
159#[allow(missing_docs)]
160pub struct FmaOperator {
161 pub a: Variable,
162 pub b: Variable,
163 pub c: Variable,
164}