oxicuda-blas 0.1.3

OxiCUDA BLAS - GPU-accelerated BLAS operations (cuBLAS equivalent)
Documentation
//! Elementwise operation type enumeration.
//!
//! Maps BLAS-level operation names to their corresponding PTX template
//! variants in [`oxicuda_ptx::templates::elementwise::ElementwiseOp`].

use oxicuda_ptx::templates::elementwise::ElementwiseOp as PtxElementwiseOp;

/// Elementwise operation types supported by the BLAS elementwise module.
///
/// Each variant corresponds to a PTX kernel generated by [`oxicuda_ptx::templates::elementwise::ElementwiseTemplate`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ElementwiseOp {
    /// Rectified linear unit: `output[i] = max(0, input[i])`.
    Relu,
    /// Gaussian error linear unit (tanh approximation).
    Gelu,
    /// Sigmoid activation: `output[i] = 1 / (1 + exp(-input[i]))`.
    Sigmoid,
    /// Sigmoid linear unit: `output[i] = input[i] * sigmoid(input[i])`.
    Silu,
    /// Hyperbolic tangent activation.
    Tanh,
    /// Element-wise addition: `C[i] = A[i] + B[i]`.
    Add,
    /// Element-wise multiplication (Hadamard product): `C[i] = A[i] * B[i]`.
    Mul,
    /// Scalar scaling: `output[i] = alpha * input[i]`.
    Scale,
    /// Add scalar: `output[i] = input[i] + scalar`.
    AddScalar,
    /// Fused add + ReLU: `C[i] = max(0, A[i] + B[i])`.
    FusedAddRelu,
    /// Fused scale-add: `C[i] = alpha * A[i] + beta * B[i]`.
    FusedScaleAdd,
}

impl ElementwiseOp {
    /// Converts this BLAS-level op to the corresponding PTX template op.
    #[allow(dead_code)]
    pub(crate) fn to_ptx_op(self) -> PtxElementwiseOp {
        match self {
            Self::Relu => PtxElementwiseOp::Relu,
            Self::Gelu => PtxElementwiseOp::Gelu,
            Self::Sigmoid => PtxElementwiseOp::Sigmoid,
            Self::Silu => PtxElementwiseOp::Silu,
            Self::Tanh => PtxElementwiseOp::Tanh,
            Self::Add => PtxElementwiseOp::Add,
            Self::Mul => PtxElementwiseOp::Mul,
            Self::Scale => PtxElementwiseOp::Scale,
            Self::AddScalar => PtxElementwiseOp::AddScalar,
            Self::FusedAddRelu => PtxElementwiseOp::FusedAddRelu,
            Self::FusedScaleAdd => PtxElementwiseOp::FusedScaleAdd,
        }
    }

    /// Returns a short lowercase name for diagnostics and logging.
    pub fn as_str(self) -> &'static str {
        match self {
            Self::Relu => "relu",
            Self::Gelu => "gelu",
            Self::Sigmoid => "sigmoid",
            Self::Silu => "silu",
            Self::Tanh => "tanh",
            Self::Add => "add",
            Self::Mul => "mul",
            Self::Scale => "scale",
            Self::AddScalar => "add_scalar",
            Self::FusedAddRelu => "fused_add_relu",
            Self::FusedScaleAdd => "fused_scale_add",
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn op_names_are_lowercase() {
        let ops = [
            ElementwiseOp::Relu,
            ElementwiseOp::Gelu,
            ElementwiseOp::Sigmoid,
            ElementwiseOp::Silu,
            ElementwiseOp::Tanh,
            ElementwiseOp::Add,
            ElementwiseOp::Mul,
            ElementwiseOp::Scale,
            ElementwiseOp::AddScalar,
            ElementwiseOp::FusedAddRelu,
            ElementwiseOp::FusedScaleAdd,
        ];
        for op in &ops {
            let name = op.as_str();
            assert_eq!(name, name.to_lowercase(), "op name should be lowercase");
        }
    }

    #[test]
    fn ptx_op_roundtrip() {
        // Verify that all ops convert without panic
        let ops = [
            ElementwiseOp::Relu,
            ElementwiseOp::Gelu,
            ElementwiseOp::Sigmoid,
            ElementwiseOp::Silu,
            ElementwiseOp::Tanh,
            ElementwiseOp::Add,
            ElementwiseOp::Mul,
            ElementwiseOp::Scale,
            ElementwiseOp::AddScalar,
            ElementwiseOp::FusedAddRelu,
            ElementwiseOp::FusedScaleAdd,
        ];
        for op in &ops {
            let ptx_op = op.to_ptx_op();
            assert_eq!(ptx_op.as_str(), op.as_str());
        }
    }
}