1use crate::lcnf::*;
6use std::collections::{HashMap, HashSet};
7
8use super::types::{
9 NumericSpecializer, SizeBudget, SpecAnalysisCache, SpecCallSite, SpecClosureArg, SpecConstArg,
10 SpecConstantFoldingHelper, SpecDepGraph, SpecDominatorTree, SpecExtCache, SpecExtConstFolder,
11 SpecExtDepGraph, SpecExtDomTree, SpecExtLiveness, SpecExtPassConfig, SpecExtPassPhase,
12 SpecExtPassRegistry, SpecExtPassStats, SpecExtWorklist, SpecLivenessInfo, SpecPassConfig,
13 SpecPassPhase, SpecPassRegistry, SpecPassStats, SpecTypeArg, SpecWorklist, SpecializationCache,
14 SpecializationConfig, SpecializationKey, SpecializationPass, SpecializationStats,
15};
16
17pub(super) fn type_suffix(ty: &LcnfType) -> String {
19 match ty {
20 LcnfType::Nat => "nat".to_string(),
21 LcnfType::Object => "obj".to_string(),
22 LcnfType::Unit => "unit".to_string(),
23 LcnfType::Erased => "e".to_string(),
24 LcnfType::LcnfString => "str".to_string(),
25 LcnfType::Var(name) => name.clone(),
26 LcnfType::Ctor(name, _) => name.clone(),
27 LcnfType::Fun(_, _) => "fn".to_string(),
28 LcnfType::Irrelevant => "irr".to_string(),
29 }
30}
31pub(super) fn find_specialization_sites(
33 expr: &LcnfExpr,
34 known_constants: &HashMap<LcnfVarId, LcnfLit>,
35 known_functions: &HashMap<LcnfVarId, String>,
36 decl_names: &HashSet<String>,
37) -> Vec<SpecCallSite> {
38 let mut sites = Vec::new();
39 let mut call_idx = 0;
40 find_spec_sites_inner(
41 expr,
42 known_constants,
43 known_functions,
44 decl_names,
45 &mut sites,
46 &mut call_idx,
47 );
48 sites
49}
50pub(super) fn find_spec_sites_inner(
51 expr: &LcnfExpr,
52 known_constants: &HashMap<LcnfVarId, LcnfLit>,
53 known_functions: &HashMap<LcnfVarId, String>,
54 decl_names: &HashSet<String>,
55 sites: &mut Vec<SpecCallSite>,
56 call_idx: &mut usize,
57) {
58 match expr {
59 LcnfExpr::Let {
60 id, value, body, ..
61 } => {
62 let mut extended_consts = known_constants.clone();
63 if let LcnfLetValue::Lit(lit) = value {
64 extended_consts.insert(*id, lit.clone());
65 }
66 let mut extended_fns = known_functions.clone();
67 if let LcnfLetValue::FVar(fvar) = value {
68 if let Some(fname) = known_functions.get(fvar) {
69 extended_fns.insert(*id, fname.clone());
70 }
71 }
72 if let LcnfLetValue::App(func, args) = value {
73 let callee_name = match func {
74 LcnfArg::Var(v) => known_functions.get(v).cloned(),
75 _ => None,
76 };
77 if let Some(ref callee) = callee_name {
78 if decl_names.contains(callee.as_str()) {
79 let const_args: Vec<SpecConstArg> = args
80 .iter()
81 .map(|arg| match arg {
82 LcnfArg::Lit(LcnfLit::Nat(n)) => SpecConstArg::Nat(*n),
83 LcnfArg::Lit(LcnfLit::Str(s)) => SpecConstArg::Str(s.clone()),
84 LcnfArg::Var(v) => {
85 if let Some(lit) = extended_consts.get(v) {
86 match lit {
87 LcnfLit::Nat(n) => SpecConstArg::Nat(*n),
88 LcnfLit::Str(s) => SpecConstArg::Str(s.clone()),
89 }
90 } else {
91 SpecConstArg::Unknown
92 }
93 }
94 _ => SpecConstArg::Unknown,
95 })
96 .collect();
97 let closure_args: Vec<SpecClosureArg> = args
98 .iter()
99 .enumerate()
100 .map(|(i, arg)| {
101 let known_fn = match arg {
102 LcnfArg::Var(v) => extended_fns.get(v).cloned(),
103 _ => None,
104 };
105 SpecClosureArg {
106 known_fn,
107 param_idx: i,
108 }
109 })
110 .collect();
111 let callee_var = match func {
112 LcnfArg::Var(v) => Some(*v),
113 _ => None,
114 };
115 sites.push(SpecCallSite {
116 callee: callee.clone(),
117 call_idx: *call_idx,
118 type_args: vec![],
119 const_args,
120 closure_args,
121 callee_var,
122 });
123 *call_idx += 1;
124 }
125 }
126 }
127 find_spec_sites_inner(
128 body,
129 &extended_consts,
130 &extended_fns,
131 decl_names,
132 sites,
133 call_idx,
134 );
135 }
136 LcnfExpr::Case { alts, default, .. } => {
137 for alt in alts {
138 find_spec_sites_inner(
139 &alt.body,
140 known_constants,
141 known_functions,
142 decl_names,
143 sites,
144 call_idx,
145 );
146 }
147 if let Some(def) = default {
148 find_spec_sites_inner(
149 def,
150 known_constants,
151 known_functions,
152 decl_names,
153 sites,
154 call_idx,
155 );
156 }
157 }
158 LcnfExpr::TailCall(func, args) => {
159 let callee_name = match func {
160 LcnfArg::Var(v) => known_functions.get(v).cloned(),
161 _ => None,
162 };
163 if let Some(callee) = callee_name {
164 if decl_names.contains(callee.as_str()) {
165 let const_args: Vec<SpecConstArg> = args
166 .iter()
167 .map(|arg| match arg {
168 LcnfArg::Lit(LcnfLit::Nat(n)) => SpecConstArg::Nat(*n),
169 LcnfArg::Lit(LcnfLit::Str(s)) => SpecConstArg::Str(s.clone()),
170 LcnfArg::Var(v) => {
171 if let Some(lit) = known_constants.get(v) {
172 match lit {
173 LcnfLit::Nat(n) => SpecConstArg::Nat(*n),
174 LcnfLit::Str(s) => SpecConstArg::Str(s.clone()),
175 }
176 } else {
177 SpecConstArg::Unknown
178 }
179 }
180 _ => SpecConstArg::Unknown,
181 })
182 .collect();
183 let closure_args: Vec<SpecClosureArg> = args
184 .iter()
185 .enumerate()
186 .map(|(i, arg)| {
187 let known_fn = match arg {
188 LcnfArg::Var(v) => known_functions.get(v).cloned(),
189 _ => None,
190 };
191 SpecClosureArg {
192 known_fn,
193 param_idx: i,
194 }
195 })
196 .collect();
197 let callee_var = match func {
198 LcnfArg::Var(v) => Some(*v),
199 _ => None,
200 };
201 sites.push(SpecCallSite {
202 callee,
203 call_idx: *call_idx,
204 type_args: vec![],
205 const_args,
206 closure_args,
207 callee_var,
208 });
209 *call_idx += 1;
210 }
211 }
212 }
213 LcnfExpr::Return(_) | LcnfExpr::Unreachable => {}
214 }
215}
216pub(super) fn count_instructions(expr: &LcnfExpr) -> usize {
218 match expr {
219 LcnfExpr::Let { body, .. } => 1 + count_instructions(body),
220 LcnfExpr::Case { alts, default, .. } => {
221 let alts_size: usize = alts.iter().map(|a| count_instructions(&a.body)).sum();
222 let def_size = default.as_ref().map(|d| count_instructions(d)).unwrap_or(0);
223 1 + alts_size + def_size
224 }
225 LcnfExpr::Return(_) | LcnfExpr::TailCall(_, _) | LcnfExpr::Unreachable => 1,
226 }
227}
228pub(super) fn analyze_closure_uniformity(
230 decl: &LcnfFunDecl,
231 param_idx: usize,
232 sites: &[SpecCallSite],
233) -> Option<String> {
234 let mut known_fn: Option<String> = None;
235 for site in sites {
236 if site.callee != decl.name {
237 continue;
238 }
239 if param_idx >= site.closure_args.len() {
240 return None;
241 }
242 match &site.closure_args[param_idx].known_fn {
243 Some(fn_name) => {
244 if let Some(ref existing) = known_fn {
245 if existing != fn_name {
246 return None;
247 }
248 } else {
249 known_fn = Some(fn_name.clone());
250 }
251 }
252 None => return None,
253 }
254 }
255 known_fn
256}
257pub(super) fn is_called_as_function(expr: &LcnfExpr, param_id: LcnfVarId) -> bool {
259 match expr {
260 LcnfExpr::Let { value, body, .. } => {
261 let called_here = matches!(
262 value, LcnfLetValue::App(LcnfArg::Var(v), _) if * v == param_id
263 );
264 called_here || is_called_as_function(body, param_id)
265 }
266 LcnfExpr::Case { alts, default, .. } => {
267 alts.iter()
268 .any(|a| is_called_as_function(&a.body, param_id))
269 || default
270 .as_ref()
271 .is_some_and(|d| is_called_as_function(d, param_id))
272 }
273 LcnfExpr::TailCall(LcnfArg::Var(v), _) => *v == param_id,
274 _ => false,
275 }
276}
277pub fn specialize_module(module: &mut LcnfModule, config: &SpecializationConfig) {
279 let mut pass = SpecializationPass::new(config.clone());
280 pass.run(module);
281}
282pub fn specialize_numeric(decl: &LcnfFunDecl) -> Option<LcnfFunDecl> {
284 let specializer = NumericSpecializer::new();
285 if !specializer.is_numeric_op(&decl.name) {
286 return None;
287 }
288 let mut spec = decl.clone();
289 spec.name = format!("{}_u64", decl.name);
290 for param in &mut spec.params {
291 param.ty = specializer.specialize_nat_to_u64(¶m.ty);
292 }
293 spec.ret_type = specializer.specialize_nat_to_u64(&spec.ret_type);
294 Some(spec)
295}
296pub fn is_worth_specializing(decl: &LcnfFunDecl, config: &SpecializationConfig) -> bool {
298 let size = count_instructions(&decl.body);
299 if size > config.size_threshold {
300 return false;
301 }
302 let has_poly = decl
303 .params
304 .iter()
305 .any(|p| matches!(p.ty, LcnfType::Var(_) | LcnfType::Object | LcnfType::Erased));
306 let has_fn_param = decl
307 .params
308 .iter()
309 .any(|p| matches!(p.ty, LcnfType::Fun(_, _)));
310 has_poly || (has_fn_param && config.specialize_closures)
311}
312#[cfg(test)]
313mod tests {
314 use super::*;
315 pub(super) fn make_var(n: u64) -> LcnfVarId {
316 LcnfVarId(n)
317 }
318 pub(super) fn make_param(n: u64, name: &str, ty: LcnfType) -> LcnfParam {
319 LcnfParam {
320 id: LcnfVarId(n),
321 name: name.to_string(),
322 ty,
323 erased: false,
324 borrowed: false,
325 }
326 }
327 pub(super) fn make_simple_let(id: u64, value: LcnfLetValue, body: LcnfExpr) -> LcnfExpr {
328 LcnfExpr::Let {
329 id: LcnfVarId(id),
330 name: format!("x{}", id),
331 ty: LcnfType::Nat,
332 value,
333 body: Box::new(body),
334 }
335 }
336 pub(super) fn make_decl(name: &str, params: Vec<LcnfParam>, body: LcnfExpr) -> LcnfFunDecl {
337 LcnfFunDecl {
338 name: name.to_string(),
339 original_name: None,
340 params,
341 ret_type: LcnfType::Nat,
342 body,
343 is_recursive: false,
344 is_lifted: false,
345 inline_cost: 1,
346 }
347 }
348 #[test]
349 pub(super) fn test_config_default() {
350 let config = SpecializationConfig::default();
351 assert_eq!(config.max_specializations, 8);
352 assert!(config.specialize_closures);
353 assert!(config.specialize_numerics);
354 assert_eq!(config.size_threshold, 200);
355 }
356 #[test]
357 pub(super) fn test_spec_key_trivial() {
358 let key = SpecializationKey {
359 original: "foo".to_string(),
360 type_args: vec![SpecTypeArg::Poly],
361 const_args: vec![SpecConstArg::Unknown],
362 closure_args: vec![SpecClosureArg {
363 known_fn: None,
364 param_idx: 0,
365 }],
366 };
367 assert!(key.is_trivial());
368 }
369 #[test]
370 pub(super) fn test_spec_key_non_trivial_type() {
371 let key = SpecializationKey {
372 original: "foo".to_string(),
373 type_args: vec![SpecTypeArg::Concrete(LcnfType::Nat)],
374 const_args: vec![],
375 closure_args: vec![],
376 };
377 assert!(!key.is_trivial());
378 }
379 #[test]
380 pub(super) fn test_spec_key_non_trivial_const() {
381 let key = SpecializationKey {
382 original: "foo".to_string(),
383 type_args: vec![],
384 const_args: vec![SpecConstArg::Nat(42)],
385 closure_args: vec![],
386 };
387 assert!(!key.is_trivial());
388 }
389 #[test]
390 pub(super) fn test_spec_key_mangled_name() {
391 let key = SpecializationKey {
392 original: "List.map".to_string(),
393 type_args: vec![SpecTypeArg::Concrete(LcnfType::Nat)],
394 const_args: vec![SpecConstArg::Unknown],
395 closure_args: vec![],
396 };
397 let name = key.mangled_name();
398 assert!(name.starts_with("List.map"));
399 assert!(name.contains("_T0_nat"));
400 }
401 #[test]
402 pub(super) fn test_spec_key_mangled_name_with_const() {
403 let key = SpecializationKey {
404 original: "repeat".to_string(),
405 type_args: vec![],
406 const_args: vec![SpecConstArg::Nat(3)],
407 closure_args: vec![],
408 };
409 let name = key.mangled_name();
410 assert!(name.contains("_C0_N3"));
411 }
412 #[test]
413 pub(super) fn test_spec_key_mangled_name_with_closure() {
414 let key = SpecializationKey {
415 original: "List.map".to_string(),
416 type_args: vec![],
417 const_args: vec![],
418 closure_args: vec![SpecClosureArg {
419 known_fn: Some("double".to_string()),
420 param_idx: 0,
421 }],
422 };
423 let name = key.mangled_name();
424 assert!(name.contains("_Fdouble"));
425 }
426 #[test]
427 pub(super) fn test_type_suffix() {
428 assert_eq!(type_suffix(&LcnfType::Nat), "nat");
429 assert_eq!(type_suffix(&LcnfType::Object), "obj");
430 assert_eq!(type_suffix(&LcnfType::Unit), "unit");
431 assert_eq!(type_suffix(&LcnfType::LcnfString), "str");
432 }
433 #[test]
434 pub(super) fn test_cache_operations() {
435 let mut cache = SpecializationCache::new();
436 let key = SpecializationKey {
437 original: "foo".to_string(),
438 type_args: vec![SpecTypeArg::Concrete(LcnfType::Nat)],
439 const_args: vec![],
440 closure_args: vec![],
441 };
442 assert!(cache.lookup(&key).is_none());
443 cache.insert(key.clone(), "foo_nat".to_string(), 10);
444 assert_eq!(cache.lookup(&key), Some("foo_nat"));
445 assert_eq!(cache.specialization_count("foo"), 1);
446 assert_eq!(cache.total_growth, 10);
447 }
448 #[test]
449 pub(super) fn test_size_budget() {
450 let mut budget = SizeBudget::new(100, 2.0);
451 assert!(budget.can_afford(50));
452 assert!(budget.can_afford(100));
453 assert!(!budget.can_afford(101));
454 budget.spend(50);
455 assert!(budget.can_afford(50));
456 assert!(!budget.can_afford(51));
457 assert_eq!(budget.remaining(), 50);
458 }
459 #[test]
460 pub(super) fn test_numeric_specializer() {
461 let specializer = NumericSpecializer::new();
462 assert!(specializer.is_numeric_op("Nat.add"));
463 assert!(specializer.is_numeric_op("Nat.mul"));
464 assert!(!specializer.is_numeric_op("List.map"));
465 }
466 #[test]
467 pub(super) fn test_numeric_type_specialization() {
468 let specializer = NumericSpecializer::new();
469 let ty = LcnfType::Fun(vec![LcnfType::Nat, LcnfType::Nat], Box::new(LcnfType::Nat));
470 let spec = specializer.specialize_nat_to_u64(&ty);
471 assert_eq!(
472 spec,
473 LcnfType::Fun(vec![LcnfType::Nat, LcnfType::Nat], Box::new(LcnfType::Nat))
474 );
475 }
476 #[test]
477 pub(super) fn test_specialize_numeric() {
478 let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
479 let decl = make_decl(
480 "Nat.add",
481 vec![
482 make_param(0, "a", LcnfType::Nat),
483 make_param(1, "b", LcnfType::Nat),
484 ],
485 body,
486 );
487 let result = specialize_numeric(&decl);
488 assert!(result.is_some());
489 let spec = result.expect("spec should be Some/Ok");
490 assert_eq!(spec.name, "Nat.add_u64");
491 }
492 #[test]
493 pub(super) fn test_specialize_numeric_non_numeric() {
494 let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
495 let decl = make_decl("List.map", vec![make_param(0, "f", LcnfType::Object)], body);
496 let result = specialize_numeric(&decl);
497 assert!(result.is_none());
498 }
499 #[test]
500 pub(super) fn test_is_worth_specializing_polymorphic() {
501 let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
502 let decl = make_decl(
503 "id",
504 vec![make_param(0, "x", LcnfType::Var("a".to_string()))],
505 body,
506 );
507 let config = SpecializationConfig::default();
508 assert!(is_worth_specializing(&decl, &config));
509 }
510 #[test]
511 pub(super) fn test_is_worth_specializing_concrete() {
512 let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
513 let decl = make_decl(
514 "add",
515 vec![
516 make_param(0, "a", LcnfType::Nat),
517 make_param(1, "b", LcnfType::Nat),
518 ],
519 body,
520 );
521 let config = SpecializationConfig::default();
522 assert!(!is_worth_specializing(&decl, &config));
523 }
524 #[test]
525 pub(super) fn test_is_worth_specializing_higher_order() {
526 let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
527 let fn_ty = LcnfType::Fun(vec![LcnfType::Nat], Box::new(LcnfType::Nat));
528 let decl = make_decl("apply", vec![make_param(0, "f", fn_ty)], body);
529 let config = SpecializationConfig::default();
530 assert!(is_worth_specializing(&decl, &config));
531 }
532 #[test]
533 pub(super) fn test_is_called_as_function() {
534 let body = make_simple_let(
535 5,
536 LcnfLetValue::App(
537 LcnfArg::Var(make_var(0)),
538 vec![LcnfArg::Lit(LcnfLit::Nat(1))],
539 ),
540 LcnfExpr::Return(LcnfArg::Var(make_var(5))),
541 );
542 assert!(is_called_as_function(&body, make_var(0)));
543 assert!(!is_called_as_function(&body, make_var(1)));
544 }
545 #[test]
546 pub(super) fn test_count_instructions() {
547 let expr = make_simple_let(
548 1,
549 LcnfLetValue::Lit(LcnfLit::Nat(42)),
550 make_simple_let(
551 2,
552 LcnfLetValue::Lit(LcnfLit::Nat(10)),
553 LcnfExpr::Return(LcnfArg::Var(make_var(2))),
554 ),
555 );
556 assert_eq!(count_instructions(&expr), 3);
557 }
558 #[test]
559 pub(super) fn test_substitute_constant() {
560 let mut expr = make_simple_let(
561 1,
562 LcnfLetValue::FVar(make_var(0)),
563 LcnfExpr::Return(LcnfArg::Var(make_var(1))),
564 );
565 let pass = SpecializationPass::new(SpecializationConfig::default());
566 pass.substitute_constant(&mut expr, make_var(0), &LcnfLit::Nat(42));
567 if let LcnfExpr::Let { value, .. } = &expr {
568 assert_eq!(*value, LcnfLetValue::Lit(LcnfLit::Nat(42)));
569 } else {
570 panic!("Expected Let");
571 }
572 }
573 #[test]
574 pub(super) fn test_substitute_constant_in_return() {
575 let mut expr = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
576 let pass = SpecializationPass::new(SpecializationConfig::default());
577 pass.substitute_constant(&mut expr, make_var(0), &LcnfLit::Nat(99));
578 assert_eq!(expr, LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(99))));
579 }
580 #[test]
581 pub(super) fn test_specialize_module_empty() {
582 let mut module = LcnfModule::default();
583 let config = SpecializationConfig::default();
584 specialize_module(&mut module, &config);
585 assert!(module.fun_decls.is_empty());
586 }
587 #[test]
588 pub(super) fn test_specialize_module_simple() {
589 let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
590 let decl = make_decl(
591 "id",
592 vec![make_param(0, "x", LcnfType::Var("a".to_string()))],
593 body,
594 );
595 let mut module = LcnfModule {
596 fun_decls: vec![decl],
597 extern_decls: vec![],
598 name: "test".to_string(),
599 metadata: LcnfModuleMetadata::default(),
600 };
601 let config = SpecializationConfig::default();
602 specialize_module(&mut module, &config);
603 assert!(!module.fun_decls.is_empty());
604 }
605 #[test]
606 pub(super) fn test_closure_uniformity_analysis() {
607 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
608 let decl = make_decl("apply", vec![make_param(0, "f", LcnfType::Object)], body);
609 let sites = vec![
610 SpecCallSite {
611 callee: "apply".to_string(),
612 call_idx: 0,
613 type_args: vec![],
614 const_args: vec![],
615 closure_args: vec![SpecClosureArg {
616 known_fn: Some("double".to_string()),
617 param_idx: 0,
618 }],
619 callee_var: None,
620 },
621 SpecCallSite {
622 callee: "apply".to_string(),
623 call_idx: 1,
624 type_args: vec![],
625 const_args: vec![],
626 closure_args: vec![SpecClosureArg {
627 known_fn: Some("double".to_string()),
628 param_idx: 0,
629 }],
630 callee_var: None,
631 },
632 ];
633 let result = analyze_closure_uniformity(&decl, 0, &sites);
634 assert_eq!(result, Some("double".to_string()));
635 }
636 #[test]
637 pub(super) fn test_closure_uniformity_non_uniform() {
638 let body = LcnfExpr::Return(LcnfArg::Lit(LcnfLit::Nat(0)));
639 let decl = make_decl("apply", vec![make_param(0, "f", LcnfType::Object)], body);
640 let sites = vec![
641 SpecCallSite {
642 callee: "apply".to_string(),
643 call_idx: 0,
644 type_args: vec![],
645 const_args: vec![],
646 closure_args: vec![SpecClosureArg {
647 known_fn: Some("double".to_string()),
648 param_idx: 0,
649 }],
650 callee_var: None,
651 },
652 SpecCallSite {
653 callee: "apply".to_string(),
654 call_idx: 1,
655 type_args: vec![],
656 const_args: vec![],
657 closure_args: vec![SpecClosureArg {
658 known_fn: Some("triple".to_string()),
659 param_idx: 0,
660 }],
661 callee_var: None,
662 },
663 ];
664 let result = analyze_closure_uniformity(&decl, 0, &sites);
665 assert!(result.is_none());
666 }
667 #[test]
668 pub(super) fn test_find_specialization_sites() {
669 let body = make_simple_let(
670 1,
671 LcnfLetValue::App(
672 LcnfArg::Var(make_var(10)),
673 vec![LcnfArg::Lit(LcnfLit::Nat(42))],
674 ),
675 LcnfExpr::Return(LcnfArg::Var(make_var(1))),
676 );
677 let known_consts: HashMap<LcnfVarId, LcnfLit> = HashMap::new();
678 let mut known_fns: HashMap<LcnfVarId, String> = HashMap::new();
679 known_fns.insert(make_var(10), "target_fn".to_string());
680 let mut decl_names = HashSet::new();
681 decl_names.insert("target_fn".to_string());
682 let sites = find_specialization_sites(&body, &known_consts, &known_fns, &decl_names);
683 assert_eq!(sites.len(), 1);
684 assert_eq!(sites[0].callee, "target_fn");
685 assert!(matches!(sites[0].const_args[0], SpecConstArg::Nat(42)));
686 }
687 #[test]
688 pub(super) fn test_create_specialization() {
689 let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
690 let decl = make_decl(
691 "my_fn",
692 vec![
693 make_param(0, "x", LcnfType::Nat),
694 make_param(1, "y", LcnfType::Nat),
695 ],
696 body,
697 );
698 let key = SpecializationKey {
699 original: "my_fn".to_string(),
700 type_args: vec![],
701 const_args: vec![SpecConstArg::Nat(10), SpecConstArg::Unknown],
702 closure_args: vec![],
703 };
704 let mut pass = SpecializationPass::new(SpecializationConfig::default());
705 let result = pass.create_specialization(&decl, &key);
706 assert!(result.is_some());
707 let spec = result.expect("spec should be Some/Ok");
708 assert!(spec.decl.name.contains("my_fn"));
709 assert!(spec.decl.name.contains("_C0_N10"));
710 assert_eq!(spec.decl.params.len(), 1);
711 }
712 #[test]
713 pub(super) fn test_stats_default() {
714 let stats = SpecializationStats::default();
715 assert_eq!(stats.type_specializations, 0);
716 assert_eq!(stats.const_specializations, 0);
717 assert_eq!(stats.closure_specializations, 0);
718 }
719 #[test]
720 pub(super) fn test_pass_fresh_id() {
721 let mut pass = SpecializationPass::new(SpecializationConfig::default());
722 let id1 = pass.fresh_id();
723 let id2 = pass.fresh_id();
724 assert_ne!(id1, id2);
725 }
726 #[test]
727 pub(super) fn test_substitute_in_tailcall() {
728 let mut expr = LcnfExpr::TailCall(
729 LcnfArg::Var(make_var(10)),
730 vec![LcnfArg::Var(make_var(0)), LcnfArg::Var(make_var(1))],
731 );
732 let pass = SpecializationPass::new(SpecializationConfig::default());
733 pass.substitute_constant(&mut expr, make_var(0), &LcnfLit::Nat(7));
734 if let LcnfExpr::TailCall(_, args) = &expr {
735 assert_eq!(args[0], LcnfArg::Lit(LcnfLit::Nat(7)));
736 assert_eq!(args[1], LcnfArg::Var(make_var(1)));
737 } else {
738 panic!("Expected TailCall");
739 }
740 }
741 #[test]
742 pub(super) fn test_is_called_in_case() {
743 let body = LcnfExpr::Case {
744 scrutinee: make_var(1),
745 scrutinee_ty: LcnfType::Nat,
746 alts: vec![LcnfAlt {
747 ctor_name: "True".to_string(),
748 ctor_tag: 0,
749 params: vec![],
750 body: make_simple_let(
751 5,
752 LcnfLetValue::App(
753 LcnfArg::Var(make_var(0)),
754 vec![LcnfArg::Lit(LcnfLit::Nat(1))],
755 ),
756 LcnfExpr::Return(LcnfArg::Var(make_var(5))),
757 ),
758 }],
759 default: None,
760 };
761 assert!(is_called_as_function(&body, make_var(0)));
762 assert!(!is_called_as_function(&body, make_var(2)));
763 }
764 #[test]
765 pub(super) fn test_tailcall_specialization_site() {
766 let expr = LcnfExpr::TailCall(
767 LcnfArg::Var(make_var(10)),
768 vec![LcnfArg::Lit(LcnfLit::Nat(5))],
769 );
770 let mut known_fns: HashMap<LcnfVarId, String> = HashMap::new();
771 known_fns.insert(make_var(10), "recurse".to_string());
772 let mut decl_names = HashSet::new();
773 decl_names.insert("recurse".to_string());
774 let known_consts: HashMap<LcnfVarId, LcnfLit> = HashMap::new();
775 let sites = find_specialization_sites(&expr, &known_consts, &known_fns, &decl_names);
776 assert_eq!(sites.len(), 1);
777 assert_eq!(sites[0].callee, "recurse");
778 }
779 #[test]
780 pub(super) fn test_recursive_specialization_disabled() {
781 let body = LcnfExpr::Return(LcnfArg::Var(make_var(0)));
782 let mut decl = make_decl("rec_fn", vec![make_param(0, "n", LcnfType::Nat)], body);
783 decl.is_recursive = true;
784 let key = SpecializationKey {
785 original: "rec_fn".to_string(),
786 type_args: vec![],
787 const_args: vec![SpecConstArg::Nat(5)],
788 closure_args: vec![],
789 };
790 let mut pass = SpecializationPass::new(SpecializationConfig {
791 allow_recursive: false,
792 ..SpecializationConfig::default()
793 });
794 let result = pass.create_specialization(&decl, &key);
795 assert!(result.is_none());
796 }
797}
798#[cfg(test)]
799mod Spec_infra_tests {
800 use super::*;
801 #[test]
802 pub(super) fn test_pass_config() {
803 let config = SpecPassConfig::new("test_pass", SpecPassPhase::Transformation);
804 assert!(config.enabled);
805 assert!(config.phase.is_modifying());
806 assert_eq!(config.phase.name(), "transformation");
807 }
808 #[test]
809 pub(super) fn test_pass_stats() {
810 let mut stats = SpecPassStats::new();
811 stats.record_run(10, 100, 3);
812 stats.record_run(20, 200, 5);
813 assert_eq!(stats.total_runs, 2);
814 assert!((stats.average_changes_per_run() - 15.0).abs() < 0.01);
815 assert!((stats.success_rate() - 1.0).abs() < 0.01);
816 let s = stats.format_summary();
817 assert!(s.contains("Runs: 2/2"));
818 }
819 #[test]
820 pub(super) fn test_pass_registry() {
821 let mut reg = SpecPassRegistry::new();
822 reg.register(SpecPassConfig::new("pass_a", SpecPassPhase::Analysis));
823 reg.register(SpecPassConfig::new("pass_b", SpecPassPhase::Transformation).disabled());
824 assert_eq!(reg.total_passes(), 2);
825 assert_eq!(reg.enabled_count(), 1);
826 reg.update_stats("pass_a", 5, 50, 2);
827 let stats = reg.get_stats("pass_a").expect("stats should exist");
828 assert_eq!(stats.total_changes, 5);
829 }
830 #[test]
831 pub(super) fn test_analysis_cache() {
832 let mut cache = SpecAnalysisCache::new(10);
833 cache.insert("key1".to_string(), vec![1, 2, 3]);
834 assert!(cache.get("key1").is_some());
835 assert!(cache.get("key2").is_none());
836 assert!((cache.hit_rate() - 0.5).abs() < 0.01);
837 cache.invalidate("key1");
838 assert!(!cache.entries["key1"].valid);
839 assert_eq!(cache.size(), 1);
840 }
841 #[test]
842 pub(super) fn test_worklist() {
843 let mut wl = SpecWorklist::new();
844 assert!(wl.push(1));
845 assert!(wl.push(2));
846 assert!(!wl.push(1));
847 assert_eq!(wl.len(), 2);
848 assert_eq!(wl.pop(), Some(1));
849 assert!(!wl.contains(1));
850 assert!(wl.contains(2));
851 }
852 #[test]
853 pub(super) fn test_dominator_tree() {
854 let mut dt = SpecDominatorTree::new(5);
855 dt.set_idom(1, 0);
856 dt.set_idom(2, 0);
857 dt.set_idom(3, 1);
858 assert!(dt.dominates(0, 3));
859 assert!(dt.dominates(1, 3));
860 assert!(!dt.dominates(2, 3));
861 assert!(dt.dominates(3, 3));
862 }
863 #[test]
864 pub(super) fn test_liveness() {
865 let mut liveness = SpecLivenessInfo::new(3);
866 liveness.add_def(0, 1);
867 liveness.add_use(1, 1);
868 assert!(liveness.defs[0].contains(&1));
869 assert!(liveness.uses[1].contains(&1));
870 }
871 #[test]
872 pub(super) fn test_constant_folding() {
873 assert_eq!(SpecConstantFoldingHelper::fold_add_i64(3, 4), Some(7));
874 assert_eq!(SpecConstantFoldingHelper::fold_div_i64(10, 0), None);
875 assert_eq!(SpecConstantFoldingHelper::fold_div_i64(10, 2), Some(5));
876 assert_eq!(
877 SpecConstantFoldingHelper::fold_bitand_i64(0b1100, 0b1010),
878 0b1000
879 );
880 assert_eq!(SpecConstantFoldingHelper::fold_bitnot_i64(0), -1);
881 }
882 #[test]
883 pub(super) fn test_dep_graph() {
884 let mut g = SpecDepGraph::new();
885 g.add_dep(1, 2);
886 g.add_dep(2, 3);
887 g.add_dep(1, 3);
888 assert_eq!(g.dependencies_of(2), vec![1]);
889 let topo = g.topological_sort();
890 assert_eq!(topo.len(), 3);
891 assert!(!g.has_cycle());
892 let pos: std::collections::HashMap<u32, usize> =
893 topo.iter().enumerate().map(|(i, &n)| (n, i)).collect();
894 assert!(pos[&1] < pos[&2]);
895 assert!(pos[&1] < pos[&3]);
896 assert!(pos[&2] < pos[&3]);
897 }
898}
899#[cfg(test)]
900mod specext_pass_tests {
901 use super::*;
902 #[test]
903 pub(super) fn test_specext_phase_order() {
904 assert_eq!(SpecExtPassPhase::Early.order(), 0);
905 assert_eq!(SpecExtPassPhase::Middle.order(), 1);
906 assert_eq!(SpecExtPassPhase::Late.order(), 2);
907 assert_eq!(SpecExtPassPhase::Finalize.order(), 3);
908 assert!(SpecExtPassPhase::Early.is_early());
909 assert!(!SpecExtPassPhase::Early.is_late());
910 }
911 #[test]
912 pub(super) fn test_specext_config_builder() {
913 let c = SpecExtPassConfig::new("p")
914 .with_phase(SpecExtPassPhase::Late)
915 .with_max_iter(50)
916 .with_debug(1);
917 assert_eq!(c.name, "p");
918 assert_eq!(c.max_iterations, 50);
919 assert!(c.is_debug_enabled());
920 assert!(c.enabled);
921 let c2 = c.disabled();
922 assert!(!c2.enabled);
923 }
924 #[test]
925 pub(super) fn test_specext_stats() {
926 let mut s = SpecExtPassStats::new();
927 s.visit();
928 s.visit();
929 s.modify();
930 s.iterate();
931 assert_eq!(s.nodes_visited, 2);
932 assert_eq!(s.nodes_modified, 1);
933 assert!(s.changed);
934 assert_eq!(s.iterations, 1);
935 let e = s.efficiency();
936 assert!((e - 0.5).abs() < 1e-9);
937 }
938 #[test]
939 pub(super) fn test_specext_registry() {
940 let mut r = SpecExtPassRegistry::new();
941 r.register(SpecExtPassConfig::new("a").with_phase(SpecExtPassPhase::Early));
942 r.register(SpecExtPassConfig::new("b").disabled());
943 assert_eq!(r.len(), 2);
944 assert_eq!(r.enabled_passes().len(), 1);
945 assert_eq!(r.passes_in_phase(&SpecExtPassPhase::Early).len(), 1);
946 }
947 #[test]
948 pub(super) fn test_specext_cache() {
949 let mut c = SpecExtCache::new(4);
950 assert!(c.get(99).is_none());
951 c.put(99, vec![1, 2, 3]);
952 let v = c.get(99).expect("v should be present in map");
953 assert_eq!(v, &[1u8, 2, 3]);
954 assert!(c.hit_rate() > 0.0);
955 assert_eq!(c.live_count(), 1);
956 }
957 #[test]
958 pub(super) fn test_specext_worklist() {
959 let mut w = SpecExtWorklist::new(10);
960 w.push(5);
961 w.push(3);
962 w.push(5);
963 assert_eq!(w.len(), 2);
964 assert!(w.contains(5));
965 let first = w.pop().expect("first should be available to pop");
966 assert!(!w.contains(first));
967 }
968 #[test]
969 pub(super) fn test_specext_dom_tree() {
970 let mut dt = SpecExtDomTree::new(5);
971 dt.set_idom(1, 0);
972 dt.set_idom(2, 0);
973 dt.set_idom(3, 1);
974 dt.set_idom(4, 1);
975 assert!(dt.dominates(0, 3));
976 assert!(dt.dominates(1, 4));
977 assert!(!dt.dominates(2, 3));
978 assert_eq!(dt.depth_of(3), 2);
979 }
980 #[test]
981 pub(super) fn test_specext_liveness() {
982 let mut lv = SpecExtLiveness::new(3);
983 lv.add_def(0, 1);
984 lv.add_use(1, 1);
985 assert!(lv.var_is_def_in_block(0, 1));
986 assert!(lv.var_is_used_in_block(1, 1));
987 assert!(!lv.var_is_def_in_block(1, 1));
988 }
989 #[test]
990 pub(super) fn test_specext_const_folder() {
991 let mut cf = SpecExtConstFolder::new();
992 assert_eq!(cf.add_i64(3, 4), Some(7));
993 assert_eq!(cf.div_i64(10, 0), None);
994 assert_eq!(cf.mul_i64(6, 7), Some(42));
995 assert_eq!(cf.and_i64(0b1100, 0b1010), 0b1000);
996 assert_eq!(cf.fold_count(), 3);
997 assert_eq!(cf.failure_count(), 1);
998 }
999 #[test]
1000 pub(super) fn test_specext_dep_graph() {
1001 let mut g = SpecExtDepGraph::new(4);
1002 g.add_edge(0, 1);
1003 g.add_edge(1, 2);
1004 g.add_edge(2, 3);
1005 assert!(!g.has_cycle());
1006 assert_eq!(g.topo_sort(), Some(vec![0, 1, 2, 3]));
1007 assert_eq!(g.reachable(0).len(), 4);
1008 let sccs = g.scc();
1009 assert_eq!(sccs.len(), 4);
1010 }
1011}