use oxicuda_ptx::templates::elementwise::ElementwiseOp as PtxElementwiseOp;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ElementwiseOp {
Relu,
Gelu,
Sigmoid,
Silu,
Tanh,
Add,
Mul,
Scale,
AddScalar,
FusedAddRelu,
FusedScaleAdd,
}
impl ElementwiseOp {
#[allow(dead_code)]
pub(crate) fn to_ptx_op(self) -> PtxElementwiseOp {
match self {
Self::Relu => PtxElementwiseOp::Relu,
Self::Gelu => PtxElementwiseOp::Gelu,
Self::Sigmoid => PtxElementwiseOp::Sigmoid,
Self::Silu => PtxElementwiseOp::Silu,
Self::Tanh => PtxElementwiseOp::Tanh,
Self::Add => PtxElementwiseOp::Add,
Self::Mul => PtxElementwiseOp::Mul,
Self::Scale => PtxElementwiseOp::Scale,
Self::AddScalar => PtxElementwiseOp::AddScalar,
Self::FusedAddRelu => PtxElementwiseOp::FusedAddRelu,
Self::FusedScaleAdd => PtxElementwiseOp::FusedScaleAdd,
}
}
pub fn as_str(self) -> &'static str {
match self {
Self::Relu => "relu",
Self::Gelu => "gelu",
Self::Sigmoid => "sigmoid",
Self::Silu => "silu",
Self::Tanh => "tanh",
Self::Add => "add",
Self::Mul => "mul",
Self::Scale => "scale",
Self::AddScalar => "add_scalar",
Self::FusedAddRelu => "fused_add_relu",
Self::FusedScaleAdd => "fused_scale_add",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn op_names_are_lowercase() {
let ops = [
ElementwiseOp::Relu,
ElementwiseOp::Gelu,
ElementwiseOp::Sigmoid,
ElementwiseOp::Silu,
ElementwiseOp::Tanh,
ElementwiseOp::Add,
ElementwiseOp::Mul,
ElementwiseOp::Scale,
ElementwiseOp::AddScalar,
ElementwiseOp::FusedAddRelu,
ElementwiseOp::FusedScaleAdd,
];
for op in &ops {
let name = op.as_str();
assert_eq!(name, name.to_lowercase(), "op name should be lowercase");
}
}
#[test]
fn ptx_op_roundtrip() {
let ops = [
ElementwiseOp::Relu,
ElementwiseOp::Gelu,
ElementwiseOp::Sigmoid,
ElementwiseOp::Silu,
ElementwiseOp::Tanh,
ElementwiseOp::Add,
ElementwiseOp::Mul,
ElementwiseOp::Scale,
ElementwiseOp::AddScalar,
ElementwiseOp::FusedAddRelu,
ElementwiseOp::FusedScaleAdd,
];
for op in &ops {
let ptx_op = op.to_ptx_op();
assert_eq!(ptx_op.as_str(), op.as_str());
}
}
}