1use super::types::{
6 CUDAAnalysisCache, CUDAConstantFoldingHelper, CUDADepGraph, CUDADominatorTree, CUDAExtCache,
7 CUDAExtConstFolder, CUDAExtDepGraph, CUDAExtDomTree, CUDAExtLiveness, CUDAExtPassConfig,
8 CUDAExtPassPhase, CUDAExtPassRegistry, CUDAExtPassStats, CUDAExtWorklist, CUDALivenessInfo,
9 CUDAPassConfig, CUDAPassPhase, CUDAPassRegistry, CUDAPassStats, CUDAWorklist, CudaBackend,
10 CudaBinOp, CudaExpr, CudaKernel, CudaModule, CudaParam, CudaQualifier, CudaStmt, CudaType,
11 DeviceFunction, LaunchBounds, LaunchConfig, MemcpyKind,
12};
13
14#[cfg(test)]
15mod tests {
16 use super::*;
17 #[test]
18 pub(super) fn test_cuda_type_display() {
19 assert_eq!(format!("{}", CudaType::Int), "int");
20 assert_eq!(format!("{}", CudaType::Float), "float");
21 assert_eq!(format!("{}", CudaType::Double), "double");
22 assert_eq!(format!("{}", CudaType::Half), "__half");
23 assert_eq!(format!("{}", CudaType::Dim3), "dim3");
24 assert_eq!(format!("{}", CudaType::CudaErrorT), "cudaError_t");
25 assert_eq!(
26 format!("{}", CudaType::Pointer(Box::new(CudaType::Float))),
27 "float*"
28 );
29 assert_eq!(
30 format!("{}", CudaType::Shared(Box::new(CudaType::Float))),
31 "__shared__ float"
32 );
33 assert_eq!(
34 format!("{}", CudaType::Constant(Box::new(CudaType::Int))),
35 "__constant__ int"
36 );
37 assert_eq!(
38 format!("{}", CudaType::Device(Box::new(CudaType::Double))),
39 "__device__ double"
40 );
41 }
42 #[test]
43 pub(super) fn test_cuda_qualifier_display() {
44 assert_eq!(format!("{}", CudaQualifier::Global), "__global__");
45 assert_eq!(format!("{}", CudaQualifier::Device), "__device__");
46 assert_eq!(format!("{}", CudaQualifier::Host), "__host__");
47 assert_eq!(format!("{}", CudaQualifier::Shared), "__shared__");
48 assert_eq!(format!("{}", CudaQualifier::Constant), "__constant__");
49 assert_eq!(format!("{}", CudaQualifier::Managed), "__managed__");
50 assert_eq!(format!("{}", CudaQualifier::Restrict), "__restrict__");
51 assert_eq!(format!("{}", CudaQualifier::Volatile), "volatile");
52 }
53 #[test]
54 pub(super) fn test_cuda_expr_emit() {
55 let backend = CudaBackend::new();
56 assert_eq!(backend.emit_expr(&CudaExpr::ThreadIdx('x')), "threadIdx.x");
57 assert_eq!(backend.emit_expr(&CudaExpr::BlockIdx('y')), "blockIdx.y");
58 assert_eq!(backend.emit_expr(&CudaExpr::BlockDim('z')), "blockDim.z");
59 assert_eq!(backend.emit_expr(&CudaExpr::GridDim('x')), "gridDim.x");
60 assert_eq!(backend.emit_expr(&CudaExpr::SyncThreads), "__syncthreads()");
61 assert_eq!(backend.emit_expr(&CudaExpr::WarpSize), "warpSize");
62 let add = CudaExpr::BinOp(
63 Box::new(CudaExpr::Var("a".into())),
64 CudaBinOp::Add,
65 Box::new(CudaExpr::Var("b".into())),
66 );
67 assert_eq!(backend.emit_expr(&add), "(a + b)");
68 let atomic = CudaExpr::AtomicAdd(
69 Box::new(CudaExpr::Var("ptr".into())),
70 Box::new(CudaExpr::LitInt(1)),
71 );
72 assert_eq!(backend.emit_expr(&atomic), "atomicAdd(ptr, 1)");
73 let cast = CudaExpr::Cast(CudaType::Float, Box::new(CudaExpr::Var("n".into())));
74 assert_eq!(backend.emit_expr(&cast), "((float)n)");
75 }
76 #[test]
77 pub(super) fn test_cuda_stmt_emit() {
78 let backend = CudaBackend::new();
79 let decl = CudaStmt::VarDecl {
80 ty: CudaType::Int,
81 name: "idx".into(),
82 init: Some(CudaExpr::LitInt(0)),
83 };
84 assert_eq!(backend.emit_stmt(&decl, 0), "int idx = 0;");
85 let malloc = CudaStmt::CudaMalloc {
86 ptr: "d_data".into(),
87 size: CudaExpr::LitInt(1024),
88 };
89 let s = backend.emit_stmt(&malloc, 0);
90 assert!(s.contains("cudaMalloc"));
91 assert!(s.contains("(void**)&d_data"));
92 assert!(s.contains("1024"));
93 let memcpy = CudaStmt::CudaMemcpy {
94 dst: CudaExpr::Var("d_data".into()),
95 src: CudaExpr::Var("h_data".into()),
96 size: CudaExpr::LitInt(1024),
97 kind: MemcpyKind::HostToDevice,
98 };
99 let s = backend.emit_stmt(&memcpy, 0);
100 assert!(s.contains("cudaMemcpy"));
101 assert!(s.contains("cudaMemcpyHostToDevice"));
102 let ret = CudaStmt::Return(Some(CudaExpr::LitInt(0)));
103 assert_eq!(backend.emit_stmt(&ret, 0), "return 0;");
104 let sync = CudaStmt::DeviceSync;
105 assert!(backend
106 .emit_stmt(&sync, 0)
107 .contains("cudaDeviceSynchronize"));
108 }
109 #[test]
110 pub(super) fn test_kernel_emit() {
111 let backend = CudaBackend::new();
112 let idx_decl = CudaStmt::VarDecl {
113 ty: CudaType::Int,
114 name: "idx".into(),
115 init: Some(CudaExpr::BinOp(
116 Box::new(CudaExpr::BinOp(
117 Box::new(CudaExpr::BlockIdx('x')),
118 CudaBinOp::Mul,
119 Box::new(CudaExpr::BlockDim('x')),
120 )),
121 CudaBinOp::Add,
122 Box::new(CudaExpr::ThreadIdx('x')),
123 )),
124 };
125 let guard = CudaStmt::IfElse {
126 cond: CudaExpr::BinOp(
127 Box::new(CudaExpr::Var("idx".into())),
128 CudaBinOp::Lt,
129 Box::new(CudaExpr::Var("n".into())),
130 ),
131 then_body: vec![CudaStmt::Assign {
132 lhs: CudaExpr::Index(
133 Box::new(CudaExpr::Var("c".into())),
134 Box::new(CudaExpr::Var("idx".into())),
135 ),
136 rhs: CudaExpr::BinOp(
137 Box::new(CudaExpr::Index(
138 Box::new(CudaExpr::Var("a".into())),
139 Box::new(CudaExpr::Var("idx".into())),
140 )),
141 CudaBinOp::Add,
142 Box::new(CudaExpr::Index(
143 Box::new(CudaExpr::Var("b".into())),
144 Box::new(CudaExpr::Var("idx".into())),
145 )),
146 ),
147 }],
148 else_body: None,
149 };
150 let kernel = CudaKernel::new("vec_add")
151 .add_param(CudaParam::new(
152 CudaType::Pointer(Box::new(CudaType::Float)),
153 "a",
154 ))
155 .add_param(CudaParam::new(
156 CudaType::Pointer(Box::new(CudaType::Float)),
157 "b",
158 ))
159 .add_param(CudaParam::new(
160 CudaType::Pointer(Box::new(CudaType::Float)),
161 "c",
162 ))
163 .add_param(CudaParam::new(CudaType::Int, "n"))
164 .add_stmt(idx_decl)
165 .add_stmt(guard)
166 .with_launch_bounds(LaunchBounds::new(256));
167 let cu = backend.emit_kernel(&kernel);
168 assert!(cu.contains("__global__"));
169 assert!(cu.contains("vec_add"));
170 assert!(cu.contains("__launch_bounds__(256)"));
171 assert!(cu.contains("threadIdx.x"));
172 assert!(cu.contains("blockIdx.x"));
173 assert!(cu.contains("blockDim.x"));
174 }
175 #[test]
176 pub(super) fn test_kernel_launch_stmt() {
177 let backend = CudaBackend::new();
178 let config =
179 LaunchConfig::simple_1d(CudaExpr::Var("grid".into()), CudaExpr::Var("block".into()));
180 let launch = CudaStmt::KernelLaunch {
181 name: "my_kernel".into(),
182 config,
183 args: vec![
184 CudaExpr::Var("d_a".into()),
185 CudaExpr::Var("d_b".into()),
186 CudaExpr::LitInt(1024),
187 ],
188 };
189 let s = backend.emit_stmt(&launch, 0);
190 assert!(s.contains("my_kernel<<<"));
191 assert!(s.contains("grid"));
192 assert!(s.contains("block"));
193 assert!(s.contains("d_a"));
194 assert!(s.contains("d_b"));
195 assert!(s.contains("1024"));
196 }
197 #[test]
198 pub(super) fn test_module_emit() {
199 let backend = CudaBackend::new();
200 let module = CudaModule::new()
201 .add_constant(CudaType::Int, "BLOCK_SIZE", Some(CudaExpr::LitInt(256)))
202 .add_kernel(CudaKernel::new("dummy_kernel").add_stmt(CudaStmt::Return(None)));
203 let cu = backend.emit_module(&module);
204 assert!(cu.contains("#include <cuda_runtime.h>"));
205 assert!(cu.contains("__constant__ int BLOCK_SIZE = 256;"));
206 assert!(cu.contains("__global__"));
207 assert!(cu.contains("dummy_kernel"));
208 assert!(cu.contains("CUDA_CHECK"));
209 }
210 #[test]
211 pub(super) fn test_device_function_and_warp_intrinsics() {
212 let backend = CudaBackend::new();
213 let shfl = CudaExpr::ShflDownSync(
214 Box::new(CudaExpr::LitInt(0xffffffff)),
215 Box::new(CudaExpr::Var("val".into())),
216 Box::new(CudaExpr::LitInt(16)),
217 );
218 let f = DeviceFunction::host_device("warp_reduce_sum", CudaType::Float)
219 .with_inline()
220 .add_param(CudaParam::new(CudaType::Float, "val"))
221 .add_stmt(CudaStmt::Expr(shfl))
222 .add_stmt(CudaStmt::Return(Some(CudaExpr::Var("val".into()))));
223 let src = backend.emit_device_function(&f);
224 assert!(src.contains("__host__"));
225 assert!(src.contains("__device__"));
226 assert!(src.contains("inline"));
227 assert!(src.contains("warp_reduce_sum"));
228 assert!(src.contains("__shfl_down_sync"));
229 let ballot = CudaExpr::BallotSync(
230 Box::new(CudaExpr::LitInt(0xffffffff)),
231 Box::new(CudaExpr::Var("pred".into())),
232 );
233 assert!(backend.emit_expr(&ballot).contains("__ballot_sync"));
234 let popc = CudaExpr::Popc(Box::new(CudaExpr::Var("mask".into())));
235 assert!(backend.emit_expr(&popc).contains("__popc"));
236 }
237}
238#[cfg(test)]
239mod CUDA_infra_tests {
240 use super::*;
241 #[test]
242 pub(super) fn test_pass_config() {
243 let config = CUDAPassConfig::new("test_pass", CUDAPassPhase::Transformation);
244 assert!(config.enabled);
245 assert!(config.phase.is_modifying());
246 assert_eq!(config.phase.name(), "transformation");
247 }
248 #[test]
249 pub(super) fn test_pass_stats() {
250 let mut stats = CUDAPassStats::new();
251 stats.record_run(10, 100, 3);
252 stats.record_run(20, 200, 5);
253 assert_eq!(stats.total_runs, 2);
254 assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
255 assert!((stats.success_rate() - 1.0).abs() < 0.01);
256 let s = stats.format_summary();
257 assert!(s.contains("Runs: 2/2"));
258 }
259 #[test]
260 pub(super) fn test_pass_registry() {
261 let mut reg = CUDAPassRegistry::new();
262 reg.register(CUDAPassConfig::new("pass_a", CUDAPassPhase::Analysis));
263 reg.register(CUDAPassConfig::new("pass_b", CUDAPassPhase::Transformation).disabled());
264 assert_eq!(reg.total_passes(), 2);
265 assert_eq!(reg.enabled_count(), 1);
266 reg.update_stats("pass_a", 5, 50, 2);
267 let stats = reg.get_stats("pass_a").expect("stats should exist");
268 assert_eq!(stats.total_changes, 5);
269 }
270 #[test]
271 pub(super) fn test_analysis_cache() {
272 let mut cache = CUDAAnalysisCache::new(10);
273 cache.insert("key1".to_string(), vec![1, 2, 3]);
274 assert!(cache.get("key1").is_some());
275 assert!(cache.get("key2").is_none());
276 assert!((cache.hit_rate() - 0.5).abs() < 0.01);
277 cache.invalidate("key1");
278 assert!(!cache.entries["key1"].valid);
279 assert_eq!(cache.size(), 1);
280 }
281 #[test]
282 pub(super) fn test_worklist() {
283 let mut wl = CUDAWorklist::new();
284 assert!(wl.push(1));
285 assert!(wl.push(2));
286 assert!(!wl.push(1));
287 assert_eq!(wl.len(), 2);
288 assert_eq!(wl.pop(), Some(1));
289 assert!(!wl.contains(1));
290 assert!(wl.contains(2));
291 }
292 #[test]
293 pub(super) fn test_dominator_tree() {
294 let mut dt = CUDADominatorTree::new(5);
295 dt.set_idom(1, 0);
296 dt.set_idom(2, 0);
297 dt.set_idom(3, 1);
298 assert!(dt.dominates(0, 3));
299 assert!(dt.dominates(1, 3));
300 assert!(!dt.dominates(2, 3));
301 assert!(dt.dominates(3, 3));
302 }
303 #[test]
304 pub(super) fn test_liveness() {
305 let mut liveness = CUDALivenessInfo::new(3);
306 liveness.add_def(0, 1);
307 liveness.add_use(1, 1);
308 assert!(liveness.defs[0].contains(&1));
309 assert!(liveness.uses[1].contains(&1));
310 }
311 #[test]
312 pub(super) fn test_constant_folding() {
313 assert_eq!(CUDAConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
314 assert_eq!(CUDAConstantFoldingHelper::fold_div_i64(10, 0), None);
315 assert_eq!(CUDAConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
316 assert_eq!(
317 CUDAConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
318 0b1000
319 );
320 assert_eq!(CUDAConstantFoldingHelper::fold_bitnot_i64(0), -1);
321 }
322 #[test]
323 pub(super) fn test_dep_graph() {
324 let mut g = CUDADepGraph::new();
325 g.add_dep(1, 2);
326 g.add_dep(2, 3);
327 g.add_dep(1, 3);
328 assert_eq!(g.dependencies_of(2), vec![1]);
329 let topo = g.topological_sort();
330 assert_eq!(topo.len(), 3);
331 assert!(!g.has_cycle());
332 let pos: std::collections::HashMap<u32, usize> =
333 topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
334 assert!(pos[&1] < pos[&2]);
335 assert!(pos[&1] < pos[&3]);
336 assert!(pos[&2] < pos[&3]);
337 }
338}
339#[cfg(test)]
340mod cudaext_pass_tests {
341 use super::*;
342 #[test]
343 pub(super) fn test_cudaext_phase_order() {
344 assert_eq!(CUDAExtPassPhase::Early.order(), 0);
345 assert_eq!(CUDAExtPassPhase::Middle.order(), 1);
346 assert_eq!(CUDAExtPassPhase::Late.order(), 2);
347 assert_eq!(CUDAExtPassPhase::Finalize.order(), 3);
348 assert!(CUDAExtPassPhase::Early.is_early());
349 assert!(!CUDAExtPassPhase::Early.is_late());
350 }
351 #[test]
352 pub(super) fn test_cudaext_config_builder() {
353 let c = CUDAExtPassConfig::new("p")
354 .with_phase(CUDAExtPassPhase::Late)
355 .with_max_iter(50)
356 .with_debug(1);
357 assert_eq!(c.name, "p");
358 assert_eq!(c.max_iterations, 50);
359 assert!(c.is_debug_enabled());
360 assert!(c.enabled);
361 let c2 = c.disabled();
362 assert!(!c2.enabled);
363 }
364 #[test]
365 pub(super) fn test_cudaext_stats() {
366 let mut s = CUDAExtPassStats::new();
367 s.visit();
368 s.visit();
369 s.modify();
370 s.iterate();
371 assert_eq!(s.nodes_visited, 2);
372 assert_eq!(s.nodes_modified, 1);
373 assert!(s.changed);
374 assert_eq!(s.iterations, 1);
375 let e = s.efficiency();
376 assert!((e - 0.5).abs() < 1e-9);
377 }
378 #[test]
379 pub(super) fn test_cudaext_registry() {
380 let mut r = CUDAExtPassRegistry::new();
381 r.register(CUDAExtPassConfig::new("a").with_phase(CUDAExtPassPhase::Early));
382 r.register(CUDAExtPassConfig::new("b").disabled());
383 assert_eq!(r.len(), 2);
384 assert_eq!(r.enabled_passes().len(), 1);
385 assert_eq!(r.passes_in_phase(&CUDAExtPassPhase::Early).len(), 1);
386 }
387 #[test]
388 pub(super) fn test_cudaext_cache() {
389 let mut c = CUDAExtCache::new(4);
390 assert!(c.get(99).is_none());
391 c.put(99, vec![1, 2, 3]);
392 let v = c.get(99).expect("v should be present in map");
393 assert_eq!(v, &[1u8, 2, 3]);
394 assert!(c.hit_rate() > 0.0);
395 assert_eq!(c.live_count(), 1);
396 }
397 #[test]
398 pub(super) fn test_cudaext_worklist() {
399 let mut w = CUDAExtWorklist::new(10);
400 w.push(5);
401 w.push(3);
402 w.push(5);
403 assert_eq!(w.len(), 2);
404 assert!(w.contains(5));
405 let first = w.pop().expect("first should be available to pop");
406 assert!(!w.contains(first));
407 }
408 #[test]
409 pub(super) fn test_cudaext_dom_tree() {
410 let mut dt = CUDAExtDomTree::new(5);
411 dt.set_idom(1, 0);
412 dt.set_idom(2, 0);
413 dt.set_idom(3, 1);
414 dt.set_idom(4, 1);
415 assert!(dt.dominates(0, 3));
416 assert!(dt.dominates(1, 4));
417 assert!(!dt.dominates(2, 3));
418 assert_eq!(dt.depth_of(3), 2);
419 }
420 #[test]
421 pub(super) fn test_cudaext_liveness() {
422 let mut lv = CUDAExtLiveness::new(3);
423 lv.add_def(0, 1);
424 lv.add_use(1, 1);
425 assert!(lv.var_is_def_in_block(0, 1));
426 assert!(lv.var_is_used_in_block(1, 1));
427 assert!(!lv.var_is_def_in_block(1, 1));
428 }
429 #[test]
430 pub(super) fn test_cudaext_const_folder() {
431 let mut cf = CUDAExtConstFolder::new();
432 assert_eq!(cf.add_i64(3, 4), Some(7));
433 assert_eq!(cf.div_i64(10, 0), None);
434 assert_eq!(cf.mul_i64(6, 7), Some(42));
435 assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
436 assert_eq!(cf.fold_count(), 3);
437 assert_eq!(cf.failure_count(), 1);
438 }
439 #[test]
440 pub(super) fn test_cudaext_dep_graph() {
441 let mut g = CUDAExtDepGraph::new(4);
442 g.add_edge(0, 1);
443 g.add_edge(1, 2);
444 g.add_edge(2, 3);
445 assert!(!g.has_cycle());
446 assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
447 assert_eq!(g.reachable(0).len(), 4);
448 let sccs = g.scc();
449 assert_eq!(sccs.len(), 4);
450 }
451}