1use std::collections::{HashMap, HashSet};
2use std::fmt;
3
4use yulang_runtime as runtime;
5use yulang_typed_ir as typed_ir;
6
7use crate::control_ir::{
8 BlockId, NativeBlock, NativeFunction, NativeLiteral, NativeModule, NativeRecordField,
9 NativeStmt, NativeTerminator, ValueId,
10};
11
12pub type NativeLowerResult<T> = Result<T, NativeLowerError>;
13
14#[derive(Debug, Clone, PartialEq, Eq)]
15pub enum NativeLowerError {
16 UnsupportedRoot {
17 root: runtime::Root,
18 },
19 MissingRootExpr {
20 index: usize,
21 },
22 UnsupportedExpr {
23 kind: &'static str,
24 },
25 UnsupportedPattern {
26 kind: &'static str,
27 },
28 UnsupportedBinding {
29 path: typed_ir::Path,
30 reason: &'static str,
31 },
32 PrimitiveArityMismatch {
33 op: typed_ir::PrimitiveOp,
34 expected: usize,
35 actual: usize,
36 },
37 CallArityMismatch {
38 target: String,
39 expected: usize,
40 actual: usize,
41 },
42}
43
44impl fmt::Display for NativeLowerError {
45 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
46 match self {
47 NativeLowerError::UnsupportedRoot { root } => {
48 write!(f, "native backend does not support root {root:?} yet")
49 }
50 NativeLowerError::MissingRootExpr { index } => {
51 write!(f, "runtime module has no root expression at index {index}")
52 }
53 NativeLowerError::UnsupportedExpr { kind } => {
54 write!(f, "native backend does not support {kind} expressions yet")
55 }
56 NativeLowerError::UnsupportedPattern { kind } => {
57 write!(f, "native backend does not support {kind} patterns yet")
58 }
59 NativeLowerError::UnsupportedBinding { path, reason } => {
60 write!(
61 f,
62 "native backend does not support binding {} yet: {reason}",
63 path_name(path)
64 )
65 }
66 NativeLowerError::PrimitiveArityMismatch {
67 op,
68 expected,
69 actual,
70 } => write!(
71 f,
72 "native backend expected {expected} arguments for primitive {op:?}, got {actual}"
73 ),
74 NativeLowerError::CallArityMismatch {
75 target,
76 expected,
77 actual,
78 } => write!(
79 f,
80 "native backend expected {expected} arguments for call to {target}, got {actual}"
81 ),
82 }
83 }
84}
85
86impl std::error::Error for NativeLowerError {}
87
88pub fn lower_module(module: &runtime::Module) -> NativeLowerResult<NativeModule> {
89 let function_table = module
90 .bindings
91 .iter()
92 .map(|binding| {
93 let info = binding_function_info(binding);
94 (binding.name.clone(), info)
95 })
96 .collect::<HashMap<_, _>>();
97 let mut functions = Vec::new();
98 for binding in &module.bindings {
99 let lowered = lower_binding(binding, &function_table)?;
100 functions.push(lowered.function);
101 functions.extend(lowered.generated);
102 }
103
104 let mut roots = Vec::new();
105 for root in &module.roots {
106 match root {
107 runtime::Root::Expr(index) => {
108 let expr = module
109 .root_exprs
110 .get(*index)
111 .ok_or(NativeLowerError::MissingRootExpr { index: *index })?;
112 let lowered =
113 FunctionLowerer::new(format!("root_{index}"), &function_table, Vec::new())
114 .lower_root(expr)?;
115 roots.push(lowered.function);
116 functions.extend(lowered.generated);
117 }
118 runtime::Root::Binding(path) => {
119 let Some(info) = function_table.get(path) else {
120 return Err(NativeLowerError::UnsupportedRoot { root: root.clone() });
121 };
122 let Some(target) = info.direct_targets.get(&0) else {
123 return Err(NativeLowerError::UnsupportedRoot { root: root.clone() });
124 };
125 roots.push(root_binding_function(roots.len(), target.clone()));
126 }
127 }
128 }
129 Ok(NativeModule { functions, roots })
130}
131
132fn lower_binding(
133 binding: &runtime::Binding,
134 functions: &HashMap<typed_ir::Path, FunctionInfo>,
135) -> NativeLowerResult<LoweredFunction> {
136 if !binding.type_params.is_empty() {
137 return Err(NativeLowerError::UnsupportedBinding {
138 path: binding.name.clone(),
139 reason: "residual type parameters",
140 });
141 }
142 if let runtime::ExprKind::PrimitiveOp(op) = binding.body.kind {
143 let mut lowered = lower_primitive_binding(&binding.name, op)?;
144 if let Some(info) = functions.get(&binding.name) {
145 lowered
146 .generated
147 .extend(partial_application_functions(info));
148 }
149 return Ok(lowered);
150 }
151 if expr_pattern_binds_path(&binding.body, &binding.name) {
152 return Err(NativeLowerError::UnsupportedBinding {
153 path: binding.name.clone(),
154 reason: "top-level structural pattern binding",
155 });
156 }
157 let (params, body) = collect_lambda_params(&binding.body);
158 let mut lowered = FunctionLowerer::new(path_name(&binding.name), functions, params.clone())
159 .lower_root(body)?;
160 let (callable_params, callable_body) = collect_callable_params(&binding.body);
161 if callable_params.len() > params.len() {
162 let info = functions
163 .get(&binding.name)
164 .expect("binding function info was created before lowering");
165 let target = info
166 .direct_targets
167 .get(&callable_params.len())
168 .expect("callable arity target was created before lowering")
169 .clone();
170 let direct =
171 FunctionLowerer::new(target, functions, callable_params).lower_root(&callable_body)?;
172 lowered.generated.push(direct.function);
173 lowered.generated.extend(direct.generated);
174 }
175 if let Some(info) = functions.get(&binding.name) {
176 lowered
177 .generated
178 .extend(partial_application_functions(info));
179 }
180 Ok(lowered)
181}
182
183fn root_binding_function(index: usize, target: String) -> NativeFunction {
184 let dest = ValueId(0);
185 NativeFunction {
186 name: format!("root_binding_{index}"),
187 captures: Vec::new(),
188 params: Vec::new(),
189 blocks: vec![NativeBlock {
190 id: BlockId(0),
191 params: Vec::new(),
192 stmts: vec![NativeStmt::DirectCall {
193 dest,
194 target,
195 args: Vec::new(),
196 }],
197 terminator: NativeTerminator::Return(dest),
198 }],
199 }
200}
201
202fn binding_function_info(binding: &runtime::Binding) -> FunctionInfo {
203 let name = path_name(&binding.name);
204 if let runtime::ExprKind::PrimitiveOp(op) = binding.body.kind {
205 let arity = primitive_arity(op);
206 return FunctionInfo {
207 direct_targets: HashMap::from([(arity, name.clone())]),
208 partial_targets: partial_target_names(&name, arity),
209 name,
210 arity,
211 };
212 }
213 let arity = collect_lambda_params(&binding.body).0.len();
214 let callable_arity = collect_callable_params(&binding.body).0.len();
215 let mut direct_targets = HashMap::from([(arity, name.clone())]);
216 if callable_arity > arity {
217 direct_targets.insert(callable_arity, format!("{name}#direct{callable_arity}"));
218 }
219 FunctionInfo {
220 direct_targets,
221 partial_targets: partial_target_names(&name, arity),
222 name,
223 arity,
224 }
225}
226
227fn partial_target_names(name: &str, arity: usize) -> HashMap<usize, String> {
228 (0..arity)
229 .map(|prefix_len| (prefix_len, format!("{name}#partial{prefix_len}")))
230 .collect()
231}
232
233fn partial_application_functions(info: &FunctionInfo) -> Vec<NativeFunction> {
234 (0..info.arity)
235 .filter_map(|prefix_len| partial_application_function(info, prefix_len))
236 .collect()
237}
238
239fn partial_application_function(info: &FunctionInfo, prefix_len: usize) -> Option<NativeFunction> {
240 let name = info.partial_targets.get(&prefix_len)?.clone();
241 let captures = (0..prefix_len).map(ValueId).collect::<Vec<_>>();
242 let params = (0..=prefix_len).map(ValueId).collect::<Vec<_>>();
243 let dest = ValueId(prefix_len + 1);
244 let prefix_args = (0..=prefix_len).map(ValueId).collect::<Vec<_>>();
245 let stmt = if prefix_len + 1 == info.arity {
246 NativeStmt::DirectCall {
247 dest,
248 target: info.direct_targets.get(&info.arity)?.clone(),
249 args: prefix_args,
250 }
251 } else {
252 NativeStmt::MakeClosure {
253 dest,
254 target: info.partial_targets.get(&(prefix_len + 1))?.clone(),
255 captures: prefix_args,
256 }
257 };
258 Some(NativeFunction {
259 name,
260 captures,
261 params: params.clone(),
262 blocks: vec![NativeBlock {
263 id: BlockId(0),
264 params,
265 stmts: vec![stmt],
266 terminator: NativeTerminator::Return(dest),
267 }],
268 })
269}
270
271fn lower_primitive_binding(
272 path: &typed_ir::Path,
273 op: typed_ir::PrimitiveOp,
274) -> NativeLowerResult<LoweredFunction> {
275 let arity = primitive_arity(op);
276 let params = (0..arity).map(ValueId).collect::<Vec<_>>();
277 let dest = ValueId(arity);
278 Ok(LoweredFunction {
279 function: NativeFunction {
280 name: path_name(path),
281 captures: Vec::new(),
282 params: params.clone(),
283 blocks: vec![NativeBlock {
284 id: BlockId(0),
285 params: params.clone(),
286 stmts: vec![NativeStmt::Primitive {
287 dest,
288 op,
289 args: params,
290 }],
291 terminator: NativeTerminator::Return(dest),
292 }],
293 },
294 generated: Vec::new(),
295 })
296}
297
298struct LoweredFunction {
299 function: NativeFunction,
300 generated: Vec<NativeFunction>,
301}
302
303#[derive(Debug, Clone, PartialEq, Eq)]
304struct FunctionInfo {
305 direct_targets: HashMap<usize, String>,
306 partial_targets: HashMap<usize, String>,
307 name: String,
308 arity: usize,
309}
310
311struct FunctionLowerer<'a> {
312 name: String,
313 functions: &'a HashMap<typed_ir::Path, FunctionInfo>,
314 next_value: usize,
315 next_block: usize,
316 blocks: Vec<NativeBlock>,
317 current: BlockBuilder,
318 locals: HashMap<typed_ir::Path, ValueId>,
319 params: Vec<ValueId>,
320 captures: Vec<ValueId>,
321 generated: Vec<NativeFunction>,
322 next_lambda: usize,
323}
324
325impl<'a> FunctionLowerer<'a> {
326 fn new(
327 name: String,
328 functions: &'a HashMap<typed_ir::Path, FunctionInfo>,
329 params: Vec<typed_ir::Name>,
330 ) -> Self {
331 let mut next_value = 0;
332 let mut param_values = Vec::with_capacity(params.len());
333 let mut locals = HashMap::new();
334 for param in params {
335 let value = ValueId(next_value);
336 next_value += 1;
337 locals.insert(typed_ir::Path::from_name(param), value);
338 param_values.push(value);
339 }
340 Self {
341 name,
342 functions,
343 next_value,
344 next_block: 1,
345 blocks: Vec::new(),
346 current: BlockBuilder::new(BlockId(0), param_values.clone()),
347 locals,
348 params: param_values,
349 captures: Vec::new(),
350 generated: Vec::new(),
351 next_lambda: 0,
352 }
353 }
354
355 fn new_closure(
356 name: String,
357 functions: &'a HashMap<typed_ir::Path, FunctionInfo>,
358 captures: Vec<typed_ir::Path>,
359 param: typed_ir::Name,
360 ) -> Self {
361 let mut next_value = 0;
362 let mut params = Vec::with_capacity(captures.len() + 1);
363 let mut locals = HashMap::new();
364 for capture in captures {
365 let value = ValueId(next_value);
366 next_value += 1;
367 locals.insert(capture, value);
368 params.push(value);
369 }
370 let captures = params.clone();
371 let param_value = ValueId(next_value);
372 next_value += 1;
373 locals.insert(typed_ir::Path::from_name(param), param_value);
374 params.push(param_value);
375 Self {
376 name,
377 functions,
378 next_value,
379 next_block: 1,
380 blocks: Vec::new(),
381 current: BlockBuilder::new(BlockId(0), params.clone()),
382 locals,
383 params,
384 captures,
385 generated: Vec::new(),
386 next_lambda: 0,
387 }
388 }
389
390 fn lower_root(mut self, expr: &runtime::Expr) -> NativeLowerResult<LoweredFunction> {
391 let value = self.lower_expr(expr)?;
392 self.terminate(NativeTerminator::Return(value));
393 self.finish_current();
394 Ok(LoweredFunction {
395 function: NativeFunction {
396 name: self.name,
397 captures: self.captures,
398 params: self.params,
399 blocks: self.blocks,
400 },
401 generated: self.generated,
402 })
403 }
404
405 fn lower_expr(&mut self, expr: &runtime::Expr) -> NativeLowerResult<ValueId> {
406 if let Some((op, args)) = primitive_apply(expr) {
407 let expected = primitive_arity(op);
408 if args.len() != expected {
409 return Err(NativeLowerError::PrimitiveArityMismatch {
410 op,
411 expected,
412 actual: args.len(),
413 });
414 }
415 let args = args
416 .into_iter()
417 .map(|arg| self.lower_expr(arg))
418 .collect::<NativeLowerResult<Vec<_>>>()?;
419 let dest = self.fresh_value();
420 self.current
421 .stmts
422 .push(NativeStmt::Primitive { dest, op, args });
423 return Ok(dest);
424 }
425
426 if let Some((target, args)) = direct_apply(expr, self.functions)? {
427 let args = args
428 .into_iter()
429 .map(|arg| self.lower_expr(arg))
430 .collect::<NativeLowerResult<Vec<_>>>()?;
431 let dest = self.fresh_value();
432 self.current
433 .stmts
434 .push(NativeStmt::DirectCall { dest, target, args });
435 return Ok(dest);
436 }
437
438 if let runtime::ExprKind::Apply { callee, arg, .. } = &expr.kind {
439 let callee = self.lower_expr(callee)?;
440 let arg = self.lower_expr(arg)?;
441 let dest = self.fresh_value();
442 self.current.stmts.push(NativeStmt::ClosureCall {
443 dest,
444 callee,
445 args: vec![arg],
446 });
447 return Ok(dest);
448 }
449
450 match &expr.kind {
451 runtime::ExprKind::Lit(lit) => {
452 let dest = self.fresh_value();
453 self.current.stmts.push(NativeStmt::Literal {
454 dest,
455 literal: lower_literal(lit),
456 });
457 Ok(dest)
458 }
459 runtime::ExprKind::PrimitiveOp(_) => Err(NativeLowerError::UnsupportedExpr {
460 kind: "bare primitive",
461 }),
462 runtime::ExprKind::Var(path) => {
463 if let Some(value) = self.locals.get(path).copied() {
464 return Ok(value);
465 }
466 if let Some(target) = self
467 .functions
468 .get(path)
469 .and_then(|info| info.partial_targets.get(&0))
470 {
471 let dest = self.fresh_value();
472 self.current.stmts.push(NativeStmt::MakeClosure {
473 dest,
474 target: target.clone(),
475 captures: Vec::new(),
476 });
477 return Ok(dest);
478 }
479 if let Some(target) = self
480 .functions
481 .get(path)
482 .and_then(|info| info.direct_targets.get(&0))
483 {
484 let dest = self.fresh_value();
485 self.current.stmts.push(NativeStmt::DirectCall {
486 dest,
487 target: target.clone(),
488 args: Vec::new(),
489 });
490 return Ok(dest);
491 }
492 Err(NativeLowerError::UnsupportedExpr { kind: "free var" })
493 }
494 runtime::ExprKind::EffectOp(_) => Err(NativeLowerError::UnsupportedExpr {
495 kind: "effect operation",
496 }),
497 runtime::ExprKind::Lambda { param, body, .. } => self.lower_lambda(param, body),
498 runtime::ExprKind::Apply { .. } => {
499 Err(NativeLowerError::UnsupportedExpr { kind: "apply" })
500 }
501 runtime::ExprKind::If {
502 cond,
503 then_branch,
504 else_branch,
505 ..
506 } => self.lower_if(cond, then_branch, else_branch),
507 runtime::ExprKind::Tuple(items) => self.lower_tuple(items),
508 runtime::ExprKind::Record { fields, spread } => self.lower_record(fields, spread),
509 runtime::ExprKind::Variant { tag, value } => self.lower_variant(tag, value.as_deref()),
510 runtime::ExprKind::Select { base, field } => self.lower_select(base, field),
511 runtime::ExprKind::Match {
512 scrutinee, arms, ..
513 } => self.lower_match(scrutinee, arms),
514 runtime::ExprKind::Block { stmts, tail } => self.lower_block(stmts, tail.as_deref()),
515 runtime::ExprKind::Handle { body, .. } => self.lower_expr(body),
516 runtime::ExprKind::BindHere { expr } => self.lower_expr(expr),
517 runtime::ExprKind::Thunk { expr, .. } => self.lower_expr(expr),
518 runtime::ExprKind::LocalPushId { body, .. } => self.lower_expr(body),
519 runtime::ExprKind::PeekId => Err(NativeLowerError::UnsupportedExpr { kind: "peek_id" }),
520 runtime::ExprKind::FindId { .. } => {
521 Err(NativeLowerError::UnsupportedExpr { kind: "find_id" })
522 }
523 runtime::ExprKind::AddId { thunk, .. } => self.lower_expr(thunk),
524 runtime::ExprKind::Coerce { expr, .. } => self.lower_expr(expr),
525 runtime::ExprKind::Pack { .. } => {
526 Err(NativeLowerError::UnsupportedExpr { kind: "pack" })
527 }
528 }
529 }
530
531 fn fresh_value(&mut self) -> ValueId {
532 let value = ValueId(self.next_value);
533 self.next_value += 1;
534 value
535 }
536
537 fn fresh_block(&mut self) -> BlockId {
538 let block = BlockId(self.next_block);
539 self.next_block += 1;
540 block
541 }
542
543 fn lower_tuple(&mut self, items: &[runtime::Expr]) -> NativeLowerResult<ValueId> {
544 let items = items
545 .iter()
546 .map(|item| self.lower_expr(item))
547 .collect::<NativeLowerResult<Vec<_>>>()?;
548 let dest = self.fresh_value();
549 self.current.stmts.push(NativeStmt::Tuple { dest, items });
550 Ok(dest)
551 }
552
553 fn lower_record(
554 &mut self,
555 fields: &[runtime::RecordExprField],
556 spread: &Option<runtime::RecordSpreadExpr>,
557 ) -> NativeLowerResult<ValueId> {
558 let base = spread
559 .as_ref()
560 .map(|spread| match spread {
561 runtime::RecordSpreadExpr::Head(expr) | runtime::RecordSpreadExpr::Tail(expr) => {
562 self.lower_expr(expr)
563 }
564 })
565 .transpose()?;
566 let fields = fields
567 .iter()
568 .map(|field| {
569 Ok(NativeRecordField {
570 name: field.name.clone(),
571 value: self.lower_expr(&field.value)?,
572 })
573 })
574 .collect::<NativeLowerResult<Vec<_>>>()?;
575 let dest = self.fresh_value();
576 self.current
577 .stmts
578 .push(NativeStmt::Record { dest, base, fields });
579 Ok(dest)
580 }
581
582 fn lower_variant(
583 &mut self,
584 tag: &typed_ir::Name,
585 value: Option<&runtime::Expr>,
586 ) -> NativeLowerResult<ValueId> {
587 let value = value.map(|value| self.lower_expr(value)).transpose()?;
588 let dest = self.fresh_value();
589 self.current.stmts.push(NativeStmt::Variant {
590 dest,
591 tag: tag.clone(),
592 value,
593 });
594 Ok(dest)
595 }
596
597 fn lower_select(
598 &mut self,
599 base: &runtime::Expr,
600 field: &typed_ir::Name,
601 ) -> NativeLowerResult<ValueId> {
602 let base = self.lower_expr(base)?;
603 let dest = self.fresh_value();
604 self.current.stmts.push(NativeStmt::Select {
605 dest,
606 base,
607 field: field.clone(),
608 });
609 Ok(dest)
610 }
611
612 fn lower_if(
613 &mut self,
614 cond: &runtime::Expr,
615 then_branch: &runtime::Expr,
616 else_branch: &runtime::Expr,
617 ) -> NativeLowerResult<ValueId> {
618 let cond = self.lower_expr(cond)?;
619 let saved_locals = self.locals.clone();
620 let then_block = self.fresh_block();
621 let else_block = self.fresh_block();
622 let merge_block = self.fresh_block();
623 let result = self.fresh_value();
624
625 self.terminate(NativeTerminator::Branch {
626 cond,
627 then_block,
628 else_block,
629 });
630 self.finish_current();
631
632 self.current = BlockBuilder::new(then_block, Vec::new());
633 self.locals = saved_locals.clone();
634 let then_value = self.lower_expr(then_branch)?;
635 self.terminate(NativeTerminator::Jump {
636 target: merge_block,
637 args: vec![then_value],
638 });
639 self.finish_current();
640
641 self.current = BlockBuilder::new(else_block, Vec::new());
642 self.locals = saved_locals.clone();
643 let else_value = self.lower_expr(else_branch)?;
644 self.terminate(NativeTerminator::Jump {
645 target: merge_block,
646 args: vec![else_value],
647 });
648 self.finish_current();
649
650 self.current = BlockBuilder::new(merge_block, vec![result]);
651 self.locals = saved_locals;
652 Ok(result)
653 }
654
655 fn lower_match(
656 &mut self,
657 scrutinee: &runtime::Expr,
658 arms: &[runtime::MatchArm],
659 ) -> NativeLowerResult<ValueId> {
660 if let Some((then_branch, else_branch)) = bool_literal_match_arms(arms) {
661 return self.lower_if(scrutinee, then_branch, else_branch);
662 }
663
664 let saved_locals = self.locals.clone();
665 let merge_block = self.fresh_block();
666 let result = self.fresh_value();
667 let fallback_block = self.fresh_block();
668 let arm_blocks = (0..arms.len())
669 .map(|_| self.fresh_block())
670 .collect::<Vec<_>>();
671
672 let mut current_test_block = None;
673 for (index, arm) in arms.iter().enumerate() {
674 if let Some(test_block) = current_test_block {
675 self.current = BlockBuilder::new(test_block, Vec::new());
676 self.locals = saved_locals.clone();
677 }
678 let scrutinee_value = self.lower_expr(scrutinee)?;
679 let next_block = if index + 1 == arms.len() {
680 fallback_block
681 } else {
682 let next = self.fresh_block();
683 current_test_block = Some(next);
684 next
685 };
686 let matched_block = if arm.guard.is_some() {
687 self.fresh_block()
688 } else {
689 arm_blocks[index]
690 };
691 self.lower_pattern_test(scrutinee_value, &arm.pattern, matched_block, next_block)?;
692 self.finish_current();
693
694 if let Some(guard) = &arm.guard {
695 self.current = BlockBuilder::new(matched_block, Vec::new());
696 self.locals = saved_locals.clone();
697 let scrutinee_value = self.lower_expr(scrutinee)?;
698 self.bind_matched_pattern(&arm.pattern, scrutinee_value)?;
699 let guard_value = self.lower_expr(guard)?;
700 self.terminate(NativeTerminator::Branch {
701 cond: guard_value,
702 then_block: arm_blocks[index],
703 else_block: next_block,
704 });
705 self.finish_current();
706 }
707 }
708
709 self.current = BlockBuilder::new(fallback_block, Vec::new());
710 let unit = self.fresh_value();
711 self.current.stmts.push(NativeStmt::Literal {
712 dest: unit,
713 literal: NativeLiteral::Unit,
714 });
715 self.terminate(NativeTerminator::Jump {
716 target: merge_block,
717 args: vec![unit],
718 });
719 self.finish_current();
720
721 for (arm, arm_block) in arms.iter().zip(arm_blocks) {
722 self.current = BlockBuilder::new(arm_block, Vec::new());
723 self.locals = saved_locals.clone();
724 let scrutinee_value = self.lower_expr(scrutinee)?;
725 self.bind_matched_pattern(&arm.pattern, scrutinee_value)?;
726 let value = self.lower_expr(&arm.body)?;
727 self.terminate(NativeTerminator::Jump {
728 target: merge_block,
729 args: vec![value],
730 });
731 self.finish_current();
732 }
733
734 self.current = BlockBuilder::new(merge_block, vec![result]);
735 self.locals = saved_locals;
736 Ok(result)
737 }
738
739 fn lower_pattern_test(
740 &mut self,
741 value: ValueId,
742 pattern: &runtime::Pattern,
743 then_block: BlockId,
744 else_block: BlockId,
745 ) -> NativeLowerResult<()> {
746 match self.emit_pattern_condition(value, pattern)? {
747 Some(cond) => {
748 self.terminate(NativeTerminator::Branch {
749 cond,
750 then_block,
751 else_block,
752 });
753 Ok(())
754 }
755 None => {
756 self.terminate(NativeTerminator::Jump {
757 target: then_block,
758 args: Vec::new(),
759 });
760 Ok(())
761 }
762 }
763 }
764
765 fn lower_block(
766 &mut self,
767 stmts: &[runtime::Stmt],
768 tail: Option<&runtime::Expr>,
769 ) -> NativeLowerResult<ValueId> {
770 let saved_locals = self.locals.clone();
771 for stmt in stmts {
772 match stmt {
773 runtime::Stmt::Let { pattern, value } => {
774 let value = self.lower_expr(value)?;
775 self.bind_pattern(pattern, value)?;
776 }
777 runtime::Stmt::Expr(expr) => {
778 self.lower_expr(expr)?;
779 }
780 runtime::Stmt::Module { .. } => {
781 self.locals = saved_locals;
782 return Err(NativeLowerError::UnsupportedExpr {
783 kind: "module statement",
784 });
785 }
786 }
787 }
788 let value = match tail {
789 Some(tail) => self.lower_expr(tail)?,
790 None => {
791 let value = self.fresh_value();
792 self.current.stmts.push(NativeStmt::Literal {
793 dest: value,
794 literal: NativeLiteral::Unit,
795 });
796 value
797 }
798 };
799 self.locals = saved_locals;
800 Ok(value)
801 }
802
803 fn lower_lambda(
804 &mut self,
805 param: &typed_ir::Name,
806 body: &runtime::Expr,
807 ) -> NativeLowerResult<ValueId> {
808 let mut bound = HashSet::new();
809 bound.insert(typed_ir::Path::from_name(param.clone()));
810 let mut capture_paths = free_vars(body, &mut bound)
811 .into_iter()
812 .filter(|path| self.locals.contains_key(path))
813 .collect::<Vec<_>>();
814 capture_paths.sort_by_key(path_name);
815 let captures = capture_paths
816 .iter()
817 .map(|path| {
818 self.locals
819 .get(path)
820 .copied()
821 .ok_or(NativeLowerError::UnsupportedExpr {
822 kind: "lambda capture",
823 })
824 })
825 .collect::<NativeLowerResult<Vec<_>>>()?;
826
827 let target = format!("{}#lambda{}", self.name, self.next_lambda);
828 self.next_lambda += 1;
829 let lowered = FunctionLowerer::new_closure(
830 target.clone(),
831 self.functions,
832 capture_paths,
833 param.clone(),
834 )
835 .lower_root(body)?;
836 self.generated.push(lowered.function);
837 self.generated.extend(lowered.generated);
838
839 let dest = self.fresh_value();
840 self.current.stmts.push(NativeStmt::MakeClosure {
841 dest,
842 target,
843 captures,
844 });
845 Ok(dest)
846 }
847
848 fn emit_pattern_condition(
849 &mut self,
850 value: ValueId,
851 pattern: &runtime::Pattern,
852 ) -> NativeLowerResult<Option<ValueId>> {
853 match pattern {
854 runtime::Pattern::Wildcard { .. } | runtime::Pattern::Bind { .. } => Ok(None),
855 runtime::Pattern::Lit { lit, .. } => {
856 let literal = self.fresh_value();
857 self.current.stmts.push(NativeStmt::Literal {
858 dest: literal,
859 literal: lower_literal(lit),
860 });
861 Ok(Some(self.emit_value_eq(value, literal)))
862 }
863 runtime::Pattern::Tuple { items, .. } => {
864 let mut cond = None;
865 for (index, item) in items.iter().enumerate() {
866 let item_value = self.fresh_value();
867 self.current.stmts.push(NativeStmt::TupleGet {
868 dest: item_value,
869 tuple: value,
870 index,
871 });
872 let item_cond = self.emit_pattern_condition(item_value, item)?;
873 cond = self.combine_optional_conditions(cond, item_cond);
874 }
875 Ok(cond)
876 }
877 runtime::Pattern::List {
878 prefix,
879 spread,
880 suffix,
881 ..
882 } => self.emit_list_pattern_condition(prefix, spread.as_deref(), suffix, value),
883 runtime::Pattern::Record { fields, spread, .. } => {
884 let mut cond = None;
885 for field in fields {
886 let field_value = self.fresh_value();
887 self.current.stmts.push(NativeStmt::Select {
888 dest: field_value,
889 base: value,
890 field: field.name.clone(),
891 });
892 let field_cond = self.emit_pattern_condition(field_value, &field.pattern)?;
893 cond = self.combine_optional_conditions(cond, field_cond);
894 }
895 if let Some(spread) = record_spread_pattern(spread.as_ref()) {
896 let rest = self.emit_record_without_fields(value, fields);
897 let spread_cond = self.emit_pattern_condition(rest, spread)?;
898 cond = self.combine_optional_conditions(cond, spread_cond);
899 }
900 Ok(cond)
901 }
902 runtime::Pattern::Variant {
903 tag,
904 value: payload,
905 ..
906 } => {
907 let tag_cond = self.fresh_value();
908 self.current.stmts.push(NativeStmt::VariantTagEq {
909 dest: tag_cond,
910 variant: value,
911 tag: tag.clone(),
912 });
913 let mut cond = Some(tag_cond);
914 if let Some(payload) = payload {
915 let payload_value = self.fresh_value();
916 self.current.stmts.push(NativeStmt::VariantPayload {
917 dest: payload_value,
918 variant: value,
919 });
920 let payload_cond = self.emit_pattern_condition(payload_value, payload)?;
921 cond = self.combine_optional_conditions(cond, payload_cond);
922 }
923 Ok(cond)
924 }
925 runtime::Pattern::Or { .. } => Err(NativeLowerError::UnsupportedPattern { kind: "or" }),
926 runtime::Pattern::As { pattern, .. } => self.emit_pattern_condition(value, pattern),
927 }
928 }
929
930 fn emit_list_pattern_condition(
931 &mut self,
932 prefix: &[runtime::Pattern],
933 spread: Option<&runtime::Pattern>,
934 suffix: &[runtime::Pattern],
935 value: ValueId,
936 ) -> NativeLowerResult<Option<ValueId>> {
937 let len = self.emit_primitive(typed_ir::PrimitiveOp::ListLen, vec![value]);
938 let required = self.emit_int_literal((prefix.len() + suffix.len()) as i64);
939 let op = if spread.is_some() {
940 typed_ir::PrimitiveOp::IntGe
941 } else {
942 typed_ir::PrimitiveOp::IntEq
943 };
944 let mut cond = Some(self.emit_primitive(op, vec![len, required]));
945 for (index, item) in prefix.iter().enumerate() {
946 let index = self.emit_int_literal(index as i64);
947 let item_value =
948 self.emit_primitive(typed_ir::PrimitiveOp::ListIndex, vec![value, index]);
949 let item_cond = self.emit_pattern_condition(item_value, item)?;
950 cond = self.combine_optional_conditions(cond, item_cond);
951 }
952 if let Some(spread) = spread {
953 let start = self.emit_int_literal(prefix.len() as i64);
954 let suffix_len = self.emit_int_literal(suffix.len() as i64);
955 let end = self.emit_primitive(typed_ir::PrimitiveOp::IntSub, vec![len, suffix_len]);
956 let slice = self.emit_primitive(
957 typed_ir::PrimitiveOp::ListIndexRangeRaw,
958 vec![value, start, end],
959 );
960 let spread_cond = self.emit_pattern_condition(slice, spread)?;
961 cond = self.combine_optional_conditions(cond, spread_cond);
962 }
963 for (offset, item) in suffix.iter().enumerate() {
964 let suffix_len = self.emit_int_literal(suffix.len() as i64);
965 let suffix_start =
966 self.emit_primitive(typed_ir::PrimitiveOp::IntSub, vec![len, suffix_len]);
967 let offset = self.emit_int_literal(offset as i64);
968 let index =
969 self.emit_primitive(typed_ir::PrimitiveOp::IntAdd, vec![suffix_start, offset]);
970 let item_value =
971 self.emit_primitive(typed_ir::PrimitiveOp::ListIndex, vec![value, index]);
972 let item_cond = self.emit_pattern_condition(item_value, item)?;
973 cond = self.combine_optional_conditions(cond, item_cond);
974 }
975 Ok(cond)
976 }
977
978 fn bind_pattern(
979 &mut self,
980 pattern: &runtime::Pattern,
981 value: ValueId,
982 ) -> NativeLowerResult<()> {
983 match pattern {
984 runtime::Pattern::Wildcard { .. } => Ok(()),
985 runtime::Pattern::Bind { name, .. } => {
986 self.locals
987 .insert(typed_ir::Path::from_name(name.clone()), value);
988 Ok(())
989 }
990 runtime::Pattern::Lit { .. } => {
991 Err(NativeLowerError::UnsupportedPattern { kind: "literal" })
992 }
993 runtime::Pattern::Tuple { items, .. } => {
994 for (index, item) in items.iter().enumerate() {
995 let item_value = self.fresh_value();
996 self.current.stmts.push(NativeStmt::TupleGet {
997 dest: item_value,
998 tuple: value,
999 index,
1000 });
1001 self.bind_pattern(item, item_value)?;
1002 }
1003 Ok(())
1004 }
1005 runtime::Pattern::List {
1006 prefix,
1007 spread,
1008 suffix,
1009 ..
1010 } if list_pattern_children_are_irrefutable(prefix, spread.as_deref(), suffix) => {
1011 self.bind_list_pattern(prefix, spread.as_deref(), suffix, value)
1012 }
1013 runtime::Pattern::List { .. } => Err(NativeLowerError::UnsupportedPattern {
1014 kind: "list nested refutable",
1015 }),
1016 runtime::Pattern::Record { fields, spread, .. } => {
1017 for field in fields {
1018 let field_value = self.fresh_value();
1019 self.current.stmts.push(NativeStmt::Select {
1020 dest: field_value,
1021 base: value,
1022 field: field.name.clone(),
1023 });
1024 self.bind_pattern(&field.pattern, field_value)?;
1025 }
1026 if let Some(spread) = record_spread_pattern(spread.as_ref()) {
1027 let rest = self.emit_record_without_fields(value, fields);
1028 self.bind_pattern(spread, rest)?;
1029 }
1030 Ok(())
1031 }
1032 runtime::Pattern::Variant {
1033 value: Some(payload),
1034 ..
1035 } => {
1036 let payload_value = self.fresh_value();
1037 self.current.stmts.push(NativeStmt::VariantPayload {
1038 dest: payload_value,
1039 variant: value,
1040 });
1041 self.bind_pattern(payload, payload_value)
1042 }
1043 runtime::Pattern::Variant { value: None, .. } => Ok(()),
1044 runtime::Pattern::Or { .. } => Err(NativeLowerError::UnsupportedPattern { kind: "or" }),
1045 runtime::Pattern::As { pattern, name, .. } => {
1046 self.bind_pattern(pattern, value)?;
1047 self.locals
1048 .insert(typed_ir::Path::from_name(name.clone()), value);
1049 Ok(())
1050 }
1051 }
1052 }
1053
1054 fn bind_matched_pattern(
1055 &mut self,
1056 pattern: &runtime::Pattern,
1057 value: ValueId,
1058 ) -> NativeLowerResult<()> {
1059 match pattern {
1060 runtime::Pattern::Lit { .. } => Ok(()),
1061 runtime::Pattern::Tuple { items, .. } => {
1062 for (index, item) in items.iter().enumerate() {
1063 let item_value = self.fresh_value();
1064 self.current.stmts.push(NativeStmt::TupleGet {
1065 dest: item_value,
1066 tuple: value,
1067 index,
1068 });
1069 self.bind_matched_pattern(item, item_value)?;
1070 }
1071 Ok(())
1072 }
1073 runtime::Pattern::Record { fields, spread, .. } => {
1074 for field in fields {
1075 let field_value = self.fresh_value();
1076 self.current.stmts.push(NativeStmt::Select {
1077 dest: field_value,
1078 base: value,
1079 field: field.name.clone(),
1080 });
1081 self.bind_matched_pattern(&field.pattern, field_value)?;
1082 }
1083 if let Some(spread) = record_spread_pattern(spread.as_ref()) {
1084 let rest = self.emit_record_without_fields(value, fields);
1085 self.bind_matched_pattern(spread, rest)?;
1086 }
1087 Ok(())
1088 }
1089 runtime::Pattern::Variant {
1090 value: Some(payload),
1091 ..
1092 } => {
1093 let payload_value = self.fresh_value();
1094 self.current.stmts.push(NativeStmt::VariantPayload {
1095 dest: payload_value,
1096 variant: value,
1097 });
1098 self.bind_matched_pattern(payload, payload_value)
1099 }
1100 runtime::Pattern::List {
1101 prefix,
1102 spread,
1103 suffix,
1104 ..
1105 } => self.bind_matched_list_pattern(prefix, spread.as_deref(), suffix, value),
1106 runtime::Pattern::As { pattern, name, .. } => {
1107 self.bind_matched_pattern(pattern, value)?;
1108 self.locals
1109 .insert(typed_ir::Path::from_name(name.clone()), value);
1110 Ok(())
1111 }
1112 _ => self.bind_pattern(pattern, value),
1113 }
1114 }
1115
1116 fn bind_list_pattern(
1117 &mut self,
1118 prefix: &[runtime::Pattern],
1119 spread: Option<&runtime::Pattern>,
1120 suffix: &[runtime::Pattern],
1121 value: ValueId,
1122 ) -> NativeLowerResult<()> {
1123 let len = if spread.is_some() || !suffix.is_empty() {
1124 Some(self.emit_primitive(typed_ir::PrimitiveOp::ListLen, vec![value]))
1125 } else {
1126 None
1127 };
1128 for (index, item) in prefix.iter().enumerate() {
1129 let index = self.emit_int_literal(index as i64);
1130 let item_value =
1131 self.emit_primitive(typed_ir::PrimitiveOp::ListIndex, vec![value, index]);
1132 self.bind_pattern(item, item_value)?;
1133 }
1134 if let Some(spread) = spread {
1135 let start = self.emit_int_literal(prefix.len() as i64);
1136 let suffix_len = self.emit_int_literal(suffix.len() as i64);
1137 let end = self.emit_primitive(
1138 typed_ir::PrimitiveOp::IntSub,
1139 vec![len.expect("list spread requires len"), suffix_len],
1140 );
1141 let slice = self.emit_primitive(
1142 typed_ir::PrimitiveOp::ListIndexRangeRaw,
1143 vec![value, start, end],
1144 );
1145 self.bind_pattern(spread, slice)?;
1146 }
1147 for (offset, item) in suffix.iter().enumerate() {
1148 let suffix_len = self.emit_int_literal(suffix.len() as i64);
1149 let suffix_start = self.emit_primitive(
1150 typed_ir::PrimitiveOp::IntSub,
1151 vec![len.expect("list suffix requires len"), suffix_len],
1152 );
1153 let offset = self.emit_int_literal(offset as i64);
1154 let index =
1155 self.emit_primitive(typed_ir::PrimitiveOp::IntAdd, vec![suffix_start, offset]);
1156 let item_value =
1157 self.emit_primitive(typed_ir::PrimitiveOp::ListIndex, vec![value, index]);
1158 self.bind_pattern(item, item_value)?;
1159 }
1160 Ok(())
1161 }
1162
1163 fn bind_matched_list_pattern(
1164 &mut self,
1165 prefix: &[runtime::Pattern],
1166 spread: Option<&runtime::Pattern>,
1167 suffix: &[runtime::Pattern],
1168 value: ValueId,
1169 ) -> NativeLowerResult<()> {
1170 let len = if spread.is_some() || !suffix.is_empty() {
1171 Some(self.emit_primitive(typed_ir::PrimitiveOp::ListLen, vec![value]))
1172 } else {
1173 None
1174 };
1175 for (index, item) in prefix.iter().enumerate() {
1176 let index = self.emit_int_literal(index as i64);
1177 let item_value =
1178 self.emit_primitive(typed_ir::PrimitiveOp::ListIndex, vec![value, index]);
1179 self.bind_matched_pattern(item, item_value)?;
1180 }
1181 if let Some(spread) = spread {
1182 let start = self.emit_int_literal(prefix.len() as i64);
1183 let suffix_len = self.emit_int_literal(suffix.len() as i64);
1184 let end = self.emit_primitive(
1185 typed_ir::PrimitiveOp::IntSub,
1186 vec![len.expect("list spread requires len"), suffix_len],
1187 );
1188 let slice = self.emit_primitive(
1189 typed_ir::PrimitiveOp::ListIndexRangeRaw,
1190 vec![value, start, end],
1191 );
1192 self.bind_matched_pattern(spread, slice)?;
1193 }
1194 for (offset, item) in suffix.iter().enumerate() {
1195 let suffix_len = self.emit_int_literal(suffix.len() as i64);
1196 let suffix_start = self.emit_primitive(
1197 typed_ir::PrimitiveOp::IntSub,
1198 vec![len.expect("list suffix requires len"), suffix_len],
1199 );
1200 let offset = self.emit_int_literal(offset as i64);
1201 let index =
1202 self.emit_primitive(typed_ir::PrimitiveOp::IntAdd, vec![suffix_start, offset]);
1203 let item_value =
1204 self.emit_primitive(typed_ir::PrimitiveOp::ListIndex, vec![value, index]);
1205 self.bind_matched_pattern(item, item_value)?;
1206 }
1207 Ok(())
1208 }
1209
1210 fn emit_record_without_fields(
1211 &mut self,
1212 value: ValueId,
1213 fields: &[runtime::RecordPatternField],
1214 ) -> ValueId {
1215 let dest = self.fresh_value();
1216 self.current.stmts.push(NativeStmt::RecordWithoutFields {
1217 dest,
1218 base: value,
1219 fields: fields.iter().map(|field| field.name.clone()).collect(),
1220 });
1221 dest
1222 }
1223
1224 fn emit_int_literal(&mut self, value: i64) -> ValueId {
1225 let dest = self.fresh_value();
1226 self.current.stmts.push(NativeStmt::Literal {
1227 dest,
1228 literal: NativeLiteral::Int(value.to_string()),
1229 });
1230 dest
1231 }
1232
1233 fn emit_primitive(&mut self, op: typed_ir::PrimitiveOp, args: Vec<ValueId>) -> ValueId {
1234 let dest = self.fresh_value();
1235 self.current
1236 .stmts
1237 .push(NativeStmt::Primitive { dest, op, args });
1238 dest
1239 }
1240
1241 fn emit_value_eq(&mut self, left: ValueId, right: ValueId) -> ValueId {
1242 let dest = self.fresh_value();
1243 self.current
1244 .stmts
1245 .push(NativeStmt::ValueEq { dest, left, right });
1246 dest
1247 }
1248
1249 fn emit_bool_and(&mut self, left: ValueId, right: ValueId) -> ValueId {
1250 let dest = self.fresh_value();
1251 self.current
1252 .stmts
1253 .push(NativeStmt::BoolAnd { dest, left, right });
1254 dest
1255 }
1256
1257 fn combine_optional_conditions(
1258 &mut self,
1259 left: Option<ValueId>,
1260 right: Option<ValueId>,
1261 ) -> Option<ValueId> {
1262 match (left, right) {
1263 (Some(left), Some(right)) => Some(self.emit_bool_and(left, right)),
1264 (Some(cond), None) | (None, Some(cond)) => Some(cond),
1265 (None, None) => None,
1266 }
1267 }
1268
1269 fn terminate(&mut self, terminator: NativeTerminator) {
1270 self.current.terminator = Some(terminator);
1271 }
1272
1273 fn finish_current(&mut self) {
1274 let terminator = self
1275 .current
1276 .terminator
1277 .take()
1278 .expect("native lowerer finished an unterminated block");
1279 self.blocks.push(NativeBlock {
1280 id: self.current.id,
1281 params: std::mem::take(&mut self.current.params),
1282 stmts: std::mem::take(&mut self.current.stmts),
1283 terminator,
1284 });
1285 }
1286}
1287
1288fn bool_literal_match_arms(arms: &[runtime::MatchArm]) -> Option<(&runtime::Expr, &runtime::Expr)> {
1289 let mut then_branch = None;
1290 let mut else_branch = None;
1291 for arm in arms {
1292 if arm.guard.is_some() {
1293 return None;
1294 }
1295 match &arm.pattern {
1296 runtime::Pattern::Lit {
1297 lit: typed_ir::Lit::Bool(true),
1298 ..
1299 } if then_branch.is_none() => then_branch = Some(&arm.body),
1300 runtime::Pattern::Lit {
1301 lit: typed_ir::Lit::Bool(false),
1302 ..
1303 } if else_branch.is_none() => else_branch = Some(&arm.body),
1304 _ => return None,
1305 }
1306 }
1307 Some((then_branch?, else_branch?))
1308}
1309
1310fn pattern_has_refutable_child(pattern: &runtime::Pattern) -> bool {
1311 match pattern {
1312 runtime::Pattern::Wildcard { .. } | runtime::Pattern::Bind { .. } => false,
1313 runtime::Pattern::Lit { .. }
1314 | runtime::Pattern::List { .. }
1315 | runtime::Pattern::Variant { .. }
1316 | runtime::Pattern::Or { .. } => true,
1317 runtime::Pattern::Tuple { items, .. } => items.iter().any(pattern_has_refutable_child),
1318 runtime::Pattern::Record { fields, spread, .. } => {
1319 fields
1320 .iter()
1321 .any(|field| pattern_has_refutable_child(&field.pattern))
1322 || record_spread_pattern(spread.as_ref()).is_some_and(pattern_has_refutable_child)
1323 }
1324 runtime::Pattern::As { pattern, .. } => pattern_has_refutable_child(pattern),
1325 }
1326}
1327
1328fn list_pattern_children_are_irrefutable(
1329 prefix: &[runtime::Pattern],
1330 spread: Option<&runtime::Pattern>,
1331 suffix: &[runtime::Pattern],
1332) -> bool {
1333 prefix
1334 .iter()
1335 .chain(spread)
1336 .chain(suffix)
1337 .all(|pattern| !pattern_has_refutable_child(pattern))
1338}
1339
1340fn record_spread_pattern(
1341 spread: Option<&runtime::RecordSpreadPattern>,
1342) -> Option<&runtime::Pattern> {
1343 match spread {
1344 Some(runtime::RecordSpreadPattern::Head(pattern))
1345 | Some(runtime::RecordSpreadPattern::Tail(pattern)) => Some(pattern),
1346 None => None,
1347 }
1348}
1349
1350fn collect_lambda_params(expr: &runtime::Expr) -> (Vec<typed_ir::Name>, &runtime::Expr) {
1351 let mut params = Vec::new();
1352 let mut current = expr;
1353 while let runtime::ExprKind::Lambda { param, body, .. } = ¤t.kind {
1354 params.push(param.clone());
1355 current = body;
1356 }
1357 (params, current)
1358}
1359
1360fn collect_callable_params(expr: &runtime::Expr) -> (Vec<typed_ir::Name>, runtime::Expr) {
1361 let (mut params, body) = collect_lambda_params(expr);
1362 let mut body = body.clone();
1363 while let runtime::ExprKind::Block {
1364 stmts,
1365 tail: Some(tail),
1366 } = &body.kind
1367 {
1368 let (tail_params, tail_body) = collect_lambda_params(tail);
1369 if tail_params.is_empty() {
1370 break;
1371 }
1372 params.extend(tail_params);
1373 body = runtime::Expr::typed(
1374 runtime::ExprKind::Block {
1375 stmts: stmts.clone(),
1376 tail: Some(Box::new(tail_body.clone())),
1377 },
1378 body.ty.clone(),
1379 );
1380 }
1381 (params, body)
1382}
1383
1384fn free_vars(expr: &runtime::Expr, bound: &mut HashSet<typed_ir::Path>) -> HashSet<typed_ir::Path> {
1385 match &expr.kind {
1386 runtime::ExprKind::Var(path) => {
1387 if bound.contains(path) {
1388 HashSet::new()
1389 } else {
1390 {
1391 let mut set = HashSet::new();
1392 set.insert(path.clone());
1393 set
1394 }
1395 }
1396 }
1397 runtime::ExprKind::Lambda { param, body, .. } => {
1398 let path = typed_ir::Path::from_name(param.clone());
1399 let inserted = bound.insert(path.clone());
1400 let vars = free_vars(body, bound);
1401 if inserted {
1402 bound.remove(&path);
1403 }
1404 vars
1405 }
1406 runtime::ExprKind::Apply { callee, arg, .. } => {
1407 let mut vars = free_vars(callee, bound);
1408 vars.extend(free_vars(arg, bound));
1409 vars
1410 }
1411 runtime::ExprKind::If {
1412 cond,
1413 then_branch,
1414 else_branch,
1415 ..
1416 } => {
1417 let mut vars = free_vars(cond, bound);
1418 vars.extend(free_vars(then_branch, bound));
1419 vars.extend(free_vars(else_branch, bound));
1420 vars
1421 }
1422 runtime::ExprKind::Block { stmts, tail } => {
1423 let mut vars = HashSet::new();
1424 let mut inserted = Vec::new();
1425 for stmt in stmts {
1426 match stmt {
1427 runtime::Stmt::Let { pattern, value } => {
1428 vars.extend(free_vars(value, bound));
1429 for name in pattern_bind_names(pattern) {
1430 let path = typed_ir::Path::from_name(name);
1431 if bound.insert(path.clone()) {
1432 inserted.push(path);
1433 }
1434 }
1435 }
1436 runtime::Stmt::Expr(expr) => vars.extend(free_vars(expr, bound)),
1437 runtime::Stmt::Module { body, .. } => vars.extend(free_vars(body, bound)),
1438 }
1439 }
1440 if let Some(tail) = tail {
1441 vars.extend(free_vars(tail, bound));
1442 }
1443 for path in inserted {
1444 bound.remove(&path);
1445 }
1446 vars
1447 }
1448 runtime::ExprKind::Tuple(items) => {
1449 let mut vars = HashSet::new();
1450 for item in items {
1451 vars.extend(free_vars(item, bound));
1452 }
1453 vars
1454 }
1455 runtime::ExprKind::Record { fields, spread } => {
1456 let mut vars = HashSet::new();
1457 for field in fields {
1458 vars.extend(free_vars(&field.value, bound));
1459 }
1460 if let Some(spread) = spread {
1461 match spread {
1462 runtime::RecordSpreadExpr::Head(expr)
1463 | runtime::RecordSpreadExpr::Tail(expr) => vars.extend(free_vars(expr, bound)),
1464 }
1465 }
1466 vars
1467 }
1468 runtime::ExprKind::Variant { value, .. } => value
1469 .as_deref()
1470 .map(|value| free_vars(value, bound))
1471 .unwrap_or_default(),
1472 runtime::ExprKind::Select { base, .. } => free_vars(base, bound),
1473 runtime::ExprKind::Match {
1474 scrutinee, arms, ..
1475 } => {
1476 let mut vars = free_vars(scrutinee, bound);
1477 for arm in arms {
1478 let mut arm_bound = bound.clone();
1479 for name in pattern_bind_names(&arm.pattern) {
1480 arm_bound.insert(typed_ir::Path::from_name(name));
1481 }
1482 if let Some(guard) = &arm.guard {
1483 vars.extend(free_vars(guard, &mut arm_bound));
1484 }
1485 vars.extend(free_vars(&arm.body, &mut arm_bound));
1486 }
1487 vars
1488 }
1489 runtime::ExprKind::Handle { body, arms, .. } => {
1490 let mut vars = free_vars(body, bound);
1491 for arm in arms {
1492 let mut arm_bound = bound.clone();
1493 for name in pattern_bind_names(&arm.payload) {
1494 arm_bound.insert(typed_ir::Path::from_name(name));
1495 }
1496 if let Some(resume) = &arm.resume {
1497 arm_bound.insert(typed_ir::Path::from_name(resume.name.clone()));
1498 }
1499 if let Some(guard) = &arm.guard {
1500 vars.extend(free_vars(guard, &mut arm_bound));
1501 }
1502 vars.extend(free_vars(&arm.body, &mut arm_bound));
1503 }
1504 vars
1505 }
1506 runtime::ExprKind::BindHere { expr }
1507 | runtime::ExprKind::Thunk { expr, .. }
1508 | runtime::ExprKind::Coerce { expr, .. }
1509 | runtime::ExprKind::Pack { expr, .. } => free_vars(expr, bound),
1510 runtime::ExprKind::LocalPushId { body, .. } => free_vars(body, bound),
1511 runtime::ExprKind::AddId { thunk, .. } => free_vars(thunk, bound),
1512 runtime::ExprKind::PrimitiveOp(_)
1513 | runtime::ExprKind::EffectOp(_)
1514 | runtime::ExprKind::Lit(_)
1515 | runtime::ExprKind::PeekId
1516 | runtime::ExprKind::FindId { .. } => HashSet::new(),
1517 }
1518}
1519
1520fn pattern_bind_names(pattern: &runtime::Pattern) -> Vec<typed_ir::Name> {
1521 match pattern {
1522 runtime::Pattern::Bind { name, .. } => vec![name.clone()],
1523 runtime::Pattern::Tuple { items, .. } => {
1524 items.iter().flat_map(pattern_bind_names).collect()
1525 }
1526 runtime::Pattern::List {
1527 prefix,
1528 spread,
1529 suffix,
1530 ..
1531 } => {
1532 let mut names = prefix
1533 .iter()
1534 .flat_map(pattern_bind_names)
1535 .collect::<Vec<_>>();
1536 if let Some(spread) = spread {
1537 names.extend(pattern_bind_names(spread));
1538 }
1539 names.extend(suffix.iter().flat_map(pattern_bind_names));
1540 names
1541 }
1542 runtime::Pattern::Record { fields, spread, .. } => {
1543 let mut names = fields
1544 .iter()
1545 .flat_map(|field| pattern_bind_names(&field.pattern))
1546 .collect::<Vec<_>>();
1547 if let Some(spread) = spread {
1548 match spread {
1549 runtime::RecordSpreadPattern::Head(pattern)
1550 | runtime::RecordSpreadPattern::Tail(pattern) => {
1551 names.extend(pattern_bind_names(pattern));
1552 }
1553 }
1554 }
1555 names
1556 }
1557 runtime::Pattern::Variant { value, .. } => {
1558 value.as_deref().map(pattern_bind_names).unwrap_or_default()
1559 }
1560 runtime::Pattern::Or { left, right, .. } => {
1561 let mut names = pattern_bind_names(left);
1562 names.extend(pattern_bind_names(right));
1563 names
1564 }
1565 runtime::Pattern::As { pattern, name, .. } => {
1566 let mut names = pattern_bind_names(pattern);
1567 names.push(name.clone());
1568 names
1569 }
1570 runtime::Pattern::Wildcard { .. } | runtime::Pattern::Lit { .. } => Vec::new(),
1571 }
1572}
1573
1574fn expr_pattern_binds_path(expr: &runtime::Expr, path: &typed_ir::Path) -> bool {
1575 match &expr.kind {
1576 runtime::ExprKind::Match {
1577 scrutinee, arms, ..
1578 } => {
1579 expr_pattern_binds_path(scrutinee, path)
1580 || arms.iter().any(|arm| {
1581 pattern_binds_path(&arm.pattern, path)
1582 || arm
1583 .guard
1584 .as_ref()
1585 .is_some_and(|guard| expr_pattern_binds_path(guard, path))
1586 || expr_pattern_binds_path(&arm.body, path)
1587 })
1588 }
1589 runtime::ExprKind::Lambda { body, .. } => expr_pattern_binds_path(body, path),
1590 runtime::ExprKind::Apply { callee, arg, .. } => {
1591 expr_pattern_binds_path(callee, path) || expr_pattern_binds_path(arg, path)
1592 }
1593 runtime::ExprKind::If {
1594 cond,
1595 then_branch,
1596 else_branch,
1597 ..
1598 } => {
1599 expr_pattern_binds_path(cond, path)
1600 || expr_pattern_binds_path(then_branch, path)
1601 || expr_pattern_binds_path(else_branch, path)
1602 }
1603 runtime::ExprKind::Block { stmts, tail } => {
1604 stmts.iter().any(|stmt| match stmt {
1605 runtime::Stmt::Let { pattern, value } => {
1606 pattern_binds_path(pattern, path) || expr_pattern_binds_path(value, path)
1607 }
1608 runtime::Stmt::Expr(expr) | runtime::Stmt::Module { body: expr, .. } => {
1609 expr_pattern_binds_path(expr, path)
1610 }
1611 }) || tail
1612 .as_deref()
1613 .is_some_and(|tail| expr_pattern_binds_path(tail, path))
1614 }
1615 runtime::ExprKind::Tuple(items) => {
1616 items.iter().any(|item| expr_pattern_binds_path(item, path))
1617 }
1618 runtime::ExprKind::Record { fields, spread } => {
1619 fields
1620 .iter()
1621 .any(|field| expr_pattern_binds_path(&field.value, path))
1622 || spread.as_ref().is_some_and(|spread| match spread {
1623 runtime::RecordSpreadExpr::Head(expr)
1624 | runtime::RecordSpreadExpr::Tail(expr) => expr_pattern_binds_path(expr, path),
1625 })
1626 }
1627 runtime::ExprKind::Variant { value, .. } => value
1628 .as_deref()
1629 .is_some_and(|value| expr_pattern_binds_path(value, path)),
1630 runtime::ExprKind::Select { base, .. } => expr_pattern_binds_path(base, path),
1631 runtime::ExprKind::Handle { body, arms, .. } => {
1632 expr_pattern_binds_path(body, path)
1633 || arms.iter().any(|arm| {
1634 pattern_binds_path(&arm.payload, path)
1635 || arm
1636 .guard
1637 .as_ref()
1638 .is_some_and(|guard| expr_pattern_binds_path(guard, path))
1639 || expr_pattern_binds_path(&arm.body, path)
1640 })
1641 }
1642 runtime::ExprKind::BindHere { expr }
1643 | runtime::ExprKind::Thunk { expr, .. }
1644 | runtime::ExprKind::Coerce { expr, .. }
1645 | runtime::ExprKind::Pack { expr, .. }
1646 | runtime::ExprKind::LocalPushId { body: expr, .. }
1647 | runtime::ExprKind::AddId { thunk: expr, .. } => expr_pattern_binds_path(expr, path),
1648 runtime::ExprKind::Var(_)
1649 | runtime::ExprKind::Lit(_)
1650 | runtime::ExprKind::PrimitiveOp(_)
1651 | runtime::ExprKind::EffectOp(_)
1652 | runtime::ExprKind::PeekId
1653 | runtime::ExprKind::FindId { .. } => false,
1654 }
1655}
1656
1657fn pattern_binds_path(pattern: &runtime::Pattern, path: &typed_ir::Path) -> bool {
1658 match pattern {
1659 runtime::Pattern::Bind { name, .. } => typed_ir::Path::from_name(name.clone()) == *path,
1660 runtime::Pattern::Tuple { items, .. } => {
1661 items.iter().any(|item| pattern_binds_path(item, path))
1662 }
1663 runtime::Pattern::List {
1664 prefix,
1665 spread,
1666 suffix,
1667 ..
1668 } => {
1669 prefix.iter().any(|item| pattern_binds_path(item, path))
1670 || spread
1671 .as_deref()
1672 .is_some_and(|spread| pattern_binds_path(spread, path))
1673 || suffix.iter().any(|item| pattern_binds_path(item, path))
1674 }
1675 runtime::Pattern::Record { fields, spread, .. } => {
1676 fields
1677 .iter()
1678 .any(|field| pattern_binds_path(&field.pattern, path))
1679 || record_spread_pattern(spread.as_ref())
1680 .is_some_and(|spread| pattern_binds_path(spread, path))
1681 }
1682 runtime::Pattern::Variant { value, .. } => value
1683 .as_deref()
1684 .is_some_and(|value| pattern_binds_path(value, path)),
1685 runtime::Pattern::Or { left, right, .. } => {
1686 pattern_binds_path(left, path) || pattern_binds_path(right, path)
1687 }
1688 runtime::Pattern::As { pattern, name, .. } => {
1689 typed_ir::Path::from_name(name.clone()) == *path || pattern_binds_path(pattern, path)
1690 }
1691 runtime::Pattern::Wildcard { .. } | runtime::Pattern::Lit { .. } => false,
1692 }
1693}
1694
1695struct BlockBuilder {
1696 id: BlockId,
1697 params: Vec<ValueId>,
1698 stmts: Vec<NativeStmt>,
1699 terminator: Option<NativeTerminator>,
1700}
1701
1702impl BlockBuilder {
1703 fn new(id: BlockId, params: Vec<ValueId>) -> Self {
1704 Self {
1705 id,
1706 params,
1707 stmts: Vec::new(),
1708 terminator: None,
1709 }
1710 }
1711}
1712
1713fn lower_literal(lit: &typed_ir::Lit) -> NativeLiteral {
1714 match lit {
1715 typed_ir::Lit::Int(value) => NativeLiteral::Int(value.clone()),
1716 typed_ir::Lit::Float(value) => NativeLiteral::Float(value.clone()),
1717 typed_ir::Lit::String(value) => NativeLiteral::String(value.clone()),
1718 typed_ir::Lit::Bool(value) => NativeLiteral::Bool(*value),
1719 typed_ir::Lit::Unit => NativeLiteral::Unit,
1720 }
1721}
1722
1723fn primitive_apply(expr: &runtime::Expr) -> Option<(typed_ir::PrimitiveOp, Vec<&runtime::Expr>)> {
1724 let mut args = Vec::new();
1725 let mut current = expr;
1726 while let runtime::ExprKind::Apply { callee, arg, .. } = ¤t.kind {
1727 args.push(arg.as_ref());
1728 current = callee;
1729 }
1730 let runtime::ExprKind::PrimitiveOp(op) = ¤t.kind else {
1731 return None;
1732 };
1733 args.reverse();
1734 Some((*op, args))
1735}
1736
1737fn direct_apply<'expr>(
1738 expr: &'expr runtime::Expr,
1739 functions: &HashMap<typed_ir::Path, FunctionInfo>,
1740) -> NativeLowerResult<Option<(String, Vec<&'expr runtime::Expr>)>> {
1741 let mut args = Vec::new();
1742 let mut current = expr;
1743 while let runtime::ExprKind::Apply { callee, arg, .. } = ¤t.kind {
1744 args.push(arg.as_ref());
1745 current = callee;
1746 }
1747 let runtime::ExprKind::Var(path) = ¤t.kind else {
1748 return Ok(None);
1749 };
1750 let Some(target) = functions.get(path) else {
1751 return Ok(None);
1752 };
1753 let Some(target_name) = target.direct_targets.get(&args.len()) else {
1754 if args.len() < target.arity || target.arity == 0 && !args.is_empty() {
1755 return Ok(None);
1756 }
1757 return Err(NativeLowerError::CallArityMismatch {
1758 target: target.name.clone(),
1759 expected: target.arity,
1760 actual: args.len(),
1761 });
1762 };
1763 args.reverse();
1764 Ok(Some((target_name.clone(), args)))
1765}
1766
1767fn path_name(path: &typed_ir::Path) -> String {
1768 path.segments
1769 .iter()
1770 .map(|segment| segment.0.as_str())
1771 .collect::<Vec<_>>()
1772 .join("::")
1773}
1774
1775fn primitive_arity(op: typed_ir::PrimitiveOp) -> usize {
1776 use typed_ir::PrimitiveOp;
1777 match op {
1778 PrimitiveOp::BoolNot
1779 | PrimitiveOp::ListEmpty
1780 | PrimitiveOp::ListSingleton
1781 | PrimitiveOp::ListLen
1782 | PrimitiveOp::ListViewRaw
1783 | PrimitiveOp::StringLen
1784 | PrimitiveOp::StringToBytes
1785 | PrimitiveOp::BytesLen
1786 | PrimitiveOp::BytesToUtf8Raw
1787 | PrimitiveOp::BytesToPath
1788 | PrimitiveOp::PathToBytes
1789 | PrimitiveOp::IntToString
1790 | PrimitiveOp::IntToHex
1791 | PrimitiveOp::IntToUpperHex
1792 | PrimitiveOp::FloatToString
1793 | PrimitiveOp::BoolToString => 1,
1794 PrimitiveOp::BoolEq
1795 | PrimitiveOp::ListMerge
1796 | PrimitiveOp::ListIndex
1797 | PrimitiveOp::ListIndexRange
1798 | PrimitiveOp::StringIndex
1799 | PrimitiveOp::StringIndexRange
1800 | PrimitiveOp::IntAdd
1801 | PrimitiveOp::IntSub
1802 | PrimitiveOp::IntMul
1803 | PrimitiveOp::IntDiv
1804 | PrimitiveOp::IntEq
1805 | PrimitiveOp::IntLt
1806 | PrimitiveOp::IntLe
1807 | PrimitiveOp::IntGt
1808 | PrimitiveOp::IntGe
1809 | PrimitiveOp::FloatAdd
1810 | PrimitiveOp::FloatSub
1811 | PrimitiveOp::FloatMul
1812 | PrimitiveOp::FloatDiv
1813 | PrimitiveOp::FloatEq
1814 | PrimitiveOp::FloatLt
1815 | PrimitiveOp::FloatLe
1816 | PrimitiveOp::FloatGt
1817 | PrimitiveOp::FloatGe
1818 | PrimitiveOp::StringEq
1819 | PrimitiveOp::StringConcat
1820 | PrimitiveOp::BytesEq
1821 | PrimitiveOp::BytesConcat
1822 | PrimitiveOp::BytesIndex
1823 | PrimitiveOp::BytesIndexRange => 2,
1824 PrimitiveOp::ListSplice
1825 | PrimitiveOp::ListIndexRangeRaw
1826 | PrimitiveOp::StringSplice
1827 | PrimitiveOp::StringIndexRangeRaw => 3,
1828 PrimitiveOp::ListSpliceRaw | PrimitiveOp::StringSpliceRaw => 4,
1829 }
1830}
1831
1832#[cfg(test)]
1833mod tests {
1834 use super::*;
1835
1836 fn unknown_lit(lit: typed_ir::Lit) -> runtime::Expr {
1837 runtime::Expr::typed(runtime::ExprKind::Lit(lit), runtime::Type::unknown())
1838 }
1839
1840 fn primitive(op: typed_ir::PrimitiveOp) -> runtime::Expr {
1841 runtime::Expr::typed(runtime::ExprKind::PrimitiveOp(op), runtime::Type::unknown())
1842 }
1843
1844 fn apply(callee: runtime::Expr, arg: runtime::Expr) -> runtime::Expr {
1845 runtime::Expr::typed(
1846 runtime::ExprKind::Apply {
1847 callee: Box::new(callee),
1848 arg: Box::new(arg),
1849 evidence: None,
1850 instantiation: None,
1851 },
1852 runtime::Type::unknown(),
1853 )
1854 }
1855
1856 fn if_expr(
1857 cond: runtime::Expr,
1858 then_branch: runtime::Expr,
1859 else_branch: runtime::Expr,
1860 ) -> runtime::Expr {
1861 runtime::Expr::typed(
1862 runtime::ExprKind::If {
1863 cond: Box::new(cond),
1864 then_branch: Box::new(then_branch),
1865 else_branch: Box::new(else_branch),
1866 evidence: None,
1867 },
1868 runtime::Type::unknown(),
1869 )
1870 }
1871
1872 fn bool_match(
1873 scrutinee: runtime::Expr,
1874 then_branch: runtime::Expr,
1875 else_branch: runtime::Expr,
1876 ) -> runtime::Expr {
1877 runtime::Expr::typed(
1878 runtime::ExprKind::Match {
1879 scrutinee: Box::new(scrutinee),
1880 arms: vec![
1881 runtime::MatchArm {
1882 pattern: runtime::Pattern::Lit {
1883 lit: typed_ir::Lit::Bool(true),
1884 ty: runtime::Type::unknown(),
1885 },
1886 guard: None,
1887 body: then_branch,
1888 },
1889 runtime::MatchArm {
1890 pattern: runtime::Pattern::Lit {
1891 lit: typed_ir::Lit::Bool(false),
1892 ty: runtime::Type::unknown(),
1893 },
1894 guard: None,
1895 body: else_branch,
1896 },
1897 ],
1898 evidence: runtime::JoinEvidence {
1899 result: typed_ir::Type::Unknown,
1900 },
1901 },
1902 runtime::Type::unknown(),
1903 )
1904 }
1905
1906 fn thunk(expr: runtime::Expr) -> runtime::Expr {
1907 runtime::Expr::typed(
1908 runtime::ExprKind::Thunk {
1909 effect: typed_ir::Type::Never,
1910 value: runtime::Type::unknown(),
1911 expr: Box::new(expr),
1912 },
1913 runtime::Type::unknown(),
1914 )
1915 }
1916
1917 fn bind_here(expr: runtime::Expr) -> runtime::Expr {
1918 runtime::Expr::typed(
1919 runtime::ExprKind::BindHere {
1920 expr: Box::new(expr),
1921 },
1922 runtime::Type::unknown(),
1923 )
1924 }
1925
1926 fn handle(expr: runtime::Expr) -> runtime::Expr {
1927 runtime::Expr::typed(
1928 runtime::ExprKind::Handle {
1929 body: Box::new(expr),
1930 arms: Vec::new(),
1931 evidence: runtime::JoinEvidence {
1932 result: typed_ir::Type::Unknown,
1933 },
1934 handler: runtime::HandleEffect {
1935 consumes: Vec::new(),
1936 residual_before: None,
1937 residual_after: None,
1938 },
1939 },
1940 runtime::Type::unknown(),
1941 )
1942 }
1943
1944 fn var(name: &str) -> runtime::Expr {
1945 runtime::Expr::typed(
1946 runtime::ExprKind::Var(typed_ir::Path::from_name(typed_ir::Name(name.to_string()))),
1947 runtime::Type::unknown(),
1948 )
1949 }
1950
1951 fn bind_pattern(name: &str) -> runtime::Pattern {
1952 runtime::Pattern::Bind {
1953 name: typed_ir::Name(name.to_string()),
1954 ty: runtime::Type::unknown(),
1955 }
1956 }
1957
1958 fn block(stmts: Vec<runtime::Stmt>, tail: runtime::Expr) -> runtime::Expr {
1959 runtime::Expr::typed(
1960 runtime::ExprKind::Block {
1961 stmts,
1962 tail: Some(Box::new(tail)),
1963 },
1964 runtime::Type::unknown(),
1965 )
1966 }
1967
1968 fn lambda(param: &str, body: runtime::Expr) -> runtime::Expr {
1969 runtime::Expr::typed(
1970 runtime::ExprKind::Lambda {
1971 param: typed_ir::Name(param.to_string()),
1972 param_effect_annotation: None,
1973 param_function_allowed_effects: None,
1974 body: Box::new(body),
1975 },
1976 runtime::Type::unknown(),
1977 )
1978 }
1979
1980 fn binding(name: &str, body: runtime::Expr) -> runtime::Binding {
1981 runtime::Binding {
1982 name: typed_ir::Path::from_name(typed_ir::Name(name.to_string())),
1983 type_params: Vec::new(),
1984 scheme: typed_ir::Scheme {
1985 requirements: Vec::new(),
1986 body: typed_ir::Type::Unknown,
1987 },
1988 body,
1989 }
1990 }
1991
1992 fn module_with_binding_and_root(
1993 binding: runtime::Binding,
1994 expr: runtime::Expr,
1995 ) -> runtime::Module {
1996 module_with_bindings_and_root(vec![binding], expr)
1997 }
1998
1999 fn module_with_bindings_and_root(
2000 bindings: Vec<runtime::Binding>,
2001 expr: runtime::Expr,
2002 ) -> runtime::Module {
2003 runtime::Module {
2004 path: typed_ir::Path::default(),
2005 bindings,
2006 root_exprs: vec![expr],
2007 roots: vec![runtime::Root::Expr(0)],
2008 role_impls: Vec::new(),
2009 }
2010 }
2011
2012 fn module_with_root(expr: runtime::Expr) -> runtime::Module {
2013 runtime::Module {
2014 path: typed_ir::Path::default(),
2015 bindings: Vec::new(),
2016 root_exprs: vec![expr],
2017 roots: vec![runtime::Root::Expr(0)],
2018 role_impls: Vec::new(),
2019 }
2020 }
2021
2022 #[test]
2023 fn lowers_literal_root() {
2024 let module = module_with_root(unknown_lit(typed_ir::Lit::Int("42".to_string())));
2025 let lowered = lower_module(&module).expect("lowered");
2026
2027 assert_eq!(lowered.roots.len(), 1);
2028 assert_eq!(
2029 lowered.roots[0].blocks[0].stmts,
2030 vec![NativeStmt::Literal {
2031 dest: ValueId(0),
2032 literal: NativeLiteral::Int("42".to_string()),
2033 }]
2034 );
2035 assert_eq!(
2036 lowered.roots[0].blocks[0].terminator,
2037 NativeTerminator::Return(ValueId(0))
2038 );
2039 }
2040
2041 #[test]
2042 fn lowers_primitive_apply_root() {
2043 let expr = apply(
2044 apply(
2045 primitive(typed_ir::PrimitiveOp::IntAdd),
2046 unknown_lit(typed_ir::Lit::Int("1".to_string())),
2047 ),
2048 unknown_lit(typed_ir::Lit::Int("2".to_string())),
2049 );
2050 let module = module_with_root(expr);
2051 let lowered = lower_module(&module).expect("lowered");
2052
2053 assert_eq!(
2054 lowered.roots[0].blocks[0].stmts,
2055 vec![
2056 NativeStmt::Literal {
2057 dest: ValueId(0),
2058 literal: NativeLiteral::Int("1".to_string()),
2059 },
2060 NativeStmt::Literal {
2061 dest: ValueId(1),
2062 literal: NativeLiteral::Int("2".to_string()),
2063 },
2064 NativeStmt::Primitive {
2065 dest: ValueId(2),
2066 op: typed_ir::PrimitiveOp::IntAdd,
2067 args: vec![ValueId(0), ValueId(1)],
2068 },
2069 ]
2070 );
2071 }
2072
2073 #[test]
2074 fn rejects_effect_operation_root() {
2075 let expr = runtime::Expr::typed(
2076 runtime::ExprKind::EffectOp(typed_ir::Path::from_name(typed_ir::Name(
2077 "read".to_string(),
2078 ))),
2079 runtime::Type::unknown(),
2080 );
2081 let module = module_with_root(expr);
2082
2083 assert_eq!(
2084 lower_module(&module).expect_err("unsupported"),
2085 NativeLowerError::UnsupportedExpr {
2086 kind: "effect operation",
2087 }
2088 );
2089 }
2090
2091 #[test]
2092 fn rejects_effect_operation_under_handle_wrapper() {
2093 let expr = handle(runtime::Expr::typed(
2094 runtime::ExprKind::EffectOp(typed_ir::Path::from_name(typed_ir::Name(
2095 "read".to_string(),
2096 ))),
2097 runtime::Type::unknown(),
2098 ));
2099 let module = module_with_root(expr);
2100
2101 assert_eq!(
2102 lower_module(&module).expect_err("unsupported"),
2103 NativeLowerError::UnsupportedExpr {
2104 kind: "effect operation",
2105 }
2106 );
2107 }
2108
2109 #[test]
2110 fn lowers_if_with_branch_and_merge_blocks() {
2111 let module = module_with_root(if_expr(
2112 unknown_lit(typed_ir::Lit::Bool(true)),
2113 unknown_lit(typed_ir::Lit::Int("1".to_string())),
2114 unknown_lit(typed_ir::Lit::Int("2".to_string())),
2115 ));
2116 let lowered = lower_module(&module).expect("lowered");
2117 let blocks = &lowered.roots[0].blocks;
2118
2119 assert_eq!(blocks.len(), 4);
2120 assert_eq!(
2121 blocks[0].terminator,
2122 NativeTerminator::Branch {
2123 cond: ValueId(0),
2124 then_block: BlockId(1),
2125 else_block: BlockId(2),
2126 }
2127 );
2128 assert_eq!(
2129 blocks[1].terminator,
2130 NativeTerminator::Jump {
2131 target: BlockId(3),
2132 args: vec![ValueId(2)],
2133 }
2134 );
2135 assert_eq!(
2136 blocks[2].terminator,
2137 NativeTerminator::Jump {
2138 target: BlockId(3),
2139 args: vec![ValueId(3)],
2140 }
2141 );
2142 assert_eq!(blocks[3].params, vec![ValueId(1)]);
2143 assert_eq!(blocks[3].terminator, NativeTerminator::Return(ValueId(1)));
2144 }
2145
2146 #[test]
2147 fn lowers_bool_match_with_branch_and_merge_blocks() {
2148 let module = module_with_root(bool_match(
2149 unknown_lit(typed_ir::Lit::Bool(true)),
2150 unknown_lit(typed_ir::Lit::Int("1".to_string())),
2151 unknown_lit(typed_ir::Lit::Int("2".to_string())),
2152 ));
2153 let lowered = lower_module(&module).expect("lowered");
2154 let blocks = &lowered.roots[0].blocks;
2155
2156 assert_eq!(blocks.len(), 4);
2157 assert_eq!(
2158 blocks[0].terminator,
2159 NativeTerminator::Branch {
2160 cond: ValueId(0),
2161 then_block: BlockId(1),
2162 else_block: BlockId(2),
2163 }
2164 );
2165 assert_eq!(
2166 blocks[1].terminator,
2167 NativeTerminator::Jump {
2168 target: BlockId(3),
2169 args: vec![ValueId(2)],
2170 }
2171 );
2172 assert_eq!(
2173 blocks[2].terminator,
2174 NativeTerminator::Jump {
2175 target: BlockId(3),
2176 args: vec![ValueId(3)],
2177 }
2178 );
2179 assert_eq!(blocks[3].params, vec![ValueId(1)]);
2180 assert_eq!(blocks[3].terminator, NativeTerminator::Return(ValueId(1)));
2181 }
2182
2183 #[test]
2184 fn lowers_effect_free_execution_wrappers() {
2185 let module = module_with_root(handle(bind_here(thunk(unknown_lit(typed_ir::Lit::Int(
2186 "42".to_string(),
2187 ))))));
2188 let lowered = lower_module(&module).expect("lowered");
2189
2190 assert_eq!(
2191 lowered.roots[0].blocks[0].stmts,
2192 vec![NativeStmt::Literal {
2193 dest: ValueId(0),
2194 literal: NativeLiteral::Int("42".to_string()),
2195 }]
2196 );
2197 assert_eq!(
2198 lowered.roots[0].blocks[0].terminator,
2199 NativeTerminator::Return(ValueId(0))
2200 );
2201 }
2202
2203 #[test]
2204 fn lowers_simple_block_binding() {
2205 let module = module_with_root(block(
2206 vec![runtime::Stmt::Let {
2207 pattern: bind_pattern("x"),
2208 value: unknown_lit(typed_ir::Lit::Int("42".to_string())),
2209 }],
2210 var("x"),
2211 ));
2212 let lowered = lower_module(&module).expect("lowered");
2213
2214 assert_eq!(
2215 lowered.roots[0].blocks[0].stmts,
2216 vec![NativeStmt::Literal {
2217 dest: ValueId(0),
2218 literal: NativeLiteral::Int("42".to_string()),
2219 }]
2220 );
2221 assert_eq!(
2222 lowered.roots[0].blocks[0].terminator,
2223 NativeTerminator::Return(ValueId(0))
2224 );
2225 }
2226
2227 #[test]
2228 fn lowers_direct_monomorphic_call() {
2229 let inc = binding(
2230 "inc",
2231 lambda(
2232 "x",
2233 apply(
2234 apply(primitive(typed_ir::PrimitiveOp::IntAdd), var("x")),
2235 unknown_lit(typed_ir::Lit::Int("1".to_string())),
2236 ),
2237 ),
2238 );
2239 let root = apply(
2240 var("inc"),
2241 unknown_lit(typed_ir::Lit::Int("41".to_string())),
2242 );
2243 let module = module_with_binding_and_root(inc, root);
2244 let lowered = lower_module(&module).expect("lowered");
2245
2246 assert_eq!(lowered.functions[0].name, "inc");
2247 assert_eq!(lowered.functions[1].name, "inc#partial0");
2248 assert_eq!(lowered.functions[0].params, vec![ValueId(0)]);
2249 assert_eq!(
2250 lowered.roots[0].blocks[0].stmts,
2251 vec![
2252 NativeStmt::Literal {
2253 dest: ValueId(0),
2254 literal: NativeLiteral::Int("41".to_string()),
2255 },
2256 NativeStmt::DirectCall {
2257 dest: ValueId(1),
2258 target: "inc".to_string(),
2259 args: vec![ValueId(0)],
2260 },
2261 ]
2262 );
2263 assert_eq!(
2264 lowered.roots[0].blocks[0].terminator,
2265 NativeTerminator::Return(ValueId(1))
2266 );
2267 }
2268
2269 #[test]
2270 fn rejects_direct_call_arity_mismatch() {
2271 let inc = binding("inc", lambda("x", var("x")));
2272 let root = apply(
2273 apply(var("inc"), unknown_lit(typed_ir::Lit::Int("1".to_string()))),
2274 unknown_lit(typed_ir::Lit::Int("2".to_string())),
2275 );
2276 let module = module_with_binding_and_root(inc, root);
2277
2278 assert_eq!(
2279 lower_module(&module).expect_err("arity mismatch"),
2280 NativeLowerError::CallArityMismatch {
2281 target: "inc".to_string(),
2282 expected: 1,
2283 actual: 2,
2284 }
2285 );
2286 }
2287
2288 #[test]
2289 fn lowers_zero_arity_binding_apply_as_closure_call() {
2290 let choose = binding(
2291 "choose",
2292 if_expr(
2293 unknown_lit(typed_ir::Lit::Bool(true)),
2294 lambda("x", var("x")),
2295 lambda("x", var("x")),
2296 ),
2297 );
2298 let root = apply(
2299 var("choose"),
2300 unknown_lit(typed_ir::Lit::Int("42".to_string())),
2301 );
2302 let module = module_with_binding_and_root(choose, root);
2303 let lowered = lower_module(&module).expect("lowered");
2304
2305 assert_eq!(
2306 lowered.roots[0].blocks[0].stmts,
2307 vec![
2308 NativeStmt::DirectCall {
2309 dest: ValueId(0),
2310 target: "choose".to_string(),
2311 args: Vec::new(),
2312 },
2313 NativeStmt::Literal {
2314 dest: ValueId(1),
2315 literal: NativeLiteral::Int("42".to_string()),
2316 },
2317 NativeStmt::ClosureCall {
2318 dest: ValueId(2),
2319 callee: ValueId(0),
2320 args: vec![ValueId(1)],
2321 },
2322 ]
2323 );
2324 assert_eq!(
2325 lowered.roots[0].blocks[0].terminator,
2326 NativeTerminator::Return(ValueId(2))
2327 );
2328 }
2329
2330 #[test]
2331 fn lowers_block_tail_lambda_as_extra_direct_arity() {
2332 let add_after_let = binding(
2333 "add_after_let",
2334 lambda(
2335 "x",
2336 block(
2337 vec![runtime::Stmt::Let {
2338 pattern: bind_pattern("z"),
2339 value: var("x"),
2340 }],
2341 lambda(
2342 "y",
2343 apply(
2344 apply(primitive(typed_ir::PrimitiveOp::IntAdd), var("z")),
2345 var("y"),
2346 ),
2347 ),
2348 ),
2349 ),
2350 );
2351 let root = apply(
2352 apply(
2353 var("add_after_let"),
2354 unknown_lit(typed_ir::Lit::Int("20".to_string())),
2355 ),
2356 unknown_lit(typed_ir::Lit::Int("22".to_string())),
2357 );
2358 let module = module_with_binding_and_root(add_after_let, root);
2359 let lowered = lower_module(&module).expect("lowered");
2360
2361 assert_eq!(
2362 lowered
2363 .functions
2364 .iter()
2365 .map(|function| function.name.as_str())
2366 .collect::<Vec<_>>(),
2367 vec![
2368 "add_after_let",
2369 "add_after_let#lambda0",
2370 "add_after_let#direct2",
2371 "add_after_let#partial0"
2372 ]
2373 );
2374 assert_eq!(
2375 lowered.roots[0].blocks[0].stmts,
2376 vec![
2377 NativeStmt::Literal {
2378 dest: ValueId(0),
2379 literal: NativeLiteral::Int("20".to_string()),
2380 },
2381 NativeStmt::Literal {
2382 dest: ValueId(1),
2383 literal: NativeLiteral::Int("22".to_string()),
2384 },
2385 NativeStmt::DirectCall {
2386 dest: ValueId(2),
2387 target: "add_after_let#direct2".to_string(),
2388 args: vec![ValueId(0), ValueId(1)],
2389 },
2390 ]
2391 );
2392 }
2393
2394 #[test]
2395 fn lowers_partial_top_level_call_as_closure_chain() {
2396 let add = binding(
2397 "add",
2398 lambda(
2399 "x",
2400 lambda(
2401 "y",
2402 apply(
2403 apply(primitive(typed_ir::PrimitiveOp::IntAdd), var("x")),
2404 var("y"),
2405 ),
2406 ),
2407 ),
2408 );
2409 let root = block(
2410 vec![runtime::Stmt::Let {
2411 pattern: bind_pattern("f"),
2412 value: apply(
2413 var("add"),
2414 unknown_lit(typed_ir::Lit::Int("40".to_string())),
2415 ),
2416 }],
2417 apply(var("f"), unknown_lit(typed_ir::Lit::Int("2".to_string()))),
2418 );
2419 let module = module_with_binding_and_root(add, root);
2420 let lowered = lower_module(&module).expect("lowered");
2421
2422 assert!(
2423 lowered
2424 .functions
2425 .iter()
2426 .any(|function| function.name == "add#partial0")
2427 );
2428 assert!(
2429 lowered
2430 .functions
2431 .iter()
2432 .any(|function| function.name == "add#partial1")
2433 );
2434 assert_eq!(
2435 lowered.roots[0].blocks[0].stmts,
2436 vec![
2437 NativeStmt::MakeClosure {
2438 dest: ValueId(0),
2439 target: "add#partial0".to_string(),
2440 captures: Vec::new(),
2441 },
2442 NativeStmt::Literal {
2443 dest: ValueId(1),
2444 literal: NativeLiteral::Int("40".to_string()),
2445 },
2446 NativeStmt::ClosureCall {
2447 dest: ValueId(2),
2448 callee: ValueId(0),
2449 args: vec![ValueId(1)],
2450 },
2451 NativeStmt::Literal {
2452 dest: ValueId(3),
2453 literal: NativeLiteral::Int("2".to_string()),
2454 },
2455 NativeStmt::ClosureCall {
2456 dest: ValueId(4),
2457 callee: ValueId(2),
2458 args: vec![ValueId(3)],
2459 },
2460 ]
2461 );
2462 }
2463
2464 #[test]
2465 fn lowers_multiple_bindings() {
2466 let inc = binding(
2467 "inc",
2468 lambda(
2469 "x",
2470 apply(
2471 apply(primitive(typed_ir::PrimitiveOp::IntAdd), var("x")),
2472 unknown_lit(typed_ir::Lit::Int("1".to_string())),
2473 ),
2474 ),
2475 );
2476 let twice = binding(
2477 "twice",
2478 lambda("x", apply(var("inc"), apply(var("inc"), var("x")))),
2479 );
2480 let root = apply(
2481 var("twice"),
2482 unknown_lit(typed_ir::Lit::Int("40".to_string())),
2483 );
2484 let module = module_with_bindings_and_root(vec![inc, twice], root);
2485 let lowered = lower_module(&module).expect("lowered");
2486
2487 assert_eq!(
2488 lowered
2489 .functions
2490 .iter()
2491 .map(|function| function.name.as_str())
2492 .collect::<Vec<_>>(),
2493 vec!["inc", "inc#partial0", "twice", "twice#partial0"]
2494 );
2495 }
2496}