Skip to main content

oxilean_codegen/cuda_backend/
types.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use std::collections::{HashMap, HashSet, VecDeque};
6
7#[allow(dead_code)]
8#[derive(Debug, Clone)]
9pub struct CUDAWorklist {
10    pub(super) items: std::collections::VecDeque<u32>,
11    pub(super) in_worklist: std::collections::HashSet<u32>,
12}
13impl CUDAWorklist {
14    #[allow(dead_code)]
15    pub fn new() -> Self {
16        CUDAWorklist {
17            items: std::collections::VecDeque::new(),
18            in_worklist: std::collections::HashSet::new(),
19        }
20    }
21    #[allow(dead_code)]
22    pub fn push(&mut self, item: u32) -> bool {
23        if self.in_worklist.insert(item) {
24            self.items.push_back(item);
25            true
26        } else {
27            false
28        }
29    }
30    #[allow(dead_code)]
31    pub fn pop(&mut self) -> Option<u32> {
32        let item = self.items.pop_front()?;
33        self.in_worklist.remove(&item);
34        Some(item)
35    }
36    #[allow(dead_code)]
37    pub fn is_empty(&self) -> bool {
38        self.items.is_empty()
39    }
40    #[allow(dead_code)]
41    pub fn len(&self) -> usize {
42        self.items.len()
43    }
44    #[allow(dead_code)]
45    pub fn contains(&self, item: u32) -> bool {
46        self.in_worklist.contains(&item)
47    }
48}
49/// A CUDA kernel (`__global__` function).
50#[derive(Debug, Clone, PartialEq)]
51pub struct CudaKernel {
52    /// Kernel name
53    pub name: String,
54    /// Parameter list
55    pub params: Vec<CudaParam>,
56    /// Shared memory declarations (emitted at the top of the kernel body)
57    pub shared_mem_decls: Vec<SharedMemDecl>,
58    /// Kernel body statements
59    pub body: Vec<CudaStmt>,
60    /// Optional `__launch_bounds__` annotation
61    pub launch_bounds: Option<LaunchBounds>,
62}
63impl CudaKernel {
64    /// Create a new kernel with no launch bounds.
65    pub fn new(name: impl Into<String>) -> Self {
66        CudaKernel {
67            name: name.into(),
68            params: Vec::new(),
69            shared_mem_decls: Vec::new(),
70            body: Vec::new(),
71            launch_bounds: None,
72        }
73    }
74    /// Append a parameter.
75    pub fn add_param(mut self, p: CudaParam) -> Self {
76        self.params.push(p);
77        self
78    }
79    /// Append a shared-memory declaration.
80    pub fn add_shared(mut self, s: SharedMemDecl) -> Self {
81        self.shared_mem_decls.push(s);
82        self
83    }
84    /// Append a body statement.
85    pub fn add_stmt(mut self, s: CudaStmt) -> Self {
86        self.body.push(s);
87        self
88    }
89    /// Set launch bounds.
90    pub fn with_launch_bounds(mut self, lb: LaunchBounds) -> Self {
91        self.launch_bounds = Some(lb);
92        self
93    }
94}
95/// CUDA kernel launch configuration.
96#[derive(Debug, Clone, PartialEq)]
97pub struct LaunchConfig {
98    /// Grid dimensions (number of blocks)
99    pub grid: CudaExpr,
100    /// Block dimensions (threads per block)
101    pub block: CudaExpr,
102    /// Dynamic shared memory bytes (0 if none)
103    pub shared_mem: CudaExpr,
104    /// CUDA stream (None → default stream)
105    pub stream: Option<CudaExpr>,
106}
107impl LaunchConfig {
108    /// Create a simple 1-D launch config with no dynamic shared memory.
109    pub fn simple_1d(grid: CudaExpr, block: CudaExpr) -> Self {
110        LaunchConfig {
111            grid,
112            block,
113            shared_mem: CudaExpr::LitInt(0),
114            stream: None,
115        }
116    }
117}
118/// Top-level CUDA module representing a single `.cu` file.
119#[derive(Debug, Clone, PartialEq)]
120pub struct CudaModule {
121    /// `#include` directives (just the header names, e.g. `"cuda_runtime.h"`)
122    pub includes: Vec<String>,
123    /// `__constant__` memory declarations at file scope
124    pub constant_decls: Vec<(CudaType, String, Option<CudaExpr>)>,
125    /// `__device__` (or `__host__ __device__`) helper functions
126    pub device_functions: Vec<DeviceFunction>,
127    /// `__global__` kernels
128    pub kernels: Vec<CudaKernel>,
129    /// Host-side code (helper functions, `main`, etc.) as raw strings
130    pub host_code: Vec<String>,
131}
132impl CudaModule {
133    /// Create an empty module with standard CUDA includes.
134    pub fn new() -> Self {
135        CudaModule {
136            includes: vec![
137                "cuda_runtime.h".to_string(),
138                "device_launch_parameters.h".to_string(),
139            ],
140            constant_decls: Vec::new(),
141            device_functions: Vec::new(),
142            kernels: Vec::new(),
143            host_code: Vec::new(),
144        }
145    }
146    /// Add an `#include` (just the name; angle brackets / quotes are added by emitter).
147    pub fn add_include(mut self, header: impl Into<String>) -> Self {
148        self.includes.push(header.into());
149        self
150    }
151    /// Declare a `__constant__` variable at file scope.
152    pub fn add_constant(
153        mut self,
154        ty: CudaType,
155        name: impl Into<String>,
156        init: Option<CudaExpr>,
157    ) -> Self {
158        self.constant_decls.push((ty, name.into(), init));
159        self
160    }
161    /// Add a device function.
162    pub fn add_device_function(mut self, f: DeviceFunction) -> Self {
163        self.device_functions.push(f);
164        self
165    }
166    /// Add a kernel.
167    pub fn add_kernel(mut self, k: CudaKernel) -> Self {
168        self.kernels.push(k);
169        self
170    }
171    /// Append raw host-side C++ code.
172    pub fn add_host_code(mut self, code: impl Into<String>) -> Self {
173        self.host_code.push(code.into());
174        self
175    }
176}
177#[allow(dead_code)]
178#[derive(Debug, Clone, Default)]
179pub struct CUDAPassStats {
180    pub total_runs: u32,
181    pub successful_runs: u32,
182    pub total_changes: u64,
183    pub time_ms: u64,
184    pub iterations_used: u32,
185}
186impl CUDAPassStats {
187    #[allow(dead_code)]
188    pub fn new() -> Self {
189        Self::default()
190    }
191    #[allow(dead_code)]
192    pub fn record_run(&mut self, changes: u64, time_ms: u64, iterations: u32) {
193        self.total_runs += 1;
194        self.successful_runs += 1;
195        self.total_changes += changes;
196        self.time_ms += time_ms;
197        self.iterations_used = iterations;
198    }
199    #[allow(dead_code)]
200    pub fn average_changes_per_run(&self) -> f64 {
201        if self.total_runs == 0 {
202            return 0.0;
203        }
204        self.total_changes as f64 / self.total_runs as f64
205    }
206    #[allow(dead_code)]
207    pub fn success_rate(&self) -> f64 {
208        if self.total_runs == 0 {
209            return 0.0;
210        }
211        self.successful_runs as f64 / self.total_runs as f64
212    }
213    #[allow(dead_code)]
214    pub fn format_summary(&self) -> String {
215        format!(
216            "Runs: {}/{}, Changes: {}, Time: {}ms",
217            self.successful_runs, self.total_runs, self.total_changes, self.time_ms
218        )
219    }
220}
221/// Worklist for CUDAExt.
222#[allow(dead_code)]
223#[derive(Debug, Clone)]
224pub struct CUDAExtWorklist {
225    pub(super) items: std::collections::VecDeque<usize>,
226    pub(super) present: Vec<bool>,
227}
228impl CUDAExtWorklist {
229    #[allow(dead_code)]
230    pub fn new(capacity: usize) -> Self {
231        Self {
232            items: std::collections::VecDeque::new(),
233            present: vec![false; capacity],
234        }
235    }
236    #[allow(dead_code)]
237    pub fn push(&mut self, id: usize) {
238        if id < self.present.len() && !self.present[id] {
239            self.present[id] = true;
240            self.items.push_back(id);
241        }
242    }
243    #[allow(dead_code)]
244    pub fn push_front(&mut self, id: usize) {
245        if id < self.present.len() && !self.present[id] {
246            self.present[id] = true;
247            self.items.push_front(id);
248        }
249    }
250    #[allow(dead_code)]
251    pub fn pop(&mut self) -> Option<usize> {
252        let id = self.items.pop_front()?;
253        if id < self.present.len() {
254            self.present[id] = false;
255        }
256        Some(id)
257    }
258    #[allow(dead_code)]
259    pub fn is_empty(&self) -> bool {
260        self.items.is_empty()
261    }
262    #[allow(dead_code)]
263    pub fn len(&self) -> usize {
264        self.items.len()
265    }
266    #[allow(dead_code)]
267    pub fn contains(&self, id: usize) -> bool {
268        id < self.present.len() && self.present[id]
269    }
270    #[allow(dead_code)]
271    pub fn drain_all(&mut self) -> Vec<usize> {
272        let v: Vec<usize> = self.items.drain(..).collect();
273        for &id in &v {
274            if id < self.present.len() {
275                self.present[id] = false;
276            }
277        }
278        v
279    }
280}
281/// Emitter state for producing CUDA `.cu` source code.
282pub struct CudaBackend {
283    pub(super) indent_width: usize,
284}
285impl CudaBackend {
286    /// Create a new backend with 4-space indentation.
287    pub fn new() -> Self {
288        CudaBackend { indent_width: 4 }
289    }
290    /// Create a backend with a custom indent width.
291    pub fn with_indent(indent_width: usize) -> Self {
292        CudaBackend { indent_width }
293    }
294    pub(super) fn indent(&self, depth: usize) -> String {
295        " ".repeat(self.indent_width * depth)
296    }
297    /// Emit a CUDA expression to a string.
298    pub fn emit_expr(&self, expr: &CudaExpr) -> String {
299        expr.emit()
300    }
301    /// Emit a single statement at the given indentation depth.
302    pub fn emit_stmt(&self, stmt: &CudaStmt, depth: usize) -> String {
303        let ind = self.indent(depth);
304        match stmt {
305            CudaStmt::VarDecl { ty, name, init } => match init {
306                Some(expr) => format!("{}{} {} = {};", ind, ty, name, expr.emit()),
307                None => format!("{}{} {};", ind, ty, name),
308            },
309            CudaStmt::Assign { lhs, rhs } => {
310                format!("{}{} = {};", ind, lhs.emit(), rhs.emit())
311            }
312            CudaStmt::CompoundAssign { lhs, op, rhs } => {
313                format!("{}{} {}= {};", ind, lhs.emit(), op, rhs.emit())
314            }
315            CudaStmt::IfElse {
316                cond,
317                then_body,
318                else_body,
319            } => self.emit_if_else(cond, then_body, else_body.as_deref(), depth),
320            CudaStmt::ForLoop {
321                init,
322                cond,
323                step,
324                body,
325            } => self.emit_for_loop(init, cond, step, body, depth),
326            CudaStmt::WhileLoop { cond, body } => self.emit_while(cond, body, depth),
327            CudaStmt::KernelLaunch { name, config, args } => {
328                self.emit_kernel_launch(name, config, args, depth)
329            }
330            CudaStmt::CudaMalloc { ptr, size } => {
331                format!("{}cudaMalloc((void**)&{}, {});", ind, ptr, size.emit())
332            }
333            CudaStmt::CudaMemcpy {
334                dst,
335                src,
336                size,
337                kind,
338            } => {
339                format!(
340                    "{}cudaMemcpy({}, {}, {}, {});",
341                    ind,
342                    dst.emit(),
343                    src.emit(),
344                    size.emit(),
345                    kind
346                )
347            }
348            CudaStmt::CudaFree(ptr) => format!("{}cudaFree({});", ind, ptr.emit()),
349            CudaStmt::Return(Some(expr)) => format!("{}return {};", ind, expr.emit()),
350            CudaStmt::Return(None) => format!("{}return;", ind),
351            CudaStmt::Expr(expr) => format!("{}{};", ind, expr.emit()),
352            CudaStmt::DeviceSync => format!("{}cudaDeviceSynchronize();", ind),
353            CudaStmt::CheckError(expr) => format!("{}CUDA_CHECK({});", ind, expr.emit()),
354            CudaStmt::Block(stmts) => {
355                let mut out = format!("{}{{\n", ind);
356                for s in stmts {
357                    out.push_str(&self.emit_stmt(s, depth + 1));
358                    out.push('\n');
359                }
360                out.push_str(&format!("{}}}", ind));
361                out
362            }
363            CudaStmt::Break => format!("{}break;", ind),
364            CudaStmt::Continue => format!("{}continue;", ind),
365        }
366    }
367    pub(super) fn emit_if_else(
368        &self,
369        cond: &CudaExpr,
370        then_body: &[CudaStmt],
371        else_body: Option<&[CudaStmt]>,
372        depth: usize,
373    ) -> String {
374        let ind = self.indent(depth);
375        let inner = self.indent(depth + 1);
376        let mut out = format!("{}if ({}) {{\n", ind, cond.emit());
377        for s in then_body {
378            out.push_str(&self.emit_stmt(s, depth + 1));
379            out.push('\n');
380        }
381        out.push_str(&format!("{}}}", ind));
382        if let Some(eb) = else_body {
383            out.push_str(" else {\n");
384            for s in eb {
385                out.push_str(&self.emit_stmt(s, depth + 1));
386                out.push('\n');
387            }
388            out.push_str(&format!("{}}}", ind));
389        }
390        let _ = inner;
391        out
392    }
393    pub(super) fn emit_for_loop(
394        &self,
395        init: &CudaStmt,
396        cond: &CudaExpr,
397        step: &CudaExpr,
398        body: &[CudaStmt],
399        depth: usize,
400    ) -> String {
401        let ind = self.indent(depth);
402        let init_str = self.emit_stmt(init, 0).trim().to_string();
403        let init_header = init_str.trim_end_matches(';');
404        let mut out = format!(
405            "{}for ({}; {}; {}) {{\n",
406            ind,
407            init_header,
408            cond.emit(),
409            step.emit()
410        );
411        for s in body {
412            out.push_str(&self.emit_stmt(s, depth + 1));
413            out.push('\n');
414        }
415        out.push_str(&format!("{}}}", ind));
416        out
417    }
418    pub(super) fn emit_while(&self, cond: &CudaExpr, body: &[CudaStmt], depth: usize) -> String {
419        let ind = self.indent(depth);
420        let mut out = format!("{}while ({}) {{\n", ind, cond.emit());
421        for s in body {
422            out.push_str(&self.emit_stmt(s, depth + 1));
423            out.push('\n');
424        }
425        out.push_str(&format!("{}}}", ind));
426        out
427    }
428    pub(super) fn emit_kernel_launch(
429        &self,
430        name: &str,
431        config: &LaunchConfig,
432        args: &[CudaExpr],
433        depth: usize,
434    ) -> String {
435        let ind = self.indent(depth);
436        let grid = config.grid.emit();
437        let block = config.block.emit();
438        let shmem = config.shared_mem.emit();
439        let stream = config
440            .stream
441            .as_ref()
442            .map(|s| s.emit())
443            .unwrap_or_else(|| "0".to_string());
444        let arg_strs: Vec<String> = args.iter().map(|a| a.emit()).collect();
445        format!(
446            "{}{}<<<{}, {}, {}, {}>>>({});",
447            ind,
448            name,
449            grid,
450            block,
451            shmem,
452            stream,
453            arg_strs.join(", ")
454        )
455    }
456    pub(super) fn emit_device_function(&self, f: &DeviceFunction) -> String {
457        let quals: Vec<String> = f.qualifiers.iter().map(|q| format!("{}", q)).collect();
458        let inline_str = if f.is_inline { "inline " } else { "" };
459        let qual_str = quals.join(" ");
460        let params: Vec<String> = f.params.iter().map(|p| p.emit()).collect();
461        let mut out = format!(
462            "{}{} {} {}({}) {{\n",
463            inline_str,
464            qual_str,
465            f.ret,
466            f.name,
467            params.join(", ")
468        );
469        for s in &f.body {
470            out.push_str(&self.emit_stmt(s, 1));
471            out.push('\n');
472        }
473        out.push('}');
474        out
475    }
476    pub(super) fn emit_kernel(&self, k: &CudaKernel) -> String {
477        let lb = k
478            .launch_bounds
479            .as_ref()
480            .map(|lb| format!("{} ", lb.emit()))
481            .unwrap_or_default();
482        let params: Vec<String> = k.params.iter().map(|p| p.emit()).collect();
483        let mut out = format!(
484            "__global__ {}void {}({}) {{\n",
485            lb,
486            k.name,
487            params.join(", ")
488        );
489        for smd in &k.shared_mem_decls {
490            out.push_str(&format!("    {}\n", smd.emit()));
491        }
492        for s in &k.body {
493            out.push_str(&self.emit_stmt(s, 1));
494            out.push('\n');
495        }
496        out.push('}');
497        out
498    }
499    /// Emit the full `.cu` file as a `String`.
500    pub fn emit_module(&self, module: &CudaModule) -> String {
501        let mut out = String::new();
502        for inc in &module.includes {
503            let is_std = !inc.contains('/') && !inc.ends_with(".cuh");
504            if is_std {
505                out.push_str(&format!("#include <{}>\n", inc));
506            } else {
507                out.push_str(&format!("#include \"{}\"\n", inc));
508            }
509        }
510        if !module.includes.is_empty() {
511            out.push('\n');
512        }
513        out.push_str(
514            "#define CUDA_CHECK(err) \\\n\
515             do { \\\n\
516             \tcudaError_t _err = (err); \\\n\
517             \tif (_err != cudaSuccess) { \\\n\
518             \t\tfprintf(stderr, \"CUDA error %s:%d: %s\\n\", \\\n\
519             \t\t\t__FILE__, __LINE__, cudaGetErrorString(_err)); \\\n\
520             \t\texit(EXIT_FAILURE); \\\n\
521             \t} \\\n\
522             } while(0)\n\n",
523        );
524        for (ty, name, init) in &module.constant_decls {
525            match init {
526                Some(expr) => out.push_str(&format!(
527                    "__constant__ {} {} = {};\n",
528                    ty,
529                    name,
530                    expr.emit()
531                )),
532                None => out.push_str(&format!("__constant__ {} {};\n", ty, name)),
533            }
534        }
535        if !module.constant_decls.is_empty() {
536            out.push('\n');
537        }
538        for f in &module.device_functions {
539            out.push_str(&self.emit_device_function(f));
540            out.push_str("\n\n");
541        }
542        for k in &module.kernels {
543            out.push_str(&self.emit_kernel(k));
544            out.push_str("\n\n");
545        }
546        for block in &module.host_code {
547            out.push_str(block);
548            out.push_str("\n\n");
549        }
550        out
551    }
552}
553/// A `__shared__` memory declaration inside a kernel.
554#[derive(Debug, Clone, PartialEq)]
555pub struct SharedMemDecl {
556    /// Element type
557    pub ty: CudaType,
558    /// Variable name
559    pub name: String,
560    /// Array size (None for dynamic shared memory)
561    pub size: Option<CudaExpr>,
562}
563impl SharedMemDecl {
564    pub(super) fn emit(&self) -> String {
565        match &self.size {
566            Some(sz) => format!("__shared__ {} {}[{}];", self.ty, self.name, sz.emit()),
567            None => format!("extern __shared__ {} {}[];", self.ty, self.name),
568        }
569    }
570}
571/// Unary prefix operators.
572#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
573pub enum CudaUnOp {
574    Neg,
575    Not,
576    BitNot,
577    Deref,
578    AddrOf,
579}
580/// CUDA C++ expression AST node.
581#[derive(Debug, Clone, PartialEq)]
582pub enum CudaExpr {
583    /// Integer literal: `42`
584    LitInt(i64),
585    /// Float literal: `3.14f`
586    LitFloat(f64),
587    /// Boolean literal: `true` / `false`
588    LitBool(bool),
589    /// Named variable or parameter: `x`
590    Var(String),
591    /// `threadIdx.x`, `threadIdx.y`, `threadIdx.z`
592    ThreadIdx(char),
593    /// `blockIdx.x`, `blockIdx.y`, `blockIdx.z`
594    BlockIdx(char),
595    /// `blockDim.x`, `blockDim.y`, `blockDim.z`
596    BlockDim(char),
597    /// `gridDim.x`, `gridDim.y`, `gridDim.z`
598    GridDim(char),
599    /// `__syncthreads()`
600    SyncThreads,
601    /// `atomicAdd(addr, val)` — atomic addition
602    AtomicAdd(Box<CudaExpr>, Box<CudaExpr>),
603    /// `atomicSub(addr, val)`
604    AtomicSub(Box<CudaExpr>, Box<CudaExpr>),
605    /// `atomicExch(addr, val)`
606    AtomicExch(Box<CudaExpr>, Box<CudaExpr>),
607    /// `atomicCAS(addr, compare, val)`
608    AtomicCas(Box<CudaExpr>, Box<CudaExpr>, Box<CudaExpr>),
609    /// `atomicMax(addr, val)`
610    AtomicMax(Box<CudaExpr>, Box<CudaExpr>),
611    /// `atomicMin(addr, val)`
612    AtomicMin(Box<CudaExpr>, Box<CudaExpr>),
613    /// Binary operation: `a + b`
614    BinOp(Box<CudaExpr>, CudaBinOp, Box<CudaExpr>),
615    /// Unary operation: `!a`
616    UnOp(CudaUnOp, Box<CudaExpr>),
617    /// Array subscript: `arr[idx]`
618    Index(Box<CudaExpr>, Box<CudaExpr>),
619    /// Struct member access: `s.field`
620    Member(Box<CudaExpr>, String),
621    /// Pointer member access: `p->field`
622    PtrMember(Box<CudaExpr>, String),
623    /// C-style cast: `(T)expr`
624    Cast(CudaType, Box<CudaExpr>),
625    /// Function call: `func(args...)`
626    Call(String, Vec<CudaExpr>),
627    /// Ternary conditional: `cond ? then : else`
628    Ternary(Box<CudaExpr>, Box<CudaExpr>, Box<CudaExpr>),
629    /// `__ldg(&x)` — read-only cache load
630    Ldg(Box<CudaExpr>),
631    /// `__shfl_down_sync(mask, var, delta)`
632    ShflDownSync(Box<CudaExpr>, Box<CudaExpr>, Box<CudaExpr>),
633    /// `__shfl_xor_sync(mask, var, laneMask)`
634    ShflXorSync(Box<CudaExpr>, Box<CudaExpr>, Box<CudaExpr>),
635    /// `warpSize` builtin
636    WarpSize,
637    /// `__ballot_sync(mask, predicate)`
638    BallotSync(Box<CudaExpr>, Box<CudaExpr>),
639    /// `__popc(x)` — popcount
640    Popc(Box<CudaExpr>),
641}
642impl CudaExpr {
643    pub(super) fn emit(&self) -> String {
644        match self {
645            CudaExpr::LitInt(n) => n.to_string(),
646            CudaExpr::LitFloat(f) => {
647                let s = format!("{:.6}", f);
648                format!("{}f", s)
649            }
650            CudaExpr::LitBool(b) => if *b { "true" } else { "false" }.to_string(),
651            CudaExpr::Var(name) => name.clone(),
652            CudaExpr::ThreadIdx(c) => format!("threadIdx.{}", c),
653            CudaExpr::BlockIdx(c) => format!("blockIdx.{}", c),
654            CudaExpr::BlockDim(c) => format!("blockDim.{}", c),
655            CudaExpr::GridDim(c) => format!("gridDim.{}", c),
656            CudaExpr::SyncThreads => "__syncthreads()".to_string(),
657            CudaExpr::AtomicAdd(addr, val) => {
658                format!("atomicAdd({}, {})", addr.emit(), val.emit())
659            }
660            CudaExpr::AtomicSub(addr, val) => {
661                format!("atomicSub({}, {})", addr.emit(), val.emit())
662            }
663            CudaExpr::AtomicExch(addr, val) => {
664                format!("atomicExch({}, {})", addr.emit(), val.emit())
665            }
666            CudaExpr::AtomicCas(addr, cmp, val) => {
667                format!("atomicCAS({}, {}, {})", addr.emit(), cmp.emit(), val.emit())
668            }
669            CudaExpr::AtomicMax(addr, val) => {
670                format!("atomicMax({}, {})", addr.emit(), val.emit())
671            }
672            CudaExpr::AtomicMin(addr, val) => {
673                format!("atomicMin({}, {})", addr.emit(), val.emit())
674            }
675            CudaExpr::BinOp(lhs, op, rhs) => {
676                format!("({} {} {})", lhs.emit(), op, rhs.emit())
677            }
678            CudaExpr::UnOp(op, expr) => format!("({}{})", op, expr.emit()),
679            CudaExpr::Index(base, idx) => format!("{}[{}]", base.emit(), idx.emit()),
680            CudaExpr::Member(base, field) => format!("{}.{}", base.emit(), field),
681            CudaExpr::PtrMember(base, field) => format!("{}->{}", base.emit(), field),
682            CudaExpr::Cast(ty, expr) => format!("(({}){})", ty, expr.emit()),
683            CudaExpr::Call(name, args) => {
684                let arg_strs: Vec<String> = args.iter().map(|a| a.emit()).collect();
685                format!("{}({})", name, arg_strs.join(", "))
686            }
687            CudaExpr::Ternary(cond, then, els) => {
688                format!("({} ? {} : {})", cond.emit(), then.emit(), els.emit())
689            }
690            CudaExpr::Ldg(addr) => format!("__ldg({})", addr.emit()),
691            CudaExpr::ShflDownSync(mask, var, delta) => {
692                format!(
693                    "__shfl_down_sync({}, {}, {})",
694                    mask.emit(),
695                    var.emit(),
696                    delta.emit()
697                )
698            }
699            CudaExpr::ShflXorSync(mask, var, lane_mask) => {
700                format!(
701                    "__shfl_xor_sync({}, {}, {})",
702                    mask.emit(),
703                    var.emit(),
704                    lane_mask.emit()
705                )
706            }
707            CudaExpr::WarpSize => "warpSize".to_string(),
708            CudaExpr::BallotSync(mask, pred) => {
709                format!("__ballot_sync({}, {})", mask.emit(), pred.emit())
710            }
711            CudaExpr::Popc(x) => format!("__popc({})", x.emit()),
712        }
713    }
714}
715#[allow(dead_code)]
716#[derive(Debug, Clone)]
717pub struct CUDALivenessInfo {
718    pub live_in: Vec<std::collections::HashSet<u32>>,
719    pub live_out: Vec<std::collections::HashSet<u32>>,
720    pub defs: Vec<std::collections::HashSet<u32>>,
721    pub uses: Vec<std::collections::HashSet<u32>>,
722}
723impl CUDALivenessInfo {
724    #[allow(dead_code)]
725    pub fn new(block_count: usize) -> Self {
726        CUDALivenessInfo {
727            live_in: vec![std::collections::HashSet::new(); block_count],
728            live_out: vec![std::collections::HashSet::new(); block_count],
729            defs: vec![std::collections::HashSet::new(); block_count],
730            uses: vec![std::collections::HashSet::new(); block_count],
731        }
732    }
733    #[allow(dead_code)]
734    pub fn add_def(&mut self, block: usize, var: u32) {
735        if block < self.defs.len() {
736            self.defs[block].insert(var);
737        }
738    }
739    #[allow(dead_code)]
740    pub fn add_use(&mut self, block: usize, var: u32) {
741        if block < self.uses.len() {
742            self.uses[block].insert(var);
743        }
744    }
745    #[allow(dead_code)]
746    pub fn is_live_in(&self, block: usize, var: u32) -> bool {
747        self.live_in
748            .get(block)
749            .map(|s| s.contains(&var))
750            .unwrap_or(false)
751    }
752    #[allow(dead_code)]
753    pub fn is_live_out(&self, block: usize, var: u32) -> bool {
754        self.live_out
755            .get(block)
756            .map(|s| s.contains(&var))
757            .unwrap_or(false)
758    }
759}
760/// CUDA statement AST node.
761#[derive(Debug, Clone, PartialEq)]
762pub enum CudaStmt {
763    /// Variable declaration with optional initializer:
764    /// `CudaType name [ = init ];`
765    VarDecl {
766        ty: CudaType,
767        name: String,
768        init: Option<CudaExpr>,
769    },
770    /// Simple assignment: `lhs = rhs;`
771    Assign { lhs: CudaExpr, rhs: CudaExpr },
772    /// Compound assignment: `lhs += rhs;` etc.
773    CompoundAssign {
774        lhs: CudaExpr,
775        op: CudaBinOp,
776        rhs: CudaExpr,
777    },
778    /// If / optional else:
779    IfElse {
780        cond: CudaExpr,
781        then_body: Vec<CudaStmt>,
782        else_body: Option<Vec<CudaStmt>>,
783    },
784    /// C-style for loop:
785    /// `for (init; cond; step) { body }`
786    ForLoop {
787        init: Box<CudaStmt>,
788        cond: CudaExpr,
789        step: CudaExpr,
790        body: Vec<CudaStmt>,
791    },
792    /// While loop: `while (cond) { body }`
793    WhileLoop { cond: CudaExpr, body: Vec<CudaStmt> },
794    /// CUDA kernel launch: `name<<<grid, block, shmem, stream>>>(args...);`
795    KernelLaunch {
796        name: String,
797        config: LaunchConfig,
798        args: Vec<CudaExpr>,
799    },
800    /// `cudaMalloc((void**)&ptr, size);`
801    CudaMalloc { ptr: String, size: CudaExpr },
802    /// `cudaMemcpy(dst, src, size, kind);`
803    CudaMemcpy {
804        dst: CudaExpr,
805        src: CudaExpr,
806        size: CudaExpr,
807        kind: MemcpyKind,
808    },
809    /// `cudaFree(ptr);`
810    CudaFree(CudaExpr),
811    /// `return expr;`
812    Return(Option<CudaExpr>),
813    /// Raw expression statement: `expr;`
814    Expr(CudaExpr),
815    /// `cudaDeviceSynchronize();`
816    DeviceSync,
817    /// `cudaCheckError()` macro invocation
818    CheckError(CudaExpr),
819    /// Block of statements grouped with `{}`
820    Block(Vec<CudaStmt>),
821    /// `break;`
822    Break,
823    /// `continue;`
824    Continue,
825}
826/// A parameter in a CUDA kernel or device function.
827#[derive(Debug, Clone, PartialEq)]
828pub struct CudaParam {
829    /// CUDA type
830    pub ty: CudaType,
831    /// Parameter name
832    pub name: String,
833    /// Whether the parameter is `const`
834    pub is_const: bool,
835    /// Optional qualifier such as `__restrict__`
836    pub qualifier: Option<CudaQualifier>,
837}
838impl CudaParam {
839    /// Create a plain parameter.
840    pub fn new(ty: CudaType, name: impl Into<String>) -> Self {
841        CudaParam {
842            ty,
843            name: name.into(),
844            is_const: false,
845            qualifier: None,
846        }
847    }
848    /// Mark parameter as `const`.
849    pub fn with_const(mut self) -> Self {
850        self.is_const = true;
851        self
852    }
853    /// Add a CUDA qualifier (e.g. `__restrict__`).
854    pub fn with_qualifier(mut self, q: CudaQualifier) -> Self {
855        self.qualifier = Some(q);
856        self
857    }
858    pub(super) fn emit(&self) -> String {
859        let mut parts = Vec::new();
860        if self.is_const {
861            parts.push("const".to_string());
862        }
863        if let Some(q) = &self.qualifier {
864            parts.push(format!("{}", q));
865        }
866        parts.push(format!("{}", self.ty));
867        parts.push(self.name.clone());
868        parts.join(" ")
869    }
870}
871/// A `__device__` (or `__host__ __device__`) helper function.
872#[derive(Debug, Clone, PartialEq)]
873pub struct DeviceFunction {
874    /// Function name
875    pub name: String,
876    /// Qualifiers (should include at least `Device`)
877    pub qualifiers: Vec<CudaQualifier>,
878    /// Return type
879    pub ret: CudaType,
880    /// Parameter list
881    pub params: Vec<CudaParam>,
882    /// Body statements
883    pub body: Vec<CudaStmt>,
884    /// Whether the function is `inline`
885    pub is_inline: bool,
886}
887impl DeviceFunction {
888    /// Create a plain `__device__` function.
889    pub fn new(name: impl Into<String>, ret: CudaType) -> Self {
890        DeviceFunction {
891            name: name.into(),
892            qualifiers: vec![CudaQualifier::Device],
893            ret,
894            params: Vec::new(),
895            body: Vec::new(),
896            is_inline: false,
897        }
898    }
899    /// Create a `__host__ __device__` function.
900    pub fn host_device(name: impl Into<String>, ret: CudaType) -> Self {
901        DeviceFunction {
902            name: name.into(),
903            qualifiers: vec![CudaQualifier::Host, CudaQualifier::Device],
904            ret,
905            params: Vec::new(),
906            body: Vec::new(),
907            is_inline: false,
908        }
909    }
910    /// Mark as `inline`.
911    pub fn with_inline(mut self) -> Self {
912        self.is_inline = true;
913        self
914    }
915    /// Append a parameter.
916    pub fn add_param(mut self, p: CudaParam) -> Self {
917        self.params.push(p);
918        self
919    }
920    /// Append a body statement.
921    pub fn add_stmt(mut self, s: CudaStmt) -> Self {
922        self.body.push(s);
923        self
924    }
925}
926/// CUDA / C++ type representation used in generated `.cu` files.
927#[derive(Debug, Clone, PartialEq, Eq, Hash)]
928pub enum CudaType {
929    /// `int`
930    Int,
931    /// `long`
932    Long,
933    /// `float`
934    Float,
935    /// `double`
936    Double,
937    /// `__half` (CUDA half-precision float)
938    Half,
939    /// `bool`
940    Bool,
941    /// `dim3` (three-component grid/block dimension)
942    Dim3,
943    /// `size_t`
944    DimT,
945    /// `cudaError_t`
946    CudaErrorT,
947    /// Pointer to inner type: `T*`
948    Pointer(Box<CudaType>),
949    /// `__shared__` qualified type (used internally for shared-mem decls)
950    Shared(Box<CudaType>),
951    /// `__constant__` qualified type
952    Constant(Box<CudaType>),
953    /// `__device__` qualified type
954    Device(Box<CudaType>),
955    /// Void: `void`
956    Void,
957    /// Unsigned int: `unsigned int`
958    UInt,
959    /// Named struct or typedef
960    Named(String),
961}
962#[allow(dead_code)]
963#[derive(Debug, Clone)]
964pub struct CUDACacheEntry {
965    pub key: String,
966    pub data: Vec<u8>,
967    pub timestamp: u64,
968    pub valid: bool,
969}
970/// Constant folding helper for CUDAExt.
971#[allow(dead_code)]
972#[derive(Debug, Clone, Default)]
973pub struct CUDAExtConstFolder {
974    pub(super) folds: usize,
975    pub(super) failures: usize,
976    pub(super) enabled: bool,
977}
978impl CUDAExtConstFolder {
979    #[allow(dead_code)]
980    pub fn new() -> Self {
981        Self {
982            folds: 0,
983            failures: 0,
984            enabled: true,
985        }
986    }
987    #[allow(dead_code)]
988    pub fn add_i64(&mut self, a: i64, b: i64) -> Option<i64> {
989        self.folds += 1;
990        a.checked_add(b)
991    }
992    #[allow(dead_code)]
993    pub fn sub_i64(&mut self, a: i64, b: i64) -> Option<i64> {
994        self.folds += 1;
995        a.checked_sub(b)
996    }
997    #[allow(dead_code)]
998    pub fn mul_i64(&mut self, a: i64, b: i64) -> Option<i64> {
999        self.folds += 1;
1000        a.checked_mul(b)
1001    }
1002    #[allow(dead_code)]
1003    pub fn div_i64(&mut self, a: i64, b: i64) -> Option<i64> {
1004        if b == 0 {
1005            self.failures += 1;
1006            None
1007        } else {
1008            self.folds += 1;
1009            a.checked_div(b)
1010        }
1011    }
1012    #[allow(dead_code)]
1013    pub fn rem_i64(&mut self, a: i64, b: i64) -> Option<i64> {
1014        if b == 0 {
1015            self.failures += 1;
1016            None
1017        } else {
1018            self.folds += 1;
1019            a.checked_rem(b)
1020        }
1021    }
1022    #[allow(dead_code)]
1023    pub fn neg_i64(&mut self, a: i64) -> Option<i64> {
1024        self.folds += 1;
1025        a.checked_neg()
1026    }
1027    #[allow(dead_code)]
1028    pub fn shl_i64(&mut self, a: i64, s: u32) -> Option<i64> {
1029        if s >= 64 {
1030            self.failures += 1;
1031            None
1032        } else {
1033            self.folds += 1;
1034            a.checked_shl(s)
1035        }
1036    }
1037    #[allow(dead_code)]
1038    pub fn shr_i64(&mut self, a: i64, s: u32) -> Option<i64> {
1039        if s >= 64 {
1040            self.failures += 1;
1041            None
1042        } else {
1043            self.folds += 1;
1044            a.checked_shr(s)
1045        }
1046    }
1047    #[allow(dead_code)]
1048    pub fn and_i64(&mut self, a: i64, b: i64) -> i64 {
1049        self.folds += 1;
1050        a & b
1051    }
1052    #[allow(dead_code)]
1053    pub fn or_i64(&mut self, a: i64, b: i64) -> i64 {
1054        self.folds += 1;
1055        a | b
1056    }
1057    #[allow(dead_code)]
1058    pub fn xor_i64(&mut self, a: i64, b: i64) -> i64 {
1059        self.folds += 1;
1060        a ^ b
1061    }
1062    #[allow(dead_code)]
1063    pub fn not_i64(&mut self, a: i64) -> i64 {
1064        self.folds += 1;
1065        !a
1066    }
1067    #[allow(dead_code)]
1068    pub fn fold_count(&self) -> usize {
1069        self.folds
1070    }
1071    #[allow(dead_code)]
1072    pub fn failure_count(&self) -> usize {
1073        self.failures
1074    }
1075    #[allow(dead_code)]
1076    pub fn enable(&mut self) {
1077        self.enabled = true;
1078    }
1079    #[allow(dead_code)]
1080    pub fn disable(&mut self) {
1081        self.enabled = false;
1082    }
1083    #[allow(dead_code)]
1084    pub fn is_enabled(&self) -> bool {
1085        self.enabled
1086    }
1087}
1088/// Pass execution phase for CUDAExt.
1089#[allow(dead_code)]
1090#[derive(Debug, Clone, PartialEq, Eq, Hash)]
1091pub enum CUDAExtPassPhase {
1092    Early,
1093    Middle,
1094    Late,
1095    Finalize,
1096}
1097impl CUDAExtPassPhase {
1098    #[allow(dead_code)]
1099    pub fn is_early(&self) -> bool {
1100        matches!(self, Self::Early)
1101    }
1102    #[allow(dead_code)]
1103    pub fn is_middle(&self) -> bool {
1104        matches!(self, Self::Middle)
1105    }
1106    #[allow(dead_code)]
1107    pub fn is_late(&self) -> bool {
1108        matches!(self, Self::Late)
1109    }
1110    #[allow(dead_code)]
1111    pub fn is_finalize(&self) -> bool {
1112        matches!(self, Self::Finalize)
1113    }
1114    #[allow(dead_code)]
1115    pub fn order(&self) -> u32 {
1116        match self {
1117            Self::Early => 0,
1118            Self::Middle => 1,
1119            Self::Late => 2,
1120            Self::Finalize => 3,
1121        }
1122    }
1123    #[allow(dead_code)]
1124    pub fn from_order(n: u32) -> Option<Self> {
1125        match n {
1126            0 => Some(Self::Early),
1127            1 => Some(Self::Middle),
1128            2 => Some(Self::Late),
1129            3 => Some(Self::Finalize),
1130            _ => None,
1131        }
1132    }
1133}
1134/// Kind of `cudaMemcpy` transfer.
1135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
1136pub enum MemcpyKind {
1137    /// `cudaMemcpyHostToDevice`
1138    HostToDevice,
1139    /// `cudaMemcpyDeviceToHost`
1140    DeviceToHost,
1141    /// `cudaMemcpyDeviceToDevice`
1142    DeviceToDevice,
1143    /// `cudaMemcpyHostToHost`
1144    HostToHost,
1145}
1146/// Analysis cache for CUDAExt.
1147#[allow(dead_code)]
1148#[derive(Debug)]
1149pub struct CUDAExtCache {
1150    pub(super) entries: Vec<(u64, Vec<u8>, bool, u32)>,
1151    pub(super) cap: usize,
1152    pub(super) total_hits: u64,
1153    pub(super) total_misses: u64,
1154}
1155impl CUDAExtCache {
1156    #[allow(dead_code)]
1157    pub fn new(cap: usize) -> Self {
1158        Self {
1159            entries: Vec::new(),
1160            cap,
1161            total_hits: 0,
1162            total_misses: 0,
1163        }
1164    }
1165    #[allow(dead_code)]
1166    pub fn get(&mut self, key: u64) -> Option<&[u8]> {
1167        for e in self.entries.iter_mut() {
1168            if e.0 == key && e.2 {
1169                e.3 += 1;
1170                self.total_hits += 1;
1171                return Some(&e.1);
1172            }
1173        }
1174        self.total_misses += 1;
1175        None
1176    }
1177    #[allow(dead_code)]
1178    pub fn put(&mut self, key: u64, data: Vec<u8>) {
1179        if self.entries.len() >= self.cap {
1180            self.entries.retain(|e| e.2);
1181            if self.entries.len() >= self.cap {
1182                self.entries.remove(0);
1183            }
1184        }
1185        self.entries.push((key, data, true, 0));
1186    }
1187    #[allow(dead_code)]
1188    pub fn invalidate(&mut self) {
1189        for e in self.entries.iter_mut() {
1190            e.2 = false;
1191        }
1192    }
1193    #[allow(dead_code)]
1194    pub fn hit_rate(&self) -> f64 {
1195        let t = self.total_hits + self.total_misses;
1196        if t == 0 {
1197            0.0
1198        } else {
1199            self.total_hits as f64 / t as f64
1200        }
1201    }
1202    #[allow(dead_code)]
1203    pub fn live_count(&self) -> usize {
1204        self.entries.iter().filter(|e| e.2).count()
1205    }
1206}
1207/// Liveness analysis for CUDAExt.
1208#[allow(dead_code)]
1209#[derive(Debug, Clone, Default)]
1210pub struct CUDAExtLiveness {
1211    pub live_in: Vec<Vec<usize>>,
1212    pub live_out: Vec<Vec<usize>>,
1213    pub defs: Vec<Vec<usize>>,
1214    pub uses: Vec<Vec<usize>>,
1215}
1216impl CUDAExtLiveness {
1217    #[allow(dead_code)]
1218    pub fn new(n: usize) -> Self {
1219        Self {
1220            live_in: vec![Vec::new(); n],
1221            live_out: vec![Vec::new(); n],
1222            defs: vec![Vec::new(); n],
1223            uses: vec![Vec::new(); n],
1224        }
1225    }
1226    #[allow(dead_code)]
1227    pub fn live_in(&self, b: usize, v: usize) -> bool {
1228        self.live_in.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1229    }
1230    #[allow(dead_code)]
1231    pub fn live_out(&self, b: usize, v: usize) -> bool {
1232        self.live_out
1233            .get(b)
1234            .map(|s| s.contains(&v))
1235            .unwrap_or(false)
1236    }
1237    #[allow(dead_code)]
1238    pub fn add_def(&mut self, b: usize, v: usize) {
1239        if let Some(s) = self.defs.get_mut(b) {
1240            if !s.contains(&v) {
1241                s.push(v);
1242            }
1243        }
1244    }
1245    #[allow(dead_code)]
1246    pub fn add_use(&mut self, b: usize, v: usize) {
1247        if let Some(s) = self.uses.get_mut(b) {
1248            if !s.contains(&v) {
1249                s.push(v);
1250            }
1251        }
1252    }
1253    #[allow(dead_code)]
1254    pub fn var_is_used_in_block(&self, b: usize, v: usize) -> bool {
1255        self.uses.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1256    }
1257    #[allow(dead_code)]
1258    pub fn var_is_def_in_block(&self, b: usize, v: usize) -> bool {
1259        self.defs.get(b).map(|s| s.contains(&v)).unwrap_or(false)
1260    }
1261}
1262#[allow(dead_code)]
1263pub struct CUDAPassRegistry {
1264    pub(super) configs: Vec<CUDAPassConfig>,
1265    pub(super) stats: std::collections::HashMap<String, CUDAPassStats>,
1266}
1267impl CUDAPassRegistry {
1268    #[allow(dead_code)]
1269    pub fn new() -> Self {
1270        CUDAPassRegistry {
1271            configs: Vec::new(),
1272            stats: std::collections::HashMap::new(),
1273        }
1274    }
1275    #[allow(dead_code)]
1276    pub fn register(&mut self, config: CUDAPassConfig) {
1277        self.stats
1278            .insert(config.pass_name.clone(), CUDAPassStats::new());
1279        self.configs.push(config);
1280    }
1281    #[allow(dead_code)]
1282    pub fn enabled_passes(&self) -> Vec<&CUDAPassConfig> {
1283        self.configs.iter().filter(|c| c.enabled).collect()
1284    }
1285    #[allow(dead_code)]
1286    pub fn get_stats(&self, name: &str) -> Option<&CUDAPassStats> {
1287        self.stats.get(name)
1288    }
1289    #[allow(dead_code)]
1290    pub fn total_passes(&self) -> usize {
1291        self.configs.len()
1292    }
1293    #[allow(dead_code)]
1294    pub fn enabled_count(&self) -> usize {
1295        self.enabled_passes().len()
1296    }
1297    #[allow(dead_code)]
1298    pub fn update_stats(&mut self, name: &str, changes: u64, time_ms: u64, iter: u32) {
1299        if let Some(stats) = self.stats.get_mut(name) {
1300            stats.record_run(changes, time_ms, iter);
1301        }
1302    }
1303}
1304#[allow(dead_code)]
1305#[derive(Debug, Clone)]
1306pub struct CUDAAnalysisCache {
1307    pub(super) entries: std::collections::HashMap<String, CUDACacheEntry>,
1308    pub(super) max_size: usize,
1309    pub(super) hits: u64,
1310    pub(super) misses: u64,
1311}
1312impl CUDAAnalysisCache {
1313    #[allow(dead_code)]
1314    pub fn new(max_size: usize) -> Self {
1315        CUDAAnalysisCache {
1316            entries: std::collections::HashMap::new(),
1317            max_size,
1318            hits: 0,
1319            misses: 0,
1320        }
1321    }
1322    #[allow(dead_code)]
1323    pub fn get(&mut self, key: &str) -> Option<&CUDACacheEntry> {
1324        if self.entries.contains_key(key) {
1325            self.hits += 1;
1326            self.entries.get(key)
1327        } else {
1328            self.misses += 1;
1329            None
1330        }
1331    }
1332    #[allow(dead_code)]
1333    pub fn insert(&mut self, key: String, data: Vec<u8>) {
1334        if self.entries.len() >= self.max_size {
1335            if let Some(oldest) = self.entries.keys().next().cloned() {
1336                self.entries.remove(&oldest);
1337            }
1338        }
1339        self.entries.insert(
1340            key.clone(),
1341            CUDACacheEntry {
1342                key,
1343                data,
1344                timestamp: 0,
1345                valid: true,
1346            },
1347        );
1348    }
1349    #[allow(dead_code)]
1350    pub fn invalidate(&mut self, key: &str) {
1351        if let Some(entry) = self.entries.get_mut(key) {
1352            entry.valid = false;
1353        }
1354    }
1355    #[allow(dead_code)]
1356    pub fn clear(&mut self) {
1357        self.entries.clear();
1358    }
1359    #[allow(dead_code)]
1360    pub fn hit_rate(&self) -> f64 {
1361        let total = self.hits + self.misses;
1362        if total == 0 {
1363            return 0.0;
1364        }
1365        self.hits as f64 / total as f64
1366    }
1367    #[allow(dead_code)]
1368    pub fn size(&self) -> usize {
1369        self.entries.len()
1370    }
1371}
1372/// Dependency graph for CUDAExt.
1373#[allow(dead_code)]
1374#[derive(Debug, Clone)]
1375pub struct CUDAExtDepGraph {
1376    pub(super) n: usize,
1377    pub(super) adj: Vec<Vec<usize>>,
1378    pub(super) rev: Vec<Vec<usize>>,
1379    pub(super) edge_count: usize,
1380}
1381impl CUDAExtDepGraph {
1382    #[allow(dead_code)]
1383    pub fn new(n: usize) -> Self {
1384        Self {
1385            n,
1386            adj: vec![Vec::new(); n],
1387            rev: vec![Vec::new(); n],
1388            edge_count: 0,
1389        }
1390    }
1391    #[allow(dead_code)]
1392    pub fn add_edge(&mut self, from: usize, to: usize) {
1393        if from < self.n && to < self.n {
1394            if !self.adj[from].contains(&to) {
1395                self.adj[from].push(to);
1396                self.rev[to].push(from);
1397                self.edge_count += 1;
1398            }
1399        }
1400    }
1401    #[allow(dead_code)]
1402    pub fn succs(&self, n: usize) -> &[usize] {
1403        self.adj.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1404    }
1405    #[allow(dead_code)]
1406    pub fn preds(&self, n: usize) -> &[usize] {
1407        self.rev.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1408    }
1409    #[allow(dead_code)]
1410    pub fn topo_sort(&self) -> Option<Vec<usize>> {
1411        let mut deg: Vec<usize> = (0..self.n).map(|i| self.rev[i].len()).collect();
1412        let mut q: std::collections::VecDeque<usize> =
1413            (0..self.n).filter(|&i| deg[i] == 0).collect();
1414        let mut out = Vec::with_capacity(self.n);
1415        while let Some(u) = q.pop_front() {
1416            out.push(u);
1417            for &v in &self.adj[u] {
1418                deg[v] -= 1;
1419                if deg[v] == 0 {
1420                    q.push_back(v);
1421                }
1422            }
1423        }
1424        if out.len() == self.n {
1425            Some(out)
1426        } else {
1427            None
1428        }
1429    }
1430    #[allow(dead_code)]
1431    pub fn has_cycle(&self) -> bool {
1432        self.topo_sort().is_none()
1433    }
1434    #[allow(dead_code)]
1435    pub fn reachable(&self, start: usize) -> Vec<usize> {
1436        let mut vis = vec![false; self.n];
1437        let mut stk = vec![start];
1438        let mut out = Vec::new();
1439        while let Some(u) = stk.pop() {
1440            if u < self.n && !vis[u] {
1441                vis[u] = true;
1442                out.push(u);
1443                for &v in &self.adj[u] {
1444                    if !vis[v] {
1445                        stk.push(v);
1446                    }
1447                }
1448            }
1449        }
1450        out
1451    }
1452    #[allow(dead_code)]
1453    pub fn scc(&self) -> Vec<Vec<usize>> {
1454        let mut visited = vec![false; self.n];
1455        let mut order = Vec::new();
1456        for i in 0..self.n {
1457            if !visited[i] {
1458                let mut stk = vec![(i, 0usize)];
1459                while let Some((u, idx)) = stk.last_mut() {
1460                    if !visited[*u] {
1461                        visited[*u] = true;
1462                    }
1463                    if *idx < self.adj[*u].len() {
1464                        let v = self.adj[*u][*idx];
1465                        *idx += 1;
1466                        if !visited[v] {
1467                            stk.push((v, 0));
1468                        }
1469                    } else {
1470                        order.push(*u);
1471                        stk.pop();
1472                    }
1473                }
1474            }
1475        }
1476        let mut comp = vec![usize::MAX; self.n];
1477        let mut components: Vec<Vec<usize>> = Vec::new();
1478        for &start in order.iter().rev() {
1479            if comp[start] == usize::MAX {
1480                let cid = components.len();
1481                let mut component = Vec::new();
1482                let mut stk = vec![start];
1483                while let Some(u) = stk.pop() {
1484                    if comp[u] == usize::MAX {
1485                        comp[u] = cid;
1486                        component.push(u);
1487                        for &v in &self.rev[u] {
1488                            if comp[v] == usize::MAX {
1489                                stk.push(v);
1490                            }
1491                        }
1492                    }
1493                }
1494                components.push(component);
1495            }
1496        }
1497        components
1498    }
1499    #[allow(dead_code)]
1500    pub fn node_count(&self) -> usize {
1501        self.n
1502    }
1503    #[allow(dead_code)]
1504    pub fn edge_count(&self) -> usize {
1505        self.edge_count
1506    }
1507}
1508/// Statistics for CUDAExt passes.
1509#[allow(dead_code)]
1510#[derive(Debug, Clone, Default)]
1511pub struct CUDAExtPassStats {
1512    pub iterations: usize,
1513    pub changed: bool,
1514    pub nodes_visited: usize,
1515    pub nodes_modified: usize,
1516    pub time_ms: u64,
1517    pub memory_bytes: usize,
1518    pub errors: usize,
1519}
1520impl CUDAExtPassStats {
1521    #[allow(dead_code)]
1522    pub fn new() -> Self {
1523        Self::default()
1524    }
1525    #[allow(dead_code)]
1526    pub fn visit(&mut self) {
1527        self.nodes_visited += 1;
1528    }
1529    #[allow(dead_code)]
1530    pub fn modify(&mut self) {
1531        self.nodes_modified += 1;
1532        self.changed = true;
1533    }
1534    #[allow(dead_code)]
1535    pub fn iterate(&mut self) {
1536        self.iterations += 1;
1537    }
1538    #[allow(dead_code)]
1539    pub fn error(&mut self) {
1540        self.errors += 1;
1541    }
1542    #[allow(dead_code)]
1543    pub fn efficiency(&self) -> f64 {
1544        if self.nodes_visited == 0 {
1545            0.0
1546        } else {
1547            self.nodes_modified as f64 / self.nodes_visited as f64
1548        }
1549    }
1550    #[allow(dead_code)]
1551    pub fn merge(&mut self, o: &CUDAExtPassStats) {
1552        self.iterations += o.iterations;
1553        self.changed |= o.changed;
1554        self.nodes_visited += o.nodes_visited;
1555        self.nodes_modified += o.nodes_modified;
1556        self.time_ms += o.time_ms;
1557        self.memory_bytes = self.memory_bytes.max(o.memory_bytes);
1558        self.errors += o.errors;
1559    }
1560}
1561#[allow(dead_code)]
1562#[derive(Debug, Clone, PartialEq)]
1563pub enum CUDAPassPhase {
1564    Analysis,
1565    Transformation,
1566    Verification,
1567    Cleanup,
1568}
1569impl CUDAPassPhase {
1570    #[allow(dead_code)]
1571    pub fn name(&self) -> &str {
1572        match self {
1573            CUDAPassPhase::Analysis => "analysis",
1574            CUDAPassPhase::Transformation => "transformation",
1575            CUDAPassPhase::Verification => "verification",
1576            CUDAPassPhase::Cleanup => "cleanup",
1577        }
1578    }
1579    #[allow(dead_code)]
1580    pub fn is_modifying(&self) -> bool {
1581        matches!(self, CUDAPassPhase::Transformation | CUDAPassPhase::Cleanup)
1582    }
1583}
1584#[allow(dead_code)]
1585#[derive(Debug, Clone)]
1586pub struct CUDAPassConfig {
1587    pub phase: CUDAPassPhase,
1588    pub enabled: bool,
1589    pub max_iterations: u32,
1590    pub debug_output: bool,
1591    pub pass_name: String,
1592}
1593impl CUDAPassConfig {
1594    #[allow(dead_code)]
1595    pub fn new(name: impl Into<String>, phase: CUDAPassPhase) -> Self {
1596        CUDAPassConfig {
1597            phase,
1598            enabled: true,
1599            max_iterations: 10,
1600            debug_output: false,
1601            pass_name: name.into(),
1602        }
1603    }
1604    #[allow(dead_code)]
1605    pub fn disabled(mut self) -> Self {
1606        self.enabled = false;
1607        self
1608    }
1609    #[allow(dead_code)]
1610    pub fn with_debug(mut self) -> Self {
1611        self.debug_output = true;
1612        self
1613    }
1614    #[allow(dead_code)]
1615    pub fn max_iter(mut self, n: u32) -> Self {
1616        self.max_iterations = n;
1617        self
1618    }
1619}
1620#[allow(dead_code)]
1621#[derive(Debug, Clone)]
1622pub struct CUDADepGraph {
1623    pub(super) nodes: Vec<u32>,
1624    pub(super) edges: Vec<(u32, u32)>,
1625}
1626impl CUDADepGraph {
1627    #[allow(dead_code)]
1628    pub fn new() -> Self {
1629        CUDADepGraph {
1630            nodes: Vec::new(),
1631            edges: Vec::new(),
1632        }
1633    }
1634    #[allow(dead_code)]
1635    pub fn add_node(&mut self, id: u32) {
1636        if !self.nodes.contains(&id) {
1637            self.nodes.push(id);
1638        }
1639    }
1640    #[allow(dead_code)]
1641    pub fn add_dep(&mut self, dep: u32, dependent: u32) {
1642        self.add_node(dep);
1643        self.add_node(dependent);
1644        self.edges.push((dep, dependent));
1645    }
1646    #[allow(dead_code)]
1647    pub fn dependents_of(&self, node: u32) -> Vec<u32> {
1648        self.edges
1649            .iter()
1650            .filter(|(d, _)| *d == node)
1651            .map(|(_, dep)| *dep)
1652            .collect()
1653    }
1654    #[allow(dead_code)]
1655    pub fn dependencies_of(&self, node: u32) -> Vec<u32> {
1656        self.edges
1657            .iter()
1658            .filter(|(_, dep)| *dep == node)
1659            .map(|(d, _)| *d)
1660            .collect()
1661    }
1662    #[allow(dead_code)]
1663    pub fn topological_sort(&self) -> Vec<u32> {
1664        let mut in_degree: std::collections::HashMap<u32, u32> = std::collections::HashMap::new();
1665        for &n in &self.nodes {
1666            in_degree.insert(n, 0);
1667        }
1668        for (_, dep) in &self.edges {
1669            *in_degree.entry(*dep).or_insert(0) += 1;
1670        }
1671        let mut queue: std::collections::VecDeque<u32> = self
1672            .nodes
1673            .iter()
1674            .filter(|&&n| in_degree[&n] == 0)
1675            .copied()
1676            .collect();
1677        let mut result = Vec::new();
1678        while let Some(node) = queue.pop_front() {
1679            result.push(node);
1680            for dep in self.dependents_of(node) {
1681                let cnt = in_degree.entry(dep).or_insert(0);
1682                *cnt = cnt.saturating_sub(1);
1683                if *cnt == 0 {
1684                    queue.push_back(dep);
1685                }
1686            }
1687        }
1688        result
1689    }
1690    #[allow(dead_code)]
1691    pub fn has_cycle(&self) -> bool {
1692        self.topological_sort().len() < self.nodes.len()
1693    }
1694}
1695#[allow(dead_code)]
1696pub struct CUDAConstantFoldingHelper;
1697impl CUDAConstantFoldingHelper {
1698    #[allow(dead_code)]
1699    pub fn fold_add_i64(a: i64, b: i64) -> Option<i64> {
1700        a.checked_add(b)
1701    }
1702    #[allow(dead_code)]
1703    pub fn fold_sub_i64(a: i64, b: i64) -> Option<i64> {
1704        a.checked_sub(b)
1705    }
1706    #[allow(dead_code)]
1707    pub fn fold_mul_i64(a: i64, b: i64) -> Option<i64> {
1708        a.checked_mul(b)
1709    }
1710    #[allow(dead_code)]
1711    pub fn fold_div_i64(a: i64, b: i64) -> Option<i64> {
1712        if b == 0 {
1713            None
1714        } else {
1715            a.checked_div(b)
1716        }
1717    }
1718    #[allow(dead_code)]
1719    pub fn fold_add_f64(a: f64, b: f64) -> f64 {
1720        a + b
1721    }
1722    #[allow(dead_code)]
1723    pub fn fold_mul_f64(a: f64, b: f64) -> f64 {
1724        a * b
1725    }
1726    #[allow(dead_code)]
1727    pub fn fold_neg_i64(a: i64) -> Option<i64> {
1728        a.checked_neg()
1729    }
1730    #[allow(dead_code)]
1731    pub fn fold_not_bool(a: bool) -> bool {
1732        !a
1733    }
1734    #[allow(dead_code)]
1735    pub fn fold_and_bool(a: bool, b: bool) -> bool {
1736        a && b
1737    }
1738    #[allow(dead_code)]
1739    pub fn fold_or_bool(a: bool, b: bool) -> bool {
1740        a || b
1741    }
1742    #[allow(dead_code)]
1743    pub fn fold_shl_i64(a: i64, b: u32) -> Option<i64> {
1744        a.checked_shl(b)
1745    }
1746    #[allow(dead_code)]
1747    pub fn fold_shr_i64(a: i64, b: u32) -> Option<i64> {
1748        a.checked_shr(b)
1749    }
1750    #[allow(dead_code)]
1751    pub fn fold_rem_i64(a: i64, b: i64) -> Option<i64> {
1752        if b == 0 {
1753            None
1754        } else {
1755            Some(a % b)
1756        }
1757    }
1758    #[allow(dead_code)]
1759    pub fn fold_bitand_i64(a: i64, b: i64) -> i64 {
1760        a & b
1761    }
1762    #[allow(dead_code)]
1763    pub fn fold_bitor_i64(a: i64, b: i64) -> i64 {
1764        a | b
1765    }
1766    #[allow(dead_code)]
1767    pub fn fold_bitxor_i64(a: i64, b: i64) -> i64 {
1768        a ^ b
1769    }
1770    #[allow(dead_code)]
1771    pub fn fold_bitnot_i64(a: i64) -> i64 {
1772        !a
1773    }
1774}
1775/// Configuration for CUDAExt passes.
1776#[allow(dead_code)]
1777#[derive(Debug, Clone)]
1778pub struct CUDAExtPassConfig {
1779    pub name: String,
1780    pub phase: CUDAExtPassPhase,
1781    pub enabled: bool,
1782    pub max_iterations: usize,
1783    pub debug: u32,
1784    pub timeout_ms: Option<u64>,
1785}
1786impl CUDAExtPassConfig {
1787    #[allow(dead_code)]
1788    pub fn new(name: impl Into<String>) -> Self {
1789        Self {
1790            name: name.into(),
1791            phase: CUDAExtPassPhase::Middle,
1792            enabled: true,
1793            max_iterations: 100,
1794            debug: 0,
1795            timeout_ms: None,
1796        }
1797    }
1798    #[allow(dead_code)]
1799    pub fn with_phase(mut self, phase: CUDAExtPassPhase) -> Self {
1800        self.phase = phase;
1801        self
1802    }
1803    #[allow(dead_code)]
1804    pub fn with_max_iter(mut self, n: usize) -> Self {
1805        self.max_iterations = n;
1806        self
1807    }
1808    #[allow(dead_code)]
1809    pub fn with_debug(mut self, d: u32) -> Self {
1810        self.debug = d;
1811        self
1812    }
1813    #[allow(dead_code)]
1814    pub fn disabled(mut self) -> Self {
1815        self.enabled = false;
1816        self
1817    }
1818    #[allow(dead_code)]
1819    pub fn with_timeout(mut self, ms: u64) -> Self {
1820        self.timeout_ms = Some(ms);
1821        self
1822    }
1823    #[allow(dead_code)]
1824    pub fn is_debug_enabled(&self) -> bool {
1825        self.debug > 0
1826    }
1827}
1828/// CUDA function / variable qualifiers.
1829#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1830pub enum CudaQualifier {
1831    /// `__global__` — kernel callable from host, runs on device
1832    Global,
1833    /// `__device__` — callable/usable only on device
1834    Device,
1835    /// `__host__` — callable only from host (default)
1836    Host,
1837    /// `__shared__` — shared memory within a thread block
1838    Shared,
1839    /// `__constant__` — read-only constant memory
1840    Constant,
1841    /// `__managed__` — accessible from both host and device
1842    Managed,
1843    /// `__restrict__` — pointer alias hint
1844    Restrict,
1845    /// `volatile` — volatile memory access
1846    Volatile,
1847}
1848/// Optional launch-bounds hint: `__launch_bounds__(maxThreads[, minBlocks])`.
1849#[derive(Debug, Clone, PartialEq)]
1850pub struct LaunchBounds {
1851    /// Maximum threads per block
1852    pub max_threads: u32,
1853    /// Minimum blocks per multiprocessor (optional)
1854    pub min_blocks: Option<u32>,
1855}
1856impl LaunchBounds {
1857    /// Create launch bounds with only a max-thread count.
1858    pub fn new(max_threads: u32) -> Self {
1859        LaunchBounds {
1860            max_threads,
1861            min_blocks: None,
1862        }
1863    }
1864    /// Create launch bounds with both max-threads and min-blocks.
1865    pub fn with_min_blocks(max_threads: u32, min_blocks: u32) -> Self {
1866        LaunchBounds {
1867            max_threads,
1868            min_blocks: Some(min_blocks),
1869        }
1870    }
1871    pub(super) fn emit(&self) -> String {
1872        match self.min_blocks {
1873            Some(mb) => format!("__launch_bounds__({}, {})", self.max_threads, mb),
1874            None => format!("__launch_bounds__({})", self.max_threads),
1875        }
1876    }
1877}
1878/// Dominator tree for CUDAExt.
1879#[allow(dead_code)]
1880#[derive(Debug, Clone)]
1881pub struct CUDAExtDomTree {
1882    pub(super) idom: Vec<Option<usize>>,
1883    pub(super) children: Vec<Vec<usize>>,
1884    pub(super) depth: Vec<usize>,
1885}
1886impl CUDAExtDomTree {
1887    #[allow(dead_code)]
1888    pub fn new(n: usize) -> Self {
1889        Self {
1890            idom: vec![None; n],
1891            children: vec![Vec::new(); n],
1892            depth: vec![0; n],
1893        }
1894    }
1895    #[allow(dead_code)]
1896    pub fn set_idom(&mut self, node: usize, dom: usize) {
1897        if node < self.idom.len() {
1898            self.idom[node] = Some(dom);
1899            if dom < self.children.len() {
1900                self.children[dom].push(node);
1901            }
1902            self.depth[node] = if dom < self.depth.len() {
1903                self.depth[dom] + 1
1904            } else {
1905                1
1906            };
1907        }
1908    }
1909    #[allow(dead_code)]
1910    pub fn dominates(&self, a: usize, mut b: usize) -> bool {
1911        if a == b {
1912            return true;
1913        }
1914        let n = self.idom.len();
1915        for _ in 0..n {
1916            match self.idom.get(b).copied().flatten() {
1917                None => return false,
1918                Some(p) if p == a => return true,
1919                Some(p) if p == b => return false,
1920                Some(p) => b = p,
1921            }
1922        }
1923        false
1924    }
1925    #[allow(dead_code)]
1926    pub fn children_of(&self, n: usize) -> &[usize] {
1927        self.children.get(n).map(|v| v.as_slice()).unwrap_or(&[])
1928    }
1929    #[allow(dead_code)]
1930    pub fn depth_of(&self, n: usize) -> usize {
1931        self.depth.get(n).copied().unwrap_or(0)
1932    }
1933    #[allow(dead_code)]
1934    pub fn lca(&self, mut a: usize, mut b: usize) -> usize {
1935        let n = self.idom.len();
1936        for _ in 0..(2 * n) {
1937            if a == b {
1938                return a;
1939            }
1940            if self.depth_of(a) > self.depth_of(b) {
1941                a = self.idom.get(a).and_then(|x| *x).unwrap_or(a);
1942            } else {
1943                b = self.idom.get(b).and_then(|x| *x).unwrap_or(b);
1944            }
1945        }
1946        0
1947    }
1948}
1949/// Binary operators available in CUDA C++ expressions.
1950#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
1951pub enum CudaBinOp {
1952    Add,
1953    Sub,
1954    Mul,
1955    Div,
1956    Mod,
1957    Eq,
1958    Neq,
1959    Lt,
1960    Le,
1961    Gt,
1962    Ge,
1963    And,
1964    Or,
1965    BitAnd,
1966    BitOr,
1967    BitXor,
1968    Shl,
1969    Shr,
1970}
1971#[allow(dead_code)]
1972#[derive(Debug, Clone)]
1973pub struct CUDADominatorTree {
1974    pub idom: Vec<Option<u32>>,
1975    pub dom_children: Vec<Vec<u32>>,
1976    pub dom_depth: Vec<u32>,
1977}
1978impl CUDADominatorTree {
1979    #[allow(dead_code)]
1980    pub fn new(size: usize) -> Self {
1981        CUDADominatorTree {
1982            idom: vec![None; size],
1983            dom_children: vec![Vec::new(); size],
1984            dom_depth: vec![0; size],
1985        }
1986    }
1987    #[allow(dead_code)]
1988    pub fn set_idom(&mut self, node: usize, idom: u32) {
1989        self.idom[node] = Some(idom);
1990    }
1991    #[allow(dead_code)]
1992    pub fn dominates(&self, a: usize, b: usize) -> bool {
1993        if a == b {
1994            return true;
1995        }
1996        let mut cur = b;
1997        loop {
1998            match self.idom[cur] {
1999                Some(parent) if parent as usize == a => return true,
2000                Some(parent) if parent as usize == cur => return false,
2001                Some(parent) => cur = parent as usize,
2002                None => return false,
2003            }
2004        }
2005    }
2006    #[allow(dead_code)]
2007    pub fn depth(&self, node: usize) -> u32 {
2008        self.dom_depth.get(node).copied().unwrap_or(0)
2009    }
2010}
2011/// Pass registry for CUDAExt.
2012#[allow(dead_code)]
2013#[derive(Debug, Default)]
2014pub struct CUDAExtPassRegistry {
2015    pub(super) configs: Vec<CUDAExtPassConfig>,
2016    pub(super) stats: Vec<CUDAExtPassStats>,
2017}
2018impl CUDAExtPassRegistry {
2019    #[allow(dead_code)]
2020    pub fn new() -> Self {
2021        Self::default()
2022    }
2023    #[allow(dead_code)]
2024    pub fn register(&mut self, c: CUDAExtPassConfig) {
2025        self.stats.push(CUDAExtPassStats::new());
2026        self.configs.push(c);
2027    }
2028    #[allow(dead_code)]
2029    pub fn len(&self) -> usize {
2030        self.configs.len()
2031    }
2032    #[allow(dead_code)]
2033    pub fn is_empty(&self) -> bool {
2034        self.configs.is_empty()
2035    }
2036    #[allow(dead_code)]
2037    pub fn get(&self, i: usize) -> Option<&CUDAExtPassConfig> {
2038        self.configs.get(i)
2039    }
2040    #[allow(dead_code)]
2041    pub fn get_stats(&self, i: usize) -> Option<&CUDAExtPassStats> {
2042        self.stats.get(i)
2043    }
2044    #[allow(dead_code)]
2045    pub fn enabled_passes(&self) -> Vec<&CUDAExtPassConfig> {
2046        self.configs.iter().filter(|c| c.enabled).collect()
2047    }
2048    #[allow(dead_code)]
2049    pub fn passes_in_phase(&self, ph: &CUDAExtPassPhase) -> Vec<&CUDAExtPassConfig> {
2050        self.configs
2051            .iter()
2052            .filter(|c| c.enabled && &c.phase == ph)
2053            .collect()
2054    }
2055    #[allow(dead_code)]
2056    pub fn total_nodes_visited(&self) -> usize {
2057        self.stats.iter().map(|s| s.nodes_visited).sum()
2058    }
2059    #[allow(dead_code)]
2060    pub fn any_changed(&self) -> bool {
2061        self.stats.iter().any(|s| s.changed)
2062    }
2063}