use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ReductionOp {
Sum,
Min,
Max,
And,
Or,
Xor,
Product,
}
impl ReductionOp {
#[must_use]
pub fn atomic_name(&self) -> &'static str {
match self {
ReductionOp::Sum => "atomicAdd",
ReductionOp::Min => "atomicMin",
ReductionOp::Max => "atomicMax",
ReductionOp::And => "atomicAnd",
ReductionOp::Or => "atomicOr",
ReductionOp::Xor => "atomicXor",
ReductionOp::Product => "atomicMul", }
}
#[must_use]
pub fn wgsl_atomic_name(&self) -> Option<&'static str> {
match self {
ReductionOp::Sum => Some("atomicAdd"),
ReductionOp::Min => Some("atomicMin"),
ReductionOp::Max => Some("atomicMax"),
ReductionOp::And => Some("atomicAnd"),
ReductionOp::Or => Some("atomicOr"),
ReductionOp::Xor => Some("atomicXor"),
ReductionOp::Product => None, }
}
#[must_use]
pub fn c_operator(&self) -> &'static str {
match self {
ReductionOp::Sum => "+",
ReductionOp::Min => "min",
ReductionOp::Max => "max",
ReductionOp::And => "&",
ReductionOp::Or => "|",
ReductionOp::Xor => "^",
ReductionOp::Product => "*",
}
}
#[must_use]
pub const fn is_commutative(&self) -> bool {
true }
#[must_use]
pub const fn is_associative(&self) -> bool {
true
}
}
impl std::fmt::Display for ReductionOp {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ReductionOp::Sum => write!(f, "sum"),
ReductionOp::Min => write!(f, "min"),
ReductionOp::Max => write!(f, "max"),
ReductionOp::And => write!(f, "and"),
ReductionOp::Or => write!(f, "or"),
ReductionOp::Xor => write!(f, "xor"),
ReductionOp::Product => write!(f, "product"),
}
}
}
pub trait ReductionScalar: Copy + Send + Sync + Debug + Default + 'static {
fn identity(op: ReductionOp) -> Self;
fn combine(a: Self, b: Self, op: ReductionOp) -> Self;
fn size_bytes() -> usize {
std::mem::size_of::<Self>()
}
fn cuda_type() -> &'static str;
fn wgsl_type() -> &'static str;
}
impl ReductionScalar for f32 {
fn identity(op: ReductionOp) -> Self {
match op {
ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0.0,
ReductionOp::Min => f32::INFINITY,
ReductionOp::Max => f32::NEG_INFINITY,
ReductionOp::Product | ReductionOp::And => 1.0,
}
}
fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
match op {
ReductionOp::Sum => a + b,
ReductionOp::Min => a.min(b),
ReductionOp::Max => a.max(b),
ReductionOp::Product => a * b,
ReductionOp::And => f32::from_bits(a.to_bits() & b.to_bits()),
ReductionOp::Or => f32::from_bits(a.to_bits() | b.to_bits()),
ReductionOp::Xor => f32::from_bits(a.to_bits() ^ b.to_bits()),
}
}
fn cuda_type() -> &'static str {
"float"
}
fn wgsl_type() -> &'static str {
"f32"
}
}
impl ReductionScalar for f64 {
fn identity(op: ReductionOp) -> Self {
match op {
ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0.0,
ReductionOp::Min => f64::INFINITY,
ReductionOp::Max => f64::NEG_INFINITY,
ReductionOp::Product | ReductionOp::And => 1.0,
}
}
fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
match op {
ReductionOp::Sum => a + b,
ReductionOp::Min => a.min(b),
ReductionOp::Max => a.max(b),
ReductionOp::Product => a * b,
ReductionOp::And => f64::from_bits(a.to_bits() & b.to_bits()),
ReductionOp::Or => f64::from_bits(a.to_bits() | b.to_bits()),
ReductionOp::Xor => f64::from_bits(a.to_bits() ^ b.to_bits()),
}
}
fn cuda_type() -> &'static str {
"double"
}
fn wgsl_type() -> &'static str {
"f32" }
}
impl ReductionScalar for i32 {
fn identity(op: ReductionOp) -> Self {
match op {
ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
ReductionOp::Min => i32::MAX,
ReductionOp::Max => i32::MIN,
ReductionOp::Product => 1,
ReductionOp::And => -1, }
}
fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
match op {
ReductionOp::Sum => a.wrapping_add(b),
ReductionOp::Min => a.min(b),
ReductionOp::Max => a.max(b),
ReductionOp::Product => a.wrapping_mul(b),
ReductionOp::And => a & b,
ReductionOp::Or => a | b,
ReductionOp::Xor => a ^ b,
}
}
fn cuda_type() -> &'static str {
"int"
}
fn wgsl_type() -> &'static str {
"i32"
}
}
impl ReductionScalar for i64 {
fn identity(op: ReductionOp) -> Self {
match op {
ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
ReductionOp::Min => i64::MAX,
ReductionOp::Max => i64::MIN,
ReductionOp::Product => 1,
ReductionOp::And => -1,
}
}
fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
match op {
ReductionOp::Sum => a.wrapping_add(b),
ReductionOp::Min => a.min(b),
ReductionOp::Max => a.max(b),
ReductionOp::Product => a.wrapping_mul(b),
ReductionOp::And => a & b,
ReductionOp::Or => a | b,
ReductionOp::Xor => a ^ b,
}
}
fn cuda_type() -> &'static str {
"long long"
}
fn wgsl_type() -> &'static str {
"i32" }
}
impl ReductionScalar for u32 {
fn identity(op: ReductionOp) -> Self {
match op {
ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
ReductionOp::Min | ReductionOp::And => u32::MAX,
ReductionOp::Max => 0,
ReductionOp::Product => 1,
}
}
fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
match op {
ReductionOp::Sum => a.wrapping_add(b),
ReductionOp::Min => a.min(b),
ReductionOp::Max => a.max(b),
ReductionOp::Product => a.wrapping_mul(b),
ReductionOp::And => a & b,
ReductionOp::Or => a | b,
ReductionOp::Xor => a ^ b,
}
}
fn cuda_type() -> &'static str {
"unsigned int"
}
fn wgsl_type() -> &'static str {
"u32"
}
}
impl ReductionScalar for u64 {
fn identity(op: ReductionOp) -> Self {
match op {
ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
ReductionOp::Min | ReductionOp::And => u64::MAX,
ReductionOp::Max => 0,
ReductionOp::Product => 1,
}
}
fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
match op {
ReductionOp::Sum => a.wrapping_add(b),
ReductionOp::Min => a.min(b),
ReductionOp::Max => a.max(b),
ReductionOp::Product => a.wrapping_mul(b),
ReductionOp::And => a & b,
ReductionOp::Or => a | b,
ReductionOp::Xor => a ^ b,
}
}
fn cuda_type() -> &'static str {
"unsigned long long"
}
fn wgsl_type() -> &'static str {
"u32" }
}
#[derive(Debug, Clone)]
pub struct ReductionConfig {
pub num_slots: usize,
pub use_cooperative: bool,
pub use_software_barrier: bool,
pub shared_mem_bytes: usize,
}
impl Default for ReductionConfig {
fn default() -> Self {
Self {
num_slots: 1,
use_cooperative: true,
use_software_barrier: true,
shared_mem_bytes: 0,
}
}
}
impl ReductionConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_slots(mut self, num_slots: usize) -> Self {
self.num_slots = num_slots.max(1);
self
}
#[must_use]
pub fn with_cooperative(mut self, enabled: bool) -> Self {
self.use_cooperative = enabled;
self
}
#[must_use]
pub fn with_software_barrier(mut self, enabled: bool) -> Self {
self.use_software_barrier = enabled;
self
}
#[must_use]
pub fn with_shared_mem(mut self, bytes: usize) -> Self {
self.shared_mem_bytes = bytes;
self
}
}
pub trait ReductionHandle<T: ReductionScalar>: Send + Sync {
fn device_ptr(&self) -> u64;
fn reset(&self) -> crate::error::Result<()>;
fn read(&self) -> crate::error::Result<T>;
fn read_combined(&self) -> crate::error::Result<T>;
fn sync_and_read(&self) -> crate::error::Result<T>;
fn op(&self) -> ReductionOp;
fn num_slots(&self) -> usize;
}
pub trait GlobalReduction: Send + Sync {
fn create_reduction_buffer<T: ReductionScalar>(
&self,
op: ReductionOp,
config: &ReductionConfig,
) -> crate::error::Result<Box<dyn ReductionHandle<T>>>;
fn supports_cooperative(&self) -> bool;
fn supports_grid_reduction(&self) -> bool;
fn cooperative_compute_capability(&self) -> Option<(u32, u32)> {
Some((6, 0)) }
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reduction_op_display() {
assert_eq!(format!("{}", ReductionOp::Sum), "sum");
assert_eq!(format!("{}", ReductionOp::Min), "min");
assert_eq!(format!("{}", ReductionOp::Max), "max");
}
#[test]
fn test_f32_identity() {
assert_eq!(f32::identity(ReductionOp::Sum), 0.0);
assert_eq!(f32::identity(ReductionOp::Min), f32::INFINITY);
assert_eq!(f32::identity(ReductionOp::Max), f32::NEG_INFINITY);
assert_eq!(f32::identity(ReductionOp::Product), 1.0);
}
#[test]
fn test_f32_combine() {
assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Sum), 5.0);
assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Min), 2.0);
assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Max), 3.0);
assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Product), 6.0);
}
#[test]
fn test_i32_identity() {
assert_eq!(i32::identity(ReductionOp::Sum), 0);
assert_eq!(i32::identity(ReductionOp::Min), i32::MAX);
assert_eq!(i32::identity(ReductionOp::Max), i32::MIN);
assert_eq!(i32::identity(ReductionOp::And), -1);
assert_eq!(i32::identity(ReductionOp::Or), 0);
}
#[test]
fn test_u32_combine() {
assert_eq!(u32::combine(5, 3, ReductionOp::Sum), 8);
assert_eq!(u32::combine(5, 3, ReductionOp::Min), 3);
assert_eq!(u32::combine(5, 3, ReductionOp::Max), 5);
assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::And), 0b1000);
assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::Or), 0b1110);
assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::Xor), 0b0110);
}
#[test]
fn test_reduction_config_builder() {
let config = ReductionConfig::new()
.with_slots(4)
.with_cooperative(false)
.with_shared_mem(4096);
assert_eq!(config.num_slots, 4);
assert!(!config.use_cooperative);
assert_eq!(config.shared_mem_bytes, 4096);
}
#[test]
fn test_cuda_type_names() {
assert_eq!(f32::cuda_type(), "float");
assert_eq!(f64::cuda_type(), "double");
assert_eq!(i32::cuda_type(), "int");
assert_eq!(i64::cuda_type(), "long long");
assert_eq!(u32::cuda_type(), "unsigned int");
assert_eq!(u64::cuda_type(), "unsigned long long");
}
#[test]
fn test_atomic_names() {
assert_eq!(ReductionOp::Sum.atomic_name(), "atomicAdd");
assert_eq!(ReductionOp::Min.atomic_name(), "atomicMin");
assert_eq!(ReductionOp::Max.atomic_name(), "atomicMax");
}
}