Skip to main content

kore/
monomorphize.rs

1use crate::types::*;
2use crate::ast::*;
3use crate::error::{KoreResult, KoreError};
4use std::collections::{HashMap, HashSet};
5
6/// Result of monomorphization
7pub struct MonomorphizedProgram {
8    pub items: Vec<TypedItem>,
9}
10
11pub fn monomorphize(program: &TypedProgram) -> KoreResult<MonomorphizedProgram> {
12    let mut ctx = MonoContext::new();
13    
14
15    
16    // 1. First Pass: Collect all global items: functions, impls
17    for item in &program.items {
18        match item {
19            TypedItem::Function(func) => {
20
21                if !func.ast.generics.is_empty() {
22                    ctx.generic_functions.insert(func.ast.name.clone(), func.clone());
23
24                } else {
25                    ctx.concrete_items.push(item.clone());
26
27                }
28            }
29            TypedItem::Struct(s) => {
30                let mut fields = HashMap::new();
31                for f in &s.ast.fields {
32                    if let Ok(ty) = resolve_ast_type(&f.ty) {
33                        fields.insert(f.name.clone(), ty);
34                    }
35                }
36                ctx.structs.insert(s.ast.name.clone(), fields);
37                ctx.concrete_items.push(item.clone());
38            }
39            TypedItem::Impl(imp) => {
40                // Register methods from impl blocks
41                // Mangle them as Type_method
42                let type_name = match &imp.ast.target_type {
43                    Type::Named { name, .. } => name.clone(),
44                    _ => continue, // Skip complex types for now
45                };
46                
47                let target_ty = resolve_ast_type(&imp.ast.target_type).unwrap_or(ResolvedType::Unknown);
48                
49                // Register trait implementation
50                if let Some(trait_name) = &imp.ast.trait_name {
51                    let type_name_str = type_to_string(&target_ty);
52                    ctx.trait_impls.insert((trait_name.clone(), type_name_str));
53                }
54                
55                for method in &imp.ast.methods {
56                    let mangled_name = format!("{}_{}", type_name, method.name);
57                    
58                    let mut standalone_fn = method.clone();
59                    standalone_fn.name = mangled_name.clone();
60                    
61                    // Resolve method type
62                    let mut params = Vec::new();
63                    for p in &method.params {
64                         if p.name == "self" {
65                             params.push(target_ty.clone());
66                         } else {
67                             params.push(resolve_ast_type(&p.ty).unwrap_or(ResolvedType::Unknown));
68                         }
69                    }
70                    let ret = method.return_type.as_ref()
71                        .map(|t| resolve_ast_type(t).unwrap_or(ResolvedType::Unknown))
72                        .unwrap_or(ResolvedType::Unit);
73                    
74                    let method_ty = ResolvedType::Function {
75                        params,
76                        ret: Box::new(ret),
77                        effects: crate::effects::EffectSet::new(), // Todo scan effects?
78                    };
79                    
80                    let typed_method = TypedFunction {
81                        ast: standalone_fn,
82                        resolved_type: method_ty,
83                        effects: crate::effects::EffectSet::new(),
84                    };
85                    
86                    ctx.methods.entry(type_name.clone()).or_default().insert(method.name.clone(), mangled_name.clone());
87                    ctx.concrete_items.push(TypedItem::Function(typed_method));
88                }
89            }
90            _ => {
91                ctx.concrete_items.push(item.clone());
92            }
93        }
94    }
95    
96    // 2. Scan concrete items for calls
97    let mut i = 0;
98    while i < ctx.concrete_items.len() {
99        let item = ctx.concrete_items[i].clone();
100        match item {
101            TypedItem::Function(func) => {
102                // Check if Async
103                if func.effects.effects.contains(&crate::effects::Effect::Async) {
104                    // Lower Async Function to State Machine
105                    // This returns the transformed entry function (synchronous, returns Future struct)
106                    // The State Machine Struct and Poll Function are pushed to ctx.concrete_items inside lower_async_fn
107                    let entry_fn = lower_async_fn(&mut ctx, &func)?;
108                    
109                    // Replace the original async function with the transformed entry function
110                    ctx.concrete_items[i] = TypedItem::Function(entry_fn);
111                } else {
112                    let new_func = scan_function(&mut ctx, &func)?;
113                    ctx.concrete_items[i] = TypedItem::Function(new_func);
114                }
115            }
116            _ => {}
117        }
118        i += 1;
119    }
120    
121    Ok(MonomorphizedProgram { items: ctx.concrete_items })
122}
123
124struct MonoContext {
125    generic_functions: HashMap<String, TypedFunction>,
126    concrete_items: Vec<TypedItem>,
127    instantiated: HashMap<String, String>,
128    /// Type -> MethodName -> MangledName
129    methods: HashMap<String, HashMap<String, String>>,
130    /// Struct Name -> Field Name -> Type
131    structs: HashMap<String, HashMap<String, ResolvedType>>,
132    /// (TraitName, TypeName) -> Implemented
133    trait_impls: HashSet<(String, String)>,
134}
135
136impl MonoContext {
137    fn new() -> Self {
138        Self {
139            generic_functions: HashMap::new(),
140            concrete_items: Vec::new(),
141            instantiated: HashMap::new(),
142            methods: HashMap::new(),
143            structs: HashMap::new(),
144            trait_impls: HashSet::new(),
145        }
146    }
147    
148    fn instantiate(&mut self, name: &str, type_args: &[ResolvedType]) -> KoreResult<String> {
149        let mangled_name = format!("{}_{}", name, mangle_types(type_args));
150        
151        if self.instantiated.contains_key(&mangled_name) {
152            return Ok(mangled_name);
153        }
154        
155        let generic_func = self.generic_functions.get(name)
156            .ok_or_else(|| KoreError::type_error(format!("Generic function {} not found", name), crate::span::Span::new(0,0)))?
157            .clone();
158            
159        if generic_func.ast.generics.len() != type_args.len() {
160             return Err(KoreError::type_error(format!("Generic arg count mismatch for {}: expected {}, got {}", name, generic_func.ast.generics.len(), type_args.len()), generic_func.ast.span));
161        }
162        
163        let mut mapping = HashMap::new();
164        for (i, param) in generic_func.ast.generics.iter().enumerate() {
165            mapping.insert(param.name.clone(), type_args[i].clone());
166        }
167        
168        let mut new_func = generic_func.clone();
169        new_func.ast.name = mangled_name.clone();
170        new_func.ast.generics.clear();
171        
172        if let ResolvedType::Function { params, ret, .. } = &mut new_func.resolved_type {
173            for p in params {
174                *p = substitute_type(p, &mapping);
175            }
176            *ret = Box::new(substitute_type(&ret, &mapping));
177        }
178        
179        self.instantiated.insert(mangled_name.clone(), mangled_name.clone());
180        
181        substitute_ast_types(&mut new_func.ast, &mapping);
182        self.concrete_items.push(TypedItem::Function(new_func));
183        
184        Ok(mangled_name)
185    }
186}
187
188fn type_to_string(ty: &ResolvedType) -> String {
189    match ty {
190        ResolvedType::Int(_) => "Int".to_string(),
191        ResolvedType::Float(_) => "Float".to_string(),
192        ResolvedType::String => "String".to_string(),
193        ResolvedType::Bool => "Bool".to_string(),
194        ResolvedType::Unit => "Unit".to_string(),
195        ResolvedType::Struct(n, _) => n.clone(),
196        ResolvedType::Enum(n, _) => n.clone(),
197        ResolvedType::Tuple(ts) => format!("({})", ts.iter().map(type_to_string).collect::<Vec<_>>().join(", ")),
198        _ => "Any".to_string(),
199    }
200}
201
202fn mangle_types(types: &[ResolvedType]) -> String {
203    types.iter().map(type_to_string).collect::<Vec<_>>().join("_")
204}
205
206fn resolve_ast_type(ty: &Type) -> KoreResult<ResolvedType> {
207    crate::types::resolve_type(ty)
208}
209
210/// Unify a parameter type with an argument type to extract generic bindings.
211/// For example, unifying `fn(T) -> T` with `fn(Int) -> Int` yields `{T: Int}`.
212fn unify(
213    param_type: &ResolvedType,
214    arg_type: &ResolvedType,
215    bindings: &mut HashMap<String, ResolvedType>,
216) {
217    match (param_type, arg_type) {
218        // If the parameter type is a generic, bind it to the argument type
219        (ResolvedType::Generic(name), concrete) => {
220            if let Some(existing) = bindings.get(name) {
221                // Already bound - ideally check for consistency, but for now just keep first binding
222                let _ = existing;
223            } else {
224                bindings.insert(name.clone(), concrete.clone());
225            }
226        }
227        
228        // Recursively unify function types
229        (ResolvedType::Function { params: p_params, ret: p_ret, .. }, 
230         ResolvedType::Function { params: a_params, ret: a_ret, .. }) => {
231            // Unify parameter types
232            for (pp, ap) in p_params.iter().zip(a_params.iter()) {
233                unify(pp, ap, bindings);
234            }
235            // Unify return type
236            unify(p_ret, a_ret, bindings);
237        }
238        
239        // Recursively unify array types
240        (ResolvedType::Array(p_inner, _), ResolvedType::Array(a_inner, _)) => {
241            unify(p_inner, a_inner, bindings);
242        }
243        
244        // Recursively unify tuple types
245        (ResolvedType::Tuple(p_elems), ResolvedType::Tuple(a_elems)) => {
246            for (pe, ae) in p_elems.iter().zip(a_elems.iter()) {
247                unify(pe, ae, bindings);
248            }
249        }
250        
251        // For concrete types that match, nothing to unify
252        _ => {}
253    }
254}
255
256/// Infer type arguments for a generic function call by unifying parameter types with argument types.
257fn infer_type_args(
258    ctx: &MonoContext,
259    generic_func: &TypedFunction,
260    arg_types: &[ResolvedType],
261) -> KoreResult<Vec<ResolvedType>> {
262    let mut bindings: HashMap<String, ResolvedType> = HashMap::new();
263    
264    // Get the parameter types from the function signature
265    let param_types: Vec<ResolvedType> = if let ResolvedType::Function { params, .. } = &generic_func.resolved_type {
266        params.clone()
267    } else {
268        // Fallback: resolve from AST
269        generic_func.ast.params.iter()
270            .map(|p| resolve_ast_type(&p.ty).unwrap_or(ResolvedType::Unknown))
271            .collect()
272    };
273    
274    // Unify each parameter type with the corresponding argument type
275    for (param_ty, arg_ty) in param_types.iter().zip(arg_types.iter()) {
276        unify(param_ty, arg_ty, &mut bindings);
277    }
278    
279    // Extract the inferred types in the order of the generic parameters
280    let mut inferred = Vec::new();
281    for generic in &generic_func.ast.generics {
282        if let Some(ty) = bindings.get(&generic.name) {
283             // Check Bounds!
284             for bound in &generic.bounds {
285                 let type_name = type_to_string(ty);
286                 if !ctx.trait_impls.contains(&(bound.trait_name.clone(), type_name.clone())) {
287                     return Err(KoreError::type_error(
288                         format!("Type '{}' does not satisfy bound '{}'", type_name, bound.trait_name), 
289                         generic.span
290                     ));
291                 }
292             }
293             inferred.push(ty.clone());
294        } else {
295            // Generic wasn't inferred - could be an error, but let's use Unknown for now
296            inferred.push(ResolvedType::Unknown);
297        }
298    }
299    
300    Ok(inferred)
301}
302
303fn substitute_type(ty: &ResolvedType, mapping: &HashMap<String, ResolvedType>) -> ResolvedType {
304    match ty {
305        ResolvedType::Generic(name) => mapping.get(name).cloned().unwrap_or(ty.clone()),
306        ResolvedType::Function { params, ret, effects } => {
307            ResolvedType::Function {
308                params: params.iter().map(|p| substitute_type(p, mapping)).collect(),
309                ret: Box::new(substitute_type(ret, mapping)),
310                effects: effects.clone()
311            }
312        }
313        ResolvedType::Array(inner, n) => ResolvedType::Array(Box::new(substitute_type(inner, mapping)), *n),
314        _ => ty.clone() 
315    }
316}
317
318fn substitute_ast_types(func: &mut Function, mapping: &HashMap<String, ResolvedType>) {
319    // 1. Substitute param types
320    for param in &mut func.params {
321        substitute_type_ast(&mut param.ty, mapping);
322    }
323    
324    // 2. Substitute return type
325    if let Some(ret) = &mut func.return_type {
326        substitute_type_ast(ret, mapping);
327    }
328    
329    // 3. Substitute body
330    substitute_block(&mut func.body, mapping);
331}
332
333fn substitute_block(block: &mut Block, mapping: &HashMap<String, ResolvedType>) {
334    for stmt in &mut block.stmts {
335        substitute_stmt(stmt, mapping);
336    }
337}
338
339fn substitute_stmt(stmt: &mut Stmt, mapping: &HashMap<String, ResolvedType>) {
340    match stmt {
341        Stmt::Let { ty, value, .. } => {
342            if let Some(t) = ty {
343                substitute_type_ast(t, mapping);
344            }
345            if let Some(v) = value {
346                substitute_expr(v, mapping);
347            }
348        }
349        Stmt::Expr(e) => substitute_expr(e, mapping),
350        Stmt::Return(Some(e), _) => substitute_expr(e, mapping),
351        Stmt::For { iter, body, .. } => {
352            substitute_expr(iter, mapping);
353            substitute_block(body, mapping);
354        }
355        Stmt::While { condition, body, .. } => {
356            substitute_expr(condition, mapping);
357            substitute_block(body, mapping);
358        }
359        _ => {}
360    }
361}
362
363fn substitute_expr(expr: &mut Expr, mapping: &HashMap<String, ResolvedType>) {
364    match expr {
365        Expr::Cast { value, target, .. } => {
366            substitute_expr(value, mapping);
367            substitute_type_ast(target, mapping);
368        }
369        Expr::Binary { left, right, .. } => {
370            substitute_expr(left, mapping);
371            substitute_expr(right, mapping);
372        }
373        Expr::Unary { operand, .. } => substitute_expr(operand, mapping),
374        Expr::Call { callee, args, .. } => {
375            substitute_expr(callee, mapping);
376            for arg in args {
377                substitute_expr(&mut arg.value, mapping);
378            }
379        }
380        Expr::MethodCall { receiver, args, .. } => {
381            substitute_expr(receiver, mapping);
382            for arg in args {
383                substitute_expr(&mut arg.value, mapping);
384            }
385        }
386        Expr::Field { object, .. } => {
387            substitute_expr(object, mapping);
388        }
389        Expr::Index { object, index, .. } => {
390            substitute_expr(object, mapping);
391            substitute_expr(index, mapping);
392        }
393        Expr::Struct { fields, .. } => {
394             for (_, v) in fields {
395                 substitute_expr(v, mapping);
396             }
397        }
398        Expr::Array(items, _) => {
399             for item in items {
400                 substitute_expr(item, mapping);
401             }
402        }
403        Expr::Tuple(items, _) => {
404             for item in items {
405                 substitute_expr(item, mapping);
406             }
407        }
408        Expr::Block(b, _) => substitute_block(b, mapping),
409        Expr::If { condition, then_branch, else_branch, .. } => {
410             substitute_expr(condition, mapping);
411             substitute_block(then_branch, mapping);
412             if let Some(br) = else_branch {
413                 match br.as_mut() {
414                     ElseBranch::Else(b) => substitute_block(b, mapping),
415                     ElseBranch::ElseIf(c, t, _) => { // Simplified recursion
416                         substitute_expr(c, mapping);
417                         substitute_block(t, mapping);
418                     }
419                 }
420             }
421        }
422        Expr::Match { scrutinee, arms, .. } => {
423            substitute_expr(scrutinee, mapping);
424            for arm in arms {
425                substitute_expr(&mut arm.body, mapping);
426            }
427        }
428        Expr::Lambda { params, body, return_type, .. } => {
429             for p in params {
430                 substitute_type_ast(&mut p.ty, mapping);
431             }
432             if let Some(ret) = return_type {
433                 substitute_type_ast(ret, mapping);
434             }
435             substitute_expr(body, mapping);
436        }
437        Expr::Await(inner, _) => {
438            substitute_expr(inner, mapping);
439        }
440        _ => {}
441    }
442}
443
444fn substitute_type_ast(ty: &mut Type, mapping: &HashMap<String, ResolvedType>) {
445    match ty {
446        Type::Named { name, generics, .. } => {
447            if let Some(concrete) = mapping.get(name) {
448                *ty = resolved_to_ast_type(concrete, ty.span());
449            } else {
450                for g in generics {
451                    substitute_type_ast(g, mapping);
452                }
453            }
454        }
455        Type::Tuple(types, _) => {
456            for t in types {
457                substitute_type_ast(t, mapping);
458            }
459        }
460        Type::Function { params, return_type, .. } => {
461            for p in params {
462                substitute_type_ast(p, mapping);
463            }
464            substitute_type_ast(return_type, mapping);
465        }
466        Type::Array(inner, _, _) => {
467             substitute_type_ast(inner, mapping);
468         }
469         Type::Slice(inner, _) => {
470             substitute_type_ast(inner, mapping);
471         }
472         _ => {}
473    }
474}
475
476fn resolved_to_ast_type(res: &ResolvedType, span: crate::span::Span) -> Type {
477    match res {
478        ResolvedType::Int(_) => Type::Named { name: "Int".into(), generics: vec![], span },
479        ResolvedType::Float(_) => Type::Named { name: "Float".into(), generics: vec![], span },
480        ResolvedType::Bool => Type::Named { name: "Bool".into(), generics: vec![], span },
481        ResolvedType::String => Type::Named { name: "String".into(), generics: vec![], span },
482        ResolvedType::Unit => Type::Unit(span),
483        ResolvedType::Struct(n, _) => Type::Named { name: n.clone(), generics: vec![], span },
484        _ => Type::Named { name: "Any".into(), generics: vec![], span }, // Fallback
485    }
486}
487
488struct MonoTypeEnv {
489    scopes: Vec<HashMap<String, ResolvedType>>,
490}
491
492impl MonoTypeEnv {
493    fn new() -> Self {
494        Self { scopes: vec![HashMap::new()] }
495    }
496    fn push(&mut self) { self.scopes.push(HashMap::new()); }
497    fn pop(&mut self) { self.scopes.pop(); }
498    
499    fn define(&mut self, name: String, ty: ResolvedType) {
500        if let Some(s) = self.scopes.last_mut() {
501            s.insert(name, ty);
502        }
503    }
504    
505    fn get(&self, name: &str) -> ResolvedType {
506        for s in self.scopes.iter().rev() {
507            if let Some(t) = s.get(name) { return t.clone(); }
508        }
509        ResolvedType::Unknown
510    }
511}
512
513fn lower_async_fn(ctx: &mut MonoContext, func: &TypedFunction) -> KoreResult<TypedFunction> {
514    let state_machine_name = format!("{}_Future", func.ast.name);
515    
516    // 1. Create State Machine Struct
517    // struct MyFn_Future { state: Int, ...args, ...locals }
518    let mut fields = HashMap::new();
519    fields.insert("state".to_string(), ResolvedType::Int(IntSize::I64));
520    
521    // Capture arguments
522    for param in &func.ast.params {
523        fields.insert(param.name.clone(), resolve_ast_type(&param.ty).unwrap_or(ResolvedType::Unknown));
524    }
525    
526    // Capture locals (lifted to struct fields)
527    let locals = collect_locals(&func.ast.body);
528    for (name, ty) in locals {
529        fields.entry(name).or_insert(ty);
530    }
531    
532    let _struct_ty = ResolvedType::Struct(state_machine_name.clone(), fields.clone());
533    
534    // Register Struct
535    ctx.structs.insert(state_machine_name.clone(), fields.clone());
536    
537    // Emit Struct Definition
538    // We need to create a TypedStruct and push it
539    let struct_def = TypedItem::Struct(TypedStruct {
540        ast: Struct {
541            name: state_machine_name.clone(),
542            generics: vec![],
543            fields: fields.iter().map(|(n, t)| Field {
544                name: n.clone(),
545                ty: resolved_to_ast_type(t, func.ast.span),
546                visibility: Visibility::Public,
547                default: None,
548                weak: false,
549                span: func.ast.span
550            }).collect(),
551            visibility: Visibility::Public,
552            span: func.ast.span,
553        },
554        field_types: fields.clone(),
555    });
556    ctx.concrete_items.push(struct_def);
557    
558    // 2. Generate Poll Function
559    // fn MyFn_Future_poll(self: &mut MyFn_Future) -> Poll<T>
560    let poll_name = format!("{}_poll", state_machine_name);
561    
562    // Create 'self' param
563    let self_type = ResolvedType::Struct(state_machine_name.clone(), fields.clone());
564    let self_param = Param {
565        name: "self".to_string(),
566        ty: resolved_to_ast_type(&self_type, func.ast.span),
567        mutable: true,
568        default: None,
569        span: func.ast.span,
570    };
571    
572    // === AWAIT CHOPPING: Split function body at await points ===
573    
574    // Step 1: Collect all await points and statements between them
575    let await_points = collect_await_points(&func.ast.body);
576    
577    // Step 2: Add storage fields for each await's pending future and its result
578    for (i, _) in await_points.iter().enumerate() {
579        let field_name = format!("_await_{}", i);
580        // Store futures as Unknown type (dynamic typing for interpreter)
581        fields.insert(field_name, ResolvedType::Unknown);
582        
583        // Store result of the future
584        let res_name = format!("_await_{}_result", i);
585        fields.insert(res_name, ResolvedType::Unknown);
586    }
587    
588    // Update struct with new fields
589    ctx.structs.insert(state_machine_name.clone(), fields.clone());
590    
591    // Step 3: Generate match arms for each state
592    let mut arms = Vec::new();
593    
594    if await_points.is_empty() {
595        // No awaits - just execute the whole body in state 0 and return Ready
596        let mut rewritten_body = func.ast.body.clone();
597        rewrite_access_to_self(&mut rewritten_body, &fields);
598        
599        // Wrap result in Poll::Ready
600        let body_with_ready = wrap_return_in_poll_ready(rewritten_body, func.ast.span);
601        
602        let arm0 = MatchArm {
603            pattern: Pattern::Literal(Expr::Int(0, func.ast.span)),
604            guard: None,
605            body: body_with_ready,
606            span: func.ast.span,
607        };
608        arms.push(arm0);
609    } else {
610        // Has awaits - generate state machine
611        let segments = split_at_awaits(&func.ast.body, &await_points);
612        
613        for (state_idx, segment) in segments.iter().enumerate() {
614            let arm = generate_state_arm(
615                state_idx,
616                segment,
617                &await_points,
618                &fields,
619                &state_machine_name,
620                func.ast.span,
621            );
622            arms.push(arm);
623        }
624    }
625    
626    // Fallback arm for completed/invalid states
627    let arm_wild = MatchArm {
628        pattern: Pattern::Wildcard(func.ast.span),
629        guard: None,
630        body: Expr::Call {
631            callee: Box::new(Expr::Ident("panic".to_string(), func.ast.span)),
632            args: vec![CallArg {
633                name: None,
634                value: Expr::String("polled after completion".to_string(), func.ast.span),
635                span: func.ast.span,
636            }],
637            span: func.ast.span,
638        },
639        span: func.ast.span,
640    };
641    arms.push(arm_wild);
642    
643    // Create poll body with the match expression
644    let mut poll_body = Block { stmts: vec![], span: func.ast.span };
645    
646    let match_expr = Expr::Match {
647        scrutinee: Box::new(Expr::Field {
648            object: Box::new(Expr::Ident("self".to_string(), func.ast.span)),
649            field: "state".to_string(),
650            span: func.ast.span
651        }),
652        arms,
653        span: func.ast.span,
654    };
655    
656    poll_body.stmts.push(Stmt::Expr(match_expr));
657    
658    let poll_fn = TypedItem::Function(TypedFunction {
659        ast: Function {
660            name: poll_name.clone(),
661            generics: vec![],
662            params: vec![self_param],
663            return_type: None, // Should be Poll<T>
664            effects: vec![],
665            body: poll_body,
666            visibility: Visibility::Public,
667            span: func.ast.span,
668        },
669        resolved_type: ResolvedType::Function {
670            params: vec![self_type],
671            ret: Box::new(ResolvedType::Unit), // Todo Poll
672            effects: crate::effects::EffectSet::new(),
673        },
674        effects: crate::effects::EffectSet::new(),
675    });
676    ctx.concrete_items.push(poll_fn);
677    
678    // 3. Rewrite Original Function
679    // fn MyFn(args) -> MyFn_Future
680    let mut entry_fn = func.clone();
681    
682    // Construct Struct Init
683    let mut init_fields = Vec::new();
684    init_fields.push(("state".to_string(), Expr::Int(0, func.ast.span)));
685    for param in &func.ast.params {
686        init_fields.push((param.name.clone(), Expr::Ident(param.name.clone(), func.ast.span)));
687    }
688    
689    // Initialize await fields
690    for (i, _) in await_points.iter().enumerate() {
691        init_fields.push((format!("_await_{}", i), Expr::None(func.ast.span)));
692        init_fields.push((format!("_await_{}_result", i), Expr::None(func.ast.span)));
693    }
694    
695    // Initialize captured locals
696    let captured_locals = collect_locals(&func.ast.body);
697    for (name, _) in captured_locals {
698        // Skip params (already initialized)
699        if func.ast.params.iter().any(|p| p.name == name) { continue; }
700        init_fields.push((name, Expr::None(func.ast.span)));
701    }
702    
703    let body_expr = Expr::Struct {
704        name: state_machine_name.clone(),
705        fields: init_fields,
706        span: func.ast.span,
707    };
708    
709    entry_fn.ast.body = Block {
710        stmts: vec![Stmt::Return(Some(body_expr), func.ast.span)],
711        span: func.ast.span,
712    };
713    
714    // Update return type to Future (Struct)
715    // Note: In real implementation this would be impl Future<Output=T>
716    // For now, we return the struct directly.
717    entry_fn.resolved_type = ResolvedType::Function {
718        params: if let ResolvedType::Function{params, ..} = &func.resolved_type { params.clone() } else { vec![] },
719        ret: Box::new(ResolvedType::Struct(state_machine_name, fields)),
720        effects: crate::effects::EffectSet::new(), // Entry function is synchronous (returns Future)
721    };
722    
723    // Clear async effect
724    entry_fn.effects.effects.remove(&crate::effects::Effect::Async);
725    entry_fn.ast.effects.retain(|e| *e != crate::effects::Effect::Async);
726    
727    Ok(entry_fn)
728}
729
730fn rewrite_access_to_self(block: &mut Block, fields: &HashMap<String, ResolvedType>) {
731    for stmt in &mut block.stmts {
732        rewrite_stmt(stmt, fields);
733    }
734}
735
736fn rewrite_stmt(stmt: &mut Stmt, fields: &HashMap<String, ResolvedType>) {
737    // 1. Rewrite expressions inside statements
738    match stmt {
739        Stmt::Expr(e) => rewrite_expr(e, fields),
740        Stmt::Return(Some(e), _) => rewrite_expr(e, fields),
741        Stmt::Let { value: Some(e), .. } => rewrite_expr(e, fields),
742        Stmt::For { iter, body, .. } => {
743            rewrite_expr(iter, fields);
744            rewrite_access_to_self(body, fields);
745        }
746        Stmt::While { condition, body, .. } => {
747            rewrite_expr(condition, fields);
748            rewrite_access_to_self(body, fields);
749        }
750        _ => {}
751    }
752    
753    // 2. Transform local bindings to struct assignments if captured
754    let transform = if let Stmt::Let { pattern: Pattern::Binding { name, .. }, value: Some(e), span, .. } = stmt {
755        if fields.contains_key(name) {
756             Some((name.clone(), e.clone(), *span))
757        } else { None }
758    } else { None };
759    
760    if let Some((name, val, span)) = transform {
761        *stmt = Stmt::Expr(Expr::Assign {
762             target: Box::new(Expr::Field {
763                 object: Box::new(Expr::Ident("self".to_string(), span)),
764                 field: name,
765                 span,
766             }),
767             value: Box::new(val),
768             span,
769        });
770    }
771}
772
773fn rewrite_expr(expr: &mut Expr, fields: &HashMap<String, ResolvedType>) {
774    match expr {
775        Expr::Ident(name, span) => {
776            if fields.contains_key(name) {
777                // Transform `x` -> `self.x`
778                *expr = Expr::Field {
779                    object: Box::new(Expr::Ident("self".to_string(), *span)),
780                    field: name.clone(),
781                    span: *span,
782                };
783            }
784        }
785        Expr::Binary { left, right, .. } => {
786            rewrite_expr(left, fields);
787            rewrite_expr(right, fields);
788        }
789        Expr::Call { callee, args, .. } => {
790            rewrite_expr(callee, fields);
791            for arg in args {
792                rewrite_expr(&mut arg.value, fields);
793            }
794        }
795        Expr::Field { object, .. } => rewrite_expr(object, fields),
796        Expr::Await(inner, _) => rewrite_expr(inner, fields),
797        Expr::Block(b, _) => rewrite_access_to_self(b, fields),
798        // Add other recursive cases...
799        _ => {}
800    }
801}
802
803// === AWAIT CHOPPING HELPERS ===
804
805/// Represents an await point in the async function
806#[derive(Clone, Debug)]
807struct AwaitPoint {
808    /// The expression being awaited
809    awaited_expr: Expr,
810    /// Variable to bind the result to (if any)
811    result_binding: Option<String>,
812    /// Index of this await point (for state numbering)
813    index: usize,
814}
815
816/// Collect all await points from a block, in order of appearance
817fn collect_await_points(block: &Block) -> Vec<AwaitPoint> {
818    let mut points = Vec::new();
819    collect_awaits_from_block(block, &mut points);
820    points
821}
822
823fn collect_awaits_from_block(block: &Block, points: &mut Vec<AwaitPoint>) {
824    for stmt in &block.stmts {
825        collect_awaits_from_stmt(stmt, points);
826    }
827}
828
829fn collect_awaits_from_stmt(stmt: &Stmt, points: &mut Vec<AwaitPoint>) {
830    match stmt {
831        Stmt::Let { pattern, value, .. } => {
832            // Extract name from pattern if it's a simple binding
833            let name = match pattern {
834                Pattern::Binding { name: n, .. } => Some(n.clone()),
835                _ => None,
836            };
837            
838            // Check if the value is an await expression
839            if let Some(expr) = value {
840                if let Expr::Await(inner, _) = expr {
841                    points.push(AwaitPoint {
842                        awaited_expr: (**inner).clone(),
843                        result_binding: name,
844                        index: points.len(),
845                    });
846                } else {
847                    collect_awaits_from_expr(expr, points);
848                }
849            }
850        }
851        Stmt::Expr(expr) => {
852            if let Expr::Await(inner, _) = expr {
853                points.push(AwaitPoint {
854                    awaited_expr: (**inner).clone(),
855                    result_binding: None,
856                    index: points.len(),
857                });
858            } else {
859                collect_awaits_from_expr(expr, points);
860            }
861        }
862        Stmt::Return(Some(expr), _) => {
863            if let Expr::Await(inner, _) = expr {
864                points.push(AwaitPoint {
865                    awaited_expr: (**inner).clone(),
866                    result_binding: None, // Return will use the value directly
867                    index: points.len(),
868                });
869            } else {
870                collect_awaits_from_expr(expr, points);
871            }
872        }
873        // Note: In KORE, if is an expression, not a statement. If used in Stmt::Expr, 
874        // collect_awaits_from_expr will handle it.
875        Stmt::While { body, .. } | Stmt::Loop { body, .. } => {
876            collect_awaits_from_block(body, points);
877        }
878        Stmt::For { body, .. } => {
879            collect_awaits_from_block(body, points);
880        }
881        _ => {}
882    }
883}
884
885fn collect_awaits_from_expr(expr: &Expr, points: &mut Vec<AwaitPoint>) {
886    match expr {
887        Expr::Await(inner, _) => {
888            points.push(AwaitPoint {
889                awaited_expr: (**inner).clone(),
890                result_binding: None,
891                index: points.len(),
892            });
893        }
894        Expr::Binary { left, right, .. } => {
895            collect_awaits_from_expr(left, points);
896            collect_awaits_from_expr(right, points);
897        }
898        Expr::Call { callee, args, .. } => {
899            collect_awaits_from_expr(callee, points);
900            for arg in args {
901                collect_awaits_from_expr(&arg.value, points);
902            }
903        }
904        Expr::Block(block, _) => collect_awaits_from_block(block, points),
905        Expr::If { then_branch, else_branch, .. } => {
906            collect_awaits_from_block(then_branch, points);
907            if let Some(else_b) = else_branch {
908                match else_b.as_ref() {
909                    ElseBranch::Else(b) => collect_awaits_from_block(b, points),
910                    ElseBranch::ElseIf(_, then_b, _) => collect_awaits_from_block(then_b, points),
911                }
912            }
913        }
914        _ => {}
915    }
916}
917
918/// Represents a segment of code between await points
919#[derive(Clone)]
920struct CodeSegment {
921    /// Statements before the await (or all statements if no await in this segment)
922    stmts_before: Vec<Stmt>,
923    /// The await point (if this segment ends with an await)
924    await_point: Option<AwaitPoint>,
925    /// Whether this segment ends with a return
926    ends_with_return: bool,
927}
928
929/// Split the function body at await points into segments
930fn split_at_awaits(block: &Block, await_points: &[AwaitPoint]) -> Vec<CodeSegment> {
931    let mut segments = Vec::new();
932    let mut current_stmts = Vec::new();
933    let mut await_idx = 0;
934    
935    for stmt in &block.stmts {
936        // Check if this statement contains an await at the top level
937        let contains_await = match stmt {
938            Stmt::Let { value: Some(Expr::Await(_, _)), .. } => true,
939            Stmt::Expr(Expr::Await(_, _)) => true,
940            Stmt::Return(Some(Expr::Await(_, _)), _) => true,
941            _ => false,
942        };
943        
944        if contains_await && await_idx < await_points.len() {
945            // End current segment, start new one after await
946            segments.push(CodeSegment {
947                stmts_before: current_stmts.clone(),
948                await_point: Some(await_points[await_idx].clone()),
949                ends_with_return: matches!(stmt, Stmt::Return(_, _)),
950            });
951            current_stmts.clear();
952            await_idx += 1;
953        } else {
954            current_stmts.push(stmt.clone());
955        }
956    }
957    
958    // Add final segment (code after last await)
959    if !current_stmts.is_empty() || segments.is_empty() {
960        let ends_with_return = current_stmts.last()
961            .map(|s| matches!(s, Stmt::Return(_, _)))
962            .unwrap_or(false);
963        segments.push(CodeSegment {
964            stmts_before: current_stmts,
965            await_point: None,
966            ends_with_return,
967        });
968    }
969    
970    segments
971}
972
973/// Generate a match arm for a specific state in the state machine
974fn generate_state_arm(
975    state_idx: usize,
976    segment: &CodeSegment,
977    await_points: &[AwaitPoint],
978    fields: &HashMap<String, ResolvedType>,
979    _state_machine_name: &str,
980    span: crate::span::Span,
981) -> MatchArm {
982    let mut body_stmts = Vec::new();
983    
984    // If this is a continuation state (after an await), we must POLL the future from the previous step
985    if state_idx > 0 && state_idx <= await_points.len() {
986        let prev_await = &await_points[state_idx - 1];
987        let poll_field = format!("_await_{}", prev_await.index);
988        let res_field = format!("_await_{}_result", prev_await.index);
989        
990        // Match expression to check poll status
991        // match self._await_N.poll() { ... }
992        let poll_call = Expr::MethodCall {
993            receiver: Box::new(Expr::Field {
994                object: Box::new(Expr::Ident("self".to_string(), span)),
995                field: poll_field,
996                span,
997            }),
998            method: "poll".to_string(),
999            args: vec![],
1000            span,
1001        };
1002        
1003        // Arm 1: Poll::Pending => return Poll::Pending
1004        let pending_arm = MatchArm {
1005            pattern: Pattern::Variant {
1006                enum_name: Some("Poll".to_string()),
1007                variant: "Pending".to_string(),
1008                fields: VariantPatternFields::Unit,
1009                span,
1010            },
1011            guard: None,
1012            body: Expr::Return(
1013                Some(Box::new(Expr::EnumVariant {
1014                    enum_name: "Poll".to_string(),
1015                    variant: "Pending".to_string(),
1016                    fields: EnumVariantFields::Unit,
1017                    span,
1018                })),
1019                span,
1020            ),
1021            span,
1022        };
1023        
1024        // Arm 2: Poll::Ready(val) => { self._await_N_result = val; }
1025        // We capture 'val' in a binding
1026        let val_name = "val".to_string();
1027        let ready_arm = MatchArm {
1028            pattern: Pattern::Variant {
1029                enum_name: Some("Poll".to_string()),
1030                variant: "Ready".to_string(),
1031                fields: VariantPatternFields::Tuple(vec![
1032                    Pattern::Binding { name: val_name.clone(), mutable: false, span }
1033                ]),
1034                span,
1035            },
1036            guard: None,
1037            body: Expr::Assign {
1038                target: Box::new(Expr::Field {
1039                    object: Box::new(Expr::Ident("self".to_string(), span)),
1040                    field: res_field.clone(),
1041                    span,
1042                }),
1043                value: Box::new(Expr::Ident(val_name, span)),
1044                span,
1045            },
1046            span,
1047        };
1048        
1049        let poll_match = Expr::Match {
1050            scrutinee: Box::new(poll_call),
1051            arms: vec![pending_arm, ready_arm],
1052            span,
1053        };
1054        
1055        body_stmts.push(Stmt::Expr(poll_match));
1056        
1057        // Bind the result to the user's variable: let result_binding = self._await_N_result
1058        // Bind the result to the user's variable
1059        // If captured, we must assign to self.variable, otherwise use let for temporary
1060        if let Some(binding) = &prev_await.result_binding {
1061            if fields.contains_key(binding) {
1062                 // self.binding = self._await_N_result
1063                 body_stmts.push(Stmt::Expr(Expr::Assign {
1064                     target: Box::new(Expr::Field {
1065                         object: Box::new(Expr::Ident("self".to_string(), span)),
1066                         field: binding.clone(),
1067                         span,
1068                     }),
1069                     value: Box::new(Expr::Field {
1070                         object: Box::new(Expr::Ident("self".to_string(), span)),
1071                         field: res_field,
1072                         span,
1073                     }),
1074                     span,
1075                 }));
1076            } else {
1077                body_stmts.push(Stmt::Let {
1078                    pattern: Pattern::Binding { name: binding.clone(), mutable: false, span },
1079                    ty: None,
1080                    value: Some(Expr::Field {
1081                        object: Box::new(Expr::Ident("self".to_string(), span)),
1082                        field: res_field,
1083                        span,
1084                    }),
1085                    span,
1086                });
1087            }
1088        }
1089    }
1090    
1091    // Add the segment's statements (rewritten to use self.x)
1092    for stmt in &segment.stmts_before {
1093        let mut rewritten_stmt = stmt.clone();
1094        rewrite_stmt(&mut rewritten_stmt, fields);
1095        body_stmts.push(rewritten_stmt);
1096    }
1097    
1098    // Handle the await point (if any)
1099    if let Some(await_point) = &segment.await_point {
1100        // 1. Evaluate the future expression and store it
1101        let store_field = format!("_await_{}", await_point.index);
1102        let mut awaited_expr = await_point.awaited_expr.clone();
1103        rewrite_expr(&mut awaited_expr, fields);
1104        
1105        body_stmts.push(Stmt::Expr(Expr::Assign {
1106            target: Box::new(Expr::Field {
1107                object: Box::new(Expr::Ident("self".to_string(), span)),
1108                field: store_field,
1109                span,
1110            }),
1111            value: Box::new(awaited_expr),
1112            span,
1113        }));
1114        
1115        // 2. Increment state
1116        body_stmts.push(Stmt::Expr(Expr::Assign {
1117            target: Box::new(Expr::Field {
1118                object: Box::new(Expr::Ident("self".to_string(), span)),
1119                field: "state".to_string(),
1120                span,
1121            }),
1122            value: Box::new(Expr::Int((state_idx + 1) as i64, span)),
1123            span,
1124        }));
1125        
1126        // 3. Return Poll::Pending
1127        body_stmts.push(Stmt::Return(
1128            Some(Expr::EnumVariant {
1129                enum_name: "Poll".to_string(),
1130                variant: "Pending".to_string(),
1131                fields: EnumVariantFields::Unit,
1132                span,
1133            }),
1134            span,
1135        ));
1136    } else if segment.ends_with_return {
1137        // Already has a return - wrap it in Poll::Ready
1138        // (handled by the rewrite)
1139    } else if state_idx == await_points.len() {
1140        // Final state after all awaits - return Poll::Ready(Unit) or the result
1141        body_stmts.push(Stmt::Return(
1142            Some(Expr::EnumVariant {
1143                enum_name: "Poll".to_string(),
1144                variant: "Ready".to_string(),
1145                fields: EnumVariantFields::Tuple(vec![Expr::None(span)]),
1146                span,
1147            }),
1148            span,
1149        ));
1150    }
1151    
1152    MatchArm {
1153        pattern: Pattern::Literal(Expr::Int(state_idx as i64, span)),
1154        guard: None,
1155        body: Expr::Block(Block { stmts: body_stmts, span }, span),
1156        span,
1157    }
1158}
1159
1160/// Wrap the body's returns in Poll::Ready
1161fn wrap_return_in_poll_ready(mut block: Block, span: crate::span::Span) -> Expr {
1162    for stmt in &mut block.stmts {
1163        wrap_stmt_returns(stmt, span);
1164    }
1165    Expr::Block(block, span)
1166}
1167
1168fn wrap_stmt_returns(stmt: &mut Stmt, span: crate::span::Span) {
1169    match stmt {
1170        Stmt::Return(Some(expr), _) => {
1171            // Wrap: return x -> return Poll::Ready(x)
1172            let inner = std::mem::replace(expr, Expr::None(span));
1173            *expr = Expr::EnumVariant {
1174                enum_name: "Poll".to_string(),
1175                variant: "Ready".to_string(),
1176                fields: EnumVariantFields::Tuple(vec![inner]),
1177                span,
1178            };
1179        }
1180        Stmt::Return(None, s) => {
1181            // Wrap: return -> return Poll::Ready(())
1182            *stmt = Stmt::Return(
1183                Some(Expr::EnumVariant {
1184                    enum_name: "Poll".to_string(),
1185                    variant: "Ready".to_string(),
1186                    fields: EnumVariantFields::Tuple(vec![Expr::None(span)]),
1187                    span,
1188                }),
1189                *s,
1190            );
1191        }
1192        // Note: In KORE, if is an expression. Expr::If would be in Stmt::Expr, 
1193        // but we handle expressions separately if needed.
1194        Stmt::While { body, .. } | Stmt::Loop { body, .. } => {
1195            for s in &mut body.stmts {
1196                wrap_stmt_returns(s, span);
1197            }
1198        }
1199        Stmt::For { body, .. } => {
1200            for s in &mut body.stmts {
1201                wrap_stmt_returns(s, span);
1202            }
1203        }
1204        _ => {}
1205    }
1206}
1207
1208fn scan_function(ctx: &mut MonoContext, func: &TypedFunction) -> KoreResult<TypedFunction> {
1209    let mut new_func = func.clone();
1210    let mut env = MonoTypeEnv::new();
1211    
1212    if let ResolvedType::Function { params, .. } = &func.resolved_type {
1213        for (i, p) in params.iter().enumerate() {
1214            if i < func.ast.params.len() {
1215                env.define(func.ast.params[i].name.clone(), p.clone());
1216            }
1217        }
1218    }
1219    
1220    scan_block(ctx, &mut env, &mut new_func.ast.body)?;
1221    Ok(new_func)
1222}
1223
1224fn scan_block(ctx: &mut MonoContext, env: &mut MonoTypeEnv, block: &mut Block) -> KoreResult<()> {
1225    env.push();
1226    for stmt in &mut block.stmts {
1227        scan_stmt(ctx, env, stmt)?;
1228    }
1229    env.pop();
1230    Ok(())
1231}
1232
1233fn scan_stmt(ctx: &mut MonoContext, env: &mut MonoTypeEnv, stmt: &mut Stmt) -> KoreResult<()> {
1234    match stmt {
1235        Stmt::Expr(e) => { scan_expr(ctx, env, e)?; }
1236        Stmt::Return(Some(e), _) => { scan_expr(ctx, env, e)?; }
1237        Stmt::Let { pattern, value, .. } => {
1238            // Scan the value expression (may contain generic calls like identity(42))
1239            if let Some(val_expr) = value {
1240                let ty = scan_expr(ctx, env, val_expr)?;
1241                // Also define the binding in the environment for type inference
1242                if let Pattern::Binding { name, .. } = pattern {
1243                    env.define(name.clone(), ty);
1244                }
1245            }
1246        }
1247        Stmt::For { binding, iter, body, .. } => {
1248            let iter_ty = scan_expr(ctx, env, iter)?;
1249            let elem_ty = match iter_ty {
1250                ResolvedType::Array(inner, _) => *inner,
1251                _ => ResolvedType::Int(IntSize::I64),
1252            };
1253            
1254            env.push();
1255            if let Pattern::Binding { name, .. } = binding {
1256                env.define(name.clone(), elem_ty);
1257            }
1258            scan_block(ctx, env, body)?;
1259            env.pop();
1260        }
1261        Stmt::While { condition, body, .. } => {
1262            scan_expr(ctx, env, condition)?;
1263            scan_block(ctx, env, body)?;
1264        }
1265        _ => {}
1266    }
1267    Ok(())
1268}
1269
1270fn scan_expr(ctx: &mut MonoContext, env: &mut MonoTypeEnv, expr: &mut Expr) -> KoreResult<ResolvedType> {
1271    match expr {
1272        Expr::Int(_, _) => Ok(ResolvedType::Int(IntSize::I64)),
1273        Expr::Float(_, _) => Ok(ResolvedType::Float(FloatSize::F64)),
1274        Expr::String(_, _) => Ok(ResolvedType::String),
1275        Expr::Bool(_, _) => Ok(ResolvedType::Bool),
1276        Expr::Ident(name, _) => Ok(env.get(name)),
1277        Expr::Struct { name, fields, .. } => {
1278            for (_, val) in fields {
1279                scan_expr(ctx, env, val)?;
1280            }
1281            // Return struct type
1282            // Ideally we check fields against definition, but here we just return the type
1283            Ok(ResolvedType::Struct(name.clone(), HashMap::new()))
1284        },
1285        Expr::Field { object, field, span: _ } => {
1286            let obj_ty = scan_expr(ctx, env, object)?;
1287            match obj_ty {
1288                ResolvedType::Struct(name, _) => {
1289                    if let Some(fields) = ctx.structs.get(&name) {
1290                         if let Some(ty) = fields.get(field) {
1291                             return Ok(ty.clone());
1292                         }
1293                    }
1294                    // If struct logic isn't fully loaded or field missing, return Unknown but maybe warn?
1295                    // For now, if we can't find it, we can't infer proper type for chain calls.
1296                    Ok(ResolvedType::Unknown)
1297                }
1298                _ => Ok(ResolvedType::Unknown),
1299            }
1300        },
1301        Expr::MethodCall { receiver, method, args, span } => {
1302            let receiver_ty = scan_expr(ctx, env, receiver)?;
1303            
1304            let type_name = match &receiver_ty {
1305                ResolvedType::Struct(name, _) => name.clone(),
1306                ResolvedType::Int(_) => "Int".to_string(),
1307                ResolvedType::Float(_) => "Float".to_string(),
1308                ResolvedType::String => "String".to_string(),
1309                _ => {
1310                    if let ResolvedType::Unknown = receiver_ty {
1311                         // Don't error hard yet, as we might be in partial state
1312                         return Ok(ResolvedType::Unknown);
1313                    }
1314                    format!("{:?}", receiver_ty)
1315                }
1316            };
1317            
1318            let mangled_target = {
1319                let methods = ctx.methods.get(&type_name);
1320                if let Some(lookup) = methods {
1321                    lookup.get(method).cloned()
1322                } else {
1323                    None
1324                }
1325            };
1326            
1327            if let Some(target_name) = mangled_target {
1328                 let mut new_args = args.clone();
1329                 new_args.insert(0, CallArg { name: None, value: *receiver.clone(), span: receiver.span() });
1330                 
1331                 for arg in &mut new_args {
1332                     scan_expr(ctx, env, &mut arg.value)?;
1333                 }
1334
1335                 *expr = Expr::Call {
1336                     callee: Box::new(Expr::Ident(target_name, *span)), // No ctx borrow here
1337                     args: new_args,
1338                     span: *span
1339                 };
1340                 
1341                 // Ideally lookup return type of function
1342                 // For now Unknown is safe for logic
1343                 return Ok(ResolvedType::Unknown);
1344            }
1345
1346            Ok(ResolvedType::Unknown)
1347        }
1348        Expr::Call { callee, args, .. } => {
1349            if let Expr::Ident(name, _) = callee.as_ref() {
1350                if let Some(generic_func) = ctx.generic_functions.get(name).cloned() {
1351                    // First, scan all arguments to get their types
1352                    let mut arg_types = Vec::new();
1353                    for arg in args {
1354                        arg_types.push(scan_expr(ctx, env, &mut arg.value)?);
1355                    }
1356                    
1357                    // Infer type arguments through unification
1358                    let inferred_type_args = infer_type_args(ctx, &generic_func, &arg_types)?;
1359                    
1360                    let new_name = ctx.instantiate(name, &inferred_type_args)?;
1361                    *callee = Box::new(Expr::Ident(new_name, callee.span()));
1362                    return Ok(ResolvedType::Unknown); 
1363                }
1364                
1365                // If it's a standard function, we might want to lookup return type
1366                // But for now, just scan args
1367            }
1368             for arg in args {
1369                scan_expr(ctx, env, &mut arg.value)?;
1370            }
1371            Ok(ResolvedType::Unknown)
1372        }
1373        Expr::Binary { left, right, .. } => {
1374            let t = scan_expr(ctx, env, left)?;
1375            scan_expr(ctx, env, right)?;
1376            Ok(t) 
1377        }
1378        Expr::Assign { value, .. } => scan_expr(ctx, env, value),
1379        Expr::Block(b, _) => {
1380            scan_block(ctx, env, b)?;
1381            Ok(ResolvedType::Unknown)
1382        }
1383        Expr::If { condition, then_branch, else_branch, .. } => {
1384            scan_expr(ctx, env, condition)?;
1385            scan_block(ctx, env, then_branch)?;
1386            if let Some(b) = else_branch {
1387                 match b.as_mut() {
1388                     ElseBranch::Else(blk) => { scan_block(ctx, env, blk)?; }
1389                     ElseBranch::ElseIf(_, _, _) => {} 
1390                 }
1391            }
1392             Ok(ResolvedType::Unknown)
1393        }
1394        Expr::Await(inner, _) => {
1395            // Scan the inner future expression for generic calls
1396            scan_expr(ctx, env, inner)
1397        }
1398        _ => Ok(ResolvedType::Unknown),
1399    }
1400}
1401
1402fn collect_locals(block: &Block) -> HashMap<String, ResolvedType> {
1403    let mut locals = HashMap::new();
1404    collect_locals_recursive(block, &mut locals);
1405    locals
1406}
1407
1408fn collect_locals_recursive(block: &Block, locals: &mut HashMap<String, ResolvedType>) {
1409    for stmt in &block.stmts {
1410        match stmt {
1411            Stmt::Let { pattern, .. } => collect_from_pattern(pattern, locals),
1412            Stmt::For { body, .. } => collect_locals_recursive(body, locals),
1413            Stmt::While { body, .. } => collect_locals_recursive(body, locals),
1414            Stmt::Expr(Expr::Block(b, _)) => collect_locals_recursive(b, locals),
1415            Stmt::Expr(Expr::If { then_branch, else_branch, .. }) => {
1416                collect_locals_recursive(then_branch, locals);
1417                if let Some(b) = else_branch {
1418                    collect_from_else(b, locals);
1419                }
1420            }
1421            _ => {}
1422        }
1423    }
1424}
1425
1426fn collect_from_else(branch: &ElseBranch, locals: &mut HashMap<String, ResolvedType>) {
1427     match branch {
1428         ElseBranch::Else(block) => collect_locals_recursive(block, locals),
1429         ElseBranch::ElseIf(_, block, next) => {
1430             collect_locals_recursive(block, locals);
1431             if let Some(n) = next {
1432                 collect_from_else(n, locals);
1433             }
1434         }
1435     }
1436}
1437
1438fn collect_from_pattern(pattern: &Pattern, locals: &mut HashMap<String, ResolvedType>) {
1439    match pattern {
1440        Pattern::Binding { name, .. } => { locals.insert(name.clone(), ResolvedType::Unknown); },
1441        Pattern::Tuple(pats, _) => {
1442            for p in pats { collect_from_pattern(p, locals); }
1443        }
1444        Pattern::Slice { patterns, rest, .. } => {
1445             for p in patterns { collect_from_pattern(p, locals); }
1446             if let Some(r) = rest {
1447                 locals.insert(r.clone(), ResolvedType::Unknown);
1448             }
1449        }
1450        Pattern::Struct { fields, .. } => {
1451             for (_, p) in fields { collect_from_pattern(p, locals); }
1452        }
1453        Pattern::Variant { fields, .. } => {
1454              match fields {
1455                  VariantPatternFields::Tuple(pats) => { for p in pats { collect_from_pattern(p, locals); } },
1456                  VariantPatternFields::Struct(pats) => { for (_, p) in pats { collect_from_pattern(p, locals); } },
1457                  _ => {}
1458              }
1459        }
1460        Pattern::Or(pats, _) => {
1461             for p in pats { collect_from_pattern(p, locals); }
1462        }
1463        _ => {}
1464    }
1465}
1466