1use std::fmt;
48
49#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
51pub enum ReductionOp {
52 #[default]
54 Sum,
55 Min,
57 Max,
59 And,
61 Or,
63 Xor,
65 Product,
67}
68
69impl ReductionOp {
70 pub fn operator(&self) -> &'static str {
72 match self {
73 ReductionOp::Sum => "+",
74 ReductionOp::Min => "min",
75 ReductionOp::Max => "max",
76 ReductionOp::And => "&",
77 ReductionOp::Or => "|",
78 ReductionOp::Xor => "^",
79 ReductionOp::Product => "*",
80 }
81 }
82
83 pub fn atomic_fn(&self) -> &'static str {
85 match self {
86 ReductionOp::Sum => "atomicAdd",
87 ReductionOp::Min => "atomicMin",
88 ReductionOp::Max => "atomicMax",
89 ReductionOp::And => "atomicAnd",
90 ReductionOp::Or => "atomicOr",
91 ReductionOp::Xor => "atomicXor",
92 ReductionOp::Product => "atomicMul", }
94 }
95
96 pub fn identity(&self, ty: &str) -> String {
98 match (self, ty) {
99 (ReductionOp::Sum, _) => "0".to_string(),
100 (ReductionOp::Min, "float") => "INFINITY".to_string(),
101 (ReductionOp::Min, "double") => "HUGE_VAL".to_string(),
102 (ReductionOp::Min, "int") => "INT_MAX".to_string(),
103 (ReductionOp::Min, "long long") => "LLONG_MAX".to_string(),
104 (ReductionOp::Min, "unsigned int") => "UINT_MAX".to_string(),
105 (ReductionOp::Min, "unsigned long long") => "ULLONG_MAX".to_string(),
106 (ReductionOp::Max, "float") => "-INFINITY".to_string(),
107 (ReductionOp::Max, "double") => "-HUGE_VAL".to_string(),
108 (ReductionOp::Max, "int") => "INT_MIN".to_string(),
109 (ReductionOp::Max, "long long") => "LLONG_MIN".to_string(),
110 (ReductionOp::Max, _) => "0".to_string(),
111 (ReductionOp::And, _) => "~0".to_string(),
112 (ReductionOp::Or, _) => "0".to_string(),
113 (ReductionOp::Xor, _) => "0".to_string(),
114 (ReductionOp::Product, _) => "1".to_string(),
115 _ => "0".to_string(),
116 }
117 }
118
119 pub fn combine_expr(&self, a: &str, b: &str) -> String {
121 match self {
122 ReductionOp::Sum => format!("{} + {}", a, b),
123 ReductionOp::Min => format!("min({}, {})", a, b),
124 ReductionOp::Max => format!("max({}, {})", a, b),
125 ReductionOp::And => format!("{} & {}", a, b),
126 ReductionOp::Or => format!("{} | {}", a, b),
127 ReductionOp::Xor => format!("{} ^ {}", a, b),
128 ReductionOp::Product => format!("{} * {}", a, b),
129 }
130 }
131}
132
133impl fmt::Display for ReductionOp {
134 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135 match self {
136 ReductionOp::Sum => write!(f, "sum"),
137 ReductionOp::Min => write!(f, "min"),
138 ReductionOp::Max => write!(f, "max"),
139 ReductionOp::And => write!(f, "and"),
140 ReductionOp::Or => write!(f, "or"),
141 ReductionOp::Xor => write!(f, "xor"),
142 ReductionOp::Product => write!(f, "product"),
143 }
144 }
145}
146
147#[derive(Debug, Clone)]
149pub struct ReductionCodegenConfig {
150 pub block_size: u32,
152 pub value_type: String,
154 pub op: ReductionOp,
156 pub use_cooperative: bool,
158 pub generate_helpers: bool,
160}
161
162impl Default for ReductionCodegenConfig {
163 fn default() -> Self {
164 Self {
165 block_size: 256,
166 value_type: "float".to_string(),
167 op: ReductionOp::Sum,
168 use_cooperative: true,
169 generate_helpers: true,
170 }
171 }
172}
173
174impl ReductionCodegenConfig {
175 pub fn new(block_size: u32) -> Self {
177 Self {
178 block_size,
179 ..Default::default()
180 }
181 }
182
183 pub fn with_type(mut self, ty: &str) -> Self {
185 self.value_type = ty.to_string();
186 self
187 }
188
189 pub fn with_op(mut self, op: ReductionOp) -> Self {
191 self.op = op;
192 self
193 }
194
195 pub fn with_cooperative(mut self, enabled: bool) -> Self {
197 self.use_cooperative = enabled;
198 self
199 }
200}
201
202#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
204pub enum ReductionIntrinsic {
205 BlockReduceSum,
207 BlockReduceMin,
209 BlockReduceMax,
211 BlockReduceAnd,
213 BlockReduceOr,
215
216 GridReduceSum,
218 GridReduceMin,
220 GridReduceMax,
222
223 AtomicAccumulate,
225 AtomicAccumulateFetch,
227
228 BroadcastRead,
230
231 ReduceAndBroadcast,
234}
235
236impl ReductionIntrinsic {
237 pub fn from_name(name: &str) -> Option<Self> {
239 match name {
240 "block_reduce_sum" => Some(ReductionIntrinsic::BlockReduceSum),
241 "block_reduce_min" => Some(ReductionIntrinsic::BlockReduceMin),
242 "block_reduce_max" => Some(ReductionIntrinsic::BlockReduceMax),
243 "block_reduce_and" => Some(ReductionIntrinsic::BlockReduceAnd),
244 "block_reduce_or" => Some(ReductionIntrinsic::BlockReduceOr),
245 "grid_reduce_sum" => Some(ReductionIntrinsic::GridReduceSum),
246 "grid_reduce_min" => Some(ReductionIntrinsic::GridReduceMin),
247 "grid_reduce_max" => Some(ReductionIntrinsic::GridReduceMax),
248 "atomic_accumulate" => Some(ReductionIntrinsic::AtomicAccumulate),
249 "atomic_accumulate_fetch" => Some(ReductionIntrinsic::AtomicAccumulateFetch),
250 "broadcast_read" => Some(ReductionIntrinsic::BroadcastRead),
251 "reduce_and_broadcast" => Some(ReductionIntrinsic::ReduceAndBroadcast),
252 _ => None,
253 }
254 }
255
256 pub fn op(&self) -> Option<ReductionOp> {
258 match self {
259 ReductionIntrinsic::BlockReduceSum | ReductionIntrinsic::GridReduceSum => {
260 Some(ReductionOp::Sum)
261 }
262 ReductionIntrinsic::BlockReduceMin | ReductionIntrinsic::GridReduceMin => {
263 Some(ReductionOp::Min)
264 }
265 ReductionIntrinsic::BlockReduceMax | ReductionIntrinsic::GridReduceMax => {
266 Some(ReductionOp::Max)
267 }
268 ReductionIntrinsic::BlockReduceAnd => Some(ReductionOp::And),
269 ReductionIntrinsic::BlockReduceOr => Some(ReductionOp::Or),
270 _ => None,
271 }
272 }
273}
274
275pub fn generate_reduction_helpers(config: &ReductionCodegenConfig) -> String {
283 let mut code = String::new();
284
285 if config.use_cooperative {
287 code.push_str("#include <cooperative_groups.h>\n");
288 code.push_str("namespace cg = cooperative_groups;\n\n");
289 }
290 code.push_str("#include <limits.h>\n");
291 code.push_str("#include <math.h>\n\n");
292
293 code.push_str(&generate_block_reduce_fn(
295 &config.value_type,
296 config.block_size,
297 &config.op,
298 ));
299
300 code.push_str(&generate_grid_reduce_fn(
302 &config.value_type,
303 config.block_size,
304 &config.op,
305 ));
306
307 code.push_str(&generate_reduce_and_broadcast_fn(
309 &config.value_type,
310 config.block_size,
311 &config.op,
312 config.use_cooperative,
313 ));
314
315 if !config.use_cooperative {
317 code.push_str(&generate_software_barrier());
318 }
319
320 code
321}
322
323fn shfl_combine_expr(op: &ReductionOp, val: &str, offset: &str) -> String {
325 let shfl = format!("__shfl_down_sync(0xFFFFFFFF, {}, {})", val, offset);
326 op.combine_expr(val, &shfl)
327}
328
329fn generate_block_reduce_fn(ty: &str, block_size: u32, op: &ReductionOp) -> String {
331 let warp_combine = shfl_combine_expr(op, "val", "offset");
332 let cross_warp_combine = shfl_combine_expr(op, "val", "offset");
333 let num_warps = block_size / 32;
334 let identity = op.identity(ty);
335
336 format!(
337 r#"
338// Block-level {op} reduction using warp-shuffle + shared memory
339__device__ {ty} __block_reduce_{op}({ty} val, {ty}* shared) {{
340 int tid = threadIdx.x;
341 int warp_id = tid / 32;
342 int lane_id = tid % 32;
343
344 // Phase 1: Intra-warp reduction via shuffle
345 #pragma unroll
346 for (int offset = 16; offset > 0; offset >>= 1) {{
347 val = {warp_combine};
348 }}
349
350 // Warp leaders store partial results
351 if (lane_id == 0) {{
352 shared[warp_id] = val;
353 }}
354 __syncthreads();
355
356 // Phase 2: First warp reduces across warp results
357 val = (tid < {num_warps}) ? shared[tid] : {identity};
358 if (warp_id == 0) {{
359 #pragma unroll
360 for (int offset = 16; offset > 0; offset >>= 1) {{
361 val = {cross_warp_combine};
362 }}
363 }}
364
365 return val;
366}}
367
368"#,
369 ty = ty,
370 op = op,
371 num_warps = num_warps,
372 identity = identity,
373 warp_combine = warp_combine,
374 cross_warp_combine = cross_warp_combine
375 )
376}
377
378fn generate_grid_reduce_fn(ty: &str, block_size: u32, op: &ReductionOp) -> String {
380 let atomic_fn = op.atomic_fn();
381 let warp_combine = shfl_combine_expr(op, "val", "offset");
382 let cross_warp_combine = shfl_combine_expr(op, "val", "offset");
383 let num_warps = block_size / 32;
384 let identity = op.identity(ty);
385
386 format!(
387 r#"
388// Grid-level {op} reduction with warp-shuffle + atomic accumulation
389__device__ void __grid_reduce_{op}({ty} val, {ty}* shared, {ty}* accumulator) {{
390 int tid = threadIdx.x;
391 int warp_id = tid / 32;
392 int lane_id = tid % 32;
393
394 // Phase 1: Intra-warp reduction via shuffle
395 #pragma unroll
396 for (int offset = 16; offset > 0; offset >>= 1) {{
397 val = {warp_combine};
398 }}
399
400 // Warp leaders store partial results
401 if (lane_id == 0) {{
402 shared[warp_id] = val;
403 }}
404 __syncthreads();
405
406 // Phase 2: First warp reduces across warp results
407 val = (tid < {num_warps}) ? shared[tid] : {identity};
408 if (warp_id == 0) {{
409 #pragma unroll
410 for (int offset = 16; offset > 0; offset >>= 1) {{
411 val = {cross_warp_combine};
412 }}
413 }}
414
415 // Block leader atomically accumulates
416 if (tid == 0) {{
417 {atomic_fn}(accumulator, val);
418 }}
419}}
420
421"#,
422 ty = ty,
423 op = op,
424 num_warps = num_warps,
425 identity = identity,
426 warp_combine = warp_combine,
427 cross_warp_combine = cross_warp_combine,
428 atomic_fn = atomic_fn
429 )
430}
431
432fn generate_reduce_and_broadcast_fn(
434 ty: &str,
435 block_size: u32,
436 op: &ReductionOp,
437 use_cooperative: bool,
438) -> String {
439 let atomic_fn = op.atomic_fn();
440 let warp_combine = shfl_combine_expr(op, "val", "offset");
441 let cross_warp_combine = shfl_combine_expr(op, "val", "offset");
442 let num_warps = block_size / 32;
443 let identity = op.identity(ty);
444
445 let grid_sync = if use_cooperative {
446 " cg::grid_group grid = cg::this_grid();\n grid.sync();"
447 } else {
448 " __software_grid_sync(&__barrier_counter, &__barrier_gen, gridDim.x * gridDim.y * gridDim.z);"
449 };
450
451 let params = if use_cooperative {
452 format!("{} val, {}* shared, {}* accumulator", ty, ty, ty)
453 } else {
454 format!(
455 "{} val, {}* shared, {}* accumulator, unsigned int* __barrier_counter, unsigned int* __barrier_gen",
456 ty, ty, ty
457 )
458 };
459
460 format!(
461 r#"
462// Reduce-and-broadcast: all threads get the global {op} result (warp-shuffle)
463__device__ {ty} __reduce_and_broadcast_{op}({params}) {{
464 int tid = threadIdx.x;
465 int warp_id = tid / 32;
466 int lane_id = tid % 32;
467
468 // Phase 1: Intra-warp reduction via shuffle
469 #pragma unroll
470 for (int offset = 16; offset > 0; offset >>= 1) {{
471 val = {warp_combine};
472 }}
473
474 // Warp leaders store partial results
475 if (lane_id == 0) {{
476 shared[warp_id] = val;
477 }}
478 __syncthreads();
479
480 // Phase 2: First warp reduces across warp results
481 val = (tid < {num_warps}) ? shared[tid] : {identity};
482 if (warp_id == 0) {{
483 #pragma unroll
484 for (int offset = 16; offset > 0; offset >>= 1) {{
485 val = {cross_warp_combine};
486 }}
487 }}
488
489 // Phase 3: Block leader atomically accumulates
490 if (tid == 0) {{
491 {atomic_fn}(accumulator, val);
492 }}
493
494 // Phase 4: Grid-wide synchronization
495{grid_sync}
496
497 // Phase 5: All threads read the result
498 return *accumulator;
499}}
500
501"#,
502 ty = ty,
503 op = op,
504 num_warps = num_warps,
505 identity = identity,
506 warp_combine = warp_combine,
507 cross_warp_combine = cross_warp_combine,
508 atomic_fn = atomic_fn,
509 params = params,
510 grid_sync = grid_sync
511 )
512}
513
514fn generate_software_barrier() -> String {
516 r#"
517// Software grid barrier using global memory atomics
518// Use when cooperative groups are unavailable
519__device__ void __software_grid_sync(
520 volatile unsigned int* barrier_counter,
521 volatile unsigned int* barrier_gen,
522 unsigned int num_blocks
523) {
524 __syncthreads();
525
526 if (threadIdx.x == 0) {
527 unsigned int gen = *barrier_gen;
528 unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
529
530 if (arrived == num_blocks) {
531 // Last block to arrive: reset counter and increment generation
532 *barrier_counter = 0;
533 __threadfence();
534 atomicAdd((unsigned int*)barrier_gen, 1);
535 } else {
536 // Wait for all blocks
537 while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
538 __threadfence();
539 }
540 }
541 }
542
543 __syncthreads();
544}
545
546"#
547 .to_string()
548}
549
550pub fn transpile_reduction_call(
562 intrinsic: ReductionIntrinsic,
563 args: &[String],
564 config: &ReductionCodegenConfig,
565) -> Result<String, String> {
566 match intrinsic {
567 ReductionIntrinsic::BlockReduceSum => {
568 validate_args(args, 2, "block_reduce_sum")?;
569 Ok(format!("__block_reduce_sum({}, {})", args[0], args[1]))
570 }
571 ReductionIntrinsic::BlockReduceMin => {
572 validate_args(args, 2, "block_reduce_min")?;
573 Ok(format!("__block_reduce_min({}, {})", args[0], args[1]))
574 }
575 ReductionIntrinsic::BlockReduceMax => {
576 validate_args(args, 2, "block_reduce_max")?;
577 Ok(format!("__block_reduce_max({}, {})", args[0], args[1]))
578 }
579 ReductionIntrinsic::BlockReduceAnd => {
580 validate_args(args, 2, "block_reduce_and")?;
581 Ok(format!("__block_reduce_and({}, {})", args[0], args[1]))
582 }
583 ReductionIntrinsic::BlockReduceOr => {
584 validate_args(args, 2, "block_reduce_or")?;
585 Ok(format!("__block_reduce_or({}, {})", args[0], args[1]))
586 }
587 ReductionIntrinsic::GridReduceSum => {
588 validate_args(args, 3, "grid_reduce_sum")?;
589 Ok(format!(
590 "__grid_reduce_sum({}, {}, {})",
591 args[0], args[1], args[2]
592 ))
593 }
594 ReductionIntrinsic::GridReduceMin => {
595 validate_args(args, 3, "grid_reduce_min")?;
596 Ok(format!(
597 "__grid_reduce_min({}, {}, {})",
598 args[0], args[1], args[2]
599 ))
600 }
601 ReductionIntrinsic::GridReduceMax => {
602 validate_args(args, 3, "grid_reduce_max")?;
603 Ok(format!(
604 "__grid_reduce_max({}, {}, {})",
605 args[0], args[1], args[2]
606 ))
607 }
608 ReductionIntrinsic::AtomicAccumulate => {
609 validate_args(args, 2, "atomic_accumulate")?;
610 Ok(format!("atomicAdd({}, {})", args[0], args[1]))
611 }
612 ReductionIntrinsic::AtomicAccumulateFetch => {
613 validate_args(args, 2, "atomic_accumulate_fetch")?;
614 Ok(format!("atomicAdd({}, {})", args[0], args[1]))
615 }
616 ReductionIntrinsic::BroadcastRead => {
617 validate_args(args, 1, "broadcast_read")?;
618 Ok(format!("(*{})", args[0]))
619 }
620 ReductionIntrinsic::ReduceAndBroadcast => {
621 if config.use_cooperative {
622 validate_args(args, 3, "reduce_and_broadcast")?;
623 Ok(format!(
624 "__reduce_and_broadcast_sum({}, {}, {})",
625 args[0], args[1], args[2]
626 ))
627 } else {
628 validate_args(args, 5, "reduce_and_broadcast (software barrier)")?;
629 Ok(format!(
630 "__reduce_and_broadcast_sum({}, {}, {}, {}, {})",
631 args[0], args[1], args[2], args[3], args[4]
632 ))
633 }
634 }
635 }
636}
637
638fn validate_args(args: &[String], expected: usize, name: &str) -> Result<(), String> {
639 if args.len() != expected {
640 Err(format!(
641 "{} requires {} arguments, got {}",
642 name,
643 expected,
644 args.len()
645 ))
646 } else {
647 Ok(())
648 }
649}
650
651pub fn generate_inline_block_reduce(
656 value_expr: &str,
657 shared_array: &str,
658 result_var: &str,
659 ty: &str,
660 block_size: u32,
661 op: &ReductionOp,
662) -> String {
663 let warp_combine = shfl_combine_expr(op, "__val", "__offset");
664 let cross_warp_combine = shfl_combine_expr(op, "__val", "__offset");
665 let num_warps = block_size / 32;
666 let identity = op.identity(ty);
667
668 format!(
669 r#"{{
670 int __tid = threadIdx.x;
671 int __warp_id = __tid / 32;
672 int __lane_id = __tid % 32;
673 {ty} __val = {value_expr};
674
675 // Phase 1: Intra-warp reduction via shuffle
676 #pragma unroll
677 for (int __offset = 16; __offset > 0; __offset >>= 1) {{
678 __val = {warp_combine};
679 }}
680
681 // Warp leaders store partial results
682 if (__lane_id == 0) {{
683 {shared_array}[__warp_id] = __val;
684 }}
685 __syncthreads();
686
687 // Phase 2: First warp reduces across warp results
688 __val = (__tid < {num_warps}) ? {shared_array}[__tid] : {identity};
689 if (__warp_id == 0) {{
690 #pragma unroll
691 for (int __offset = 16; __offset > 0; __offset >>= 1) {{
692 __val = {cross_warp_combine};
693 }}
694 }}
695
696 if (__tid == 0) {{
697 {shared_array}[0] = __val;
698 }}
699
700 {ty} {result_var} = {shared_array}[0];
701}}"#,
702 value_expr = value_expr,
703 shared_array = shared_array,
704 result_var = result_var,
705 ty = ty,
706 num_warps = num_warps,
707 identity = identity,
708 warp_combine = warp_combine,
709 cross_warp_combine = cross_warp_combine
710 )
711}
712
713pub fn generate_inline_grid_reduce(
715 value_expr: &str,
716 shared_array: &str,
717 accumulator: &str,
718 ty: &str,
719 block_size: u32,
720 op: &ReductionOp,
721) -> String {
722 let atomic_fn = op.atomic_fn();
723 let warp_combine = shfl_combine_expr(op, "__val", "__offset");
724 let cross_warp_combine = shfl_combine_expr(op, "__val", "__offset");
725 let num_warps = block_size / 32;
726 let identity = op.identity(ty);
727
728 format!(
729 r#"{{
730 int __tid = threadIdx.x;
731 int __warp_id = __tid / 32;
732 int __lane_id = __tid % 32;
733 {ty} __val = {value_expr};
734
735 // Phase 1: Intra-warp reduction via shuffle
736 #pragma unroll
737 for (int __offset = 16; __offset > 0; __offset >>= 1) {{
738 __val = {warp_combine};
739 }}
740
741 // Warp leaders store partial results
742 if (__lane_id == 0) {{
743 {shared_array}[__warp_id] = __val;
744 }}
745 __syncthreads();
746
747 // Phase 2: First warp reduces across warp results
748 __val = (__tid < {num_warps}) ? {shared_array}[__tid] : {identity};
749 if (__warp_id == 0) {{
750 #pragma unroll
751 for (int __offset = 16; __offset > 0; __offset >>= 1) {{
752 __val = {cross_warp_combine};
753 }}
754 }}
755
756 if (__tid == 0) {{
757 {atomic_fn}({accumulator}, __val);
758 }}
759}}"#,
760 ty = ty,
761 value_expr = value_expr,
762 shared_array = shared_array,
763 accumulator = accumulator,
764 num_warps = num_warps,
765 identity = identity,
766 warp_combine = warp_combine,
767 cross_warp_combine = cross_warp_combine,
768 atomic_fn = atomic_fn
769 )
770}
771
772#[allow(clippy::too_many_arguments)]
774pub fn generate_inline_reduce_and_broadcast(
775 value_expr: &str,
776 shared_array: &str,
777 accumulator: &str,
778 result_var: &str,
779 ty: &str,
780 block_size: u32,
781 op: &ReductionOp,
782 use_cooperative: bool,
783) -> String {
784 let atomic_fn = op.atomic_fn();
785 let warp_combine = shfl_combine_expr(op, "__val", "__offset");
786 let cross_warp_combine = shfl_combine_expr(op, "__val", "__offset");
787 let num_warps = block_size / 32;
788 let identity = op.identity(ty);
789
790 let grid_sync = if use_cooperative {
791 " cg::grid_group __grid = cg::this_grid();\n __grid.sync();"
792 } else {
793 " __software_grid_sync(&__barrier_counter, &__barrier_gen, gridDim.x * gridDim.y * gridDim.z);"
794 };
795
796 format!(
797 r#"{{
798 int __tid = threadIdx.x;
799 int __warp_id = __tid / 32;
800 int __lane_id = __tid % 32;
801 {ty} __val = {value_expr};
802
803 // Phase 1: Intra-warp reduction via shuffle
804 #pragma unroll
805 for (int __offset = 16; __offset > 0; __offset >>= 1) {{
806 __val = {warp_combine};
807 }}
808
809 // Warp leaders store partial results
810 if (__lane_id == 0) {{
811 {shared_array}[__warp_id] = __val;
812 }}
813 __syncthreads();
814
815 // Phase 2: First warp reduces across warp results
816 __val = (__tid < {num_warps}) ? {shared_array}[__tid] : {identity};
817 if (__warp_id == 0) {{
818 #pragma unroll
819 for (int __offset = 16; __offset > 0; __offset >>= 1) {{
820 __val = {cross_warp_combine};
821 }}
822 }}
823
824 // Atomic accumulation
825 if (__tid == 0) {{
826 {atomic_fn}({accumulator}, __val);
827 }}
828
829 // Grid synchronization
830{grid_sync}
831
832 // Broadcast: all threads read result
833 {ty} {result_var} = *{accumulator};
834}}"#,
835 value_expr = value_expr,
836 shared_array = shared_array,
837 accumulator = accumulator,
838 result_var = result_var,
839 ty = ty,
840 num_warps = num_warps,
841 identity = identity,
842 warp_combine = warp_combine,
843 cross_warp_combine = cross_warp_combine,
844 atomic_fn = atomic_fn,
845 grid_sync = grid_sync
846 )
847}
848
849#[cfg(test)]
850mod tests {
851 use super::*;
852
853 #[test]
854 fn test_reduction_op_combine() {
855 assert_eq!(ReductionOp::Sum.combine_expr("a", "b"), "a + b");
856 assert_eq!(ReductionOp::Min.combine_expr("a", "b"), "min(a, b)");
857 assert_eq!(ReductionOp::Max.combine_expr("a", "b"), "max(a, b)");
858 assert_eq!(ReductionOp::And.combine_expr("a", "b"), "a & b");
859 }
860
861 #[test]
862 fn test_reduction_op_identity() {
863 assert_eq!(ReductionOp::Sum.identity("float"), "0");
864 assert_eq!(ReductionOp::Min.identity("float"), "INFINITY");
865 assert_eq!(ReductionOp::Max.identity("float"), "-INFINITY");
866 assert_eq!(ReductionOp::Product.identity("int"), "1");
867 }
868
869 #[test]
870 fn test_intrinsic_from_name() {
871 assert_eq!(
872 ReductionIntrinsic::from_name("block_reduce_sum"),
873 Some(ReductionIntrinsic::BlockReduceSum)
874 );
875 assert_eq!(
876 ReductionIntrinsic::from_name("grid_reduce_min"),
877 Some(ReductionIntrinsic::GridReduceMin)
878 );
879 assert_eq!(
880 ReductionIntrinsic::from_name("reduce_and_broadcast"),
881 Some(ReductionIntrinsic::ReduceAndBroadcast)
882 );
883 assert_eq!(ReductionIntrinsic::from_name("not_an_intrinsic"), None);
884 }
885
886 #[test]
887 fn test_generate_block_reduce() {
888 let code = generate_block_reduce_fn("float", 256, &ReductionOp::Sum);
889 assert!(code.contains("__device__ float __block_reduce_sum"));
890 assert!(code.contains("__shfl_down_sync(0xFFFFFFFF, val, offset)"));
891 assert!(code.contains("__syncthreads()"));
892 assert!(code.contains("warp_id"));
893 assert!(code.contains("lane_id"));
894 }
895
896 #[test]
897 fn test_generate_grid_reduce() {
898 let code = generate_grid_reduce_fn("double", 128, &ReductionOp::Max);
899 assert!(code.contains("__device__ void __grid_reduce_max"));
900 assert!(code.contains("__shfl_down_sync(0xFFFFFFFF, val, offset)"));
901 assert!(code.contains("atomicMax(accumulator, val)"));
902 }
903
904 #[test]
905 fn test_generate_reduce_and_broadcast_cooperative() {
906 let code = generate_reduce_and_broadcast_fn("float", 256, &ReductionOp::Sum, true);
907 assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
908 assert!(code.contains("grid.sync()"));
909 }
910
911 #[test]
912 fn test_generate_reduce_and_broadcast_software() {
913 let code = generate_reduce_and_broadcast_fn("float", 256, &ReductionOp::Sum, false);
914 assert!(code.contains("__barrier_counter"));
915 assert!(code.contains("__software_grid_sync"));
916 }
917
918 #[test]
919 fn test_inline_block_reduce() {
920 let code = generate_inline_block_reduce(
921 "my_value",
922 "shared_mem",
923 "result",
924 "float",
925 256,
926 &ReductionOp::Sum,
927 );
928 assert!(code.contains("__shfl_down_sync(0xFFFFFFFF, __val, __offset)"));
929 assert!(code.contains("shared_mem[__warp_id] = __val"));
930 assert!(code.contains("float result = shared_mem[0]"));
931 }
932
933 #[test]
934 fn test_transpile_reduction_call() {
935 let config = ReductionCodegenConfig::default();
936
937 let result = transpile_reduction_call(
938 ReductionIntrinsic::BlockReduceSum,
939 &["val".to_string(), "shared".to_string()],
940 &config,
941 );
942 assert_eq!(result.unwrap(), "__block_reduce_sum(val, shared)");
943
944 let result = transpile_reduction_call(
945 ReductionIntrinsic::AtomicAccumulate,
946 &["&accum".to_string(), "val".to_string()],
947 &config,
948 );
949 assert_eq!(result.unwrap(), "atomicAdd(&accum, val)");
950 }
951
952 #[test]
953 fn test_reduction_helpers() {
954 let config = ReductionCodegenConfig::new(256)
955 .with_type("float")
956 .with_op(ReductionOp::Sum)
957 .with_cooperative(true);
958
959 let helpers = generate_reduction_helpers(&config);
960 assert!(helpers.contains("#include <cooperative_groups.h>"));
961 assert!(helpers.contains("__block_reduce_sum"));
962 assert!(helpers.contains("__grid_reduce_sum"));
963 assert!(helpers.contains("__reduce_and_broadcast_sum"));
964 }
965}