1use crate::types::*;
2use crate::ast::*;
3use crate::error::{KoreResult, KoreError};
4use std::collections::{HashMap, HashSet};
5
6pub struct MonomorphizedProgram {
8 pub items: Vec<TypedItem>,
9}
10
11pub fn monomorphize(program: &TypedProgram) -> KoreResult<MonomorphizedProgram> {
12 let mut ctx = MonoContext::new();
13
14
15
16 for item in &program.items {
18 match item {
19 TypedItem::Function(func) => {
20
21 if !func.ast.generics.is_empty() {
22 ctx.generic_functions.insert(func.ast.name.clone(), func.clone());
23
24 } else {
25 ctx.concrete_items.push(item.clone());
26
27 }
28 }
29 TypedItem::Struct(s) => {
30 let mut fields = HashMap::new();
31 for f in &s.ast.fields {
32 if let Ok(ty) = resolve_ast_type(&f.ty) {
33 fields.insert(f.name.clone(), ty);
34 }
35 }
36 ctx.structs.insert(s.ast.name.clone(), fields);
37 ctx.concrete_items.push(item.clone());
38 }
39 TypedItem::Impl(imp) => {
40 let type_name = match &imp.ast.target_type {
43 Type::Named { name, .. } => name.clone(),
44 _ => continue, };
46
47 let target_ty = resolve_ast_type(&imp.ast.target_type).unwrap_or(ResolvedType::Unknown);
48
49 if let Some(trait_name) = &imp.ast.trait_name {
51 let type_name_str = type_to_string(&target_ty);
52 ctx.trait_impls.insert((trait_name.clone(), type_name_str));
53 }
54
55 for method in &imp.ast.methods {
56 let mangled_name = format!("{}_{}", type_name, method.name);
57
58 let mut standalone_fn = method.clone();
59 standalone_fn.name = mangled_name.clone();
60
61 let mut params = Vec::new();
63 for p in &method.params {
64 if p.name == "self" {
65 params.push(target_ty.clone());
66 } else {
67 params.push(resolve_ast_type(&p.ty).unwrap_or(ResolvedType::Unknown));
68 }
69 }
70 let ret = method.return_type.as_ref()
71 .map(|t| resolve_ast_type(t).unwrap_or(ResolvedType::Unknown))
72 .unwrap_or(ResolvedType::Unit);
73
74 let method_ty = ResolvedType::Function {
75 params,
76 ret: Box::new(ret),
77 effects: crate::effects::EffectSet::new(), };
79
80 let typed_method = TypedFunction {
81 ast: standalone_fn,
82 resolved_type: method_ty,
83 effects: crate::effects::EffectSet::new(),
84 };
85
86 ctx.methods.entry(type_name.clone()).or_default().insert(method.name.clone(), mangled_name.clone());
87 ctx.concrete_items.push(TypedItem::Function(typed_method));
88 }
89 }
90 _ => {
91 ctx.concrete_items.push(item.clone());
92 }
93 }
94 }
95
96 let mut i = 0;
98 while i < ctx.concrete_items.len() {
99 let item = ctx.concrete_items[i].clone();
100 match item {
101 TypedItem::Function(func) => {
102 if func.effects.effects.contains(&crate::effects::Effect::Async) {
104 let entry_fn = lower_async_fn(&mut ctx, &func)?;
108
109 ctx.concrete_items[i] = TypedItem::Function(entry_fn);
111 } else {
112 let new_func = scan_function(&mut ctx, &func)?;
113 ctx.concrete_items[i] = TypedItem::Function(new_func);
114 }
115 }
116 _ => {}
117 }
118 i += 1;
119 }
120
121 Ok(MonomorphizedProgram { items: ctx.concrete_items })
122}
123
124struct MonoContext {
125 generic_functions: HashMap<String, TypedFunction>,
126 concrete_items: Vec<TypedItem>,
127 instantiated: HashMap<String, String>,
128 methods: HashMap<String, HashMap<String, String>>,
130 structs: HashMap<String, HashMap<String, ResolvedType>>,
132 trait_impls: HashSet<(String, String)>,
134}
135
136impl MonoContext {
137 fn new() -> Self {
138 Self {
139 generic_functions: HashMap::new(),
140 concrete_items: Vec::new(),
141 instantiated: HashMap::new(),
142 methods: HashMap::new(),
143 structs: HashMap::new(),
144 trait_impls: HashSet::new(),
145 }
146 }
147
148 fn instantiate(&mut self, name: &str, type_args: &[ResolvedType]) -> KoreResult<String> {
149 let mangled_name = format!("{}_{}", name, mangle_types(type_args));
150
151 if self.instantiated.contains_key(&mangled_name) {
152 return Ok(mangled_name);
153 }
154
155 let generic_func = self.generic_functions.get(name)
156 .ok_or_else(|| KoreError::type_error(format!("Generic function {} not found", name), crate::span::Span::new(0,0)))?
157 .clone();
158
159 if generic_func.ast.generics.len() != type_args.len() {
160 return Err(KoreError::type_error(format!("Generic arg count mismatch for {}: expected {}, got {}", name, generic_func.ast.generics.len(), type_args.len()), generic_func.ast.span));
161 }
162
163 let mut mapping = HashMap::new();
164 for (i, param) in generic_func.ast.generics.iter().enumerate() {
165 mapping.insert(param.name.clone(), type_args[i].clone());
166 }
167
168 let mut new_func = generic_func.clone();
169 new_func.ast.name = mangled_name.clone();
170 new_func.ast.generics.clear();
171
172 if let ResolvedType::Function { params, ret, .. } = &mut new_func.resolved_type {
173 for p in params {
174 *p = substitute_type(p, &mapping);
175 }
176 *ret = Box::new(substitute_type(&ret, &mapping));
177 }
178
179 self.instantiated.insert(mangled_name.clone(), mangled_name.clone());
180
181 substitute_ast_types(&mut new_func.ast, &mapping);
182 self.concrete_items.push(TypedItem::Function(new_func));
183
184 Ok(mangled_name)
185 }
186}
187
188fn type_to_string(ty: &ResolvedType) -> String {
189 match ty {
190 ResolvedType::Int(_) => "Int".to_string(),
191 ResolvedType::Float(_) => "Float".to_string(),
192 ResolvedType::String => "String".to_string(),
193 ResolvedType::Bool => "Bool".to_string(),
194 ResolvedType::Unit => "Unit".to_string(),
195 ResolvedType::Struct(n, _) => n.clone(),
196 ResolvedType::Enum(n, _) => n.clone(),
197 ResolvedType::Tuple(ts) => format!("({})", ts.iter().map(type_to_string).collect::<Vec<_>>().join(", ")),
198 _ => "Any".to_string(),
199 }
200}
201
202fn mangle_types(types: &[ResolvedType]) -> String {
203 types.iter().map(type_to_string).collect::<Vec<_>>().join("_")
204}
205
206fn resolve_ast_type(ty: &Type) -> KoreResult<ResolvedType> {
207 crate::types::resolve_type(ty)
208}
209
210fn unify(
213 param_type: &ResolvedType,
214 arg_type: &ResolvedType,
215 bindings: &mut HashMap<String, ResolvedType>,
216) {
217 match (param_type, arg_type) {
218 (ResolvedType::Generic(name), concrete) => {
220 if let Some(existing) = bindings.get(name) {
221 let _ = existing;
223 } else {
224 bindings.insert(name.clone(), concrete.clone());
225 }
226 }
227
228 (ResolvedType::Function { params: p_params, ret: p_ret, .. },
230 ResolvedType::Function { params: a_params, ret: a_ret, .. }) => {
231 for (pp, ap) in p_params.iter().zip(a_params.iter()) {
233 unify(pp, ap, bindings);
234 }
235 unify(p_ret, a_ret, bindings);
237 }
238
239 (ResolvedType::Array(p_inner, _), ResolvedType::Array(a_inner, _)) => {
241 unify(p_inner, a_inner, bindings);
242 }
243
244 (ResolvedType::Tuple(p_elems), ResolvedType::Tuple(a_elems)) => {
246 for (pe, ae) in p_elems.iter().zip(a_elems.iter()) {
247 unify(pe, ae, bindings);
248 }
249 }
250
251 _ => {}
253 }
254}
255
256fn infer_type_args(
258 ctx: &MonoContext,
259 generic_func: &TypedFunction,
260 arg_types: &[ResolvedType],
261) -> KoreResult<Vec<ResolvedType>> {
262 let mut bindings: HashMap<String, ResolvedType> = HashMap::new();
263
264 let param_types: Vec<ResolvedType> = if let ResolvedType::Function { params, .. } = &generic_func.resolved_type {
266 params.clone()
267 } else {
268 generic_func.ast.params.iter()
270 .map(|p| resolve_ast_type(&p.ty).unwrap_or(ResolvedType::Unknown))
271 .collect()
272 };
273
274 for (param_ty, arg_ty) in param_types.iter().zip(arg_types.iter()) {
276 unify(param_ty, arg_ty, &mut bindings);
277 }
278
279 let mut inferred = Vec::new();
281 for generic in &generic_func.ast.generics {
282 if let Some(ty) = bindings.get(&generic.name) {
283 for bound in &generic.bounds {
285 let type_name = type_to_string(ty);
286 if !ctx.trait_impls.contains(&(bound.trait_name.clone(), type_name.clone())) {
287 return Err(KoreError::type_error(
288 format!("Type '{}' does not satisfy bound '{}'", type_name, bound.trait_name),
289 generic.span
290 ));
291 }
292 }
293 inferred.push(ty.clone());
294 } else {
295 inferred.push(ResolvedType::Unknown);
297 }
298 }
299
300 Ok(inferred)
301}
302
303fn substitute_type(ty: &ResolvedType, mapping: &HashMap<String, ResolvedType>) -> ResolvedType {
304 match ty {
305 ResolvedType::Generic(name) => mapping.get(name).cloned().unwrap_or(ty.clone()),
306 ResolvedType::Function { params, ret, effects } => {
307 ResolvedType::Function {
308 params: params.iter().map(|p| substitute_type(p, mapping)).collect(),
309 ret: Box::new(substitute_type(ret, mapping)),
310 effects: effects.clone()
311 }
312 }
313 ResolvedType::Array(inner, n) => ResolvedType::Array(Box::new(substitute_type(inner, mapping)), *n),
314 _ => ty.clone()
315 }
316}
317
318fn substitute_ast_types(func: &mut Function, mapping: &HashMap<String, ResolvedType>) {
319 for param in &mut func.params {
321 substitute_type_ast(&mut param.ty, mapping);
322 }
323
324 if let Some(ret) = &mut func.return_type {
326 substitute_type_ast(ret, mapping);
327 }
328
329 substitute_block(&mut func.body, mapping);
331}
332
333fn substitute_block(block: &mut Block, mapping: &HashMap<String, ResolvedType>) {
334 for stmt in &mut block.stmts {
335 substitute_stmt(stmt, mapping);
336 }
337}
338
339fn substitute_stmt(stmt: &mut Stmt, mapping: &HashMap<String, ResolvedType>) {
340 match stmt {
341 Stmt::Let { ty, value, .. } => {
342 if let Some(t) = ty {
343 substitute_type_ast(t, mapping);
344 }
345 if let Some(v) = value {
346 substitute_expr(v, mapping);
347 }
348 }
349 Stmt::Expr(e) => substitute_expr(e, mapping),
350 Stmt::Return(Some(e), _) => substitute_expr(e, mapping),
351 Stmt::For { iter, body, .. } => {
352 substitute_expr(iter, mapping);
353 substitute_block(body, mapping);
354 }
355 Stmt::While { condition, body, .. } => {
356 substitute_expr(condition, mapping);
357 substitute_block(body, mapping);
358 }
359 _ => {}
360 }
361}
362
363fn substitute_expr(expr: &mut Expr, mapping: &HashMap<String, ResolvedType>) {
364 match expr {
365 Expr::Cast { value, target, .. } => {
366 substitute_expr(value, mapping);
367 substitute_type_ast(target, mapping);
368 }
369 Expr::Binary { left, right, .. } => {
370 substitute_expr(left, mapping);
371 substitute_expr(right, mapping);
372 }
373 Expr::Unary { operand, .. } => substitute_expr(operand, mapping),
374 Expr::Call { callee, args, .. } => {
375 substitute_expr(callee, mapping);
376 for arg in args {
377 substitute_expr(&mut arg.value, mapping);
378 }
379 }
380 Expr::MethodCall { receiver, args, .. } => {
381 substitute_expr(receiver, mapping);
382 for arg in args {
383 substitute_expr(&mut arg.value, mapping);
384 }
385 }
386 Expr::Field { object, .. } => {
387 substitute_expr(object, mapping);
388 }
389 Expr::Index { object, index, .. } => {
390 substitute_expr(object, mapping);
391 substitute_expr(index, mapping);
392 }
393 Expr::Struct { fields, .. } => {
394 for (_, v) in fields {
395 substitute_expr(v, mapping);
396 }
397 }
398 Expr::Array(items, _) => {
399 for item in items {
400 substitute_expr(item, mapping);
401 }
402 }
403 Expr::Tuple(items, _) => {
404 for item in items {
405 substitute_expr(item, mapping);
406 }
407 }
408 Expr::Block(b, _) => substitute_block(b, mapping),
409 Expr::If { condition, then_branch, else_branch, .. } => {
410 substitute_expr(condition, mapping);
411 substitute_block(then_branch, mapping);
412 if let Some(br) = else_branch {
413 match br.as_mut() {
414 ElseBranch::Else(b) => substitute_block(b, mapping),
415 ElseBranch::ElseIf(c, t, _) => { substitute_expr(c, mapping);
417 substitute_block(t, mapping);
418 }
419 }
420 }
421 }
422 Expr::Match { scrutinee, arms, .. } => {
423 substitute_expr(scrutinee, mapping);
424 for arm in arms {
425 substitute_expr(&mut arm.body, mapping);
426 }
427 }
428 Expr::Lambda { params, body, return_type, .. } => {
429 for p in params {
430 substitute_type_ast(&mut p.ty, mapping);
431 }
432 if let Some(ret) = return_type {
433 substitute_type_ast(ret, mapping);
434 }
435 substitute_expr(body, mapping);
436 }
437 Expr::Await(inner, _) => {
438 substitute_expr(inner, mapping);
439 }
440 _ => {}
441 }
442}
443
444fn substitute_type_ast(ty: &mut Type, mapping: &HashMap<String, ResolvedType>) {
445 match ty {
446 Type::Named { name, generics, .. } => {
447 if let Some(concrete) = mapping.get(name) {
448 *ty = resolved_to_ast_type(concrete, ty.span());
449 } else {
450 for g in generics {
451 substitute_type_ast(g, mapping);
452 }
453 }
454 }
455 Type::Tuple(types, _) => {
456 for t in types {
457 substitute_type_ast(t, mapping);
458 }
459 }
460 Type::Function { params, return_type, .. } => {
461 for p in params {
462 substitute_type_ast(p, mapping);
463 }
464 substitute_type_ast(return_type, mapping);
465 }
466 Type::Array(inner, _, _) => {
467 substitute_type_ast(inner, mapping);
468 }
469 Type::Slice(inner, _) => {
470 substitute_type_ast(inner, mapping);
471 }
472 _ => {}
473 }
474}
475
476fn resolved_to_ast_type(res: &ResolvedType, span: crate::span::Span) -> Type {
477 match res {
478 ResolvedType::Int(_) => Type::Named { name: "Int".into(), generics: vec![], span },
479 ResolvedType::Float(_) => Type::Named { name: "Float".into(), generics: vec![], span },
480 ResolvedType::Bool => Type::Named { name: "Bool".into(), generics: vec![], span },
481 ResolvedType::String => Type::Named { name: "String".into(), generics: vec![], span },
482 ResolvedType::Unit => Type::Unit(span),
483 ResolvedType::Struct(n, _) => Type::Named { name: n.clone(), generics: vec![], span },
484 _ => Type::Named { name: "Any".into(), generics: vec![], span }, }
486}
487
488struct MonoTypeEnv {
489 scopes: Vec<HashMap<String, ResolvedType>>,
490}
491
492impl MonoTypeEnv {
493 fn new() -> Self {
494 Self { scopes: vec![HashMap::new()] }
495 }
496 fn push(&mut self) { self.scopes.push(HashMap::new()); }
497 fn pop(&mut self) { self.scopes.pop(); }
498
499 fn define(&mut self, name: String, ty: ResolvedType) {
500 if let Some(s) = self.scopes.last_mut() {
501 s.insert(name, ty);
502 }
503 }
504
505 fn get(&self, name: &str) -> ResolvedType {
506 for s in self.scopes.iter().rev() {
507 if let Some(t) = s.get(name) { return t.clone(); }
508 }
509 ResolvedType::Unknown
510 }
511}
512
513fn lower_async_fn(ctx: &mut MonoContext, func: &TypedFunction) -> KoreResult<TypedFunction> {
514 let state_machine_name = format!("{}_Future", func.ast.name);
515
516 let mut fields = HashMap::new();
519 fields.insert("state".to_string(), ResolvedType::Int(IntSize::I64));
520
521 for param in &func.ast.params {
523 fields.insert(param.name.clone(), resolve_ast_type(¶m.ty).unwrap_or(ResolvedType::Unknown));
524 }
525
526 let locals = collect_locals(&func.ast.body);
528 for (name, ty) in locals {
529 fields.entry(name).or_insert(ty);
530 }
531
532 let _struct_ty = ResolvedType::Struct(state_machine_name.clone(), fields.clone());
533
534 ctx.structs.insert(state_machine_name.clone(), fields.clone());
536
537 let struct_def = TypedItem::Struct(TypedStruct {
540 ast: Struct {
541 name: state_machine_name.clone(),
542 generics: vec![],
543 fields: fields.iter().map(|(n, t)| Field {
544 name: n.clone(),
545 ty: resolved_to_ast_type(t, func.ast.span),
546 visibility: Visibility::Public,
547 default: None,
548 weak: false,
549 span: func.ast.span
550 }).collect(),
551 visibility: Visibility::Public,
552 span: func.ast.span,
553 },
554 field_types: fields.clone(),
555 });
556 ctx.concrete_items.push(struct_def);
557
558 let poll_name = format!("{}_poll", state_machine_name);
561
562 let self_type = ResolvedType::Struct(state_machine_name.clone(), fields.clone());
564 let self_param = Param {
565 name: "self".to_string(),
566 ty: resolved_to_ast_type(&self_type, func.ast.span),
567 mutable: true,
568 default: None,
569 span: func.ast.span,
570 };
571
572 let await_points = collect_await_points(&func.ast.body);
576
577 for (i, _) in await_points.iter().enumerate() {
579 let field_name = format!("_await_{}", i);
580 fields.insert(field_name, ResolvedType::Unknown);
582
583 let res_name = format!("_await_{}_result", i);
585 fields.insert(res_name, ResolvedType::Unknown);
586 }
587
588 ctx.structs.insert(state_machine_name.clone(), fields.clone());
590
591 let mut arms = Vec::new();
593
594 if await_points.is_empty() {
595 let mut rewritten_body = func.ast.body.clone();
597 rewrite_access_to_self(&mut rewritten_body, &fields);
598
599 let body_with_ready = wrap_return_in_poll_ready(rewritten_body, func.ast.span);
601
602 let arm0 = MatchArm {
603 pattern: Pattern::Literal(Expr::Int(0, func.ast.span)),
604 guard: None,
605 body: body_with_ready,
606 span: func.ast.span,
607 };
608 arms.push(arm0);
609 } else {
610 let segments = split_at_awaits(&func.ast.body, &await_points);
612
613 for (state_idx, segment) in segments.iter().enumerate() {
614 let arm = generate_state_arm(
615 state_idx,
616 segment,
617 &await_points,
618 &fields,
619 &state_machine_name,
620 func.ast.span,
621 );
622 arms.push(arm);
623 }
624 }
625
626 let arm_wild = MatchArm {
628 pattern: Pattern::Wildcard(func.ast.span),
629 guard: None,
630 body: Expr::Call {
631 callee: Box::new(Expr::Ident("panic".to_string(), func.ast.span)),
632 args: vec![CallArg {
633 name: None,
634 value: Expr::String("polled after completion".to_string(), func.ast.span),
635 span: func.ast.span,
636 }],
637 span: func.ast.span,
638 },
639 span: func.ast.span,
640 };
641 arms.push(arm_wild);
642
643 let mut poll_body = Block { stmts: vec![], span: func.ast.span };
645
646 let match_expr = Expr::Match {
647 scrutinee: Box::new(Expr::Field {
648 object: Box::new(Expr::Ident("self".to_string(), func.ast.span)),
649 field: "state".to_string(),
650 span: func.ast.span
651 }),
652 arms,
653 span: func.ast.span,
654 };
655
656 poll_body.stmts.push(Stmt::Expr(match_expr));
657
658 let poll_fn = TypedItem::Function(TypedFunction {
659 ast: Function {
660 name: poll_name.clone(),
661 generics: vec![],
662 params: vec![self_param],
663 return_type: None, effects: vec![],
665 body: poll_body,
666 visibility: Visibility::Public,
667 span: func.ast.span,
668 },
669 resolved_type: ResolvedType::Function {
670 params: vec![self_type],
671 ret: Box::new(ResolvedType::Unit), effects: crate::effects::EffectSet::new(),
673 },
674 effects: crate::effects::EffectSet::new(),
675 });
676 ctx.concrete_items.push(poll_fn);
677
678 let mut entry_fn = func.clone();
681
682 let mut init_fields = Vec::new();
684 init_fields.push(("state".to_string(), Expr::Int(0, func.ast.span)));
685 for param in &func.ast.params {
686 init_fields.push((param.name.clone(), Expr::Ident(param.name.clone(), func.ast.span)));
687 }
688
689 for (i, _) in await_points.iter().enumerate() {
691 init_fields.push((format!("_await_{}", i), Expr::None(func.ast.span)));
692 init_fields.push((format!("_await_{}_result", i), Expr::None(func.ast.span)));
693 }
694
695 let captured_locals = collect_locals(&func.ast.body);
697 for (name, _) in captured_locals {
698 if func.ast.params.iter().any(|p| p.name == name) { continue; }
700 init_fields.push((name, Expr::None(func.ast.span)));
701 }
702
703 let body_expr = Expr::Struct {
704 name: state_machine_name.clone(),
705 fields: init_fields,
706 span: func.ast.span,
707 };
708
709 entry_fn.ast.body = Block {
710 stmts: vec![Stmt::Return(Some(body_expr), func.ast.span)],
711 span: func.ast.span,
712 };
713
714 entry_fn.resolved_type = ResolvedType::Function {
718 params: if let ResolvedType::Function{params, ..} = &func.resolved_type { params.clone() } else { vec![] },
719 ret: Box::new(ResolvedType::Struct(state_machine_name, fields)),
720 effects: crate::effects::EffectSet::new(), };
722
723 entry_fn.effects.effects.remove(&crate::effects::Effect::Async);
725 entry_fn.ast.effects.retain(|e| *e != crate::effects::Effect::Async);
726
727 Ok(entry_fn)
728}
729
730fn rewrite_access_to_self(block: &mut Block, fields: &HashMap<String, ResolvedType>) {
731 for stmt in &mut block.stmts {
732 rewrite_stmt(stmt, fields);
733 }
734}
735
736fn rewrite_stmt(stmt: &mut Stmt, fields: &HashMap<String, ResolvedType>) {
737 match stmt {
739 Stmt::Expr(e) => rewrite_expr(e, fields),
740 Stmt::Return(Some(e), _) => rewrite_expr(e, fields),
741 Stmt::Let { value: Some(e), .. } => rewrite_expr(e, fields),
742 Stmt::For { iter, body, .. } => {
743 rewrite_expr(iter, fields);
744 rewrite_access_to_self(body, fields);
745 }
746 Stmt::While { condition, body, .. } => {
747 rewrite_expr(condition, fields);
748 rewrite_access_to_self(body, fields);
749 }
750 _ => {}
751 }
752
753 let transform = if let Stmt::Let { pattern: Pattern::Binding { name, .. }, value: Some(e), span, .. } = stmt {
755 if fields.contains_key(name) {
756 Some((name.clone(), e.clone(), *span))
757 } else { None }
758 } else { None };
759
760 if let Some((name, val, span)) = transform {
761 *stmt = Stmt::Expr(Expr::Assign {
762 target: Box::new(Expr::Field {
763 object: Box::new(Expr::Ident("self".to_string(), span)),
764 field: name,
765 span,
766 }),
767 value: Box::new(val),
768 span,
769 });
770 }
771}
772
773fn rewrite_expr(expr: &mut Expr, fields: &HashMap<String, ResolvedType>) {
774 match expr {
775 Expr::Ident(name, span) => {
776 if fields.contains_key(name) {
777 *expr = Expr::Field {
779 object: Box::new(Expr::Ident("self".to_string(), *span)),
780 field: name.clone(),
781 span: *span,
782 };
783 }
784 }
785 Expr::Binary { left, right, .. } => {
786 rewrite_expr(left, fields);
787 rewrite_expr(right, fields);
788 }
789 Expr::Call { callee, args, .. } => {
790 rewrite_expr(callee, fields);
791 for arg in args {
792 rewrite_expr(&mut arg.value, fields);
793 }
794 }
795 Expr::Field { object, .. } => rewrite_expr(object, fields),
796 Expr::Await(inner, _) => rewrite_expr(inner, fields),
797 Expr::Block(b, _) => rewrite_access_to_self(b, fields),
798 _ => {}
800 }
801}
802
803#[derive(Clone, Debug)]
807struct AwaitPoint {
808 awaited_expr: Expr,
810 result_binding: Option<String>,
812 index: usize,
814}
815
816fn collect_await_points(block: &Block) -> Vec<AwaitPoint> {
818 let mut points = Vec::new();
819 collect_awaits_from_block(block, &mut points);
820 points
821}
822
823fn collect_awaits_from_block(block: &Block, points: &mut Vec<AwaitPoint>) {
824 for stmt in &block.stmts {
825 collect_awaits_from_stmt(stmt, points);
826 }
827}
828
829fn collect_awaits_from_stmt(stmt: &Stmt, points: &mut Vec<AwaitPoint>) {
830 match stmt {
831 Stmt::Let { pattern, value, .. } => {
832 let name = match pattern {
834 Pattern::Binding { name: n, .. } => Some(n.clone()),
835 _ => None,
836 };
837
838 if let Some(expr) = value {
840 if let Expr::Await(inner, _) = expr {
841 points.push(AwaitPoint {
842 awaited_expr: (**inner).clone(),
843 result_binding: name,
844 index: points.len(),
845 });
846 } else {
847 collect_awaits_from_expr(expr, points);
848 }
849 }
850 }
851 Stmt::Expr(expr) => {
852 if let Expr::Await(inner, _) = expr {
853 points.push(AwaitPoint {
854 awaited_expr: (**inner).clone(),
855 result_binding: None,
856 index: points.len(),
857 });
858 } else {
859 collect_awaits_from_expr(expr, points);
860 }
861 }
862 Stmt::Return(Some(expr), _) => {
863 if let Expr::Await(inner, _) = expr {
864 points.push(AwaitPoint {
865 awaited_expr: (**inner).clone(),
866 result_binding: None, index: points.len(),
868 });
869 } else {
870 collect_awaits_from_expr(expr, points);
871 }
872 }
873 Stmt::While { body, .. } | Stmt::Loop { body, .. } => {
876 collect_awaits_from_block(body, points);
877 }
878 Stmt::For { body, .. } => {
879 collect_awaits_from_block(body, points);
880 }
881 _ => {}
882 }
883}
884
885fn collect_awaits_from_expr(expr: &Expr, points: &mut Vec<AwaitPoint>) {
886 match expr {
887 Expr::Await(inner, _) => {
888 points.push(AwaitPoint {
889 awaited_expr: (**inner).clone(),
890 result_binding: None,
891 index: points.len(),
892 });
893 }
894 Expr::Binary { left, right, .. } => {
895 collect_awaits_from_expr(left, points);
896 collect_awaits_from_expr(right, points);
897 }
898 Expr::Call { callee, args, .. } => {
899 collect_awaits_from_expr(callee, points);
900 for arg in args {
901 collect_awaits_from_expr(&arg.value, points);
902 }
903 }
904 Expr::Block(block, _) => collect_awaits_from_block(block, points),
905 Expr::If { then_branch, else_branch, .. } => {
906 collect_awaits_from_block(then_branch, points);
907 if let Some(else_b) = else_branch {
908 match else_b.as_ref() {
909 ElseBranch::Else(b) => collect_awaits_from_block(b, points),
910 ElseBranch::ElseIf(_, then_b, _) => collect_awaits_from_block(then_b, points),
911 }
912 }
913 }
914 _ => {}
915 }
916}
917
918#[derive(Clone)]
920struct CodeSegment {
921 stmts_before: Vec<Stmt>,
923 await_point: Option<AwaitPoint>,
925 ends_with_return: bool,
927}
928
929fn split_at_awaits(block: &Block, await_points: &[AwaitPoint]) -> Vec<CodeSegment> {
931 let mut segments = Vec::new();
932 let mut current_stmts = Vec::new();
933 let mut await_idx = 0;
934
935 for stmt in &block.stmts {
936 let contains_await = match stmt {
938 Stmt::Let { value: Some(Expr::Await(_, _)), .. } => true,
939 Stmt::Expr(Expr::Await(_, _)) => true,
940 Stmt::Return(Some(Expr::Await(_, _)), _) => true,
941 _ => false,
942 };
943
944 if contains_await && await_idx < await_points.len() {
945 segments.push(CodeSegment {
947 stmts_before: current_stmts.clone(),
948 await_point: Some(await_points[await_idx].clone()),
949 ends_with_return: matches!(stmt, Stmt::Return(_, _)),
950 });
951 current_stmts.clear();
952 await_idx += 1;
953 } else {
954 current_stmts.push(stmt.clone());
955 }
956 }
957
958 if !current_stmts.is_empty() || segments.is_empty() {
960 let ends_with_return = current_stmts.last()
961 .map(|s| matches!(s, Stmt::Return(_, _)))
962 .unwrap_or(false);
963 segments.push(CodeSegment {
964 stmts_before: current_stmts,
965 await_point: None,
966 ends_with_return,
967 });
968 }
969
970 segments
971}
972
973fn generate_state_arm(
975 state_idx: usize,
976 segment: &CodeSegment,
977 await_points: &[AwaitPoint],
978 fields: &HashMap<String, ResolvedType>,
979 _state_machine_name: &str,
980 span: crate::span::Span,
981) -> MatchArm {
982 let mut body_stmts = Vec::new();
983
984 if state_idx > 0 && state_idx <= await_points.len() {
986 let prev_await = &await_points[state_idx - 1];
987 let poll_field = format!("_await_{}", prev_await.index);
988 let res_field = format!("_await_{}_result", prev_await.index);
989
990 let poll_call = Expr::MethodCall {
993 receiver: Box::new(Expr::Field {
994 object: Box::new(Expr::Ident("self".to_string(), span)),
995 field: poll_field,
996 span,
997 }),
998 method: "poll".to_string(),
999 args: vec![],
1000 span,
1001 };
1002
1003 let pending_arm = MatchArm {
1005 pattern: Pattern::Variant {
1006 enum_name: Some("Poll".to_string()),
1007 variant: "Pending".to_string(),
1008 fields: VariantPatternFields::Unit,
1009 span,
1010 },
1011 guard: None,
1012 body: Expr::Return(
1013 Some(Box::new(Expr::EnumVariant {
1014 enum_name: "Poll".to_string(),
1015 variant: "Pending".to_string(),
1016 fields: EnumVariantFields::Unit,
1017 span,
1018 })),
1019 span,
1020 ),
1021 span,
1022 };
1023
1024 let val_name = "val".to_string();
1027 let ready_arm = MatchArm {
1028 pattern: Pattern::Variant {
1029 enum_name: Some("Poll".to_string()),
1030 variant: "Ready".to_string(),
1031 fields: VariantPatternFields::Tuple(vec![
1032 Pattern::Binding { name: val_name.clone(), mutable: false, span }
1033 ]),
1034 span,
1035 },
1036 guard: None,
1037 body: Expr::Assign {
1038 target: Box::new(Expr::Field {
1039 object: Box::new(Expr::Ident("self".to_string(), span)),
1040 field: res_field.clone(),
1041 span,
1042 }),
1043 value: Box::new(Expr::Ident(val_name, span)),
1044 span,
1045 },
1046 span,
1047 };
1048
1049 let poll_match = Expr::Match {
1050 scrutinee: Box::new(poll_call),
1051 arms: vec![pending_arm, ready_arm],
1052 span,
1053 };
1054
1055 body_stmts.push(Stmt::Expr(poll_match));
1056
1057 if let Some(binding) = &prev_await.result_binding {
1061 if fields.contains_key(binding) {
1062 body_stmts.push(Stmt::Expr(Expr::Assign {
1064 target: Box::new(Expr::Field {
1065 object: Box::new(Expr::Ident("self".to_string(), span)),
1066 field: binding.clone(),
1067 span,
1068 }),
1069 value: Box::new(Expr::Field {
1070 object: Box::new(Expr::Ident("self".to_string(), span)),
1071 field: res_field,
1072 span,
1073 }),
1074 span,
1075 }));
1076 } else {
1077 body_stmts.push(Stmt::Let {
1078 pattern: Pattern::Binding { name: binding.clone(), mutable: false, span },
1079 ty: None,
1080 value: Some(Expr::Field {
1081 object: Box::new(Expr::Ident("self".to_string(), span)),
1082 field: res_field,
1083 span,
1084 }),
1085 span,
1086 });
1087 }
1088 }
1089 }
1090
1091 for stmt in &segment.stmts_before {
1093 let mut rewritten_stmt = stmt.clone();
1094 rewrite_stmt(&mut rewritten_stmt, fields);
1095 body_stmts.push(rewritten_stmt);
1096 }
1097
1098 if let Some(await_point) = &segment.await_point {
1100 let store_field = format!("_await_{}", await_point.index);
1102 let mut awaited_expr = await_point.awaited_expr.clone();
1103 rewrite_expr(&mut awaited_expr, fields);
1104
1105 body_stmts.push(Stmt::Expr(Expr::Assign {
1106 target: Box::new(Expr::Field {
1107 object: Box::new(Expr::Ident("self".to_string(), span)),
1108 field: store_field,
1109 span,
1110 }),
1111 value: Box::new(awaited_expr),
1112 span,
1113 }));
1114
1115 body_stmts.push(Stmt::Expr(Expr::Assign {
1117 target: Box::new(Expr::Field {
1118 object: Box::new(Expr::Ident("self".to_string(), span)),
1119 field: "state".to_string(),
1120 span,
1121 }),
1122 value: Box::new(Expr::Int((state_idx + 1) as i64, span)),
1123 span,
1124 }));
1125
1126 body_stmts.push(Stmt::Return(
1128 Some(Expr::EnumVariant {
1129 enum_name: "Poll".to_string(),
1130 variant: "Pending".to_string(),
1131 fields: EnumVariantFields::Unit,
1132 span,
1133 }),
1134 span,
1135 ));
1136 } else if segment.ends_with_return {
1137 } else if state_idx == await_points.len() {
1140 body_stmts.push(Stmt::Return(
1142 Some(Expr::EnumVariant {
1143 enum_name: "Poll".to_string(),
1144 variant: "Ready".to_string(),
1145 fields: EnumVariantFields::Tuple(vec![Expr::None(span)]),
1146 span,
1147 }),
1148 span,
1149 ));
1150 }
1151
1152 MatchArm {
1153 pattern: Pattern::Literal(Expr::Int(state_idx as i64, span)),
1154 guard: None,
1155 body: Expr::Block(Block { stmts: body_stmts, span }, span),
1156 span,
1157 }
1158}
1159
1160fn wrap_return_in_poll_ready(mut block: Block, span: crate::span::Span) -> Expr {
1162 for stmt in &mut block.stmts {
1163 wrap_stmt_returns(stmt, span);
1164 }
1165 Expr::Block(block, span)
1166}
1167
1168fn wrap_stmt_returns(stmt: &mut Stmt, span: crate::span::Span) {
1169 match stmt {
1170 Stmt::Return(Some(expr), _) => {
1171 let inner = std::mem::replace(expr, Expr::None(span));
1173 *expr = Expr::EnumVariant {
1174 enum_name: "Poll".to_string(),
1175 variant: "Ready".to_string(),
1176 fields: EnumVariantFields::Tuple(vec![inner]),
1177 span,
1178 };
1179 }
1180 Stmt::Return(None, s) => {
1181 *stmt = Stmt::Return(
1183 Some(Expr::EnumVariant {
1184 enum_name: "Poll".to_string(),
1185 variant: "Ready".to_string(),
1186 fields: EnumVariantFields::Tuple(vec![Expr::None(span)]),
1187 span,
1188 }),
1189 *s,
1190 );
1191 }
1192 Stmt::While { body, .. } | Stmt::Loop { body, .. } => {
1195 for s in &mut body.stmts {
1196 wrap_stmt_returns(s, span);
1197 }
1198 }
1199 Stmt::For { body, .. } => {
1200 for s in &mut body.stmts {
1201 wrap_stmt_returns(s, span);
1202 }
1203 }
1204 _ => {}
1205 }
1206}
1207
1208fn scan_function(ctx: &mut MonoContext, func: &TypedFunction) -> KoreResult<TypedFunction> {
1209 let mut new_func = func.clone();
1210 let mut env = MonoTypeEnv::new();
1211
1212 if let ResolvedType::Function { params, .. } = &func.resolved_type {
1213 for (i, p) in params.iter().enumerate() {
1214 if i < func.ast.params.len() {
1215 env.define(func.ast.params[i].name.clone(), p.clone());
1216 }
1217 }
1218 }
1219
1220 scan_block(ctx, &mut env, &mut new_func.ast.body)?;
1221 Ok(new_func)
1222}
1223
1224fn scan_block(ctx: &mut MonoContext, env: &mut MonoTypeEnv, block: &mut Block) -> KoreResult<()> {
1225 env.push();
1226 for stmt in &mut block.stmts {
1227 scan_stmt(ctx, env, stmt)?;
1228 }
1229 env.pop();
1230 Ok(())
1231}
1232
1233fn scan_stmt(ctx: &mut MonoContext, env: &mut MonoTypeEnv, stmt: &mut Stmt) -> KoreResult<()> {
1234 match stmt {
1235 Stmt::Expr(e) => { scan_expr(ctx, env, e)?; }
1236 Stmt::Return(Some(e), _) => { scan_expr(ctx, env, e)?; }
1237 Stmt::Let { pattern, value, .. } => {
1238 if let Some(val_expr) = value {
1240 let ty = scan_expr(ctx, env, val_expr)?;
1241 if let Pattern::Binding { name, .. } = pattern {
1243 env.define(name.clone(), ty);
1244 }
1245 }
1246 }
1247 Stmt::For { binding, iter, body, .. } => {
1248 let iter_ty = scan_expr(ctx, env, iter)?;
1249 let elem_ty = match iter_ty {
1250 ResolvedType::Array(inner, _) => *inner,
1251 _ => ResolvedType::Int(IntSize::I64),
1252 };
1253
1254 env.push();
1255 if let Pattern::Binding { name, .. } = binding {
1256 env.define(name.clone(), elem_ty);
1257 }
1258 scan_block(ctx, env, body)?;
1259 env.pop();
1260 }
1261 Stmt::While { condition, body, .. } => {
1262 scan_expr(ctx, env, condition)?;
1263 scan_block(ctx, env, body)?;
1264 }
1265 _ => {}
1266 }
1267 Ok(())
1268}
1269
1270fn scan_expr(ctx: &mut MonoContext, env: &mut MonoTypeEnv, expr: &mut Expr) -> KoreResult<ResolvedType> {
1271 match expr {
1272 Expr::Int(_, _) => Ok(ResolvedType::Int(IntSize::I64)),
1273 Expr::Float(_, _) => Ok(ResolvedType::Float(FloatSize::F64)),
1274 Expr::String(_, _) => Ok(ResolvedType::String),
1275 Expr::Bool(_, _) => Ok(ResolvedType::Bool),
1276 Expr::Ident(name, _) => Ok(env.get(name)),
1277 Expr::Struct { name, fields, .. } => {
1278 for (_, val) in fields {
1279 scan_expr(ctx, env, val)?;
1280 }
1281 Ok(ResolvedType::Struct(name.clone(), HashMap::new()))
1284 },
1285 Expr::Field { object, field, span: _ } => {
1286 let obj_ty = scan_expr(ctx, env, object)?;
1287 match obj_ty {
1288 ResolvedType::Struct(name, _) => {
1289 if let Some(fields) = ctx.structs.get(&name) {
1290 if let Some(ty) = fields.get(field) {
1291 return Ok(ty.clone());
1292 }
1293 }
1294 Ok(ResolvedType::Unknown)
1297 }
1298 _ => Ok(ResolvedType::Unknown),
1299 }
1300 },
1301 Expr::MethodCall { receiver, method, args, span } => {
1302 let receiver_ty = scan_expr(ctx, env, receiver)?;
1303
1304 let type_name = match &receiver_ty {
1305 ResolvedType::Struct(name, _) => name.clone(),
1306 ResolvedType::Int(_) => "Int".to_string(),
1307 ResolvedType::Float(_) => "Float".to_string(),
1308 ResolvedType::String => "String".to_string(),
1309 _ => {
1310 if let ResolvedType::Unknown = receiver_ty {
1311 return Ok(ResolvedType::Unknown);
1313 }
1314 format!("{:?}", receiver_ty)
1315 }
1316 };
1317
1318 let mangled_target = {
1319 let methods = ctx.methods.get(&type_name);
1320 if let Some(lookup) = methods {
1321 lookup.get(method).cloned()
1322 } else {
1323 None
1324 }
1325 };
1326
1327 if let Some(target_name) = mangled_target {
1328 let mut new_args = args.clone();
1329 new_args.insert(0, CallArg { name: None, value: *receiver.clone(), span: receiver.span() });
1330
1331 for arg in &mut new_args {
1332 scan_expr(ctx, env, &mut arg.value)?;
1333 }
1334
1335 *expr = Expr::Call {
1336 callee: Box::new(Expr::Ident(target_name, *span)), args: new_args,
1338 span: *span
1339 };
1340
1341 return Ok(ResolvedType::Unknown);
1344 }
1345
1346 Ok(ResolvedType::Unknown)
1347 }
1348 Expr::Call { callee, args, .. } => {
1349 if let Expr::Ident(name, _) = callee.as_ref() {
1350 if let Some(generic_func) = ctx.generic_functions.get(name).cloned() {
1351 let mut arg_types = Vec::new();
1353 for arg in args {
1354 arg_types.push(scan_expr(ctx, env, &mut arg.value)?);
1355 }
1356
1357 let inferred_type_args = infer_type_args(ctx, &generic_func, &arg_types)?;
1359
1360 let new_name = ctx.instantiate(name, &inferred_type_args)?;
1361 *callee = Box::new(Expr::Ident(new_name, callee.span()));
1362 return Ok(ResolvedType::Unknown);
1363 }
1364
1365 }
1368 for arg in args {
1369 scan_expr(ctx, env, &mut arg.value)?;
1370 }
1371 Ok(ResolvedType::Unknown)
1372 }
1373 Expr::Binary { left, right, .. } => {
1374 let t = scan_expr(ctx, env, left)?;
1375 scan_expr(ctx, env, right)?;
1376 Ok(t)
1377 }
1378 Expr::Assign { value, .. } => scan_expr(ctx, env, value),
1379 Expr::Block(b, _) => {
1380 scan_block(ctx, env, b)?;
1381 Ok(ResolvedType::Unknown)
1382 }
1383 Expr::If { condition, then_branch, else_branch, .. } => {
1384 scan_expr(ctx, env, condition)?;
1385 scan_block(ctx, env, then_branch)?;
1386 if let Some(b) = else_branch {
1387 match b.as_mut() {
1388 ElseBranch::Else(blk) => { scan_block(ctx, env, blk)?; }
1389 ElseBranch::ElseIf(_, _, _) => {}
1390 }
1391 }
1392 Ok(ResolvedType::Unknown)
1393 }
1394 Expr::Await(inner, _) => {
1395 scan_expr(ctx, env, inner)
1397 }
1398 _ => Ok(ResolvedType::Unknown),
1399 }
1400}
1401
1402fn collect_locals(block: &Block) -> HashMap<String, ResolvedType> {
1403 let mut locals = HashMap::new();
1404 collect_locals_recursive(block, &mut locals);
1405 locals
1406}
1407
1408fn collect_locals_recursive(block: &Block, locals: &mut HashMap<String, ResolvedType>) {
1409 for stmt in &block.stmts {
1410 match stmt {
1411 Stmt::Let { pattern, .. } => collect_from_pattern(pattern, locals),
1412 Stmt::For { body, .. } => collect_locals_recursive(body, locals),
1413 Stmt::While { body, .. } => collect_locals_recursive(body, locals),
1414 Stmt::Expr(Expr::Block(b, _)) => collect_locals_recursive(b, locals),
1415 Stmt::Expr(Expr::If { then_branch, else_branch, .. }) => {
1416 collect_locals_recursive(then_branch, locals);
1417 if let Some(b) = else_branch {
1418 collect_from_else(b, locals);
1419 }
1420 }
1421 _ => {}
1422 }
1423 }
1424}
1425
1426fn collect_from_else(branch: &ElseBranch, locals: &mut HashMap<String, ResolvedType>) {
1427 match branch {
1428 ElseBranch::Else(block) => collect_locals_recursive(block, locals),
1429 ElseBranch::ElseIf(_, block, next) => {
1430 collect_locals_recursive(block, locals);
1431 if let Some(n) = next {
1432 collect_from_else(n, locals);
1433 }
1434 }
1435 }
1436}
1437
1438fn collect_from_pattern(pattern: &Pattern, locals: &mut HashMap<String, ResolvedType>) {
1439 match pattern {
1440 Pattern::Binding { name, .. } => { locals.insert(name.clone(), ResolvedType::Unknown); },
1441 Pattern::Tuple(pats, _) => {
1442 for p in pats { collect_from_pattern(p, locals); }
1443 }
1444 Pattern::Slice { patterns, rest, .. } => {
1445 for p in patterns { collect_from_pattern(p, locals); }
1446 if let Some(r) = rest {
1447 locals.insert(r.clone(), ResolvedType::Unknown);
1448 }
1449 }
1450 Pattern::Struct { fields, .. } => {
1451 for (_, p) in fields { collect_from_pattern(p, locals); }
1452 }
1453 Pattern::Variant { fields, .. } => {
1454 match fields {
1455 VariantPatternFields::Tuple(pats) => { for p in pats { collect_from_pattern(p, locals); } },
1456 VariantPatternFields::Struct(pats) => { for (_, p) in pats { collect_from_pattern(p, locals); } },
1457 _ => {}
1458 }
1459 }
1460 Pattern::Or(pats, _) => {
1461 for p in pats { collect_from_pattern(p, locals); }
1462 }
1463 _ => {}
1464 }
1465}
1466