1use crate::types::*;
2use crate::ast::*;
3use crate::error::{KainResult, KainError};
4use std::collections::{HashMap, HashSet};
5
6pub struct MonomorphizedProgram {
8 pub items: Vec<TypedItem>,
9}
10
11pub fn monomorphize(program: &TypedProgram) -> KainResult<MonomorphizedProgram> {
12 let mut ctx = MonoContext::new();
13
14
15
16 for item in &program.items {
18 match item {
19 TypedItem::Function(func) => {
20
21 if !func.ast.generics.is_empty() {
22 ctx.generic_functions.insert(func.ast.name.clone(), func.clone());
23
24 } else {
25 ctx.concrete_items.push(item.clone());
26
27 }
28 }
29 TypedItem::Struct(s) => {
30 let mut fields = HashMap::new();
31 for f in &s.ast.fields {
32 if let Ok(ty) = resolve_ast_type(&f.ty) {
33 fields.insert(f.name.clone(), ty);
34 }
35 }
36 ctx.structs.insert(s.ast.name.clone(), fields);
37 ctx.concrete_items.push(item.clone());
38 }
39 TypedItem::Impl(imp) => {
40 let type_name = match &imp.ast.target_type {
43 Type::Named { name, .. } => name.clone(),
44 _ => continue, };
46
47 let target_ty = resolve_ast_type(&imp.ast.target_type).unwrap_or(ResolvedType::Unknown);
48
49 if let Some(trait_name) = &imp.ast.trait_name {
51 let type_name_str = type_to_string(&target_ty);
52 ctx.trait_impls.insert((trait_name.clone(), type_name_str));
53 }
54
55 for method in &imp.ast.methods {
56 let mangled_name = format!("{}_{}", type_name, method.name);
57
58 let mut standalone_fn = method.clone();
59 standalone_fn.name = mangled_name.clone();
60
61 let mut params = Vec::new();
63 for p in &method.params {
64 if p.name == "self" {
65 params.push(target_ty.clone());
66 } else {
67 params.push(resolve_ast_type(&p.ty).unwrap_or(ResolvedType::Unknown));
68 }
69 }
70 let ret = method.return_type.as_ref()
71 .map(|t| resolve_ast_type(t).unwrap_or(ResolvedType::Unknown))
72 .unwrap_or(ResolvedType::Unit);
73
74 let method_ty = ResolvedType::Function {
75 params,
76 ret: Box::new(ret),
77 effects: crate::effects::EffectSet::new(), };
79
80 let typed_method = TypedFunction {
81 ast: standalone_fn,
82 resolved_type: method_ty,
83 effects: crate::effects::EffectSet::new(),
84 };
85
86 ctx.methods.entry(type_name.clone()).or_default().insert(method.name.clone(), mangled_name.clone());
87 ctx.concrete_items.push(TypedItem::Function(typed_method));
88 }
89 }
90 _ => {
91 ctx.concrete_items.push(item.clone());
92 }
93 }
94 }
95
96 let mut i = 0;
98 while i < ctx.concrete_items.len() {
99 let item = ctx.concrete_items[i].clone();
100 match item {
101 TypedItem::Function(func) => {
102 if func.effects.effects.contains(&crate::effects::Effect::Async) {
104 let entry_fn = lower_async_fn(&mut ctx, &func)?;
108
109 ctx.concrete_items[i] = TypedItem::Function(entry_fn);
111 } else {
112 let new_func = scan_function(&mut ctx, &func)?;
113 ctx.concrete_items[i] = TypedItem::Function(new_func);
114 }
115 }
116 _ => {}
117 }
118 i += 1;
119 }
120
121 Ok(MonomorphizedProgram { items: ctx.concrete_items })
122}
123
124struct MonoContext {
125 generic_functions: HashMap<String, TypedFunction>,
126 concrete_items: Vec<TypedItem>,
127 instantiated: HashMap<String, String>,
128 methods: HashMap<String, HashMap<String, String>>,
130 structs: HashMap<String, HashMap<String, ResolvedType>>,
132 trait_impls: HashSet<(String, String)>,
134}
135
136impl MonoContext {
137 fn new() -> Self {
138 Self {
139 generic_functions: HashMap::new(),
140 concrete_items: Vec::new(),
141 instantiated: HashMap::new(),
142 methods: HashMap::new(),
143 structs: HashMap::new(),
144 trait_impls: HashSet::new(),
145 }
146 }
147
148 fn instantiate(&mut self, name: &str, type_args: &[ResolvedType]) -> KainResult<String> {
149 let mangled_name = format!("{}_{}", name, mangle_types(type_args));
150
151 if self.instantiated.contains_key(&mangled_name) {
152 return Ok(mangled_name);
153 }
154
155 let generic_func = self.generic_functions.get(name)
156 .ok_or_else(|| KainError::type_error(format!("Generic function {} not found", name), crate::span::Span::new(0,0)))?
157 .clone();
158
159 if generic_func.ast.generics.len() != type_args.len() {
160 return Err(KainError::type_error(format!("Generic arg count mismatch for {}: expected {}, got {}", name, generic_func.ast.generics.len(), type_args.len()), generic_func.ast.span));
161 }
162
163 let mut mapping = HashMap::new();
164 for (i, param) in generic_func.ast.generics.iter().enumerate() {
165 mapping.insert(param.name.clone(), type_args[i].clone());
166 }
167
168 let mut new_func = generic_func.clone();
169 new_func.ast.name = mangled_name.clone();
170 new_func.ast.generics.clear();
171
172 if let ResolvedType::Function { params, ret, .. } = &mut new_func.resolved_type {
173 for p in params {
174 *p = substitute_type(p, &mapping);
175 }
176 *ret = Box::new(substitute_type(&ret, &mapping));
177 }
178
179 self.instantiated.insert(mangled_name.clone(), mangled_name.clone());
180
181 substitute_ast_types(&mut new_func.ast, &mapping);
182 self.concrete_items.push(TypedItem::Function(new_func));
183
184 Ok(mangled_name)
185 }
186}
187
188fn type_to_string(ty: &ResolvedType) -> String {
189 match ty {
190 ResolvedType::Int(_) => "Int".to_string(),
191 ResolvedType::Float(_) => "Float".to_string(),
192 ResolvedType::String => "String".to_string(),
193 ResolvedType::Bool => "Bool".to_string(),
194 ResolvedType::Unit => "Unit".to_string(),
195 ResolvedType::Struct(n, _) => n.clone(),
196 ResolvedType::Enum(n, _) => n.clone(),
197 ResolvedType::Tuple(ts) => format!("({})", ts.iter().map(type_to_string).collect::<Vec<_>>().join(", ")),
198 _ => "Any".to_string(),
199 }
200}
201
202fn mangle_types(types: &[ResolvedType]) -> String {
203 types.iter().map(type_to_string).collect::<Vec<_>>().join("_")
204}
205
206fn resolve_ast_type(ty: &Type) -> KainResult<ResolvedType> {
207 crate::types::resolve_type(ty)
208}
209
210fn unify(
213 param_type: &ResolvedType,
214 arg_type: &ResolvedType,
215 bindings: &mut HashMap<String, ResolvedType>,
216) {
217 match (param_type, arg_type) {
218 (ResolvedType::Generic(name), concrete) => {
220 if let Some(existing) = bindings.get(name) {
221 let _ = existing;
223 } else {
224 bindings.insert(name.clone(), concrete.clone());
225 }
226 }
227
228 (ResolvedType::Function { params: p_params, ret: p_ret, .. },
230 ResolvedType::Function { params: a_params, ret: a_ret, .. }) => {
231 for (pp, ap) in p_params.iter().zip(a_params.iter()) {
233 unify(pp, ap, bindings);
234 }
235 unify(p_ret, a_ret, bindings);
237 }
238
239 (ResolvedType::Array(p_inner, _), ResolvedType::Array(a_inner, _)) => {
241 unify(p_inner, a_inner, bindings);
242 }
243
244 (ResolvedType::Tuple(p_elems), ResolvedType::Tuple(a_elems)) => {
246 for (pe, ae) in p_elems.iter().zip(a_elems.iter()) {
247 unify(pe, ae, bindings);
248 }
249 }
250
251 _ => {}
253 }
254}
255
256fn infer_type_args(
258 ctx: &MonoContext,
259 generic_func: &TypedFunction,
260 arg_types: &[ResolvedType],
261) -> KainResult<Vec<ResolvedType>> {
262 let mut bindings: HashMap<String, ResolvedType> = HashMap::new();
263
264 let param_types: Vec<ResolvedType> = if let ResolvedType::Function { params, .. } = &generic_func.resolved_type {
266 params.clone()
267 } else {
268 generic_func.ast.params.iter()
270 .map(|p| resolve_ast_type(&p.ty).unwrap_or(ResolvedType::Unknown))
271 .collect()
272 };
273
274 for (param_ty, arg_ty) in param_types.iter().zip(arg_types.iter()) {
276 unify(param_ty, arg_ty, &mut bindings);
277 }
278
279 let mut inferred = Vec::new();
281 for generic in &generic_func.ast.generics {
282 if let Some(ty) = bindings.get(&generic.name) {
283 for bound in &generic.bounds {
285 let type_name = type_to_string(ty);
286 if !ctx.trait_impls.contains(&(bound.trait_name.clone(), type_name.clone())) {
287 return Err(KainError::type_error(
288 format!("Type '{}' does not satisfy bound '{}'", type_name, bound.trait_name),
289 generic.span
290 ));
291 }
292 }
293 inferred.push(ty.clone());
294 } else {
295 inferred.push(ResolvedType::Unknown);
297 }
298 }
299
300 Ok(inferred)
301}
302
303fn substitute_type(ty: &ResolvedType, mapping: &HashMap<String, ResolvedType>) -> ResolvedType {
304 match ty {
305 ResolvedType::Generic(name) => mapping.get(name).cloned().unwrap_or(ty.clone()),
306 ResolvedType::Function { params, ret, effects } => {
307 ResolvedType::Function {
308 params: params.iter().map(|p| substitute_type(p, mapping)).collect(),
309 ret: Box::new(substitute_type(ret, mapping)),
310 effects: effects.clone()
311 }
312 }
313 ResolvedType::Array(inner, n) => ResolvedType::Array(Box::new(substitute_type(inner, mapping)), *n),
314 _ => ty.clone()
315 }
316}
317
318fn substitute_ast_types(func: &mut Function, mapping: &HashMap<String, ResolvedType>) {
319 for param in &mut func.params {
321 substitute_type_ast(&mut param.ty, mapping);
322 }
323
324 if let Some(ret) = &mut func.return_type {
326 substitute_type_ast(ret, mapping);
327 }
328
329 substitute_block(&mut func.body, mapping);
331}
332
333fn substitute_block(block: &mut Block, mapping: &HashMap<String, ResolvedType>) {
334 for stmt in &mut block.stmts {
335 substitute_stmt(stmt, mapping);
336 }
337}
338
339fn substitute_stmt(stmt: &mut Stmt, mapping: &HashMap<String, ResolvedType>) {
340 match stmt {
341 Stmt::Let { ty, value, .. } => {
342 if let Some(t) = ty {
343 substitute_type_ast(t, mapping);
344 }
345 if let Some(v) = value {
346 substitute_expr(v, mapping);
347 }
348 }
349 Stmt::Expr(e) => substitute_expr(e, mapping),
350 Stmt::Return(Some(e), _) => substitute_expr(e, mapping),
351 Stmt::For { iter, body, .. } => {
352 substitute_expr(iter, mapping);
353 substitute_block(body, mapping);
354 }
355 Stmt::While { condition, body, .. } => {
356 substitute_expr(condition, mapping);
357 substitute_block(body, mapping);
358 }
359 _ => {}
360 }
361}
362
363fn substitute_expr(expr: &mut Expr, mapping: &HashMap<String, ResolvedType>) {
364 match expr {
365 Expr::Cast { value, target, .. } => {
366 substitute_expr(value, mapping);
367 substitute_type_ast(target, mapping);
368 }
369 Expr::Binary { left, right, .. } => {
370 substitute_expr(left, mapping);
371 substitute_expr(right, mapping);
372 }
373 Expr::Unary { operand, .. } => substitute_expr(operand, mapping),
374 Expr::Call { callee, args, .. } => {
375 substitute_expr(callee, mapping);
376 for arg in args {
377 substitute_expr(&mut arg.value, mapping);
378 }
379 }
380 Expr::MethodCall { receiver, args, .. } => {
381 substitute_expr(receiver, mapping);
382 for arg in args {
383 substitute_expr(&mut arg.value, mapping);
384 }
385 }
386 Expr::Field { object, .. } => {
387 substitute_expr(object, mapping);
388 }
389 Expr::Index { object, index, .. } => {
390 substitute_expr(object, mapping);
391 substitute_expr(index, mapping);
392 }
393 Expr::Struct { fields, .. } => {
394 for (_, v) in fields {
395 substitute_expr(v, mapping);
396 }
397 }
398 Expr::Array(items, _) => {
399 for item in items {
400 substitute_expr(item, mapping);
401 }
402 }
403 Expr::Tuple(items, _) => {
404 for item in items {
405 substitute_expr(item, mapping);
406 }
407 }
408 Expr::Block(b, _) => substitute_block(b, mapping),
409 Expr::If { condition, then_branch, else_branch, .. } => {
410 substitute_expr(condition, mapping);
411 substitute_block(then_branch, mapping);
412 if let Some(br) = else_branch {
413 match br.as_mut() {
414 ElseBranch::Else(b) => substitute_block(b, mapping),
415 ElseBranch::ElseIf(c, t, _) => { substitute_expr(c, mapping);
417 substitute_block(t, mapping);
418 }
419 }
420 }
421 }
422 Expr::Match { scrutinee, arms, .. } => {
423 substitute_expr(scrutinee, mapping);
424 for arm in arms {
425 substitute_expr(&mut arm.body, mapping);
426 }
427 }
428 Expr::Lambda { params, body, return_type, .. } => {
429 for p in params {
430 substitute_type_ast(&mut p.ty, mapping);
431 }
432 if let Some(ret) = return_type {
433 substitute_type_ast(ret, mapping);
434 }
435 substitute_expr(body, mapping);
436 }
437 Expr::Await(inner, _) => {
438 substitute_expr(inner, mapping);
439 }
440 _ => {}
441 }
442}
443
444fn substitute_type_ast(ty: &mut Type, mapping: &HashMap<String, ResolvedType>) {
445 match ty {
446 Type::Named { name, generics, .. } => {
447 if let Some(concrete) = mapping.get(name) {
448 *ty = resolved_to_ast_type(concrete, ty.span());
449 } else {
450 for g in generics {
451 substitute_type_ast(g, mapping);
452 }
453 }
454 }
455 Type::Tuple(types, _) => {
456 for t in types {
457 substitute_type_ast(t, mapping);
458 }
459 }
460 Type::Function { params, return_type, .. } => {
461 for p in params {
462 substitute_type_ast(p, mapping);
463 }
464 substitute_type_ast(return_type, mapping);
465 }
466 Type::Array(inner, _, _) => {
467 substitute_type_ast(inner, mapping);
468 }
469 Type::Slice(inner, _) => {
470 substitute_type_ast(inner, mapping);
471 }
472 _ => {}
473 }
474}
475
476fn resolved_to_ast_type(res: &ResolvedType, span: crate::span::Span) -> Type {
477 match res {
478 ResolvedType::Int(_) => Type::Named { name: "Int".into(), generics: vec![], span },
479 ResolvedType::Float(_) => Type::Named { name: "Float".into(), generics: vec![], span },
480 ResolvedType::Bool => Type::Named { name: "Bool".into(), generics: vec![], span },
481 ResolvedType::String => Type::Named { name: "String".into(), generics: vec![], span },
482 ResolvedType::Unit => Type::Unit(span),
483 ResolvedType::Struct(n, _) => Type::Named { name: n.clone(), generics: vec![], span },
484 _ => Type::Named { name: "Any".into(), generics: vec![], span }, }
486}
487
488struct MonoTypeEnv {
489 scopes: Vec<HashMap<String, ResolvedType>>,
490}
491
492impl MonoTypeEnv {
493 fn new() -> Self {
494 Self { scopes: vec![HashMap::new()] }
495 }
496 fn push(&mut self) { self.scopes.push(HashMap::new()); }
497 fn pop(&mut self) { self.scopes.pop(); }
498
499 fn define(&mut self, name: String, ty: ResolvedType) {
500 if let Some(s) = self.scopes.last_mut() {
501 s.insert(name, ty);
502 }
503 }
504
505 fn get(&self, name: &str) -> ResolvedType {
506 for s in self.scopes.iter().rev() {
507 if let Some(t) = s.get(name) { return t.clone(); }
508 }
509 ResolvedType::Unknown
510 }
511}
512
513fn lower_async_fn(ctx: &mut MonoContext, func: &TypedFunction) -> KainResult<TypedFunction> {
514 let state_machine_name = format!("{}_Future", func.ast.name);
515
516 let mut fields = HashMap::new();
519 fields.insert("state".to_string(), ResolvedType::Int(IntSize::I64));
520
521 for param in &func.ast.params {
523 fields.insert(param.name.clone(), resolve_ast_type(¶m.ty).unwrap_or(ResolvedType::Unknown));
524 }
525
526 let locals = collect_locals(&func.ast.body);
528 for (name, ty) in locals {
529 fields.entry(name).or_insert(ty);
530 }
531
532 let _struct_ty = ResolvedType::Struct(state_machine_name.clone(), fields.clone());
533
534 ctx.structs.insert(state_machine_name.clone(), fields.clone());
536
537 let struct_def = TypedItem::Struct(TypedStruct {
540 ast: Struct {
541 name: state_machine_name.clone(),
542 generics: vec![],
543 fields: fields.iter().map(|(n, t)| Field {
544 name: n.clone(),
545 ty: resolved_to_ast_type(t, func.ast.span),
546 visibility: Visibility::Public,
547 default: None,
548 weak: false,
549 span: func.ast.span
550 }).collect(),
551 visibility: Visibility::Public,
552 span: func.ast.span,
553 },
554 field_types: fields.clone(),
555 });
556 ctx.concrete_items.push(struct_def);
557
558 let poll_name = format!("{}_poll", state_machine_name);
561
562 let self_type = ResolvedType::Struct(state_machine_name.clone(), fields.clone());
564 let self_param = Param {
565 name: "self".to_string(),
566 ty: resolved_to_ast_type(&self_type, func.ast.span),
567 mutable: true,
568 default: None,
569 span: func.ast.span,
570 };
571
572 let await_points = collect_await_points(&func.ast.body);
576
577 for (i, _) in await_points.iter().enumerate() {
579 let field_name = format!("_await_{}", i);
580 fields.insert(field_name, ResolvedType::Unknown);
582
583 let res_name = format!("_await_{}_result", i);
585 fields.insert(res_name, ResolvedType::Unknown);
586 }
587
588 ctx.structs.insert(state_machine_name.clone(), fields.clone());
590
591 let mut arms = Vec::new();
593
594 if await_points.is_empty() {
595 let mut rewritten_body = func.ast.body.clone();
597 rewrite_access_to_self(&mut rewritten_body, &fields);
598
599 let body_with_ready = wrap_return_in_poll_ready(rewritten_body, func.ast.span);
601
602 let arm0 = MatchArm {
603 pattern: Pattern::Literal(Expr::Int(0, func.ast.span)),
604 guard: None,
605 body: body_with_ready,
606 span: func.ast.span,
607 };
608 arms.push(arm0);
609 } else {
610 let segments = split_at_awaits(&func.ast.body, &await_points);
612
613 for (state_idx, segment) in segments.iter().enumerate() {
614 let arm = generate_state_arm(
615 state_idx,
616 segment,
617 &await_points,
618 &fields,
619 &state_machine_name,
620 func.ast.span,
621 );
622 arms.push(arm);
623 }
624 }
625
626 let arm_wild = MatchArm {
628 pattern: Pattern::Wildcard(func.ast.span),
629 guard: None,
630 body: Expr::Call {
631 callee: Box::new(Expr::Ident("panic".to_string(), func.ast.span)),
632 args: vec![CallArg {
633 name: None,
634 value: Expr::String("polled after completion".to_string(), func.ast.span),
635 span: func.ast.span,
636 }],
637 span: func.ast.span,
638 },
639 span: func.ast.span,
640 };
641 arms.push(arm_wild);
642
643 let mut poll_body = Block { stmts: vec![], span: func.ast.span };
645
646 let match_expr = Expr::Match {
647 scrutinee: Box::new(Expr::Field {
648 object: Box::new(Expr::Ident("self".to_string(), func.ast.span)),
649 field: "state".to_string(),
650 span: func.ast.span
651 }),
652 arms,
653 span: func.ast.span,
654 };
655
656 poll_body.stmts.push(Stmt::Expr(match_expr));
657
658 let poll_fn = TypedItem::Function(TypedFunction {
659 ast: Function {
660 name: poll_name.clone(),
661 generics: vec![],
662 params: vec![self_param],
663 return_type: None, effects: vec![],
665 body: poll_body,
666 visibility: Visibility::Public,
667 attributes: vec![],
668 span: func.ast.span,
669 },
670 resolved_type: ResolvedType::Function {
671 params: vec![self_type],
672 ret: Box::new(ResolvedType::Unit), effects: crate::effects::EffectSet::new(),
674 },
675 effects: crate::effects::EffectSet::new(),
676 });
677 ctx.concrete_items.push(poll_fn);
678
679 let mut entry_fn = func.clone();
682
683 let mut init_fields = Vec::new();
685 init_fields.push(("state".to_string(), Expr::Int(0, func.ast.span)));
686 for param in &func.ast.params {
687 init_fields.push((param.name.clone(), Expr::Ident(param.name.clone(), func.ast.span)));
688 }
689
690 for (i, _) in await_points.iter().enumerate() {
692 init_fields.push((format!("_await_{}", i), Expr::None(func.ast.span)));
693 init_fields.push((format!("_await_{}_result", i), Expr::None(func.ast.span)));
694 }
695
696 let captured_locals = collect_locals(&func.ast.body);
698 for (name, _) in captured_locals {
699 if func.ast.params.iter().any(|p| p.name == name) { continue; }
701 init_fields.push((name, Expr::None(func.ast.span)));
702 }
703
704 let body_expr = Expr::Struct {
705 name: state_machine_name.clone(),
706 fields: init_fields,
707 span: func.ast.span,
708 };
709
710 entry_fn.ast.body = Block {
711 stmts: vec![Stmt::Return(Some(body_expr), func.ast.span)],
712 span: func.ast.span,
713 };
714
715 entry_fn.resolved_type = ResolvedType::Function {
719 params: if let ResolvedType::Function{params, ..} = &func.resolved_type { params.clone() } else { vec![] },
720 ret: Box::new(ResolvedType::Struct(state_machine_name, fields)),
721 effects: crate::effects::EffectSet::new(), };
723
724 entry_fn.effects.effects.remove(&crate::effects::Effect::Async);
726 entry_fn.ast.effects.retain(|e| *e != crate::effects::Effect::Async);
727
728 Ok(entry_fn)
729}
730
731fn rewrite_access_to_self(block: &mut Block, fields: &HashMap<String, ResolvedType>) {
732 for stmt in &mut block.stmts {
733 rewrite_stmt(stmt, fields);
734 }
735}
736
737fn rewrite_stmt(stmt: &mut Stmt, fields: &HashMap<String, ResolvedType>) {
738 match stmt {
740 Stmt::Expr(e) => rewrite_expr(e, fields),
741 Stmt::Return(Some(e), _) => rewrite_expr(e, fields),
742 Stmt::Let { value: Some(e), .. } => rewrite_expr(e, fields),
743 Stmt::For { iter, body, .. } => {
744 rewrite_expr(iter, fields);
745 rewrite_access_to_self(body, fields);
746 }
747 Stmt::While { condition, body, .. } => {
748 rewrite_expr(condition, fields);
749 rewrite_access_to_self(body, fields);
750 }
751 _ => {}
752 }
753
754 let transform = if let Stmt::Let { pattern: Pattern::Binding { name, .. }, value: Some(e), span, .. } = stmt {
756 if fields.contains_key(name) {
757 Some((name.clone(), e.clone(), *span))
758 } else { None }
759 } else { None };
760
761 if let Some((name, val, span)) = transform {
762 *stmt = Stmt::Expr(Expr::Assign {
763 target: Box::new(Expr::Field {
764 object: Box::new(Expr::Ident("self".to_string(), span)),
765 field: name,
766 span,
767 }),
768 value: Box::new(val),
769 span,
770 });
771 }
772}
773
774fn rewrite_expr(expr: &mut Expr, fields: &HashMap<String, ResolvedType>) {
775 match expr {
776 Expr::Ident(name, span) => {
777 if fields.contains_key(name) {
778 *expr = Expr::Field {
780 object: Box::new(Expr::Ident("self".to_string(), *span)),
781 field: name.clone(),
782 span: *span,
783 };
784 }
785 }
786 Expr::Binary { left, right, .. } => {
787 rewrite_expr(left, fields);
788 rewrite_expr(right, fields);
789 }
790 Expr::Call { callee, args, .. } => {
791 rewrite_expr(callee, fields);
792 for arg in args {
793 rewrite_expr(&mut arg.value, fields);
794 }
795 }
796 Expr::Field { object, .. } => rewrite_expr(object, fields),
797 Expr::Await(inner, _) => rewrite_expr(inner, fields),
798 Expr::Block(b, _) => rewrite_access_to_self(b, fields),
799 _ => {}
801 }
802}
803
804#[derive(Clone, Debug)]
808struct AwaitPoint {
809 awaited_expr: Expr,
811 result_binding: Option<String>,
813 index: usize,
815}
816
817fn collect_await_points(block: &Block) -> Vec<AwaitPoint> {
819 let mut points = Vec::new();
820 collect_awaits_from_block(block, &mut points);
821 points
822}
823
824fn collect_awaits_from_block(block: &Block, points: &mut Vec<AwaitPoint>) {
825 for stmt in &block.stmts {
826 collect_awaits_from_stmt(stmt, points);
827 }
828}
829
830fn collect_awaits_from_stmt(stmt: &Stmt, points: &mut Vec<AwaitPoint>) {
831 match stmt {
832 Stmt::Let { pattern, value, .. } => {
833 let name = match pattern {
835 Pattern::Binding { name: n, .. } => Some(n.clone()),
836 _ => None,
837 };
838
839 if let Some(expr) = value {
841 if let Expr::Await(inner, _) = expr {
842 points.push(AwaitPoint {
843 awaited_expr: (**inner).clone(),
844 result_binding: name,
845 index: points.len(),
846 });
847 } else {
848 collect_awaits_from_expr(expr, points);
849 }
850 }
851 }
852 Stmt::Expr(expr) => {
853 if let Expr::Await(inner, _) = expr {
854 points.push(AwaitPoint {
855 awaited_expr: (**inner).clone(),
856 result_binding: None,
857 index: points.len(),
858 });
859 } else {
860 collect_awaits_from_expr(expr, points);
861 }
862 }
863 Stmt::Return(Some(expr), _) => {
864 if let Expr::Await(inner, _) = expr {
865 points.push(AwaitPoint {
866 awaited_expr: (**inner).clone(),
867 result_binding: None, index: points.len(),
869 });
870 } else {
871 collect_awaits_from_expr(expr, points);
872 }
873 }
874 Stmt::While { body, .. } | Stmt::Loop { body, .. } => {
877 collect_awaits_from_block(body, points);
878 }
879 Stmt::For { body, .. } => {
880 collect_awaits_from_block(body, points);
881 }
882 _ => {}
883 }
884}
885
886fn collect_awaits_from_expr(expr: &Expr, points: &mut Vec<AwaitPoint>) {
887 match expr {
888 Expr::Await(inner, _) => {
889 points.push(AwaitPoint {
890 awaited_expr: (**inner).clone(),
891 result_binding: None,
892 index: points.len(),
893 });
894 }
895 Expr::Binary { left, right, .. } => {
896 collect_awaits_from_expr(left, points);
897 collect_awaits_from_expr(right, points);
898 }
899 Expr::Call { callee, args, .. } => {
900 collect_awaits_from_expr(callee, points);
901 for arg in args {
902 collect_awaits_from_expr(&arg.value, points);
903 }
904 }
905 Expr::Block(block, _) => collect_awaits_from_block(block, points),
906 Expr::If { then_branch, else_branch, .. } => {
907 collect_awaits_from_block(then_branch, points);
908 if let Some(else_b) = else_branch {
909 match else_b.as_ref() {
910 ElseBranch::Else(b) => collect_awaits_from_block(b, points),
911 ElseBranch::ElseIf(_, then_b, _) => collect_awaits_from_block(then_b, points),
912 }
913 }
914 }
915 _ => {}
916 }
917}
918
919#[derive(Clone)]
921struct CodeSegment {
922 stmts_before: Vec<Stmt>,
924 await_point: Option<AwaitPoint>,
926 ends_with_return: bool,
928}
929
930fn split_at_awaits(block: &Block, await_points: &[AwaitPoint]) -> Vec<CodeSegment> {
932 let mut segments = Vec::new();
933 let mut current_stmts = Vec::new();
934 let mut await_idx = 0;
935
936 for stmt in &block.stmts {
937 let contains_await = match stmt {
939 Stmt::Let { value: Some(Expr::Await(_, _)), .. } => true,
940 Stmt::Expr(Expr::Await(_, _)) => true,
941 Stmt::Return(Some(Expr::Await(_, _)), _) => true,
942 _ => false,
943 };
944
945 if contains_await && await_idx < await_points.len() {
946 segments.push(CodeSegment {
948 stmts_before: current_stmts.clone(),
949 await_point: Some(await_points[await_idx].clone()),
950 ends_with_return: matches!(stmt, Stmt::Return(_, _)),
951 });
952 current_stmts.clear();
953 await_idx += 1;
954 } else {
955 current_stmts.push(stmt.clone());
956 }
957 }
958
959 if !current_stmts.is_empty() || segments.is_empty() {
961 let ends_with_return = current_stmts.last()
962 .map(|s| matches!(s, Stmt::Return(_, _)))
963 .unwrap_or(false);
964 segments.push(CodeSegment {
965 stmts_before: current_stmts,
966 await_point: None,
967 ends_with_return,
968 });
969 }
970
971 segments
972}
973
974fn generate_state_arm(
976 state_idx: usize,
977 segment: &CodeSegment,
978 await_points: &[AwaitPoint],
979 fields: &HashMap<String, ResolvedType>,
980 _state_machine_name: &str,
981 span: crate::span::Span,
982) -> MatchArm {
983 let mut body_stmts = Vec::new();
984
985 if state_idx > 0 && state_idx <= await_points.len() {
987 let prev_await = &await_points[state_idx - 1];
988 let poll_field = format!("_await_{}", prev_await.index);
989 let res_field = format!("_await_{}_result", prev_await.index);
990
991 let poll_call = Expr::MethodCall {
994 receiver: Box::new(Expr::Field {
995 object: Box::new(Expr::Ident("self".to_string(), span)),
996 field: poll_field,
997 span,
998 }),
999 method: "poll".to_string(),
1000 args: vec![],
1001 span,
1002 };
1003
1004 let pending_arm = MatchArm {
1006 pattern: Pattern::Variant {
1007 enum_name: Some("Poll".to_string()),
1008 variant: "Pending".to_string(),
1009 fields: VariantPatternFields::Unit,
1010 span,
1011 },
1012 guard: None,
1013 body: Expr::Return(
1014 Some(Box::new(Expr::EnumVariant {
1015 enum_name: "Poll".to_string(),
1016 variant: "Pending".to_string(),
1017 fields: EnumVariantFields::Unit,
1018 span,
1019 })),
1020 span,
1021 ),
1022 span,
1023 };
1024
1025 let val_name = "val".to_string();
1028 let ready_arm = MatchArm {
1029 pattern: Pattern::Variant {
1030 enum_name: Some("Poll".to_string()),
1031 variant: "Ready".to_string(),
1032 fields: VariantPatternFields::Tuple(vec![
1033 Pattern::Binding { name: val_name.clone(), mutable: false, span }
1034 ]),
1035 span,
1036 },
1037 guard: None,
1038 body: Expr::Assign {
1039 target: Box::new(Expr::Field {
1040 object: Box::new(Expr::Ident("self".to_string(), span)),
1041 field: res_field.clone(),
1042 span,
1043 }),
1044 value: Box::new(Expr::Ident(val_name, span)),
1045 span,
1046 },
1047 span,
1048 };
1049
1050 let poll_match = Expr::Match {
1051 scrutinee: Box::new(poll_call),
1052 arms: vec![pending_arm, ready_arm],
1053 span,
1054 };
1055
1056 body_stmts.push(Stmt::Expr(poll_match));
1057
1058 if let Some(binding) = &prev_await.result_binding {
1062 if fields.contains_key(binding) {
1063 body_stmts.push(Stmt::Expr(Expr::Assign {
1065 target: Box::new(Expr::Field {
1066 object: Box::new(Expr::Ident("self".to_string(), span)),
1067 field: binding.clone(),
1068 span,
1069 }),
1070 value: Box::new(Expr::Field {
1071 object: Box::new(Expr::Ident("self".to_string(), span)),
1072 field: res_field,
1073 span,
1074 }),
1075 span,
1076 }));
1077 } else {
1078 body_stmts.push(Stmt::Let {
1079 pattern: Pattern::Binding { name: binding.clone(), mutable: false, span },
1080 ty: None,
1081 value: Some(Expr::Field {
1082 object: Box::new(Expr::Ident("self".to_string(), span)),
1083 field: res_field,
1084 span,
1085 }),
1086 span,
1087 });
1088 }
1089 }
1090 }
1091
1092 for stmt in &segment.stmts_before {
1094 let mut rewritten_stmt = stmt.clone();
1095 rewrite_stmt(&mut rewritten_stmt, fields);
1096 body_stmts.push(rewritten_stmt);
1097 }
1098
1099 if let Some(await_point) = &segment.await_point {
1101 let store_field = format!("_await_{}", await_point.index);
1103 let mut awaited_expr = await_point.awaited_expr.clone();
1104 rewrite_expr(&mut awaited_expr, fields);
1105
1106 body_stmts.push(Stmt::Expr(Expr::Assign {
1107 target: Box::new(Expr::Field {
1108 object: Box::new(Expr::Ident("self".to_string(), span)),
1109 field: store_field,
1110 span,
1111 }),
1112 value: Box::new(awaited_expr),
1113 span,
1114 }));
1115
1116 body_stmts.push(Stmt::Expr(Expr::Assign {
1118 target: Box::new(Expr::Field {
1119 object: Box::new(Expr::Ident("self".to_string(), span)),
1120 field: "state".to_string(),
1121 span,
1122 }),
1123 value: Box::new(Expr::Int((state_idx + 1) as i64, span)),
1124 span,
1125 }));
1126
1127 body_stmts.push(Stmt::Return(
1129 Some(Expr::EnumVariant {
1130 enum_name: "Poll".to_string(),
1131 variant: "Pending".to_string(),
1132 fields: EnumVariantFields::Unit,
1133 span,
1134 }),
1135 span,
1136 ));
1137 } else if segment.ends_with_return {
1138 } else if state_idx == await_points.len() {
1141 body_stmts.push(Stmt::Return(
1143 Some(Expr::EnumVariant {
1144 enum_name: "Poll".to_string(),
1145 variant: "Ready".to_string(),
1146 fields: EnumVariantFields::Tuple(vec![Expr::None(span)]),
1147 span,
1148 }),
1149 span,
1150 ));
1151 }
1152
1153 MatchArm {
1154 pattern: Pattern::Literal(Expr::Int(state_idx as i64, span)),
1155 guard: None,
1156 body: Expr::Block(Block { stmts: body_stmts, span }, span),
1157 span,
1158 }
1159}
1160
1161fn wrap_return_in_poll_ready(mut block: Block, span: crate::span::Span) -> Expr {
1163 for stmt in &mut block.stmts {
1164 wrap_stmt_returns(stmt, span);
1165 }
1166 Expr::Block(block, span)
1167}
1168
1169fn wrap_stmt_returns(stmt: &mut Stmt, span: crate::span::Span) {
1170 match stmt {
1171 Stmt::Return(Some(expr), _) => {
1172 let inner = std::mem::replace(expr, Expr::None(span));
1174 *expr = Expr::EnumVariant {
1175 enum_name: "Poll".to_string(),
1176 variant: "Ready".to_string(),
1177 fields: EnumVariantFields::Tuple(vec![inner]),
1178 span,
1179 };
1180 }
1181 Stmt::Return(None, s) => {
1182 *stmt = Stmt::Return(
1184 Some(Expr::EnumVariant {
1185 enum_name: "Poll".to_string(),
1186 variant: "Ready".to_string(),
1187 fields: EnumVariantFields::Tuple(vec![Expr::None(span)]),
1188 span,
1189 }),
1190 *s,
1191 );
1192 }
1193 Stmt::While { body, .. } | Stmt::Loop { body, .. } => {
1196 for s in &mut body.stmts {
1197 wrap_stmt_returns(s, span);
1198 }
1199 }
1200 Stmt::For { body, .. } => {
1201 for s in &mut body.stmts {
1202 wrap_stmt_returns(s, span);
1203 }
1204 }
1205 _ => {}
1206 }
1207}
1208
1209fn scan_function(ctx: &mut MonoContext, func: &TypedFunction) -> KainResult<TypedFunction> {
1210 let mut new_func = func.clone();
1211 let mut env = MonoTypeEnv::new();
1212
1213 if let ResolvedType::Function { params, .. } = &func.resolved_type {
1214 for (i, p) in params.iter().enumerate() {
1215 if i < func.ast.params.len() {
1216 env.define(func.ast.params[i].name.clone(), p.clone());
1217 }
1218 }
1219 }
1220
1221 scan_block(ctx, &mut env, &mut new_func.ast.body)?;
1222 Ok(new_func)
1223}
1224
1225fn scan_block(ctx: &mut MonoContext, env: &mut MonoTypeEnv, block: &mut Block) -> KainResult<()> {
1226 env.push();
1227 for stmt in &mut block.stmts {
1228 scan_stmt(ctx, env, stmt)?;
1229 }
1230 env.pop();
1231 Ok(())
1232}
1233
1234fn scan_stmt(ctx: &mut MonoContext, env: &mut MonoTypeEnv, stmt: &mut Stmt) -> KainResult<()> {
1235 match stmt {
1236 Stmt::Expr(e) => { scan_expr(ctx, env, e)?; }
1237 Stmt::Return(Some(e), _) => { scan_expr(ctx, env, e)?; }
1238 Stmt::Let { pattern, value, .. } => {
1239 if let Some(val_expr) = value {
1241 let ty = scan_expr(ctx, env, val_expr)?;
1242 if let Pattern::Binding { name, .. } = pattern {
1244 env.define(name.clone(), ty);
1245 }
1246 }
1247 }
1248 Stmt::For { binding, iter, body, .. } => {
1249 let iter_ty = scan_expr(ctx, env, iter)?;
1250 let elem_ty = match iter_ty {
1251 ResolvedType::Array(inner, _) => *inner,
1252 _ => ResolvedType::Int(IntSize::I64),
1253 };
1254
1255 env.push();
1256 if let Pattern::Binding { name, .. } = binding {
1257 env.define(name.clone(), elem_ty);
1258 }
1259 scan_block(ctx, env, body)?;
1260 env.pop();
1261 }
1262 Stmt::While { condition, body, .. } => {
1263 scan_expr(ctx, env, condition)?;
1264 scan_block(ctx, env, body)?;
1265 }
1266 _ => {}
1267 }
1268 Ok(())
1269}
1270
1271fn scan_expr(ctx: &mut MonoContext, env: &mut MonoTypeEnv, expr: &mut Expr) -> KainResult<ResolvedType> {
1272 match expr {
1273 Expr::Int(_, _) => Ok(ResolvedType::Int(IntSize::I64)),
1274 Expr::Float(_, _) => Ok(ResolvedType::Float(FloatSize::F64)),
1275 Expr::String(_, _) => Ok(ResolvedType::String),
1276 Expr::Bool(_, _) => Ok(ResolvedType::Bool),
1277 Expr::Ident(name, _) => Ok(env.get(name)),
1278 Expr::Struct { name, fields, .. } => {
1279 for (_, val) in fields {
1280 scan_expr(ctx, env, val)?;
1281 }
1282 Ok(ResolvedType::Struct(name.clone(), HashMap::new()))
1285 },
1286 Expr::Field { object, field, span: _ } => {
1287 let obj_ty = scan_expr(ctx, env, object)?;
1288 match obj_ty {
1289 ResolvedType::Struct(name, _) => {
1290 if let Some(fields) = ctx.structs.get(&name) {
1291 if let Some(ty) = fields.get(field) {
1292 return Ok(ty.clone());
1293 }
1294 }
1295 Ok(ResolvedType::Unknown)
1298 }
1299 _ => Ok(ResolvedType::Unknown),
1300 }
1301 },
1302 Expr::MethodCall { receiver, method, args, span } => {
1303 let receiver_ty = scan_expr(ctx, env, receiver)?;
1304
1305 let type_name = match &receiver_ty {
1306 ResolvedType::Struct(name, _) => name.clone(),
1307 ResolvedType::Int(_) => "Int".to_string(),
1308 ResolvedType::Float(_) => "Float".to_string(),
1309 ResolvedType::String => "String".to_string(),
1310 _ => {
1311 if let ResolvedType::Unknown = receiver_ty {
1312 return Ok(ResolvedType::Unknown);
1314 }
1315 format!("{:?}", receiver_ty)
1316 }
1317 };
1318
1319 let mangled_target = {
1320 let methods = ctx.methods.get(&type_name);
1321 if let Some(lookup) = methods {
1322 lookup.get(method).cloned()
1323 } else {
1324 None
1325 }
1326 };
1327
1328 if let Some(target_name) = mangled_target {
1329 let mut new_args = args.clone();
1330 new_args.insert(0, CallArg { name: None, value: *receiver.clone(), span: receiver.span() });
1331
1332 for arg in &mut new_args {
1333 scan_expr(ctx, env, &mut arg.value)?;
1334 }
1335
1336 *expr = Expr::Call {
1337 callee: Box::new(Expr::Ident(target_name, *span)), args: new_args,
1339 span: *span
1340 };
1341
1342 return Ok(ResolvedType::Unknown);
1345 }
1346
1347 Ok(ResolvedType::Unknown)
1348 }
1349 Expr::Call { callee, args, .. } => {
1350 if let Expr::Ident(name, _) = callee.as_ref() {
1351 if let Some(generic_func) = ctx.generic_functions.get(name).cloned() {
1352 let mut arg_types = Vec::new();
1354 for arg in args {
1355 arg_types.push(scan_expr(ctx, env, &mut arg.value)?);
1356 }
1357
1358 let inferred_type_args = infer_type_args(ctx, &generic_func, &arg_types)?;
1360
1361 let new_name = ctx.instantiate(name, &inferred_type_args)?;
1362 *callee = Box::new(Expr::Ident(new_name, callee.span()));
1363 return Ok(ResolvedType::Unknown);
1364 }
1365
1366 }
1369 for arg in args {
1370 scan_expr(ctx, env, &mut arg.value)?;
1371 }
1372 Ok(ResolvedType::Unknown)
1373 }
1374 Expr::Binary { left, right, .. } => {
1375 let t = scan_expr(ctx, env, left)?;
1376 scan_expr(ctx, env, right)?;
1377 Ok(t)
1378 }
1379 Expr::Assign { value, .. } => scan_expr(ctx, env, value),
1380 Expr::Block(b, _) => {
1381 scan_block(ctx, env, b)?;
1382 Ok(ResolvedType::Unknown)
1383 }
1384 Expr::If { condition, then_branch, else_branch, .. } => {
1385 scan_expr(ctx, env, condition)?;
1386 scan_block(ctx, env, then_branch)?;
1387 if let Some(b) = else_branch {
1388 match b.as_mut() {
1389 ElseBranch::Else(blk) => { scan_block(ctx, env, blk)?; }
1390 ElseBranch::ElseIf(_, _, _) => {}
1391 }
1392 }
1393 Ok(ResolvedType::Unknown)
1394 }
1395 Expr::Await(inner, _) => {
1396 scan_expr(ctx, env, inner)
1398 }
1399 _ => Ok(ResolvedType::Unknown),
1400 }
1401}
1402
1403fn collect_locals(block: &Block) -> HashMap<String, ResolvedType> {
1404 let mut locals = HashMap::new();
1405 collect_locals_recursive(block, &mut locals);
1406 locals
1407}
1408
1409fn collect_locals_recursive(block: &Block, locals: &mut HashMap<String, ResolvedType>) {
1410 for stmt in &block.stmts {
1411 match stmt {
1412 Stmt::Let { pattern, .. } => collect_from_pattern(pattern, locals),
1413 Stmt::For { body, .. } => collect_locals_recursive(body, locals),
1414 Stmt::While { body, .. } => collect_locals_recursive(body, locals),
1415 Stmt::Expr(Expr::Block(b, _)) => collect_locals_recursive(b, locals),
1416 Stmt::Expr(Expr::If { then_branch, else_branch, .. }) => {
1417 collect_locals_recursive(then_branch, locals);
1418 if let Some(b) = else_branch {
1419 collect_from_else(b, locals);
1420 }
1421 }
1422 _ => {}
1423 }
1424 }
1425}
1426
1427fn collect_from_else(branch: &ElseBranch, locals: &mut HashMap<String, ResolvedType>) {
1428 match branch {
1429 ElseBranch::Else(block) => collect_locals_recursive(block, locals),
1430 ElseBranch::ElseIf(_, block, next) => {
1431 collect_locals_recursive(block, locals);
1432 if let Some(n) = next {
1433 collect_from_else(n, locals);
1434 }
1435 }
1436 }
1437}
1438
1439fn collect_from_pattern(pattern: &Pattern, locals: &mut HashMap<String, ResolvedType>) {
1440 match pattern {
1441 Pattern::Binding { name, .. } => { locals.insert(name.clone(), ResolvedType::Unknown); },
1442 Pattern::Tuple(pats, _) => {
1443 for p in pats { collect_from_pattern(p, locals); }
1444 }
1445 Pattern::Slice { patterns, rest, .. } => {
1446 for p in patterns { collect_from_pattern(p, locals); }
1447 if let Some(r) = rest {
1448 locals.insert(r.clone(), ResolvedType::Unknown);
1449 }
1450 }
1451 Pattern::Struct { fields, .. } => {
1452 for (_, p) in fields { collect_from_pattern(p, locals); }
1453 }
1454 Pattern::Variant { fields, .. } => {
1455 match fields {
1456 VariantPatternFields::Tuple(pats) => { for p in pats { collect_from_pattern(p, locals); } },
1457 VariantPatternFields::Struct(pats) => { for (_, p) in pats { collect_from_pattern(p, locals); } },
1458 _ => {}
1459 }
1460 }
1461 Pattern::Or(pats, _) => {
1462 for p in pats { collect_from_pattern(p, locals); }
1463 }
1464 _ => {}
1465 }
1466}
1467