1use std::collections::{HashMap, HashSet, VecDeque};
11
12use crate::cps_ir::{
13 CpsContinuationId, CpsHandlerEnv, CpsRecordField, CpsShotKind, CpsStmt, CpsTerminator,
14 CpsValueId,
15};
16use crate::cps_repr_abi::{
17 CpsReprAbiContinuation, CpsReprAbiFunction, CpsReprAbiModule, CpsReprAbiValue,
18};
19use yulang_typed_ir as typed_ir;
20
21#[derive(Debug, Clone, PartialEq, Eq)]
22pub struct CpsOptimizationOutput {
23 pub module: CpsReprAbiModule,
24 pub profile: CpsOptimizationProfile,
25}
26
27#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
28pub struct CpsOptimizationProfile {
29 pub functions: usize,
30 pub roots: usize,
31 pub continuations: usize,
32 pub handlers: usize,
33 pub statements: usize,
34 pub optimized_continuations: usize,
35 pub optimized_statements: usize,
36 pub passes_run: usize,
37 pub forwarded_continuation_calls: usize,
38 pub returned_continuation_calls: usize,
39 pub folded_constant_branches: usize,
40 pub rewritten_pure_effectful_calls: usize,
41 pub reified_primitive_calls: usize,
42 pub reified_partial_closure_calls: usize,
43 pub reified_known_closure_parameter_calls: usize,
44 pub removed_unused_continuation_params: usize,
45 pub folded_structural_projections: usize,
46 pub inlined_pure_direct_calls: usize,
47 pub inlined_continuation_calls: usize,
48 pub removed_unreachable_continuations: usize,
49 pub removed_dead_pure_statements: usize,
50 pub direct_style_islands: usize,
51 pub direct_style_continuations: usize,
52 pub changed: bool,
53}
54
55pub fn optimize_cps_repr_abi_module(module: &CpsReprAbiModule) -> CpsOptimizationOutput {
56 let mut output = CpsOptimizationOutput {
57 module: module.clone(),
58 profile: CpsOptimizationProfile::measure(module),
59 };
60
61 for _ in 0..4 {
62 if !run_simplification_round(&mut output) {
63 break;
64 }
65 }
66 output.profile.record_optimized_size(&output.module);
67 analyze_direct_style_islands(&mut output);
68 maybe_trace_profile(&output.profile);
69 output
70}
71
72fn run_simplification_round(output: &mut CpsOptimizationOutput) -> bool {
73 let before = output.profile;
74 rewrite_forwarding_continuation_calls(output);
75 rewrite_returning_continuation_calls(output);
76 fold_constant_branches(output);
77 rewrite_pure_effectful_calls(output);
78 reify_direct_primitive_calls(output);
79 reify_local_partial_closure_calls(output);
80 reify_known_closure_parameter_calls(output);
81 remove_unused_continuation_params(output);
82 fold_structural_projections(output);
83 inline_pure_direct_calls(output);
84 inline_single_use_continuation_calls(output);
85 reify_local_partial_closure_calls(output);
86 reify_known_closure_parameter_calls(output);
87 remove_unused_continuation_params(output);
88 prune_unreachable_continuations(output);
89 eliminate_dead_pure_statements(output);
90 prune_unreachable_continuations(output);
91 output.profile.has_more_changes_than(before)
92}
93
94fn rewrite_forwarding_continuation_calls(output: &mut CpsOptimizationOutput) {
95 output.profile.passes_run += 1;
96 for function in output
97 .module
98 .functions
99 .iter_mut()
100 .chain(&mut output.module.roots)
101 {
102 let forwarders = forwarding_continuations(function);
103 if forwarders.is_empty() {
104 continue;
105 }
106 for continuation in &mut function.continuations {
107 output.profile.forwarded_continuation_calls +=
108 rewrite_terminator_forwarders(&mut continuation.terminator, &forwarders);
109 }
110 }
111 output.profile.changed = output.profile.forwarded_continuation_calls > 0;
112}
113
114fn rewrite_returning_continuation_calls(output: &mut CpsOptimizationOutput) {
115 output.profile.passes_run += 1;
116 for function in output
117 .module
118 .functions
119 .iter_mut()
120 .chain(&mut output.module.roots)
121 {
122 let returners = returning_continuations(function);
123 if returners.is_empty() {
124 continue;
125 }
126 for continuation in &mut function.continuations {
127 output.profile.returned_continuation_calls +=
128 rewrite_terminator_returners(&mut continuation.terminator, &returners);
129 }
130 }
131 output.profile.changed |= output.profile.returned_continuation_calls > 0;
132}
133
134fn fold_constant_branches(output: &mut CpsOptimizationOutput) {
135 output.profile.passes_run += 1;
136 for function in output
137 .module
138 .functions
139 .iter_mut()
140 .chain(&mut output.module.roots)
141 {
142 let empty_param_continuations = function
143 .continuations
144 .iter()
145 .filter(|continuation| continuation.params.is_empty())
146 .map(|continuation| continuation.id)
147 .collect::<HashSet<_>>();
148 for continuation in &mut function.continuations {
149 output.profile.folded_constant_branches +=
150 fold_constant_branch_in_continuation(continuation, &empty_param_continuations);
151 }
152 }
153 output.profile.changed |= output.profile.folded_constant_branches > 0;
154}
155
156fn rewrite_pure_effectful_calls(output: &mut CpsOptimizationOutput) {
157 output.profile.passes_run += 1;
158 let pure_functions = pure_callable_functions(&output.module);
159 if pure_functions.is_empty() {
160 return;
161 }
162 for function in output
163 .module
164 .functions
165 .iter_mut()
166 .chain(&mut output.module.roots)
167 {
168 output.profile.rewritten_pure_effectful_calls +=
169 rewrite_pure_effectful_calls_in_function(function, &pure_functions);
170 }
171 output.profile.changed |= output.profile.rewritten_pure_effectful_calls > 0;
172}
173
174fn reify_direct_primitive_calls(output: &mut CpsOptimizationOutput) {
175 output.profile.passes_run += 1;
176 let primitives = primitive_wrappers(&output.module);
177 if primitives.is_empty() {
178 return;
179 }
180 for function in output
181 .module
182 .functions
183 .iter_mut()
184 .chain(&mut output.module.roots)
185 {
186 for continuation in &mut function.continuations {
187 for stmt in &mut continuation.stmts {
188 output.profile.reified_primitive_calls +=
189 reify_direct_primitive_stmt(stmt, &primitives);
190 }
191 }
192 }
193 output.profile.changed |= output.profile.reified_primitive_calls > 0;
194}
195
196fn reify_local_partial_closure_calls(output: &mut CpsOptimizationOutput) {
197 output.profile.passes_run += 1;
198 for function in output
199 .module
200 .functions
201 .iter_mut()
202 .chain(&mut output.module.roots)
203 {
204 let wrappers = partial_closure_wrappers(function);
205 if wrappers.is_empty() {
206 continue;
207 }
208 let resumable = scalar_resume_continuations(function);
209 let mut next_value = next_function_value_id(function);
210 for continuation in &mut function.continuations {
211 output.profile.reified_partial_closure_calls +=
212 reify_local_partial_closure_calls_in_continuation(
213 continuation,
214 &wrappers,
215 &resumable,
216 &mut next_value,
217 );
218 }
219 }
220 output.profile.changed |= output.profile.reified_partial_closure_calls > 0;
221}
222
223fn reify_known_closure_parameter_calls(output: &mut CpsOptimizationOutput) {
224 output.profile.passes_run += 1;
225 for function in output
226 .module
227 .functions
228 .iter_mut()
229 .chain(&mut output.module.roots)
230 {
231 let wrappers = partial_closure_wrappers(function);
232 if wrappers.is_empty() {
233 continue;
234 }
235 output.profile.reified_known_closure_parameter_calls +=
236 reify_known_closure_parameter_calls_in_function(function, &wrappers);
237 }
238 output.profile.changed |= output.profile.reified_known_closure_parameter_calls > 0;
239}
240
241fn remove_unused_continuation_params(output: &mut CpsOptimizationOutput) {
242 output.profile.passes_run += 1;
243 for function in output
244 .module
245 .functions
246 .iter_mut()
247 .chain(&mut output.module.roots)
248 {
249 output.profile.removed_unused_continuation_params +=
250 remove_unused_continuation_params_in_function(function);
251 }
252 output.profile.changed |= output.profile.removed_unused_continuation_params > 0;
253}
254
255fn fold_structural_projections(output: &mut CpsOptimizationOutput) {
256 output.profile.passes_run += 1;
257 for function in output
258 .module
259 .functions
260 .iter_mut()
261 .chain(&mut output.module.roots)
262 {
263 for continuation in &mut function.continuations {
264 output.profile.folded_structural_projections +=
265 fold_structural_projections_in_continuation(continuation);
266 }
267 }
268 output.profile.changed |= output.profile.folded_structural_projections > 0;
269}
270
271fn inline_pure_direct_calls(output: &mut CpsOptimizationOutput) {
272 output.profile.passes_run += 1;
273 let candidates = pure_direct_inline_candidates(&output.module);
274 if candidates.is_empty() {
275 return;
276 }
277 for function in output
278 .module
279 .functions
280 .iter_mut()
281 .chain(&mut output.module.roots)
282 {
283 output.profile.inlined_pure_direct_calls +=
284 inline_pure_direct_calls_in_function(function, &candidates);
285 }
286 output.profile.changed |= output.profile.inlined_pure_direct_calls > 0;
287}
288
289fn remove_unused_continuation_params_in_function(function: &mut CpsReprAbiFunction) -> usize {
290 let unused_slots = unused_continuation_param_slots(function);
291 if unused_slots.is_empty() {
292 return 0;
293 }
294
295 let mut removed = 0;
296 for continuation in &mut function.continuations {
297 if let Some(slots) = unused_slots.get(&continuation.id) {
298 removed += remove_indexed_items(&mut continuation.params, slots);
299 }
300 if let CpsTerminator::Continue { target, args } = &mut continuation.terminator {
301 if let Some(slots) = unused_slots.get(target) {
302 remove_indexed_items(args, slots);
303 }
304 }
305 }
306 removed
307}
308
309fn unused_continuation_param_slots(
310 function: &CpsReprAbiFunction,
311) -> HashMap<CpsContinuationId, HashSet<usize>> {
312 let references = continuation_references(function);
313 let protected = protected_continuations(function);
314 let used_values = function_used_values(function);
315
316 function
317 .continuations
318 .iter()
319 .filter(|continuation| !protected.contains(&continuation.id))
320 .filter(|continuation| {
321 references
322 .get(&continuation.id)
323 .is_some_and(|reference| reference.total == reference.continue_calls)
324 })
325 .filter_map(|continuation| {
326 let slots = continuation
327 .params
328 .iter()
329 .enumerate()
330 .filter_map(|(index, param)| (!used_values.contains(¶m.value)).then_some(index))
331 .collect::<HashSet<_>>();
332 (!slots.is_empty()).then_some((continuation.id, slots))
333 })
334 .collect()
335}
336
337fn function_used_values(function: &CpsReprAbiFunction) -> HashSet<CpsValueId> {
338 let mut used = HashSet::new();
339 for continuation in &function.continuations {
340 used.extend(continuation.environment.iter().map(|slot| slot.value));
341 for stmt in &continuation.stmts {
342 used.extend(stmt_operands(stmt));
343 }
344 used.extend(terminator_values(&continuation.terminator));
345 }
346 used
347}
348
349fn remove_indexed_items<T>(items: &mut Vec<T>, slots: &HashSet<usize>) -> usize {
350 let before = items.len();
351 let mut index = 0;
352 items.retain(|_| {
353 let keep = !slots.contains(&index);
354 index += 1;
355 keep
356 });
357 before - items.len()
358}
359
360fn inline_single_use_continuation_calls(output: &mut CpsOptimizationOutput) {
361 output.profile.passes_run += 1;
362 for function in output
363 .module
364 .functions
365 .iter_mut()
366 .chain(&mut output.module.roots)
367 {
368 let candidates = inline_candidates(function);
369 if candidates.is_empty() {
370 continue;
371 }
372 for index in 0..function.continuations.len() {
373 output.profile.inlined_continuation_calls +=
374 inline_continuation_call_at(function, index, &candidates);
375 }
376 }
377 output.profile.changed |= output.profile.inlined_continuation_calls > 0;
378}
379
380fn prune_unreachable_continuations(output: &mut CpsOptimizationOutput) {
381 output.profile.passes_run += 1;
382 for function in output
383 .module
384 .functions
385 .iter_mut()
386 .chain(&mut output.module.roots)
387 {
388 let reachable = reachable_continuations(function);
389 let before = function.continuations.len();
390 function
391 .continuations
392 .retain(|continuation| reachable.contains(&continuation.id));
393 output.profile.removed_unreachable_continuations += before - function.continuations.len();
394 }
395 output.profile.changed |= output.profile.removed_unreachable_continuations > 0;
396}
397
398fn eliminate_dead_pure_statements(output: &mut CpsOptimizationOutput) {
399 output.profile.passes_run += 1;
400 for function in output
401 .module
402 .functions
403 .iter_mut()
404 .chain(&mut output.module.roots)
405 {
406 let captured_values = function_captured_values(function);
407 for continuation in &mut function.continuations {
408 output.profile.removed_dead_pure_statements +=
409 eliminate_dead_pure_statements_in_continuation(continuation, &captured_values);
410 }
411 }
412 output.profile.changed |= output.profile.removed_dead_pure_statements > 0;
413}
414
415fn analyze_direct_style_islands(output: &mut CpsOptimizationOutput) {
416 output.profile.direct_style_islands = 0;
417 output.profile.direct_style_continuations = 0;
418 for function in output.module.functions.iter().chain(&output.module.roots) {
419 let islands = direct_style_islands(function);
420 output.profile.direct_style_islands += islands.len();
421 output.profile.direct_style_continuations += islands
422 .iter()
423 .map(|island| island.continuations.len())
424 .sum::<usize>();
425 }
426}
427
428fn maybe_trace_profile(profile: &CpsOptimizationProfile) {
429 if std::env::var_os("YULANG_CPS_OPT_TRACE").is_none() {
430 return;
431 }
432 eprintln!(
433 "[CPS-OPT] functions={} roots={} continuations={} optimized_continuations={} handlers={} statements={} optimized_statements={} passes={} forwarded_continuation_calls={} returned_continuation_calls={} folded_constant_branches={} rewritten_pure_effectful_calls={} reified_primitive_calls={} reified_partial_closure_calls={} reified_known_closure_parameter_calls={} removed_unused_continuation_params={} folded_structural_projections={} inlined_pure_direct_calls={} inlined_continuation_calls={} removed_unreachable_continuations={} removed_dead_pure_statements={} direct_style_islands={} direct_style_continuations={} changed={}",
434 profile.functions,
435 profile.roots,
436 profile.continuations,
437 profile.optimized_continuations,
438 profile.handlers,
439 profile.statements,
440 profile.optimized_statements,
441 profile.passes_run,
442 profile.forwarded_continuation_calls,
443 profile.returned_continuation_calls,
444 profile.folded_constant_branches,
445 profile.rewritten_pure_effectful_calls,
446 profile.reified_primitive_calls,
447 profile.reified_partial_closure_calls,
448 profile.reified_known_closure_parameter_calls,
449 profile.removed_unused_continuation_params,
450 profile.folded_structural_projections,
451 profile.inlined_pure_direct_calls,
452 profile.inlined_continuation_calls,
453 profile.removed_unreachable_continuations,
454 profile.removed_dead_pure_statements,
455 profile.direct_style_islands,
456 profile.direct_style_continuations,
457 profile.changed
458 );
459}
460
461fn primitive_wrappers(module: &CpsReprAbiModule) -> HashMap<String, PrimitiveWrapper> {
462 module
463 .functions
464 .iter()
465 .chain(&module.roots)
466 .filter_map(primitive_wrapper)
467 .collect()
468}
469
470fn primitive_wrapper(function: &CpsReprAbiFunction) -> Option<(String, PrimitiveWrapper)> {
471 if !function.handlers.is_empty() {
472 return None;
473 }
474 let continuation = function
475 .continuations
476 .iter()
477 .find(|continuation| continuation.id == function.entry)?;
478 if !continuation.environment.is_empty() || continuation.stmts.len() != 1 {
479 return None;
480 }
481 let [CpsStmt::Primitive { dest, op, args }] = continuation.stmts.as_slice() else {
482 return None;
483 };
484 if !matches!(continuation.terminator, CpsTerminator::Return(value) if value == *dest) {
485 return None;
486 }
487 let params = continuation
488 .params
489 .iter()
490 .map(|param| param.value)
491 .collect::<Vec<_>>();
492 if function
493 .params
494 .iter()
495 .map(|param| param.value)
496 .collect::<Vec<_>>()
497 != params
498 {
499 return None;
500 }
501 if *args != params {
502 return None;
503 }
504 Some((
505 function.name.clone(),
506 PrimitiveWrapper {
507 op: *op,
508 arity: params.len(),
509 },
510 ))
511}
512
513fn reify_direct_primitive_stmt(
514 stmt: &mut CpsStmt,
515 primitives: &HashMap<String, PrimitiveWrapper>,
516) -> usize {
517 let CpsStmt::DirectCall { dest, target, args } = stmt else {
518 return 0;
519 };
520 let Some(primitive) = primitives.get(target) else {
521 return 0;
522 };
523 if primitive.arity != args.len() {
524 return 0;
525 }
526 *stmt = CpsStmt::Primitive {
527 dest: *dest,
528 op: primitive.op,
529 args: args.clone(),
530 };
531 1
532}
533
534#[derive(Debug, Clone, Copy, PartialEq, Eq)]
535struct PrimitiveWrapper {
536 op: typed_ir::PrimitiveOp,
537 arity: usize,
538}
539
540fn pure_callable_functions(module: &CpsReprAbiModule) -> HashSet<String> {
541 module
542 .functions
543 .iter()
544 .filter(|function| function_is_pure_callable(function))
545 .map(|function| function.name.clone())
546 .collect()
547}
548
549fn function_is_pure_callable(function: &CpsReprAbiFunction) -> bool {
550 function.handlers.is_empty()
551 && function
552 .continuations
553 .iter()
554 .all(|continuation| continuation.environment.is_empty())
555 && function
556 .continuations
557 .iter()
558 .flat_map(|continuation| &continuation.stmts)
559 .all(stmt_is_direct_call_safe)
560 && function
561 .continuations
562 .iter()
563 .all(|continuation| terminator_is_direct_call_safe(&continuation.terminator))
564}
565
566fn stmt_is_direct_call_safe(stmt: &CpsStmt) -> bool {
567 matches!(
568 stmt,
569 CpsStmt::Literal { .. }
570 | CpsStmt::Tuple { .. }
571 | CpsStmt::Record { .. }
572 | CpsStmt::RecordWithoutFields { .. }
573 | CpsStmt::Variant { .. }
574 | CpsStmt::Select { .. }
575 | CpsStmt::SelectWithDefault { .. }
576 | CpsStmt::RecordHasField { .. }
577 | CpsStmt::TupleGet { .. }
578 | CpsStmt::VariantTagEq { .. }
579 | CpsStmt::VariantPayload { .. }
580 | CpsStmt::Primitive { .. }
581 | CpsStmt::DirectCall { .. }
582 )
583}
584
585fn terminator_is_direct_call_safe(terminator: &CpsTerminator) -> bool {
586 matches!(
587 terminator,
588 CpsTerminator::Return(_) | CpsTerminator::Continue { .. } | CpsTerminator::Branch { .. }
589 )
590}
591
592fn rewrite_pure_effectful_calls_in_function(
593 function: &mut CpsReprAbiFunction,
594 pure_functions: &HashSet<String>,
595) -> usize {
596 let resumable = scalar_resume_continuations(function);
597 let mut next_value = next_function_value_id(function);
598 let mut count = 0;
599
600 for continuation in &mut function.continuations {
601 let CpsTerminator::EffectfulCall {
602 target,
603 args,
604 resume,
605 } = &continuation.terminator
606 else {
607 continue;
608 };
609 if !pure_functions.contains(target) || !resumable.contains(resume) {
610 continue;
611 }
612 let dest = next_value;
613 next_value.0 += 1;
614 continuation.stmts.push(CpsStmt::DirectCall {
615 dest,
616 target: target.clone(),
617 args: args.clone(),
618 });
619 continuation.terminator = CpsTerminator::Continue {
620 target: *resume,
621 args: vec![dest],
622 };
623 count += 1;
624 }
625
626 count
627}
628
629fn fold_constant_branch_in_continuation(
630 continuation: &mut CpsReprAbiContinuation,
631 empty_param_continuations: &HashSet<CpsContinuationId>,
632) -> usize {
633 let (cond, then_cont, else_cont) = match &continuation.terminator {
634 CpsTerminator::Branch {
635 cond,
636 then_cont,
637 else_cont,
638 } => (*cond, *then_cont, *else_cont),
639 _ => return 0,
640 };
641 let Some(value) = local_bool_literal(continuation, cond) else {
642 return 0;
643 };
644 let target = if value { then_cont } else { else_cont };
645 if !empty_param_continuations.contains(&target) {
646 return 0;
647 }
648 continuation.terminator = CpsTerminator::Continue {
649 target,
650 args: Vec::new(),
651 };
652 1
653}
654
655fn local_bool_literal(continuation: &CpsReprAbiContinuation, value: CpsValueId) -> Option<bool> {
656 continuation.stmts.iter().find_map(|stmt| match stmt {
657 CpsStmt::Literal {
658 dest,
659 literal: crate::cps_ir::CpsLiteral::Bool(bool_value),
660 } if *dest == value => Some(*bool_value),
661 _ => None,
662 })
663}
664
665fn scalar_resume_continuations(function: &CpsReprAbiFunction) -> HashSet<CpsContinuationId> {
666 function
667 .continuations
668 .iter()
669 .filter(|continuation| {
670 continuation.environment.is_empty() && continuation.params.len() == 1
671 })
672 .map(|continuation| continuation.id)
673 .collect()
674}
675
676fn partial_closure_wrappers(
677 function: &CpsReprAbiFunction,
678) -> HashMap<CpsContinuationId, PartialClosureWrapper> {
679 function
680 .continuations
681 .iter()
682 .filter_map(partial_closure_wrapper)
683 .collect()
684}
685
686fn partial_closure_wrapper(
687 continuation: &CpsReprAbiContinuation,
688) -> Option<(CpsContinuationId, PartialClosureWrapper)> {
689 if continuation.params.len() != 1 || continuation.stmts.len() != 1 {
690 return None;
691 }
692 let [stmt] = continuation.stmts.as_slice() else {
693 return None;
694 };
695 let Some((dest, call, args)) = partial_closure_call_shape(stmt) else {
696 return None;
697 };
698 if !matches!(continuation.terminator, CpsTerminator::Return(value) if value == dest) {
699 return None;
700 }
701 let captured = continuation
702 .environment
703 .iter()
704 .map(|slot| slot.value)
705 .collect::<Vec<_>>();
706 let param = continuation.params[0].value;
707 if args.len() != captured.len() + 1 {
708 return None;
709 }
710 if args[..captured.len()] != captured {
711 return None;
712 }
713 if args[captured.len()] != param {
714 return None;
715 }
716 Some((continuation.id, PartialClosureWrapper { call, captured }))
717}
718
719fn partial_closure_call_shape(
720 stmt: &CpsStmt,
721) -> Option<(CpsValueId, PartialClosureCall, &[CpsValueId])> {
722 match stmt {
723 CpsStmt::DirectCall { dest, target, args } => Some((
724 *dest,
725 PartialClosureCall::Direct {
726 target: target.clone(),
727 },
728 args,
729 )),
730 CpsStmt::Primitive { dest, op, args } => {
731 Some((*dest, PartialClosureCall::Primitive { op: *op }, args))
732 }
733 _ => None,
734 }
735}
736
737fn reify_local_partial_closure_calls_in_continuation(
738 continuation: &mut CpsReprAbiContinuation,
739 wrappers: &HashMap<CpsContinuationId, PartialClosureWrapper>,
740 resumable: &HashSet<CpsContinuationId>,
741 next_value: &mut CpsValueId,
742) -> usize {
743 reify_partial_closure_calls_in_continuation(
744 continuation,
745 wrappers,
746 &HashMap::new(),
747 resumable,
748 next_value,
749 )
750}
751
752fn reify_known_closure_parameter_calls_in_function(
753 function: &mut CpsReprAbiFunction,
754 wrappers: &HashMap<CpsContinuationId, PartialClosureWrapper>,
755) -> usize {
756 let closure_values = local_closure_values(function, wrappers);
757 if closure_values.is_empty() {
758 return 0;
759 }
760 let parameter_wrappers = known_closure_parameter_wrappers(function, &closure_values);
761 if parameter_wrappers.is_empty() {
762 return 0;
763 }
764
765 let resumable = scalar_resume_continuations(function);
766 let mut next_value = next_function_value_id(function);
767 let mut count = 0;
768 for continuation in &mut function.continuations {
769 let Some(initial_closures) = parameter_wrappers.get(&continuation.id) else {
770 continue;
771 };
772 count += reify_partial_closure_calls_in_continuation(
773 continuation,
774 wrappers,
775 initial_closures,
776 &resumable,
777 &mut next_value,
778 );
779 }
780 count
781}
782
783fn reify_partial_closure_calls_in_continuation(
784 continuation: &mut CpsReprAbiContinuation,
785 wrappers: &HashMap<CpsContinuationId, PartialClosureWrapper>,
786 initial_closures: &HashMap<CpsValueId, PartialClosureWrapper>,
787 resumable: &HashSet<CpsContinuationId>,
788 next_value: &mut CpsValueId,
789) -> usize {
790 let mut closures = initial_closures.clone();
791 let mut count = 0;
792 for stmt in &mut continuation.stmts {
793 match stmt {
794 CpsStmt::MakeClosure { dest, entry } => {
795 if let Some(wrapper) = wrappers.get(entry) {
796 closures.insert(*dest, wrapper.clone());
797 }
798 }
799 CpsStmt::MakeRecursiveClosure { dest, .. } => {
800 closures.remove(dest);
801 }
802 CpsStmt::ApplyClosure { dest, closure, arg } => {
803 let Some(wrapper) = closures.get(closure) else {
804 continue;
805 };
806 let mut args = wrapper.captured.clone();
807 args.push(*arg);
808 *stmt = wrapper.call.to_stmt(*dest, args);
809 count += 1;
810 }
811 _ => {}
812 }
813 }
814 count += reify_partial_closure_terminator(
815 &mut continuation.stmts,
816 &mut continuation.terminator,
817 &closures,
818 resumable,
819 next_value,
820 );
821 count
822}
823
824fn reify_partial_closure_terminator(
825 stmts: &mut Vec<CpsStmt>,
826 terminator: &mut CpsTerminator,
827 closures: &HashMap<CpsValueId, PartialClosureWrapper>,
828 resumable: &HashSet<CpsContinuationId>,
829 next_value: &mut CpsValueId,
830) -> usize {
831 let (closure, arg, resume) = match terminator {
832 CpsTerminator::EffectfulApply {
833 closure,
834 arg,
835 resume,
836 } => (*closure, *arg, *resume),
837 _ => return 0,
838 };
839 let Some(wrapper) = closures.get(&closure) else {
840 return 0;
841 };
842 let mut args = wrapper.captured.clone();
843 args.push(arg);
844 match &wrapper.call {
845 PartialClosureCall::Direct { target } => {
846 *terminator = CpsTerminator::EffectfulCall {
847 target: target.clone(),
848 args,
849 resume,
850 };
851 1
852 }
853 PartialClosureCall::Primitive { op } if resumable.contains(&resume) => {
854 let dest = *next_value;
855 next_value.0 += 1;
856 stmts.push(CpsStmt::Primitive {
857 dest,
858 op: *op,
859 args,
860 });
861 *terminator = CpsTerminator::Continue {
862 target: resume,
863 args: vec![dest],
864 };
865 1
866 }
867 PartialClosureCall::Primitive { .. } => 0,
868 }
869}
870
871fn local_closure_values(
872 function: &CpsReprAbiFunction,
873 wrappers: &HashMap<CpsContinuationId, PartialClosureWrapper>,
874) -> HashMap<CpsValueId, PartialClosureWrapper> {
875 let mut closures = HashMap::new();
876 for continuation in &function.continuations {
877 for stmt in &continuation.stmts {
878 match stmt {
879 CpsStmt::MakeClosure { dest, entry } => {
880 let Some(wrapper) = wrappers.get(entry) else {
881 continue;
882 };
883 closures.insert(*dest, wrapper.clone());
884 }
885 CpsStmt::MakeRecursiveClosure { dest, .. } => {
886 closures.remove(dest);
887 }
888 _ => {}
889 }
890 }
891 }
892 closures
893}
894
895fn known_closure_parameter_wrappers(
896 function: &CpsReprAbiFunction,
897 closure_values: &HashMap<CpsValueId, PartialClosureWrapper>,
898) -> HashMap<CpsContinuationId, HashMap<CpsValueId, PartialClosureWrapper>> {
899 let continuations = function
900 .continuations
901 .iter()
902 .map(|continuation| (continuation.id, continuation))
903 .collect::<HashMap<_, _>>();
904 let references = continuation_references(function);
905 let protected = protected_continuations(function);
906 let mut candidates: HashMap<CpsContinuationId, Vec<KnownClosureParameterCandidate>> =
907 HashMap::new();
908 let mut blocked = HashSet::new();
909
910 for continuation in &function.continuations {
911 let CpsTerminator::Continue { target, args } = &continuation.terminator else {
912 continue;
913 };
914 if protected.contains(target) {
915 continue;
916 }
917 let Some(target_continuation) = continuations.get(target) else {
918 continue;
919 };
920 let Some(reference) = references.get(target) else {
921 continue;
922 };
923 if reference.total != reference.continue_calls
924 || args.len() != target_continuation.params.len()
925 {
926 blocked.insert(*target);
927 continue;
928 }
929
930 let slots = candidates.entry(*target).or_insert_with(|| {
931 vec![KnownClosureParameterCandidate::Unseen; target_continuation.params.len()]
932 });
933 for (index, arg) in args.iter().enumerate() {
934 let adapted = closure_values
935 .get(arg)
936 .and_then(|wrapper| wrapper.rebase_for_continue(args, &target_continuation.params));
937 slots[index].merge(adapted);
938 }
939 }
940
941 blocked.into_iter().for_each(|target| {
942 candidates.remove(&target);
943 });
944
945 candidates
946 .into_iter()
947 .filter_map(|(target, slots)| {
948 let continuation = continuations.get(&target)?;
949 let known = continuation
950 .params
951 .iter()
952 .zip(slots)
953 .filter_map(|(param, slot)| match slot {
954 KnownClosureParameterCandidate::Known(wrapper) => Some((param.value, wrapper)),
955 KnownClosureParameterCandidate::Unseen
956 | KnownClosureParameterCandidate::Conflict => None,
957 })
958 .collect::<HashMap<_, _>>();
959 (!known.is_empty()).then_some((target, known))
960 })
961 .collect()
962}
963
964#[derive(Debug, Clone, PartialEq, Eq)]
965enum KnownClosureParameterCandidate {
966 Unseen,
967 Known(PartialClosureWrapper),
968 Conflict,
969}
970
971impl KnownClosureParameterCandidate {
972 fn merge(&mut self, wrapper: Option<PartialClosureWrapper>) {
973 let Some(wrapper) = wrapper else {
974 *self = Self::Conflict;
975 return;
976 };
977 match self {
978 Self::Unseen => *self = Self::Known(wrapper),
979 Self::Known(current) if *current == wrapper => {}
980 Self::Known(_) | Self::Conflict => *self = Self::Conflict,
981 }
982 }
983}
984
985#[derive(Debug, Clone, PartialEq, Eq)]
986struct PartialClosureWrapper {
987 call: PartialClosureCall,
988 captured: Vec<CpsValueId>,
989}
990
991impl PartialClosureWrapper {
992 fn rebase_for_continue(
993 &self,
994 supplied_args: &[CpsValueId],
995 target_params: &[CpsReprAbiValue],
996 ) -> Option<Self> {
997 if supplied_args.len() != target_params.len() {
998 return None;
999 }
1000 let captured = self
1001 .captured
1002 .iter()
1003 .map(|captured| {
1004 supplied_args
1005 .iter()
1006 .position(|arg| arg == captured)
1007 .map(|index| target_params[index].value)
1008 })
1009 .collect::<Option<Vec<_>>>()?;
1010 Some(Self {
1011 call: self.call.clone(),
1012 captured,
1013 })
1014 }
1015}
1016
1017#[derive(Debug, Clone, PartialEq, Eq)]
1018enum PartialClosureCall {
1019 Direct { target: String },
1020 Primitive { op: typed_ir::PrimitiveOp },
1021}
1022
1023impl PartialClosureCall {
1024 fn to_stmt(&self, dest: CpsValueId, args: Vec<CpsValueId>) -> CpsStmt {
1025 match self {
1026 PartialClosureCall::Direct { target } => CpsStmt::DirectCall {
1027 dest,
1028 target: target.clone(),
1029 args,
1030 },
1031 PartialClosureCall::Primitive { op } => CpsStmt::Primitive {
1032 dest,
1033 op: *op,
1034 args,
1035 },
1036 }
1037 }
1038}
1039
1040fn pure_direct_inline_candidates(module: &CpsReprAbiModule) -> HashMap<String, PureDirectInline> {
1041 module
1042 .functions
1043 .iter()
1044 .filter_map(pure_direct_inline_candidate)
1045 .collect()
1046}
1047
1048fn pure_direct_inline_candidate(
1049 function: &CpsReprAbiFunction,
1050) -> Option<(String, PureDirectInline)> {
1051 if !function.handlers.is_empty() || function.continuations.len() != 1 {
1052 return None;
1053 }
1054 let continuation = function
1055 .continuations
1056 .iter()
1057 .find(|continuation| continuation.id == function.entry)?;
1058 if !continuation.environment.is_empty() || continuation.stmts.len() > 16 {
1059 return None;
1060 }
1061 if continuation.params.len() != function.params.len() {
1062 return None;
1063 }
1064 if continuation
1065 .params
1066 .iter()
1067 .map(|param| param.value)
1068 .ne(function.params.iter().map(|param| param.value))
1069 {
1070 return None;
1071 }
1072 if !continuation.stmts.iter().all(pure_direct_inline_stmt) {
1073 return None;
1074 }
1075 let CpsTerminator::Return(result) = continuation.terminator else {
1076 return None;
1077 };
1078 if !continuation
1079 .stmts
1080 .iter()
1081 .any(|stmt| stmt_dest(stmt) == Some(result))
1082 {
1083 return None;
1084 }
1085 Some((
1086 function.name.clone(),
1087 PureDirectInline {
1088 params: continuation
1089 .params
1090 .iter()
1091 .map(|param| param.value)
1092 .collect(),
1093 stmts: continuation.stmts.clone(),
1094 result,
1095 },
1096 ))
1097}
1098
1099fn pure_direct_inline_stmt(stmt: &CpsStmt) -> bool {
1100 matches!(
1101 stmt,
1102 CpsStmt::Literal { .. }
1103 | CpsStmt::Tuple { .. }
1104 | CpsStmt::Record { .. }
1105 | CpsStmt::RecordWithoutFields { .. }
1106 | CpsStmt::Variant { .. }
1107 | CpsStmt::Select { .. }
1108 | CpsStmt::SelectWithDefault { .. }
1109 | CpsStmt::RecordHasField { .. }
1110 | CpsStmt::TupleGet { .. }
1111 | CpsStmt::VariantTagEq { .. }
1112 | CpsStmt::VariantPayload { .. }
1113 | CpsStmt::Primitive { .. }
1114 )
1115}
1116
1117fn inline_pure_direct_calls_in_function(
1118 function: &mut CpsReprAbiFunction,
1119 candidates: &HashMap<String, PureDirectInline>,
1120) -> usize {
1121 let mut next_value = next_function_value_id(function);
1122 let mut count = 0;
1123 for continuation in &mut function.continuations {
1124 let mut stmts = Vec::with_capacity(continuation.stmts.len());
1125 for stmt in continuation.stmts.drain(..) {
1126 let CpsStmt::DirectCall { dest, target, args } = &stmt else {
1127 stmts.push(stmt);
1128 continue;
1129 };
1130 let Some(candidate) = candidates.get(target) else {
1131 stmts.push(stmt);
1132 continue;
1133 };
1134 if candidate.params.len() != args.len() {
1135 stmts.push(stmt);
1136 continue;
1137 }
1138 let mut substitution = candidate
1139 .params
1140 .iter()
1141 .copied()
1142 .zip(args.iter().copied())
1143 .collect::<HashMap<_, _>>();
1144 for stmt in &candidate.stmts {
1145 if let Some(value) = stmt_dest(stmt) {
1146 substitution.entry(value).or_insert_with(|| {
1147 let fresh = next_value;
1148 next_value.0 += 1;
1149 fresh
1150 });
1151 }
1152 }
1153 substitution.insert(candidate.result, *dest);
1154 stmts.extend(
1155 candidate
1156 .stmts
1157 .iter()
1158 .cloned()
1159 .map(|stmt| substitute_pure_inline_stmt_values(stmt, &substitution)),
1160 );
1161 count += 1;
1162 }
1163 continuation.stmts = stmts;
1164 }
1165 count
1166}
1167
1168fn substitute_pure_inline_stmt_values(
1169 stmt: CpsStmt,
1170 substitution: &HashMap<CpsValueId, CpsValueId>,
1171) -> CpsStmt {
1172 match stmt {
1173 CpsStmt::Literal { dest, literal } => CpsStmt::Literal {
1174 dest: subst_value(dest, substitution),
1175 literal,
1176 },
1177 CpsStmt::Tuple { dest, items } => CpsStmt::Tuple {
1178 dest: subst_value(dest, substitution),
1179 items: subst_values(items, substitution),
1180 },
1181 CpsStmt::Record { dest, base, fields } => CpsStmt::Record {
1182 dest: subst_value(dest, substitution),
1183 base: base.map(|value| subst_value(value, substitution)),
1184 fields: fields
1185 .into_iter()
1186 .map(|field| CpsRecordField {
1187 name: field.name,
1188 value: subst_value(field.value, substitution),
1189 })
1190 .collect(),
1191 },
1192 CpsStmt::RecordWithoutFields { dest, base, fields } => CpsStmt::RecordWithoutFields {
1193 dest: subst_value(dest, substitution),
1194 base: subst_value(base, substitution),
1195 fields,
1196 },
1197 CpsStmt::Variant { dest, tag, value } => CpsStmt::Variant {
1198 dest: subst_value(dest, substitution),
1199 tag,
1200 value: value.map(|value| subst_value(value, substitution)),
1201 },
1202 CpsStmt::Select { dest, base, field } => CpsStmt::Select {
1203 dest: subst_value(dest, substitution),
1204 base: subst_value(base, substitution),
1205 field,
1206 },
1207 CpsStmt::SelectWithDefault {
1208 dest,
1209 base,
1210 field,
1211 default,
1212 } => CpsStmt::SelectWithDefault {
1213 dest: subst_value(dest, substitution),
1214 base: subst_value(base, substitution),
1215 field,
1216 default: subst_value(default, substitution),
1217 },
1218 CpsStmt::RecordHasField { dest, base, field } => CpsStmt::RecordHasField {
1219 dest: subst_value(dest, substitution),
1220 base: subst_value(base, substitution),
1221 field,
1222 },
1223 CpsStmt::TupleGet { dest, tuple, index } => CpsStmt::TupleGet {
1224 dest: subst_value(dest, substitution),
1225 tuple: subst_value(tuple, substitution),
1226 index,
1227 },
1228 CpsStmt::VariantTagEq { dest, variant, tag } => CpsStmt::VariantTagEq {
1229 dest: subst_value(dest, substitution),
1230 variant: subst_value(variant, substitution),
1231 tag,
1232 },
1233 CpsStmt::VariantPayload { dest, variant } => CpsStmt::VariantPayload {
1234 dest: subst_value(dest, substitution),
1235 variant: subst_value(variant, substitution),
1236 },
1237 CpsStmt::Primitive { dest, op, args } => CpsStmt::Primitive {
1238 dest: subst_value(dest, substitution),
1239 op,
1240 args: subst_values(args, substitution),
1241 },
1242 stmt => stmt,
1243 }
1244}
1245
1246fn next_function_value_id(function: &CpsReprAbiFunction) -> CpsValueId {
1247 let max_value = function
1248 .params
1249 .iter()
1250 .map(|value| value.value)
1251 .chain(
1252 function
1253 .continuations
1254 .iter()
1255 .flat_map(continuation_value_ids),
1256 )
1257 .map(|value| value.0)
1258 .max()
1259 .unwrap_or(0);
1260 CpsValueId(max_value + 1)
1261}
1262
1263fn continuation_value_ids(
1264 continuation: &CpsReprAbiContinuation,
1265) -> impl Iterator<Item = CpsValueId> + '_ {
1266 continuation
1267 .params
1268 .iter()
1269 .map(|value| value.value)
1270 .chain(continuation.environment.iter().map(|slot| slot.value))
1271 .chain(continuation.stmts.iter().filter_map(stmt_dest))
1272}
1273
1274fn fold_structural_projections_in_continuation(continuation: &mut CpsReprAbiContinuation) -> usize {
1275 let mut aliases = HashMap::<CpsValueId, CpsValueId>::new();
1276 let mut tuples = HashMap::<CpsValueId, Vec<CpsValueId>>::new();
1277 let mut scalar_values = HashSet::<CpsValueId>::new();
1278 let mut stmts = Vec::with_capacity(continuation.stmts.len());
1279 let mut count = 0;
1280
1281 for stmt in continuation.stmts.drain(..) {
1282 let stmt = substitute_stmt_values(stmt, &aliases);
1283 match stmt {
1284 CpsStmt::Tuple { dest, items } => {
1285 tuples.insert(dest, items.clone());
1286 stmts.push(CpsStmt::Tuple { dest, items });
1287 }
1288 CpsStmt::TupleGet { dest, tuple, index } => {
1289 if let Some(items) = tuples.get(&tuple) {
1290 if let Some(value) = items.get(index).copied() {
1291 let value = resolve_alias(value, &aliases);
1292 if scalar_values.contains(&value) {
1293 aliases.insert(dest, value);
1294 scalar_values.insert(dest);
1295 count += 1;
1296 continue;
1297 }
1298 }
1299 }
1300 tuples.remove(&dest);
1301 stmts.push(CpsStmt::TupleGet { dest, tuple, index });
1302 }
1303 stmt => {
1304 if let Some(dest) = stmt_dest(&stmt) {
1305 tuples.remove(&dest);
1306 if stmt_produces_scalar_value(&stmt) {
1307 scalar_values.insert(dest);
1308 }
1309 }
1310 stmts.push(stmt);
1311 }
1312 }
1313 }
1314
1315 continuation.terminator =
1316 substitute_terminator_values(continuation.terminator.clone(), &aliases);
1317 continuation.stmts = stmts;
1318 count
1319}
1320
1321fn stmt_produces_scalar_value(stmt: &CpsStmt) -> bool {
1322 matches!(
1323 stmt,
1324 CpsStmt::Literal { .. }
1325 | CpsStmt::RecordHasField { .. }
1326 | CpsStmt::VariantTagEq { .. }
1327 | CpsStmt::Primitive {
1328 op: typed_ir::PrimitiveOp::BoolNot
1329 | typed_ir::PrimitiveOp::BoolEq
1330 | typed_ir::PrimitiveOp::IntAdd
1331 | typed_ir::PrimitiveOp::IntSub
1332 | typed_ir::PrimitiveOp::IntMul
1333 | typed_ir::PrimitiveOp::IntEq
1334 | typed_ir::PrimitiveOp::IntLt
1335 | typed_ir::PrimitiveOp::IntLe
1336 | typed_ir::PrimitiveOp::IntGt
1337 | typed_ir::PrimitiveOp::IntGe
1338 | typed_ir::PrimitiveOp::FloatAdd
1339 | typed_ir::PrimitiveOp::FloatSub
1340 | typed_ir::PrimitiveOp::FloatMul
1341 | typed_ir::PrimitiveOp::FloatEq
1342 | typed_ir::PrimitiveOp::FloatLt
1343 | typed_ir::PrimitiveOp::FloatLe
1344 | typed_ir::PrimitiveOp::FloatGt
1345 | typed_ir::PrimitiveOp::FloatGe,
1346 ..
1347 }
1348 )
1349}
1350
1351fn resolve_alias(mut value: CpsValueId, aliases: &HashMap<CpsValueId, CpsValueId>) -> CpsValueId {
1352 let mut seen = HashSet::new();
1353 while let Some(next) = aliases.get(&value).copied() {
1354 if !seen.insert(value) {
1355 break;
1356 }
1357 value = next;
1358 }
1359 value
1360}
1361
1362#[derive(Debug, Clone, PartialEq, Eq)]
1363struct PureDirectInline {
1364 params: Vec<CpsValueId>,
1365 stmts: Vec<CpsStmt>,
1366 result: CpsValueId,
1367}
1368
1369#[derive(Debug, Clone, PartialEq, Eq)]
1370struct DirectStyleIsland {
1371 continuations: Vec<CpsContinuationId>,
1372}
1373
1374fn direct_style_islands(function: &CpsReprAbiFunction) -> Vec<DirectStyleIsland> {
1375 let candidates = function
1376 .continuations
1377 .iter()
1378 .filter(|continuation| direct_style_candidate(continuation))
1379 .map(|continuation| continuation.id)
1380 .collect::<HashSet<_>>();
1381 if candidates.is_empty() {
1382 return Vec::new();
1383 }
1384
1385 let continuations = function
1386 .continuations
1387 .iter()
1388 .map(|continuation| (continuation.id, continuation))
1389 .collect::<HashMap<_, _>>();
1390 let mut visited = HashSet::new();
1391 let mut islands = Vec::new();
1392
1393 for start in candidates.iter().copied() {
1394 if visited.contains(&start) {
1395 continue;
1396 }
1397 let mut queue = VecDeque::from([start]);
1398 let mut island = Vec::new();
1399 visited.insert(start);
1400
1401 while let Some(id) = queue.pop_front() {
1402 island.push(id);
1403 let Some(continuation) = continuations.get(&id) else {
1404 continue;
1405 };
1406 for successor in direct_style_successors(&continuation.terminator) {
1407 if candidates.contains(&successor) && visited.insert(successor) {
1408 queue.push_back(successor);
1409 }
1410 }
1411 }
1412
1413 island.sort();
1414 islands.push(DirectStyleIsland {
1415 continuations: island,
1416 });
1417 }
1418
1419 islands.sort_by_key(|island| island.continuations.first().copied());
1420 islands
1421}
1422
1423fn direct_style_candidate(continuation: &CpsReprAbiContinuation) -> bool {
1424 if !continuation.environment.is_empty() {
1425 return false;
1426 }
1427 continuation.stmts.iter().all(direct_style_stmt)
1428 && matches!(
1429 continuation.terminator,
1430 CpsTerminator::Return(_)
1431 | CpsTerminator::Continue { .. }
1432 | CpsTerminator::Branch { .. }
1433 )
1434}
1435
1436fn direct_style_stmt(stmt: &CpsStmt) -> bool {
1437 matches!(
1438 stmt,
1439 CpsStmt::Literal { .. }
1440 | CpsStmt::Tuple { .. }
1441 | CpsStmt::Record { .. }
1442 | CpsStmt::RecordWithoutFields { .. }
1443 | CpsStmt::Variant { .. }
1444 | CpsStmt::Select { .. }
1445 | CpsStmt::SelectWithDefault { .. }
1446 | CpsStmt::RecordHasField { .. }
1447 | CpsStmt::TupleGet { .. }
1448 | CpsStmt::VariantTagEq { .. }
1449 | CpsStmt::VariantPayload { .. }
1450 | CpsStmt::Primitive { .. }
1451 | CpsStmt::DirectCall { .. }
1452 )
1453}
1454
1455fn direct_style_successors(terminator: &CpsTerminator) -> Vec<CpsContinuationId> {
1456 match terminator {
1457 CpsTerminator::Continue { target, .. } => vec![*target],
1458 CpsTerminator::Branch {
1459 then_cont,
1460 else_cont,
1461 ..
1462 } => vec![*then_cont, *else_cont],
1463 CpsTerminator::Return(_)
1464 | CpsTerminator::Perform { .. }
1465 | CpsTerminator::EffectfulCall { .. }
1466 | CpsTerminator::EffectfulApply { .. }
1467 | CpsTerminator::EffectfulForce { .. } => Vec::new(),
1468 }
1469}
1470
1471fn eliminate_dead_pure_statements_in_continuation(
1472 continuation: &mut CpsReprAbiContinuation,
1473 captured_values: &HashSet<CpsValueId>,
1474) -> usize {
1475 let mut live = terminator_values(&continuation.terminator)
1476 .into_iter()
1477 .collect::<HashSet<_>>();
1478 live.extend(captured_values.iter().copied());
1479 let mut kept = Vec::with_capacity(continuation.stmts.len());
1480 let mut removed = 0;
1481
1482 for stmt in continuation.stmts.iter().rev() {
1483 let dest = stmt_dest(stmt);
1484 if dest.is_some_and(|dest| !live.contains(&dest)) && stmt_is_pure(stmt) {
1485 removed += 1;
1486 continue;
1487 }
1488
1489 if let Some(dest) = dest {
1490 live.remove(&dest);
1491 }
1492 live.extend(stmt_operands(stmt));
1493 kept.push(stmt.clone());
1494 }
1495
1496 kept.reverse();
1497 continuation.stmts = kept;
1498 removed
1499}
1500
1501fn function_captured_values(function: &CpsReprAbiFunction) -> HashSet<CpsValueId> {
1502 function
1503 .continuations
1504 .iter()
1505 .flat_map(|continuation| continuation.environment.iter().map(|slot| slot.value))
1506 .collect()
1507}
1508
1509fn inline_candidates(
1510 function: &CpsReprAbiFunction,
1511) -> HashMap<CpsContinuationId, CpsReprAbiContinuation> {
1512 let references = continuation_references(function);
1513 let protected = protected_continuations(function);
1514 function
1515 .continuations
1516 .iter()
1517 .filter(|continuation| {
1518 if continuation.shot_kind != CpsShotKind::OneShot {
1519 return false;
1520 }
1521 if !continuation.environment.is_empty() {
1522 return false;
1523 }
1524 if continuation.stmts.len() > 12 {
1525 return false;
1526 }
1527 references
1528 .get(&continuation.id)
1529 .is_some_and(|reference| reference.total == 1 && reference.continue_calls == 1)
1530 })
1531 .filter(|continuation| !protected.contains(&continuation.id))
1532 .map(|continuation| (continuation.id, continuation.clone()))
1533 .collect()
1534}
1535
1536fn inline_continuation_call_at(
1537 function: &mut CpsReprAbiFunction,
1538 index: usize,
1539 candidates: &HashMap<CpsContinuationId, CpsReprAbiContinuation>,
1540) -> usize {
1541 let continuation = &mut function.continuations[index];
1542 let CpsTerminator::Continue { target, args } = &continuation.terminator else {
1543 return 0;
1544 };
1545 let Some(target_continuation) = candidates.get(target) else {
1546 return 0;
1547 };
1548 if target_continuation.id == continuation.id {
1549 return 0;
1550 }
1551 if target_continuation.params.len() != args.len() {
1552 return 0;
1553 }
1554
1555 let substitution = target_continuation
1556 .params
1557 .iter()
1558 .zip(args.iter().copied())
1559 .map(|(param, arg)| (param.value, arg))
1560 .collect::<HashMap<_, _>>();
1561 continuation.stmts.extend(
1562 target_continuation
1563 .stmts
1564 .iter()
1565 .cloned()
1566 .map(|stmt| substitute_stmt_values(stmt, &substitution)),
1567 );
1568 continuation.terminator =
1569 substitute_terminator_values(target_continuation.terminator.clone(), &substitution);
1570 1
1571}
1572
1573#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
1574struct ContinuationReferenceCount {
1575 total: usize,
1576 continue_calls: usize,
1577}
1578
1579fn continuation_references(
1580 function: &CpsReprAbiFunction,
1581) -> HashMap<CpsContinuationId, ContinuationReferenceCount> {
1582 let mut references = HashMap::new();
1583 for continuation in &function.continuations {
1584 for stmt in &continuation.stmts {
1585 collect_stmt_reference_counts(stmt, &mut references);
1586 }
1587 collect_terminator_reference_counts(&continuation.terminator, &mut references);
1588 }
1589 references
1590}
1591
1592fn protected_continuations(function: &CpsReprAbiFunction) -> HashSet<CpsContinuationId> {
1593 let mut protected = HashSet::new();
1594 protected.insert(function.entry);
1595 for handler in &function.handlers {
1596 for arm in &handler.arms {
1597 protected.insert(arm.entry);
1598 }
1599 }
1600 for continuation in &function.continuations {
1601 for stmt in &continuation.stmts {
1602 collect_protected_stmt_continuations(stmt, &mut protected);
1603 }
1604 }
1605 protected
1606}
1607
1608fn collect_stmt_reference_counts(
1609 stmt: &CpsStmt,
1610 references: &mut HashMap<CpsContinuationId, ContinuationReferenceCount>,
1611) {
1612 match stmt {
1613 CpsStmt::MakeThunk { entry, .. }
1614 | CpsStmt::MakeClosure { entry, .. }
1615 | CpsStmt::MakeRecursiveClosure { entry, .. } => {
1616 count_reference(*entry, references, false);
1617 }
1618 CpsStmt::InstallHandler {
1619 value,
1620 escape,
1621 envs,
1622 ..
1623 } => {
1624 count_reference(*value, references, false);
1625 count_reference(*escape, references, false);
1626 for env in envs {
1627 count_reference(env.entry, references, false);
1628 }
1629 }
1630 CpsStmt::ResumeWithHandler { envs, .. } => {
1631 for env in envs {
1632 count_reference(env.entry, references, false);
1633 }
1634 }
1635 CpsStmt::Literal { .. }
1636 | CpsStmt::FreshGuard { .. }
1637 | CpsStmt::PeekGuard { .. }
1638 | CpsStmt::FindGuard { .. }
1639 | CpsStmt::AddThunkBoundary { .. }
1640 | CpsStmt::ForceThunk { .. }
1641 | CpsStmt::Tuple { .. }
1642 | CpsStmt::Record { .. }
1643 | CpsStmt::RecordWithoutFields { .. }
1644 | CpsStmt::Variant { .. }
1645 | CpsStmt::Select { .. }
1646 | CpsStmt::SelectWithDefault { .. }
1647 | CpsStmt::RecordHasField { .. }
1648 | CpsStmt::TupleGet { .. }
1649 | CpsStmt::VariantTagEq { .. }
1650 | CpsStmt::VariantPayload { .. }
1651 | CpsStmt::Primitive { .. }
1652 | CpsStmt::DirectCall { .. }
1653 | CpsStmt::ApplyClosure { .. }
1654 | CpsStmt::CloneContinuation { .. }
1655 | CpsStmt::Resume { .. }
1656 | CpsStmt::UninstallHandler { .. } => {}
1657 }
1658}
1659
1660fn collect_terminator_reference_counts(
1661 terminator: &CpsTerminator,
1662 references: &mut HashMap<CpsContinuationId, ContinuationReferenceCount>,
1663) {
1664 match terminator {
1665 CpsTerminator::Continue { target, .. } => count_reference(*target, references, true),
1666 CpsTerminator::Branch {
1667 then_cont,
1668 else_cont,
1669 ..
1670 } => {
1671 count_reference(*then_cont, references, false);
1672 count_reference(*else_cont, references, false);
1673 }
1674 CpsTerminator::Perform { resume, .. }
1675 | CpsTerminator::EffectfulCall { resume, .. }
1676 | CpsTerminator::EffectfulApply { resume, .. }
1677 | CpsTerminator::EffectfulForce { resume, .. } => {
1678 count_reference(*resume, references, false)
1679 }
1680 CpsTerminator::Return(_) => {}
1681 }
1682}
1683
1684fn collect_protected_stmt_continuations(
1685 stmt: &CpsStmt,
1686 protected: &mut HashSet<CpsContinuationId>,
1687) {
1688 match stmt {
1689 CpsStmt::MakeThunk { entry, .. }
1690 | CpsStmt::MakeClosure { entry, .. }
1691 | CpsStmt::MakeRecursiveClosure { entry, .. } => {
1692 protected.insert(*entry);
1693 }
1694 CpsStmt::InstallHandler {
1695 value,
1696 escape,
1697 envs,
1698 ..
1699 } => {
1700 protected.insert(*value);
1701 protected.insert(*escape);
1702 for env in envs {
1703 protected.insert(env.entry);
1704 }
1705 }
1706 CpsStmt::ResumeWithHandler { envs, .. } => {
1707 for env in envs {
1708 protected.insert(env.entry);
1709 }
1710 }
1711 CpsStmt::Literal { .. }
1712 | CpsStmt::FreshGuard { .. }
1713 | CpsStmt::PeekGuard { .. }
1714 | CpsStmt::FindGuard { .. }
1715 | CpsStmt::AddThunkBoundary { .. }
1716 | CpsStmt::ForceThunk { .. }
1717 | CpsStmt::Tuple { .. }
1718 | CpsStmt::Record { .. }
1719 | CpsStmt::RecordWithoutFields { .. }
1720 | CpsStmt::Variant { .. }
1721 | CpsStmt::Select { .. }
1722 | CpsStmt::SelectWithDefault { .. }
1723 | CpsStmt::RecordHasField { .. }
1724 | CpsStmt::TupleGet { .. }
1725 | CpsStmt::VariantTagEq { .. }
1726 | CpsStmt::VariantPayload { .. }
1727 | CpsStmt::Primitive { .. }
1728 | CpsStmt::DirectCall { .. }
1729 | CpsStmt::ApplyClosure { .. }
1730 | CpsStmt::CloneContinuation { .. }
1731 | CpsStmt::Resume { .. }
1732 | CpsStmt::UninstallHandler { .. } => {}
1733 }
1734}
1735
1736fn count_reference(
1737 id: CpsContinuationId,
1738 references: &mut HashMap<CpsContinuationId, ContinuationReferenceCount>,
1739 is_continue_call: bool,
1740) {
1741 let reference = references.entry(id).or_default();
1742 reference.total += 1;
1743 if is_continue_call {
1744 reference.continue_calls += 1;
1745 }
1746}
1747
1748fn stmt_is_pure(stmt: &CpsStmt) -> bool {
1749 matches!(
1750 stmt,
1751 CpsStmt::Literal { .. }
1752 | CpsStmt::MakeThunk { .. }
1753 | CpsStmt::AddThunkBoundary { .. }
1754 | CpsStmt::MakeClosure { .. }
1755 | CpsStmt::MakeRecursiveClosure { .. }
1756 | CpsStmt::Tuple { .. }
1757 | CpsStmt::Record { .. }
1758 | CpsStmt::RecordWithoutFields { .. }
1759 | CpsStmt::Variant { .. }
1760 | CpsStmt::Select { .. }
1761 | CpsStmt::SelectWithDefault { .. }
1762 | CpsStmt::RecordHasField { .. }
1763 | CpsStmt::TupleGet { .. }
1764 | CpsStmt::VariantTagEq { .. }
1765 | CpsStmt::Primitive {
1766 op: typed_ir::PrimitiveOp::BoolNot
1767 | typed_ir::PrimitiveOp::BoolEq
1768 | typed_ir::PrimitiveOp::IntAdd
1769 | typed_ir::PrimitiveOp::IntSub
1770 | typed_ir::PrimitiveOp::IntMul
1771 | typed_ir::PrimitiveOp::IntEq
1772 | typed_ir::PrimitiveOp::IntLt
1773 | typed_ir::PrimitiveOp::IntLe
1774 | typed_ir::PrimitiveOp::IntGt
1775 | typed_ir::PrimitiveOp::IntGe
1776 | typed_ir::PrimitiveOp::IntToString
1777 | typed_ir::PrimitiveOp::IntToHex
1778 | typed_ir::PrimitiveOp::IntToUpperHex
1779 | typed_ir::PrimitiveOp::FloatAdd
1780 | typed_ir::PrimitiveOp::FloatSub
1781 | typed_ir::PrimitiveOp::FloatMul
1782 | typed_ir::PrimitiveOp::FloatEq
1783 | typed_ir::PrimitiveOp::FloatLt
1784 | typed_ir::PrimitiveOp::FloatLe
1785 | typed_ir::PrimitiveOp::FloatGt
1786 | typed_ir::PrimitiveOp::FloatGe
1787 | typed_ir::PrimitiveOp::FloatToString
1788 | typed_ir::PrimitiveOp::BoolToString
1789 | typed_ir::PrimitiveOp::StringConcat
1790 | typed_ir::PrimitiveOp::StringLen
1791 | typed_ir::PrimitiveOp::StringEq,
1792 ..
1793 }
1794 )
1795}
1796
1797fn stmt_dest(stmt: &CpsStmt) -> Option<CpsValueId> {
1798 match stmt {
1799 CpsStmt::Literal { dest, .. }
1800 | CpsStmt::FreshGuard { dest, .. }
1801 | CpsStmt::PeekGuard { dest }
1802 | CpsStmt::FindGuard { dest, .. }
1803 | CpsStmt::MakeThunk { dest, .. }
1804 | CpsStmt::AddThunkBoundary { dest, .. }
1805 | CpsStmt::MakeClosure { dest, .. }
1806 | CpsStmt::MakeRecursiveClosure { dest, .. }
1807 | CpsStmt::ForceThunk { dest, .. }
1808 | CpsStmt::Tuple { dest, .. }
1809 | CpsStmt::Record { dest, .. }
1810 | CpsStmt::RecordWithoutFields { dest, .. }
1811 | CpsStmt::Variant { dest, .. }
1812 | CpsStmt::Select { dest, .. }
1813 | CpsStmt::SelectWithDefault { dest, .. }
1814 | CpsStmt::RecordHasField { dest, .. }
1815 | CpsStmt::TupleGet { dest, .. }
1816 | CpsStmt::VariantTagEq { dest, .. }
1817 | CpsStmt::VariantPayload { dest, .. }
1818 | CpsStmt::Primitive { dest, .. }
1819 | CpsStmt::DirectCall { dest, .. }
1820 | CpsStmt::ApplyClosure { dest, .. }
1821 | CpsStmt::CloneContinuation { dest, .. }
1822 | CpsStmt::Resume { dest, .. }
1823 | CpsStmt::ResumeWithHandler { dest, .. } => Some(*dest),
1824 CpsStmt::InstallHandler { .. } | CpsStmt::UninstallHandler { .. } => None,
1825 }
1826}
1827
1828fn stmt_operands(stmt: &CpsStmt) -> Vec<CpsValueId> {
1829 match stmt {
1830 CpsStmt::FindGuard { guard, .. } => vec![*guard],
1831 CpsStmt::AddThunkBoundary { thunk, guard, .. } => vec![*thunk, *guard],
1832 CpsStmt::ForceThunk { thunk, .. } => vec![*thunk],
1833 CpsStmt::Tuple { items, .. } => items.clone(),
1834 CpsStmt::Record { base, fields, .. } => base
1835 .iter()
1836 .copied()
1837 .chain(fields.iter().map(|field| field.value))
1838 .collect(),
1839 CpsStmt::RecordWithoutFields { base, .. } => vec![*base],
1840 CpsStmt::Variant { value, .. } => value.iter().copied().collect(),
1841 CpsStmt::Select { base, .. } | CpsStmt::RecordHasField { base, .. } => vec![*base],
1842 CpsStmt::SelectWithDefault { base, default, .. } => vec![*base, *default],
1843 CpsStmt::TupleGet { tuple, .. } => vec![*tuple],
1844 CpsStmt::VariantTagEq { variant, .. } | CpsStmt::VariantPayload { variant, .. } => {
1845 vec![*variant]
1846 }
1847 CpsStmt::Primitive { args, .. } | CpsStmt::DirectCall { args, .. } => args.clone(),
1848 CpsStmt::ApplyClosure { closure, arg, .. } => vec![*closure, *arg],
1849 CpsStmt::CloneContinuation { source, .. } => vec![*source],
1850 CpsStmt::Resume {
1851 resumption, arg, ..
1852 } => vec![*resumption, *arg],
1853 CpsStmt::ResumeWithHandler {
1854 resumption,
1855 arg,
1856 envs,
1857 ..
1858 } => std::iter::once(*resumption)
1859 .chain(std::iter::once(*arg))
1860 .chain(envs.iter().flat_map(|env| env.values.iter().copied()))
1861 .collect(),
1862 CpsStmt::InstallHandler { envs, .. } => envs
1863 .iter()
1864 .flat_map(|env| env.values.iter().copied())
1865 .collect(),
1866 CpsStmt::Literal { .. }
1867 | CpsStmt::FreshGuard { .. }
1868 | CpsStmt::PeekGuard { .. }
1869 | CpsStmt::MakeThunk { .. }
1870 | CpsStmt::MakeClosure { .. }
1871 | CpsStmt::MakeRecursiveClosure { .. }
1872 | CpsStmt::UninstallHandler { .. } => Vec::new(),
1873 }
1874}
1875
1876fn terminator_values(terminator: &CpsTerminator) -> Vec<CpsValueId> {
1877 match terminator {
1878 CpsTerminator::Return(value) => vec![*value],
1879 CpsTerminator::Continue { args, .. } => args.clone(),
1880 CpsTerminator::Branch { cond, .. } => vec![*cond],
1881 CpsTerminator::Perform {
1882 payload, blocked, ..
1883 } => std::iter::once(*payload)
1884 .chain(blocked.iter().copied())
1885 .collect(),
1886 CpsTerminator::EffectfulCall { args, .. } => args.clone(),
1887 CpsTerminator::EffectfulApply { closure, arg, .. } => vec![*closure, *arg],
1888 CpsTerminator::EffectfulForce { thunk, .. } => vec![*thunk],
1889 }
1890}
1891
1892fn reachable_continuations(function: &CpsReprAbiFunction) -> HashSet<CpsContinuationId> {
1893 let continuations = function
1894 .continuations
1895 .iter()
1896 .map(|continuation| (continuation.id, continuation))
1897 .collect::<HashMap<_, _>>();
1898 let mut reachable = HashSet::new();
1899 let mut work = VecDeque::new();
1900
1901 push_reachable(function.entry, &mut reachable, &mut work);
1902 for handler in &function.handlers {
1903 for arm in &handler.arms {
1904 push_reachable(arm.entry, &mut reachable, &mut work);
1905 }
1906 }
1907
1908 while let Some(id) = work.pop_front() {
1909 let Some(continuation) = continuations.get(&id) else {
1910 continue;
1911 };
1912 for stmt in &continuation.stmts {
1913 collect_stmt_continuations(stmt, &mut reachable, &mut work);
1914 }
1915 collect_terminator_continuations(&continuation.terminator, &mut reachable, &mut work);
1916 }
1917
1918 reachable
1919}
1920
1921fn push_reachable(
1922 id: CpsContinuationId,
1923 reachable: &mut HashSet<CpsContinuationId>,
1924 work: &mut VecDeque<CpsContinuationId>,
1925) {
1926 if reachable.insert(id) {
1927 work.push_back(id);
1928 }
1929}
1930
1931fn collect_stmt_continuations(
1932 stmt: &CpsStmt,
1933 reachable: &mut HashSet<CpsContinuationId>,
1934 work: &mut VecDeque<CpsContinuationId>,
1935) {
1936 match stmt {
1937 CpsStmt::MakeThunk { entry, .. }
1938 | CpsStmt::MakeClosure { entry, .. }
1939 | CpsStmt::MakeRecursiveClosure { entry, .. } => {
1940 push_reachable(*entry, reachable, work);
1941 }
1942 CpsStmt::InstallHandler {
1943 value,
1944 escape,
1945 envs,
1946 ..
1947 } => {
1948 push_reachable(*value, reachable, work);
1949 push_reachable(*escape, reachable, work);
1950 for env in envs {
1951 push_reachable(env.entry, reachable, work);
1952 }
1953 }
1954 CpsStmt::ResumeWithHandler { envs, .. } => {
1955 for env in envs {
1956 push_reachable(env.entry, reachable, work);
1957 }
1958 }
1959 CpsStmt::Literal { .. }
1960 | CpsStmt::FreshGuard { .. }
1961 | CpsStmt::PeekGuard { .. }
1962 | CpsStmt::FindGuard { .. }
1963 | CpsStmt::AddThunkBoundary { .. }
1964 | CpsStmt::ForceThunk { .. }
1965 | CpsStmt::Tuple { .. }
1966 | CpsStmt::Record { .. }
1967 | CpsStmt::RecordWithoutFields { .. }
1968 | CpsStmt::Variant { .. }
1969 | CpsStmt::Select { .. }
1970 | CpsStmt::SelectWithDefault { .. }
1971 | CpsStmt::RecordHasField { .. }
1972 | CpsStmt::TupleGet { .. }
1973 | CpsStmt::VariantTagEq { .. }
1974 | CpsStmt::VariantPayload { .. }
1975 | CpsStmt::Primitive { .. }
1976 | CpsStmt::DirectCall { .. }
1977 | CpsStmt::ApplyClosure { .. }
1978 | CpsStmt::CloneContinuation { .. }
1979 | CpsStmt::Resume { .. }
1980 | CpsStmt::UninstallHandler { .. } => {}
1981 }
1982}
1983
1984fn collect_terminator_continuations(
1985 terminator: &CpsTerminator,
1986 reachable: &mut HashSet<CpsContinuationId>,
1987 work: &mut VecDeque<CpsContinuationId>,
1988) {
1989 match terminator {
1990 CpsTerminator::Continue { target, .. } => push_reachable(*target, reachable, work),
1991 CpsTerminator::Branch {
1992 then_cont,
1993 else_cont,
1994 ..
1995 } => {
1996 push_reachable(*then_cont, reachable, work);
1997 push_reachable(*else_cont, reachable, work);
1998 }
1999 CpsTerminator::Perform { resume, .. }
2000 | CpsTerminator::EffectfulCall { resume, .. }
2001 | CpsTerminator::EffectfulApply { resume, .. }
2002 | CpsTerminator::EffectfulForce { resume, .. } => push_reachable(*resume, reachable, work),
2003 CpsTerminator::Return(_) => {}
2004 }
2005}
2006
2007fn forwarding_continuations(
2008 function: &CpsReprAbiFunction,
2009) -> HashMap<CpsContinuationId, ForwardingContinuation> {
2010 let mut forwarders = HashMap::new();
2011 for continuation in &function.continuations {
2012 if !continuation.stmts.is_empty() || !continuation.environment.is_empty() {
2013 continue;
2014 }
2015 let CpsTerminator::Continue { target, args } = &continuation.terminator else {
2016 continue;
2017 };
2018 if *target == continuation.id {
2019 continue;
2020 }
2021 if args
2022 .iter()
2023 .all(|arg| continuation.params.iter().any(|param| param.value == *arg))
2024 {
2025 forwarders.insert(
2026 continuation.id,
2027 ForwardingContinuation {
2028 params: continuation
2029 .params
2030 .iter()
2031 .map(|param| param.value)
2032 .collect(),
2033 target: *target,
2034 args: args.clone(),
2035 },
2036 );
2037 }
2038 }
2039 forwarders
2040}
2041
2042fn returning_continuations(
2043 function: &CpsReprAbiFunction,
2044) -> HashMap<CpsContinuationId, ReturningContinuation> {
2045 let mut returners = HashMap::new();
2046 for continuation in &function.continuations {
2047 if !continuation.stmts.is_empty() || !continuation.environment.is_empty() {
2048 continue;
2049 }
2050 let CpsTerminator::Return(value) = continuation.terminator else {
2051 continue;
2052 };
2053 if let Some(param_index) = continuation
2054 .params
2055 .iter()
2056 .position(|param| param.value == value)
2057 {
2058 returners.insert(continuation.id, ReturningContinuation { param_index });
2059 }
2060 }
2061 returners
2062}
2063
2064fn rewrite_terminator_forwarders(
2065 terminator: &mut CpsTerminator,
2066 forwarders: &HashMap<CpsContinuationId, ForwardingContinuation>,
2067) -> usize {
2068 match terminator {
2069 CpsTerminator::Continue { target, args } => {
2070 rewrite_continuation_call(target, args, forwarders)
2071 }
2072 CpsTerminator::Perform { resume, .. }
2073 | CpsTerminator::EffectfulCall { resume, .. }
2074 | CpsTerminator::EffectfulApply { resume, .. }
2075 | CpsTerminator::EffectfulForce { resume, .. } => {
2076 let mut args = Vec::new();
2077 rewrite_resume_target(resume, &mut args, forwarders)
2078 }
2079 CpsTerminator::Branch {
2080 then_cont,
2081 else_cont,
2082 ..
2083 } => {
2084 let mut count = 0;
2085 let mut args = Vec::new();
2086 count += rewrite_resume_target(then_cont, &mut args, forwarders);
2087 count += rewrite_resume_target(else_cont, &mut args, forwarders);
2088 count
2089 }
2090 CpsTerminator::Return(_) => 0,
2091 }
2092}
2093
2094fn rewrite_terminator_returners(
2095 terminator: &mut CpsTerminator,
2096 returners: &HashMap<CpsContinuationId, ReturningContinuation>,
2097) -> usize {
2098 let CpsTerminator::Continue { target, args } = terminator else {
2099 return 0;
2100 };
2101 let Some(returner) = returners.get(target) else {
2102 return 0;
2103 };
2104 let Some(value) = args.get(returner.param_index).copied() else {
2105 return 0;
2106 };
2107 *terminator = CpsTerminator::Return(value);
2108 1
2109}
2110
2111fn rewrite_continuation_call(
2112 target: &mut CpsContinuationId,
2113 args: &mut Vec<CpsValueId>,
2114 forwarders: &HashMap<CpsContinuationId, ForwardingContinuation>,
2115) -> usize {
2116 let mut count = 0;
2117 while let Some(forwarder) = forwarders.get(target) {
2118 let Some(remapped) = forwarder.remap_args(args) else {
2119 break;
2120 };
2121 *target = forwarder.target;
2122 *args = remapped;
2123 count += 1;
2124 }
2125 count
2126}
2127
2128fn rewrite_resume_target(
2129 target: &mut CpsContinuationId,
2130 args: &mut Vec<CpsValueId>,
2131 forwarders: &HashMap<CpsContinuationId, ForwardingContinuation>,
2132) -> usize {
2133 let mut count = 0;
2134 while let Some(forwarder) = forwarders.get(target) {
2135 if !forwarder.params.is_empty() {
2136 break;
2137 }
2138 if !forwarder.args.is_empty() {
2139 break;
2140 }
2141 *target = forwarder.target;
2142 args.clear();
2143 count += 1;
2144 }
2145 count
2146}
2147
2148fn substitute_stmt_values(
2149 stmt: CpsStmt,
2150 substitution: &HashMap<CpsValueId, CpsValueId>,
2151) -> CpsStmt {
2152 match stmt {
2153 CpsStmt::Literal { dest, literal } => CpsStmt::Literal { dest, literal },
2154 CpsStmt::FreshGuard { dest, var } => CpsStmt::FreshGuard { dest, var },
2155 CpsStmt::PeekGuard { dest } => CpsStmt::PeekGuard { dest },
2156 CpsStmt::FindGuard { dest, guard } => CpsStmt::FindGuard {
2157 dest,
2158 guard: subst_value(guard, substitution),
2159 },
2160 CpsStmt::MakeThunk { dest, entry } => CpsStmt::MakeThunk { dest, entry },
2161 CpsStmt::AddThunkBoundary {
2162 dest,
2163 thunk,
2164 guard,
2165 allowed,
2166 active,
2167 } => CpsStmt::AddThunkBoundary {
2168 dest,
2169 thunk: subst_value(thunk, substitution),
2170 guard: subst_value(guard, substitution),
2171 allowed,
2172 active,
2173 },
2174 CpsStmt::MakeClosure { dest, entry } => CpsStmt::MakeClosure { dest, entry },
2175 CpsStmt::MakeRecursiveClosure { dest, entry } => {
2176 CpsStmt::MakeRecursiveClosure { dest, entry }
2177 }
2178 CpsStmt::ForceThunk { dest, thunk } => CpsStmt::ForceThunk {
2179 dest,
2180 thunk: subst_value(thunk, substitution),
2181 },
2182 CpsStmt::Tuple { dest, items } => CpsStmt::Tuple {
2183 dest,
2184 items: subst_values(items, substitution),
2185 },
2186 CpsStmt::Record { dest, base, fields } => CpsStmt::Record {
2187 dest,
2188 base: base.map(|value| subst_value(value, substitution)),
2189 fields: fields
2190 .into_iter()
2191 .map(|field| CpsRecordField {
2192 name: field.name,
2193 value: subst_value(field.value, substitution),
2194 })
2195 .collect(),
2196 },
2197 CpsStmt::RecordWithoutFields { dest, base, fields } => CpsStmt::RecordWithoutFields {
2198 dest,
2199 base: subst_value(base, substitution),
2200 fields,
2201 },
2202 CpsStmt::Variant { dest, tag, value } => CpsStmt::Variant {
2203 dest,
2204 tag,
2205 value: value.map(|value| subst_value(value, substitution)),
2206 },
2207 CpsStmt::Select { dest, base, field } => CpsStmt::Select {
2208 dest,
2209 base: subst_value(base, substitution),
2210 field,
2211 },
2212 CpsStmt::SelectWithDefault {
2213 dest,
2214 base,
2215 field,
2216 default,
2217 } => CpsStmt::SelectWithDefault {
2218 dest,
2219 base: subst_value(base, substitution),
2220 field,
2221 default: subst_value(default, substitution),
2222 },
2223 CpsStmt::RecordHasField { dest, base, field } => CpsStmt::RecordHasField {
2224 dest,
2225 base: subst_value(base, substitution),
2226 field,
2227 },
2228 CpsStmt::TupleGet { dest, tuple, index } => CpsStmt::TupleGet {
2229 dest,
2230 tuple: subst_value(tuple, substitution),
2231 index,
2232 },
2233 CpsStmt::VariantTagEq { dest, variant, tag } => CpsStmt::VariantTagEq {
2234 dest,
2235 variant: subst_value(variant, substitution),
2236 tag,
2237 },
2238 CpsStmt::VariantPayload { dest, variant } => CpsStmt::VariantPayload {
2239 dest,
2240 variant: subst_value(variant, substitution),
2241 },
2242 CpsStmt::Primitive { dest, op, args } => CpsStmt::Primitive {
2243 dest,
2244 op,
2245 args: subst_values(args, substitution),
2246 },
2247 CpsStmt::DirectCall { dest, target, args } => CpsStmt::DirectCall {
2248 dest,
2249 target,
2250 args: subst_values(args, substitution),
2251 },
2252 CpsStmt::ApplyClosure { dest, closure, arg } => CpsStmt::ApplyClosure {
2253 dest,
2254 closure: subst_value(closure, substitution),
2255 arg: subst_value(arg, substitution),
2256 },
2257 CpsStmt::CloneContinuation { dest, source } => CpsStmt::CloneContinuation {
2258 dest,
2259 source: subst_value(source, substitution),
2260 },
2261 CpsStmt::Resume {
2262 dest,
2263 resumption,
2264 arg,
2265 } => CpsStmt::Resume {
2266 dest,
2267 resumption: subst_value(resumption, substitution),
2268 arg: subst_value(arg, substitution),
2269 },
2270 CpsStmt::ResumeWithHandler {
2271 dest,
2272 resumption,
2273 arg,
2274 handler,
2275 envs,
2276 } => CpsStmt::ResumeWithHandler {
2277 dest,
2278 resumption: subst_value(resumption, substitution),
2279 arg: subst_value(arg, substitution),
2280 handler,
2281 envs: subst_handler_envs(envs, substitution),
2282 },
2283 CpsStmt::InstallHandler {
2284 handler,
2285 envs,
2286 value,
2287 escape,
2288 } => CpsStmt::InstallHandler {
2289 handler,
2290 envs: subst_handler_envs(envs, substitution),
2291 value,
2292 escape,
2293 },
2294 CpsStmt::UninstallHandler { handler } => CpsStmt::UninstallHandler { handler },
2295 }
2296}
2297
2298fn substitute_terminator_values(
2299 terminator: CpsTerminator,
2300 substitution: &HashMap<CpsValueId, CpsValueId>,
2301) -> CpsTerminator {
2302 match terminator {
2303 CpsTerminator::Return(value) => CpsTerminator::Return(subst_value(value, substitution)),
2304 CpsTerminator::Continue { target, args } => CpsTerminator::Continue {
2305 target,
2306 args: subst_values(args, substitution),
2307 },
2308 CpsTerminator::Branch {
2309 cond,
2310 then_cont,
2311 else_cont,
2312 } => CpsTerminator::Branch {
2313 cond: subst_value(cond, substitution),
2314 then_cont,
2315 else_cont,
2316 },
2317 CpsTerminator::Perform {
2318 effect,
2319 payload,
2320 resume,
2321 handler,
2322 blocked,
2323 } => CpsTerminator::Perform {
2324 effect,
2325 payload: subst_value(payload, substitution),
2326 resume,
2327 handler,
2328 blocked: blocked.map(|value| subst_value(value, substitution)),
2329 },
2330 CpsTerminator::EffectfulCall {
2331 target,
2332 args,
2333 resume,
2334 } => CpsTerminator::EffectfulCall {
2335 target,
2336 args: subst_values(args, substitution),
2337 resume,
2338 },
2339 CpsTerminator::EffectfulApply {
2340 closure,
2341 arg,
2342 resume,
2343 } => CpsTerminator::EffectfulApply {
2344 closure: subst_value(closure, substitution),
2345 arg: subst_value(arg, substitution),
2346 resume,
2347 },
2348 CpsTerminator::EffectfulForce { thunk, resume } => CpsTerminator::EffectfulForce {
2349 thunk: subst_value(thunk, substitution),
2350 resume,
2351 },
2352 }
2353}
2354
2355fn subst_handler_envs(
2356 envs: Vec<CpsHandlerEnv>,
2357 substitution: &HashMap<CpsValueId, CpsValueId>,
2358) -> Vec<CpsHandlerEnv> {
2359 envs.into_iter()
2360 .map(|env| CpsHandlerEnv {
2361 entry: env.entry,
2362 values: subst_values(env.values, substitution),
2363 targets: subst_values(env.targets, substitution),
2364 })
2365 .collect()
2366}
2367
2368fn subst_values(
2369 values: Vec<CpsValueId>,
2370 substitution: &HashMap<CpsValueId, CpsValueId>,
2371) -> Vec<CpsValueId> {
2372 values
2373 .into_iter()
2374 .map(|value| subst_value(value, substitution))
2375 .collect()
2376}
2377
2378fn subst_value(value: CpsValueId, substitution: &HashMap<CpsValueId, CpsValueId>) -> CpsValueId {
2379 substitution.get(&value).copied().unwrap_or(value)
2380}
2381
2382#[derive(Debug, Clone, PartialEq, Eq)]
2383struct ForwardingContinuation {
2384 params: Vec<CpsValueId>,
2385 target: CpsContinuationId,
2386 args: Vec<CpsValueId>,
2387}
2388
2389#[derive(Debug, Clone, Copy, PartialEq, Eq)]
2390struct ReturningContinuation {
2391 param_index: usize,
2392}
2393
2394impl ForwardingContinuation {
2395 fn remap_args(&self, supplied_args: &[CpsValueId]) -> Option<Vec<CpsValueId>> {
2396 if supplied_args.len() != self.params.len() {
2397 return None;
2398 }
2399 self.args
2400 .iter()
2401 .map(|forwarded| {
2402 self.params
2403 .iter()
2404 .position(|param| param == forwarded)
2405 .map(|index| supplied_args[index])
2406 })
2407 .collect()
2408 }
2409}
2410
2411impl CpsOptimizationProfile {
2412 fn record_optimized_size(&mut self, module: &CpsReprAbiModule) {
2413 self.optimized_continuations = module
2414 .functions
2415 .iter()
2416 .chain(&module.roots)
2417 .map(|function| function.continuations.len())
2418 .sum();
2419 self.optimized_statements = module
2420 .functions
2421 .iter()
2422 .chain(&module.roots)
2423 .flat_map(|function| &function.continuations)
2424 .map(|continuation| continuation.stmts.len())
2425 .sum();
2426 }
2427
2428 fn has_more_changes_than(self, before: Self) -> bool {
2429 self.forwarded_continuation_calls > before.forwarded_continuation_calls
2430 || self.returned_continuation_calls > before.returned_continuation_calls
2431 || self.folded_constant_branches > before.folded_constant_branches
2432 || self.rewritten_pure_effectful_calls > before.rewritten_pure_effectful_calls
2433 || self.reified_primitive_calls > before.reified_primitive_calls
2434 || self.reified_partial_closure_calls > before.reified_partial_closure_calls
2435 || self.reified_known_closure_parameter_calls
2436 > before.reified_known_closure_parameter_calls
2437 || self.removed_unused_continuation_params > before.removed_unused_continuation_params
2438 || self.folded_structural_projections > before.folded_structural_projections
2439 || self.inlined_pure_direct_calls > before.inlined_pure_direct_calls
2440 || self.inlined_continuation_calls > before.inlined_continuation_calls
2441 || self.removed_unreachable_continuations > before.removed_unreachable_continuations
2442 || self.removed_dead_pure_statements > before.removed_dead_pure_statements
2443 }
2444
2445 pub fn measure(module: &CpsReprAbiModule) -> Self {
2446 let functions = module.functions.len();
2447 let roots = module.roots.len();
2448 let continuations = module
2449 .functions
2450 .iter()
2451 .chain(&module.roots)
2452 .map(|function| function.continuations.len())
2453 .sum();
2454 let handlers = module
2455 .functions
2456 .iter()
2457 .chain(&module.roots)
2458 .map(|function| function.handlers.len())
2459 .sum();
2460 let statements = module
2461 .functions
2462 .iter()
2463 .chain(&module.roots)
2464 .flat_map(|function| &function.continuations)
2465 .map(|continuation| continuation.stmts.len())
2466 .sum();
2467
2468 Self {
2469 functions,
2470 roots,
2471 continuations,
2472 handlers,
2473 statements,
2474 optimized_continuations: continuations,
2475 optimized_statements: statements,
2476 passes_run: 0,
2477 forwarded_continuation_calls: 0,
2478 returned_continuation_calls: 0,
2479 folded_constant_branches: 0,
2480 rewritten_pure_effectful_calls: 0,
2481 reified_primitive_calls: 0,
2482 reified_partial_closure_calls: 0,
2483 reified_known_closure_parameter_calls: 0,
2484 removed_unused_continuation_params: 0,
2485 folded_structural_projections: 0,
2486 inlined_pure_direct_calls: 0,
2487 inlined_continuation_calls: 0,
2488 removed_unreachable_continuations: 0,
2489 removed_dead_pure_statements: 0,
2490 direct_style_islands: 0,
2491 direct_style_continuations: 0,
2492 changed: false,
2493 }
2494 }
2495}
2496
2497#[cfg(test)]
2498mod tests {
2499 use crate::cps_ir::{
2500 CpsContinuationId, CpsFunction, CpsLiteral, CpsModule, CpsShotKind, CpsStmt, CpsTerminator,
2501 CpsValueId,
2502 };
2503 use crate::cps_repr::lower_cps_repr_module;
2504 use crate::cps_repr_abi::lower_cps_repr_abi_module;
2505
2506 use super::*;
2507
2508 #[test]
2509 fn optimization_boundary_keeps_non_forwarding_module() {
2510 let abi = sample_abi_module();
2511 let optimized = optimize_cps_repr_abi_module(&abi);
2512
2513 assert_eq!(optimized.module, abi);
2514 assert_eq!(optimized.profile.roots, 1);
2515 assert_eq!(optimized.profile.continuations, 1);
2516 assert_eq!(optimized.profile.optimized_continuations, 1);
2517 assert_eq!(optimized.profile.statements, 1);
2518 assert_eq!(optimized.profile.optimized_statements, 1);
2519 assert_eq!(optimized.profile.passes_run, 17);
2520 assert_eq!(optimized.profile.forwarded_continuation_calls, 0);
2521 assert_eq!(optimized.profile.returned_continuation_calls, 0);
2522 assert_eq!(optimized.profile.folded_constant_branches, 0);
2523 assert_eq!(optimized.profile.rewritten_pure_effectful_calls, 0);
2524 assert_eq!(optimized.profile.reified_primitive_calls, 0);
2525 assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
2526 assert_eq!(optimized.profile.reified_known_closure_parameter_calls, 0);
2527 assert_eq!(optimized.profile.removed_unused_continuation_params, 0);
2528 assert_eq!(optimized.profile.folded_structural_projections, 0);
2529 assert_eq!(optimized.profile.inlined_pure_direct_calls, 0);
2530 assert_eq!(optimized.profile.inlined_continuation_calls, 0);
2531 assert_eq!(optimized.profile.removed_unreachable_continuations, 0);
2532 assert_eq!(optimized.profile.removed_dead_pure_statements, 0);
2533 assert_eq!(optimized.profile.direct_style_islands, 1);
2534 assert_eq!(optimized.profile.direct_style_continuations, 1);
2535 assert!(!optimized.profile.changed);
2536 }
2537
2538 #[test]
2539 fn rewrites_empty_continue_forwarder_calls() {
2540 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2541 functions: Vec::new(),
2542 roots: vec![CpsFunction {
2543 name: "root".to_string(),
2544 params: Vec::new(),
2545 entry: CpsContinuationId(0),
2546 handlers: Vec::new(),
2547 continuations: vec![
2548 crate::cps_ir::CpsContinuation {
2549 id: CpsContinuationId(0),
2550 params: Vec::new(),
2551 captures: Vec::new(),
2552 shot_kind: CpsShotKind::OneShot,
2553 stmts: vec![CpsStmt::Literal {
2554 dest: CpsValueId(0),
2555 literal: CpsLiteral::Int("42".to_string()),
2556 }],
2557 terminator: CpsTerminator::Continue {
2558 target: CpsContinuationId(1),
2559 args: vec![CpsValueId(0)],
2560 },
2561 },
2562 crate::cps_ir::CpsContinuation {
2563 id: CpsContinuationId(1),
2564 params: vec![CpsValueId(1)],
2565 captures: Vec::new(),
2566 shot_kind: CpsShotKind::OneShot,
2567 stmts: Vec::new(),
2568 terminator: CpsTerminator::Continue {
2569 target: CpsContinuationId(2),
2570 args: vec![CpsValueId(1)],
2571 },
2572 },
2573 crate::cps_ir::CpsContinuation {
2574 id: CpsContinuationId(2),
2575 params: vec![CpsValueId(2)],
2576 captures: Vec::new(),
2577 shot_kind: CpsShotKind::OneShot,
2578 stmts: Vec::new(),
2579 terminator: CpsTerminator::Return(CpsValueId(2)),
2580 },
2581 ],
2582 }],
2583 }));
2584
2585 let optimized = optimize_cps_repr_abi_module(&abi);
2586 let entry = &optimized.module.roots[0].continuations[0];
2587
2588 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(0)));
2589 assert_eq!(optimized.profile.forwarded_continuation_calls, 1);
2590 assert_eq!(optimized.profile.returned_continuation_calls, 2);
2591 assert_eq!(optimized.profile.reified_primitive_calls, 0);
2592 assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
2593 assert_eq!(optimized.profile.inlined_pure_direct_calls, 0);
2594 assert_eq!(optimized.profile.inlined_continuation_calls, 0);
2595 assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
2596 assert_eq!(optimized.profile.removed_dead_pure_statements, 0);
2597 assert_eq!(optimized.profile.direct_style_islands, 1);
2598 assert_eq!(optimized.profile.direct_style_continuations, 1);
2599 assert!(optimized.profile.changed);
2600 }
2601
2602 #[test]
2603 fn rewrites_empty_returning_continuation_calls() {
2604 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2605 functions: Vec::new(),
2606 roots: vec![CpsFunction {
2607 name: "root".to_string(),
2608 params: Vec::new(),
2609 entry: CpsContinuationId(0),
2610 handlers: Vec::new(),
2611 continuations: vec![
2612 crate::cps_ir::CpsContinuation {
2613 id: CpsContinuationId(0),
2614 params: Vec::new(),
2615 captures: Vec::new(),
2616 shot_kind: CpsShotKind::OneShot,
2617 stmts: vec![CpsStmt::Literal {
2618 dest: CpsValueId(0),
2619 literal: CpsLiteral::Int("42".to_string()),
2620 }],
2621 terminator: CpsTerminator::Continue {
2622 target: CpsContinuationId(1),
2623 args: vec![CpsValueId(0)],
2624 },
2625 },
2626 crate::cps_ir::CpsContinuation {
2627 id: CpsContinuationId(1),
2628 params: vec![CpsValueId(1)],
2629 captures: Vec::new(),
2630 shot_kind: CpsShotKind::OneShot,
2631 stmts: Vec::new(),
2632 terminator: CpsTerminator::Return(CpsValueId(1)),
2633 },
2634 ],
2635 }],
2636 }));
2637
2638 let optimized = optimize_cps_repr_abi_module(&abi);
2639 let entry = &optimized.module.roots[0].continuations[0];
2640
2641 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(0)));
2642 assert_eq!(optimized.profile.returned_continuation_calls, 1);
2643 assert_eq!(optimized.profile.reified_primitive_calls, 0);
2644 assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
2645 assert_eq!(optimized.profile.inlined_pure_direct_calls, 0);
2646 assert_eq!(optimized.profile.inlined_continuation_calls, 0);
2647 assert_eq!(optimized.profile.removed_unreachable_continuations, 1);
2648 assert_eq!(optimized.profile.removed_dead_pure_statements, 0);
2649 assert_eq!(optimized.profile.direct_style_islands, 1);
2650 assert_eq!(optimized.profile.direct_style_continuations, 1);
2651 assert!(optimized.profile.changed);
2652 }
2653
2654 #[test]
2655 fn inlines_single_use_one_shot_continuations() {
2656 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2657 functions: Vec::new(),
2658 roots: vec![CpsFunction {
2659 name: "root".to_string(),
2660 params: Vec::new(),
2661 entry: CpsContinuationId(0),
2662 handlers: Vec::new(),
2663 continuations: vec![
2664 crate::cps_ir::CpsContinuation {
2665 id: CpsContinuationId(0),
2666 params: Vec::new(),
2667 captures: Vec::new(),
2668 shot_kind: CpsShotKind::OneShot,
2669 stmts: vec![CpsStmt::Literal {
2670 dest: CpsValueId(0),
2671 literal: CpsLiteral::Int("41".to_string()),
2672 }],
2673 terminator: CpsTerminator::Continue {
2674 target: CpsContinuationId(1),
2675 args: vec![CpsValueId(0)],
2676 },
2677 },
2678 crate::cps_ir::CpsContinuation {
2679 id: CpsContinuationId(1),
2680 params: vec![CpsValueId(1)],
2681 captures: Vec::new(),
2682 shot_kind: CpsShotKind::OneShot,
2683 stmts: vec![
2684 CpsStmt::Literal {
2685 dest: CpsValueId(2),
2686 literal: CpsLiteral::Int("1".to_string()),
2687 },
2688 CpsStmt::Primitive {
2689 dest: CpsValueId(3),
2690 op: yulang_typed_ir::PrimitiveOp::IntAdd,
2691 args: vec![CpsValueId(1), CpsValueId(2)],
2692 },
2693 ],
2694 terminator: CpsTerminator::Return(CpsValueId(3)),
2695 },
2696 ],
2697 }],
2698 }));
2699
2700 let optimized = optimize_cps_repr_abi_module(&abi);
2701 let root = &optimized.module.roots[0];
2702 let entry = &root.continuations[0];
2703
2704 assert_eq!(root.continuations.len(), 1);
2705 assert_eq!(entry.stmts.len(), 3);
2706 assert_eq!(
2707 entry.stmts[2],
2708 CpsStmt::Primitive {
2709 dest: CpsValueId(3),
2710 op: yulang_typed_ir::PrimitiveOp::IntAdd,
2711 args: vec![CpsValueId(0), CpsValueId(2)],
2712 }
2713 );
2714 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(3)));
2715 assert_eq!(optimized.profile.inlined_continuation_calls, 1);
2716 assert_eq!(optimized.profile.removed_unreachable_continuations, 1);
2717 assert_eq!(optimized.profile.direct_style_islands, 1);
2718 assert_eq!(optimized.profile.direct_style_continuations, 1);
2719 }
2720
2721 #[test]
2722 fn reifies_direct_calls_to_primitive_wrappers() {
2723 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2724 functions: vec![CpsFunction {
2725 name: "add".to_string(),
2726 params: vec![CpsValueId(0), CpsValueId(1)],
2727 entry: CpsContinuationId(0),
2728 handlers: Vec::new(),
2729 continuations: vec![crate::cps_ir::CpsContinuation {
2730 id: CpsContinuationId(0),
2731 params: vec![CpsValueId(0), CpsValueId(1)],
2732 captures: Vec::new(),
2733 shot_kind: CpsShotKind::MultiShot,
2734 stmts: vec![CpsStmt::Primitive {
2735 dest: CpsValueId(2),
2736 op: typed_ir::PrimitiveOp::IntAdd,
2737 args: vec![CpsValueId(0), CpsValueId(1)],
2738 }],
2739 terminator: CpsTerminator::Return(CpsValueId(2)),
2740 }],
2741 }],
2742 roots: vec![CpsFunction {
2743 name: "root".to_string(),
2744 params: Vec::new(),
2745 entry: CpsContinuationId(0),
2746 handlers: Vec::new(),
2747 continuations: vec![crate::cps_ir::CpsContinuation {
2748 id: CpsContinuationId(0),
2749 params: Vec::new(),
2750 captures: Vec::new(),
2751 shot_kind: CpsShotKind::OneShot,
2752 stmts: vec![
2753 CpsStmt::Literal {
2754 dest: CpsValueId(0),
2755 literal: CpsLiteral::Int("1".to_string()),
2756 },
2757 CpsStmt::Literal {
2758 dest: CpsValueId(1),
2759 literal: CpsLiteral::Int("2".to_string()),
2760 },
2761 CpsStmt::DirectCall {
2762 dest: CpsValueId(2),
2763 target: "add".to_string(),
2764 args: vec![CpsValueId(0), CpsValueId(1)],
2765 },
2766 ],
2767 terminator: CpsTerminator::Return(CpsValueId(2)),
2768 }],
2769 }],
2770 }));
2771
2772 let optimized = optimize_cps_repr_abi_module(&abi);
2773 let entry = &optimized.module.roots[0].continuations[0];
2774
2775 assert_eq!(
2776 entry.stmts[2],
2777 CpsStmt::Primitive {
2778 dest: CpsValueId(2),
2779 op: typed_ir::PrimitiveOp::IntAdd,
2780 args: vec![CpsValueId(0), CpsValueId(1)],
2781 }
2782 );
2783 assert_eq!(optimized.profile.reified_primitive_calls, 1);
2784 assert_eq!(optimized.profile.direct_style_islands, 2);
2785 assert_eq!(optimized.profile.direct_style_continuations, 2);
2786 }
2787
2788 #[test]
2789 fn inlines_small_pure_direct_calls() {
2790 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2791 functions: vec![CpsFunction {
2792 name: "plus_one".to_string(),
2793 params: vec![CpsValueId(0)],
2794 entry: CpsContinuationId(0),
2795 handlers: Vec::new(),
2796 continuations: vec![crate::cps_ir::CpsContinuation {
2797 id: CpsContinuationId(0),
2798 params: vec![CpsValueId(0)],
2799 captures: Vec::new(),
2800 shot_kind: CpsShotKind::OneShot,
2801 stmts: vec![
2802 CpsStmt::Literal {
2803 dest: CpsValueId(1),
2804 literal: CpsLiteral::Int("1".to_string()),
2805 },
2806 CpsStmt::Primitive {
2807 dest: CpsValueId(2),
2808 op: typed_ir::PrimitiveOp::IntAdd,
2809 args: vec![CpsValueId(0), CpsValueId(1)],
2810 },
2811 ],
2812 terminator: CpsTerminator::Return(CpsValueId(2)),
2813 }],
2814 }],
2815 roots: vec![CpsFunction {
2816 name: "root".to_string(),
2817 params: Vec::new(),
2818 entry: CpsContinuationId(0),
2819 handlers: Vec::new(),
2820 continuations: vec![crate::cps_ir::CpsContinuation {
2821 id: CpsContinuationId(0),
2822 params: Vec::new(),
2823 captures: Vec::new(),
2824 shot_kind: CpsShotKind::OneShot,
2825 stmts: vec![
2826 CpsStmt::Literal {
2827 dest: CpsValueId(0),
2828 literal: CpsLiteral::Int("41".to_string()),
2829 },
2830 CpsStmt::DirectCall {
2831 dest: CpsValueId(1),
2832 target: "plus_one".to_string(),
2833 args: vec![CpsValueId(0)],
2834 },
2835 ],
2836 terminator: CpsTerminator::Return(CpsValueId(1)),
2837 }],
2838 }],
2839 }));
2840
2841 let optimized = optimize_cps_repr_abi_module(&abi);
2842 let entry = &optimized.module.roots[0].continuations[0];
2843
2844 assert_eq!(entry.stmts.len(), 3);
2845 assert_eq!(
2846 entry.stmts[1],
2847 CpsStmt::Literal {
2848 dest: CpsValueId(2),
2849 literal: CpsLiteral::Int("1".to_string()),
2850 }
2851 );
2852 assert_eq!(
2853 entry.stmts[2],
2854 CpsStmt::Primitive {
2855 dest: CpsValueId(1),
2856 op: typed_ir::PrimitiveOp::IntAdd,
2857 args: vec![CpsValueId(0), CpsValueId(2)],
2858 }
2859 );
2860 assert_eq!(optimized.profile.inlined_pure_direct_calls, 1);
2861 assert_eq!(optimized.profile.removed_dead_pure_statements, 0);
2862 }
2863
2864 #[test]
2865 fn inlines_small_structural_pure_direct_calls() {
2866 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2867 functions: vec![CpsFunction {
2868 name: "pair".to_string(),
2869 params: vec![CpsValueId(0), CpsValueId(1)],
2870 entry: CpsContinuationId(0),
2871 handlers: Vec::new(),
2872 continuations: vec![crate::cps_ir::CpsContinuation {
2873 id: CpsContinuationId(0),
2874 params: vec![CpsValueId(0), CpsValueId(1)],
2875 captures: Vec::new(),
2876 shot_kind: CpsShotKind::OneShot,
2877 stmts: vec![CpsStmt::Tuple {
2878 dest: CpsValueId(2),
2879 items: vec![CpsValueId(0), CpsValueId(1)],
2880 }],
2881 terminator: CpsTerminator::Return(CpsValueId(2)),
2882 }],
2883 }],
2884 roots: vec![CpsFunction {
2885 name: "root".to_string(),
2886 params: Vec::new(),
2887 entry: CpsContinuationId(0),
2888 handlers: Vec::new(),
2889 continuations: vec![crate::cps_ir::CpsContinuation {
2890 id: CpsContinuationId(0),
2891 params: Vec::new(),
2892 captures: Vec::new(),
2893 shot_kind: CpsShotKind::OneShot,
2894 stmts: vec![
2895 CpsStmt::Literal {
2896 dest: CpsValueId(0),
2897 literal: CpsLiteral::Int("1".to_string()),
2898 },
2899 CpsStmt::Literal {
2900 dest: CpsValueId(1),
2901 literal: CpsLiteral::Int("2".to_string()),
2902 },
2903 CpsStmt::DirectCall {
2904 dest: CpsValueId(2),
2905 target: "pair".to_string(),
2906 args: vec![CpsValueId(0), CpsValueId(1)],
2907 },
2908 ],
2909 terminator: CpsTerminator::Return(CpsValueId(2)),
2910 }],
2911 }],
2912 }));
2913
2914 let optimized = optimize_cps_repr_abi_module(&abi);
2915 let entry = &optimized.module.roots[0].continuations[0];
2916
2917 assert_eq!(
2918 entry.stmts[2],
2919 CpsStmt::Tuple {
2920 dest: CpsValueId(2),
2921 items: vec![CpsValueId(0), CpsValueId(1)],
2922 }
2923 );
2924 assert_eq!(optimized.profile.inlined_pure_direct_calls, 1);
2925 }
2926
2927 #[test]
2928 fn rewrites_effectful_call_to_pure_callee() {
2929 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
2930 functions: vec![CpsFunction {
2931 name: "plus_one".to_string(),
2932 params: vec![CpsValueId(0)],
2933 entry: CpsContinuationId(0),
2934 handlers: Vec::new(),
2935 continuations: vec![
2936 crate::cps_ir::CpsContinuation {
2937 id: CpsContinuationId(0),
2938 params: vec![CpsValueId(0)],
2939 captures: Vec::new(),
2940 shot_kind: CpsShotKind::OneShot,
2941 stmts: vec![
2942 CpsStmt::Literal {
2943 dest: CpsValueId(1),
2944 literal: CpsLiteral::Int("1".to_string()),
2945 },
2946 CpsStmt::Primitive {
2947 dest: CpsValueId(2),
2948 op: typed_ir::PrimitiveOp::IntAdd,
2949 args: vec![CpsValueId(0), CpsValueId(1)],
2950 },
2951 ],
2952 terminator: CpsTerminator::Continue {
2953 target: CpsContinuationId(1),
2954 args: vec![CpsValueId(2)],
2955 },
2956 },
2957 crate::cps_ir::CpsContinuation {
2958 id: CpsContinuationId(1),
2959 params: vec![CpsValueId(3)],
2960 captures: Vec::new(),
2961 shot_kind: CpsShotKind::OneShot,
2962 stmts: Vec::new(),
2963 terminator: CpsTerminator::Return(CpsValueId(3)),
2964 },
2965 ],
2966 }],
2967 roots: vec![CpsFunction {
2968 name: "root".to_string(),
2969 params: Vec::new(),
2970 entry: CpsContinuationId(0),
2971 handlers: Vec::new(),
2972 continuations: vec![
2973 crate::cps_ir::CpsContinuation {
2974 id: CpsContinuationId(0),
2975 params: Vec::new(),
2976 captures: Vec::new(),
2977 shot_kind: CpsShotKind::OneShot,
2978 stmts: vec![CpsStmt::Literal {
2979 dest: CpsValueId(0),
2980 literal: CpsLiteral::Int("41".to_string()),
2981 }],
2982 terminator: CpsTerminator::EffectfulCall {
2983 target: "plus_one".to_string(),
2984 args: vec![CpsValueId(0)],
2985 resume: CpsContinuationId(1),
2986 },
2987 },
2988 crate::cps_ir::CpsContinuation {
2989 id: CpsContinuationId(1),
2990 params: vec![CpsValueId(1)],
2991 captures: Vec::new(),
2992 shot_kind: CpsShotKind::OneShot,
2993 stmts: Vec::new(),
2994 terminator: CpsTerminator::Return(CpsValueId(1)),
2995 },
2996 ],
2997 }],
2998 }));
2999
3000 let optimized = optimize_cps_repr_abi_module(&abi);
3001 let entry = &optimized.module.roots[0].continuations[0];
3002
3003 assert_eq!(
3004 entry.stmts[1],
3005 CpsStmt::Literal {
3006 dest: CpsValueId(3),
3007 literal: CpsLiteral::Int("1".to_string()),
3008 }
3009 );
3010 assert_eq!(
3011 entry.stmts[2],
3012 CpsStmt::Primitive {
3013 dest: CpsValueId(2),
3014 op: typed_ir::PrimitiveOp::IntAdd,
3015 args: vec![CpsValueId(0), CpsValueId(3)],
3016 }
3017 );
3018 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(2)));
3019 assert_eq!(optimized.profile.rewritten_pure_effectful_calls, 1);
3020 assert_eq!(optimized.profile.inlined_pure_direct_calls, 1);
3021 assert_eq!(optimized.profile.returned_continuation_calls, 1);
3022 }
3023
3024 #[test]
3025 fn reifies_local_partial_closure_apply_to_direct_call() {
3026 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3027 functions: vec![CpsFunction {
3028 name: "add".to_string(),
3029 params: vec![CpsValueId(0), CpsValueId(1)],
3030 entry: CpsContinuationId(0),
3031 handlers: Vec::new(),
3032 continuations: vec![crate::cps_ir::CpsContinuation {
3033 id: CpsContinuationId(0),
3034 params: vec![CpsValueId(0), CpsValueId(1)],
3035 captures: Vec::new(),
3036 shot_kind: CpsShotKind::MultiShot,
3037 stmts: vec![CpsStmt::Primitive {
3038 dest: CpsValueId(2),
3039 op: typed_ir::PrimitiveOp::IntAdd,
3040 args: vec![CpsValueId(0), CpsValueId(1)],
3041 }],
3042 terminator: CpsTerminator::Return(CpsValueId(2)),
3043 }],
3044 }],
3045 roots: vec![CpsFunction {
3046 name: "root".to_string(),
3047 params: Vec::new(),
3048 entry: CpsContinuationId(0),
3049 handlers: Vec::new(),
3050 continuations: vec![
3051 crate::cps_ir::CpsContinuation {
3052 id: CpsContinuationId(0),
3053 params: Vec::new(),
3054 captures: Vec::new(),
3055 shot_kind: CpsShotKind::OneShot,
3056 stmts: vec![
3057 CpsStmt::Literal {
3058 dest: CpsValueId(0),
3059 literal: CpsLiteral::Int("40".to_string()),
3060 },
3061 CpsStmt::MakeClosure {
3062 dest: CpsValueId(1),
3063 entry: CpsContinuationId(1),
3064 },
3065 CpsStmt::Literal {
3066 dest: CpsValueId(2),
3067 literal: CpsLiteral::Int("2".to_string()),
3068 },
3069 CpsStmt::ApplyClosure {
3070 dest: CpsValueId(3),
3071 closure: CpsValueId(1),
3072 arg: CpsValueId(2),
3073 },
3074 ],
3075 terminator: CpsTerminator::Return(CpsValueId(3)),
3076 },
3077 crate::cps_ir::CpsContinuation {
3078 id: CpsContinuationId(1),
3079 params: vec![CpsValueId(4)],
3080 captures: vec![CpsValueId(0)],
3081 shot_kind: CpsShotKind::OneShot,
3082 stmts: vec![CpsStmt::DirectCall {
3083 dest: CpsValueId(5),
3084 target: "add".to_string(),
3085 args: vec![CpsValueId(0), CpsValueId(4)],
3086 }],
3087 terminator: CpsTerminator::Return(CpsValueId(5)),
3088 },
3089 ],
3090 }],
3091 }));
3092
3093 let optimized = optimize_cps_repr_abi_module(&abi);
3094 let entry = &optimized.module.roots[0].continuations[0];
3095
3096 assert_eq!(entry.stmts.len(), 3);
3097 assert_eq!(
3098 entry.stmts[2],
3099 CpsStmt::Primitive {
3100 dest: CpsValueId(3),
3101 op: typed_ir::PrimitiveOp::IntAdd,
3102 args: vec![CpsValueId(0), CpsValueId(2)],
3103 }
3104 );
3105 assert_eq!(optimized.profile.reified_partial_closure_calls, 1);
3106 assert_eq!(optimized.profile.removed_unreachable_continuations, 1);
3107 assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3108 assert_eq!(optimized.profile.direct_style_islands, 2);
3109 assert_eq!(optimized.profile.direct_style_continuations, 2);
3110 }
3111
3112 #[test]
3113 fn reifies_partial_closure_apply_after_inline() {
3114 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3115 functions: vec![CpsFunction {
3116 name: "add".to_string(),
3117 params: vec![CpsValueId(0), CpsValueId(1)],
3118 entry: CpsContinuationId(0),
3119 handlers: Vec::new(),
3120 continuations: vec![crate::cps_ir::CpsContinuation {
3121 id: CpsContinuationId(0),
3122 params: vec![CpsValueId(0), CpsValueId(1)],
3123 captures: Vec::new(),
3124 shot_kind: CpsShotKind::MultiShot,
3125 stmts: vec![CpsStmt::Primitive {
3126 dest: CpsValueId(2),
3127 op: typed_ir::PrimitiveOp::IntAdd,
3128 args: vec![CpsValueId(0), CpsValueId(1)],
3129 }],
3130 terminator: CpsTerminator::Return(CpsValueId(2)),
3131 }],
3132 }],
3133 roots: vec![CpsFunction {
3134 name: "root".to_string(),
3135 params: Vec::new(),
3136 entry: CpsContinuationId(0),
3137 handlers: Vec::new(),
3138 continuations: vec![
3139 crate::cps_ir::CpsContinuation {
3140 id: CpsContinuationId(0),
3141 params: Vec::new(),
3142 captures: Vec::new(),
3143 shot_kind: CpsShotKind::OneShot,
3144 stmts: vec![
3145 CpsStmt::Literal {
3146 dest: CpsValueId(0),
3147 literal: CpsLiteral::Int("40".to_string()),
3148 },
3149 CpsStmt::MakeClosure {
3150 dest: CpsValueId(1),
3151 entry: CpsContinuationId(1),
3152 },
3153 CpsStmt::Literal {
3154 dest: CpsValueId(2),
3155 literal: CpsLiteral::Int("2".to_string()),
3156 },
3157 ],
3158 terminator: CpsTerminator::Continue {
3159 target: CpsContinuationId(2),
3160 args: vec![CpsValueId(1), CpsValueId(2)],
3161 },
3162 },
3163 crate::cps_ir::CpsContinuation {
3164 id: CpsContinuationId(1),
3165 params: vec![CpsValueId(4)],
3166 captures: vec![CpsValueId(0)],
3167 shot_kind: CpsShotKind::OneShot,
3168 stmts: vec![CpsStmt::DirectCall {
3169 dest: CpsValueId(5),
3170 target: "add".to_string(),
3171 args: vec![CpsValueId(0), CpsValueId(4)],
3172 }],
3173 terminator: CpsTerminator::Return(CpsValueId(5)),
3174 },
3175 crate::cps_ir::CpsContinuation {
3176 id: CpsContinuationId(2),
3177 params: vec![CpsValueId(6), CpsValueId(7)],
3178 captures: Vec::new(),
3179 shot_kind: CpsShotKind::OneShot,
3180 stmts: vec![CpsStmt::ApplyClosure {
3181 dest: CpsValueId(8),
3182 closure: CpsValueId(6),
3183 arg: CpsValueId(7),
3184 }],
3185 terminator: CpsTerminator::Return(CpsValueId(8)),
3186 },
3187 ],
3188 }],
3189 }));
3190
3191 let optimized = optimize_cps_repr_abi_module(&abi);
3192 let entry = &optimized.module.roots[0].continuations[0];
3193
3194 assert_eq!(entry.stmts.len(), 3);
3195 assert_eq!(
3196 entry.stmts[2],
3197 CpsStmt::Primitive {
3198 dest: CpsValueId(8),
3199 op: typed_ir::PrimitiveOp::IntAdd,
3200 args: vec![CpsValueId(0), CpsValueId(2)],
3201 }
3202 );
3203 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(8)));
3204 assert_eq!(optimized.profile.inlined_continuation_calls, 1);
3205 assert_eq!(optimized.profile.reified_partial_closure_calls, 1);
3206 assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3207 assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3208 assert_eq!(optimized.profile.direct_style_islands, 2);
3209 assert_eq!(optimized.profile.direct_style_continuations, 2);
3210 }
3211
3212 #[test]
3213 fn reifies_uncaptured_closure_apply_through_continuation_parameter() {
3214 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3215 functions: Vec::new(),
3216 roots: vec![CpsFunction {
3217 name: "root".to_string(),
3218 params: Vec::new(),
3219 entry: CpsContinuationId(0),
3220 handlers: Vec::new(),
3221 continuations: vec![
3222 crate::cps_ir::CpsContinuation {
3223 id: CpsContinuationId(0),
3224 params: Vec::new(),
3225 captures: Vec::new(),
3226 shot_kind: CpsShotKind::OneShot,
3227 stmts: vec![
3228 CpsStmt::MakeClosure {
3229 dest: CpsValueId(0),
3230 entry: CpsContinuationId(1),
3231 },
3232 CpsStmt::Literal {
3233 dest: CpsValueId(1),
3234 literal: CpsLiteral::Int("7".to_string()),
3235 },
3236 ],
3237 terminator: CpsTerminator::Continue {
3238 target: CpsContinuationId(2),
3239 args: vec![CpsValueId(0), CpsValueId(1)],
3240 },
3241 },
3242 crate::cps_ir::CpsContinuation {
3243 id: CpsContinuationId(1),
3244 params: vec![CpsValueId(2)],
3245 captures: Vec::new(),
3246 shot_kind: CpsShotKind::OneShot,
3247 stmts: vec![CpsStmt::Primitive {
3248 dest: CpsValueId(3),
3249 op: typed_ir::PrimitiveOp::IntToString,
3250 args: vec![CpsValueId(2)],
3251 }],
3252 terminator: CpsTerminator::Return(CpsValueId(3)),
3253 },
3254 crate::cps_ir::CpsContinuation {
3255 id: CpsContinuationId(2),
3256 params: vec![CpsValueId(4), CpsValueId(5)],
3257 captures: Vec::new(),
3258 shot_kind: CpsShotKind::OneShot,
3259 stmts: vec![CpsStmt::ApplyClosure {
3260 dest: CpsValueId(6),
3261 closure: CpsValueId(4),
3262 arg: CpsValueId(5),
3263 }],
3264 terminator: CpsTerminator::Return(CpsValueId(6)),
3265 },
3266 ],
3267 }],
3268 }));
3269
3270 let optimized = optimize_cps_repr_abi_module(&abi);
3271 let root = &optimized.module.roots[0];
3272 let entry = root
3273 .continuations
3274 .iter()
3275 .find(|continuation| continuation.id == CpsContinuationId(0))
3276 .unwrap();
3277
3278 assert!(root.continuations.iter().all(|continuation| {
3279 continuation
3280 .stmts
3281 .iter()
3282 .all(|stmt| !matches!(stmt, CpsStmt::ApplyClosure { .. }))
3283 }));
3284 assert!(entry.stmts.iter().any(|stmt| {
3285 matches!(
3286 stmt,
3287 CpsStmt::Primitive {
3288 op: typed_ir::PrimitiveOp::IntToString,
3289 args,
3290 ..
3291 } if args == &vec![CpsValueId(1)]
3292 )
3293 }));
3294 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(6)));
3295 assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
3296 assert_eq!(optimized.profile.reified_known_closure_parameter_calls, 1);
3297 assert_eq!(optimized.profile.inlined_continuation_calls, 1);
3298 assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3299 assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3300 }
3301
3302 #[test]
3303 fn reifies_captured_closure_apply_when_captures_are_continuation_parameters() {
3304 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3305 functions: vec![CpsFunction {
3306 name: "add".to_string(),
3307 params: vec![CpsValueId(0), CpsValueId(1)],
3308 entry: CpsContinuationId(0),
3309 handlers: Vec::new(),
3310 continuations: vec![crate::cps_ir::CpsContinuation {
3311 id: CpsContinuationId(0),
3312 params: vec![CpsValueId(0), CpsValueId(1)],
3313 captures: Vec::new(),
3314 shot_kind: CpsShotKind::MultiShot,
3315 stmts: vec![CpsStmt::Primitive {
3316 dest: CpsValueId(2),
3317 op: typed_ir::PrimitiveOp::IntAdd,
3318 args: vec![CpsValueId(0), CpsValueId(1)],
3319 }],
3320 terminator: CpsTerminator::Return(CpsValueId(2)),
3321 }],
3322 }],
3323 roots: vec![CpsFunction {
3324 name: "root".to_string(),
3325 params: Vec::new(),
3326 entry: CpsContinuationId(0),
3327 handlers: Vec::new(),
3328 continuations: vec![
3329 crate::cps_ir::CpsContinuation {
3330 id: CpsContinuationId(0),
3331 params: Vec::new(),
3332 captures: Vec::new(),
3333 shot_kind: CpsShotKind::OneShot,
3334 stmts: vec![
3335 CpsStmt::Literal {
3336 dest: CpsValueId(0),
3337 literal: CpsLiteral::Int("40".to_string()),
3338 },
3339 CpsStmt::MakeClosure {
3340 dest: CpsValueId(1),
3341 entry: CpsContinuationId(1),
3342 },
3343 CpsStmt::Literal {
3344 dest: CpsValueId(2),
3345 literal: CpsLiteral::Int("2".to_string()),
3346 },
3347 ],
3348 terminator: CpsTerminator::Continue {
3349 target: CpsContinuationId(2),
3350 args: vec![CpsValueId(1), CpsValueId(0), CpsValueId(2)],
3351 },
3352 },
3353 crate::cps_ir::CpsContinuation {
3354 id: CpsContinuationId(1),
3355 params: vec![CpsValueId(4)],
3356 captures: vec![CpsValueId(0)],
3357 shot_kind: CpsShotKind::OneShot,
3358 stmts: vec![CpsStmt::DirectCall {
3359 dest: CpsValueId(5),
3360 target: "add".to_string(),
3361 args: vec![CpsValueId(0), CpsValueId(4)],
3362 }],
3363 terminator: CpsTerminator::Return(CpsValueId(5)),
3364 },
3365 crate::cps_ir::CpsContinuation {
3366 id: CpsContinuationId(2),
3367 params: vec![CpsValueId(6), CpsValueId(7), CpsValueId(8)],
3368 captures: Vec::new(),
3369 shot_kind: CpsShotKind::OneShot,
3370 stmts: vec![CpsStmt::ApplyClosure {
3371 dest: CpsValueId(9),
3372 closure: CpsValueId(6),
3373 arg: CpsValueId(8),
3374 }],
3375 terminator: CpsTerminator::Return(CpsValueId(9)),
3376 },
3377 ],
3378 }],
3379 }));
3380
3381 let optimized = optimize_cps_repr_abi_module(&abi);
3382 let root = &optimized.module.roots[0];
3383 let entry = root
3384 .continuations
3385 .iter()
3386 .find(|continuation| continuation.id == CpsContinuationId(0))
3387 .unwrap();
3388
3389 assert!(root.continuations.iter().all(|continuation| {
3390 continuation
3391 .stmts
3392 .iter()
3393 .all(|stmt| !matches!(stmt, CpsStmt::ApplyClosure { .. }))
3394 }));
3395 assert!(entry.stmts.iter().any(|stmt| {
3396 matches!(
3397 stmt,
3398 CpsStmt::Primitive {
3399 op: typed_ir::PrimitiveOp::IntAdd,
3400 args,
3401 ..
3402 } if args == &vec![CpsValueId(0), CpsValueId(2)]
3403 )
3404 }));
3405 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(9)));
3406 assert_eq!(optimized.profile.reified_partial_closure_calls, 0);
3407 assert_eq!(optimized.profile.reified_known_closure_parameter_calls, 1);
3408 assert_eq!(optimized.profile.inlined_continuation_calls, 1);
3409 assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3410 assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3411 }
3412
3413 #[test]
3414 fn reifies_local_effectful_apply_to_known_primitive_closure() {
3415 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3416 functions: Vec::new(),
3417 roots: vec![CpsFunction {
3418 name: "root".to_string(),
3419 params: Vec::new(),
3420 entry: CpsContinuationId(0),
3421 handlers: Vec::new(),
3422 continuations: vec![
3423 crate::cps_ir::CpsContinuation {
3424 id: CpsContinuationId(0),
3425 params: Vec::new(),
3426 captures: Vec::new(),
3427 shot_kind: CpsShotKind::OneShot,
3428 stmts: vec![
3429 CpsStmt::MakeClosure {
3430 dest: CpsValueId(0),
3431 entry: CpsContinuationId(1),
3432 },
3433 CpsStmt::Literal {
3434 dest: CpsValueId(1),
3435 literal: CpsLiteral::Int("7".to_string()),
3436 },
3437 ],
3438 terminator: CpsTerminator::EffectfulApply {
3439 closure: CpsValueId(0),
3440 arg: CpsValueId(1),
3441 resume: CpsContinuationId(2),
3442 },
3443 },
3444 crate::cps_ir::CpsContinuation {
3445 id: CpsContinuationId(1),
3446 params: vec![CpsValueId(2)],
3447 captures: Vec::new(),
3448 shot_kind: CpsShotKind::OneShot,
3449 stmts: vec![CpsStmt::Primitive {
3450 dest: CpsValueId(3),
3451 op: typed_ir::PrimitiveOp::IntToString,
3452 args: vec![CpsValueId(2)],
3453 }],
3454 terminator: CpsTerminator::Return(CpsValueId(3)),
3455 },
3456 crate::cps_ir::CpsContinuation {
3457 id: CpsContinuationId(2),
3458 params: vec![CpsValueId(4)],
3459 captures: Vec::new(),
3460 shot_kind: CpsShotKind::OneShot,
3461 stmts: Vec::new(),
3462 terminator: CpsTerminator::Return(CpsValueId(4)),
3463 },
3464 ],
3465 }],
3466 }));
3467
3468 let optimized = optimize_cps_repr_abi_module(&abi);
3469 let root = &optimized.module.roots[0];
3470 let entry = root
3471 .continuations
3472 .iter()
3473 .find(|continuation| continuation.id == CpsContinuationId(0))
3474 .unwrap();
3475
3476 assert!(root.continuations.iter().all(|continuation| {
3477 !matches!(
3478 continuation.terminator,
3479 CpsTerminator::EffectfulApply { .. }
3480 )
3481 }));
3482 assert!(entry.stmts.iter().any(|stmt| {
3483 matches!(
3484 stmt,
3485 CpsStmt::Primitive {
3486 op: typed_ir::PrimitiveOp::IntToString,
3487 args,
3488 ..
3489 } if args == &vec![CpsValueId(1)]
3490 )
3491 }));
3492 assert!(matches!(entry.terminator, CpsTerminator::Return(_)));
3493 assert_eq!(optimized.profile.reified_partial_closure_calls, 1);
3494 assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3495 assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3496 }
3497
3498 #[test]
3499 fn removes_dead_pure_value_statements() {
3500 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3501 functions: Vec::new(),
3502 roots: vec![CpsFunction {
3503 name: "root".to_string(),
3504 params: Vec::new(),
3505 entry: CpsContinuationId(0),
3506 handlers: Vec::new(),
3507 continuations: vec![crate::cps_ir::CpsContinuation {
3508 id: CpsContinuationId(0),
3509 params: Vec::new(),
3510 captures: Vec::new(),
3511 shot_kind: CpsShotKind::OneShot,
3512 stmts: vec![
3513 CpsStmt::Literal {
3514 dest: CpsValueId(0),
3515 literal: CpsLiteral::Int("1".to_string()),
3516 },
3517 CpsStmt::Literal {
3518 dest: CpsValueId(1),
3519 literal: CpsLiteral::Int("2".to_string()),
3520 },
3521 CpsStmt::Tuple {
3522 dest: CpsValueId(2),
3523 items: vec![CpsValueId(0), CpsValueId(1)],
3524 },
3525 ],
3526 terminator: CpsTerminator::Return(CpsValueId(0)),
3527 }],
3528 }],
3529 }));
3530
3531 let optimized = optimize_cps_repr_abi_module(&abi);
3532 let entry = &optimized.module.roots[0].continuations[0];
3533
3534 assert_eq!(
3535 entry.stmts,
3536 vec![CpsStmt::Literal {
3537 dest: CpsValueId(0),
3538 literal: CpsLiteral::Int("1".to_string()),
3539 }]
3540 );
3541 assert_eq!(optimized.profile.removed_dead_pure_statements, 2);
3542 }
3543
3544 #[test]
3545 fn removes_dead_total_primitives_and_structural_projections() {
3546 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3547 functions: Vec::new(),
3548 roots: vec![CpsFunction {
3549 name: "root".to_string(),
3550 params: Vec::new(),
3551 entry: CpsContinuationId(0),
3552 handlers: Vec::new(),
3553 continuations: vec![crate::cps_ir::CpsContinuation {
3554 id: CpsContinuationId(0),
3555 params: Vec::new(),
3556 captures: Vec::new(),
3557 shot_kind: CpsShotKind::OneShot,
3558 stmts: vec![
3559 CpsStmt::Literal {
3560 dest: CpsValueId(0),
3561 literal: CpsLiteral::Int("1".to_string()),
3562 },
3563 CpsStmt::Literal {
3564 dest: CpsValueId(1),
3565 literal: CpsLiteral::Int("2".to_string()),
3566 },
3567 CpsStmt::Primitive {
3568 dest: CpsValueId(2),
3569 op: typed_ir::PrimitiveOp::IntAdd,
3570 args: vec![CpsValueId(0), CpsValueId(1)],
3571 },
3572 CpsStmt::Tuple {
3573 dest: CpsValueId(3),
3574 items: vec![CpsValueId(0), CpsValueId(1)],
3575 },
3576 CpsStmt::TupleGet {
3577 dest: CpsValueId(4),
3578 tuple: CpsValueId(3),
3579 index: 1,
3580 },
3581 ],
3582 terminator: CpsTerminator::Return(CpsValueId(0)),
3583 }],
3584 }],
3585 }));
3586
3587 let optimized = optimize_cps_repr_abi_module(&abi);
3588 let entry = &optimized.module.roots[0].continuations[0];
3589
3590 assert_eq!(
3591 entry.stmts,
3592 vec![CpsStmt::Literal {
3593 dest: CpsValueId(0),
3594 literal: CpsLiteral::Int("1".to_string()),
3595 }]
3596 );
3597 assert_eq!(optimized.profile.folded_structural_projections, 1);
3598 assert_eq!(optimized.profile.removed_dead_pure_statements, 3);
3599 }
3600
3601 #[test]
3602 fn folds_tuple_get_from_local_tuple() {
3603 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3604 functions: Vec::new(),
3605 roots: vec![CpsFunction {
3606 name: "root".to_string(),
3607 params: Vec::new(),
3608 entry: CpsContinuationId(0),
3609 handlers: Vec::new(),
3610 continuations: vec![crate::cps_ir::CpsContinuation {
3611 id: CpsContinuationId(0),
3612 params: Vec::new(),
3613 captures: Vec::new(),
3614 shot_kind: CpsShotKind::OneShot,
3615 stmts: vec![
3616 CpsStmt::Literal {
3617 dest: CpsValueId(0),
3618 literal: CpsLiteral::Int("1".to_string()),
3619 },
3620 CpsStmt::Literal {
3621 dest: CpsValueId(1),
3622 literal: CpsLiteral::Int("2".to_string()),
3623 },
3624 CpsStmt::Tuple {
3625 dest: CpsValueId(2),
3626 items: vec![CpsValueId(0), CpsValueId(1)],
3627 },
3628 CpsStmt::TupleGet {
3629 dest: CpsValueId(3),
3630 tuple: CpsValueId(2),
3631 index: 1,
3632 },
3633 ],
3634 terminator: CpsTerminator::Return(CpsValueId(3)),
3635 }],
3636 }],
3637 }));
3638
3639 let optimized = optimize_cps_repr_abi_module(&abi);
3640 let entry = &optimized.module.roots[0].continuations[0];
3641
3642 assert_eq!(
3643 entry.stmts,
3644 vec![CpsStmt::Literal {
3645 dest: CpsValueId(1),
3646 literal: CpsLiteral::Int("2".to_string()),
3647 }]
3648 );
3649 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(1)));
3650 assert_eq!(optimized.profile.folded_structural_projections, 1);
3651 assert_eq!(optimized.profile.removed_dead_pure_statements, 2);
3652 }
3653
3654 #[test]
3655 fn removes_unused_multi_use_continuation_parameters() {
3656 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3657 functions: Vec::new(),
3658 roots: vec![CpsFunction {
3659 name: "root".to_string(),
3660 params: Vec::new(),
3661 entry: CpsContinuationId(0),
3662 handlers: Vec::new(),
3663 continuations: vec![
3664 crate::cps_ir::CpsContinuation {
3665 id: CpsContinuationId(0),
3666 params: Vec::new(),
3667 captures: Vec::new(),
3668 shot_kind: CpsShotKind::OneShot,
3669 stmts: vec![
3670 CpsStmt::Literal {
3671 dest: CpsValueId(0),
3672 literal: CpsLiteral::Int("1".to_string()),
3673 },
3674 CpsStmt::Literal {
3675 dest: CpsValueId(8),
3676 literal: CpsLiteral::Bool(false),
3677 },
3678 CpsStmt::Primitive {
3679 dest: CpsValueId(9),
3680 op: typed_ir::PrimitiveOp::BoolNot,
3681 args: vec![CpsValueId(8)],
3682 },
3683 ],
3684 terminator: CpsTerminator::Branch {
3685 cond: CpsValueId(9),
3686 then_cont: CpsContinuationId(1),
3687 else_cont: CpsContinuationId(2),
3688 },
3689 },
3690 crate::cps_ir::CpsContinuation {
3691 id: CpsContinuationId(1),
3692 params: Vec::new(),
3693 captures: Vec::new(),
3694 shot_kind: CpsShotKind::OneShot,
3695 stmts: vec![CpsStmt::Literal {
3696 dest: CpsValueId(2),
3697 literal: CpsLiteral::Int("2".to_string()),
3698 }],
3699 terminator: CpsTerminator::Continue {
3700 target: CpsContinuationId(3),
3701 args: vec![CpsValueId(0), CpsValueId(2)],
3702 },
3703 },
3704 crate::cps_ir::CpsContinuation {
3705 id: CpsContinuationId(2),
3706 params: Vec::new(),
3707 captures: Vec::new(),
3708 shot_kind: CpsShotKind::OneShot,
3709 stmts: vec![CpsStmt::Literal {
3710 dest: CpsValueId(3),
3711 literal: CpsLiteral::Int("3".to_string()),
3712 }],
3713 terminator: CpsTerminator::Continue {
3714 target: CpsContinuationId(3),
3715 args: vec![CpsValueId(0), CpsValueId(3)],
3716 },
3717 },
3718 crate::cps_ir::CpsContinuation {
3719 id: CpsContinuationId(3),
3720 params: vec![CpsValueId(4), CpsValueId(5)],
3721 captures: Vec::new(),
3722 shot_kind: CpsShotKind::OneShot,
3723 stmts: vec![
3724 CpsStmt::Literal {
3725 dest: CpsValueId(6),
3726 literal: CpsLiteral::Int("0".to_string()),
3727 },
3728 CpsStmt::Primitive {
3729 dest: CpsValueId(7),
3730 op: typed_ir::PrimitiveOp::IntAdd,
3731 args: vec![CpsValueId(5), CpsValueId(6)],
3732 },
3733 ],
3734 terminator: CpsTerminator::Return(CpsValueId(7)),
3735 },
3736 ],
3737 }],
3738 }));
3739
3740 let optimized = optimize_cps_repr_abi_module(&abi);
3741 let root = &optimized.module.roots[0];
3742 let join = root
3743 .continuations
3744 .iter()
3745 .find(|continuation| continuation.id == CpsContinuationId(3))
3746 .unwrap();
3747
3748 assert_eq!(
3749 join.params
3750 .iter()
3751 .map(|param| param.value)
3752 .collect::<Vec<_>>(),
3753 vec![CpsValueId(5)]
3754 );
3755 for source in [CpsContinuationId(1), CpsContinuationId(2)] {
3756 let continuation = root
3757 .continuations
3758 .iter()
3759 .find(|continuation| continuation.id == source)
3760 .unwrap();
3761 assert!(matches!(
3762 &continuation.terminator,
3763 CpsTerminator::Continue { args, .. } if args.len() == 1
3764 ));
3765 }
3766 assert_eq!(optimized.profile.removed_unused_continuation_params, 1);
3767 assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3768 }
3769
3770 #[test]
3771 fn folds_constant_bool_branches_before_pruning() {
3772 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3773 functions: Vec::new(),
3774 roots: vec![CpsFunction {
3775 name: "root".to_string(),
3776 params: Vec::new(),
3777 entry: CpsContinuationId(0),
3778 handlers: Vec::new(),
3779 continuations: vec![
3780 crate::cps_ir::CpsContinuation {
3781 id: CpsContinuationId(0),
3782 params: Vec::new(),
3783 captures: Vec::new(),
3784 shot_kind: CpsShotKind::OneShot,
3785 stmts: vec![CpsStmt::Literal {
3786 dest: CpsValueId(0),
3787 literal: CpsLiteral::Bool(true),
3788 }],
3789 terminator: CpsTerminator::Branch {
3790 cond: CpsValueId(0),
3791 then_cont: CpsContinuationId(1),
3792 else_cont: CpsContinuationId(2),
3793 },
3794 },
3795 crate::cps_ir::CpsContinuation {
3796 id: CpsContinuationId(1),
3797 params: Vec::new(),
3798 captures: Vec::new(),
3799 shot_kind: CpsShotKind::OneShot,
3800 stmts: vec![CpsStmt::Literal {
3801 dest: CpsValueId(1),
3802 literal: CpsLiteral::Int("1".to_string()),
3803 }],
3804 terminator: CpsTerminator::Return(CpsValueId(1)),
3805 },
3806 crate::cps_ir::CpsContinuation {
3807 id: CpsContinuationId(2),
3808 params: Vec::new(),
3809 captures: Vec::new(),
3810 shot_kind: CpsShotKind::OneShot,
3811 stmts: vec![CpsStmt::Literal {
3812 dest: CpsValueId(2),
3813 literal: CpsLiteral::Int("2".to_string()),
3814 }],
3815 terminator: CpsTerminator::Return(CpsValueId(2)),
3816 },
3817 ],
3818 }],
3819 }));
3820
3821 let optimized = optimize_cps_repr_abi_module(&abi);
3822 let entry = &optimized.module.roots[0].continuations[0];
3823
3824 assert_eq!(
3825 entry.stmts,
3826 vec![CpsStmt::Literal {
3827 dest: CpsValueId(1),
3828 literal: CpsLiteral::Int("1".to_string()),
3829 }]
3830 );
3831 assert_eq!(entry.terminator, CpsTerminator::Return(CpsValueId(1)));
3832 assert_eq!(optimized.profile.folded_constant_branches, 1);
3833 assert_eq!(optimized.profile.inlined_continuation_calls, 1);
3834 assert_eq!(optimized.profile.removed_unreachable_continuations, 2);
3835 assert_eq!(optimized.profile.removed_dead_pure_statements, 1);
3836 }
3837
3838 #[test]
3839 fn keeps_handler_arm_entries_when_pruning_unreachable_continuations() {
3840 let effect = yulang_typed_ir::Path::from_name(yulang_typed_ir::Name("ask".to_string()));
3841 let abi = lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3842 functions: Vec::new(),
3843 roots: vec![CpsFunction {
3844 name: "root".to_string(),
3845 params: Vec::new(),
3846 entry: CpsContinuationId(0),
3847 handlers: vec![crate::cps_ir::CpsHandler {
3848 id: crate::cps_ir::CpsHandlerId(0),
3849 arms: vec![crate::cps_ir::CpsHandlerArm {
3850 effect,
3851 entry: CpsContinuationId(1),
3852 }],
3853 }],
3854 continuations: vec![
3855 crate::cps_ir::CpsContinuation {
3856 id: CpsContinuationId(0),
3857 params: Vec::new(),
3858 captures: Vec::new(),
3859 shot_kind: CpsShotKind::OneShot,
3860 stmts: vec![CpsStmt::Literal {
3861 dest: CpsValueId(0),
3862 literal: CpsLiteral::Int("1".to_string()),
3863 }],
3864 terminator: CpsTerminator::Return(CpsValueId(0)),
3865 },
3866 crate::cps_ir::CpsContinuation {
3867 id: CpsContinuationId(1),
3868 params: vec![CpsValueId(1), CpsValueId(2)],
3869 captures: Vec::new(),
3870 shot_kind: CpsShotKind::MultiShot,
3871 stmts: Vec::new(),
3872 terminator: CpsTerminator::Return(CpsValueId(1)),
3873 },
3874 crate::cps_ir::CpsContinuation {
3875 id: CpsContinuationId(2),
3876 params: Vec::new(),
3877 captures: Vec::new(),
3878 shot_kind: CpsShotKind::OneShot,
3879 stmts: Vec::new(),
3880 terminator: CpsTerminator::Return(CpsValueId(0)),
3881 },
3882 ],
3883 }],
3884 }));
3885
3886 let optimized = optimize_cps_repr_abi_module(&abi);
3887 let ids = optimized.module.roots[0]
3888 .continuations
3889 .iter()
3890 .map(|continuation| continuation.id)
3891 .collect::<Vec<_>>();
3892
3893 assert_eq!(ids, vec![CpsContinuationId(0), CpsContinuationId(1)]);
3894 assert_eq!(optimized.profile.removed_unreachable_continuations, 1);
3895 }
3896
3897 fn sample_abi_module() -> CpsReprAbiModule {
3898 lower_cps_repr_abi_module(&lower_cps_repr_module(&CpsModule {
3899 functions: Vec::new(),
3900 roots: vec![CpsFunction {
3901 name: "root".to_string(),
3902 params: Vec::new(),
3903 entry: CpsContinuationId(0),
3904 handlers: Vec::new(),
3905 continuations: vec![crate::cps_ir::CpsContinuation {
3906 id: CpsContinuationId(0),
3907 params: Vec::new(),
3908 captures: Vec::new(),
3909 shot_kind: CpsShotKind::OneShot,
3910 stmts: vec![CpsStmt::Literal {
3911 dest: CpsValueId(0),
3912 literal: CpsLiteral::Int("42".to_string()),
3913 }],
3914 terminator: CpsTerminator::Return(CpsValueId(0)),
3915 }],
3916 }],
3917 }))
3918 }
3919}