oxilean_codegen/spirv_backend/
functions.rs1use std::collections::HashMap;
6
7use super::types::{
8 AddressingModel, Decoration, MemoryModel, SPIRVAnalysisCache, SPIRVConstantFoldingHelper,
9 SPIRVDepGraph, SPIRVDominatorTree, SPIRVLivenessInfo, SPIRVPassConfig, SPIRVPassPhase,
10 SPIRVPassRegistry, SPIRVPassStats, SPIRVWorklist, SpirVBackend, SpirVBasicBlock,
11 SpirVCapability, SpirVFunction, SpirVInstruction, SpirVModule, SpirVOp, SpirVType,
12 StorageClass,
13};
14
15#[cfg(test)]
16mod tests {
17 use super::*;
18 #[test]
19 pub(super) fn test_spirv_type_display() {
20 assert_eq!(SpirVType::Void.to_string(), "void");
21 assert_eq!(SpirVType::Bool.to_string(), "bool");
22 assert_eq!(
23 SpirVType::Int {
24 width: 32,
25 signed: true
26 }
27 .to_string(),
28 "i32"
29 );
30 assert_eq!(
31 SpirVType::Int {
32 width: 64,
33 signed: false
34 }
35 .to_string(),
36 "u64"
37 );
38 assert_eq!(SpirVType::Float { width: 32 }.to_string(), "f32");
39 assert_eq!(SpirVType::Sampler.to_string(), "sampler");
40 }
41 #[test]
42 pub(super) fn test_spirv_type_display_compound() {
43 let vec4 = SpirVType::Vector {
44 element: Box::new(SpirVType::Float { width: 32 }),
45 count: 4,
46 };
47 assert_eq!(vec4.to_string(), "vec4<f32>");
48 let mat4 = SpirVType::Matrix {
49 column_type: Box::new(SpirVType::Vector {
50 element: Box::new(SpirVType::Float { width: 32 }),
51 count: 4,
52 }),
53 column_count: 4,
54 };
55 assert!(mat4.to_string().contains("mat4x"));
56 let arr = SpirVType::Array {
57 element: Box::new(SpirVType::Int {
58 width: 32,
59 signed: true,
60 }),
61 length: 16,
62 };
63 assert_eq!(arr.to_string(), "[i32; 16]");
64 }
65 #[test]
66 pub(super) fn test_spirv_instruction_emit_text() {
67 let instr = SpirVInstruction::with_result(5, 3, SpirVOp::FAdd, vec![6, 7]);
68 let text = instr.emit_text();
69 assert!(text.contains("%5 ="));
70 assert!(text.contains("OpFAdd"));
71 assert!(text.contains("%6"));
72 assert!(text.contains("%7"));
73 }
74 #[test]
75 pub(super) fn test_spirv_instruction_no_result() {
76 let instr = SpirVInstruction::no_result(SpirVOp::Return, vec![]);
77 assert!(instr.result_id.is_none());
78 let text = instr.emit_text();
79 assert!(text.contains("OpReturn"));
80 }
81 #[test]
82 pub(super) fn test_spirv_instruction_word_count() {
83 let instr = SpirVInstruction::with_result(1, 2, SpirVOp::FAdd, vec![3, 4]);
84 assert_eq!(instr.word_count(), 5);
85 let void_ret = SpirVInstruction::no_result(SpirVOp::Return, vec![]);
86 assert_eq!(void_ret.word_count(), 1);
87 }
88 #[test]
89 pub(super) fn test_spirv_basic_block() {
90 let mut block = SpirVBasicBlock::new(10);
91 block.push(SpirVInstruction::with_result(
92 11,
93 2,
94 SpirVOp::FAdd,
95 vec![3, 4],
96 ));
97 block.push(SpirVInstruction::no_result(SpirVOp::Return, vec![]));
98 assert_eq!(block.instr_count(), 2);
99 let text = block.emit_text();
100 assert!(text.contains("%10 = OpLabel"));
101 assert!(text.contains("OpFAdd"));
102 assert!(text.contains("OpReturn"));
103 }
104 #[test]
105 pub(super) fn test_spirv_function_emit() {
106 let mut func = SpirVFunction::new(1, Some("main".to_string()), 2, 3);
107 func.add_param(4, 5);
108 let mut block = SpirVBasicBlock::new(6);
109 block.push(SpirVInstruction::no_result(SpirVOp::Return, vec![]));
110 func.add_block(block);
111 let text = func.emit_text();
112 assert!(text.contains("OpFunction"));
113 assert!(text.contains("OpFunctionParameter"));
114 assert!(text.contains("OpFunctionEnd"));
115 assert!(text.contains("; main"));
116 }
117 #[test]
118 pub(super) fn test_spirv_module_new() {
119 let module = SpirVModule::new();
120 assert_eq!(module.bound, 1);
121 assert!(module.capabilities.is_empty());
122 assert!(module.functions.is_empty());
123 }
124 #[test]
125 pub(super) fn test_spirv_module_emit_text() {
126 let mut module = SpirVModule::new();
127 module.add_capability(SpirVCapability::Shader);
128 module.add_capability(SpirVCapability::Float64);
129 module.memory_model = (AddressingModel::Logical, MemoryModel::GLSL450);
130 let text = module.emit_text();
131 assert!(text.contains("OpCapability Shader"));
132 assert!(text.contains("OpCapability Float64"));
133 assert!(text.contains("OpMemoryModel Logical GLSL450"));
134 }
135 #[test]
136 pub(super) fn test_spirv_backend_configure_vulkan() {
137 let mut backend = SpirVBackend::new();
138 backend.configure_for_vulkan();
139 assert!(backend
140 .module
141 .capabilities
142 .contains(&SpirVCapability::Shader));
143 assert_eq!(backend.module.memory_model.1, MemoryModel::GLSL450);
144 assert!(backend.glsl_ext_id.is_some());
145 }
146 #[test]
147 pub(super) fn test_spirv_backend_type_declarations() {
148 let mut backend = SpirVBackend::new();
149 let f32_id = backend.declare_float_type(32);
150 let f32_id2 = backend.declare_float_type(32);
151 assert_eq!(f32_id, f32_id2);
152 let i32_id = backend.declare_int_type(32, true);
153 let u32_id = backend.declare_int_type(32, false);
154 assert_ne!(i32_id, u32_id);
155 let vec4_id = backend.declare_vector_type(f32_id, 4);
156 let vec4_id2 = backend.declare_vector_type(f32_id, 4);
157 assert_eq!(vec4_id, vec4_id2);
158 }
159 #[test]
160 pub(super) fn test_spirv_backend_begin_function() {
161 let mut backend = SpirVBackend::new();
162 let f32_id = backend.declare_float_type(32);
163 let func = backend.begin_function("add_f32", f32_id, vec![f32_id, f32_id]);
164 assert_eq!(func.params.len(), 2);
165 assert_eq!(func.name.as_deref(), Some("add_f32"));
166 backend.finish_function(func);
167 assert_eq!(backend.function_count(), 1);
168 assert!(backend.lookup_symbol("add_f32").is_some());
169 }
170 #[test]
171 pub(super) fn test_spirv_backend_compute_kernel() {
172 let mut backend = SpirVBackend::new();
173 let func_id = backend.emit_compute_kernel("fill_buffer", 64, 1, 1);
174 assert!(func_id > 0);
175 assert_eq!(backend.function_count(), 1);
176 let text = backend.emit_text();
177 assert!(text.contains("fill_buffer"));
178 assert!(text.contains("OpEntryPoint"));
179 assert!(text.contains("LocalSize"));
180 }
181 #[test]
182 pub(super) fn test_spirv_binary_header() {
183 let backend = SpirVBackend::new();
184 let header = backend.emit_binary_header();
185 assert_eq!(header.len(), 5);
186 assert_eq!(header[0], 0x0723_0203);
187 assert_eq!(header[4], 0);
188 }
189 #[test]
190 pub(super) fn test_spirv_module_word_count() {
191 let mut backend = SpirVBackend::new();
192 backend.emit_compute_kernel("test_kernel", 32, 1, 1);
193 let wc = backend.module.estimate_word_count();
194 assert!(wc >= 5);
195 }
196 #[test]
197 pub(super) fn test_spirv_decoration_and_names() {
198 let mut module = SpirVModule::new();
199 module.set_name(1, "my_var");
200 module.decorate(1, Decoration::Binding(0));
201 module.decorate(1, Decoration::DescriptorSet(0));
202 assert_eq!(
203 module
204 .debug_names
205 .get(&1)
206 .expect("value should be present in map"),
207 "my_var"
208 );
209 assert_eq!(
210 module
211 .decorations
212 .get(&1)
213 .expect("value should be present in map")
214 .len(),
215 2
216 );
217 let text = module.emit_text();
218 assert!(text.contains("OpName %1 \"my_var\""));
219 assert!(text.contains("OpDecorate %1 Binding(0)"));
220 }
221 #[test]
222 pub(super) fn test_spirv_op_display() {
223 assert_eq!(SpirVOp::FAdd.to_string(), "OpFAdd");
224 assert_eq!(SpirVOp::IMul.to_string(), "OpIMul");
225 assert_eq!(
226 SpirVOp::MatrixTimesVector.to_string(),
227 "OpMatrixTimesVector"
228 );
229 assert_eq!(
230 SpirVOp::CompositeConstruct.to_string(),
231 "OpCompositeConstruct"
232 );
233 assert_eq!(SpirVOp::Return.to_string(), "OpReturn");
234 assert_eq!(SpirVOp::Load.to_string(), "OpLoad");
235 }
236 #[test]
237 pub(super) fn test_storage_class_display() {
238 assert_eq!(StorageClass::Uniform.to_string(), "Uniform");
239 assert_eq!(StorageClass::Function.to_string(), "Function");
240 assert_eq!(StorageClass::Workgroup.to_string(), "Workgroup");
241 assert_eq!(StorageClass::Input.to_string(), "Input");
242 assert_eq!(StorageClass::Output.to_string(), "Output");
243 }
244}
245#[cfg(test)]
246mod SPIRV_infra_tests {
247 use super::*;
248 #[test]
249 pub(super) fn test_pass_config() {
250 let config = SPIRVPassConfig::new("test_pass", SPIRVPassPhase::Transformation);
251 assert!(config.enabled);
252 assert!(config.phase.is_modifying());
253 assert_eq!(config.phase.name(), "transformation");
254 }
255 #[test]
256 pub(super) fn test_pass_stats() {
257 let mut stats = SPIRVPassStats::new();
258 stats.record_run(10, 100, 3);
259 stats.record_run(20, 200, 5);
260 assert_eq!(stats.total_runs, 2);
261 assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
262 assert!((stats.success_rate() - 1.0).abs() < 0.01);
263 let s = stats.format_summary();
264 assert!(s.contains("Runs: 2/2"));
265 }
266 #[test]
267 pub(super) fn test_pass_registry() {
268 let mut reg = SPIRVPassRegistry::new();
269 reg.register(SPIRVPassConfig::new("pass_a", SPIRVPassPhase::Analysis));
270 reg.register(SPIRVPassConfig::new("pass_b", SPIRVPassPhase::Transformation).disabled());
271 assert_eq!(reg.total_passes(), 2);
272 assert_eq!(reg.enabled_count(), 1);
273 reg.update_stats("pass_a", 5, 50, 2);
274 let stats = reg.get_stats("pass_a").expect("stats should exist");
275 assert_eq!(stats.total_changes, 5);
276 }
277 #[test]
278 pub(super) fn test_analysis_cache() {
279 let mut cache = SPIRVAnalysisCache::new(10);
280 cache.insert("key1".to_string(), vec![1, 2, 3]);
281 assert!(cache.get("key1").is_some());
282 assert!(cache.get("key2").is_none());
283 assert!((cache.hit_rate() - 0.5).abs() < 0.01);
284 cache.invalidate("key1");
285 assert!(!cache.entries["key1"].valid);
286 assert_eq!(cache.size(), 1);
287 }
288 #[test]
289 pub(super) fn test_worklist() {
290 let mut wl = SPIRVWorklist::new();
291 assert!(wl.push(1));
292 assert!(wl.push(2));
293 assert!(!wl.push(1));
294 assert_eq!(wl.len(), 2);
295 assert_eq!(wl.pop(), Some(1));
296 assert!(!wl.contains(1));
297 assert!(wl.contains(2));
298 }
299 #[test]
300 pub(super) fn test_dominator_tree() {
301 let mut dt = SPIRVDominatorTree::new(5);
302 dt.set_idom(1, 0);
303 dt.set_idom(2, 0);
304 dt.set_idom(3, 1);
305 assert!(dt.dominates(0, 3));
306 assert!(dt.dominates(1, 3));
307 assert!(!dt.dominates(2, 3));
308 assert!(dt.dominates(3, 3));
309 }
310 #[test]
311 pub(super) fn test_liveness() {
312 let mut liveness = SPIRVLivenessInfo::new(3);
313 liveness.add_def(0, 1);
314 liveness.add_use(1, 1);
315 assert!(liveness.defs[0].contains(&1));
316 assert!(liveness.uses[1].contains(&1));
317 }
318 #[test]
319 pub(super) fn test_constant_folding() {
320 assert_eq!(SPIRVConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
321 assert_eq!(SPIRVConstantFoldingHelper::fold_div_i64(10, 0), None);
322 assert_eq!(SPIRVConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
323 assert_eq!(
324 SPIRVConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
325 0b1000
326 );
327 assert_eq!(SPIRVConstantFoldingHelper::fold_bitnot_i64(0), -1);
328 }
329 #[test]
330 pub(super) fn test_dep_graph() {
331 let mut g = SPIRVDepGraph::new();
332 g.add_dep(1, 2);
333 g.add_dep(2, 3);
334 g.add_dep(1, 3);
335 assert_eq!(g.dependencies_of(2), vec![1]);
336 let topo = g.topological_sort();
337 assert_eq!(topo.len(), 3);
338 assert!(!g.has_cycle());
339 let pos: std::collections::HashMap<u32, usize> =
340 topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
341 assert!(pos[&1] < pos[&2]);
342 assert!(pos[&1] < pos[&3]);
343 assert!(pos[&2] < pos[&3]);
344 }
345}