use crate::dtype::{DType, Element};
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub enum SemiringOp {
MinPlus,
MaxPlus,
MaxMin,
MinMax,
OrAnd,
PlusMax,
}
impl SemiringOp {
pub fn reduce_identity_f64(self) -> f64 {
match self {
SemiringOp::MinPlus | SemiringOp::MinMax => f64::INFINITY,
SemiringOp::MaxPlus | SemiringOp::MaxMin => f64::NEG_INFINITY,
SemiringOp::OrAnd => 0.0,
SemiringOp::PlusMax => 0.0,
}
}
pub fn reduce_identity<T: Element>(self) -> T {
T::from_f64(self.reduce_identity_f64())
}
#[inline]
pub fn combine<T: Element>(self, a: T, b: T) -> T {
match self {
SemiringOp::MinPlus | SemiringOp::MaxPlus => a + b,
SemiringOp::MaxMin => {
if a <= b { a } else { b }
}
SemiringOp::MinMax | SemiringOp::PlusMax => {
if a >= b { a } else { b }
}
SemiringOp::OrAnd => {
let az = a.to_f64();
let bz = b.to_f64();
if az != 0.0 && bz != 0.0 {
T::one()
} else {
T::zero()
}
}
}
}
#[inline]
pub fn reduce<T: Element>(self, acc: T, val: T) -> T {
match self {
SemiringOp::MinPlus | SemiringOp::MinMax => {
if val <= acc { val } else { acc }
}
SemiringOp::MaxPlus | SemiringOp::MaxMin => {
if val >= acc { val } else { acc }
}
SemiringOp::OrAnd => {
let az = acc.to_f64();
let vz = val.to_f64();
if az != 0.0 || vz != 0.0 {
T::one()
} else {
T::zero()
}
}
SemiringOp::PlusMax => {
acc + val
}
}
}
pub fn combine_name(self) -> &'static str {
match self {
SemiringOp::MinPlus | SemiringOp::MaxPlus => "add",
SemiringOp::MaxMin => "min",
SemiringOp::MinMax | SemiringOp::PlusMax => "max",
SemiringOp::OrAnd => "and",
}
}
pub fn reduce_name(self) -> &'static str {
match self {
SemiringOp::MinPlus | SemiringOp::MinMax => "min",
SemiringOp::MaxPlus | SemiringOp::MaxMin => "max",
SemiringOp::OrAnd => "or",
SemiringOp::PlusMax => "add",
}
}
pub fn validate_dtype(self, dtype: DType) -> bool {
match self {
SemiringOp::OrAnd => matches!(dtype, DType::Bool | DType::U8),
_ => {
matches!(dtype, DType::F32 | DType::F64 | DType::I32 | DType::I64) || {
#[cfg(feature = "f16")]
if matches!(dtype, DType::F16 | DType::BF16) {
return true;
}
#[cfg(feature = "fp8")]
if matches!(dtype, DType::FP8E4M3 | DType::FP8E5M2) {
return true;
}
false
}
}
}
}
}
impl std::fmt::Display for SemiringOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let (reduce, combine) = match self {
SemiringOp::MinPlus => ("min", "+"),
SemiringOp::MaxPlus => ("max", "+"),
SemiringOp::MaxMin => ("max", "min"),
SemiringOp::MinMax => ("min", "max"),
SemiringOp::OrAnd => ("OR", "AND"),
SemiringOp::PlusMax => ("+", "max"),
};
write!(f, "({}, {})", reduce, combine)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_min_plus_combine() {
assert_eq!(SemiringOp::MinPlus.combine(3.0f32, 5.0), 8.0);
}
#[test]
fn test_min_plus_reduce() {
assert_eq!(SemiringOp::MinPlus.reduce(3.0f32, 5.0), 3.0);
assert_eq!(SemiringOp::MinPlus.reduce(5.0f32, 3.0), 3.0);
}
#[test]
fn test_max_min_combine() {
assert_eq!(SemiringOp::MaxMin.combine(3.0f32, 5.0), 3.0);
}
#[test]
fn test_max_min_reduce() {
assert_eq!(SemiringOp::MaxMin.reduce(3.0f32, 5.0), 5.0);
}
#[test]
fn test_min_max_combine() {
assert_eq!(SemiringOp::MinMax.combine(3.0f32, 5.0), 5.0);
}
#[test]
fn test_min_max_reduce() {
assert_eq!(SemiringOp::MinMax.reduce(3.0f32, 5.0), 3.0);
}
#[test]
fn test_plus_max() {
assert_eq!(SemiringOp::PlusMax.combine(3.0f32, 5.0), 5.0);
assert_eq!(SemiringOp::PlusMax.reduce(3.0f32, 5.0), 8.0);
}
#[test]
fn test_identity_elements() {
assert_eq!(SemiringOp::MinPlus.reduce_identity::<f32>(), f32::INFINITY);
assert_eq!(
SemiringOp::MaxPlus.reduce_identity::<f32>(),
f32::NEG_INFINITY
);
assert_eq!(SemiringOp::PlusMax.reduce_identity::<f32>(), 0.0);
}
#[test]
fn test_display() {
assert_eq!(format!("{}", SemiringOp::MinPlus), "(min, +)");
assert_eq!(format!("{}", SemiringOp::OrAnd), "(OR, AND)");
}
#[test]
fn test_validate_dtype() {
assert!(SemiringOp::MinPlus.validate_dtype(DType::F32));
assert!(SemiringOp::MinPlus.validate_dtype(DType::F64));
assert!(SemiringOp::MinPlus.validate_dtype(DType::I32));
assert!(!SemiringOp::MinPlus.validate_dtype(DType::Bool));
assert!(SemiringOp::OrAnd.validate_dtype(DType::Bool));
assert!(!SemiringOp::OrAnd.validate_dtype(DType::F32));
}
}