1use std::fmt::Debug;
32
33#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
44pub enum ReductionOp {
45 Sum,
47 Min,
49 Max,
51 And,
53 Or,
55 Xor,
57 Product,
59}
60
61impl ReductionOp {
62 #[must_use]
64 pub fn atomic_name(&self) -> &'static str {
65 match self {
66 ReductionOp::Sum => "atomicAdd",
67 ReductionOp::Min => "atomicMin",
68 ReductionOp::Max => "atomicMax",
69 ReductionOp::And => "atomicAnd",
70 ReductionOp::Or => "atomicOr",
71 ReductionOp::Xor => "atomicXor",
72 ReductionOp::Product => "atomicMul", }
74 }
75
76 #[must_use]
78 pub fn wgsl_atomic_name(&self) -> Option<&'static str> {
79 match self {
80 ReductionOp::Sum => Some("atomicAdd"),
81 ReductionOp::Min => Some("atomicMin"),
82 ReductionOp::Max => Some("atomicMax"),
83 ReductionOp::And => Some("atomicAnd"),
84 ReductionOp::Or => Some("atomicOr"),
85 ReductionOp::Xor => Some("atomicXor"),
86 ReductionOp::Product => None, }
88 }
89
90 #[must_use]
92 pub fn c_operator(&self) -> &'static str {
93 match self {
94 ReductionOp::Sum => "+",
95 ReductionOp::Min => "min",
96 ReductionOp::Max => "max",
97 ReductionOp::And => "&",
98 ReductionOp::Or => "|",
99 ReductionOp::Xor => "^",
100 ReductionOp::Product => "*",
101 }
102 }
103
104 #[must_use]
106 pub const fn is_commutative(&self) -> bool {
107 true }
109
110 #[must_use]
112 pub const fn is_associative(&self) -> bool {
113 true
116 }
117}
118
119impl std::fmt::Display for ReductionOp {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 match self {
122 ReductionOp::Sum => write!(f, "sum"),
123 ReductionOp::Min => write!(f, "min"),
124 ReductionOp::Max => write!(f, "max"),
125 ReductionOp::And => write!(f, "and"),
126 ReductionOp::Or => write!(f, "or"),
127 ReductionOp::Xor => write!(f, "xor"),
128 ReductionOp::Product => write!(f, "product"),
129 }
130 }
131}
132
133pub trait ReductionScalar: Copy + Send + Sync + Debug + Default + 'static {
138 fn identity(op: ReductionOp) -> Self;
140
141 fn combine(a: Self, b: Self, op: ReductionOp) -> Self;
143
144 fn size_bytes() -> usize {
146 std::mem::size_of::<Self>()
147 }
148
149 fn cuda_type() -> &'static str;
151
152 fn wgsl_type() -> &'static str;
154}
155
156impl ReductionScalar for f32 {
157 fn identity(op: ReductionOp) -> Self {
158 match op {
159 ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0.0,
160 ReductionOp::Min => f32::INFINITY,
161 ReductionOp::Max => f32::NEG_INFINITY,
162 ReductionOp::Product | ReductionOp::And => 1.0,
163 }
164 }
165
166 fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
167 match op {
168 ReductionOp::Sum => a + b,
169 ReductionOp::Min => a.min(b),
170 ReductionOp::Max => a.max(b),
171 ReductionOp::Product => a * b,
172 ReductionOp::And => f32::from_bits(a.to_bits() & b.to_bits()),
174 ReductionOp::Or => f32::from_bits(a.to_bits() | b.to_bits()),
175 ReductionOp::Xor => f32::from_bits(a.to_bits() ^ b.to_bits()),
176 }
177 }
178
179 fn cuda_type() -> &'static str {
180 "float"
181 }
182
183 fn wgsl_type() -> &'static str {
184 "f32"
185 }
186}
187
188impl ReductionScalar for f64 {
189 fn identity(op: ReductionOp) -> Self {
190 match op {
191 ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0.0,
192 ReductionOp::Min => f64::INFINITY,
193 ReductionOp::Max => f64::NEG_INFINITY,
194 ReductionOp::Product | ReductionOp::And => 1.0,
195 }
196 }
197
198 fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
199 match op {
200 ReductionOp::Sum => a + b,
201 ReductionOp::Min => a.min(b),
202 ReductionOp::Max => a.max(b),
203 ReductionOp::Product => a * b,
204 ReductionOp::And => f64::from_bits(a.to_bits() & b.to_bits()),
205 ReductionOp::Or => f64::from_bits(a.to_bits() | b.to_bits()),
206 ReductionOp::Xor => f64::from_bits(a.to_bits() ^ b.to_bits()),
207 }
208 }
209
210 fn cuda_type() -> &'static str {
211 "double"
212 }
213
214 fn wgsl_type() -> &'static str {
215 "f32" }
217}
218
219impl ReductionScalar for i32 {
220 fn identity(op: ReductionOp) -> Self {
221 match op {
222 ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
223 ReductionOp::Min => i32::MAX,
224 ReductionOp::Max => i32::MIN,
225 ReductionOp::Product => 1,
226 ReductionOp::And => -1, }
228 }
229
230 fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
231 match op {
232 ReductionOp::Sum => a.wrapping_add(b),
233 ReductionOp::Min => a.min(b),
234 ReductionOp::Max => a.max(b),
235 ReductionOp::Product => a.wrapping_mul(b),
236 ReductionOp::And => a & b,
237 ReductionOp::Or => a | b,
238 ReductionOp::Xor => a ^ b,
239 }
240 }
241
242 fn cuda_type() -> &'static str {
243 "int"
244 }
245
246 fn wgsl_type() -> &'static str {
247 "i32"
248 }
249}
250
251impl ReductionScalar for i64 {
252 fn identity(op: ReductionOp) -> Self {
253 match op {
254 ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
255 ReductionOp::Min => i64::MAX,
256 ReductionOp::Max => i64::MIN,
257 ReductionOp::Product => 1,
258 ReductionOp::And => -1,
259 }
260 }
261
262 fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
263 match op {
264 ReductionOp::Sum => a.wrapping_add(b),
265 ReductionOp::Min => a.min(b),
266 ReductionOp::Max => a.max(b),
267 ReductionOp::Product => a.wrapping_mul(b),
268 ReductionOp::And => a & b,
269 ReductionOp::Or => a | b,
270 ReductionOp::Xor => a ^ b,
271 }
272 }
273
274 fn cuda_type() -> &'static str {
275 "long long"
276 }
277
278 fn wgsl_type() -> &'static str {
279 "i32" }
281}
282
283impl ReductionScalar for u32 {
284 fn identity(op: ReductionOp) -> Self {
285 match op {
286 ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
287 ReductionOp::Min | ReductionOp::And => u32::MAX,
288 ReductionOp::Max => 0,
289 ReductionOp::Product => 1,
290 }
291 }
292
293 fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
294 match op {
295 ReductionOp::Sum => a.wrapping_add(b),
296 ReductionOp::Min => a.min(b),
297 ReductionOp::Max => a.max(b),
298 ReductionOp::Product => a.wrapping_mul(b),
299 ReductionOp::And => a & b,
300 ReductionOp::Or => a | b,
301 ReductionOp::Xor => a ^ b,
302 }
303 }
304
305 fn cuda_type() -> &'static str {
306 "unsigned int"
307 }
308
309 fn wgsl_type() -> &'static str {
310 "u32"
311 }
312}
313
314impl ReductionScalar for u64 {
315 fn identity(op: ReductionOp) -> Self {
316 match op {
317 ReductionOp::Sum | ReductionOp::Or | ReductionOp::Xor => 0,
318 ReductionOp::Min | ReductionOp::And => u64::MAX,
319 ReductionOp::Max => 0,
320 ReductionOp::Product => 1,
321 }
322 }
323
324 fn combine(a: Self, b: Self, op: ReductionOp) -> Self {
325 match op {
326 ReductionOp::Sum => a.wrapping_add(b),
327 ReductionOp::Min => a.min(b),
328 ReductionOp::Max => a.max(b),
329 ReductionOp::Product => a.wrapping_mul(b),
330 ReductionOp::And => a & b,
331 ReductionOp::Or => a | b,
332 ReductionOp::Xor => a ^ b,
333 }
334 }
335
336 fn cuda_type() -> &'static str {
337 "unsigned long long"
338 }
339
340 fn wgsl_type() -> &'static str {
341 "u32" }
343}
344
345#[derive(Debug, Clone)]
347pub struct ReductionConfig {
348 pub num_slots: usize,
354
355 pub use_cooperative: bool,
360
361 pub use_software_barrier: bool,
366
367 pub shared_mem_bytes: usize,
372}
373
374impl Default for ReductionConfig {
375 fn default() -> Self {
376 Self {
377 num_slots: 1,
378 use_cooperative: true,
379 use_software_barrier: true,
380 shared_mem_bytes: 0,
381 }
382 }
383}
384
385impl ReductionConfig {
386 #[must_use]
388 pub fn new() -> Self {
389 Self::default()
390 }
391
392 #[must_use]
394 pub fn with_slots(mut self, num_slots: usize) -> Self {
395 self.num_slots = num_slots.max(1);
396 self
397 }
398
399 #[must_use]
401 pub fn with_cooperative(mut self, enabled: bool) -> Self {
402 self.use_cooperative = enabled;
403 self
404 }
405
406 #[must_use]
408 pub fn with_software_barrier(mut self, enabled: bool) -> Self {
409 self.use_software_barrier = enabled;
410 self
411 }
412
413 #[must_use]
415 pub fn with_shared_mem(mut self, bytes: usize) -> Self {
416 self.shared_mem_bytes = bytes;
417 self
418 }
419}
420
421pub trait ReductionHandle<T: ReductionScalar>: Send + Sync {
426 fn device_ptr(&self) -> u64;
428
429 fn reset(&self) -> crate::error::Result<()>;
431
432 fn read(&self) -> crate::error::Result<T>;
434
435 fn read_combined(&self) -> crate::error::Result<T>;
437
438 fn sync_and_read(&self) -> crate::error::Result<T>;
442
443 fn op(&self) -> ReductionOp;
445
446 fn num_slots(&self) -> usize;
448}
449
450pub trait GlobalReduction: Send + Sync {
455 fn create_reduction_buffer<T: ReductionScalar>(
457 &self,
458 op: ReductionOp,
459 config: &ReductionConfig,
460 ) -> crate::error::Result<Box<dyn ReductionHandle<T>>>;
461
462 fn supports_cooperative(&self) -> bool;
464
465 fn supports_grid_reduction(&self) -> bool;
467
468 fn cooperative_compute_capability(&self) -> Option<(u32, u32)> {
472 Some((6, 0)) }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 #[test]
481 fn test_reduction_op_display() {
482 assert_eq!(format!("{}", ReductionOp::Sum), "sum");
483 assert_eq!(format!("{}", ReductionOp::Min), "min");
484 assert_eq!(format!("{}", ReductionOp::Max), "max");
485 }
486
487 #[test]
488 fn test_f32_identity() {
489 assert_eq!(f32::identity(ReductionOp::Sum), 0.0);
490 assert_eq!(f32::identity(ReductionOp::Min), f32::INFINITY);
491 assert_eq!(f32::identity(ReductionOp::Max), f32::NEG_INFINITY);
492 assert_eq!(f32::identity(ReductionOp::Product), 1.0);
493 }
494
495 #[test]
496 fn test_f32_combine() {
497 assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Sum), 5.0);
498 assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Min), 2.0);
499 assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Max), 3.0);
500 assert_eq!(f32::combine(2.0, 3.0, ReductionOp::Product), 6.0);
501 }
502
503 #[test]
504 fn test_i32_identity() {
505 assert_eq!(i32::identity(ReductionOp::Sum), 0);
506 assert_eq!(i32::identity(ReductionOp::Min), i32::MAX);
507 assert_eq!(i32::identity(ReductionOp::Max), i32::MIN);
508 assert_eq!(i32::identity(ReductionOp::And), -1);
509 assert_eq!(i32::identity(ReductionOp::Or), 0);
510 }
511
512 #[test]
513 fn test_u32_combine() {
514 assert_eq!(u32::combine(5, 3, ReductionOp::Sum), 8);
515 assert_eq!(u32::combine(5, 3, ReductionOp::Min), 3);
516 assert_eq!(u32::combine(5, 3, ReductionOp::Max), 5);
517 assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::And), 0b1000);
518 assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::Or), 0b1110);
519 assert_eq!(u32::combine(0b1100, 0b1010, ReductionOp::Xor), 0b0110);
520 }
521
522 #[test]
523 fn test_reduction_config_builder() {
524 let config = ReductionConfig::new()
525 .with_slots(4)
526 .with_cooperative(false)
527 .with_shared_mem(4096);
528
529 assert_eq!(config.num_slots, 4);
530 assert!(!config.use_cooperative);
531 assert_eq!(config.shared_mem_bytes, 4096);
532 }
533
534 #[test]
535 fn test_cuda_type_names() {
536 assert_eq!(f32::cuda_type(), "float");
537 assert_eq!(f64::cuda_type(), "double");
538 assert_eq!(i32::cuda_type(), "int");
539 assert_eq!(i64::cuda_type(), "long long");
540 assert_eq!(u32::cuda_type(), "unsigned int");
541 assert_eq!(u64::cuda_type(), "unsigned long long");
542 }
543
544 #[test]
545 fn test_atomic_names() {
546 assert_eq!(ReductionOp::Sum.atomic_name(), "atomicAdd");
547 assert_eq!(ReductionOp::Min.atomic_name(), "atomicMin");
548 assert_eq!(ReductionOp::Max.atomic_name(), "atomicMax");
549 }
550}