1use core::fmt::Display;
2
3use crate::TypeHash;
4
5use crate::{BinaryOperator, OperationArgs, OperationReflect, UnaryOperator, Variable};
6
7#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
9#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationReflect)]
10#[operation(opcode_name = ArithmeticOpCode, pure)]
11pub enum Arithmetic {
12 #[operation(commutative)]
13 Add(BinaryOperator),
14 #[operation(commutative)]
15 SaturatingAdd(BinaryOperator),
16 Fma(FmaOperator),
17 Sub(BinaryOperator),
18 SaturatingSub(BinaryOperator),
19 #[operation(commutative)]
20 Mul(BinaryOperator),
21 Div(BinaryOperator),
22 Abs(UnaryOperator),
23 Exp(UnaryOperator),
24 Log(UnaryOperator),
25 Log1p(UnaryOperator),
26 Cos(UnaryOperator),
27 Sin(UnaryOperator),
28 Tanh(UnaryOperator),
29 Powf(BinaryOperator),
30 Powi(BinaryOperator),
31 Sqrt(UnaryOperator),
32 Round(UnaryOperator),
33 Floor(UnaryOperator),
34 Ceil(UnaryOperator),
35 Trunc(UnaryOperator),
36 Erf(UnaryOperator),
37 Recip(UnaryOperator),
38 Clamp(ClampOperator),
39 Modulo(BinaryOperator),
40 Neg(UnaryOperator),
41 #[operation(commutative)]
42 Max(BinaryOperator),
43 #[operation(commutative)]
44 Min(BinaryOperator),
45 Remainder(BinaryOperator),
46 Magnitude(UnaryOperator),
47 Normalize(UnaryOperator),
48 #[operation(commutative)]
49 Dot(BinaryOperator),
50 #[operation(commutative)]
51 MulHi(BinaryOperator),
52}
53
54impl Display for Arithmetic {
55 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
56 match self {
57 Arithmetic::Add(op) => write!(f, "{} + {}", op.lhs, op.rhs),
58 Arithmetic::SaturatingAdd(op) => write!(f, "saturating_add({}, {})", op.lhs, op.rhs),
59 Arithmetic::Fma(op) => write!(f, "{} * {} + {}", op.a, op.b, op.c),
60 Arithmetic::Sub(op) => write!(f, "{} - {}", op.lhs, op.rhs),
61 Arithmetic::SaturatingSub(op) => write!(f, "saturating_sub({}, {})", op.lhs, op.rhs),
62 Arithmetic::Mul(op) => write!(f, "{} * {}", op.lhs, op.rhs),
63 Arithmetic::Div(op) => write!(f, "{} / {}", op.lhs, op.rhs),
64 Arithmetic::Abs(op) => write!(f, "{}.abs()", op.input),
65 Arithmetic::Exp(op) => write!(f, "{}.exp()", op.input),
66 Arithmetic::Log(op) => write!(f, "{}.log()", op.input),
67 Arithmetic::Log1p(op) => write!(f, "{}.log_1p()", op.input),
68 Arithmetic::Cos(op) => write!(f, "{}.cos()", op.input),
69 Arithmetic::Sin(op) => write!(f, "{}.sin()", op.input),
70 Arithmetic::Tanh(op) => write!(f, "{}.tanh()", op.input),
71 Arithmetic::Powf(op) => write!(f, "{}.powf({})", op.lhs, op.rhs),
72 Arithmetic::Powi(op) => write!(f, "{}.powi({})", op.lhs, op.rhs),
73 Arithmetic::Sqrt(op) => write!(f, "{}.sqrt()", op.input),
74 Arithmetic::Round(op) => write!(f, "{}.round()", op.input),
75 Arithmetic::Floor(op) => write!(f, "{}.floor()", op.input),
76 Arithmetic::Ceil(op) => write!(f, "{}.ceil()", op.input),
77 Arithmetic::Trunc(op) => write!(f, "{}.trunc()", op.input),
78 Arithmetic::Erf(op) => write!(f, "{}.erf()", op.input),
79 Arithmetic::Recip(op) => write!(f, "{}.recip()", op.input),
80 Arithmetic::Clamp(op) => {
81 write!(f, "{}.clamp({}, {})", op.input, op.min_value, op.max_value)
82 }
83 Arithmetic::Modulo(op) => write!(f, "{} % {}", op.lhs, op.rhs),
84 Arithmetic::Neg(op) => write!(f, "-{}", op.input),
85 Arithmetic::Max(op) => write!(f, "{}.max({})", op.lhs, op.rhs),
86 Arithmetic::Min(op) => write!(f, "{}.min({})", op.lhs, op.rhs),
87 Arithmetic::Remainder(op) => write!(f, "{} rem {}", op.lhs, op.rhs),
88 Arithmetic::Magnitude(op) => write!(f, "{}.length()", op.input),
89 Arithmetic::Normalize(op) => write!(f, "{}.normalize()", op.input),
90 Arithmetic::Dot(op) => write!(f, "{}.dot({})", op.lhs, op.rhs),
91 Arithmetic::MulHi(op) => write!(f, "mul_hi({}, {})", op.lhs, op.rhs),
92 }
93 }
94}
95
96#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
97#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
98#[allow(missing_docs)]
99pub struct ClampOperator {
100 pub input: Variable,
101 pub min_value: Variable,
102 pub max_value: Variable,
103}
104
105#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
106#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
107#[allow(missing_docs)]
108pub struct ReadGlobalOperator {
109 pub variable: Variable,
110}
111
112#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
113#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash)]
114#[allow(missing_docs)]
115pub struct ReadGlobalWithLayoutOperator {
116 pub variable: Variable,
117 pub tensor_read_pos: usize,
118 pub tensor_layout_pos: usize,
119}
120
121#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
122#[derive(Debug, Clone, TypeHash, PartialEq, Eq, Hash, OperationArgs)]
123#[allow(missing_docs)]
124pub struct FmaOperator {
125 pub a: Variable,
126 pub b: Variable,
127 pub c: Variable,
128}