oxicuda-blas 0.1.4

OxiCUDA BLAS - GPU-accelerated BLAS operations (cuBLAS equivalent)
Documentation
//! Elementwise operation type enumeration.
//!
//! Maps BLAS-level operation names to their corresponding PTX template
//! variants in [`oxicuda_ptx::templates::elementwise::ElementwiseOp`].

use oxicuda_ptx::templates::elementwise::ElementwiseOp as PtxElementwiseOp;

/// Elementwise operation types supported by the BLAS elementwise module.
///
/// Each variant corresponds to a PTX kernel generated by [`oxicuda_ptx::templates::elementwise::ElementwiseTemplate`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ElementwiseOp {
    /// Rectified linear unit: `output[i] = max(0, input[i])`.
    Relu,
    /// Gaussian error linear unit (tanh approximation).
    Gelu,
    /// Sigmoid activation: `output[i] = 1 / (1 + exp(-input[i]))`.
    Sigmoid,
    /// Sigmoid linear unit: `output[i] = input[i] * sigmoid(input[i])`.
    Silu,
    /// Hyperbolic tangent activation.
    Tanh,
    /// Arithmetic negation: `output[i] = -input[i]`.
    Neg,
    /// Absolute value: `output[i] = |input[i]|`.
    Abs,
    /// Square root: `output[i] = sqrt(input[i])`.
    Sqrt,
    /// Reciprocal square root: `output[i] = 1 / sqrt(input[i])`.
    Rsqrt,
    /// Exponential: `output[i] = exp(input[i])`.
    Exp,
    /// Natural logarithm: `output[i] = ln(input[i])`.
    Log,
    /// Ceiling: `output[i] = ceil(input[i])`.
    Ceil,
    /// Floor: `output[i] = floor(input[i])`.
    Floor,
    /// Hard sigmoid: `output[i] = max(0, min(1, 0.2*input[i] + 0.5))`.
    HardSigmoid,
    /// Hard swish: `output[i] = input[i] * max(0, min(6, input[i]+3)) / 6`.
    HardSwish,
    /// Softplus: `output[i] = ln(1 + exp(input[i]))`.
    Softplus,
    /// Leaky relu: `output[i] = input[i] >= 0 ? input[i] : 0.01 * input[i]`.
    LeakyRelu,
    /// One-minus: `output[i] = 1 - input[i]`.
    OneMinus,
    /// Element-wise addition: `C[i] = A[i] + B[i]`.
    Add,
    /// Element-wise subtraction: `C[i] = A[i] - B[i]`.
    Sub,
    /// Element-wise multiplication (Hadamard product): `C[i] = A[i] * B[i]`.
    Mul,
    /// Element-wise division: `C[i] = A[i] / B[i]`.
    Div,
    /// Element-wise power: `C[i] = A[i]^B[i]`.
    Pow,
    /// Element-wise minimum: `C[i] = min(A[i], B[i])`.
    Min,
    /// Element-wise maximum: `C[i] = max(A[i], B[i])`.
    Max,
    /// Comparison equal: `C[i] = (A[i] == B[i]) ? 1.0 : 0.0`.
    CmpEq,
    /// Comparison not-equal: `C[i] = (A[i] != B[i]) ? 1.0 : 0.0`.
    CmpNe,
    /// Comparison less-than: `C[i] = (A[i] < B[i]) ? 1.0 : 0.0`.
    CmpLt,
    /// Comparison greater-than: `C[i] = (A[i] > B[i]) ? 1.0 : 0.0`.
    CmpGt,
    /// Comparison less-or-equal: `C[i] = (A[i] <= B[i]) ? 1.0 : 0.0`.
    CmpLe,
    /// Comparison greater-or-equal: `C[i] = (A[i] >= B[i]) ? 1.0 : 0.0`.
    CmpGe,
    /// Fuzzy OR via max: `C[i] = max(A[i], B[i])`.
    OrMax,
    /// Probabilistic OR: `C[i] = A[i] + B[i] - A[i]*B[i]`.
    OrProbSum,
    /// Fuzzy NAND: `C[i] = 1 - A[i]*B[i]`.
    Nand,
    /// Fuzzy NOR: `C[i] = 1 - (A[i] + B[i] - A[i]*B[i])`.
    Nor,
    /// Fuzzy XOR: `C[i] = A[i] + B[i] - 2*A[i]*B[i]`.
    Xor,
    /// Scalar scaling: `output[i] = alpha * input[i]`.
    Scale,
    /// Add scalar: `output[i] = input[i] + scalar`.
    AddScalar,
    /// Fused add + ReLU: `C[i] = max(0, A[i] + B[i])`.
    FusedAddRelu,
    /// Fused scale-add: `C[i] = alpha * A[i] + beta * B[i]`.
    FusedScaleAdd,
}

