ringkernel_cuda_codegen/
reduction_intrinsics.rs

1//! Reduction Intrinsics for CUDA Code Generation
2//!
3//! This module provides DSL intrinsics for efficient parallel reductions that
4//! transpile to optimized CUDA code. Reductions aggregate values across threads
5//! using operations like sum, min, max, etc.
6//!
7//! # Reduction Hierarchy
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────────────────────┐
11//! │                     Grid-Level Reduction                        │
12//! │  ┌───────────────────────────────────────────────────────────┐  │
13//! │  │                    Block 0                                │  │
14//! │  │  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐         │  │
15//! │  │  │ Warp 0  │ │ Warp 1  │ │ Warp 2  │ │ Warp N  │         │  │
16//! │  │  │ shuffle │ │ shuffle │ │ shuffle │ │ shuffle │         │  │
17//! │  │  └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘         │  │
18//! │  │       └──────────────┴──────────┴──────────┘             │  │
19//! │  │                      │ shared memory                     │  │
20//! │  │                      ▼                                   │  │
21//! │  │              block_reduce_sum()                          │  │
22//! │  │                      │                                   │  │
23//! │  └──────────────────────┼───────────────────────────────────┘  │
24//! │                         ▼ atomicAdd                            │
25//! │  ┌─────────────────────────────────────────────────────────┐  │
26//! │  │              Global Accumulator (mapped memory)          │  │
27//! │  └─────────────────────────────────────────────────────────┘  │
28//! │                         │ grid.sync() / barrier                │
29//! │                         ▼                                      │
30//! │              All threads read final result                     │
31//! └─────────────────────────────────────────────────────────────────┘
32//! ```
33//!
34//! # Example
35//!
36//! ```ignore
37//! // Rust DSL
38//! fn pagerank_phase1(ranks: &[f64], out_degree: &[u32], dangling_sum: &mut f64, n: u32) {
39//!     let idx = global_thread_idx();
40//!     if idx >= n { return; }
41//!
42//!     let contrib = if out_degree[idx] == 0 { ranks[idx] } else { 0.0 };
43//!     reduce_and_broadcast(contrib, dangling_sum);  // All threads get sum
44//! }
45//! ```
46
47use std::fmt;
48
49/// Reduction operation types for code generation.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
51pub enum ReductionOp {
52    /// Sum: `a + b`
53    #[default]
54    Sum,
55    /// Minimum: `min(a, b)`
56    Min,
57    /// Maximum: `max(a, b)`
58    Max,
59    /// Bitwise AND: `a & b`
60    And,
61    /// Bitwise OR: `a | b`
62    Or,
63    /// Bitwise XOR: `a ^ b`
64    Xor,
65    /// Product: `a * b`
66    Product,
67}
68
69impl ReductionOp {
70    /// Get the binary operator for this reduction.
71    pub fn operator(&self) -> &'static str {
72        match self {
73            ReductionOp::Sum => "+",
74            ReductionOp::Min => "min",
75            ReductionOp::Max => "max",
76            ReductionOp::And => "&",
77            ReductionOp::Or => "|",
78            ReductionOp::Xor => "^",
79            ReductionOp::Product => "*",
80        }
81    }
82
83    /// Get the CUDA atomic function name.
84    pub fn atomic_fn(&self) -> &'static str {
85        match self {
86            ReductionOp::Sum => "atomicAdd",
87            ReductionOp::Min => "atomicMin",
88            ReductionOp::Max => "atomicMax",
89            ReductionOp::And => "atomicAnd",
90            ReductionOp::Or => "atomicOr",
91            ReductionOp::Xor => "atomicXor",
92            ReductionOp::Product => "atomicMul", // Custom implementation
93        }
94    }
95
96    /// Get the identity value expression for this operation.
97    pub fn identity(&self, ty: &str) -> String {
98        match (self, ty) {
99            (ReductionOp::Sum, _) => "0".to_string(),
100            (ReductionOp::Min, "float") => "INFINITY".to_string(),
101            (ReductionOp::Min, "double") => "HUGE_VAL".to_string(),
102            (ReductionOp::Min, "int") => "INT_MAX".to_string(),
103            (ReductionOp::Min, "long long") => "LLONG_MAX".to_string(),
104            (ReductionOp::Min, "unsigned int") => "UINT_MAX".to_string(),
105            (ReductionOp::Min, "unsigned long long") => "ULLONG_MAX".to_string(),
106            (ReductionOp::Max, "float") => "-INFINITY".to_string(),
107            (ReductionOp::Max, "double") => "-HUGE_VAL".to_string(),
108            (ReductionOp::Max, "int") => "INT_MIN".to_string(),
109            (ReductionOp::Max, "long long") => "LLONG_MIN".to_string(),
110            (ReductionOp::Max, _) => "0".to_string(),
111            (ReductionOp::And, _) => "~0".to_string(),
112            (ReductionOp::Or, _) => "0".to_string(),
113            (ReductionOp::Xor, _) => "0".to_string(),
114            (ReductionOp::Product, _) => "1".to_string(),
115            _ => "0".to_string(),
116        }
117    }
118
119    /// Get the combination expression for two values.
120    pub fn combine_expr(&self, a: &str, b: &str) -> String {
121        match self {
122            ReductionOp::Sum => format!("{} + {}", a, b),
123            ReductionOp::Min => format!("min({}, {})", a, b),
124            ReductionOp::Max => format!("max({}, {})", a, b),
125            ReductionOp::And => format!("{} & {}", a, b),
126            ReductionOp::Or => format!("{} | {}", a, b),
127            ReductionOp::Xor => format!("{} ^ {}", a, b),
128            ReductionOp::Product => format!("{} * {}", a, b),
129        }
130    }
131}
132
133impl fmt::Display for ReductionOp {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        match self {
136            ReductionOp::Sum => write!(f, "sum"),
137            ReductionOp::Min => write!(f, "min"),
138            ReductionOp::Max => write!(f, "max"),
139            ReductionOp::And => write!(f, "and"),
140            ReductionOp::Or => write!(f, "or"),
141            ReductionOp::Xor => write!(f, "xor"),
142            ReductionOp::Product => write!(f, "product"),
143        }
144    }
145}
146
147/// Configuration for reduction code generation.
148#[derive(Debug, Clone)]
149pub struct ReductionCodegenConfig {
150    /// Thread block size (must be power of 2).
151    pub block_size: u32,
152    /// Value type (CUDA type name).
153    pub value_type: String,
154    /// Reduction operation.
155    pub op: ReductionOp,
156    /// Use cooperative groups for grid-wide sync.
157    pub use_cooperative: bool,
158    /// Generate helper function definitions.
159    pub generate_helpers: bool,
160}
161
162impl Default for ReductionCodegenConfig {
163    fn default() -> Self {
164        Self {
165            block_size: 256,
166            value_type: "float".to_string(),
167            op: ReductionOp::Sum,
168            use_cooperative: true,
169            generate_helpers: true,
170        }
171    }
172}
173
174impl ReductionCodegenConfig {
175    /// Create a new config with the given block size.
176    pub fn new(block_size: u32) -> Self {
177        Self {
178            block_size,
179            ..Default::default()
180        }
181    }
182
183    /// Set the value type.
184    pub fn with_type(mut self, ty: &str) -> Self {
185        self.value_type = ty.to_string();
186        self
187    }
188
189    /// Set the reduction operation.
190    pub fn with_op(mut self, op: ReductionOp) -> Self {
191        self.op = op;
192        self
193    }
194
195    /// Enable or disable cooperative groups.
196    pub fn with_cooperative(mut self, enabled: bool) -> Self {
197        self.use_cooperative = enabled;
198        self
199    }
200}
201
202/// Reduction intrinsic types for the DSL.
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
204pub enum ReductionIntrinsic {
205    /// Block-level sum reduction: `block_reduce_sum(value, shared_mem)`
206    BlockReduceSum,
207    /// Block-level min reduction: `block_reduce_min(value, shared_mem)`
208    BlockReduceMin,
209    /// Block-level max reduction: `block_reduce_max(value, shared_mem)`
210    BlockReduceMax,
211    /// Block-level AND reduction: `block_reduce_and(value, shared_mem)`
212    BlockReduceAnd,
213    /// Block-level OR reduction: `block_reduce_or(value, shared_mem)`
214    BlockReduceOr,
215
216    /// Grid-level sum with atomic accumulation: `grid_reduce_sum(value, shared, accumulator)`
217    GridReduceSum,
218    /// Grid-level min with atomic accumulation: `grid_reduce_min(value, shared, accumulator)`
219    GridReduceMin,
220    /// Grid-level max with atomic accumulation: `grid_reduce_max(value, shared, accumulator)`
221    GridReduceMax,
222
223    /// Atomic accumulate to buffer: `atomic_accumulate(accumulator, value)`
224    AtomicAccumulate,
225    /// Atomic accumulate with fetch: `atomic_accumulate_fetch(accumulator, value)`
226    AtomicAccumulateFetch,
227
228    /// Read from reduction buffer (broadcast): `broadcast_read(buffer)`
229    BroadcastRead,
230
231    /// Full reduce-and-broadcast: `reduce_and_broadcast(value, shared, accumulator)`
232    /// Returns the global result to all threads.
233    ReduceAndBroadcast,
234}
235
236impl ReductionIntrinsic {
237    /// Parse intrinsic from DSL function name.
238    pub fn from_name(name: &str) -> Option<Self> {
239        match name {
240            "block_reduce_sum" => Some(ReductionIntrinsic::BlockReduceSum),
241            "block_reduce_min" => Some(ReductionIntrinsic::BlockReduceMin),
242            "block_reduce_max" => Some(ReductionIntrinsic::BlockReduceMax),
243            "block_reduce_and" => Some(ReductionIntrinsic::BlockReduceAnd),
244            "block_reduce_or" => Some(ReductionIntrinsic::BlockReduceOr),
245            "grid_reduce_sum" => Some(ReductionIntrinsic::GridReduceSum),
246            "grid_reduce_min" => Some(ReductionIntrinsic::GridReduceMin),
247            "grid_reduce_max" => Some(ReductionIntrinsic::GridReduceMax),
248            "atomic_accumulate" => Some(ReductionIntrinsic::AtomicAccumulate),
249            "atomic_accumulate_fetch" => Some(ReductionIntrinsic::AtomicAccumulateFetch),
250            "broadcast_read" => Some(ReductionIntrinsic::BroadcastRead),
251            "reduce_and_broadcast" => Some(ReductionIntrinsic::ReduceAndBroadcast),
252            _ => None,
253        }
254    }
255
256    /// Get the reduction operation for this intrinsic.
257    pub fn op(&self) -> Option<ReductionOp> {
258        match self {
259            ReductionIntrinsic::BlockReduceSum | ReductionIntrinsic::GridReduceSum => {
260                Some(ReductionOp::Sum)
261            }
262            ReductionIntrinsic::BlockReduceMin | ReductionIntrinsic::GridReduceMin => {
263                Some(ReductionOp::Min)
264            }
265            ReductionIntrinsic::BlockReduceMax | ReductionIntrinsic::GridReduceMax => {
266                Some(ReductionOp::Max)
267            }
268            ReductionIntrinsic::BlockReduceAnd => Some(ReductionOp::And),
269            ReductionIntrinsic::BlockReduceOr => Some(ReductionOp::Or),
270            _ => None,
271        }
272    }
273}
274
275// =============================================================================
276// Code Generation Functions
277// =============================================================================
278
279/// Generate CUDA helper functions for reduction operations.
280///
281/// This should be included at the top of kernels that use reduction intrinsics.
282pub fn generate_reduction_helpers(config: &ReductionCodegenConfig) -> String {
283    let mut code = String::new();
284
285    // Includes
286    if config.use_cooperative {
287        code.push_str("#include <cooperative_groups.h>\n");
288        code.push_str("namespace cg = cooperative_groups;\n\n");
289    }
290    code.push_str("#include <limits.h>\n");
291    code.push_str("#include <math.h>\n\n");
292
293    // Block-level reduction function
294    code.push_str(&generate_block_reduce_fn(
295        &config.value_type,
296        config.block_size,
297        &config.op,
298    ));
299
300    // Grid-level reduction function
301    code.push_str(&generate_grid_reduce_fn(
302        &config.value_type,
303        config.block_size,
304        &config.op,
305    ));
306
307    // Reduce-and-broadcast function
308    code.push_str(&generate_reduce_and_broadcast_fn(
309        &config.value_type,
310        config.block_size,
311        &config.op,
312        config.use_cooperative,
313    ));
314
315    // Software barrier (if not using cooperative groups)
316    if !config.use_cooperative {
317        code.push_str(&generate_software_barrier());
318    }
319
320    code
321}
322
323/// Generate a block-level reduction function.
324fn generate_block_reduce_fn(ty: &str, block_size: u32, op: &ReductionOp) -> String {
325    let combine = op.combine_expr("shared[tid]", "shared[tid + s]");
326
327    format!(
328        r#"
329// Block-level {op} reduction using shared memory
330__device__ {ty} __block_reduce_{op}({ty} val, {ty}* shared) {{
331    int tid = threadIdx.x;
332    shared[tid] = val;
333    __syncthreads();
334
335    // Tree reduction
336    #pragma unroll
337    for (int s = {block_size} / 2; s > 0; s >>= 1) {{
338        if (tid < s) {{
339            shared[tid] = {combine};
340        }}
341        __syncthreads();
342    }}
343
344    return shared[0];
345}}
346
347"#,
348        ty = ty,
349        op = op,
350        block_size = block_size,
351        combine = combine
352    )
353}
354
355/// Generate a grid-level reduction function with atomic accumulation.
356fn generate_grid_reduce_fn(ty: &str, block_size: u32, op: &ReductionOp) -> String {
357    let atomic_fn = op.atomic_fn();
358    let combine = op.combine_expr("shared[tid]", "shared[tid + s]");
359
360    format!(
361        r#"
362// Grid-level {op} reduction with atomic accumulation
363__device__ void __grid_reduce_{op}({ty} val, {ty}* shared, {ty}* accumulator) {{
364    int tid = threadIdx.x;
365    shared[tid] = val;
366    __syncthreads();
367
368    // Block-level tree reduction
369    #pragma unroll
370    for (int s = {block_size} / 2; s > 0; s >>= 1) {{
371        if (tid < s) {{
372            shared[tid] = {combine};
373        }}
374        __syncthreads();
375    }}
376
377    // Block leader atomically accumulates
378    if (tid == 0) {{
379        {atomic_fn}(accumulator, shared[0]);
380    }}
381}}
382
383"#,
384        ty = ty,
385        op = op,
386        block_size = block_size,
387        combine = combine,
388        atomic_fn = atomic_fn
389    )
390}
391
392/// Generate a reduce-and-broadcast function.
393fn generate_reduce_and_broadcast_fn(
394    ty: &str,
395    block_size: u32,
396    op: &ReductionOp,
397    use_cooperative: bool,
398) -> String {
399    let atomic_fn = op.atomic_fn();
400    let combine = op.combine_expr("shared[tid]", "shared[tid + s]");
401
402    let grid_sync = if use_cooperative {
403        "    cg::grid_group grid = cg::this_grid();\n    grid.sync();"
404    } else {
405        "    __software_grid_sync(&__barrier_counter, &__barrier_gen, gridDim.x * gridDim.y * gridDim.z);"
406    };
407
408    let params = if use_cooperative {
409        format!("{} val, {}* shared, {}* accumulator", ty, ty, ty)
410    } else {
411        format!(
412            "{} val, {}* shared, {}* accumulator, unsigned int* __barrier_counter, unsigned int* __barrier_gen",
413            ty, ty, ty
414        )
415    };
416
417    format!(
418        r#"
419// Reduce-and-broadcast: all threads get the global {op} result
420__device__ {ty} __reduce_and_broadcast_{op}({params}) {{
421    int tid = threadIdx.x;
422    shared[tid] = val;
423    __syncthreads();
424
425    // Phase 1: Block-level reduction
426    #pragma unroll
427    for (int s = {block_size} / 2; s > 0; s >>= 1) {{
428        if (tid < s) {{
429            shared[tid] = {combine};
430        }}
431        __syncthreads();
432    }}
433
434    // Phase 2: Block leader atomically accumulates
435    if (tid == 0) {{
436        {atomic_fn}(accumulator, shared[0]);
437    }}
438
439    // Phase 3: Grid-wide synchronization
440{grid_sync}
441
442    // Phase 4: All threads read the result
443    return *accumulator;
444}}
445
446"#,
447        ty = ty,
448        op = op,
449        block_size = block_size,
450        combine = combine,
451        atomic_fn = atomic_fn,
452        params = params,
453        grid_sync = grid_sync
454    )
455}
456
457/// Generate a software grid barrier for non-cooperative kernels.
458fn generate_software_barrier() -> String {
459    r#"
460// Software grid barrier using global memory atomics
461// Use when cooperative groups are unavailable
462__device__ void __software_grid_sync(
463    volatile unsigned int* barrier_counter,
464    volatile unsigned int* barrier_gen,
465    unsigned int num_blocks
466) {
467    __syncthreads();
468
469    if (threadIdx.x == 0) {
470        unsigned int gen = *barrier_gen;
471        unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
472
473        if (arrived == num_blocks) {
474            // Last block to arrive: reset counter and increment generation
475            *barrier_counter = 0;
476            __threadfence();
477            atomicAdd((unsigned int*)barrier_gen, 1);
478        } else {
479            // Wait for all blocks
480            while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
481                __threadfence();
482            }
483        }
484    }
485
486    __syncthreads();
487}
488
489"#
490    .to_string()
491}
492
493/// Transpile a reduction intrinsic call to CUDA code.
494///
495/// # Arguments
496///
497/// * `intrinsic` - The reduction intrinsic being called
498/// * `args` - The transpiled argument expressions
499/// * `config` - Reduction configuration
500///
501/// # Returns
502///
503/// The CUDA expression or statement for this intrinsic.
504pub fn transpile_reduction_call(
505    intrinsic: ReductionIntrinsic,
506    args: &[String],
507    config: &ReductionCodegenConfig,
508) -> Result<String, String> {
509    match intrinsic {
510        ReductionIntrinsic::BlockReduceSum => {
511            validate_args(args, 2, "block_reduce_sum")?;
512            Ok(format!("__block_reduce_sum({}, {})", args[0], args[1]))
513        }
514        ReductionIntrinsic::BlockReduceMin => {
515            validate_args(args, 2, "block_reduce_min")?;
516            Ok(format!("__block_reduce_min({}, {})", args[0], args[1]))
517        }
518        ReductionIntrinsic::BlockReduceMax => {
519            validate_args(args, 2, "block_reduce_max")?;
520            Ok(format!("__block_reduce_max({}, {})", args[0], args[1]))
521        }
522        ReductionIntrinsic::BlockReduceAnd => {
523            validate_args(args, 2, "block_reduce_and")?;
524            Ok(format!("__block_reduce_and({}, {})", args[0], args[1]))
525        }
526        ReductionIntrinsic::BlockReduceOr => {
527            validate_args(args, 2, "block_reduce_or")?;
528            Ok(format!("__block_reduce_or({}, {})", args[0], args[1]))
529        }
530        ReductionIntrinsic::GridReduceSum => {
531            validate_args(args, 3, "grid_reduce_sum")?;
532            Ok(format!(
533                "__grid_reduce_sum({}, {}, {})",
534                args[0], args[1], args[2]
535            ))
536        }
537        ReductionIntrinsic::GridReduceMin => {
538            validate_args(args, 3, "grid_reduce_min")?;
539            Ok(format!(
540                "__grid_reduce_min({}, {}, {})",
541                args[0], args[1], args[2]
542            ))
543        }
544        ReductionIntrinsic::GridReduceMax => {
545            validate_args(args, 3, "grid_reduce_max")?;
546            Ok(format!(
547                "__grid_reduce_max({}, {}, {})",
548                args[0], args[1], args[2]
549            ))
550        }
551        ReductionIntrinsic::AtomicAccumulate => {
552            validate_args(args, 2, "atomic_accumulate")?;
553            Ok(format!("atomicAdd({}, {})", args[0], args[1]))
554        }
555        ReductionIntrinsic::AtomicAccumulateFetch => {
556            validate_args(args, 2, "atomic_accumulate_fetch")?;
557            Ok(format!("atomicAdd({}, {})", args[0], args[1]))
558        }
559        ReductionIntrinsic::BroadcastRead => {
560            validate_args(args, 1, "broadcast_read")?;
561            Ok(format!("(*{})", args[0]))
562        }
563        ReductionIntrinsic::ReduceAndBroadcast => {
564            if config.use_cooperative {
565                validate_args(args, 3, "reduce_and_broadcast")?;
566                Ok(format!(
567                    "__reduce_and_broadcast_sum({}, {}, {})",
568                    args[0], args[1], args[2]
569                ))
570            } else {
571                validate_args(args, 5, "reduce_and_broadcast (software barrier)")?;
572                Ok(format!(
573                    "__reduce_and_broadcast_sum({}, {}, {}, {}, {})",
574                    args[0], args[1], args[2], args[3], args[4]
575                ))
576            }
577        }
578    }
579}
580
581fn validate_args(args: &[String], expected: usize, name: &str) -> Result<(), String> {
582    if args.len() != expected {
583        Err(format!(
584            "{} requires {} arguments, got {}",
585            name,
586            expected,
587            args.len()
588        ))
589    } else {
590        Ok(())
591    }
592}
593
594/// Generate inline reduction code (without helper function call).
595///
596/// Use this when you want the reduction logic inlined in a specific location.
597pub fn generate_inline_block_reduce(
598    value_expr: &str,
599    shared_array: &str,
600    result_var: &str,
601    ty: &str,
602    block_size: u32,
603    op: &ReductionOp,
604) -> String {
605    let combine = op.combine_expr(
606        &format!("{}[__tid]", shared_array),
607        &format!("{}[__tid + __s]", shared_array),
608    );
609
610    format!(
611        r#"{{
612    int __tid = threadIdx.x;
613    {shared_array}[__tid] = {value_expr};
614    __syncthreads();
615
616    #pragma unroll
617    for (int __s = {block_size} / 2; __s > 0; __s >>= 1) {{
618        if (__tid < __s) {{
619            {shared_array}[__tid] = {combine};
620        }}
621        __syncthreads();
622    }}
623
624    {ty} {result_var} = {shared_array}[0];
625}}"#,
626        value_expr = value_expr,
627        shared_array = shared_array,
628        result_var = result_var,
629        ty = ty,
630        block_size = block_size,
631        combine = combine
632    )
633}
634
635/// Generate inline grid reduction with atomic accumulation.
636pub fn generate_inline_grid_reduce(
637    value_expr: &str,
638    shared_array: &str,
639    accumulator: &str,
640    _ty: &str,
641    block_size: u32,
642    op: &ReductionOp,
643) -> String {
644    let atomic_fn = op.atomic_fn();
645    let combine = op.combine_expr(
646        &format!("{}[__tid]", shared_array),
647        &format!("{}[__tid + __s]", shared_array),
648    );
649
650    format!(
651        r#"{{
652    int __tid = threadIdx.x;
653    {shared_array}[__tid] = {value_expr};
654    __syncthreads();
655
656    #pragma unroll
657    for (int __s = {block_size} / 2; __s > 0; __s >>= 1) {{
658        if (__tid < __s) {{
659            {shared_array}[__tid] = {combine};
660        }}
661        __syncthreads();
662    }}
663
664    if (__tid == 0) {{
665        {atomic_fn}({accumulator}, {shared_array}[0]);
666    }}
667}}"#,
668        value_expr = value_expr,
669        shared_array = shared_array,
670        accumulator = accumulator,
671        block_size = block_size,
672        combine = combine,
673        atomic_fn = atomic_fn
674    )
675}
676
677/// Generate inline reduce-and-broadcast code.
678#[allow(clippy::too_many_arguments)]
679pub fn generate_inline_reduce_and_broadcast(
680    value_expr: &str,
681    shared_array: &str,
682    accumulator: &str,
683    result_var: &str,
684    ty: &str,
685    block_size: u32,
686    op: &ReductionOp,
687    use_cooperative: bool,
688) -> String {
689    let atomic_fn = op.atomic_fn();
690    let combine = op.combine_expr(
691        &format!("{}[__tid]", shared_array),
692        &format!("{}[__tid + __s]", shared_array),
693    );
694
695    let grid_sync = if use_cooperative {
696        "    cg::grid_group __grid = cg::this_grid();\n    __grid.sync();"
697    } else {
698        "    __software_grid_sync(&__barrier_counter, &__barrier_gen, gridDim.x * gridDim.y * gridDim.z);"
699    };
700
701    format!(
702        r#"{{
703    int __tid = threadIdx.x;
704    {shared_array}[__tid] = {value_expr};
705    __syncthreads();
706
707    // Block-level reduction
708    #pragma unroll
709    for (int __s = {block_size} / 2; __s > 0; __s >>= 1) {{
710        if (__tid < __s) {{
711            {shared_array}[__tid] = {combine};
712        }}
713        __syncthreads();
714    }}
715
716    // Atomic accumulation
717    if (__tid == 0) {{
718        {atomic_fn}({accumulator}, {shared_array}[0]);
719    }}
720
721    // Grid synchronization
722{grid_sync}
723
724    // Broadcast: all threads read result
725    {ty} {result_var} = *{accumulator};
726}}"#,
727        value_expr = value_expr,
728        shared_array = shared_array,
729        accumulator = accumulator,
730        result_var = result_var,
731        ty = ty,
732        block_size = block_size,
733        combine = combine,
734        atomic_fn = atomic_fn,
735        grid_sync = grid_sync
736    )
737}
738
739#[cfg(test)]
740mod tests {
741    use super::*;
742
743    #[test]
744    fn test_reduction_op_combine() {
745        assert_eq!(ReductionOp::Sum.combine_expr("a", "b"), "a + b");
746        assert_eq!(ReductionOp::Min.combine_expr("a", "b"), "min(a, b)");
747        assert_eq!(ReductionOp::Max.combine_expr("a", "b"), "max(a, b)");
748        assert_eq!(ReductionOp::And.combine_expr("a", "b"), "a & b");
749    }
750
751    #[test]
752    fn test_reduction_op_identity() {
753        assert_eq!(ReductionOp::Sum.identity("float"), "0");
754        assert_eq!(ReductionOp::Min.identity("float"), "INFINITY");
755        assert_eq!(ReductionOp::Max.identity("float"), "-INFINITY");
756        assert_eq!(ReductionOp::Product.identity("int"), "1");
757    }
758
759    #[test]
760    fn test_intrinsic_from_name() {
761        assert_eq!(
762            ReductionIntrinsic::from_name("block_reduce_sum"),
763            Some(ReductionIntrinsic::BlockReduceSum)
764        );
765        assert_eq!(
766            ReductionIntrinsic::from_name("grid_reduce_min"),
767            Some(ReductionIntrinsic::GridReduceMin)
768        );
769        assert_eq!(
770            ReductionIntrinsic::from_name("reduce_and_broadcast"),
771            Some(ReductionIntrinsic::ReduceAndBroadcast)
772        );
773        assert_eq!(ReductionIntrinsic::from_name("not_an_intrinsic"), None);
774    }
775
776    #[test]
777    fn test_generate_block_reduce() {
778        let code = generate_block_reduce_fn("float", 256, &ReductionOp::Sum);
779        assert!(code.contains("__device__ float __block_reduce_sum"));
780        assert!(code.contains("shared[tid] = shared[tid] + shared[tid + s]"));
781        assert!(code.contains("__syncthreads()"));
782    }
783
784    #[test]
785    fn test_generate_grid_reduce() {
786        let code = generate_grid_reduce_fn("double", 128, &ReductionOp::Max);
787        assert!(code.contains("__device__ void __grid_reduce_max"));
788        assert!(code.contains("atomicMax(accumulator, shared[0])"));
789    }
790
791    #[test]
792    fn test_generate_reduce_and_broadcast_cooperative() {
793        let code = generate_reduce_and_broadcast_fn("float", 256, &ReductionOp::Sum, true);
794        assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
795        assert!(code.contains("grid.sync()"));
796    }
797
798    #[test]
799    fn test_generate_reduce_and_broadcast_software() {
800        let code = generate_reduce_and_broadcast_fn("float", 256, &ReductionOp::Sum, false);
801        assert!(code.contains("__barrier_counter"));
802        assert!(code.contains("__software_grid_sync"));
803    }
804
805    #[test]
806    fn test_inline_block_reduce() {
807        let code = generate_inline_block_reduce(
808            "my_value",
809            "shared_mem",
810            "result",
811            "float",
812            256,
813            &ReductionOp::Sum,
814        );
815        assert!(code.contains("shared_mem[__tid] = my_value"));
816        assert!(code.contains("float result = shared_mem[0]"));
817    }
818
819    #[test]
820    fn test_transpile_reduction_call() {
821        let config = ReductionCodegenConfig::default();
822
823        let result = transpile_reduction_call(
824            ReductionIntrinsic::BlockReduceSum,
825            &["val".to_string(), "shared".to_string()],
826            &config,
827        );
828        assert_eq!(result.unwrap(), "__block_reduce_sum(val, shared)");
829
830        let result = transpile_reduction_call(
831            ReductionIntrinsic::AtomicAccumulate,
832            &["&accum".to_string(), "val".to_string()],
833            &config,
834        );
835        assert_eq!(result.unwrap(), "atomicAdd(&accum, val)");
836    }
837
838    #[test]
839    fn test_reduction_helpers() {
840        let config = ReductionCodegenConfig::new(256)
841            .with_type("float")
842            .with_op(ReductionOp::Sum)
843            .with_cooperative(true);
844
845        let helpers = generate_reduction_helpers(&config);
846        assert!(helpers.contains("#include <cooperative_groups.h>"));
847        assert!(helpers.contains("__block_reduce_sum"));
848        assert!(helpers.contains("__grid_reduce_sum"));
849        assert!(helpers.contains("__reduce_and_broadcast_sum"));
850    }
851}