Skip to main content

yulang_native/
backend_selection.rs

1use std::collections::{HashMap, HashSet};
2use std::fmt;
3
4use yulang_runtime as runtime;
5use yulang_typed_ir as typed_ir;
6
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct NativeBackendPlan {
9    pub roots: Vec<NativeRootBackend>,
10}
11
12impl NativeBackendPlan {
13    pub fn module_backend(&self) -> NativeBackendSelection {
14        self.roots
15            .iter()
16            .find_map(|root| match &root.selection {
17                NativeBackendSelection::CpsMainline { reason } => {
18                    Some(NativeBackendSelection::CpsMainline {
19                        reason: reason.clone(),
20                    })
21                }
22                NativeBackendSelection::ValueFastPath => None,
23                NativeBackendSelection::Unsupported { reason } => {
24                    Some(NativeBackendSelection::Unsupported {
25                        reason: reason.clone(),
26                    })
27                }
28            })
29            .unwrap_or(NativeBackendSelection::ValueFastPath)
30    }
31}
32
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct NativeRootBackend {
35    pub root: NativeRootLabel,
36    pub selection: NativeBackendSelection,
37}
38
39#[derive(Debug, Clone, PartialEq, Eq)]
40pub enum NativeRootLabel {
41    Binding(typed_ir::Path),
42    Expr(usize),
43}
44
45impl fmt::Display for NativeRootLabel {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            NativeRootLabel::Binding(path) => write!(f, "binding {:?}", path),
49            NativeRootLabel::Expr(index) => write!(f, "root expr {index}"),
50        }
51    }
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum NativeBackendSelection {
56    ValueFastPath,
57    CpsMainline { reason: NativeBackendReason },
58    Unsupported { reason: NativeBackendReason },
59}
60
61#[derive(Debug, Clone, PartialEq, Eq)]
62pub struct NativeBackendReason {
63    pub root: NativeRootLabel,
64    pub kind: NativeBackendReasonKind,
65}
66
67impl fmt::Display for NativeBackendReason {
68    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69        write!(f, "{} contains {}", self.root, self.kind)
70    }
71}
72
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74pub enum NativeBackendReasonKind {
75    EffectOperation,
76    Handler,
77    Thunk,
78    ThunkBoundary,
79    ClosureValue,
80    StructuralPatternBinding,
81    EffectIdScope,
82    EffectIdRead,
83}
84
85impl fmt::Display for NativeBackendReasonKind {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        let text = match self {
88            NativeBackendReasonKind::EffectOperation => "effect operation",
89            NativeBackendReasonKind::Handler => "effect handler",
90            NativeBackendReasonKind::Thunk => "thunk",
91            NativeBackendReasonKind::ThunkBoundary => "thunk boundary",
92            NativeBackendReasonKind::ClosureValue => "closure value",
93            NativeBackendReasonKind::StructuralPatternBinding => "structural pattern binding",
94            NativeBackendReasonKind::EffectIdScope => "effect id scope",
95            NativeBackendReasonKind::EffectIdRead => "effect id read",
96        };
97        f.write_str(text)
98    }
99}
100
101pub fn select_native_backends(module: &runtime::Module) -> NativeBackendPlan {
102    let bindings = module
103        .bindings
104        .iter()
105        .map(|binding| (binding.name.clone(), &binding.body))
106        .collect::<HashMap<_, _>>();
107    let roots = module
108        .roots
109        .iter()
110        .map(|root| {
111            let label = root_label(root);
112            let reason = match root {
113                runtime::Root::Binding(path) => bindings
114                    .get(path)
115                    .and_then(|body| first_cps_reason(body, &bindings)),
116                runtime::Root::Expr(index) => module
117                    .root_exprs
118                    .get(*index)
119                    .and_then(|expr| first_cps_reason(expr, &bindings)),
120            };
121            NativeRootBackend {
122                root: label.clone(),
123                selection: reason
124                    .map(|kind| NativeBackendSelection::CpsMainline {
125                        reason: NativeBackendReason { root: label, kind },
126                    })
127                    .unwrap_or(NativeBackendSelection::ValueFastPath),
128            }
129        })
130        .collect();
131    NativeBackendPlan { roots }
132}
133
134fn root_label(root: &runtime::Root) -> NativeRootLabel {
135    match root {
136        runtime::Root::Binding(path) => NativeRootLabel::Binding(path.clone()),
137        runtime::Root::Expr(index) => NativeRootLabel::Expr(*index),
138    }
139}
140
141fn first_cps_reason(
142    root: &runtime::Expr,
143    bindings: &HashMap<typed_ir::Path, &runtime::Expr>,
144) -> Option<NativeBackendReasonKind> {
145    let mut seen_bindings = HashSet::new();
146    first_cps_reason_expr(root, bindings, &mut seen_bindings)
147}
148
149fn first_cps_reason_expr(
150    expr: &runtime::Expr,
151    bindings: &HashMap<typed_ir::Path, &runtime::Expr>,
152    seen_bindings: &mut HashSet<typed_ir::Path>,
153) -> Option<NativeBackendReasonKind> {
154    match &expr.kind {
155        runtime::ExprKind::EffectOp(_) => Some(NativeBackendReasonKind::EffectOperation),
156        runtime::ExprKind::Handle { .. } => Some(NativeBackendReasonKind::Handler),
157        runtime::ExprKind::Thunk { .. } => Some(NativeBackendReasonKind::Thunk),
158        runtime::ExprKind::BindHere { .. } | runtime::ExprKind::AddId { .. } => {
159            Some(NativeBackendReasonKind::ThunkBoundary)
160        }
161        runtime::ExprKind::LocalPushId { .. } => Some(NativeBackendReasonKind::EffectIdScope),
162        runtime::ExprKind::PeekId | runtime::ExprKind::FindId { .. } => {
163            Some(NativeBackendReasonKind::EffectIdRead)
164        }
165        runtime::ExprKind::Var(path) => {
166            if seen_bindings.insert(path.clone()) {
167                let reason = bindings.get(path).and_then(|body| {
168                    if binding_body_shadows_path(path, body) {
169                        Some(NativeBackendReasonKind::StructuralPatternBinding)
170                    } else {
171                        first_cps_reason_expr(body, bindings, seen_bindings)
172                    }
173                });
174                seen_bindings.remove(path);
175                reason
176            } else {
177                None
178            }
179        }
180        runtime::ExprKind::PrimitiveOp(_) | runtime::ExprKind::Lit(_) => None,
181        runtime::ExprKind::Lambda { .. } => Some(NativeBackendReasonKind::ClosureValue),
182        runtime::ExprKind::Apply { callee, arg, .. } => {
183            first_cps_reason_expr(callee, bindings, seen_bindings)
184                .or_else(|| first_cps_reason_expr(arg, bindings, seen_bindings))
185        }
186        runtime::ExprKind::If {
187            cond,
188            then_branch,
189            else_branch,
190            ..
191        } => first_cps_reason_expr(cond, bindings, seen_bindings)
192            .or_else(|| first_cps_reason_expr(then_branch, bindings, seen_bindings))
193            .or_else(|| first_cps_reason_expr(else_branch, bindings, seen_bindings)),
194        runtime::ExprKind::Tuple(items) => items
195            .iter()
196            .find_map(|item| first_cps_reason_expr(item, bindings, seen_bindings)),
197        runtime::ExprKind::Record { fields, spread } => fields
198            .iter()
199            .find_map(|field| first_cps_reason_expr(&field.value, bindings, seen_bindings))
200            .or_else(|| match spread {
201                Some(runtime::RecordSpreadExpr::Head(expr))
202                | Some(runtime::RecordSpreadExpr::Tail(expr)) => {
203                    first_cps_reason_expr(expr, bindings, seen_bindings)
204                }
205                None => None,
206            }),
207        runtime::ExprKind::Variant { value, .. } => value
208            .as_deref()
209            .and_then(|value| first_cps_reason_expr(value, bindings, seen_bindings)),
210        runtime::ExprKind::Select { base, .. } => {
211            first_cps_reason_expr(base, bindings, seen_bindings)
212        }
213        runtime::ExprKind::Match {
214            scrutinee, arms, ..
215        } => first_cps_reason_expr(scrutinee, bindings, seen_bindings).or_else(|| {
216            arms.iter().find_map(|arm| {
217                arm.guard
218                    .as_ref()
219                    .and_then(|guard| first_cps_reason_expr(guard, bindings, seen_bindings))
220                    .or_else(|| first_cps_reason_expr(&arm.body, bindings, seen_bindings))
221            })
222        }),
223        runtime::ExprKind::Block { stmts, tail } => stmts
224            .iter()
225            .find_map(|stmt| match stmt {
226                runtime::Stmt::Let { value, .. } | runtime::Stmt::Expr(value) => {
227                    first_cps_reason_expr(value, bindings, seen_bindings)
228                }
229                runtime::Stmt::Module { body, .. } => {
230                    first_cps_reason_expr(body, bindings, seen_bindings)
231                }
232            })
233            .or_else(|| {
234                tail.as_deref()
235                    .and_then(|tail| first_cps_reason_expr(tail, bindings, seen_bindings))
236            }),
237        runtime::ExprKind::Coerce { expr, .. } | runtime::ExprKind::Pack { expr, .. } => {
238            first_cps_reason_expr(expr, bindings, seen_bindings)
239        }
240    }
241}
242
243fn binding_body_shadows_path(path: &typed_ir::Path, body: &runtime::Expr) -> bool {
244    match &body.kind {
245        runtime::ExprKind::Match { arms, .. } => arms
246            .iter()
247            .any(|arm| pattern_binds_path(&arm.pattern, path)),
248        runtime::ExprKind::Coerce { expr, .. } | runtime::ExprKind::Pack { expr, .. } => {
249            binding_body_shadows_path(path, expr)
250        }
251        _ => false,
252    }
253}
254
255fn pattern_binds_path(pattern: &runtime::Pattern, path: &typed_ir::Path) -> bool {
256    match pattern {
257        runtime::Pattern::Bind { name, .. } => typed_ir::Path::from_name(name.clone()) == *path,
258        runtime::Pattern::Tuple { items, .. } => {
259            items.iter().any(|item| pattern_binds_path(item, path))
260        }
261        runtime::Pattern::List {
262            prefix,
263            spread,
264            suffix,
265            ..
266        } => {
267            prefix.iter().any(|item| pattern_binds_path(item, path))
268                || spread
269                    .as_deref()
270                    .is_some_and(|spread| pattern_binds_path(spread, path))
271                || suffix.iter().any(|item| pattern_binds_path(item, path))
272        }
273        runtime::Pattern::Record { fields, spread, .. } => {
274            fields
275                .iter()
276                .any(|field| pattern_binds_path(&field.pattern, path))
277                || spread.as_ref().is_some_and(|spread| match spread {
278                    runtime::RecordSpreadPattern::Head(pattern)
279                    | runtime::RecordSpreadPattern::Tail(pattern) => {
280                        pattern_binds_path(pattern, path)
281                    }
282                })
283        }
284        runtime::Pattern::Variant {
285            value: Some(value), ..
286        }
287        | runtime::Pattern::As { pattern: value, .. } => pattern_binds_path(value, path),
288        runtime::Pattern::Or { left, right, .. } => {
289            pattern_binds_path(left, path) || pattern_binds_path(right, path)
290        }
291        runtime::Pattern::Wildcard { .. }
292        | runtime::Pattern::Lit { .. }
293        | runtime::Pattern::Variant { value: None, .. } => false,
294    }
295}
296
297#[cfg(test)]
298mod tests {
299    use super::*;
300
301    fn module_with_root(expr: runtime::Expr) -> runtime::Module {
302        runtime::Module {
303            path: typed_ir::Path::default(),
304            bindings: Vec::new(),
305            root_exprs: vec![expr],
306            roots: vec![runtime::Root::Expr(0)],
307            role_impls: Vec::new(),
308        }
309    }
310
311    fn module_with_binding(
312        name: &str,
313        body: runtime::Expr,
314        root: runtime::Expr,
315    ) -> runtime::Module {
316        runtime::Module {
317            path: typed_ir::Path::default(),
318            bindings: vec![runtime::Binding {
319                name: path(name),
320                type_params: Vec::new(),
321                scheme: typed_ir::Scheme {
322                    requirements: Vec::new(),
323                    body: typed_ir::Type::Unknown,
324                },
325                body,
326            }],
327            root_exprs: vec![root],
328            roots: vec![runtime::Root::Expr(0)],
329            role_impls: Vec::new(),
330        }
331    }
332
333    fn path(name: &str) -> typed_ir::Path {
334        typed_ir::Path::from_name(typed_ir::Name(name.to_string()))
335    }
336
337    fn lit_int(value: &str) -> runtime::Expr {
338        runtime::Expr::typed(
339            runtime::ExprKind::Lit(typed_ir::Lit::Int(value.to_string())),
340            runtime::Type::unknown(),
341        )
342    }
343
344    fn var(name: &str) -> runtime::Expr {
345        runtime::Expr::typed(runtime::ExprKind::Var(path(name)), runtime::Type::unknown())
346    }
347
348    fn primitive(op: typed_ir::PrimitiveOp) -> runtime::Expr {
349        runtime::Expr::typed(runtime::ExprKind::PrimitiveOp(op), runtime::Type::unknown())
350    }
351
352    fn apply(callee: runtime::Expr, arg: runtime::Expr) -> runtime::Expr {
353        runtime::Expr::typed(
354            runtime::ExprKind::Apply {
355                callee: Box::new(callee),
356                arg: Box::new(arg),
357                evidence: None,
358                instantiation: None,
359            },
360            runtime::Type::unknown(),
361        )
362    }
363
364    fn list_pattern(items: Vec<runtime::Pattern>) -> runtime::Pattern {
365        runtime::Pattern::List {
366            prefix: items,
367            spread: None,
368            suffix: Vec::new(),
369            ty: runtime::Type::unknown(),
370        }
371    }
372
373    fn bind_pattern(name: &str) -> runtime::Pattern {
374        runtime::Pattern::Bind {
375            name: typed_ir::Name(name.to_string()),
376            ty: runtime::Type::unknown(),
377        }
378    }
379
380    fn identity_lambda() -> runtime::Expr {
381        runtime::Expr::typed(
382            runtime::ExprKind::Lambda {
383                param: typed_ir::Name("x".to_string()),
384                param_effect_annotation: None,
385                param_function_allowed_effects: None,
386                body: Box::new(var("x")),
387            },
388            runtime::Type::unknown(),
389        )
390    }
391
392    #[test]
393    fn selects_value_fast_path_for_pure_root() {
394        let plan = select_native_backends(&module_with_root(lit_int("42")));
395
396        assert_eq!(plan.module_backend(), NativeBackendSelection::ValueFastPath);
397    }
398
399    #[test]
400    fn selects_cps_mainline_for_effect_operation_root() {
401        let expr = runtime::Expr::typed(
402            runtime::ExprKind::EffectOp(path("yield")),
403            runtime::Type::unknown(),
404        );
405        let plan = select_native_backends(&module_with_root(expr));
406
407        assert_eq!(
408            plan.module_backend(),
409            NativeBackendSelection::CpsMainline {
410                reason: NativeBackendReason {
411                    root: NativeRootLabel::Expr(0),
412                    kind: NativeBackendReasonKind::EffectOperation,
413                },
414            }
415        );
416    }
417
418    #[test]
419    fn follows_reachable_binding_when_selecting_backend() {
420        let body = runtime::Expr::typed(
421            runtime::ExprKind::Handle {
422                body: Box::new(lit_int("1")),
423                arms: Vec::new(),
424                evidence: runtime::JoinEvidence {
425                    result: typed_ir::Type::Unknown,
426                },
427                handler: runtime::HandleEffect {
428                    consumes: Vec::new(),
429                    residual_before: None,
430                    residual_after: None,
431                },
432            },
433            runtime::Type::unknown(),
434        );
435        let plan = select_native_backends(&module_with_binding("run", body, var("run")));
436
437        assert_eq!(
438            plan.module_backend(),
439            NativeBackendSelection::CpsMainline {
440                reason: NativeBackendReason {
441                    root: NativeRootLabel::Expr(0),
442                    kind: NativeBackendReasonKind::Handler,
443                },
444            }
445        );
446    }
447
448    #[test]
449    fn selects_cps_mainline_for_closure_value_root() {
450        let expr = identity_lambda();
451        let plan = select_native_backends(&module_with_root(expr));
452
453        assert_eq!(
454            plan.module_backend(),
455            NativeBackendSelection::CpsMainline {
456                reason: NativeBackendReason {
457                    root: NativeRootLabel::Expr(0),
458                    kind: NativeBackendReasonKind::ClosureValue,
459                },
460            }
461        );
462    }
463
464    #[test]
465    fn selects_cps_mainline_for_closure_value_inside_record() {
466        let expr = runtime::Expr::typed(
467            runtime::ExprKind::Record {
468                fields: vec![runtime::RecordExprField {
469                    name: typed_ir::Name("f".to_string()),
470                    value: identity_lambda(),
471                }],
472                spread: None,
473            },
474            runtime::Type::unknown(),
475        );
476        let plan = select_native_backends(&module_with_root(expr));
477
478        assert_eq!(
479            plan.module_backend(),
480            NativeBackendSelection::CpsMainline {
481                reason: NativeBackendReason {
482                    root: NativeRootLabel::Expr(0),
483                    kind: NativeBackendReasonKind::ClosureValue,
484                },
485            }
486        );
487    }
488
489    #[test]
490    fn selects_cps_mainline_for_closure_value_inside_list_primitive() {
491        let expr = apply(
492            primitive(typed_ir::PrimitiveOp::ListSingleton),
493            identity_lambda(),
494        );
495        let plan = select_native_backends(&module_with_root(expr));
496
497        assert_eq!(
498            plan.module_backend(),
499            NativeBackendSelection::CpsMainline {
500                reason: NativeBackendReason {
501                    root: NativeRootLabel::Expr(0),
502                    kind: NativeBackendReasonKind::ClosureValue,
503                },
504            }
505        );
506    }
507
508    #[test]
509    fn selects_cps_mainline_for_self_shadowing_structural_binding() {
510        let body = runtime::Expr::typed(
511            runtime::ExprKind::Match {
512                scrutinee: Box::new(lit_int("0")),
513                arms: vec![runtime::MatchArm {
514                    pattern: list_pattern(vec![bind_pattern("x"), bind_pattern("y")]),
515                    guard: None,
516                    body: var("x"),
517                }],
518                evidence: runtime::JoinEvidence {
519                    result: typed_ir::Type::Unknown,
520                },
521            },
522            runtime::Type::unknown(),
523        );
524        let plan = select_native_backends(&module_with_binding("x", body, var("x")));
525
526        assert_eq!(
527            plan.module_backend(),
528            NativeBackendSelection::CpsMainline {
529                reason: NativeBackendReason {
530                    root: NativeRootLabel::Expr(0),
531                    kind: NativeBackendReasonKind::StructuralPatternBinding,
532                },
533            }
534        );
535    }
536}