impl ElementwiseOp {
    /// Converts this BLAS-level op to the corresponding PTX template op.
    #[allow(dead_code)]
    pub(crate) fn to_ptx_op(self) -> PtxElementwiseOp {
        match self {
            Self::Relu => PtxElementwiseOp::Relu,
            Self::Gelu => PtxElementwiseOp::Gelu,
            Self::Sigmoid => PtxElementwiseOp::Sigmoid,
            Self::Silu => PtxElementwiseOp::Silu,
            Self::Tanh => PtxElementwiseOp::Tanh,
            Self::Neg => PtxElementwiseOp::Neg,
            Self::Abs => PtxElementwiseOp::Abs,
            Self::Sqrt => PtxElementwiseOp::Sqrt,
            Self::Rsqrt => PtxElementwiseOp::Rsqrt,
            Self::Exp => PtxElementwiseOp::Exp,
            Self::Log => PtxElementwiseOp::Log,
            Self::Ceil => PtxElementwiseOp::Ceil,
            Self::Floor => PtxElementwiseOp::Floor,
            Self::HardSigmoid => PtxElementwiseOp::HardSigmoid,
            Self::HardSwish => PtxElementwiseOp::HardSwish,
            Self::Softplus => PtxElementwiseOp::Softplus,
            Self::LeakyRelu => PtxElementwiseOp::LeakyRelu,
            Self::OneMinus => PtxElementwiseOp::OneMinus,
            Self::Add => PtxElementwiseOp::Add,
            Self::Sub => PtxElementwiseOp::Sub,
            Self::Mul => PtxElementwiseOp::Mul,
            Self::Div => PtxElementwiseOp::Div,
            Self::Pow => PtxElementwiseOp::Pow,
            Self::Min => PtxElementwiseOp::Min,
            Self::Max => PtxElementwiseOp::Max,
            Self::CmpEq => PtxElementwiseOp::CmpEq,
            Self::CmpNe => PtxElementwiseOp::CmpNe,
            Self::CmpLt => PtxElementwiseOp::CmpLt,
            Self::CmpGt => PtxElementwiseOp::CmpGt,
            Self::CmpLe => PtxElementwiseOp::CmpLe,
            Self::CmpGe => PtxElementwiseOp::CmpGe,
            Self::OrMax => PtxElementwiseOp::OrMax,
            Self::OrProbSum => PtxElementwiseOp::OrProbSum,
            Self::Nand => PtxElementwiseOp::Nand,
            Self::Nor => PtxElementwiseOp::Nor,
            Self::Xor => PtxElementwiseOp::Xor,
            Self::Scale => PtxElementwiseOp::Scale,
            Self::AddScalar => PtxElementwiseOp::AddScalar,
            Self::FusedAddRelu => PtxElementwiseOp::FusedAddRelu,
            Self::FusedScaleAdd => PtxElementwiseOp::FusedScaleAdd,
        }
    }

