Skip to main content

oxilean_codegen/spirv_backend/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use std::collections::HashMap;
6
7use super::types::{
8    AddressingModel, Decoration, MemoryModel, SPIRVAnalysisCache, SPIRVConstantFoldingHelper,
9    SPIRVDepGraph, SPIRVDominatorTree, SPIRVLivenessInfo, SPIRVPassConfig, SPIRVPassPhase,
10    SPIRVPassRegistry, SPIRVPassStats, SPIRVWorklist, SpirVBackend, SpirVBasicBlock,
11    SpirVCapability, SpirVFunction, SpirVInstruction, SpirVModule, SpirVOp, SpirVType,
12    StorageClass,
13};
14
15#[cfg(test)]
16mod tests {
17    use super::*;
18    #[test]
19    pub(super) fn test_spirv_type_display() {
20        assert_eq!(SpirVType::Void.to_string(), "void");
21        assert_eq!(SpirVType::Bool.to_string(), "bool");
22        assert_eq!(
23            SpirVType::Int {
24                width: 32,
25                signed: true
26            }
27            .to_string(),
28            "i32"
29        );
30        assert_eq!(
31            SpirVType::Int {
32                width: 64,
33                signed: false
34            }
35            .to_string(),
36            "u64"
37        );
38        assert_eq!(SpirVType::Float { width: 32 }.to_string(), "f32");
39        assert_eq!(SpirVType::Sampler.to_string(), "sampler");
40    }
41    #[test]
42    pub(super) fn test_spirv_type_display_compound() {
43        let vec4 = SpirVType::Vector {
44            element: Box::new(SpirVType::Float { width: 32 }),
45            count: 4,
46        };
47        assert_eq!(vec4.to_string(), "vec4<f32>");
48        let mat4 = SpirVType::Matrix {
49            column_type: Box::new(SpirVType::Vector {
50                element: Box::new(SpirVType::Float { width: 32 }),
51                count: 4,
52            }),
53            column_count: 4,
54        };
55        assert!(mat4.to_string().contains("mat4x"));
56        let arr = SpirVType::Array {
57            element: Box::new(SpirVType::Int {
58                width: 32,
59                signed: true,
60            }),
61            length: 16,
62        };
63        assert_eq!(arr.to_string(), "[i32; 16]");
64    }
65    #[test]
66    pub(super) fn test_spirv_instruction_emit_text() {
67        let instr = SpirVInstruction::with_result(5, 3, SpirVOp::FAdd, vec![6, 7]);
68        let text = instr.emit_text();
69        assert!(text.contains("%5 ="));
70        assert!(text.contains("OpFAdd"));
71        assert!(text.contains("%6"));
72        assert!(text.contains("%7"));
73    }
74    #[test]
75    pub(super) fn test_spirv_instruction_no_result() {
76        let instr = SpirVInstruction::no_result(SpirVOp::Return, vec![]);
77        assert!(instr.result_id.is_none());
78        let text = instr.emit_text();
79        assert!(text.contains("OpReturn"));
80    }
81    #[test]
82    pub(super) fn test_spirv_instruction_word_count() {
83        let instr = SpirVInstruction::with_result(1, 2, SpirVOp::FAdd, vec![3, 4]);
84        assert_eq!(instr.word_count(), 5);
85        let void_ret = SpirVInstruction::no_result(SpirVOp::Return, vec![]);
86        assert_eq!(void_ret.word_count(), 1);
87    }
88    #[test]
89    pub(super) fn test_spirv_basic_block() {
90        let mut block = SpirVBasicBlock::new(10);
91        block.push(SpirVInstruction::with_result(
92            11,
93            2,
94            SpirVOp::FAdd,
95            vec![3, 4],
96        ));
97        block.push(SpirVInstruction::no_result(SpirVOp::Return, vec![]));
98        assert_eq!(block.instr_count(), 2);
99        let text = block.emit_text();
100        assert!(text.contains("%10 = OpLabel"));
101        assert!(text.contains("OpFAdd"));
102        assert!(text.contains("OpReturn"));
103    }
104    #[test]
105    pub(super) fn test_spirv_function_emit() {
106        let mut func = SpirVFunction::new(1, Some("main".to_string()), 2, 3);
107        func.add_param(4, 5);
108        let mut block = SpirVBasicBlock::new(6);
109        block.push(SpirVInstruction::no_result(SpirVOp::Return, vec![]));
110        func.add_block(block);
111        let text = func.emit_text();
112        assert!(text.contains("OpFunction"));
113        assert!(text.contains("OpFunctionParameter"));
114        assert!(text.contains("OpFunctionEnd"));
115        assert!(text.contains("; main"));
116    }
117    #[test]
118    pub(super) fn test_spirv_module_new() {
119        let module = SpirVModule::new();
120        assert_eq!(module.bound, 1);
121        assert!(module.capabilities.is_empty());
122        assert!(module.functions.is_empty());
123    }
124    #[test]
125    pub(super) fn test_spirv_module_emit_text() {
126        let mut module = SpirVModule::new();
127        module.add_capability(SpirVCapability::Shader);
128        module.add_capability(SpirVCapability::Float64);
129        module.memory_model = (AddressingModel::Logical, MemoryModel::GLSL450);
130        let text = module.emit_text();
131        assert!(text.contains("OpCapability Shader"));
132        assert!(text.contains("OpCapability Float64"));
133        assert!(text.contains("OpMemoryModel Logical GLSL450"));
134    }
135    #[test]
136    pub(super) fn test_spirv_backend_configure_vulkan() {
137        let mut backend = SpirVBackend::new();
138        backend.configure_for_vulkan();
139        assert!(backend
140            .module
141            .capabilities
142            .contains(&SpirVCapability::Shader));
143        assert_eq!(backend.module.memory_model.1, MemoryModel::GLSL450);
144        assert!(backend.glsl_ext_id.is_some());
145    }
146    #[test]
147    pub(super) fn test_spirv_backend_type_declarations() {
148        let mut backend = SpirVBackend::new();
149        let f32_id = backend.declare_float_type(32);
150        let f32_id2 = backend.declare_float_type(32);
151        assert_eq!(f32_id, f32_id2);
152        let i32_id = backend.declare_int_type(32, true);
153        let u32_id = backend.declare_int_type(32, false);
154        assert_ne!(i32_id, u32_id);
155        let vec4_id = backend.declare_vector_type(f32_id, 4);
156        let vec4_id2 = backend.declare_vector_type(f32_id, 4);
157        assert_eq!(vec4_id, vec4_id2);
158    }
159    #[test]
160    pub(super) fn test_spirv_backend_begin_function() {
161        let mut backend = SpirVBackend::new();
162        let f32_id = backend.declare_float_type(32);
163        let func = backend.begin_function("add_f32", f32_id, vec![f32_id, f32_id]);
164        assert_eq!(func.params.len(), 2);
165        assert_eq!(func.name.as_deref(), Some("add_f32"));
166        backend.finish_function(func);
167        assert_eq!(backend.function_count(), 1);
168        assert!(backend.lookup_symbol("add_f32").is_some());
169    }
170    #[test]
171    pub(super) fn test_spirv_backend_compute_kernel() {
172        let mut backend = SpirVBackend::new();
173        let func_id = backend.emit_compute_kernel("fill_buffer", 64, 1, 1);
174        assert!(func_id > 0);
175        assert_eq!(backend.function_count(), 1);
176        let text = backend.emit_text();
177        assert!(text.contains("fill_buffer"));
178        assert!(text.contains("OpEntryPoint"));
179        assert!(text.contains("LocalSize"));
180    }
181    #[test]
182    pub(super) fn test_spirv_binary_header() {
183        let backend = SpirVBackend::new();
184        let header = backend.emit_binary_header();
185        assert_eq!(header.len(), 5);
186        assert_eq!(header[0], 0x0723_0203);
187        assert_eq!(header[4], 0);
188    }
189    #[test]
190    pub(super) fn test_spirv_module_word_count() {
191        let mut backend = SpirVBackend::new();
192        backend.emit_compute_kernel("test_kernel", 32, 1, 1);
193        let wc = backend.module.estimate_word_count();
194        assert!(wc >= 5);
195    }
196    #[test]
197    pub(super) fn test_spirv_decoration_and_names() {
198        let mut module = SpirVModule::new();
199        module.set_name(1, "my_var");
200        module.decorate(1, Decoration::Binding(0));
201        module.decorate(1, Decoration::DescriptorSet(0));
202        assert_eq!(
203            module
204                .debug_names
205                .get(&1)
206                .expect("value should be present in map"),
207            "my_var"
208        );
209        assert_eq!(
210            module
211                .decorations
212                .get(&1)
213                .expect("value should be present in map")
214                .len(),
215            2
216        );
217        let text = module.emit_text();
218        assert!(text.contains("OpName %1 \"my_var\""));
219        assert!(text.contains("OpDecorate %1 Binding(0)"));
220    }
221    #[test]
222    pub(super) fn test_spirv_op_display() {
223        assert_eq!(SpirVOp::FAdd.to_string(), "OpFAdd");
224        assert_eq!(SpirVOp::IMul.to_string(), "OpIMul");
225        assert_eq!(
226            SpirVOp::MatrixTimesVector.to_string(),
227            "OpMatrixTimesVector"
228        );
229        assert_eq!(
230            SpirVOp::CompositeConstruct.to_string(),
231            "OpCompositeConstruct"
232        );
233        assert_eq!(SpirVOp::Return.to_string(), "OpReturn");
234        assert_eq!(SpirVOp::Load.to_string(), "OpLoad");
235    }
236    #[test]
237    pub(super) fn test_storage_class_display() {
238        assert_eq!(StorageClass::Uniform.to_string(), "Uniform");
239        assert_eq!(StorageClass::Function.to_string(), "Function");
240        assert_eq!(StorageClass::Workgroup.to_string(), "Workgroup");
241        assert_eq!(StorageClass::Input.to_string(), "Input");
242        assert_eq!(StorageClass::Output.to_string(), "Output");
243    }
244}
245#[cfg(test)]
246mod SPIRV_infra_tests {
247    use super::*;
248    #[test]
249    pub(super) fn test_pass_config() {
250        let config = SPIRVPassConfig::new("test_pass", SPIRVPassPhase::Transformation);
251        assert!(config.enabled);
252        assert!(config.phase.is_modifying());
253        assert_eq!(config.phase.name(), "transformation");
254    }
255    #[test]
256    pub(super) fn test_pass_stats() {
257        let mut stats = SPIRVPassStats::new();
258        stats.record_run(10, 100, 3);
259        stats.record_run(20, 200, 5);
260        assert_eq!(stats.total_runs, 2);
261        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
262        assert!((stats.success_rate() - 1.0).abs() < 0.01);
263        let s = stats.format_summary();
264        assert!(s.contains("Runs: 2/2"));
265    }
266    #[test]
267    pub(super) fn test_pass_registry() {
268        let mut reg = SPIRVPassRegistry::new();
269        reg.register(SPIRVPassConfig::new("pass_a", SPIRVPassPhase::Analysis));
270        reg.register(SPIRVPassConfig::new("pass_b", SPIRVPassPhase::Transformation).disabled());
271        assert_eq!(reg.total_passes(), 2);
272        assert_eq!(reg.enabled_count(), 1);
273        reg.update_stats("pass_a", 5, 50, 2);
274        let stats = reg.get_stats("pass_a").expect("stats should exist");
275        assert_eq!(stats.total_changes, 5);
276    }
277    #[test]
278    pub(super) fn test_analysis_cache() {
279        let mut cache = SPIRVAnalysisCache::new(10);
280        cache.insert("key1".to_string(), vec![1, 2, 3]);
281        assert!(cache.get("key1").is_some());
282        assert!(cache.get("key2").is_none());
283        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
284        cache.invalidate("key1");
285        assert!(!cache.entries["key1"].valid);
286        assert_eq!(cache.size(), 1);
287    }
288    #[test]
289    pub(super) fn test_worklist() {
290        let mut wl = SPIRVWorklist::new();
291        assert!(wl.push(1));
292        assert!(wl.push(2));
293        assert!(!wl.push(1));
294        assert_eq!(wl.len(), 2);
295        assert_eq!(wl.pop(), Some(1));
296        assert!(!wl.contains(1));
297        assert!(wl.contains(2));
298    }
299    #[test]
300    pub(super) fn test_dominator_tree() {
301        let mut dt = SPIRVDominatorTree::new(5);
302        dt.set_idom(1, 0);
303        dt.set_idom(2, 0);
304        dt.set_idom(3, 1);
305        assert!(dt.dominates(0, 3));
306        assert!(dt.dominates(1, 3));
307        assert!(!dt.dominates(2, 3));
308        assert!(dt.dominates(3, 3));
309    }
310    #[test]
311    pub(super) fn test_liveness() {
312        let mut liveness = SPIRVLivenessInfo::new(3);
313        liveness.add_def(0, 1);
314        liveness.add_use(1, 1);
315        assert!(liveness.defs[0].contains(&1));
316        assert!(liveness.uses[1].contains(&1));
317    }
318    #[test]
319    pub(super) fn test_constant_folding() {
320        assert_eq!(SPIRVConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
321        assert_eq!(SPIRVConstantFoldingHelper::fold_div_i64(10, 0), None);
322        assert_eq!(SPIRVConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
323        assert_eq!(
324            SPIRVConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
325            0b1000
326        );
327        assert_eq!(SPIRVConstantFoldingHelper::fold_bitnot_i64(0), -1);
328    }
329    #[test]
330    pub(super) fn test_dep_graph() {
331        let mut g = SPIRVDepGraph::new();
332        g.add_dep(1, 2);
333        g.add_dep(2, 3);
334        g.add_dep(1, 3);
335        assert_eq!(g.dependencies_of(2), vec![1]);
336        let topo = g.topological_sort();
337        assert_eq!(topo.len(), 3);
338        assert!(!g.has_cycle());
339        let pos: std::collections::HashMap<u32, usize> =
340            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
341        assert!(pos[&1] < pos[&2]);
342        assert!(pos[&1] < pos[&3]);
343        assert!(pos[&2] < pos[&3]);
344    }
345}