Skip to main content

oxilean_codegen/cuda_backend/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use super::types::{
6    CUDAAnalysisCache, CUDAConstantFoldingHelper, CUDADepGraph, CUDADominatorTree, CUDAExtCache,
7    CUDAExtConstFolder, CUDAExtDepGraph, CUDAExtDomTree, CUDAExtLiveness, CUDAExtPassConfig,
8    CUDAExtPassPhase, CUDAExtPassRegistry, CUDAExtPassStats, CUDAExtWorklist, CUDALivenessInfo,
9    CUDAPassConfig, CUDAPassPhase, CUDAPassRegistry, CUDAPassStats, CUDAWorklist, CudaBackend,
10    CudaBinOp, CudaExpr, CudaKernel, CudaModule, CudaParam, CudaQualifier, CudaStmt, CudaType,
11    DeviceFunction, LaunchBounds, LaunchConfig, MemcpyKind,
12};
13
14#[cfg(test)]
15mod tests {
16    use super::*;
17    #[test]
18    pub(super) fn test_cuda_type_display() {
19        assert_eq!(format!("{}", CudaType::Int), "int");
20        assert_eq!(format!("{}", CudaType::Float), "float");
21        assert_eq!(format!("{}", CudaType::Double), "double");
22        assert_eq!(format!("{}", CudaType::Half), "__half");
23        assert_eq!(format!("{}", CudaType::Dim3), "dim3");
24        assert_eq!(format!("{}", CudaType::CudaErrorT), "cudaError_t");
25        assert_eq!(
26            format!("{}", CudaType::Pointer(Box::new(CudaType::Float))),
27            "float*"
28        );
29        assert_eq!(
30            format!("{}", CudaType::Shared(Box::new(CudaType::Float))),
31            "__shared__ float"
32        );
33        assert_eq!(
34            format!("{}", CudaType::Constant(Box::new(CudaType::Int))),
35            "__constant__ int"
36        );
37        assert_eq!(
38            format!("{}", CudaType::Device(Box::new(CudaType::Double))),
39            "__device__ double"
40        );
41    }
42    #[test]
43    pub(super) fn test_cuda_qualifier_display() {
44        assert_eq!(format!("{}", CudaQualifier::Global), "__global__");
45        assert_eq!(format!("{}", CudaQualifier::Device), "__device__");
46        assert_eq!(format!("{}", CudaQualifier::Host), "__host__");
47        assert_eq!(format!("{}", CudaQualifier::Shared), "__shared__");
48        assert_eq!(format!("{}", CudaQualifier::Constant), "__constant__");
49        assert_eq!(format!("{}", CudaQualifier::Managed), "__managed__");
50        assert_eq!(format!("{}", CudaQualifier::Restrict), "__restrict__");
51        assert_eq!(format!("{}", CudaQualifier::Volatile), "volatile");
52    }
53    #[test]
54    pub(super) fn test_cuda_expr_emit() {
55        let backend = CudaBackend::new();
56        assert_eq!(backend.emit_expr(&CudaExpr::ThreadIdx('x')), "threadIdx.x");
57        assert_eq!(backend.emit_expr(&CudaExpr::BlockIdx('y')), "blockIdx.y");
58        assert_eq!(backend.emit_expr(&CudaExpr::BlockDim('z')), "blockDim.z");
59        assert_eq!(backend.emit_expr(&CudaExpr::GridDim('x')), "gridDim.x");
60        assert_eq!(backend.emit_expr(&CudaExpr::SyncThreads), "__syncthreads()");
61        assert_eq!(backend.emit_expr(&CudaExpr::WarpSize), "warpSize");
62        let add = CudaExpr::BinOp(
63            Box::new(CudaExpr::Var("a".into())),
64            CudaBinOp::Add,
65            Box::new(CudaExpr::Var("b".into())),
66        );
67        assert_eq!(backend.emit_expr(&add), "(a + b)");
68        let atomic = CudaExpr::AtomicAdd(
69            Box::new(CudaExpr::Var("ptr".into())),
70            Box::new(CudaExpr::LitInt(1)),
71        );
72        assert_eq!(backend.emit_expr(&atomic), "atomicAdd(ptr, 1)");
73        let cast = CudaExpr::Cast(CudaType::Float, Box::new(CudaExpr::Var("n".into())));
74        assert_eq!(backend.emit_expr(&cast), "((float)n)");
75    }
76    #[test]
77    pub(super) fn test_cuda_stmt_emit() {
78        let backend = CudaBackend::new();
79        let decl = CudaStmt::VarDecl {
80            ty: CudaType::Int,
81            name: "idx".into(),
82            init: Some(CudaExpr::LitInt(0)),
83        };
84        assert_eq!(backend.emit_stmt(&decl, 0), "int idx = 0;");
85        let malloc = CudaStmt::CudaMalloc {
86            ptr: "d_data".into(),
87            size: CudaExpr::LitInt(1024),
88        };
89        let s = backend.emit_stmt(&malloc, 0);
90        assert!(s.contains("cudaMalloc"));
91        assert!(s.contains("(void**)&d_data"));
92        assert!(s.contains("1024"));
93        let memcpy = CudaStmt::CudaMemcpy {
94            dst: CudaExpr::Var("d_data".into()),
95            src: CudaExpr::Var("h_data".into()),
96            size: CudaExpr::LitInt(1024),
97            kind: MemcpyKind::HostToDevice,
98        };
99        let s = backend.emit_stmt(&memcpy, 0);
100        assert!(s.contains("cudaMemcpy"));
101        assert!(s.contains("cudaMemcpyHostToDevice"));
102        let ret = CudaStmt::Return(Some(CudaExpr::LitInt(0)));
103        assert_eq!(backend.emit_stmt(&ret, 0), "return 0;");
104        let sync = CudaStmt::DeviceSync;
105        assert!(backend
106            .emit_stmt(&sync, 0)
107            .contains("cudaDeviceSynchronize"));
108    }
109    #[test]
110    pub(super) fn test_kernel_emit() {
111        let backend = CudaBackend::new();
112        let idx_decl = CudaStmt::VarDecl {
113            ty: CudaType::Int,
114            name: "idx".into(),
115            init: Some(CudaExpr::BinOp(
116                Box::new(CudaExpr::BinOp(
117                    Box::new(CudaExpr::BlockIdx('x')),
118                    CudaBinOp::Mul,
119                    Box::new(CudaExpr::BlockDim('x')),
120                )),
121                CudaBinOp::Add,
122                Box::new(CudaExpr::ThreadIdx('x')),
123            )),
124        };
125        let guard = CudaStmt::IfElse {
126            cond: CudaExpr::BinOp(
127                Box::new(CudaExpr::Var("idx".into())),
128                CudaBinOp::Lt,
129                Box::new(CudaExpr::Var("n".into())),
130            ),
131            then_body: vec![CudaStmt::Assign {
132                lhs: CudaExpr::Index(
133                    Box::new(CudaExpr::Var("c".into())),
134                    Box::new(CudaExpr::Var("idx".into())),
135                ),
136                rhs: CudaExpr::BinOp(
137                    Box::new(CudaExpr::Index(
138                        Box::new(CudaExpr::Var("a".into())),
139                        Box::new(CudaExpr::Var("idx".into())),
140                    )),
141                    CudaBinOp::Add,
142                    Box::new(CudaExpr::Index(
143                        Box::new(CudaExpr::Var("b".into())),
144                        Box::new(CudaExpr::Var("idx".into())),
145                    )),
146                ),
147            }],
148            else_body: None,
149        };
150        let kernel = CudaKernel::new("vec_add")
151            .add_param(CudaParam::new(
152                CudaType::Pointer(Box::new(CudaType::Float)),
153                "a",
154            ))
155            .add_param(CudaParam::new(
156                CudaType::Pointer(Box::new(CudaType::Float)),
157                "b",
158            ))
159            .add_param(CudaParam::new(
160                CudaType::Pointer(Box::new(CudaType::Float)),
161                "c",
162            ))
163            .add_param(CudaParam::new(CudaType::Int, "n"))
164            .add_stmt(idx_decl)
165            .add_stmt(guard)
166            .with_launch_bounds(LaunchBounds::new(256));
167        let cu = backend.emit_kernel(&kernel);
168        assert!(cu.contains("__global__"));
169        assert!(cu.contains("vec_add"));
170        assert!(cu.contains("__launch_bounds__(256)"));
171        assert!(cu.contains("threadIdx.x"));
172        assert!(cu.contains("blockIdx.x"));
173        assert!(cu.contains("blockDim.x"));
174    }
175    #[test]
176    pub(super) fn test_kernel_launch_stmt() {
177        let backend = CudaBackend::new();
178        let config =
179            LaunchConfig::simple_1d(CudaExpr::Var("grid".into()), CudaExpr::Var("block".into()));
180        let launch = CudaStmt::KernelLaunch {
181            name: "my_kernel".into(),
182            config,
183            args: vec![
184                CudaExpr::Var("d_a".into()),
185                CudaExpr::Var("d_b".into()),
186                CudaExpr::LitInt(1024),
187            ],
188        };
189        let s = backend.emit_stmt(&launch, 0);
190        assert!(s.contains("my_kernel<<<"));
191        assert!(s.contains("grid"));
192        assert!(s.contains("block"));
193        assert!(s.contains("d_a"));
194        assert!(s.contains("d_b"));
195        assert!(s.contains("1024"));
196    }
197    #[test]
198    pub(super) fn test_module_emit() {
199        let backend = CudaBackend::new();
200        let module = CudaModule::new()
201            .add_constant(CudaType::Int, "BLOCK_SIZE", Some(CudaExpr::LitInt(256)))
202            .add_kernel(CudaKernel::new("dummy_kernel").add_stmt(CudaStmt::Return(None)));
203        let cu = backend.emit_module(&module);
204        assert!(cu.contains("#include <cuda_runtime.h>"));
205        assert!(cu.contains("__constant__ int BLOCK_SIZE = 256;"));
206        assert!(cu.contains("__global__"));
207        assert!(cu.contains("dummy_kernel"));
208        assert!(cu.contains("CUDA_CHECK"));
209    }
210    #[test]
211    pub(super) fn test_device_function_and_warp_intrinsics() {
212        let backend = CudaBackend::new();
213        let shfl = CudaExpr::ShflDownSync(
214            Box::new(CudaExpr::LitInt(0xffffffff)),
215            Box::new(CudaExpr::Var("val".into())),
216            Box::new(CudaExpr::LitInt(16)),
217        );
218        let f = DeviceFunction::host_device("warp_reduce_sum", CudaType::Float)
219            .with_inline()
220            .add_param(CudaParam::new(CudaType::Float, "val"))
221            .add_stmt(CudaStmt::Expr(shfl))
222            .add_stmt(CudaStmt::Return(Some(CudaExpr::Var("val".into()))));
223        let src = backend.emit_device_function(&f);
224        assert!(src.contains("__host__"));
225        assert!(src.contains("__device__"));
226        assert!(src.contains("inline"));
227        assert!(src.contains("warp_reduce_sum"));
228        assert!(src.contains("__shfl_down_sync"));
229        let ballot = CudaExpr::BallotSync(
230            Box::new(CudaExpr::LitInt(0xffffffff)),
231            Box::new(CudaExpr::Var("pred".into())),
232        );
233        assert!(backend.emit_expr(&ballot).contains("__ballot_sync"));
234        let popc = CudaExpr::Popc(Box::new(CudaExpr::Var("mask".into())));
235        assert!(backend.emit_expr(&popc).contains("__popc"));
236    }
237}
238#[cfg(test)]
239mod CUDA_infra_tests {
240    use super::*;
241    #[test]
242    pub(super) fn test_pass_config() {
243        let config = CUDAPassConfig::new("test_pass", CUDAPassPhase::Transformation);
244        assert!(config.enabled);
245        assert!(config.phase.is_modifying());
246        assert_eq!(config.phase.name(), "transformation");
247    }
248    #[test]
249    pub(super) fn test_pass_stats() {
250        let mut stats = CUDAPassStats::new();
251        stats.record_run(10, 100, 3);
252        stats.record_run(20, 200, 5);
253        assert_eq!(stats.total_runs, 2);
254        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
255        assert!((stats.success_rate() - 1.0).abs() < 0.01);
256        let s = stats.format_summary();
257        assert!(s.contains("Runs: 2/2"));
258    }
259    #[test]
260    pub(super) fn test_pass_registry() {
261        let mut reg = CUDAPassRegistry::new();
262        reg.register(CUDAPassConfig::new("pass_a", CUDAPassPhase::Analysis));
263        reg.register(CUDAPassConfig::new("pass_b", CUDAPassPhase::Transformation).disabled());
264        assert_eq!(reg.total_passes(), 2);
265        assert_eq!(reg.enabled_count(), 1);
266        reg.update_stats("pass_a", 5, 50, 2);
267        let stats = reg.get_stats("pass_a").expect("stats should exist");
268        assert_eq!(stats.total_changes, 5);
269    }
270    #[test]
271    pub(super) fn test_analysis_cache() {
272        let mut cache = CUDAAnalysisCache::new(10);
273        cache.insert("key1".to_string(), vec![1, 2, 3]);
274        assert!(cache.get("key1").is_some());
275        assert!(cache.get("key2").is_none());
276        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
277        cache.invalidate("key1");
278        assert!(!cache.entries["key1"].valid);
279        assert_eq!(cache.size(), 1);
280    }
281    #[test]
282    pub(super) fn test_worklist() {
283        let mut wl = CUDAWorklist::new();
284        assert!(wl.push(1));
285        assert!(wl.push(2));
286        assert!(!wl.push(1));
287        assert_eq!(wl.len(), 2);
288        assert_eq!(wl.pop(), Some(1));
289        assert!(!wl.contains(1));
290        assert!(wl.contains(2));
291    }
292    #[test]
293    pub(super) fn test_dominator_tree() {
294        let mut dt = CUDADominatorTree::new(5);
295        dt.set_idom(1, 0);
296        dt.set_idom(2, 0);
297        dt.set_idom(3, 1);
298        assert!(dt.dominates(0, 3));
299        assert!(dt.dominates(1, 3));
300        assert!(!dt.dominates(2, 3));
301        assert!(dt.dominates(3, 3));
302    }
303    #[test]
304    pub(super) fn test_liveness() {
305        let mut liveness = CUDALivenessInfo::new(3);
306        liveness.add_def(0, 1);
307        liveness.add_use(1, 1);
308        assert!(liveness.defs[0].contains(&1));
309        assert!(liveness.uses[1].contains(&1));
310    }
311    #[test]
312    pub(super) fn test_constant_folding() {
313        assert_eq!(CUDAConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
314        assert_eq!(CUDAConstantFoldingHelper::fold_div_i64(10, 0), None);
315        assert_eq!(CUDAConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
316        assert_eq!(
317            CUDAConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
318            0b1000
319        );
320        assert_eq!(CUDAConstantFoldingHelper::fold_bitnot_i64(0), -1);
321    }
322    #[test]
323    pub(super) fn test_dep_graph() {
324        let mut g = CUDADepGraph::new();
325        g.add_dep(1, 2);
326        g.add_dep(2, 3);
327        g.add_dep(1, 3);
328        assert_eq!(g.dependencies_of(2), vec![1]);
329        let topo = g.topological_sort();
330        assert_eq!(topo.len(), 3);
331        assert!(!g.has_cycle());
332        let pos: std::collections::HashMap<u32, usize> =
333            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
334        assert!(pos[&1] < pos[&2]);
335        assert!(pos[&1] < pos[&3]);
336        assert!(pos[&2] < pos[&3]);
337    }
338}
339#[cfg(test)]
340mod cudaext_pass_tests {
341    use super::*;
342    #[test]
343    pub(super) fn test_cudaext_phase_order() {
344        assert_eq!(CUDAExtPassPhase::Early.order(), 0);
345        assert_eq!(CUDAExtPassPhase::Middle.order(), 1);
346        assert_eq!(CUDAExtPassPhase::Late.order(), 2);
347        assert_eq!(CUDAExtPassPhase::Finalize.order(), 3);
348        assert!(CUDAExtPassPhase::Early.is_early());
349        assert!(!CUDAExtPassPhase::Early.is_late());
350    }
351    #[test]
352    pub(super) fn test_cudaext_config_builder() {
353        let c = CUDAExtPassConfig::new("p")
354            .with_phase(CUDAExtPassPhase::Late)
355            .with_max_iter(50)
356            .with_debug(1);
357        assert_eq!(c.name, "p");
358        assert_eq!(c.max_iterations, 50);
359        assert!(c.is_debug_enabled());
360        assert!(c.enabled);
361        let c2 = c.disabled();
362        assert!(!c2.enabled);
363    }
364    #[test]
365    pub(super) fn test_cudaext_stats() {
366        let mut s = CUDAExtPassStats::new();
367        s.visit();
368        s.visit();
369        s.modify();
370        s.iterate();
371        assert_eq!(s.nodes_visited, 2);
372        assert_eq!(s.nodes_modified, 1);
373        assert!(s.changed);
374        assert_eq!(s.iterations, 1);
375        let e = s.efficiency();
376        assert!((e - 0.5).abs() < 1e-9);
377    }
378    #[test]
379    pub(super) fn test_cudaext_registry() {
380        let mut r = CUDAExtPassRegistry::new();
381        r.register(CUDAExtPassConfig::new("a").with_phase(CUDAExtPassPhase::Early));
382        r.register(CUDAExtPassConfig::new("b").disabled());
383        assert_eq!(r.len(), 2);
384        assert_eq!(r.enabled_passes().len(), 1);
385        assert_eq!(r.passes_in_phase(&CUDAExtPassPhase::Early).len(), 1);
386    }
387    #[test]
388    pub(super) fn test_cudaext_cache() {
389        let mut c = CUDAExtCache::new(4);
390        assert!(c.get(99).is_none());
391        c.put(99, vec![1, 2, 3]);
392        let v = c.get(99).expect("v should be present in map");
393        assert_eq!(v, &[1u8, 2, 3]);
394        assert!(c.hit_rate() > 0.0);
395        assert_eq!(c.live_count(), 1);
396    }
397    #[test]
398    pub(super) fn test_cudaext_worklist() {
399        let mut w = CUDAExtWorklist::new(10);
400        w.push(5);
401        w.push(3);
402        w.push(5);
403        assert_eq!(w.len(), 2);
404        assert!(w.contains(5));
405        let first = w.pop().expect("first should be available to pop");
406        assert!(!w.contains(first));
407    }
408    #[test]
409    pub(super) fn test_cudaext_dom_tree() {
410        let mut dt = CUDAExtDomTree::new(5);
411        dt.set_idom(1, 0);
412        dt.set_idom(2, 0);
413        dt.set_idom(3, 1);
414        dt.set_idom(4, 1);
415        assert!(dt.dominates(0, 3));
416        assert!(dt.dominates(1, 4));
417        assert!(!dt.dominates(2, 3));
418        assert_eq!(dt.depth_of(3), 2);
419    }
420    #[test]
421    pub(super) fn test_cudaext_liveness() {
422        let mut lv = CUDAExtLiveness::new(3);
423        lv.add_def(0, 1);
424        lv.add_use(1, 1);
425        assert!(lv.var_is_def_in_block(0, 1));
426        assert!(lv.var_is_used_in_block(1, 1));
427        assert!(!lv.var_is_def_in_block(1, 1));
428    }
429    #[test]
430    pub(super) fn test_cudaext_const_folder() {
431        let mut cf = CUDAExtConstFolder::new();
432        assert_eq!(cf.add_i64(3, 4), Some(7));
433        assert_eq!(cf.div_i64(10, 0), None);
434        assert_eq!(cf.mul_i64(6, 7), Some(42));
435        assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
436        assert_eq!(cf.fold_count(), 3);
437        assert_eq!(cf.failure_count(), 1);
438    }
439    #[test]
440    pub(super) fn test_cudaext_dep_graph() {
441        let mut g = CUDAExtDepGraph::new(4);
442        g.add_edge(0, 1);
443        g.add_edge(1, 2);
444        g.add_edge(2, 3);
445        assert!(!g.has_cycle());
446        assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
447        assert_eq!(g.reachable(0).len(), 4);
448        let sccs = g.scc();
449        assert_eq!(sccs.len(), 4);
450    }
451}