Skip to main content

ringkernel_core/
reduction.rs

1//! Global Reduction Primitives
2//!
3//! This module provides traits and types for GPU-accelerated reduction operations.
4//! Reductions aggregate values across all GPU threads using operations like sum,
5//! min, max, etc.
6//!
7//! # Use Cases
8//!
9//! - **PageRank**: Sum dangling node contributions across all nodes
10//! - **Graph algorithms**: Compute convergence metrics, global norms
11//! - **Scientific computing**: Vector norms, dot products, energy calculations
12//!
13//! # Architecture
14//!
15//! Reductions use a hierarchical approach for efficiency:
16//! 1. **Warp-level**: Use shuffle instructions for fast intra-warp reduction
17//! 2. **Block-level**: Tree reduction in shared memory with `__syncthreads()`
18//! 3. **Grid-level**: Atomic accumulation from block leaders, then broadcast
19//!
20//! # Example
21//!
22//! ```ignore
23//! use ringkernel_core::reduction::{ReductionOp, GlobalReduction};
24//!
25//! // In kernel code (DSL):
26//! let my_contrib = if out_degree[idx] == 0 { rank } else { 0.0 };
27//! let dangling_sum = reduce_and_broadcast(my_contrib, &accumulator);
28//! let new_rank = base + damping * (incoming + dangling_sum / n);
29//! ```
30
31use std::fmt::Debug;
32
33/// Reduction operation types.
34///
35/// Each operation has an identity value that serves as the neutral element:
36/// - Sum: 0 (a + 0 = a)
37/// - Min: MAX (min(a, MAX) = a)
38/// - Max: MIN (max(a, MIN) = a)
39/// - And: all bits set (-1 for signed, MAX for unsigned)
40/// - Or: 0 (a | 0 = a)
41/// - Xor: 0 (a ^ 0 = a)
42/// - Product: 1 (a * 1 = a)
43#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum ReductionOp {
45    /// Sum of all values.
46    Sum,
47    /// Minimum value.
48    Min,
49    /// Maximum value.
50    Max,
51    /// Bitwise AND.
52    And,
53    /// Bitwise OR.
54    Or,
55    /// Bitwise XOR.
56    Xor,
57    /// Product of all values.
58    Product,
59}
60
61impl ReductionOp {
62    /// Get the CUDA atomic function name for this operation.
63    #[must_use]
64    pub fn atomic_name(&self) -> &'static str {
65        match self {
66            ReductionOp::Sum => "atomicAdd",
67            ReductionOp::Min => "atomicMin",
68            ReductionOp::Max => "atomicMax",
69            ReductionOp::And => "atomicAnd",
70            ReductionOp::Or => "atomicOr",
71            ReductionOp::Xor => "atomicXor",
72            ReductionOp::Product => "atomicMul", // Requires custom implementation
73        }
74    }
75
76    /// Get the WGSL atomic function name for this operation.
77    #[must_use]
78    pub fn wgsl_atomic_name(&self) -> Option<&'static str> {
79        match self {
80            ReductionOp::Sum => Some("atomicAdd"),
81            ReductionOp::Min => Some("atomicMin"),
82            ReductionOp::Max => Some("atomicMax"),
83            ReductionOp::And => Some("atomicAnd"),
84            ReductionOp::Or => Some("atomicOr"),
85            ReductionOp::Xor => Some("atomicXor"),
86            ReductionOp::Product => None, // Not supported in WGSL
87        }
88    }
89
90    /// Get the C operator for this reduction (for code generation).
91    #[must_use]
92    pub fn c_operator(&self) -> &'static str {
93        match self {
94            ReductionOp::Sum => "+",
95            ReductionOp::Min => "min",
96            ReductionOp::Max => "max",
97            ReductionOp::And => "&",
98            ReductionOp::Or => "|",
99            ReductionOp::Xor => "^",
100            ReductionOp::Product => "*",
101        }
102    }
103
104    /// Check if this operation is commutative.
105    #[must_use]
106    pub const fn is_commutative(&self) -> bool {
107        true // All supported operations are commutative
108    }
109
110    /// Check if this operation is associative.
111    #[must_use]
112    pub const fn is_associative(&self) -> bool {
113        // Note: floating-point sum/product are not strictly associative
114        // due to rounding, but we treat them as such for parallel reduction
115        true
116    }
117}
118
119impl std::fmt::Display for ReductionOp {
120    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121        match self {
122            ReductionOp::Sum => write!(f, "sum"),
123            ReductionOp::Min => write!(f, "min"),
124            ReductionOp::Max => write!(f, "max"),
125            ReductionOp::And => write!(f, "and"),
126            ReductionOp::Or => write!(f, "or"),
127            ReductionOp::Xor => write!(f, "xor"),
128            ReductionOp::Product => write!(f, "product"),
129        }
130    }
131}
132
133/// Trait for scalar types that support reduction operations.
134///
135/// Implementors must provide identity values for each reduction operation.
136/// The identity value is the neutral element such that `op(x, identity) = x`.
137pub trait ReductionScalar: Copy + Send + Sync + Debug + Default + 'static {
138    /// Get the identity value for the given reduction operation.
139    fn identity(op: ReductionOp) -> Self;
140
141    /// Combine two values according to the reduction operation.
142    fn combine(a: Self, b: Self, op: ReductionOp) -> Self;
143
144    /// Size in bytes.
145    fn size_bytes() -> usize {
146        std::mem::size_of::<Self>()
147    }
148
149    /// CUDA type name for code generation.
150    fn cuda_type() -> &'static str;
151
152    /// WGSL type name for code generation.
153    fn wgsl_type() -> &'static str;
154}
155
156impl ReductionScalar for f32 {
157    fn identity(op: ReductionOp) -> Self {
158        match op {
159            ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0.0,
160            ReductionOp::Min => f32::INFINITY,
161            ReductionOp::Max => f32::NEG_INFINITY,
162            ReductionOp::Product | ReductionOp::And => 1.0,
163        }
164    }
165
166    fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
167        match op {
168            ReductionOp::Sum => a + b,
169            ReductionOp::Min => a.min(b),
170            ReductionOp::Max => a.max(b),
171            ReductionOp::Product => a * b,
172            // Bitwise ops on floats: use bit representation
173            ReductionOp::And => f32::from_bits(a.to_bits() & b.to_bits()),
174            ReductionOp::Or => f32::from_bits(a.to_bits() | b.to_bits()),
175            ReductionOp::Xor => f32::from_bits(a.to_bits() ^ b.to_bits()),
176        }
177    }
178
179    fn cuda_type() -> &'static str {
180        "float"
181    }
182
183    fn wgsl_type() -> &'static str {
184        "f32"
185    }
186}
187
188impl ReductionScalar for f64 {
189    fn identity(op: ReductionOp) -> Self {
190        match op {
191            ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0.0,
192            ReductionOp::Min => f64::INFINITY,
193            ReductionOp::Max => f64::NEG_INFINITY,
194            ReductionOp::Product | ReductionOp::And => 1.0,
195        }
196    }
197
198    fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
199        match op {
200            ReductionOp::Sum => a + b,
201            ReductionOp::Min => a.min(b),
202            ReductionOp::Max => a.max(b),
203            ReductionOp::Product => a * b,
204            ReductionOp::And => f64::from_bits(a.to_bits() & b.to_bits()),
205            ReductionOp::Or => f64::from_bits(a.to_bits() | b.to_bits()),
206            ReductionOp::Xor => f64::from_bits(a.to_bits() ^ b.to_bits()),
207        }
208    }
209
210    fn cuda_type() -> &'static str {
211        "double"
212    }
213
214    fn wgsl_type() -> &'static str {
215        "f32" // WGSL doesn't have f64, fallback to f32
216    }
217}
218
219impl ReductionScalar for i32 {
220    fn identity(op: ReductionOp) -> Self {
221        match op {
222            ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
223            ReductionOp::Min => i32::MAX,
224            ReductionOp::Max => i32::MIN,
225            ReductionOp::Product => 1,
226            ReductionOp::And => -1, // All bits set
227        }
228    }
229
230    fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
231        match op {
232            ReductionOp::Sum => a.wrapping_add(b),
233            ReductionOp::Min => a.min(b),
234            ReductionOp::Max => a.max(b),
235            ReductionOp::Product => a.wrapping_mul(b),
236            ReductionOp::And => a & b,
237            ReductionOp::Or => a | b,
238            ReductionOp::Xor => a ^ b,
239        }
240    }
241
242    fn cuda_type() -> &'static str {
243        "int"
244    }
245
246    fn wgsl_type() -> &'static str {
247        "i32"
248    }
249}
250
251impl ReductionScalar for i64 {
252    fn identity(op: ReductionOp) -> Self {
253        match op {
254            ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
255            ReductionOp::Min => i64::MAX,
256            ReductionOp::Max => i64::MIN,
257            ReductionOp::Product => 1,
258            ReductionOp::And => -1,
259        }
260    }
261
262    fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
263        match op {
264            ReductionOp::Sum => a.wrapping_add(b),
265            ReductionOp::Min => a.min(b),
266            ReductionOp::Max => a.max(b),
267            ReductionOp::Product => a.wrapping_mul(b),
268            ReductionOp::And => a & b,
269            ReductionOp::Or => a | b,
270            ReductionOp::Xor => a ^ b,
271        }
272    }
273
274    fn cuda_type() -> &'static str {
275        "long long"
276    }
277
278    fn wgsl_type() -> &'static str {
279        "i32" // WGSL doesn't have i64
280    }
281}
282
283impl ReductionScalar for u32 {
284    fn identity(op: ReductionOp) -> Self {
285        match op {
286            ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
287            ReductionOp::Min | ReductionOp::And => u32::MAX,
288            ReductionOp::Max => 0,
289            ReductionOp::Product => 1,
290        }
291    }
292
293    fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
294        match op {
295            ReductionOp::Sum => a.wrapping_add(b),
296            ReductionOp::Min => a.min(b),
297            ReductionOp::Max => a.max(b),
298            ReductionOp::Product => a.wrapping_mul(b),
299            ReductionOp::And => a & b,
300            ReductionOp::Or => a | b,
301            ReductionOp::Xor => a ^ b,
302        }
303    }
304
305    fn cuda_type() -> &'static str {
306        "unsigned int"
307    }
308
309    fn wgsl_type() -> &'static str {
310        "u32"
311    }
312}
313
314impl ReductionScalar for u64 {
315    fn identity(op: ReductionOp) -> Self {
316        match op {
317            ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
318            ReductionOp::Min | ReductionOp::And => u64::MAX,
319            ReductionOp::Max => 0,
320            ReductionOp::Product => 1,
321        }
322    }
323
324    fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
325        match op {
326            ReductionOp::Sum => a.wrapping_add(b),
327            ReductionOp::Min => a.min(b),
328            ReductionOp::Max => a.max(b),
329            ReductionOp::Product => a.wrapping_mul(b),
330            ReductionOp::And => a & b,
331            ReductionOp::Or => a | b,
332            ReductionOp::Xor => a ^ b,
333        }
334    }
335
336    fn cuda_type() -> &'static str {
337        "unsigned long long"
338    }
339
340    fn wgsl_type() -> &'static str {
341        "u32" // WGSL doesn't have u64
342    }
343}
344
345/// Configuration for reduction operations.
346#[derive(Debug, Clone)]
347pub struct ReductionConfig {
348    /// Number of reduction slots (for parallel accumulation).
349    ///
350    /// Multiple slots reduce atomic contention by spreading updates
351    /// across several memory locations. The final result is computed
352    /// by combining all slots on the host.
353    pub num_slots: usize,
354
355    /// Use cooperative groups for grid-wide synchronization.
356    ///
357    /// Requires compute capability 6.0+ (Pascal or newer).
358    /// When disabled, falls back to software barriers or multi-launch.
359    pub use_cooperative: bool,
360
361    /// Use software barrier when cooperative groups unavailable.
362    ///
363    /// Software barriers use atomic counters in global memory.
364    /// This works on all devices but has higher latency.
365    pub use_software_barrier: bool,
366
367    /// Shared memory size per block for reduction (bytes).
368    ///
369    /// Should be at least `block_size * sizeof(T)` for full reduction.
370    /// Default: 0 (auto-calculate based on block size).
371    pub shared_mem_bytes: usize,
372}
373
374impl Default for ReductionConfig {
375    fn default() -> Self {
376        Self {
377            num_slots: 1,
378            use_cooperative: true,
379            use_software_barrier: true,
380            shared_mem_bytes: 0,
381        }
382    }
383}
384
385impl ReductionConfig {
386    /// Create a new reduction config with default settings.
387    #[must_use]
388    pub fn new() -> Self {
389        Self::default()
390    }
391
392    /// Set the number of accumulation slots.
393    #[must_use]
394    pub fn with_slots(mut self, num_slots: usize) -> Self {
395        self.num_slots = num_slots.max(1);
396        self
397    }
398
399    /// Enable or disable cooperative groups.
400    #[must_use]
401    pub fn with_cooperative(mut self, enabled: bool) -> Self {
402        self.use_cooperative = enabled;
403        self
404    }
405
406    /// Enable or disable software barrier fallback.
407    #[must_use]
408    pub fn with_software_barrier(mut self, enabled: bool) -> Self {
409        self.use_software_barrier = enabled;
410        self
411    }
412
413    /// Set explicit shared memory size.
414    #[must_use]
415    pub fn with_shared_mem(mut self, bytes: usize) -> Self {
416        self.shared_mem_bytes = bytes;
417        self
418    }
419}
420
421/// Handle to a reduction buffer for streaming operations.
422///
423/// This trait abstracts over backend-specific reduction buffer implementations,
424/// allowing the same code to work with CUDA, WebGPU, or CPU backends.
425pub trait ReductionHandle<T: ReductionScalar>: Send + Sync {
426    /// Get device pointer for kernel parameter passing.
427    fn device_ptr(&self) -> u64;
428
429    /// Reset buffer to identity value.
430    fn reset(&self) -> crate::error::Result<()>;
431
432    /// Read the current reduction result from slot 0.
433    fn read(&self) -> crate::error::Result<T>;
434
435    /// Read and combine all slots into a single result.
436    fn read_combined(&self) -> crate::error::Result<T>;
437
438    /// Synchronize device and read result.
439    ///
440    /// Ensures all GPU operations complete before reading.
441    fn sync_and_read(&self) -> crate::error::Result<T>;
442
443    /// Get the reduction operation type.
444    fn op(&self) -> ReductionOp;
445
446    /// Get number of slots.
447    fn num_slots(&self) -> usize;
448}
449
450/// Trait for GPU runtimes that support global reduction operations.
451///
452/// Implemented by backend-specific runtimes (CUDA, WebGPU, etc.) to provide
453/// efficient reduction primitives.
454pub trait GlobalReduction: Send + Sync {
455    /// Create a reduction buffer for the specified type and operation.
456    fn create_reduction_buffer<T: ReductionScalar>(
457        &self,
458        op: ReductionOp,
459        config: &ReductionConfig,
460    ) -> crate::error::Result<Box<dyn ReductionHandle<T>>>;
461
462    /// Check if cooperative groups are supported.
463    fn supports_cooperative(&self) -> bool;
464
465    /// Check if grid-wide reduction is available.
466    fn supports_grid_reduction(&self) -> bool;
467
468    /// Get minimum compute capability for cooperative groups.
469    ///
470    /// Returns (major, minor) version tuple, or None if not applicable.
471    fn cooperative_compute_capability(&self) -> Option<(u32, u32)> {
472        Some((6, 0)) // Default: Pascal or newer
473    }
474}
475
476#[cfg(test)]
477mod tests {
478    use super::*;
479
480    #[test]
481    fn test_reduction_op_display() {
482        assert_eq!(format!("{}", ReductionOp::Sum), "sum");
483        assert_eq!(format!("{}", ReductionOp::Min), "min");
484        assert_eq!(format!("{}", ReductionOp::Max), "max");
485    }
486
487    #[test]
488    fn test_f32_identity() {
489        assert_eq!(f32::identity(ReductionOp::Sum), 0.0);
490        assert_eq!(f32::identity(ReductionOp::Min), f32::INFINITY);
491        assert_eq!(f32::identity(ReductionOp::Max), f32::NEG_INFINITY);
492        assert_eq!(f32::identity(ReductionOp::Product), 1.0);
493    }
494
495    #[test]
496    fn test_f32_combine() {
497        assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Sum), 5.0);
498        assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Min), 2.0);
499        assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Max), 3.0);
500        assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Product), 6.0);
501    }
502
503    #[test]
504    fn test_i32_identity() {
505        assert_eq!(i32::identity(ReductionOp::Sum), 0);
506        assert_eq!(i32::identity(ReductionOp::Min), i32::MAX);
507        assert_eq!(i32::identity(ReductionOp::Max), i32::MIN);
508        assert_eq!(i32::identity(ReductionOp::And), -1);
509        assert_eq!(i32::identity(ReductionOp::Or), 0);
510    }
511
512    #[test]
513    fn test_u32_combine() {
514        assert_eq!(u32::combine(5, 3, ReductionOp::Sum), 8);
515        assert_eq!(u32::combine(5, 3, ReductionOp::Min), 3);
516        assert_eq!(u32::combine(5, 3, ReductionOp::Max), 5);
517        assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::And), 0b1000);
518        assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::Or), 0b1110);
519        assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::Xor), 0b0110);
520    }
521
522    #[test]
523    fn test_reduction_config_builder() {
524        let config = ReductionConfig::new()
525            .with_slots(4)
526            .with_cooperative(false)
527            .with_shared_mem(4096);
528
529        assert_eq!(config.num_slots, 4);
530        assert!(!config.use_cooperative);
531        assert_eq!(config.shared_mem_bytes, 4096);
532    }
533
534    #[test]
535    fn test_cuda_type_names() {
536        assert_eq!(f32::cuda_type(), "float");
537        assert_eq!(f64::cuda_type(), "double");
538        assert_eq!(i32::cuda_type(), "int");
539        assert_eq!(i64::cuda_type(), "long long");
540        assert_eq!(u32::cuda_type(), "unsigned int");
541        assert_eq!(u64::cuda_type(), "unsigned long long");
542    }
543
544    #[test]
545    fn test_atomic_names() {
546        assert_eq!(ReductionOp::Sum.atomic_name(), "atomicAdd");
547        assert_eq!(ReductionOp::Min.atomic_name(), "atomicMin");
548        assert_eq!(ReductionOp::Max.atomic_name(), "atomicMax");
549    }
550}