1use std::collections::{HashMap, HashSet};
22
23use crate::analysis::unify::{InferType, TyVar, TypeScheme, TypeError, UnificationTable, infer_to_logos, unify_numeric};
24use crate::analysis::{FnSig, LogosType, TypeDef, TypeEnv, TypeRegistry};
25use crate::ast::stmt::{BinaryOpKind, Expr, OptFlag, Pattern, Stmt};
26use crate::intern::{Interner, Symbol};
27
28#[derive(Clone, Debug)]
38struct FunctionRecord {
39 param_names: Vec<Symbol>,
41 scheme: TypeScheme,
44}
45
46struct CheckEnv<'r> {
51 scopes: Vec<HashMap<Symbol, InferType>>,
53 all_vars: HashMap<Symbol, InferType>,
55 functions: HashMap<Symbol, FunctionRecord>,
57 current_return_type: Option<InferType>,
59 table: UnificationTable,
61 registry: &'r TypeRegistry,
62 interner: &'r Interner,
63}
64
65impl<'r> CheckEnv<'r> {
66 fn new(registry: &'r TypeRegistry, interner: &'r Interner) -> Self {
67 Self {
68 scopes: vec![HashMap::new()],
69 all_vars: HashMap::new(),
70 functions: HashMap::new(),
71 current_return_type: None,
72 table: UnificationTable::new(),
73 registry,
74 interner,
75 }
76 }
77
78 fn push_scope(&mut self) {
79 self.scopes.push(HashMap::new());
80 }
81
82 fn pop_scope(&mut self) {
83 self.scopes.pop();
84 }
85
86 fn bind_var(&mut self, sym: Symbol, ty: InferType) {
88 if let Some(scope) = self.scopes.last_mut() {
89 scope.insert(sym, ty.clone());
90 }
91 self.all_vars.insert(sym, ty);
92 }
93
94 fn lookup_var(&self, sym: Symbol) -> Option<InferType> {
100 for scope in self.scopes.iter().rev() {
101 if let Some(ty) = scope.get(&sym) {
102 return Some(self.table.resolve(ty));
103 }
104 }
105 None
106 }
107
108 fn into_type_env(self) -> TypeEnv {
110 let mut type_env = TypeEnv::new();
111
112 for (sym, ty) in self.all_vars {
114 let logos_ty = self.table.to_logos_type(&ty);
115 type_env.register(sym, logos_ty);
116 }
117
118 for (name, rec) in self.functions {
120 if let InferType::Function(param_types, ret_box) = &rec.scheme.body {
124 let ret_logos = self.table.to_logos_type(ret_box);
125 let params: Vec<(Symbol, LogosType)> = rec.param_names.iter()
126 .zip(param_types.iter())
127 .map(|(sym, ty)| (*sym, self.table.to_logos_type(ty)))
128 .collect();
129 type_env.register_fn(name, FnSig { params, return_type: ret_logos });
130 }
131 }
132
133 type_env
134 }
135}
136
137impl<'r> CheckEnv<'r> {
142 fn preregister_functions(&mut self, stmts: &[Stmt]) {
150 for stmt in stmts {
151 if let Stmt::FunctionDef { name, generics, params, return_type, .. } = stmt {
152 let type_param_map: HashMap<Symbol, TyVar> = generics
154 .iter()
155 .map(|&sym| (sym, self.table.fresh_var()))
156 .collect();
157
158 let param_types: Vec<InferType> = params
159 .iter()
160 .map(|(_, ty_expr)| {
161 InferType::from_type_expr_with_params(ty_expr, self.interner, &type_param_map)
162 })
163 .collect();
164 let param_names: Vec<Symbol> = params.iter().map(|(sym, _)| *sym).collect();
165
166 let ret_type = if let Some(rt) = return_type {
167 InferType::from_type_expr_with_params(rt, self.interner, &type_param_map)
168 } else {
169 self.table.fresh()
170 };
171
172 let generic_vars: Vec<TyVar> = generics
173 .iter()
174 .filter_map(|sym| type_param_map.get(sym).copied())
175 .collect();
176
177 let scheme = TypeScheme {
178 vars: generic_vars,
179 body: InferType::Function(param_types, Box::new(ret_type)),
180 };
181
182 self.functions.insert(*name, FunctionRecord { param_names, scheme });
183 }
184 }
185 }
186}
187
188impl<'r> CheckEnv<'r> {
193 fn check_expr(
198 &mut self,
199 expr: &Expr,
200 expected: &InferType,
201 ) -> Result<InferType, TypeError> {
202 use crate::ast::stmt::Literal;
203
204 if let Expr::Literal(Literal::Number(_)) = expr {
206 match expected {
207 InferType::Float => return Ok(InferType::Float),
208 InferType::Nat => return Ok(InferType::Nat),
209 InferType::Int => return Ok(InferType::Int),
210 InferType::Byte => return Ok(InferType::Byte),
211 _ => {}
212 }
213 }
214
215 if let Expr::Literal(Literal::Nothing) = expr {
218 if let InferType::Option(_) = expected {
219 return Ok(expected.clone());
220 }
221 }
222
223 let inferred = self.infer_expr(expr)?;
225 self.table.unify(&inferred, expected)?;
226 Ok(self.table.zonk(expected))
227 }
228
229 fn infer_expr(&mut self, expr: &Expr) -> Result<InferType, TypeError> {
231 match expr {
232 Expr::Literal(lit) => Ok(InferType::from_literal(lit)),
233
234 Expr::Identifier(sym) => {
235 Ok(self.lookup_var(*sym).unwrap_or(InferType::Unknown))
236 }
237
238 Expr::BinaryOp { op, left, right } => {
239 self.infer_binary_op(*op, left, right)
240 }
241
242 Expr::Length { .. } => Ok(InferType::Int),
243
244 Expr::Call { function, args } => {
245 self.infer_call(*function, args)
246 }
247
248 Expr::Index { collection, .. } => {
249 let coll_ty = self.infer_expr(collection)?;
250 let walked = self.table.zonk(&coll_ty);
251 match walked {
252 InferType::Seq(inner) => Ok(*inner),
253 InferType::Map(_, v) => Ok(*v),
254 _ => Ok(InferType::Unknown),
255 }
256 }
257
258 Expr::List(items) => {
259 if items.is_empty() {
260 let elem_var = self.table.fresh();
261 Ok(InferType::Seq(Box::new(elem_var)))
262 } else {
263 let elem_type = self.infer_expr(items[0])?;
264 Ok(InferType::Seq(Box::new(elem_type)))
265 }
266 }
267
268 Expr::OptionSome { value } => {
269 let inner = self.infer_expr(value)?;
270 Ok(InferType::Option(Box::new(inner)))
271 }
272
273 Expr::OptionNone => {
274 let elem_var = self.table.fresh();
275 Ok(InferType::Option(Box::new(elem_var)))
276 }
277
278 Expr::Range { .. } => Ok(InferType::Seq(Box::new(InferType::Int))),
279
280 Expr::Contains { .. } => Ok(InferType::Bool),
281
282 Expr::Copy { expr: inner } | Expr::Give { value: inner } => {
283 self.infer_expr(inner)
284 }
285
286 Expr::WithCapacity { value, .. } => self.infer_expr(value),
287
288 Expr::FieldAccess { object, field } => {
289 let obj_ty = self.infer_expr(object)?;
290 self.infer_field_access(obj_ty, *field)
291 }
292
293 Expr::New { type_name, type_args, .. } => {
294 let name = self.interner.resolve(*type_name);
295 match name {
296 "Seq" | "List" | "Vec" => {
297 let elem = type_args
298 .first()
299 .map(|t| InferType::from_type_expr(t, self.interner))
300 .unwrap_or_else(|| self.table.fresh());
301 Ok(InferType::Seq(Box::new(elem)))
302 }
303 "Map" | "HashMap" => {
304 let key = type_args
305 .first()
306 .map(|t| InferType::from_type_expr(t, self.interner))
307 .unwrap_or(InferType::String);
308 let val = type_args
309 .get(1)
310 .map(|t| InferType::from_type_expr(t, self.interner))
311 .unwrap_or(InferType::String);
312 Ok(InferType::Map(Box::new(key), Box::new(val)))
313 }
314 "Set" | "HashSet" => {
315 let elem = type_args
316 .first()
317 .map(|t| InferType::from_type_expr(t, self.interner))
318 .unwrap_or_else(|| self.table.fresh());
319 Ok(InferType::Set(Box::new(elem)))
320 }
321 _ => Ok(InferType::UserDefined(*type_name)),
322 }
323 }
324
325 Expr::NewVariant { enum_name, .. } => {
326 Ok(InferType::UserDefined(*enum_name))
327 }
328
329 Expr::CallExpr { callee, args } => {
330 self.infer_call_expr(callee, args)
331 }
332
333 Expr::Closure { params, body: closure_body, return_type } => {
334 self.infer_closure(params, closure_body, return_type)
335 }
336
337 Expr::InterpolatedString(_) => Ok(InferType::String),
338
339 Expr::Slice { collection, .. } => self.infer_expr(collection),
340
341 Expr::Union { left, .. } | Expr::Intersection { left, .. } => {
342 self.infer_expr(left)
343 }
344
345 _ => Ok(InferType::Unknown),
347 }
348 }
349
350 fn infer_binary_op(
352 &mut self,
353 op: BinaryOpKind,
354 left: &Expr,
355 right: &Expr,
356 ) -> Result<InferType, TypeError> {
357 match op {
358 BinaryOpKind::Eq
359 | BinaryOpKind::NotEq
360 | BinaryOpKind::Lt
361 | BinaryOpKind::Gt
362 | BinaryOpKind::LtEq
363 | BinaryOpKind::GtEq => Ok(InferType::Bool),
364
365 BinaryOpKind::And | BinaryOpKind::Or => {
367 let lt = self.infer_expr(left)?;
368 if lt == InferType::Int {
369 Ok(InferType::Int)
370 } else {
371 Ok(InferType::Bool)
372 }
373 }
374
375 BinaryOpKind::Concat => Ok(InferType::String),
376
377 BinaryOpKind::BitXor | BinaryOpKind::Shl | BinaryOpKind::Shr => Ok(InferType::Int),
378
379 BinaryOpKind::Add => {
380 let lt = self.infer_expr(left)?;
381 let rt = self.infer_expr(right)?;
382 if lt == InferType::String || rt == InferType::String {
383 Ok(InferType::String)
384 } else if lt == InferType::Unknown || rt == InferType::Unknown {
385 Ok(InferType::Unknown)
386 } else {
387 unify_numeric(<, &rt).or(Ok(InferType::Unknown))
388 }
389 }
390
391 BinaryOpKind::Subtract
392 | BinaryOpKind::Multiply
393 | BinaryOpKind::Divide
394 | BinaryOpKind::Modulo => {
395 let lt = self.infer_expr(left)?;
396 let rt = self.infer_expr(right)?;
397 if lt == InferType::Unknown || rt == InferType::Unknown {
398 Ok(InferType::Unknown)
399 } else {
400 unify_numeric(<, &rt).or(Ok(InferType::Unknown))
401 }
402 }
403 }
404 }
405
406 fn infer_call(&mut self, function: Symbol, args: &[&Expr]) -> Result<InferType, TypeError> {
412 let name = self.interner.resolve(function);
413 match name {
414 "sqrt" | "parseFloat" | "pow" => Ok(InferType::Float),
415 "parseInt" | "floor" | "ceil" | "round" => Ok(InferType::Int),
416 "abs" | "min" | "max" => {
417 if let Some(first) = args.first() {
418 self.infer_expr(first)
419 } else {
420 Ok(InferType::Unknown)
421 }
422 }
423 _ => {
424 if let Some(rec) = self.functions.get(&function).cloned() {
425 let instantiated = self.table.instantiate(&rec.scheme);
428
429 if let InferType::Function(param_types, ret_box) = instantiated {
430 for (arg, param_ty) in args.iter().zip(param_types.iter()) {
432 let arg_ty = self.infer_expr(arg)?;
433 self.table.unify(&arg_ty, param_ty)?;
434 }
435 Ok(self.table.zonk(&ret_box))
436 } else {
437 Ok(InferType::Unknown)
439 }
440 } else {
441 Ok(InferType::Unknown)
442 }
443 }
444 }
445 }
446
447 fn infer_call_expr(
449 &mut self,
450 callee: &Expr,
451 args: &[&Expr],
452 ) -> Result<InferType, TypeError> {
453 let callee_ty = self.infer_expr(callee)?;
454 let ret_var = self.table.fresh();
455 let arg_types: Vec<InferType> = args
456 .iter()
457 .map(|a| self.infer_expr(a))
458 .collect::<Result<_, _>>()?;
459 let fn_ty = InferType::Function(arg_types, Box::new(ret_var.clone()));
460
461 let walked = self.table.zonk(&callee_ty);
462 match walked {
463 InferType::Unknown => Ok(ret_var),
464 InferType::Function(_, _) => {
465 self.table.unify(&walked, &fn_ty)?;
466 Ok(ret_var)
467 }
468 InferType::Var(_) => {
469 self.table.unify(&walked, &fn_ty)?;
470 Ok(ret_var)
471 }
472 other => Err(TypeError::NotAFunction { found: other }),
473 }
474 }
475
476 fn infer_closure(
478 &mut self,
479 params: &[(Symbol, &crate::ast::stmt::TypeExpr)],
480 body: &crate::ast::stmt::ClosureBody,
481 return_type: &Option<&crate::ast::stmt::TypeExpr>,
482 ) -> Result<InferType, TypeError> {
483 let param_types: Vec<InferType> = params
484 .iter()
485 .map(|(_, ty_expr)| InferType::from_type_expr(ty_expr, self.interner))
486 .collect();
487
488 let ret_type = if let Some(rt) = return_type {
489 InferType::from_type_expr(rt, self.interner)
490 } else {
491 self.table.fresh()
492 };
493
494 self.push_scope();
495 for ((sym, _), ty) in params.iter().zip(param_types.iter()) {
496 self.bind_var(*sym, ty.clone());
497 }
498
499 let prev_return = self.current_return_type.take();
500 self.current_return_type = Some(ret_type.clone());
501
502 match body {
503 crate::ast::stmt::ClosureBody::Expression(expr) => {
504 let body_ty = self.infer_expr(expr)?;
505 self.table.unify(&body_ty, &ret_type).ok();
507 }
508 crate::ast::stmt::ClosureBody::Block(stmts) => {
509 for stmt in *stmts {
510 self.infer_stmt(stmt)?;
511 }
512 }
513 }
514
515 self.current_return_type = prev_return;
516 self.pop_scope();
517
518 Ok(InferType::Function(param_types, Box::new(ret_type)))
519 }
520
521 fn infer_field_access(
523 &self,
524 obj_ty: InferType,
525 field: Symbol,
526 ) -> Result<InferType, TypeError> {
527 let resolved = self.table.zonk(&obj_ty);
528 match &resolved {
529 InferType::UserDefined(type_sym) => {
530 if let Some(TypeDef::Struct { fields, .. }) = self.registry.get(*type_sym) {
531 if let Some(field_def) = fields.iter().find(|f| f.name == field) {
532 Ok(InferType::from_field_type(
533 &field_def.ty,
534 self.interner,
535 &HashMap::new(),
536 ))
537 } else {
538 Err(TypeError::FieldNotFound {
539 type_name: *type_sym,
540 field_name: field,
541 })
542 }
543 } else {
544 Ok(InferType::Unknown)
546 }
547 }
548 _ => Ok(InferType::Unknown),
550 }
551 }
552}
553
554impl<'r> CheckEnv<'r> {
559 fn infer_stmt(&mut self, stmt: &Stmt) -> Result<(), TypeError> {
560 match stmt {
561 Stmt::Let { var, ty, value, .. } => {
562 let final_ty = if let Some(type_expr) = ty {
563 let annotated = InferType::from_type_expr(type_expr, self.interner);
564 if annotated != InferType::Unknown {
565 self.check_expr(value, &annotated)?
567 } else {
568 self.infer_expr(value)?
569 }
570 } else {
571 self.infer_expr(value)?
572 };
573 self.bind_var(*var, final_ty);
574 Ok(())
575 }
576
577 Stmt::Set { target, value } => {
578 let inferred = self.infer_expr(value)?;
579 if let Some(existing) = self.lookup_var(*target) {
581 if existing != InferType::Unknown {
582 self.table.unify(&inferred, &existing).ok();
583 }
584 }
585 let resolved = self.table.zonk(&inferred);
587 if resolved != InferType::Unknown {
588 self.bind_var(*target, resolved);
589 }
590 Ok(())
591 }
592
593 Stmt::FunctionDef {
594 name,
595 generics,
596 params,
597 body,
598 return_type,
599 is_native,
600 ..
601 } => {
602 let type_param_map: HashMap<Symbol, TyVar> = {
606 let existing_vars: Vec<TyVar> = self.functions
608 .get(name)
609 .map(|rec| rec.scheme.vars.clone())
610 .unwrap_or_default();
611 if existing_vars.len() == generics.len() {
612 generics.iter().copied().zip(existing_vars).collect()
613 } else {
614 generics.iter().map(|&sym| (sym, self.table.fresh_var())).collect()
615 }
616 };
617
618 let param_types: Vec<InferType> = params
619 .iter()
620 .map(|(_, ty_expr)| {
621 InferType::from_type_expr_with_params(ty_expr, self.interner, &type_param_map)
622 })
623 .collect();
624 let param_names: Vec<Symbol> = params.iter().map(|(sym, _)| *sym).collect();
625
626 let ret_type = if let Some(rt) = return_type {
627 InferType::from_type_expr_with_params(rt, self.interner, &type_param_map)
628 } else if let Some(rec) = self.functions.get(name) {
629 if let InferType::Function(_, ret_box) = &rec.scheme.body {
631 *ret_box.clone()
632 } else {
633 self.table.fresh()
634 }
635 } else {
636 self.table.fresh()
637 };
638
639 let generic_vars: Vec<TyVar> = generics
640 .iter()
641 .filter_map(|sym| type_param_map.get(sym).copied())
642 .collect();
643
644 if *is_native {
646 let scheme = TypeScheme {
647 vars: generic_vars,
648 body: InferType::Function(param_types, Box::new(ret_type)),
649 };
650 self.functions.insert(*name, FunctionRecord { param_names, scheme });
651 return Ok(());
652 }
653
654 let prev_return_type = self.current_return_type.take();
656 self.current_return_type = Some(ret_type.clone());
657
658 self.push_scope();
660 for (sym, ty) in param_names.iter().zip(param_types.iter()) {
661 self.bind_var(*sym, ty.clone());
662 }
663 for s in *body {
664 self.infer_stmt(s)?;
665 }
666 self.pop_scope();
667
668 self.current_return_type = prev_return_type;
669
670 let resolved_params: Vec<InferType> = param_types
674 .iter()
675 .map(|ty| self.table.resolve(ty))
676 .collect();
677 let resolved_ret = self.table.resolve(&ret_type);
678
679 let scheme = TypeScheme {
680 vars: generic_vars,
681 body: InferType::Function(resolved_params, Box::new(resolved_ret)),
682 };
683 self.functions.insert(*name, FunctionRecord { param_names, scheme });
684 Ok(())
685 }
686
687 Stmt::Return { value } => {
688 let ty = match value {
689 Some(expr) => self.infer_expr(expr)?,
690 None => InferType::Unit,
691 };
692 if let Some(expected) = self.current_return_type.clone() {
693 if expected != InferType::Unknown {
695 self.table.unify(&ty, &expected)?;
696 }
697 }
698 Ok(())
699 }
700
701 Stmt::Repeat { pattern, iterable, body } => {
702 let iterable_ty = self.infer_expr(iterable)?;
703 let elem_ty = match self.table.zonk(&iterable_ty) {
704 InferType::Seq(inner) | InferType::Set(inner) => *inner,
705 InferType::Map(k, _) => *k,
706 _ => InferType::Unknown,
707 };
708 match pattern {
709 Pattern::Identifier(sym) => self.bind_var(*sym, elem_ty),
710 Pattern::Tuple(syms) => {
711 for sym in syms {
712 self.bind_var(*sym, InferType::Unknown);
713 }
714 }
715 }
716 for s in *body {
717 self.infer_stmt(s)?;
718 }
719 Ok(())
720 }
721
722 Stmt::If { then_block, else_block, .. } => {
723 for s in *then_block {
724 self.infer_stmt(s)?;
725 }
726 if let Some(else_b) = else_block {
727 for s in *else_b {
728 self.infer_stmt(s)?;
729 }
730 }
731 Ok(())
732 }
733
734 Stmt::While { body, .. } => {
735 for s in *body {
736 self.infer_stmt(s)?;
737 }
738 Ok(())
739 }
740
741 Stmt::Inspect { target, arms, .. } => {
742 let _target_ty = self.infer_expr(target)?;
743 for arm in arms {
744 self.push_scope();
745 self.infer_inspect_arm(arm)?;
746 self.pop_scope();
747 }
748 Ok(())
749 }
750
751 Stmt::Zone { body, .. } => {
752 for s in *body {
753 self.infer_stmt(s)?;
754 }
755 Ok(())
756 }
757
758 Stmt::ReadFrom { var, .. } => {
759 self.bind_var(*var, InferType::String);
760 Ok(())
761 }
762
763 Stmt::CreatePipe { var, element_type, .. } => {
764 let elem = InferType::from_type_name(self.interner.resolve(*element_type));
765 self.bind_var(*var, elem);
766 Ok(())
767 }
768
769 Stmt::ReceivePipe { var, pipe } => {
770 let elem_ty = self.infer_expr(pipe)?;
772 self.bind_var(*var, elem_ty);
773 Ok(())
774 }
775
776 Stmt::TryReceivePipe { var, pipe } => {
777 let elem_ty = self.infer_expr(pipe)?;
778 self.bind_var(*var, InferType::Option(Box::new(elem_ty)));
780 Ok(())
781 }
782
783 Stmt::Pop { into: Some(var), collection } => {
784 let coll_ty = self.infer_expr(collection)?;
785 let elem_ty = match self.table.zonk(&coll_ty) {
786 InferType::Seq(inner) | InferType::Set(inner) => *inner,
787 _ => InferType::Unknown,
788 };
789 self.bind_var(*var, elem_ty);
790 Ok(())
791 }
792
793 Stmt::AwaitMessage { into, .. } => {
794 self.bind_var(*into, InferType::Unknown);
795 Ok(())
796 }
797
798 Stmt::LaunchTaskWithHandle { handle, .. } => {
799 self.bind_var(*handle, InferType::Unknown);
800 Ok(())
801 }
802
803 Stmt::Concurrent { tasks } | Stmt::Parallel { tasks } => {
804 for s in *tasks {
805 self.infer_stmt(s)?;
806 }
807 Ok(())
808 }
809
810 Stmt::Select { branches } => {
811 for branch in branches {
812 match branch {
813 crate::ast::stmt::SelectBranch::Receive { var, pipe, body } => {
814 let elem_ty = self.infer_expr(pipe)?;
815 self.push_scope();
816 self.bind_var(*var, elem_ty);
817 for s in *body {
818 self.infer_stmt(s)?;
819 }
820 self.pop_scope();
821 }
822 crate::ast::stmt::SelectBranch::Timeout { body, .. } => {
823 for s in *body {
824 self.infer_stmt(s)?;
825 }
826 }
827 }
828 }
829 Ok(())
830 }
831
832 _ => Ok(()),
833 }
834 }
835
836 fn infer_inspect_arm(
838 &mut self,
839 arm: &crate::ast::stmt::MatchArm,
840 ) -> Result<(), TypeError> {
841 if let Some(variant_sym) = arm.variant {
842 if let Some((_, variant_def)) = self.registry.find_variant(variant_sym) {
843 let fields: Vec<_> = variant_def
845 .fields
846 .iter()
847 .map(|f| (f.name, f.ty.clone()))
848 .collect();
849
850 for (field_sym, binding_sym) in &arm.bindings {
851 let ty = fields
852 .iter()
853 .find(|(name, _)| *name == *field_sym)
854 .map(|(_, ty)| {
855 InferType::from_field_type(ty, self.interner, &HashMap::new())
856 })
857 .unwrap_or(InferType::Unknown);
858 self.bind_var(*binding_sym, ty);
859 }
860 } else {
861 for (_, binding_sym) in &arm.bindings {
863 self.bind_var(*binding_sym, InferType::Unknown);
864 }
865 }
866 } else {
867 for (_, binding_sym) in &arm.bindings {
869 self.bind_var(*binding_sym, InferType::Unknown);
870 }
871 }
872
873 for s in arm.body {
874 self.infer_stmt(s)?;
875 }
876 Ok(())
877 }
878}
879
880pub fn check_program(
890 stmts: &[Stmt],
891 interner: &Interner,
892 registry: &TypeRegistry,
893) -> Result<TypeEnv, TypeError> {
894 let mut env = CheckEnv::new(registry, interner);
895
896 env.preregister_functions(stmts);
898
899 for stmt in stmts {
901 env.infer_stmt(stmt)?;
902 }
903
904 Ok(env.into_type_env())
905}
906
907#[cfg(test)]
912mod tests {
913 use super::*;
914 use crate::ast::stmt::{Expr, Literal, Stmt, TypeExpr};
915 use crate::intern::Interner;
916
917 fn mk_interner() -> Interner {
922 Interner::new()
923 }
924
925 fn run(stmts: &[Stmt], interner: &Interner) -> TypeEnv {
926 check_program(stmts, interner, &TypeRegistry::new()).expect("check_program failed")
927 }
928
929 #[test]
934 fn let_literal_int() {
935 let mut interner = mk_interner();
936 let x = interner.intern("x");
937 let val = Expr::Literal(Literal::Number(42));
938 let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
939 let env = run(&stmts, &interner);
940 assert_eq!(env.lookup(x), &LogosType::Int);
941 }
942
943 #[test]
944 fn let_literal_float() {
945 let mut interner = mk_interner();
946 let x = interner.intern("x");
947 let val = Expr::Literal(Literal::Float(3.14));
948 let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
949 let env = run(&stmts, &interner);
950 assert_eq!(env.lookup(x), &LogosType::Float);
951 }
952
953 #[test]
954 fn let_literal_string() {
955 let mut interner = mk_interner();
956 let s = interner.intern("s");
957 let hello = interner.intern("hello");
958 let val = Expr::Literal(Literal::Text(hello));
959 let stmts = [Stmt::Let { var: s, ty: None, value: &val, mutable: false }];
960 let env = run(&stmts, &interner);
961 assert_eq!(env.lookup(s), &LogosType::String);
962 }
963
964 #[test]
969 fn let_with_annotation_uses_annotation() {
970 let mut interner = mk_interner();
971 let x = interner.intern("x");
972 let float_sym = interner.intern("Real");
973 let val = Expr::Literal(Literal::Number(5)); let ty_ann = TypeExpr::Primitive(float_sym);
975 let stmts = [Stmt::Let { var: x, ty: Some(&ty_ann), value: &val, mutable: false }];
976 let env = run(&stmts, &interner);
977 assert_eq!(env.lookup(x), &LogosType::Float);
979 }
980
981 #[test]
982 fn let_type_mismatch_fails() {
983 let mut interner = mk_interner();
984 let x = interner.intern("x");
985 let int_sym = interner.intern("Int");
986 let val = Expr::Literal(Literal::Text(Symbol::EMPTY));
987 let ty_ann = TypeExpr::Primitive(int_sym);
988 let stmts = [Stmt::Let { var: x, ty: Some(&ty_ann), value: &val, mutable: false }];
989 let result = check_program(&stmts, &interner, &TypeRegistry::new());
990 assert!(result.is_err(), "Int and Text should not unify");
991 }
992
993 #[test]
998 fn empty_list_is_seq_unknown() {
999 let mut interner = mk_interner();
1000 let xs = interner.intern("xs");
1001 let val = Expr::List(vec![]);
1002 let stmts = [Stmt::Let { var: xs, ty: None, value: &val, mutable: false }];
1003 let env = run(&stmts, &interner);
1004 assert!(matches!(env.lookup(xs), LogosType::Seq(_)));
1006 }
1007
1008 #[test]
1009 fn non_empty_list_infers_element_type() {
1010 let mut interner = mk_interner();
1011 let xs = interner.intern("xs");
1012 let one = Expr::Literal(Literal::Number(1));
1013 let two = Expr::Literal(Literal::Number(2));
1014 let val = Expr::List(vec![&one, &two]);
1015 let stmts = [Stmt::Let { var: xs, ty: None, value: &val, mutable: false }];
1016 let env = run(&stmts, &interner);
1017 assert_eq!(env.lookup(xs), &LogosType::Seq(Box::new(LogosType::Int)));
1018 }
1019
1020 #[test]
1025 fn option_none_is_option_unknown() {
1026 let mut interner = mk_interner();
1027 let x = interner.intern("x");
1028 let val = Expr::OptionNone;
1029 let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
1030 let env = run(&stmts, &interner);
1031 assert!(matches!(env.lookup(x), LogosType::Option(_)));
1032 }
1033
1034 #[test]
1035 fn option_some_infers_inner_type() {
1036 let mut interner = mk_interner();
1037 let x = interner.intern("x");
1038 let inner = Expr::Literal(Literal::Number(42));
1039 let val = Expr::OptionSome { value: &inner };
1040 let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
1041 let env = run(&stmts, &interner);
1042 assert_eq!(env.lookup(x), &LogosType::Option(Box::new(LogosType::Int)));
1043 }
1044
1045 #[test]
1050 fn function_def_registers_signature() {
1051 let mut interner = mk_interner();
1052 let f = interner.intern("double");
1053 let x_param = interner.intern("x");
1054 let int_sym = interner.intern("Int");
1055 let int_ty = TypeExpr::Primitive(int_sym);
1056 let ret_ty = TypeExpr::Primitive(int_sym);
1057 let lit = Expr::Literal(Literal::Number(0));
1058 let ret_stmt = Stmt::Return { value: Some(&lit) };
1059 let body = [ret_stmt];
1060 let stmts = [Stmt::FunctionDef {
1061 name: f,
1062 generics: vec![],
1063 params: vec![(x_param, &int_ty)],
1064 body: &body,
1065 return_type: Some(&ret_ty),
1066 is_native: false,
1067 native_path: None,
1068 is_exported: false,
1069 export_target: None,
1070 opt_flags: HashSet::new(),
1071 }];
1072 let env = run(&stmts, &interner);
1073 let sig = env.lookup_fn(f).expect("function should be registered");
1074 assert_eq!(sig.return_type, LogosType::Int);
1075 assert_eq!(sig.params.len(), 1);
1076 assert_eq!(sig.params[0].1, LogosType::Int);
1077 }
1078
1079 #[test]
1080 fn function_call_returns_registered_type() {
1081 let mut interner = mk_interner();
1082 let f = interner.intern("compute");
1083 let float_sym = interner.intern("Real");
1084 let float_ty = TypeExpr::Primitive(float_sym);
1085 let lit = Expr::Literal(Literal::Float(1.0));
1086 let ret_stmt = Stmt::Return { value: Some(&lit) };
1087 let body = [ret_stmt];
1088 let fn_def = Stmt::FunctionDef {
1089 name: f,
1090 generics: vec![],
1091 params: vec![],
1092 body: &body,
1093 return_type: Some(&float_ty),
1094 is_native: false,
1095 native_path: None,
1096 is_exported: false,
1097 export_target: None,
1098 opt_flags: HashSet::new(),
1099 };
1100 let result_var = interner.intern("result");
1101 let call = Expr::Call { function: f, args: vec![] };
1102 let let_stmt = Stmt::Let { var: result_var, ty: None, value: &call, mutable: false };
1103 let stmts = [fn_def, let_stmt];
1104 let env = run(&stmts, &interner);
1105 assert_eq!(env.lookup(result_var), &LogosType::Float);
1106 }
1107
1108 #[test]
1113 fn readfrom_is_string() {
1114 let mut interner = mk_interner();
1115 let v = interner.intern("input");
1116 let stmts = [Stmt::ReadFrom {
1117 var: v,
1118 source: crate::ast::stmt::ReadSource::Console,
1119 }];
1120 let env = run(&stmts, &interner);
1121 assert_eq!(env.lookup(v), &LogosType::String);
1122 }
1123
1124 #[test]
1129 fn repeat_loop_var_gets_element_type() {
1130 let mut interner = mk_interner();
1131 let items = interner.intern("items");
1132 let elem = interner.intern("elem");
1133 let one = Expr::Literal(Literal::Number(1));
1134 let list = Expr::List(vec![&one]);
1135 let let_items = Stmt::Let { var: items, ty: None, value: &list, mutable: false };
1136 let items_ref = Expr::Identifier(items);
1137 let repeat = Stmt::Repeat {
1138 pattern: Pattern::Identifier(elem),
1139 iterable: &items_ref,
1140 body: &[],
1141 };
1142 let stmts = [let_items, repeat];
1143 let env = run(&stmts, &interner);
1144 assert_eq!(env.lookup(elem), &LogosType::Int);
1145 }
1146
1147 #[test]
1152 fn field_access_resolves_with_registry() {
1153 use crate::analysis::{FieldDef, FieldType, TypeDef};
1154
1155 let mut interner = mk_interner();
1156 let point_sym = interner.intern("Point");
1157 let x_field_sym = interner.intern("x");
1158 let int_sym = interner.intern("Int");
1159 let p_var = interner.intern("p");
1160 let result_var = interner.intern("px");
1161
1162 let mut registry = TypeRegistry::new();
1164 registry.register(
1165 point_sym,
1166 TypeDef::Struct {
1167 fields: vec![FieldDef {
1168 name: x_field_sym,
1169 ty: FieldType::Primitive(int_sym),
1170 is_public: true,
1171 }],
1172 generics: vec![],
1173 is_portable: false,
1174 is_shared: false,
1175 },
1176 );
1177
1178 let new_point = Expr::New { type_name: point_sym, type_args: vec![], init_fields: vec![] };
1180 let let_p = Stmt::Let { var: p_var, ty: None, value: &new_point, mutable: false };
1181
1182 let p_ref = Expr::Identifier(p_var);
1184 let field_access = Expr::FieldAccess { object: &p_ref, field: x_field_sym };
1185 let let_px = Stmt::Let { var: result_var, ty: None, value: &field_access, mutable: false };
1186
1187 let stmts = [let_p, let_px];
1188 let env = check_program(&stmts, &interner, ®istry).expect("check_program failed");
1189 assert_eq!(env.lookup(result_var), &LogosType::Int);
1190 }
1191
1192 #[test]
1197 fn forward_reference_function_call() {
1198 let mut interner = mk_interner();
1199 let f = interner.intern("later_fn");
1200 let result_var = interner.intern("r");
1201 let bool_sym = interner.intern("Bool");
1202 let bool_ty = TypeExpr::Primitive(bool_sym);
1203
1204 let call = Expr::Call { function: f, args: vec![] };
1206 let let_r = Stmt::Let { var: result_var, ty: None, value: &call, mutable: false };
1207
1208 let lit = Expr::Literal(Literal::Boolean(true));
1210 let ret_stmt = Stmt::Return { value: Some(&lit) };
1211 let body = [ret_stmt];
1212 let fn_def = Stmt::FunctionDef {
1213 name: f,
1214 generics: vec![],
1215 params: vec![],
1216 body: &body,
1217 return_type: Some(&bool_ty),
1218 is_native: false,
1219 native_path: None,
1220 is_exported: false,
1221 export_target: None,
1222 opt_flags: HashSet::new(),
1223 };
1224
1225 let stmts = [let_r, fn_def];
1227 let env = run(&stmts, &interner);
1228 assert_eq!(env.lookup(result_var), &LogosType::Bool);
1229 }
1230
1231 #[test]
1236 fn return_type_mismatch_fails() {
1237 let mut interner = mk_interner();
1238 let f = interner.intern("f");
1239 let int_sym = interner.intern("Int");
1240 let int_ty = TypeExpr::Primitive(int_sym);
1241 let lit = Expr::Literal(Literal::Text(Symbol::EMPTY));
1243 let ret_stmt = Stmt::Return { value: Some(&lit) };
1244 let body = [ret_stmt];
1245 let stmts = [Stmt::FunctionDef {
1246 name: f,
1247 generics: vec![],
1248 params: vec![],
1249 body: &body,
1250 return_type: Some(&int_ty),
1251 is_native: false,
1252 native_path: None,
1253 is_exported: false,
1254 export_target: None,
1255 opt_flags: HashSet::new(),
1256 }];
1257 let result = check_program(&stmts, &interner, &TypeRegistry::new());
1258 assert!(result.is_err(), "returning Text from Int function should fail");
1259 }
1260
1261 #[test]
1266 fn new_user_defined_is_user_defined_type() {
1267 let mut interner = mk_interner();
1268 let point = interner.intern("Point");
1269 let p = interner.intern("p");
1270 let new_point = Expr::New { type_name: point, type_args: vec![], init_fields: vec![] };
1271 let stmts = [Stmt::Let { var: p, ty: None, value: &new_point, mutable: false }];
1272 let env = run(&stmts, &interner);
1273 assert_eq!(env.lookup(p), &LogosType::UserDefined(point));
1274 }
1275
1276 #[test]
1281 fn string_vars_in_legacy_api() {
1282 let mut interner = mk_interner();
1283 let s = interner.intern("name");
1284 let hello = interner.intern("hello");
1285 let val = Expr::Literal(Literal::Text(hello));
1286 let stmts = [Stmt::Let { var: s, ty: None, value: &val, mutable: false }];
1287 let env = run(&stmts, &interner);
1288 assert!(env.to_legacy_string_vars().contains(&s));
1289 }
1290
1291 #[test]
1292 fn unknown_vars_filtered_in_legacy_api() {
1293 let mut interner = mk_interner();
1294 let x = interner.intern("x");
1295 let val = Expr::OptionNone; let stmts = [Stmt::Let { var: x, ty: None, value: &val, mutable: false }];
1297 let env = run(&stmts, &interner);
1298 let legacy = env.to_legacy_variable_types();
1300 assert!(!legacy.is_empty() || legacy.is_empty()); }
1303
1304 #[test]
1309 fn generic_identity_infers_int_return() {
1310 let mut interner = mk_interner();
1314 let f = interner.intern("identity");
1315 let x_param = interner.intern("x");
1316 let t_sym = interner.intern("T");
1317 let t_ty = TypeExpr::Primitive(t_sym);
1318 let x_ref = Expr::Identifier(x_param);
1319 let ret_stmt = Stmt::Return { value: Some(&x_ref) };
1320 let body = [ret_stmt];
1321 let fn_def = Stmt::FunctionDef {
1322 name: f,
1323 generics: vec![t_sym],
1324 params: vec![(x_param, &t_ty)],
1325 body: &body,
1326 return_type: Some(&t_ty),
1327 is_native: false,
1328 native_path: None,
1329 is_exported: false,
1330 export_target: None,
1331 opt_flags: HashSet::new(),
1332 };
1333 let r = interner.intern("r");
1334 let lit = Expr::Literal(Literal::Number(42));
1335 let call = Expr::Call { function: f, args: vec![&lit] };
1336 let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1337 let stmts = [fn_def, let_r];
1338 let env = run(&stmts, &interner);
1339 assert_eq!(env.lookup(r), &LogosType::Int,
1340 "identity(42) should return Int, got {:?}", env.lookup(r));
1341 }
1342
1343 #[test]
1344 fn generic_identity_infers_bool_return() {
1345 let mut interner = mk_interner();
1347 let f = interner.intern("identity");
1348 let x_param = interner.intern("x");
1349 let t_sym = interner.intern("T");
1350 let t_ty = TypeExpr::Primitive(t_sym);
1351 let x_ref = Expr::Identifier(x_param);
1352 let ret_stmt = Stmt::Return { value: Some(&x_ref) };
1353 let body = [ret_stmt];
1354 let fn_def = Stmt::FunctionDef {
1355 name: f,
1356 generics: vec![t_sym],
1357 params: vec![(x_param, &t_ty)],
1358 body: &body,
1359 return_type: Some(&t_ty),
1360 is_native: false,
1361 native_path: None,
1362 is_exported: false,
1363 export_target: None,
1364 opt_flags: HashSet::new(),
1365 };
1366 let r = interner.intern("r");
1367 let lit = Expr::Literal(Literal::Boolean(true));
1368 let call = Expr::Call { function: f, args: vec![&lit] };
1369 let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1370 let stmts = [fn_def, let_r];
1371 let env = run(&stmts, &interner);
1372 assert_eq!(env.lookup(r), &LogosType::Bool,
1373 "identity(true) should return Bool, got {:?}", env.lookup(r));
1374 }
1375
1376 #[test]
1377 fn generic_two_type_params_first() {
1378 let mut interner = mk_interner();
1382 let f = interner.intern("first");
1383 let a_param = interner.intern("a");
1384 let b_param = interner.intern("b");
1385 let a_sym = interner.intern("A");
1386 let b_sym = interner.intern("B");
1387 let a_ty = TypeExpr::Primitive(a_sym);
1388 let b_ty = TypeExpr::Primitive(b_sym);
1389 let a_ref = Expr::Identifier(a_param);
1390 let ret_stmt = Stmt::Return { value: Some(&a_ref) };
1391 let body = [ret_stmt];
1392 let fn_def = Stmt::FunctionDef {
1393 name: f,
1394 generics: vec![a_sym, b_sym],
1395 params: vec![(a_param, &a_ty), (b_param, &b_ty)],
1396 body: &body,
1397 return_type: Some(&a_ty),
1398 is_native: false,
1399 native_path: None,
1400 is_exported: false,
1401 export_target: None,
1402 opt_flags: HashSet::new(),
1403 };
1404 let r = interner.intern("r");
1405 let lit_int = Expr::Literal(Literal::Number(42));
1406 let lit_bool = Expr::Literal(Literal::Boolean(true));
1407 let call = Expr::Call { function: f, args: vec![&lit_int, &lit_bool] };
1408 let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1409 let stmts = [fn_def, let_r];
1410 let env = run(&stmts, &interner);
1411 assert_eq!(env.lookup(r), &LogosType::Int,
1412 "first(42, true) should return Int (first param type), got {:?}", env.lookup(r));
1413 }
1414
1415 #[test]
1416 fn generic_calls_are_independent() {
1417 let mut interner = mk_interner();
1420 let f = interner.intern("identity");
1421 let x_param = interner.intern("x");
1422 let t_sym = interner.intern("T");
1423 let t_ty = TypeExpr::Primitive(t_sym);
1424 let x_ref = Expr::Identifier(x_param);
1425 let ret_stmt = Stmt::Return { value: Some(&x_ref) };
1426 let body = [ret_stmt];
1427 let fn_def = Stmt::FunctionDef {
1428 name: f,
1429 generics: vec![t_sym],
1430 params: vec![(x_param, &t_ty)],
1431 body: &body,
1432 return_type: Some(&t_ty),
1433 is_native: false,
1434 native_path: None,
1435 is_exported: false,
1436 export_target: None,
1437 opt_flags: HashSet::new(),
1438 };
1439 let r1 = interner.intern("r1");
1440 let r2 = interner.intern("r2");
1441 let lit_int = Expr::Literal(Literal::Number(42));
1442 let lit_bool = Expr::Literal(Literal::Boolean(true));
1443 let call1 = Expr::Call { function: f, args: vec![&lit_int] };
1444 let call2 = Expr::Call { function: f, args: vec![&lit_bool] };
1445 let let_r1 = Stmt::Let { var: r1, ty: None, value: &call1, mutable: false };
1446 let let_r2 = Stmt::Let { var: r2, ty: None, value: &call2, mutable: false };
1447 let stmts = [fn_def, let_r1, let_r2];
1448 let env = run(&stmts, &interner);
1449 assert_eq!(env.lookup(r1), &LogosType::Int,
1450 "identity(42) should be Int, got {:?}", env.lookup(r1));
1451 assert_eq!(env.lookup(r2), &LogosType::Bool,
1452 "identity(true) should be Bool, got {:?}", env.lookup(r2));
1453 }
1454
1455 #[test]
1456 fn monomorphic_functions_unaffected_by_generics() {
1457 let mut interner = mk_interner();
1459 let f = interner.intern("double");
1460 let x_param = interner.intern("x");
1461 let int_sym = interner.intern("Int");
1462 let int_ty = TypeExpr::Primitive(int_sym);
1463 let x_ref = Expr::Identifier(x_param);
1464 let lit2 = Expr::Literal(Literal::Number(2));
1465 let mul = Expr::BinaryOp {
1466 op: BinaryOpKind::Multiply,
1467 left: &x_ref,
1468 right: &lit2,
1469 };
1470 let ret_stmt = Stmt::Return { value: Some(&mul) };
1471 let body = [ret_stmt];
1472 let fn_def = Stmt::FunctionDef {
1473 name: f,
1474 generics: vec![],
1475 params: vec![(x_param, &int_ty)],
1476 body: &body,
1477 return_type: Some(&int_ty),
1478 is_native: false,
1479 native_path: None,
1480 is_exported: false,
1481 export_target: None,
1482 opt_flags: HashSet::new(),
1483 };
1484 let r = interner.intern("r");
1485 let lit5 = Expr::Literal(Literal::Number(5));
1486 let call = Expr::Call { function: f, args: vec![&lit5] };
1487 let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1488 let stmts = [fn_def, let_r];
1489 let env = run(&stmts, &interner);
1490 assert_eq!(env.lookup(r), &LogosType::Int,
1491 "double(5) should return Int, got {:?}", env.lookup(r));
1492 }
1493
1494 #[test]
1495 fn generic_forward_reference_resolves() {
1496 let mut interner = mk_interner();
1501 let f = interner.intern("identity");
1502 let x_param = interner.intern("x");
1503 let t_sym = interner.intern("T");
1504 let t_ty = TypeExpr::Primitive(t_sym);
1505 let x_ref = Expr::Identifier(x_param);
1506 let ret_stmt = Stmt::Return { value: Some(&x_ref) };
1507 let body = [ret_stmt];
1508 let fn_def = Stmt::FunctionDef {
1509 name: f,
1510 generics: vec![t_sym],
1511 params: vec![(x_param, &t_ty)],
1512 body: &body,
1513 return_type: Some(&t_ty),
1514 is_native: false,
1515 native_path: None,
1516 is_exported: false,
1517 export_target: None,
1518 opt_flags: HashSet::new(),
1519 };
1520 let r = interner.intern("r");
1521 let lit = Expr::Literal(Literal::Number(99));
1522 let call = Expr::Call { function: f, args: vec![&lit] };
1523 let let_r = Stmt::Let { var: r, ty: None, value: &call, mutable: false };
1524 let stmts = [let_r, fn_def];
1526 let env = run(&stmts, &interner);
1527 assert_eq!(env.lookup(r), &LogosType::Int,
1528 "forward-ref identity(99) should be Int, got {:?}", env.lookup(r));
1529 }
1530}