Skip to main content

oxilean_codegen/opt_specialize/
functions.rs

1//! Auto-generated module
2//!
3//! 🤖 Generated with [SplitRS](https://github.com/cool-japan/splitrs)
4
5use crate::lcnf::*;
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9    NumericSpecializer, SizeBudget, SpecAnalysisCache, SpecCallSite, SpecClosureArg, SpecConstArg,
10    SpecConstantFoldingHelper, SpecDepGraph, SpecDominatorTree, SpecExtCache, SpecExtConstFolder,
11    SpecExtDepGraph, SpecExtDomTree, SpecExtLiveness, SpecExtPassConfig, SpecExtPassPhase,
12    SpecExtPassRegistry, SpecExtPassStats, SpecExtWorklist, SpecLivenessInfo, SpecPassConfig,
13    SpecPassPhase, SpecPassRegistry, SpecPassStats, SpecTypeArg, SpecWorklist, SpecializationCache,
14    SpecializationConfig, SpecializationKey, SpecializationPass, SpecializationStats,
15};
16
17/// Generate a short suffix for a type
18pub(super) fn type_suffix(ty: &LcnfType) -> String {
19    match ty {
20        LcnfType::Nat => "nat".to_string(),
21        LcnfType::Object => "obj".to_string(),
22        LcnfType::Unit => "unit".to_string(),
23        LcnfType::Erased => "e".to_string(),
24        LcnfType::LcnfString => "str".to_string(),
25        LcnfType::Var(name) => name.clone(),
26        LcnfType::Ctor(name, _) => name.clone(),
27        LcnfType::Fun(_, _) => "fn".to_string(),
28        LcnfType::Irrelevant => "irr".to_string(),
29    }
30}
31/// Analyze call sites in a function body to find specialization opportunities
32pub(super) fn find_specialization_sites(
33    expr: &LcnfExpr,
34    known_constants: &HashMap<LcnfVarId, LcnfLit>,
35    known_functions: &HashMap<LcnfVarId, String>,
36    decl_names: &HashSet<String>,
37) -> Vec<SpecCallSite> {
38    let mut sites = Vec::new();
39    let mut call_idx = 0;
40    find_spec_sites_inner(
41        expr,
42        known_constants,
43        known_functions,
44        decl_names,
45        &mut sites,
46        &mut call_idx,
47    );
48    sites
49}
50pub(super) fn find_spec_sites_inner(
51    expr: &LcnfExpr,
52    known_constants: &HashMap<LcnfVarId, LcnfLit>,
53    known_functions: &HashMap<LcnfVarId, String>,
54    decl_names: &HashSet<String>,
55    sites: &mut Vec<SpecCallSite>,
56    call_idx: &mut usize,
57) {
58    match expr {
59        LcnfExpr::Let {
60            id, value, body, ..
61        } => {
62            let mut extended_consts = known_constants.clone();
63            if let LcnfLetValue::Lit(lit) = value {
64                extended_consts.insert(*id, lit.clone());
65            }
66            let mut extended_fns = known_functions.clone();
67            if let LcnfLetValue::FVar(fvar) = value {
68                if let Some(fname) = known_functions.get(fvar) {
69                    extended_fns.insert(*id, fname.clone());
70                }
71            }
72            if let LcnfLetValue::App(func, args) = value {
73                let callee_name = match func {
74                    LcnfArg::Var(v) => known_functions.get(v).cloned(),
75                    _ => None,
76                };
77                if let Some(ref callee) = callee_name {
78                    if decl_names.contains(callee.as_str()) {
79                        let const_args: Vec<SpecConstArg> = args
80                            .iter()
81                            .map(|arg| match arg {
82                                LcnfArg::Lit(LcnfLit::Nat(n)) => SpecConstArg::Nat(*n),
83                                LcnfArg::Lit(LcnfLit::Str(s)) => SpecConstArg::Str(s.clone()),
84                                LcnfArg::Var(v) => {
85                                    if let Some(lit) = extended_consts.get(v) {
86                                        match lit {
87                                            LcnfLit::Nat(n) => SpecConstArg::Nat(*n),
88                                            LcnfLit::Str(s) => SpecConstArg::Str(s.clone()),
89                                        }
90                                    } else {
91                                        SpecConstArg::Unknown
92                                    }
93                                }
94                                _ => SpecConstArg::Unknown,
95                            })
96                            .collect();
97                        let closure_args: Vec<SpecClosureArg> = args
98                            .iter()
99                            .enumerate()
100                            .map(|(i, arg)| {
101                                let known_fn = match arg {
102                                    LcnfArg::Var(v) => extended_fns.get(v).cloned(),
103                                    _ => None,
104                                };
105                                SpecClosureArg {
106                                    known_fn,
107                                    param_idx: i,
108                                }
109                            })
110                            .collect();
111                        let callee_var = match func {
112                            LcnfArg::Var(v) => Some(*v),
113                            _ => None,
114                        };
115                        sites.push(SpecCallSite {
116                            callee: callee.clone(),
117                            call_idx: *call_idx,
118                            type_args: vec![],
119                            const_args,
120                            closure_args,
121                            callee_var,
122                        });
123                        *call_idx += 1;
124                    }
125                }
126            }
127            find_spec_sites_inner(
128                body,
129                &extended_consts,
130                &extended_fns,
131                decl_names,
132                sites,
133                call_idx,
134            );
135        }
136        LcnfExpr::Case { alts, default, .. } => {
137            for alt in alts {
138                find_spec_sites_inner(
139                    &alt.body,
140                    known_constants,
141                    known_functions,
142                    decl_names,
143                    sites,
144                    call_idx,
145                );
146            }
147            if let Some(def) = default {
148                find_spec_sites_inner(
149                    def,
150                    known_constants,
151                    known_functions,
152                    decl_names,
153                    sites,
154                    call_idx,
155                );
156            }
157        }
158        LcnfExpr::TailCall(func, args) => {
159            let callee_name = match func {
160                LcnfArg::Var(v) => known_functions.get(v).cloned(),
161                _ => None,
162            };
163            if let Some(callee) = callee_name {
164                if decl_names.contains(callee.as_str()) {
165                    let const_args: Vec<SpecConstArg> = args
166                        .iter()
167                        .map(|arg| match arg {
168                            LcnfArg::Lit(LcnfLit::Nat(n)) => SpecConstArg::Nat(*n),
169                            LcnfArg::Lit(LcnfLit::Str(s)) => SpecConstArg::Str(s.clone()),
170                            LcnfArg::Var(v) => {
171                                if let Some(lit) = known_constants.get(v) {
172                                    match lit {
173                                        LcnfLit::Nat(n) => SpecConstArg::Nat(*n),
174                                        LcnfLit::Str(s) => SpecConstArg::Str(s.clone()),
175                                    }
176                                } else {
177                                    SpecConstArg::Unknown
178                                }
179                            }
180                            _ => SpecConstArg::Unknown,
181                        })
182                        .collect();
183                    let closure_args: Vec<SpecClosureArg> = args
184                        .iter()
185                        .enumerate()
186                        .map(|(i, arg)| {
187                            let known_fn = match arg {
188                                LcnfArg::Var(v) => known_functions.get(v).cloned(),
189                                _ => None,
190                            };
191                            SpecClosureArg {
192                                known_fn,
193                                param_idx: i,
194                            }
195                        })
196                        .collect();
197                    let callee_var = match func {
198                        LcnfArg::Var(v) => Some(*v),
199                        _ => None,
200                    };
201                    sites.push(SpecCallSite {
202                        callee,
203                        call_idx: *call_idx,
204                        type_args: vec![],
205                        const_args,
206                        closure_args,
207                        callee_var,
208                    });
209                    *call_idx += 1;
210                }
211            }
212        }
213        LcnfExpr::Return(_) | LcnfExpr::Unreachable => {}
214    }
215}
216/// Count the number of LCNF instructions in an expression
217pub(super) fn count_instructions(expr: &LcnfExpr) -> usize {
218    match expr {
219        LcnfExpr::Let { body, .. } => 1 + count_instructions(body),
220        LcnfExpr::Case { alts, default, .. } => {
221            let alts_size: usize = alts.iter().map(|a| count_instructions(&a.body)).sum();
222            let def_size = default.as_ref().map(|d| count_instructions(d)).unwrap_or(0);
223            1 + alts_size + def_size
224        }
225        LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => 1,
226    }
227}
228/// Analyze whether a function parameter is always called with the same closure
229pub(super) fn analyze_closure_uniformity(
230    decl: &LcnfFunDecl,
231    param_idx: usize,
232    sites: &[SpecCallSite],
233) -> Option<String> {
234    let mut known_fn: Option<String> = None;
235    for site in sites {
236        if site.callee != decl.name {
237            continue;
238        }
239        if param_idx >= site.closure_args.len() {
240            return None;
241        }
242        match &site.closure_args[param_idx].known_fn {
243            Some(fn_name) => {
244                if let Some(ref existing) = known_fn {
245                    if existing != fn_name {
246                        return None;
247                    }
248                } else {
249                    known_fn = Some(fn_name.clone());
250                }
251            }
252            None => return None,
253        }
254    }
255    known_fn
256}
257/// Check whether a parameter is used as a function (called) in the body
258pub(super) fn is_called_as_function(expr: &LcnfExpr, param_id: LcnfVarId) -> bool {
259    match expr {
260        LcnfExpr::Let { value, body, .. } => {
261            let called_here = matches!(
262                value, LcnfLetValue::App(LcnfArg::Var(v), _) if * v == param_id
263            );
264            called_here || is_called_as_function(body, param_id)
265        }
266        LcnfExpr::Case { alts, default, .. } => {
267            alts.iter()
268                .any(|a| is_called_as_function(&a.body, param_id))
269                || default
270                    .as_ref()
271                    .is_some_and(|d| is_called_as_function(d, param_id))
272        }
273        LcnfExpr::TailCall(LcnfArg::Var(v), _) => *v == param_id,
274        _ => false,
275    }
276}
277/// Main entry point: specialize functions in a module
278pub fn specialize_module(module: &mut LcnfModule, config: &SpecializationConfig) {
279    let mut pass = SpecializationPass::new(config.clone());
280    pass.run(module);
281}
282/// Specialize a single function for numeric operations
283pub fn specialize_numeric(decl: &LcnfFunDecl) -> Option<LcnfFunDecl> {
284    let specializer = NumericSpecializer::new();
285    if !specializer.is_numeric_op(&decl.name) {
286        return None;
287    }
288    let mut spec = decl.clone();
289    spec.name = format!("{}_u64", decl.name);
290    for param in &mut spec.params {
291        param.ty = specializer.specialize_nat_to_u64(&param.ty);
292    }
293    spec.ret_type = specializer.specialize_nat_to_u64(&spec.ret_type);
294    Some(spec)
295}
296/// Check if a function is worth specializing based on heuristics
297pub fn is_worth_specializing(decl: &LcnfFunDecl, config: &SpecializationConfig) -> bool {
298    let size = count_instructions(&decl.body);
299    if size > config.size_threshold {
300        return false;
301    }
302    let has_poly = decl
303        .params
304        .iter()
305        .any(|p| matches!(p.ty, LcnfType::Var(_) | LcnfType::Object | LcnfType::Erased));
306    let has_fn_param = decl
307        .params
308        .iter()
309        .any(|p| matches!(p.ty, LcnfType::Fun(_, _)));
310    has_poly || (has_fn_param && config.specialize_closures)
311}
312#[cfg(test)]
313mod tests {
314    use super::*;
315    pub(super) fn make_var(n: u64) -> LcnfVarId {
316        LcnfVarId(n)
317    }
318    pub(super) fn make_param(n: u64, name: &str, ty: LcnfType) -> LcnfParam {
319        LcnfParam {
320            id: LcnfVarId(n),
321            name: name.to_string(),
322            ty,
323            erased: false,
324            borrowed: false,
325        }
326    }
327    pub(super) fn make_simple_let(id: u64, value: LcnfLetValue, body: LcnfExpr) -> LcnfExpr {
328        LcnfExpr::Let {
329            id: LcnfVarId(id),
330            name: format!("x{}", id),
331            ty: LcnfType::Nat,
332            value,
333            body: Box::new(body),
334        }
335    }
336    pub(super) fn make_decl(name: &str, params: Vec<LcnfParam>, body: LcnfExpr) -> LcnfFunDecl {
337        LcnfFunDecl {
338            name: name.to_string(),
339            original_name: None,
340            params,
341            ret_type: LcnfType::Nat,
342            body,
343            is_recursive: false,
344            is_lifted: false,
345            inline_cost: 1,
346        }
347    }
348    #[test]
349    pub(super) fn test_config_default() {
350        let config = SpecializationConfig::default();
351        assert_eq!(config.max_specializations, 8);
352        assert!(config.specialize_closures);
353        assert!(config.specialize_numerics);
354        assert_eq!(config.size_threshold, 200);
355    }
356    #[test]
357    pub(super) fn test_spec_key_trivial() {
358        let key = SpecializationKey {
359            original: "foo".to_string(),
360            type_args: vec![SpecTypeArg::Poly],
361            const_args: vec![SpecConstArg::Unknown],
362            closure_args: vec![SpecClosureArg {
363                known_fn: None,
364                param_idx: 0,
365            }],
366        };
367        assert!(key.is_trivial());
368    }
369    #[test]
370    pub(super) fn test_spec_key_non_trivial_type() {
371        let key = SpecializationKey {
372            original: "foo".to_string(),
373            type_args: vec![SpecTypeArg::Concrete(LcnfType::Nat)],
374            const_args: vec![],
375            closure_args: vec![],
376        };
377        assert!(!key.is_trivial());
378    }
379    #[test]
380    pub(super) fn test_spec_key_non_trivial_const() {
381        let key = SpecializationKey {
382            original: "foo".to_string(),
383            type_args: vec![],
384            const_args: vec![SpecConstArg::Nat(42)],
385            closure_args: vec![],
386        };
387        assert!(!key.is_trivial());
388    }
389    #[test]
390    pub(super) fn test_spec_key_mangled_name() {
391        let key = SpecializationKey {
392            original: "List.map".to_string(),
393            type_args: vec![SpecTypeArg::Concrete(LcnfType::Nat)],
394            const_args: vec![SpecConstArg::Unknown],
395            closure_args: vec![],
396        };
397        let name = key.mangled_name();
398        assert!(name.starts_with("List.map"));
399        assert!(name.contains("_T0_nat"));
400    }
401    #[test]
402    pub(super) fn test_spec_key_mangled_name_with_const() {
403        let key = SpecializationKey {
404            original: "repeat".to_string(),
405            type_args: vec![],
406            const_args: vec![SpecConstArg::Nat(3)],
407            closure_args: vec![],
408        };
409        let name = key.mangled_name();
410        assert!(name.contains("_C0_N3"));
411    }
412    #[test]
413    pub(super) fn test_spec_key_mangled_name_with_closure() {
414        let key = SpecializationKey {
415            original: "List.map".to_string(),
416            type_args: vec![],
417            const_args: vec![],
418            closure_args: vec![SpecClosureArg {
419                known_fn: Some("double".to_string()),
420                param_idx: 0,
421            }],
422        };
423        let name = key.mangled_name();
424        assert!(name.contains("_Fdouble"));
425    }
426    #[test]
427    pub(super) fn test_type_suffix() {
428        assert_eq!(type_suffix(&LcnfType::Nat), "nat");
429        assert_eq!(type_suffix(&LcnfType::Object), "obj");
430        assert_eq!(type_suffix(&LcnfType::Unit), "unit");
431        assert_eq!(type_suffix(&LcnfType::LcnfString), "str");
432    }
433    #[test]
434    pub(super) fn test_cache_operations() {
435        let mut cache = SpecializationCache::new();
436        let key = SpecializationKey {
437            original: "foo".to_string(),
438            type_args: vec![SpecTypeArg::Concrete(LcnfType::Nat)],
439            const_args: vec![],
440            closure_args: vec![],
441        };
442        assert!(cache.lookup(&key).is_none());
443        cache.insert(key.clone(), "foo_nat".to_string(), 10);
444        assert_eq!(cache.lookup(&key), Some("foo_nat"));
445        assert_eq!(cache.specialization_count("foo"), 1);
446        assert_eq!(cache.total_growth, 10);
447    }
448    #[test]
449    pub(super) fn test_size_budget() {
450        let mut budget = SizeBudget::new(100, 2.0);
451        assert!(budget.can_afford(50));
452        assert!(budget.can_afford(100));
453        assert!(!budget.can_afford(101));
454        budget.spend(50);
455        assert!(budget.can_afford(50));
456        assert!(!budget.can_afford(51));
457        assert_eq!(budget.remaining(), 50);
458    }
459    #[test]
460    pub(super) fn test_numeric_specializer() {
461        let specializer = NumericSpecializer::new();
462        assert!(specializer.is_numeric_op("Nat.add"));
463        assert!(specializer.is_numeric_op("Nat.mul"));
464        assert!(!specializer.is_numeric_op("List.map"));
465    }
466    #[test]
467    pub(super) fn test_numeric_type_specialization() {
468        let specializer = NumericSpecializer::new();
469        let ty = LcnfType::Fun(vec![LcnfType::Nat, LcnfType::Nat], Box::new(LcnfType::Nat));
470        let spec = specializer.specialize_nat_to_u64(&ty);
471        assert_eq!(
472            spec,
473            LcnfType::Fun(vec![LcnfType::Nat, LcnfType::Nat], Box::new(LcnfType::Nat))
474        );
475    }
476    #[test]
477    pub(super) fn test_specialize_numeric() {
478        let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
479        let decl = make_decl(
480            "Nat.add",
481            vec![
482                make_param(0, "a", LcnfType::Nat),
483                make_param(1, "b", LcnfType::Nat),
484            ],
485            body,
486        );
487        let result = specialize_numeric(&decl);
488        assert!(result.is_some());
489        let spec = result.expect("spec should be Some/Ok");
490        assert_eq!(spec.name, "Nat.add_u64");
491    }
492    #[test]
493    pub(super) fn test_specialize_numeric_non_numeric() {
494        let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
495        let decl = make_decl("List.map", vec![make_param(0, "f", LcnfType::Object)], body);
496        let result = specialize_numeric(&decl);
497        assert!(result.is_none());
498    }
499    #[test]
500    pub(super) fn test_is_worth_specializing_polymorphic() {
501        let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
502        let decl = make_decl(
503            "id",
504            vec![make_param(0, "x", LcnfType::Var("a".to_string()))],
505            body,
506        );
507        let config = SpecializationConfig::default();
508        assert!(is_worth_specializing(&decl, &config));
509    }
510    #[test]
511    pub(super) fn test_is_worth_specializing_concrete() {
512        let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
513        let decl = make_decl(
514            "add",
515            vec![
516                make_param(0, "a", LcnfType::Nat),
517                make_param(1, "b", LcnfType::Nat),
518            ],
519            body,
520        );
521        let config = SpecializationConfig::default();
522        assert!(!is_worth_specializing(&decl, &config));
523    }
524    #[test]
525    pub(super) fn test_is_worth_specializing_higher_order() {
526        let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
527        let fn_ty = LcnfType::Fun(vec![LcnfType::Nat], Box::new(LcnfType::Nat));
528        let decl = make_decl("apply", vec![make_param(0, "f", fn_ty)], body);
529        let config = SpecializationConfig::default();
530        assert!(is_worth_specializing(&decl, &config));
531    }
532    #[test]
533    pub(super) fn test_is_called_as_function() {
534        let body = make_simple_let(
535            5,
536            LcnfLetValue::App(
537                LcnfArg::Var(make_var(0)),
538                vec![LcnfArg::Lit(LcnfLit::Nat(1))],
539            ),
540            LcnfExpr::Return(LcnfArg::Var(make_var(5))),
541        );
542        assert!(is_called_as_function(&body, make_var(0)));
543        assert!(!is_called_as_function(&body, make_var(1)));
544    }
545    #[test]
546    pub(super) fn test_count_instructions() {
547        let expr = make_simple_let(
548            1,
549            LcnfLetValue::Lit(LcnfLit::Nat(42)),
550            make_simple_let(
551                2,
552                LcnfLetValue::Lit(LcnfLit::Nat(10)),
553                LcnfExpr::Return(LcnfArg::Var(make_var(2))),
554            ),
555        );
556        assert_eq!(count_instructions(&expr), 3);
557    }
558    #[test]
559    pub(super) fn test_substitute_constant() {
560        let mut expr = make_simple_let(
561            1,
562            LcnfLetValue::FVar(make_var(0)),
563            LcnfExpr::Return(LcnfArg::Var(make_var(1))),
564        );
565        let pass = SpecializationPass::new(SpecializationConfig::default());
566        pass.substitute_constant(&mut expr, make_var(0), &LcnfLit::Nat(42));
567        if let LcnfExpr::Let { value, .. } = &expr {
568            assert_eq!(*value, LcnfLetValue::Lit(LcnfLit::Nat(42)));
569        } else {
570            panic!("Expected Let");
571        }
572    }
573    #[test]
574    pub(super) fn test_substitute_constant_in_return() {
575        let mut expr = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
576        let pass = SpecializationPass::new(SpecializationConfig::default());
577        pass.substitute_constant(&mut expr, make_var(0), &LcnfLit::Nat(99));
578        assert_eq!(expr, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(99))));
579    }
580    #[test]
581    pub(super) fn test_specialize_module_empty() {
582        let mut module = LcnfModule::default();
583        let config = SpecializationConfig::default();
584        specialize_module(&mut module, &config);
585        assert!(module.fun_decls.is_empty());
586    }
587    #[test]
588    pub(super) fn test_specialize_module_simple() {
589        let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
590        let decl = make_decl(
591            "id",
592            vec![make_param(0, "x", LcnfType::Var("a".to_string()))],
593            body,
594        );
595        let mut module = LcnfModule {
596            fun_decls: vec![decl],
597            extern_decls: vec![],
598            name: "test".to_string(),
599            metadata: LcnfModuleMetadata::default(),
600        };
601        let config = SpecializationConfig::default();
602        specialize_module(&mut module, &config);
603        assert!(!module.fun_decls.is_empty());
604    }
605    #[test]
606    pub(super) fn test_closure_uniformity_analysis() {
607        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
608        let decl = make_decl("apply", vec![make_param(0, "f", LcnfType::Object)], body);
609        let sites = vec![
610            SpecCallSite {
611                callee: "apply".to_string(),
612                call_idx: 0,
613                type_args: vec![],
614                const_args: vec![],
615                closure_args: vec![SpecClosureArg {
616                    known_fn: Some("double".to_string()),
617                    param_idx: 0,
618                }],
619                callee_var: None,
620            },
621            SpecCallSite {
622                callee: "apply".to_string(),
623                call_idx: 1,
624                type_args: vec![],
625                const_args: vec![],
626                closure_args: vec![SpecClosureArg {
627                    known_fn: Some("double".to_string()),
628                    param_idx: 0,
629                }],
630                callee_var: None,
631            },
632        ];
633        let result = analyze_closure_uniformity(&decl, 0, &sites);
634        assert_eq!(result, Some("double".to_string()));
635    }
636    #[test]
637    pub(super) fn test_closure_uniformity_non_uniform() {
638        let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
639        let decl = make_decl("apply", vec![make_param(0, "f", LcnfType::Object)], body);
640        let sites = vec![
641            SpecCallSite {
642                callee: "apply".to_string(),
643                call_idx: 0,
644                type_args: vec![],
645                const_args: vec![],
646                closure_args: vec![SpecClosureArg {
647                    known_fn: Some("double".to_string()),
648                    param_idx: 0,
649                }],
650                callee_var: None,
651            },
652            SpecCallSite {
653                callee: "apply".to_string(),
654                call_idx: 1,
655                type_args: vec![],
656                const_args: vec![],
657                closure_args: vec![SpecClosureArg {
658                    known_fn: Some("triple".to_string()),
659                    param_idx: 0,
660                }],
661                callee_var: None,
662            },
663        ];
664        let result = analyze_closure_uniformity(&decl, 0, &sites);
665        assert!(result.is_none());
666    }
667    #[test]
668    pub(super) fn test_find_specialization_sites() {
669        let body = make_simple_let(
670            1,
671            LcnfLetValue::App(
672                LcnfArg::Var(make_var(10)),
673                vec![LcnfArg::Lit(LcnfLit::Nat(42))],
674            ),
675            LcnfExpr::Return(LcnfArg::Var(make_var(1))),
676        );
677        let known_consts: HashMap<LcnfVarId, LcnfLit> = HashMap::new();
678        let mut known_fns: HashMap<LcnfVarId, String> = HashMap::new();
679        known_fns.insert(make_var(10), "target_fn".to_string());
680        let mut decl_names = HashSet::new();
681        decl_names.insert("target_fn".to_string());
682        let sites = find_specialization_sites(&body, &known_consts, &known_fns, &decl_names);
683        assert_eq!(sites.len(), 1);
684        assert_eq!(sites[0].callee, "target_fn");
685        assert!(matches!(sites[0].const_args[0], SpecConstArg::Nat(42)));
686    }
687    #[test]
688    pub(super) fn test_create_specialization() {
689        let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
690        let decl = make_decl(
691            "my_fn",
692            vec![
693                make_param(0, "x", LcnfType::Nat),
694                make_param(1, "y", LcnfType::Nat),
695            ],
696            body,
697        );
698        let key = SpecializationKey {
699            original: "my_fn".to_string(),
700            type_args: vec![],
701            const_args: vec![SpecConstArg::Nat(10), SpecConstArg::Unknown],
702            closure_args: vec![],
703        };
704        let mut pass = SpecializationPass::new(SpecializationConfig::default());
705        let result = pass.create_specialization(&decl, &key);
706        assert!(result.is_some());
707        let spec = result.expect("spec should be Some/Ok");
708        assert!(spec.decl.name.contains("my_fn"));
709        assert!(spec.decl.name.contains("_C0_N10"));
710        assert_eq!(spec.decl.params.len(), 1);
711    }
712    #[test]
713    pub(super) fn test_stats_default() {
714        let stats = SpecializationStats::default();
715        assert_eq!(stats.type_specializations, 0);
716        assert_eq!(stats.const_specializations, 0);
717        assert_eq!(stats.closure_specializations, 0);
718    }
719    #[test]
720    pub(super) fn test_pass_fresh_id() {
721        let mut pass = SpecializationPass::new(SpecializationConfig::default());
722        let id1 = pass.fresh_id();
723        let id2 = pass.fresh_id();
724        assert_ne!(id1, id2);
725    }
726    #[test]
727    pub(super) fn test_substitute_in_tailcall() {
728        let mut expr = LcnfExpr::TailCall(
729            LcnfArg::Var(make_var(10)),
730            vec![LcnfArg::Var(make_var(0)), LcnfArg::Var(make_var(1))],
731        );
732        let pass = SpecializationPass::new(SpecializationConfig::default());
733        pass.substitute_constant(&mut expr, make_var(0), &LcnfLit::Nat(7));
734        if let LcnfExpr::TailCall(_, args) = &expr {
735            assert_eq!(args[0], LcnfArg::Lit(LcnfLit::Nat(7)));
736            assert_eq!(args[1], LcnfArg::Var(make_var(1)));
737        } else {
738            panic!("Expected TailCall");
739        }
740    }
741    #[test]
742    pub(super) fn test_is_called_in_case() {
743        let body = LcnfExpr::Case {
744            scrutinee: make_var(1),
745            scrutinee_ty: LcnfType::Nat,
746            alts: vec![LcnfAlt {
747                ctor_name: "True".to_string(),
748                ctor_tag: 0,
749                params: vec![],
750                body: make_simple_let(
751                    5,
752                    LcnfLetValue::App(
753                        LcnfArg::Var(make_var(0)),
754                        vec![LcnfArg::Lit(LcnfLit::Nat(1))],
755                    ),
756                    LcnfExpr::Return(LcnfArg::Var(make_var(5))),
757                ),
758            }],
759            default: None,
760        };
761        assert!(is_called_as_function(&body, make_var(0)));
762        assert!(!is_called_as_function(&body, make_var(2)));
763    }
764    #[test]
765    pub(super) fn test_tailcall_specialization_site() {
766        let expr = LcnfExpr::TailCall(
767            LcnfArg::Var(make_var(10)),
768            vec![LcnfArg::Lit(LcnfLit::Nat(5))],
769        );
770        let mut known_fns: HashMap<LcnfVarId, String> = HashMap::new();
771        known_fns.insert(make_var(10), "recurse".to_string());
772        let mut decl_names = HashSet::new();
773        decl_names.insert("recurse".to_string());
774        let known_consts: HashMap<LcnfVarId, LcnfLit> = HashMap::new();
775        let sites = find_specialization_sites(&expr, &known_consts, &known_fns, &decl_names);
776        assert_eq!(sites.len(), 1);
777        assert_eq!(sites[0].callee, "recurse");
778    }
779    #[test]
780    pub(super) fn test_recursive_specialization_disabled() {
781        let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
782        let mut decl = make_decl("rec_fn", vec![make_param(0, "n", LcnfType::Nat)], body);
783        decl.is_recursive = true;
784        let key = SpecializationKey {
785            original: "rec_fn".to_string(),
786            type_args: vec![],
787            const_args: vec![SpecConstArg::Nat(5)],
788            closure_args: vec![],
789        };
790        let mut pass = SpecializationPass::new(SpecializationConfig {
791            allow_recursive: false,
792            ..SpecializationConfig::default()
793        });
794        let result = pass.create_specialization(&decl, &key);
795        assert!(result.is_none());
796    }
797}
798#[cfg(test)]
799mod Spec_infra_tests {
800    use super::*;
801    #[test]
802    pub(super) fn test_pass_config() {
803        let config = SpecPassConfig::new("test_pass", SpecPassPhase::Transformation);
804        assert!(config.enabled);
805        assert!(config.phase.is_modifying());
806        assert_eq!(config.phase.name(), "transformation");
807    }
808    #[test]
809    pub(super) fn test_pass_stats() {
810        let mut stats = SpecPassStats::new();
811        stats.record_run(10, 100, 3);
812        stats.record_run(20, 200, 5);
813        assert_eq!(stats.total_runs, 2);
814        assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
815        assert!((stats.success_rate() - 1.0).abs() < 0.01);
816        let s = stats.format_summary();
817        assert!(s.contains("Runs: 2/2"));
818    }
819    #[test]
820    pub(super) fn test_pass_registry() {
821        let mut reg = SpecPassRegistry::new();
822        reg.register(SpecPassConfig::new("pass_a", SpecPassPhase::Analysis));
823        reg.register(SpecPassConfig::new("pass_b", SpecPassPhase::Transformation).disabled());
824        assert_eq!(reg.total_passes(), 2);
825        assert_eq!(reg.enabled_count(), 1);
826        reg.update_stats("pass_a", 5, 50, 2);
827        let stats = reg.get_stats("pass_a").expect("stats should exist");
828        assert_eq!(stats.total_changes, 5);
829    }
830    #[test]
831    pub(super) fn test_analysis_cache() {
832        let mut cache = SpecAnalysisCache::new(10);
833        cache.insert("key1".to_string(), vec![1, 2, 3]);
834        assert!(cache.get("key1").is_some());
835        assert!(cache.get("key2").is_none());
836        assert!((cache.hit_rate() - 0.5).abs() < 0.01);
837        cache.invalidate("key1");
838        assert!(!cache.entries["key1"].valid);
839        assert_eq!(cache.size(), 1);
840    }
841    #[test]
842    pub(super) fn test_worklist() {
843        let mut wl = SpecWorklist::new();
844        assert!(wl.push(1));
845        assert!(wl.push(2));
846        assert!(!wl.push(1));
847        assert_eq!(wl.len(), 2);
848        assert_eq!(wl.pop(), Some(1));
849        assert!(!wl.contains(1));
850        assert!(wl.contains(2));
851    }
852    #[test]
853    pub(super) fn test_dominator_tree() {
854        let mut dt = SpecDominatorTree::new(5);
855        dt.set_idom(1, 0);
856        dt.set_idom(2, 0);
857        dt.set_idom(3, 1);
858        assert!(dt.dominates(0, 3));
859        assert!(dt.dominates(1, 3));
860        assert!(!dt.dominates(2, 3));
861        assert!(dt.dominates(3, 3));
862    }
863    #[test]
864    pub(super) fn test_liveness() {
865        let mut liveness = SpecLivenessInfo::new(3);
866        liveness.add_def(0, 1);
867        liveness.add_use(1, 1);
868        assert!(liveness.defs[0].contains(&1));
869        assert!(liveness.uses[1].contains(&1));
870    }
871    #[test]
872    pub(super) fn test_constant_folding() {
873        assert_eq!(SpecConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
874        assert_eq!(SpecConstantFoldingHelper::fold_div_i64(10, 0), None);
875        assert_eq!(SpecConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
876        assert_eq!(
877            SpecConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
878            0b1000
879        );
880        assert_eq!(SpecConstantFoldingHelper::fold_bitnot_i64(0), -1);
881    }
882    #[test]
883    pub(super) fn test_dep_graph() {
884        let mut g = SpecDepGraph::new();
885        g.add_dep(1, 2);
886        g.add_dep(2, 3);
887        g.add_dep(1, 3);
888        assert_eq!(g.dependencies_of(2), vec![1]);
889        let topo = g.topological_sort();
890        assert_eq!(topo.len(), 3);
891        assert!(!g.has_cycle());
892        let pos: std::collections::HashMap<u32, usize> =
893            topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
894        assert!(pos[&1] < pos[&2]);
895        assert!(pos[&1] < pos[&3]);
896        assert!(pos[&2] < pos[&3]);
897    }
898}
899#[cfg(test)]
900mod specext_pass_tests {
901    use super::*;
902    #[test]
903    pub(super) fn test_specext_phase_order() {
904        assert_eq!(SpecExtPassPhase::Early.order(), 0);
905        assert_eq!(SpecExtPassPhase::Middle.order(), 1);
906        assert_eq!(SpecExtPassPhase::Late.order(), 2);
907        assert_eq!(SpecExtPassPhase::Finalize.order(), 3);
908        assert!(SpecExtPassPhase::Early.is_early());
909        assert!(!SpecExtPassPhase::Early.is_late());
910    }
911    #[test]
912    pub(super) fn test_specext_config_builder() {
913        let c = SpecExtPassConfig::new("p")
914            .with_phase(SpecExtPassPhase::Late)
915            .with_max_iter(50)
916            .with_debug(1);
917        assert_eq!(c.name, "p");
918        assert_eq!(c.max_iterations, 50);
919        assert!(c.is_debug_enabled());
920        assert!(c.enabled);
921        let c2 = c.disabled();
922        assert!(!c2.enabled);
923    }
924    #[test]
925    pub(super) fn test_specext_stats() {
926        let mut s = SpecExtPassStats::new();
927        s.visit();
928        s.visit();
929        s.modify();
930        s.iterate();
931        assert_eq!(s.nodes_visited, 2);
932        assert_eq!(s.nodes_modified, 1);
933        assert!(s.changed);
934        assert_eq!(s.iterations, 1);
935        let e = s.efficiency();
936        assert!((e - 0.5).abs() < 1e-9);
937    }
938    #[test]
939    pub(super) fn test_specext_registry() {
940        let mut r = SpecExtPassRegistry::new();
941        r.register(SpecExtPassConfig::new("a").with_phase(SpecExtPassPhase::Early));
942        r.register(SpecExtPassConfig::new("b").disabled());
943        assert_eq!(r.len(), 2);
944        assert_eq!(r.enabled_passes().len(), 1);
945        assert_eq!(r.passes_in_phase(&SpecExtPassPhase::Early).len(), 1);
946    }
947    #[test]
948    pub(super) fn test_specext_cache() {
949        let mut c = SpecExtCache::new(4);
950        assert!(c.get(99).is_none());
951        c.put(99, vec![1, 2, 3]);
952        let v = c.get(99).expect("v should be present in map");
953        assert_eq!(v, &[1u8, 2, 3]);
954        assert!(c.hit_rate() > 0.0);
955        assert_eq!(c.live_count(), 1);
956    }
957    #[test]
958    pub(super) fn test_specext_worklist() {
959        let mut w = SpecExtWorklist::new(10);
960        w.push(5);
961        w.push(3);
962        w.push(5);
963        assert_eq!(w.len(), 2);
964        assert!(w.contains(5));
965        let first = w.pop().expect("first should be available to pop");
966        assert!(!w.contains(first));
967    }
968    #[test]
969    pub(super) fn test_specext_dom_tree() {
970        let mut dt = SpecExtDomTree::new(5);
971        dt.set_idom(1, 0);
972        dt.set_idom(2, 0);
973        dt.set_idom(3, 1);
974        dt.set_idom(4, 1);
975        assert!(dt.dominates(0, 3));
976        assert!(dt.dominates(1, 4));
977        assert!(!dt.dominates(2, 3));
978        assert_eq!(dt.depth_of(3), 2);
979    }
980    #[test]
981    pub(super) fn test_specext_liveness() {
982        let mut lv = SpecExtLiveness::new(3);
983        lv.add_def(0, 1);
984        lv.add_use(1, 1);
985        assert!(lv.var_is_def_in_block(0, 1));
986        assert!(lv.var_is_used_in_block(1, 1));
987        assert!(!lv.var_is_def_in_block(1, 1));
988    }
989    #[test]
990    pub(super) fn test_specext_const_folder() {
991        let mut cf = SpecExtConstFolder::new();
992        assert_eq!(cf.add_i64(3, 4), Some(7));
993        assert_eq!(cf.div_i64(10, 0), None);
994        assert_eq!(cf.mul_i64(6, 7), Some(42));
995        assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
996        assert_eq!(cf.fold_count(), 3);
997        assert_eq!(cf.failure_count(), 1);
998    }
999    #[test]
1000    pub(super) fn test_specext_dep_graph() {
1001        let mut g = SpecExtDepGraph::new(4);
1002        g.add_edge(0, 1);
1003        g.add_edge(1, 2);
1004        g.add_edge(2, 3);
1005        assert!(!g.has_cycle());
1006        assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
1007        assert_eq!(g.reachable(0).len(), 4);
1008        let sccs = g.scc();
1009        assert_eq!(sccs.len(), 4);
1010    }
1011}