#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ElementwiseOp {
Add,
Sub,
Mul,
Div,
Relu,
Gelu,
Sigmoid,
Silu,
Tanh,
Neg,
Abs,
Sqrt,
Rsqrt,
Exp,
Log,
Scale,
AddScalar,
Ceil,
Floor,
HardSigmoid,
HardSwish,
Softplus,
LeakyRelu,
OneMinus,
FusedAddRelu,
FusedScaleAdd,
Pow,
Min,
Max,
CmpEq,
CmpNe,
CmpLt,
CmpGt,
CmpLe,
CmpGe,
OrMax,
OrProbSum,
Nand,
Nor,
Xor,
Fill,
}
impl ElementwiseOp {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::Add => "add",
Self::Sub => "sub",
Self::Mul => "mul",
Self::Div => "div",
Self::Relu => "relu",
Self::Gelu => "gelu",
Self::Sigmoid => "sigmoid",
Self::Silu => "silu",
Self::Tanh => "tanh",
Self::Neg => "neg",
Self::Abs => "abs",
Self::Sqrt => "sqrt",
Self::Rsqrt => "rsqrt",
Self::Exp => "exp",
Self::Log => "log",
Self::Ceil => "ceil",
Self::Floor => "floor",
Self::HardSigmoid => "hard_sigmoid",
Self::HardSwish => "hard_swish",
Self::Softplus => "softplus",
Self::LeakyRelu => "leaky_relu",
Self::OneMinus => "one_minus",
Self::Pow => "pow",
Self::Min => "min",
Self::Max => "max",
Self::CmpEq => "cmp_eq",
Self::CmpNe => "cmp_ne",
Self::CmpLt => "cmp_lt",
Self::CmpGt => "cmp_gt",
Self::CmpLe => "cmp_le",
Self::CmpGe => "cmp_ge",
Self::OrMax => "or_max",
Self::OrProbSum => "or_prob_sum",
Self::Nand => "nand",
Self::Nor => "nor",
Self::Xor => "xor",
Self::Scale => "scale",
Self::AddScalar => "add_scalar",
Self::FusedAddRelu => "fused_add_relu",
Self::FusedScaleAdd => "fused_scale_add",
Self::Fill => "fill",
}
}
#[must_use]
pub const fn is_binary(self) -> bool {
matches!(
self,
Self::Add
| Self::Sub
| Self::Mul
| Self::Div
| Self::FusedAddRelu
| Self::FusedScaleAdd
| Self::Pow
| Self::Min
| Self::Max
| Self::CmpEq
| Self::CmpNe
| Self::CmpLt
| Self::CmpGt
| Self::CmpLe
| Self::CmpGe
| Self::OrMax
| Self::OrProbSum
| Self::Nand
| Self::Nor
| Self::Xor
)
}
#[must_use]
pub const fn needs_scalar(self) -> bool {
matches!(
self,
Self::Scale | Self::AddScalar | Self::FusedScaleAdd | Self::Fill
)
}
}