use oxicuda_ptx::templates::reduction::ReductionOp as PtxReductionOp;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReductionOp {
Sum,
Max,
Min,
Product,
}
impl ReductionOp {
#[allow(dead_code)]
pub(crate) fn to_ptx_op(self) -> PtxReductionOp {
match self {
Self::Sum => PtxReductionOp::Sum,
Self::Max => PtxReductionOp::Max,
Self::Min => PtxReductionOp::Min,
Self::Product => PtxReductionOp::Prod,
}
}
pub fn as_str(self) -> &'static str {
match self {
Self::Sum => "sum",
Self::Max => "max",
Self::Min => "min",
Self::Product => "product",
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn op_names_are_lowercase() {
let ops = [
ReductionOp::Sum,
ReductionOp::Max,
ReductionOp::Min,
ReductionOp::Product,
];
for op in &ops {
let name = op.as_str();
assert_eq!(name, name.to_lowercase());
}
}
#[test]
fn ptx_op_conversion() {
assert_eq!(ReductionOp::Sum.to_ptx_op().as_str(), "sum");
assert_eq!(ReductionOp::Max.to_ptx_op().as_str(), "max");
assert_eq!(ReductionOp::Min.to_ptx_op().as_str(), "min");
assert_eq!(ReductionOp::Product.to_ptx_op().as_str(), "prod");
}
}