Skip to main content

oxilean_codegen/native_backend/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::*;
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9    BasicBlock, BlockId, CondCode, NatAnalysisCache, NatConstantFoldingHelper, NatDepGraph,
10    NatDominatorTree, NatLivenessInfo, NatPassConfig, NatPassPhase, NatPassRegistry, NatPassStats,
11    NatWorklist, NativeBackend, NativeEmitConfig, NativeEmitStats, NativeInst, NativeModule,
12    NativeType, NativeValue, Register, RegisterAllocator,
13};
14
15/// Map an LCNF type to a native type.
16pub(super) fn lcnf_type_to_native(ty: &LcnfType) -> NativeType {
17    match ty {
18        LcnfType::Nat => NativeType::I64,
19        LcnfType::LcnfString => NativeType::Ptr,
20        LcnfType::Object => NativeType::Ptr,
21        LcnfType::Var(_) => NativeType::Ptr,
22        LcnfType::Fun(_, _) => NativeType::Ptr,
23        LcnfType::Ctor(_, _) => NativeType::Ptr,
24        LcnfType::Erased | LcnfType::Irrelevant | LcnfType::Unit => NativeType::Void,
25    }
26}
27/// Compile an LCNF module to native IR with default settings.
28pub fn compile_to_native(module: &LcnfModule) -> NativeModule {
29    let mut backend = NativeBackend::default_backend();
30    backend.compile_module(module)
31}
32/// Compile and perform register allocation.
33pub fn compile_and_regalloc(
34    module: &LcnfModule,
35    num_regs: usize,
36) -> (NativeModule, Vec<HashMap<Register, Register>>) {
37    let mut backend = NativeBackend::default_backend();
38    let native_module = backend.compile_module(module);
39    let mut allocations = Vec::new();
40    for func in &native_module.functions {
41        let mut allocator = RegisterAllocator::new(num_regs);
42        let alloc = allocator.allocate(func);
43        allocations.push(alloc);
44    }
45    (native_module, allocations)
46}
47#[cfg(test)]
48mod tests {
49    use super::*;
50    pub(super) fn vid(n: u64) -> LcnfVarId {
51        LcnfVarId(n)
52    }
53    pub(super) fn mk_param(n: u64, name: &str) -> LcnfParam {
54        LcnfParam {
55            id: vid(n),
56            name: name.to_string(),
57            ty: LcnfType::Nat,
58            erased: false,
59            borrowed: false,
60        }
61    }
62    pub(super) fn mk_fun_decl(name: &str, body: LcnfExpr) -> LcnfFunDecl {
63        LcnfFunDecl {
64            name: name.to_string(),
65            original_name: None,
66            params: vec![mk_param(0, "n")],
67            ret_type: LcnfType::Nat,
68            body,
69            is_recursive: false,
70            is_lifted: false,
71            inline_cost: 1,
72        }
73    }
74    pub(super) fn mk_module(decls: Vec<LcnfFunDecl>) -> LcnfModule {
75        LcnfModule {
76            fun_decls: decls,
77            extern_decls: vec![],
78            name: "test_mod".to_string(),
79            metadata: LcnfModuleMetadata::default(),
80        }
81    }
82    #[test]
83    pub(super) fn test_native_type_size() {
84        assert_eq!(NativeType::I8.size_bytes(), 1);
85        assert_eq!(NativeType::I16.size_bytes(), 2);
86        assert_eq!(NativeType::I32.size_bytes(), 4);
87        assert_eq!(NativeType::I64.size_bytes(), 8);
88        assert_eq!(NativeType::Ptr.size_bytes(), 8);
89        assert_eq!(NativeType::Void.size_bytes(), 0);
90    }
91    #[test]
92    pub(super) fn test_native_type_display() {
93        assert_eq!(NativeType::I64.to_string(), "i64");
94        assert_eq!(NativeType::Ptr.to_string(), "ptr");
95        assert_eq!(NativeType::Void.to_string(), "void");
96    }
97    #[test]
98    pub(super) fn test_native_type_properties() {
99        assert!(NativeType::I32.is_integer());
100        assert!(!NativeType::I32.is_float());
101        assert!(NativeType::F64.is_float());
102        assert!(NativeType::Ptr.is_pointer());
103    }
104    #[test]
105    pub(super) fn test_register_virtual_physical() {
106        let vr = Register::virt(5);
107        assert!(vr.is_virtual());
108        assert!(!vr.is_physical());
109        let pr = Register::phys(3);
110        assert!(!pr.is_virtual());
111        assert!(pr.is_physical());
112    }
113    #[test]
114    pub(super) fn test_register_display() {
115        assert_eq!(Register::virt(5).to_string(), "v5");
116        assert_eq!(Register::phys(3).to_string(), "r3");
117    }
118    #[test]
119    pub(super) fn test_native_value_display() {
120        assert_eq!(NativeValue::Reg(Register::virt(0)).to_string(), "v0");
121        assert_eq!(NativeValue::Imm(42).to_string(), "#42");
122        assert_eq!(NativeValue::FRef("foo".to_string()).to_string(), "@foo");
123        assert_eq!(NativeValue::StackSlot(3).to_string(), "ss3");
124    }
125    #[test]
126    pub(super) fn test_block_id_display() {
127        assert_eq!(BlockId(0).to_string(), "bb0");
128        assert_eq!(BlockId(5).to_string(), "bb5");
129    }
130    #[test]
131    pub(super) fn test_basic_block_successors() {
132        let mut block = BasicBlock::new(BlockId(0));
133        block.push_inst(NativeInst::Br { target: BlockId(1) });
134        assert_eq!(block.successors(), vec![BlockId(1)]);
135        let mut block2 = BasicBlock::new(BlockId(1));
136        block2.push_inst(NativeInst::CondBr {
137            cond: NativeValue::Reg(Register::virt(0)),
138            then_target: BlockId(2),
139            else_target: BlockId(3),
140        });
141        assert_eq!(block2.successors(), vec![BlockId(2), BlockId(3)]);
142    }
143    #[test]
144    pub(super) fn test_inst_is_terminator() {
145        assert!(NativeInst::Br { target: BlockId(0) }.is_terminator());
146        assert!(NativeInst::Ret { value: None }.is_terminator());
147        assert!(!NativeInst::Nop.is_terminator());
148        assert!(!NativeInst::LoadImm {
149            dst: Register::virt(0),
150            ty: NativeType::I64,
151            value: 0
152        }
153        .is_terminator());
154    }
155    #[test]
156    pub(super) fn test_inst_dst_reg() {
157        let inst = NativeInst::Add {
158            dst: Register::virt(5),
159            ty: NativeType::I64,
160            lhs: NativeValue::Reg(Register::virt(0)),
161            rhs: NativeValue::Imm(1),
162        };
163        assert_eq!(inst.dst_reg(), Some(Register::virt(5)));
164        let inst2 = NativeInst::Ret { value: None };
165        assert_eq!(inst2.dst_reg(), None);
166    }
167    #[test]
168    pub(super) fn test_inst_src_regs() {
169        let inst = NativeInst::Add {
170            dst: Register::virt(5),
171            ty: NativeType::I64,
172            lhs: NativeValue::Reg(Register::virt(1)),
173            rhs: NativeValue::Reg(Register::virt(2)),
174        };
175        let srcs = inst.src_regs();
176        assert_eq!(srcs.len(), 2);
177        assert!(srcs.contains(&Register::virt(1)));
178        assert!(srcs.contains(&Register::virt(2)));
179    }
180    #[test]
181    pub(super) fn test_compile_simple_function() {
182        let body = LcnfExpr::Return(LcnfArg::Var(vid(0)));
183        let decl = mk_fun_decl("identity", body);
184        let module = mk_module(vec![decl]);
185        let mut backend = NativeBackend::default_backend();
186        let native_module = backend.compile_module(&module);
187        assert_eq!(native_module.functions.len(), 1);
188        let func = &native_module.functions[0];
189        assert_eq!(func.name, "identity");
190        assert!(!func.blocks.is_empty());
191    }
192    #[test]
193    pub(super) fn test_compile_case_expression() {
194        let body = LcnfExpr::Case {
195            scrutinee: vid(0),
196            scrutinee_ty: LcnfType::Ctor("Bool".into(), vec![]),
197            alts: vec![
198                LcnfAlt {
199                    ctor_name: "False".into(),
200                    ctor_tag: 0,
201                    params: vec![],
202                    body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0))),
203                },
204                LcnfAlt {
205                    ctor_name: "True".into(),
206                    ctor_tag: 1,
207                    params: vec![],
208                    body: LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(1))),
209                },
210            ],
211            default: Some(Box::new(LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(99))))),
212        };
213        let decl = mk_fun_decl("to_nat", body);
214        let module = mk_module(vec![decl]);
215        let mut backend = NativeBackend::default_backend();
216        let native_module = backend.compile_module(&module);
217        let func = &native_module.functions[0];
218        assert!(func.blocks.len() > 1);
219    }
220    #[test]
221    pub(super) fn test_register_allocation() {
222        let body = LcnfExpr::Let {
223            id: vid(1),
224            name: "a".to_string(),
225            ty: LcnfType::Nat,
226            value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
227            body: Box::new(LcnfExpr::Let {
228                id: vid(2),
229                name: "b".to_string(),
230                ty: LcnfType::Nat,
231                value: LcnfLetValue::Lit(LcnfLit::Nat(10)),
232                body: Box::new(LcnfExpr::Return(LcnfArg::Var(vid(1)))),
233            }),
234        };
235        let decl = mk_fun_decl("test_alloc", body);
236        let module = mk_module(vec![decl]);
237        let (native_module, allocations) = compile_and_regalloc(&module, 8);
238        assert_eq!(native_module.functions.len(), 1);
239        assert_eq!(allocations.len(), 1);
240        let alloc = &allocations[0];
241        for (vreg, phys) in alloc {
242            assert!(vreg.is_virtual());
243            assert!(phys.is_physical());
244        }
245    }
246    #[test]
247    pub(super) fn test_native_module_display() {
248        let module = NativeModule::new("test");
249        let s = module.to_string();
250        assert!(s.contains("module: test"));
251    }
252    #[test]
253    pub(super) fn test_native_emit_config_default() {
254        let cfg = NativeEmitConfig::default();
255        assert_eq!(cfg.opt_level, 1);
256        assert!(!cfg.debug_info);
257        assert_eq!(cfg.target_arch, "x86_64");
258    }
259    #[test]
260    pub(super) fn test_native_emit_stats_display() {
261        let stats = NativeEmitStats {
262            functions_compiled: 3,
263            blocks_generated: 10,
264            ..Default::default()
265        };
266        let s = stats.to_string();
267        assert!(s.contains("fns=3"));
268        assert!(s.contains("blocks=10"));
269    }
270    #[test]
271    pub(super) fn test_lcnf_type_to_native() {
272        assert_eq!(lcnf_type_to_native(&LcnfType::Nat), NativeType::I64);
273        assert_eq!(lcnf_type_to_native(&LcnfType::Object), NativeType::Ptr);
274        assert_eq!(lcnf_type_to_native(&LcnfType::Unit), NativeType::Void);
275    }
276    #[test]
277    pub(super) fn test_compile_let_chain() {
278        let body = LcnfExpr::Let {
279            id: vid(1),
280            name: "a".to_string(),
281            ty: LcnfType::Nat,
282            value: LcnfLetValue::Lit(LcnfLit::Nat(42)),
283            body: Box::new(LcnfExpr::Let {
284                id: vid(2),
285                name: "b".to_string(),
286                ty: LcnfType::Nat,
287                value: LcnfLetValue::App(LcnfArg::Var(vid(99)), vec![LcnfArg::Var(vid(1))]),
288                body: Box::new(LcnfExpr::Return(LcnfArg::Var(vid(2)))),
289            }),
290        };
291        let decl = mk_fun_decl("chain", body);
292        let mut backend = NativeBackend::default_backend();
293        let func = backend.compile_fun_decl(&decl);
294        assert!(!func.blocks.is_empty());
295        assert!(func.instruction_count() > 0);
296    }
297    #[test]
298    pub(super) fn test_virtual_registers() {
299        let body = LcnfExpr::Let {
300            id: vid(1),
301            name: "a".to_string(),
302            ty: LcnfType::Nat,
303            value: LcnfLetValue::Lit(LcnfLit::Nat(1)),
304            body: Box::new(LcnfExpr::Return(LcnfArg::Var(vid(1)))),
305        };
306        let decl = mk_fun_decl("test", body);
307        let mut backend = NativeBackend::default_backend();
308        let func = backend.compile_fun_decl(&decl);
309        let vregs = func.virtual_registers();
310        assert!(!vregs.is_empty());
311    }
312    #[test]
313    pub(super) fn test_cond_code_display() {
314        assert_eq!(CondCode::Eq.to_string(), "eq");
315        assert_eq!(CondCode::Lt.to_string(), "lt");
316        assert_eq!(CondCode::Uge.to_string(), "uge");
317    }
318    #[test]
319    pub(super) fn test_native_func_display() {
320        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
321        let decl = mk_fun_decl("display_test", body);
322        let mut backend = NativeBackend::default_backend();
323        let func = backend.compile_fun_decl(&decl);
324        let s = func.to_string();
325        assert!(s.contains("func @display_test"));
326    }
327}
328#[cfg(test)]
329mod Nat_infra_tests {
330    use super::*;
331    #[test]
332    pub(super) fn test_pass_config() {
333        let config = NatPassConfig::new("test_pass", NatPassPhase::Transformation);
334        assert!(config.enabled);
335        assert!(config.phase.is_modifying());
336        assert_eq!(config.phase.name(), "transformation");
337    }
338    #[test]
339    pub(super) fn test_pass_stats() {
340        let mut stats = NatPassStats::new();
341        stats.record_run(10, 100, 3);
342        stats.record_run(20, 200, 5);
343        assert_eq!(stats.total_runs, 2);
344        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
345        assert!((stats.success_rate() - 1.0).abs() < 0.01);
346        let s = stats.format_summary();
347        assert!(s.contains("Runs: 2/2"));
348    }
349    #[test]
350    pub(super) fn test_pass_registry() {
351        let mut reg = NatPassRegistry::new();
352        reg.register(NatPassConfig::new("pass_a", NatPassPhase::Analysis));
353        reg.register(NatPassConfig::new("pass_b", NatPassPhase::Transformation).disabled());
354        assert_eq!(reg.total_passes(), 2);
355        assert_eq!(reg.enabled_count(), 1);
356        reg.update_stats("pass_a", 5, 50, 2);
357        let stats = reg.get_stats("pass_a").expect("stats should exist");
358        assert_eq!(stats.total_changes, 5);
359    }
360    #[test]
361    pub(super) fn test_analysis_cache() {
362        let mut cache = NatAnalysisCache::new(10);
363        cache.insert("key1".to_string(), vec![1, 2, 3]);
364        assert!(cache.get("key1").is_some());
365        assert!(cache.get("key2").is_none());
366        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
367        cache.invalidate("key1");
368        assert!(!cache.entries["key1"].valid);
369        assert_eq!(cache.size(), 1);
370    }
371    #[test]
372    pub(super) fn test_worklist() {
373        let mut wl = NatWorklist::new();
374        assert!(wl.push(1));
375        assert!(wl.push(2));
376        assert!(!wl.push(1));
377        assert_eq!(wl.len(), 2);
378        assert_eq!(wl.pop(), Some(1));
379        assert!(!wl.contains(1));
380        assert!(wl.contains(2));
381    }
382    #[test]
383    pub(super) fn test_dominator_tree() {
384        let mut dt = NatDominatorTree::new(5);
385        dt.set_idom(1, 0);
386        dt.set_idom(2, 0);
387        dt.set_idom(3, 1);
388        assert!(dt.dominates(0, 3));
389        assert!(dt.dominates(1, 3));
390        assert!(!dt.dominates(2, 3));
391        assert!(dt.dominates(3, 3));
392    }
393    #[test]
394    pub(super) fn test_liveness() {
395        let mut liveness = NatLivenessInfo::new(3);
396        liveness.add_def(0, 1);
397        liveness.add_use(1, 1);
398        assert!(liveness.defs[0].contains(&1));
399        assert!(liveness.uses[1].contains(&1));
400    }
401    #[test]
402    pub(super) fn test_constant_folding() {
403        assert_eq!(NatConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
404        assert_eq!(NatConstantFoldingHelper::fold_div_i64(10, 0), None);
405        assert_eq!(NatConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
406        assert_eq!(
407            NatConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
408            0b1000
409        );
410        assert_eq!(NatConstantFoldingHelper::fold_bitnot_i64(0), -1);
411    }
412    #[test]
413    pub(super) fn test_dep_graph() {
414        let mut g = NatDepGraph::new();
415        g.add_dep(1, 2);
416        g.add_dep(2, 3);
417        g.add_dep(1, 3);
418        assert_eq!(g.dependencies_of(2), vec![1]);
419        let topo = g.topological_sort();
420        assert_eq!(topo.len(), 3);
421        assert!(!g.has_cycle());
422        let pos: std::collections::HashMap<u32, usize> =
423            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
424        assert!(pos[&1] < pos[&2]);
425        assert!(pos[&1] < pos[&3]);
426        assert!(pos[&2] < pos[&3]);
427    }
428}