Skip to main content

ringkernel_cuda_codegen/
reduction_intrinsics.rs

1//! Reduction Intrinsics for CUDA Code Generation
2//!
3//! This module provides DSL intrinsics for efficient parallel reductions that
4//! transpile to optimized CUDA code. Reductions aggregate values across threads
5//! using operations like sum, min, max, etc.
6//!
7//! # Reduction Hierarchy
8//!
9//! ```text
10//! ┌─────────────────────────────────────────────────────────────────┐
11//! │                     Grid-Level Reduction                        │
12//! │  ┌───────────────────────────────────────────────────────────┐  │
13//! │  │                    Block 0                                │  │
14//! │  │  ┌─────────┐ ┌─────────┐ ┌─────────┐ ┌─────────┐         │  │
15//! │  │  │ Warp 0  │ │ Warp 1  │ │ Warp 2  │ │ Warp N  │         │  │
16//! │  │  │ shuffle │ │ shuffle │ │ shuffle │ │ shuffle │         │  │
17//! │  │  └────┬────┘ └────┬────┘ └────┬────┘ └────┬────┘         │  │
18//! │  │       └──────────────┴──────────┴──────────┘             │  │
19//! │  │                      │ shared memory                     │  │
20//! │  │                      ▼                                   │  │
21//! │  │              block_reduce_sum()                          │  │
22//! │  │                      │                                   │  │
23//! │  └──────────────────────┼───────────────────────────────────┘  │
24//! │                         ▼ atomicAdd                            │
25//! │  ┌─────────────────────────────────────────────────────────┐  │
26//! │  │              Global Accumulator (mapped memory)          │  │
27//! │  └─────────────────────────────────────────────────────────┘  │
28//! │                         │ grid.sync() / barrier                │
29//! │                         ▼                                      │
30//! │              All threads read final result                     │
31//! └─────────────────────────────────────────────────────────────────┘
32//! ```
33//!
34//! # Example
35//!
36//! ```ignore
37//! // Rust DSL
38//! fn pagerank_phase1(ranks: &[f64], out_degree: &[u32], dangling_sum: &mut f64, n: u32) {
39//!     let idx = global_thread_idx();
40//!     if idx >= n { return; }
41//!
42//!     let contrib = if out_degree[idx] == 0 { ranks[idx] } else { 0.0 };
43//!     reduce_and_broadcast(contrib, dangling_sum);  // All threads get sum
44//! }
45//! ```
46
47use std::fmt;
48
49/// Reduction operation types for code generation.
50#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
51pub enum ReductionOp {
52    /// Sum: `a + b`
53    #[default]
54    Sum,
55    /// Minimum: `min(a, b)`
56    Min,
57    /// Maximum: `max(a, b)`
58    Max,
59    /// Bitwise AND: `a & b`
60    And,
61    /// Bitwise OR: `a | b`
62    Or,
63    /// Bitwise XOR: `a ^ b`
64    Xor,
65    /// Product: `a * b`
66    Product,
67}
68
69impl ReductionOp {
70    /// Get the binary operator for this reduction.
71    pub fn operator(&self) -> &'static str {
72        match self {
73            ReductionOp::Sum => "+",
74            ReductionOp::Min => "min",
75            ReductionOp::Max => "max",
76            ReductionOp::And => "&",
77            ReductionOp::Or => "|",
78            ReductionOp::Xor => "^",
79            ReductionOp::Product => "*",
80        }
81    }
82
83    /// Get the CUDA atomic function name.
84    pub fn atomic_fn(&self) -> &'static str {
85        match self {
86            ReductionOp::Sum => "atomicAdd",
87            ReductionOp::Min => "atomicMin",
88            ReductionOp::Max => "atomicMax",
89            ReductionOp::And => "atomicAnd",
90            ReductionOp::Or => "atomicOr",
91            ReductionOp::Xor => "atomicXor",
92            ReductionOp::Product => "atomicMul", // Custom implementation
93        }
94    }
95
96    /// Get the identity value expression for this operation.
97    pub fn identity(&self, ty: &str) -> String {
98        match (self, ty) {
99            (ReductionOp::Sum, _) => "0".to_string(),
100            (ReductionOp::Min, "float") => "INFINITY".to_string(),
101            (ReductionOp::Min, "double") => "HUGE_VAL".to_string(),
102            (ReductionOp::Min, "int") => "INT_MAX".to_string(),
103            (ReductionOp::Min, "long long") => "LLONG_MAX".to_string(),
104            (ReductionOp::Min, "unsigned int") => "UINT_MAX".to_string(),
105            (ReductionOp::Min, "unsigned long long") => "ULLONG_MAX".to_string(),
106            (ReductionOp::Max, "float") => "-INFINITY".to_string(),
107            (ReductionOp::Max, "double") => "-HUGE_VAL".to_string(),
108            (ReductionOp::Max, "int") => "INT_MIN".to_string(),
109            (ReductionOp::Max, "long long") => "LLONG_MIN".to_string(),
110            (ReductionOp::Max, _) => "0".to_string(),
111            (ReductionOp::And, _) => "~0".to_string(),
112            (ReductionOp::Or, _) => "0".to_string(),
113            (ReductionOp::Xor, _) => "0".to_string(),
114            (ReductionOp::Product, _) => "1".to_string(),
115            _ => "0".to_string(),
116        }
117    }
118
119    /// Get the combination expression for two values.
120    pub fn combine_expr(&self, a: &str, b: &str) -> String {
121        match self {
122            ReductionOp::Sum => format!("{} + {}", a, b),
123            ReductionOp::Min => format!("min({}, {})", a, b),
124            ReductionOp::Max => format!("max({}, {})", a, b),
125            ReductionOp::And => format!("{} & {}", a, b),
126            ReductionOp::Or => format!("{} | {}", a, b),
127            ReductionOp::Xor => format!("{} ^ {}", a, b),
128            ReductionOp::Product => format!("{} * {}", a, b),
129        }
130    }
131}
132
133impl fmt::Display for ReductionOp {
134    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
135        match self {
136            ReductionOp::Sum => write!(f, "sum"),
137            ReductionOp::Min => write!(f, "min"),
138            ReductionOp::Max => write!(f, "max"),
139            ReductionOp::And => write!(f, "and"),
140            ReductionOp::Or => write!(f, "or"),
141            ReductionOp::Xor => write!(f, "xor"),
142            ReductionOp::Product => write!(f, "product"),
143        }
144    }
145}
146
147/// Configuration for reduction code generation.
148#[derive(Debug, Clone)]
149pub struct ReductionCodegenConfig {
150    /// Thread block size (must be power of 2).
151    pub block_size: u32,
152    /// Value type (CUDA type name).
153    pub value_type: String,
154    /// Reduction operation.
155    pub op: ReductionOp,
156    /// Use cooperative groups for grid-wide sync.
157    pub use_cooperative: bool,
158    /// Generate helper function definitions.
159    pub generate_helpers: bool,
160}
161
162impl Default for ReductionCodegenConfig {
163    fn default() -> Self {
164        Self {
165            block_size: 256,
166            value_type: "float".to_string(),
167            op: ReductionOp::Sum,
168            use_cooperative: true,
169            generate_helpers: true,
170        }
171    }
172}
173
174impl ReductionCodegenConfig {
175    /// Create a new config with the given block size.
176    pub fn new(block_size: u32) -> Self {
177        Self {
178            block_size,
179            ..Default::default()
180        }
181    }
182
183    /// Set the value type.
184    pub fn with_type(mut self, ty: &str) -> Self {
185        self.value_type = ty.to_string();
186        self
187    }
188
189    /// Set the reduction operation.
190    pub fn with_op(mut self, op: ReductionOp) -> Self {
191        self.op = op;
192        self
193    }
194
195    /// Enable or disable cooperative groups.
196    pub fn with_cooperative(mut self, enabled: bool) -> Self {
197        self.use_cooperative = enabled;
198        self
199    }
200}
201
202/// Reduction intrinsic types for the DSL.
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
204pub enum ReductionIntrinsic {
205    /// Block-level sum reduction: `block_reduce_sum(value, shared_mem)`
206    BlockReduceSum,
207    /// Block-level min reduction: `block_reduce_min(value, shared_mem)`
208    BlockReduceMin,
209    /// Block-level max reduction: `block_reduce_max(value, shared_mem)`
210    BlockReduceMax,
211    /// Block-level AND reduction: `block_reduce_and(value, shared_mem)`
212    BlockReduceAnd,
213    /// Block-level OR reduction: `block_reduce_or(value, shared_mem)`
214    BlockReduceOr,
215
216    /// Grid-level sum with atomic accumulation: `grid_reduce_sum(value, shared, accumulator)`
217    GridReduceSum,
218    /// Grid-level min with atomic accumulation: `grid_reduce_min(value, shared, accumulator)`
219    GridReduceMin,
220    /// Grid-level max with atomic accumulation: `grid_reduce_max(value, shared, accumulator)`
221    GridReduceMax,
222
223    /// Atomic accumulate to buffer: `atomic_accumulate(accumulator, value)`
224    AtomicAccumulate,
225    /// Atomic accumulate with fetch: `atomic_accumulate_fetch(accumulator, value)`
226    AtomicAccumulateFetch,
227
228    /// Read from reduction buffer (broadcast): `broadcast_read(buffer)`
229    BroadcastRead,
230
231    /// Full reduce-and-broadcast: `reduce_and_broadcast(value, shared, accumulator)`
232    /// Returns the global result to all threads.
233    ReduceAndBroadcast,
234}
235
236impl ReductionIntrinsic {
237    /// Parse intrinsic from DSL function name.
238    pub fn from_name(name: &str) -> Option<Self> {
239        match name {
240            "block_reduce_sum" => Some(ReductionIntrinsic::BlockReduceSum),
241            "block_reduce_min" => Some(ReductionIntrinsic::BlockReduceMin),
242            "block_reduce_max" => Some(ReductionIntrinsic::BlockReduceMax),
243            "block_reduce_and" => Some(ReductionIntrinsic::BlockReduceAnd),
244            "block_reduce_or" => Some(ReductionIntrinsic::BlockReduceOr),
245            "grid_reduce_sum" => Some(ReductionIntrinsic::GridReduceSum),
246            "grid_reduce_min" => Some(ReductionIntrinsic::GridReduceMin),
247            "grid_reduce_max" => Some(ReductionIntrinsic::GridReduceMax),
248            "atomic_accumulate" => Some(ReductionIntrinsic::AtomicAccumulate),
249            "atomic_accumulate_fetch" => Some(ReductionIntrinsic::AtomicAccumulateFetch),
250            "broadcast_read" => Some(ReductionIntrinsic::BroadcastRead),
251            "reduce_and_broadcast" => Some(ReductionIntrinsic::ReduceAndBroadcast),
252            _ => None,
253        }
254    }
255
256    /// Get the reduction operation for this intrinsic.
257    pub fn op(&self) -> Option<ReductionOp> {
258        match self {
259            ReductionIntrinsic::BlockReduceSum | ReductionIntrinsic::GridReduceSum => {
260                Some(ReductionOp::Sum)
261            }
262            ReductionIntrinsic::BlockReduceMin | ReductionIntrinsic::GridReduceMin => {
263                Some(ReductionOp::Min)
264            }
265            ReductionIntrinsic::BlockReduceMax | ReductionIntrinsic::GridReduceMax => {
266                Some(ReductionOp::Max)
267            }
268            ReductionIntrinsic::BlockReduceAnd => Some(ReductionOp::And),
269            ReductionIntrinsic::BlockReduceOr => Some(ReductionOp::Or),
270            _ => None,
271        }
272    }
273}
274
275// =============================================================================
276// Code Generation Functions
277// =============================================================================
278
279/// Generate CUDA helper functions for reduction operations.
280///
281/// This should be included at the top of kernels that use reduction intrinsics.
282pub fn generate_reduction_helpers(config: &ReductionCodegenConfig) -> String {
283    let mut code = String::new();
284
285    // Includes
286    if config.use_cooperative {
287        code.push_str("#include <cooperative_groups.h>\n");
288        code.push_str("namespace cg = cooperative_groups;\n\n");
289    }
290    code.push_str("#include <limits.h>\n");
291    code.push_str("#include <math.h>\n\n");
292
293    // Block-level reduction function
294    code.push_str(&generate_block_reduce_fn(
295        &config.value_type,
296        config.block_size,
297        &config.op,
298    ));
299
300    // Grid-level reduction function
301    code.push_str(&generate_grid_reduce_fn(
302        &config.value_type,
303        config.block_size,
304        &config.op,
305    ));
306
307    // Reduce-and-broadcast function
308    code.push_str(&generate_reduce_and_broadcast_fn(
309        &config.value_type,
310        config.block_size,
311        &config.op,
312        config.use_cooperative,
313    ));
314
315    // Software barrier (if not using cooperative groups)
316    if !config.use_cooperative {
317        code.push_str(&generate_software_barrier());
318    }
319
320    code
321}
322
323/// Generate the warp-shuffle combine expression for a reduction operation.
324fn shfl_combine_expr(op: &ReductionOp, val: &str, offset: &str) -> String {
325    let shfl = format!("__shfl_down_sync(0xFFFFFFFF, {}, {})", val, offset);
326    op.combine_expr(val, &shfl)
327}
328
329/// Generate a block-level reduction function using two-phase warp-shuffle.
330fn generate_block_reduce_fn(ty: &str, block_size: u32, op: &ReductionOp) -> String {
331    let warp_combine = shfl_combine_expr(op, "val", "offset");
332    let cross_warp_combine = shfl_combine_expr(op, "val", "offset");
333    let num_warps = block_size / 32;
334    let identity = op.identity(ty);
335
336    format!(
337        r#"
338// Block-level {op} reduction using warp-shuffle + shared memory
339__device__ {ty} __block_reduce_{op}({ty} val, {ty}* shared) {{
340    int tid = threadIdx.x;
341    int warp_id = tid / 32;
342    int lane_id = tid % 32;
343
344    // Phase 1: Intra-warp reduction via shuffle
345    #pragma unroll
346    for (int offset = 16; offset > 0; offset >>= 1) {{
347        val = {warp_combine};
348    }}
349
350    // Warp leaders store partial results
351    if (lane_id == 0) {{
352        shared[warp_id] = val;
353    }}
354    __syncthreads();
355
356    // Phase 2: First warp reduces across warp results
357    val = (tid < {num_warps}) ? shared[tid] : {identity};
358    if (warp_id == 0) {{
359        #pragma unroll
360        for (int offset = 16; offset > 0; offset >>= 1) {{
361            val = {cross_warp_combine};
362        }}
363    }}
364
365    return val;
366}}
367
368"#,
369        ty = ty,
370        op = op,
371        num_warps = num_warps,
372        identity = identity,
373        warp_combine = warp_combine,
374        cross_warp_combine = cross_warp_combine
375    )
376}
377
378/// Generate a grid-level reduction function with atomic accumulation.
379fn generate_grid_reduce_fn(ty: &str, block_size: u32, op: &ReductionOp) -> String {
380    let atomic_fn = op.atomic_fn();
381    let warp_combine = shfl_combine_expr(op, "val", "offset");
382    let cross_warp_combine = shfl_combine_expr(op, "val", "offset");
383    let num_warps = block_size / 32;
384    let identity = op.identity(ty);
385
386    format!(
387        r#"
388// Grid-level {op} reduction with warp-shuffle + atomic accumulation
389__device__ void __grid_reduce_{op}({ty} val, {ty}* shared, {ty}* accumulator) {{
390    int tid = threadIdx.x;
391    int warp_id = tid / 32;
392    int lane_id = tid % 32;
393
394    // Phase 1: Intra-warp reduction via shuffle
395    #pragma unroll
396    for (int offset = 16; offset > 0; offset >>= 1) {{
397        val = {warp_combine};
398    }}
399
400    // Warp leaders store partial results
401    if (lane_id == 0) {{
402        shared[warp_id] = val;
403    }}
404    __syncthreads();
405
406    // Phase 2: First warp reduces across warp results
407    val = (tid < {num_warps}) ? shared[tid] : {identity};
408    if (warp_id == 0) {{
409        #pragma unroll
410        for (int offset = 16; offset > 0; offset >>= 1) {{
411            val = {cross_warp_combine};
412        }}
413    }}
414
415    // Block leader atomically accumulates
416    if (tid == 0) {{
417        {atomic_fn}(accumulator, val);
418    }}
419}}
420
421"#,
422        ty = ty,
423        op = op,
424        num_warps = num_warps,
425        identity = identity,
426        warp_combine = warp_combine,
427        cross_warp_combine = cross_warp_combine,
428        atomic_fn = atomic_fn
429    )
430}
431
432/// Generate a reduce-and-broadcast function.
433fn generate_reduce_and_broadcast_fn(
434    ty: &str,
435    block_size: u32,
436    op: &ReductionOp,
437    use_cooperative: bool,
438) -> String {
439    let atomic_fn = op.atomic_fn();
440    let warp_combine = shfl_combine_expr(op, "val", "offset");
441    let cross_warp_combine = shfl_combine_expr(op, "val", "offset");
442    let num_warps = block_size / 32;
443    let identity = op.identity(ty);
444
445    let grid_sync = if use_cooperative {
446        "    cg::grid_group grid = cg::this_grid();\n    grid.sync();"
447    } else {
448        "    __software_grid_sync(&__barrier_counter, &__barrier_gen, gridDim.x * gridDim.y * gridDim.z);"
449    };
450
451    let params = if use_cooperative {
452        format!("{} val, {}* shared, {}* accumulator", ty, ty, ty)
453    } else {
454        format!(
455            "{} val, {}* shared, {}* accumulator, unsigned int* __barrier_counter, unsigned int* __barrier_gen",
456            ty, ty, ty
457        )
458    };
459
460    format!(
461        r#"
462// Reduce-and-broadcast: all threads get the global {op} result (warp-shuffle)
463__device__ {ty} __reduce_and_broadcast_{op}({params}) {{
464    int tid = threadIdx.x;
465    int warp_id = tid / 32;
466    int lane_id = tid % 32;
467
468    // Phase 1: Intra-warp reduction via shuffle
469    #pragma unroll
470    for (int offset = 16; offset > 0; offset >>= 1) {{
471        val = {warp_combine};
472    }}
473
474    // Warp leaders store partial results
475    if (lane_id == 0) {{
476        shared[warp_id] = val;
477    }}
478    __syncthreads();
479
480    // Phase 2: First warp reduces across warp results
481    val = (tid < {num_warps}) ? shared[tid] : {identity};
482    if (warp_id == 0) {{
483        #pragma unroll
484        for (int offset = 16; offset > 0; offset >>= 1) {{
485            val = {cross_warp_combine};
486        }}
487    }}
488
489    // Phase 3: Block leader atomically accumulates
490    if (tid == 0) {{
491        {atomic_fn}(accumulator, val);
492    }}
493
494    // Phase 4: Grid-wide synchronization
495{grid_sync}
496
497    // Phase 5: All threads read the result
498    return *accumulator;
499}}
500
501"#,
502        ty = ty,
503        op = op,
504        num_warps = num_warps,
505        identity = identity,
506        warp_combine = warp_combine,
507        cross_warp_combine = cross_warp_combine,
508        atomic_fn = atomic_fn,
509        params = params,
510        grid_sync = grid_sync
511    )
512}
513
514/// Generate a software grid barrier for non-cooperative kernels.
515fn generate_software_barrier() -> String {
516    r#"
517// Software grid barrier using global memory atomics
518// Use when cooperative groups are unavailable
519__device__ void __software_grid_sync(
520    volatile unsigned int* barrier_counter,
521    volatile unsigned int* barrier_gen,
522    unsigned int num_blocks
523) {
524    __syncthreads();
525
526    if (threadIdx.x == 0) {
527        unsigned int gen = *barrier_gen;
528        unsigned int arrived = atomicAdd((unsigned int*)barrier_counter, 1) + 1;
529
530        if (arrived == num_blocks) {
531            // Last block to arrive: reset counter and increment generation
532            *barrier_counter = 0;
533            __threadfence();
534            atomicAdd((unsigned int*)barrier_gen, 1);
535        } else {
536            // Wait for all blocks
537            while (atomicAdd((unsigned int*)barrier_gen, 0) == gen) {
538                __threadfence();
539            }
540        }
541    }
542
543    __syncthreads();
544}
545
546"#
547    .to_string()
548}
549
550/// Transpile a reduction intrinsic call to CUDA code.
551///
552/// # Arguments
553///
554/// * `intrinsic` - The reduction intrinsic being called
555/// * `args` - The transpiled argument expressions
556/// * `config` - Reduction configuration
557///
558/// # Returns
559///
560/// The CUDA expression or statement for this intrinsic.
561pub fn transpile_reduction_call(
562    intrinsic: ReductionIntrinsic,
563    args: &[String],
564    config: &ReductionCodegenConfig,
565) -> Result<String, String> {
566    match intrinsic {
567        ReductionIntrinsic::BlockReduceSum => {
568            validate_args(args, 2, "block_reduce_sum")?;
569            Ok(format!("__block_reduce_sum({}, {})", args[0], args[1]))
570        }
571        ReductionIntrinsic::BlockReduceMin => {
572            validate_args(args, 2, "block_reduce_min")?;
573            Ok(format!("__block_reduce_min({}, {})", args[0], args[1]))
574        }
575        ReductionIntrinsic::BlockReduceMax => {
576            validate_args(args, 2, "block_reduce_max")?;
577            Ok(format!("__block_reduce_max({}, {})", args[0], args[1]))
578        }
579        ReductionIntrinsic::BlockReduceAnd => {
580            validate_args(args, 2, "block_reduce_and")?;
581            Ok(format!("__block_reduce_and({}, {})", args[0], args[1]))
582        }
583        ReductionIntrinsic::BlockReduceOr => {
584            validate_args(args, 2, "block_reduce_or")?;
585            Ok(format!("__block_reduce_or({}, {})", args[0], args[1]))
586        }
587        ReductionIntrinsic::GridReduceSum => {
588            validate_args(args, 3, "grid_reduce_sum")?;
589            Ok(format!(
590                "__grid_reduce_sum({}, {}, {})",
591                args[0], args[1], args[2]
592            ))
593        }
594        ReductionIntrinsic::GridReduceMin => {
595            validate_args(args, 3, "grid_reduce_min")?;
596            Ok(format!(
597                "__grid_reduce_min({}, {}, {})",
598                args[0], args[1], args[2]
599            ))
600        }
601        ReductionIntrinsic::GridReduceMax => {
602            validate_args(args, 3, "grid_reduce_max")?;
603            Ok(format!(
604                "__grid_reduce_max({}, {}, {})",
605                args[0], args[1], args[2]
606            ))
607        }
608        ReductionIntrinsic::AtomicAccumulate => {
609            validate_args(args, 2, "atomic_accumulate")?;
610            Ok(format!("atomicAdd({}, {})", args[0], args[1]))
611        }
612        ReductionIntrinsic::AtomicAccumulateFetch => {
613            validate_args(args, 2, "atomic_accumulate_fetch")?;
614            Ok(format!("atomicAdd({}, {})", args[0], args[1]))
615        }
616        ReductionIntrinsic::BroadcastRead => {
617            validate_args(args, 1, "broadcast_read")?;
618            Ok(format!("(*{})", args[0]))
619        }
620        ReductionIntrinsic::ReduceAndBroadcast => {
621            if config.use_cooperative {
622                validate_args(args, 3, "reduce_and_broadcast")?;
623                Ok(format!(
624                    "__reduce_and_broadcast_sum({}, {}, {})",
625                    args[0], args[1], args[2]
626                ))
627            } else {
628                validate_args(args, 5, "reduce_and_broadcast (software barrier)")?;
629                Ok(format!(
630                    "__reduce_and_broadcast_sum({}, {}, {}, {}, {})",
631                    args[0], args[1], args[2], args[3], args[4]
632                ))
633            }
634        }
635    }
636}
637
638fn validate_args(args: &[String], expected: usize, name: &str) -> Result<(), String> {
639    if args.len() != expected {
640        Err(format!(
641            "{} requires {} arguments, got {}",
642            name,
643            expected,
644            args.len()
645        ))
646    } else {
647        Ok(())
648    }
649}
650
651/// Generate inline reduction code (without helper function call).
652///
653/// Use this when you want the reduction logic inlined in a specific location.
654/// Uses two-phase warp-shuffle reduction for efficiency.
655pub fn generate_inline_block_reduce(
656    value_expr: &str,
657    shared_array: &str,
658    result_var: &str,
659    ty: &str,
660    block_size: u32,
661    op: &ReductionOp,
662) -> String {
663    let warp_combine = shfl_combine_expr(op, "__val", "__offset");
664    let cross_warp_combine = shfl_combine_expr(op, "__val", "__offset");
665    let num_warps = block_size / 32;
666    let identity = op.identity(ty);
667
668    format!(
669        r#"{{
670    int __tid = threadIdx.x;
671    int __warp_id = __tid / 32;
672    int __lane_id = __tid % 32;
673    {ty} __val = {value_expr};
674
675    // Phase 1: Intra-warp reduction via shuffle
676    #pragma unroll
677    for (int __offset = 16; __offset > 0; __offset >>= 1) {{
678        __val = {warp_combine};
679    }}
680
681    // Warp leaders store partial results
682    if (__lane_id == 0) {{
683        {shared_array}[__warp_id] = __val;
684    }}
685    __syncthreads();
686
687    // Phase 2: First warp reduces across warp results
688    __val = (__tid < {num_warps}) ? {shared_array}[__tid] : {identity};
689    if (__warp_id == 0) {{
690        #pragma unroll
691        for (int __offset = 16; __offset > 0; __offset >>= 1) {{
692            __val = {cross_warp_combine};
693        }}
694    }}
695
696    if (__tid == 0) {{
697        {shared_array}[0] = __val;
698    }}
699
700    {ty} {result_var} = {shared_array}[0];
701}}"#,
702        value_expr = value_expr,
703        shared_array = shared_array,
704        result_var = result_var,
705        ty = ty,
706        num_warps = num_warps,
707        identity = identity,
708        warp_combine = warp_combine,
709        cross_warp_combine = cross_warp_combine
710    )
711}
712
713/// Generate inline grid reduction with atomic accumulation.
714pub fn generate_inline_grid_reduce(
715    value_expr: &str,
716    shared_array: &str,
717    accumulator: &str,
718    ty: &str,
719    block_size: u32,
720    op: &ReductionOp,
721) -> String {
722    let atomic_fn = op.atomic_fn();
723    let warp_combine = shfl_combine_expr(op, "__val", "__offset");
724    let cross_warp_combine = shfl_combine_expr(op, "__val", "__offset");
725    let num_warps = block_size / 32;
726    let identity = op.identity(ty);
727
728    format!(
729        r#"{{
730    int __tid = threadIdx.x;
731    int __warp_id = __tid / 32;
732    int __lane_id = __tid % 32;
733    {ty} __val = {value_expr};
734
735    // Phase 1: Intra-warp reduction via shuffle
736    #pragma unroll
737    for (int __offset = 16; __offset > 0; __offset >>= 1) {{
738        __val = {warp_combine};
739    }}
740
741    // Warp leaders store partial results
742    if (__lane_id == 0) {{
743        {shared_array}[__warp_id] = __val;
744    }}
745    __syncthreads();
746
747    // Phase 2: First warp reduces across warp results
748    __val = (__tid < {num_warps}) ? {shared_array}[__tid] : {identity};
749    if (__warp_id == 0) {{
750        #pragma unroll
751        for (int __offset = 16; __offset > 0; __offset >>= 1) {{
752            __val = {cross_warp_combine};
753        }}
754    }}
755
756    if (__tid == 0) {{
757        {atomic_fn}({accumulator}, __val);
758    }}
759}}"#,
760        ty = ty,
761        value_expr = value_expr,
762        shared_array = shared_array,
763        accumulator = accumulator,
764        num_warps = num_warps,
765        identity = identity,
766        warp_combine = warp_combine,
767        cross_warp_combine = cross_warp_combine,
768        atomic_fn = atomic_fn
769    )
770}
771
772/// Generate inline reduce-and-broadcast code.
773#[allow(clippy::too_many_arguments)]
774pub fn generate_inline_reduce_and_broadcast(
775    value_expr: &str,
776    shared_array: &str,
777    accumulator: &str,
778    result_var: &str,
779    ty: &str,
780    block_size: u32,
781    op: &ReductionOp,
782    use_cooperative: bool,
783) -> String {
784    let atomic_fn = op.atomic_fn();
785    let warp_combine = shfl_combine_expr(op, "__val", "__offset");
786    let cross_warp_combine = shfl_combine_expr(op, "__val", "__offset");
787    let num_warps = block_size / 32;
788    let identity = op.identity(ty);
789
790    let grid_sync = if use_cooperative {
791        "    cg::grid_group __grid = cg::this_grid();\n    __grid.sync();"
792    } else {
793        "    __software_grid_sync(&__barrier_counter, &__barrier_gen, gridDim.x * gridDim.y * gridDim.z);"
794    };
795
796    format!(
797        r#"{{
798    int __tid = threadIdx.x;
799    int __warp_id = __tid / 32;
800    int __lane_id = __tid % 32;
801    {ty} __val = {value_expr};
802
803    // Phase 1: Intra-warp reduction via shuffle
804    #pragma unroll
805    for (int __offset = 16; __offset > 0; __offset >>= 1) {{
806        __val = {warp_combine};
807    }}
808
809    // Warp leaders store partial results
810    if (__lane_id == 0) {{
811        {shared_array}[__warp_id] = __val;
812    }}
813    __syncthreads();
814
815    // Phase 2: First warp reduces across warp results
816    __val = (__tid < {num_warps}) ? {shared_array}[__tid] : {identity};
817    if (__warp_id == 0) {{
818        #pragma unroll
819        for (int __offset = 16; __offset > 0; __offset >>= 1) {{
820            __val = {cross_warp_combine};
821        }}
822    }}
823
824    // Atomic accumulation
825    if (__tid == 0) {{
826        {atomic_fn}({accumulator}, __val);
827    }}
828
829    // Grid synchronization
830{grid_sync}
831
832    // Broadcast: all threads read result
833    {ty} {result_var} = *{accumulator};
834}}"#,
835        value_expr = value_expr,
836        shared_array = shared_array,
837        accumulator = accumulator,
838        result_var = result_var,
839        ty = ty,
840        num_warps = num_warps,
841        identity = identity,
842        warp_combine = warp_combine,
843        cross_warp_combine = cross_warp_combine,
844        atomic_fn = atomic_fn,
845        grid_sync = grid_sync
846    )
847}
848
849#[cfg(test)]
850mod tests {
851    use super::*;
852
853    #[test]
854    fn test_reduction_op_combine() {
855        assert_eq!(ReductionOp::Sum.combine_expr("a", "b"), "a + b");
856        assert_eq!(ReductionOp::Min.combine_expr("a", "b"), "min(a, b)");
857        assert_eq!(ReductionOp::Max.combine_expr("a", "b"), "max(a, b)");
858        assert_eq!(ReductionOp::And.combine_expr("a", "b"), "a & b");
859    }
860
861    #[test]
862    fn test_reduction_op_identity() {
863        assert_eq!(ReductionOp::Sum.identity("float"), "0");
864        assert_eq!(ReductionOp::Min.identity("float"), "INFINITY");
865        assert_eq!(ReductionOp::Max.identity("float"), "-INFINITY");
866        assert_eq!(ReductionOp::Product.identity("int"), "1");
867    }
868
869    #[test]
870    fn test_intrinsic_from_name() {
871        assert_eq!(
872            ReductionIntrinsic::from_name("block_reduce_sum"),
873            Some(ReductionIntrinsic::BlockReduceSum)
874        );
875        assert_eq!(
876            ReductionIntrinsic::from_name("grid_reduce_min"),
877            Some(ReductionIntrinsic::GridReduceMin)
878        );
879        assert_eq!(
880            ReductionIntrinsic::from_name("reduce_and_broadcast"),
881            Some(ReductionIntrinsic::ReduceAndBroadcast)
882        );
883        assert_eq!(ReductionIntrinsic::from_name("not_an_intrinsic"), None);
884    }
885
886    #[test]
887    fn test_generate_block_reduce() {
888        let code = generate_block_reduce_fn("float", 256, &ReductionOp::Sum);
889        assert!(code.contains("__device__ float __block_reduce_sum"));
890        assert!(code.contains("__shfl_down_sync(0xFFFFFFFF, val, offset)"));
891        assert!(code.contains("__syncthreads()"));
892        assert!(code.contains("warp_id"));
893        assert!(code.contains("lane_id"));
894    }
895
896    #[test]
897    fn test_generate_grid_reduce() {
898        let code = generate_grid_reduce_fn("double", 128, &ReductionOp::Max);
899        assert!(code.contains("__device__ void __grid_reduce_max"));
900        assert!(code.contains("__shfl_down_sync(0xFFFFFFFF, val, offset)"));
901        assert!(code.contains("atomicMax(accumulator, val)"));
902    }
903
904    #[test]
905    fn test_generate_reduce_and_broadcast_cooperative() {
906        let code = generate_reduce_and_broadcast_fn("float", 256, &ReductionOp::Sum, true);
907        assert!(code.contains("cg::grid_group grid = cg::this_grid()"));
908        assert!(code.contains("grid.sync()"));
909    }
910
911    #[test]
912    fn test_generate_reduce_and_broadcast_software() {
913        let code = generate_reduce_and_broadcast_fn("float", 256, &ReductionOp::Sum, false);
914        assert!(code.contains("__barrier_counter"));
915        assert!(code.contains("__software_grid_sync"));
916    }
917
918    #[test]
919    fn test_inline_block_reduce() {
920        let code = generate_inline_block_reduce(
921            "my_value",
922            "shared_mem",
923            "result",
924            "float",
925            256,
926            &ReductionOp::Sum,
927        );
928        assert!(code.contains("__shfl_down_sync(0xFFFFFFFF, __val, __offset)"));
929        assert!(code.contains("shared_mem[__warp_id] = __val"));
930        assert!(code.contains("float result = shared_mem[0]"));
931    }
932
933    #[test]
934    fn test_transpile_reduction_call() {
935        let config = ReductionCodegenConfig::default();
936
937        let result = transpile_reduction_call(
938            ReductionIntrinsic::BlockReduceSum,
939            &["val".to_string(), "shared".to_string()],
940            &config,
941        );
942        assert_eq!(result.unwrap(), "__block_reduce_sum(val, shared)");
943
944        let result = transpile_reduction_call(
945            ReductionIntrinsic::AtomicAccumulate,
946            &["&accum".to_string(), "val".to_string()],
947            &config,
948        );
949        assert_eq!(result.unwrap(), "atomicAdd(&accum, val)");
950    }
951
952    #[test]
953    fn test_reduction_helpers() {
954        let config = ReductionCodegenConfig::new(256)
955            .with_type("float")
956            .with_op(ReductionOp::Sum)
957            .with_cooperative(true);
958
959        let helpers = generate_reduction_helpers(&config);
960        assert!(helpers.contains("#include <cooperative_groups.h>"));
961        assert!(helpers.contains("__block_reduce_sum"));
962        assert!(helpers.contains("__grid_reduce_sum"));
963        assert!(helpers.contains("__reduce_and_broadcast_sum"));
964    }
965}