oxicuda-blas 0.1.3

OxiCUDA BLAS - GPU-accelerated BLAS operations (cuBLAS equivalent)
Documentation
//! Reduction operation type enumeration.
//!
//! Maps BLAS-level reduction names to their corresponding PTX template
//! variants in [`oxicuda_ptx::templates::reduction::ReductionOp`].

use oxicuda_ptx::templates::reduction::ReductionOp as PtxReductionOp;

/// Reduction operation types supported by the BLAS reduction module.
///
/// Each variant corresponds to a PTX kernel generated by [`oxicuda_ptx::templates::reduction::ReductionTemplate`].
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReductionOp {
    /// Summation: `result = sum(input[i])`.
    Sum,
    /// Maximum: `result = max(input[i])`.
    Max,
    /// Minimum: `result = min(input[i])`.
    Min,
    /// Product: `result = prod(input[i])`.
    Product,
}

impl ReductionOp {
    /// Converts to the corresponding PTX template reduction op.
    #[allow(dead_code)]
    pub(crate) fn to_ptx_op(self) -> PtxReductionOp {
        match self {
            Self::Sum => PtxReductionOp::Sum,
            Self::Max => PtxReductionOp::Max,
            Self::Min => PtxReductionOp::Min,
            Self::Product => PtxReductionOp::Prod,
        }
    }

    /// Returns a short lowercase name for diagnostics and logging.
    pub fn as_str(self) -> &'static str {
        match self {
            Self::Sum => "sum",
            Self::Max => "max",
            Self::Min => "min",
            Self::Product => "product",
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn op_names_are_lowercase() {
        let ops = [
            ReductionOp::Sum,
            ReductionOp::Max,
            ReductionOp::Min,
            ReductionOp::Product,
        ];
        for op in &ops {
            let name = op.as_str();
            assert_eq!(name, name.to_lowercase());
        }
    }

    #[test]
    fn ptx_op_conversion() {
        assert_eq!(ReductionOp::Sum.to_ptx_op().as_str(), "sum");
        assert_eq!(ReductionOp::Max.to_ptx_op().as_str(), "max");
        assert_eq!(ReductionOp::Min.to_ptx_op().as_str(), "min");
        assert_eq!(ReductionOp::Product.to_ptx_op().as_str(), "prod");
    }
}