    /// Returns a short lowercase name for diagnostics and logging.
    pub fn as_str(self) -> &'static str {
        match self {
            Self::Relu => "relu",
            Self::Gelu => "gelu",
            Self::Sigmoid => "sigmoid",
            Self::Silu => "silu",
            Self::Tanh => "tanh",
            Self::Neg => "neg",
            Self::Abs => "abs",
            Self::Sqrt => "sqrt",
            Self::Rsqrt => "rsqrt",
            Self::Exp => "exp",
            Self::Log => "log",
            Self::Ceil => "ceil",
            Self::Floor => "floor",
            Self::HardSigmoid => "hard_sigmoid",
            Self::HardSwish => "hard_swish",
            Self::Softplus => "softplus",
            Self::LeakyRelu => "leaky_relu",
            Self::OneMinus => "one_minus",
            Self::Add => "add",
            Self::Sub => "sub",
            Self::Mul => "mul",
            Self::Div => "div",
            Self::Pow => "pow",
            Self::Min => "min",
            Self::Max => "max",
            Self::CmpEq => "cmp_eq",
            Self::CmpNe => "cmp_ne",
            Self::CmpLt => "cmp_lt",
            Self::CmpGt => "cmp_gt",
            Self::CmpLe => "cmp_le",
            Self::CmpGe => "cmp_ge",
            Self::OrMax => "or_max",
            Self::OrProbSum => "or_prob_sum",
            Self::Nand => "nand",
            Self::Nor => "nor",
            Self::Xor => "xor",
            Self::Scale => "scale",
            Self::AddScalar => "add_scalar",
            Self::FusedAddRelu => "fused_add_relu",
            Self::FusedScaleAdd => "fused_scale_add",
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn op_names_are_lowercase() {
        let ops = [
            ElementwiseOp::Relu,
            ElementwiseOp::Gelu,
            ElementwiseOp::Sigmoid,
            ElementwiseOp::Silu,
            ElementwiseOp::Tanh,
            ElementwiseOp::Neg,
            ElementwiseOp::Abs,
            ElementwiseOp::Sqrt,
            ElementwiseOp::Rsqrt,
            ElementwiseOp::Exp,
            ElementwiseOp::Log,
            ElementwiseOp::Ceil,
            ElementwiseOp::Floor,
            ElementwiseOp::HardSigmoid,
            ElementwiseOp::HardSwish,
            ElementwiseOp::Softplus,
            ElementwiseOp::LeakyRelu,
            ElementwiseOp::OneMinus,
            ElementwiseOp::Add,
            ElementwiseOp::Sub,
            ElementwiseOp::Mul,
            ElementwiseOp::Div,
            ElementwiseOp::Pow,
            ElementwiseOp::Min,
            ElementwiseOp::Max,
            ElementwiseOp::CmpEq,
            ElementwiseOp::CmpNe,
            ElementwiseOp::CmpLt,
            ElementwiseOp::CmpGt,
            ElementwiseOp::CmpLe,
            ElementwiseOp::CmpGe,
            ElementwiseOp::OrMax,
            ElementwiseOp::OrProbSum,
            ElementwiseOp::Nand,
            ElementwiseOp::Nor,
            ElementwiseOp::Xor,
            ElementwiseOp::Scale,
            ElementwiseOp::AddScalar,
            ElementwiseOp::FusedAddRelu,
            ElementwiseOp::FusedScaleAdd,
        ];
        for op in &ops {
            let name = op.as_str();
            assert_eq!(name, name.to_lowercase(), "op name should be lowercase");
        }
    }

    #[test]
    fn ptx_op_roundtrip() {
        // Verify that all ops convert without panic
        let ops = [
            ElementwiseOp::Relu,
            ElementwiseOp::Gelu,
            ElementwiseOp::Sigmoid,
            ElementwiseOp::Silu,
            ElementwiseOp::Tanh,
            ElementwiseOp::Neg,
            ElementwiseOp::Abs,
            ElementwiseOp::Sqrt,
            ElementwiseOp::Rsqrt,
            ElementwiseOp::Exp,
            ElementwiseOp::Log,
            ElementwiseOp::Ceil,
            ElementwiseOp::Floor,
            ElementwiseOp::HardSigmoid,
            ElementwiseOp::HardSwish,
            ElementwiseOp::Softplus,
            ElementwiseOp::LeakyRelu,
            ElementwiseOp::OneMinus,
            ElementwiseOp::Add,
            ElementwiseOp::Sub,
            ElementwiseOp::Mul,
            ElementwiseOp::Div,
            ElementwiseOp::Pow,
            ElementwiseOp::Min,
            ElementwiseOp::Max,
            ElementwiseOp::CmpEq,
            ElementwiseOp::CmpNe,
            ElementwiseOp::CmpLt,
            ElementwiseOp::CmpGt,
            ElementwiseOp::CmpLe,
            ElementwiseOp::CmpGe,
            ElementwiseOp::OrMax,
            ElementwiseOp::OrProbSum,
            ElementwiseOp::Nand,
            ElementwiseOp::Nor,
            ElementwiseOp::Xor,
            ElementwiseOp::Scale,
            ElementwiseOp::AddScalar,
            ElementwiseOp::FusedAddRelu,
            ElementwiseOp::FusedScaleAdd,
        ];
        for op in &ops {
            let ptx_op = op.to_ptx_op();
            assert_eq!(ptx_op.as_str(), op.as_str());
        }
    }
}