Skip to main content

fastc/typecheck/
mod.rs

1//! Type checking pass
2//!
3//! This pass:
4//! 1. Infers types for all expressions
5//! 2. Checks that operators are applied to correct types
6//! 3. Checks that function calls have correct arguments
7//! 4. Tracks unsafe context and enforces safety rules
8
9mod context;
10mod safety;
11
12pub use context::*;
13pub use safety::*;
14
15use crate::ast::{
16    BinOp, Block, ConstExpr, EnumDecl, Expr, ExternItem, File, FnDecl, Item, PrimitiveType, Repr,
17    Stmt, StructDecl, TypeExpr, UnaryOp,
18};
19use crate::diag::CompileError;
20use crate::lexer::Span;
21use crate::resolve::{Symbol, SymbolKind, SymbolTable};
22use std::collections::{HashMap, HashSet};
23
24/// Type checker
25pub struct TypeChecker<'a> {
26    source: &'a str,
27    symbols: SymbolTable,
28    safety: SafetyContext,
29    current_fn_return_type: Option<TypeExpr>,
30    errors: Vec<CompileError>,
31    enum_decls: HashMap<String, EnumDecl>,
32    struct_decls: HashMap<String, StructDecl>,
33}
34
35impl<'a> TypeChecker<'a> {
36    pub fn new(source: &'a str, symbols: SymbolTable) -> Self {
37        Self {
38            source,
39            symbols,
40            safety: SafetyContext::new(),
41            current_fn_return_type: None,
42            errors: Vec::new(),
43            enum_decls: HashMap::new(),
44            struct_decls: HashMap::new(),
45        }
46    }
47
48    pub fn check(&mut self, file: &File) -> Result<(), CompileError> {
49        // First pass: collect type declarations for validation
50        for item in &file.items {
51            match item {
52                Item::Enum(enum_decl) => {
53                    self.enum_decls
54                        .insert(enum_decl.name.clone(), enum_decl.clone());
55                }
56                Item::Struct(struct_decl) => {
57                    self.struct_decls
58                        .insert(struct_decl.name.clone(), struct_decl.clone());
59                }
60                _ => {}
61            }
62        }
63
64        // Second pass: type check items
65        for item in &file.items {
66            self.check_item(item);
67        }
68
69        // Return all errors collected during type checking
70        if !self.errors.is_empty() {
71            Err(CompileError::multiple(std::mem::take(&mut self.errors)))
72        } else {
73            Ok(())
74        }
75    }
76
77    fn check_item(&mut self, item: &Item) {
78        match item {
79            Item::Fn(fn_decl) => self.check_fn(fn_decl),
80            Item::Struct(_) => {} // Struct fields were checked during resolution
81            Item::Enum(_) => {}
82            Item::Const(_) => {} // Const type was declared
83            Item::Opaque(_) => {}
84            Item::Extern(extern_block) => {
85                // Validate FFI types in extern signatures
86                for extern_item in &extern_block.items {
87                    if let ExternItem::Fn(proto) = extern_item {
88                        // Check return type
89                        self.validate_ffi_type(&proto.return_type, &proto.span);
90                        // Check parameters
91                        for param in &proto.params {
92                            self.validate_ffi_type(&param.ty, &proto.span);
93                        }
94                    }
95                }
96            }
97            Item::Use(_) => {} // Module imports don't need type checking
98            Item::Mod(_) => {} // Module declarations handled separately
99        }
100    }
101
102    fn check_fn(&mut self, fn_decl: &FnDecl) {
103        // Enter function scope
104        self.symbols.enter_scope();
105
106        // Track if this is an unsafe function
107        if fn_decl.is_unsafe {
108            self.safety.enter_unsafe();
109        }
110
111        // Set current return type
112        self.current_fn_return_type = Some(fn_decl.return_type.clone());
113
114        // Define parameters in scope
115        for param in &fn_decl.params {
116            let symbol = Symbol {
117                name: param.name.clone(),
118                kind: SymbolKind::Variable,
119                ty: param.ty.clone(),
120                span: param.span.clone(),
121            };
122            let _ = self.symbols.define(symbol);
123        }
124
125        // Check body
126        self.check_block(&fn_decl.body);
127
128        // Reset state
129        self.current_fn_return_type = None;
130        if fn_decl.is_unsafe {
131            self.safety.exit_unsafe();
132        }
133        self.symbols.exit_scope();
134    }
135
136    fn check_block(&mut self, block: &Block) {
137        self.symbols.enter_scope();
138        for stmt in &block.stmts {
139            self.check_stmt(stmt);
140        }
141        self.symbols.exit_scope();
142    }
143
144    fn check_stmt(&mut self, stmt: &Stmt) {
145        match stmt {
146            Stmt::Let {
147                name,
148                ty,
149                init,
150                span,
151            } => {
152                let init_ty = self.infer_expr(init);
153                if !self.types_compatible(ty, &init_ty) {
154                    self.error_type_mismatch(ty, &init_ty, span);
155                }
156
157                // Define variable
158                let symbol = Symbol {
159                    name: name.clone(),
160                    kind: SymbolKind::Variable,
161                    ty: ty.clone(),
162                    span: span.clone(),
163                };
164                let _ = self.symbols.define(symbol);
165            }
166
167            Stmt::Assign { lhs, rhs, span } => {
168                let lhs_ty = self.infer_expr(lhs);
169                let rhs_ty = self.infer_expr(rhs);
170
171                if !self.types_compatible(&lhs_ty, &rhs_ty) {
172                    self.error_type_mismatch(&lhs_ty, &rhs_ty, span);
173                }
174
175                // Check that lhs is assignable
176                self.check_assignable(lhs, span);
177            }
178
179            Stmt::If {
180                cond,
181                then_block,
182                else_block,
183                span,
184            } => {
185                let cond_ty = self.infer_expr(cond);
186                if !self.is_bool(&cond_ty) {
187                    self.error(
188                        format!("condition must be bool, got {:?}", cond_ty),
189                        span.clone(),
190                    );
191                }
192
193                self.check_block(then_block);
194
195                if let Some(else_branch) = else_block {
196                    match else_branch {
197                        crate::ast::ElseBranch::ElseIf(if_stmt) => self.check_stmt(if_stmt),
198                        crate::ast::ElseBranch::Else(block) => self.check_block(block),
199                    }
200                }
201            }
202
203            Stmt::IfLet {
204                name,
205                expr,
206                then_block,
207                else_block,
208                span,
209            } => {
210                let expr_ty = self.infer_expr(expr);
211
212                // The expression should be opt(T) or res(T, E)
213                let inner_ty = match &expr_ty {
214                    TypeExpr::Opt(inner) => (**inner).clone(),
215                    TypeExpr::Res(ok, _) => (**ok).clone(),
216                    _ => {
217                        self.error(
218                            format!("if-let requires opt or res type, got {:?}", expr_ty),
219                            span.clone(),
220                        );
221                        TypeExpr::Void
222                    }
223                };
224
225                // Check then block with bound variable
226                self.symbols.enter_scope();
227                let symbol = Symbol {
228                    name: name.clone(),
229                    kind: SymbolKind::Variable,
230                    ty: inner_ty,
231                    span: span.clone(),
232                };
233                let _ = self.symbols.define(symbol);
234
235                for stmt in &then_block.stmts {
236                    self.check_stmt(stmt);
237                }
238                self.symbols.exit_scope();
239
240                if let Some(else_blk) = else_block {
241                    self.check_block(else_blk);
242                }
243            }
244
245            Stmt::While { cond, body, span } => {
246                let cond_ty = self.infer_expr(cond);
247                if !self.is_bool(&cond_ty) {
248                    self.error(
249                        format!("condition must be bool, got {:?}", cond_ty),
250                        span.clone(),
251                    );
252                }
253                self.check_block(body);
254            }
255
256            Stmt::For {
257                init,
258                cond,
259                step,
260                body,
261                ..
262            } => {
263                self.symbols.enter_scope();
264
265                if let Some(init) = init {
266                    match init {
267                        crate::ast::ForInit::Let { name, ty, init } => {
268                            let init_ty = self.infer_expr(init);
269                            if !self.types_compatible(ty, &init_ty) {
270                                self.error_type_mismatch(ty, &init_ty, &init.span());
271                            }
272                            let symbol = Symbol {
273                                name: name.clone(),
274                                kind: SymbolKind::Variable,
275                                ty: ty.clone(),
276                                span: 0..0,
277                            };
278                            let _ = self.symbols.define(symbol);
279                        }
280                        crate::ast::ForInit::Assign { lhs, rhs } => {
281                            let lhs_ty = self.infer_expr(lhs);
282                            let rhs_ty = self.infer_expr(rhs);
283                            if !self.types_compatible(&lhs_ty, &rhs_ty) {
284                                self.error_type_mismatch(&lhs_ty, &rhs_ty, &lhs.span());
285                            }
286                        }
287                        crate::ast::ForInit::Call(expr) => {
288                            self.infer_expr(expr);
289                        }
290                    }
291                }
292
293                if let Some(cond) = cond {
294                    let cond_ty = self.infer_expr(cond);
295                    if !self.is_bool(&cond_ty) {
296                        self.error(
297                            format!("for condition must be bool, got {:?}", cond_ty),
298                            cond.span(),
299                        );
300                    }
301                }
302
303                if let Some(step) = step {
304                    match step {
305                        crate::ast::ForStep::Assign { lhs, rhs } => {
306                            let lhs_ty = self.infer_expr(lhs);
307                            let rhs_ty = self.infer_expr(rhs);
308                            if !self.types_compatible(&lhs_ty, &rhs_ty) {
309                                self.error_type_mismatch(&lhs_ty, &rhs_ty, &lhs.span());
310                            }
311                        }
312                        crate::ast::ForStep::Call(expr) => {
313                            self.infer_expr(expr);
314                        }
315                    }
316                }
317
318                for stmt in &body.stmts {
319                    self.check_stmt(stmt);
320                }
321
322                self.symbols.exit_scope();
323            }
324
325            Stmt::Switch {
326                expr,
327                cases,
328                default,
329                span,
330            } => {
331                let expr_ty = self.infer_expr(expr);
332
333                // Switch must be on integer or enum type
334                if !self.is_integer(&expr_ty) && !matches!(expr_ty, TypeExpr::Named(_)) {
335                    self.error(
336                        format!("switch expression must be integer or enum, got {:?}", expr_ty),
337                        expr.span(),
338                    );
339                }
340
341                // Exhaustiveness check for enums
342                if let TypeExpr::Named(enum_name) = &expr_ty {
343                    if let Some(enum_decl) = self.enum_decls.get(enum_name).cloned() {
344                        let expected_variants: HashSet<String> = enum_decl
345                            .variants
346                            .iter()
347                            .map(|v| format!("{}_{}", enum_name, v.name))
348                            .collect();
349
350                        let mut covered_variants = HashSet::new();
351
352                        for case in cases {
353                            // Extract variant name from case value
354                            if let ConstExpr::Ident(name) = &case.value {
355                                covered_variants.insert(name.clone());
356                            }
357                        }
358
359                        let missing: Vec<_> = expected_variants
360                            .difference(&covered_variants)
361                            .cloned()
362                            .collect();
363
364                        if !missing.is_empty() && default.is_none() {
365                            self.error(
366                                format!(
367                                    "non-exhaustive switch on enum '{}': missing variants {:?}",
368                                    enum_name, missing
369                                ),
370                                span.clone(),
371                            );
372                        }
373                    }
374                }
375
376                for case in cases {
377                    for stmt in &case.stmts {
378                        self.check_stmt(stmt);
379                    }
380                }
381
382                // Check default block if present
383                if let Some(default_stmts) = default {
384                    for stmt in default_stmts {
385                        self.check_stmt(stmt);
386                    }
387                }
388            }
389
390            Stmt::Return { value, span } => {
391                let expected = self
392                    .current_fn_return_type
393                    .clone()
394                    .unwrap_or(TypeExpr::Void);
395
396                if let Some(value) = value {
397                    let actual = self.infer_expr(value);
398                    if !self.types_compatible(&expected, &actual) {
399                        self.error_type_mismatch(&expected, &actual, span);
400                    }
401                } else if !matches!(expected, TypeExpr::Void) {
402                    self.error(
403                        format!("expected return value of type {:?}", expected),
404                        span.clone(),
405                    );
406                }
407            }
408
409            Stmt::Break { .. } | Stmt::Continue { .. } => {}
410
411            Stmt::Defer { body, .. } => {
412                self.check_block(body);
413            }
414
415            Stmt::Expr { expr, .. } => {
416                self.infer_expr(expr);
417            }
418
419            Stmt::Discard { expr, .. } => {
420                self.infer_expr(expr);
421            }
422
423            Stmt::Unsafe { body, .. } => {
424                self.safety.enter_unsafe();
425                self.check_block(body);
426                self.safety.exit_unsafe();
427            }
428
429            Stmt::Block(block) => {
430                self.check_block(block);
431            }
432        }
433    }
434
435    fn infer_expr(&mut self, expr: &Expr) -> TypeExpr {
436        match expr {
437            Expr::IntLit { .. } => TypeExpr::Primitive(PrimitiveType::I32), // Default to i32
438            Expr::FloatLit { .. } => TypeExpr::Primitive(PrimitiveType::F64), // Default to f64
439            Expr::BoolLit { .. } => TypeExpr::Primitive(PrimitiveType::Bool),
440            Expr::CStr { .. } => TypeExpr::Raw(Box::new(TypeExpr::Primitive(PrimitiveType::U8))),
441            Expr::Bytes { .. } => TypeExpr::Slice(Box::new(TypeExpr::Primitive(PrimitiveType::U8))),
442
443            Expr::Ident { name, .. } => {
444                if let Some(sym) = self.symbols.lookup(name) {
445                    sym.ty.clone()
446                } else {
447                    TypeExpr::Void // Error already reported during resolution
448                }
449            }
450
451            Expr::Binary { op, lhs, rhs, span } => {
452                let lhs_ty = self.infer_expr(lhs);
453                let rhs_ty = self.infer_expr(rhs);
454
455                // Binary ops require same types
456                if !self.types_compatible(&lhs_ty, &rhs_ty) {
457                    self.error_type_mismatch(&lhs_ty, &rhs_ty, span);
458                }
459
460                match op {
461                    // Comparison operators return bool
462                    BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge => {
463                        TypeExpr::Primitive(PrimitiveType::Bool)
464                    }
465                    // Logical operators require bool and return bool
466                    BinOp::And | BinOp::Or => {
467                        if !self.is_bool(&lhs_ty) {
468                            self.error(
469                                format!("logical operator requires bool, got {:?}", lhs_ty),
470                                span.clone(),
471                            );
472                        }
473                        TypeExpr::Primitive(PrimitiveType::Bool)
474                    }
475                    // Arithmetic and bitwise operators return the same type
476                    _ => lhs_ty,
477                }
478            }
479
480            Expr::Unary { op, operand, span } => {
481                let operand_ty = self.infer_expr(operand);
482
483                match op {
484                    UnaryOp::Neg => {
485                        if !self.is_numeric(&operand_ty) {
486                            self.error(
487                                format!("negation requires numeric type, got {:?}", operand_ty),
488                                span.clone(),
489                            );
490                        }
491                        operand_ty
492                    }
493                    UnaryOp::Not => {
494                        if !self.is_bool(&operand_ty) {
495                            self.error(
496                                format!("logical not requires bool, got {:?}", operand_ty),
497                                span.clone(),
498                            );
499                        }
500                        TypeExpr::Primitive(PrimitiveType::Bool)
501                    }
502                    UnaryOp::BitNot => {
503                        if !self.is_integer(&operand_ty) {
504                            self.error(
505                                format!("bitwise not requires integer, got {:?}", operand_ty),
506                                span.clone(),
507                            );
508                        }
509                        operand_ty
510                    }
511                }
512            }
513
514            Expr::Paren { inner, .. } => self.infer_expr(inner),
515
516            Expr::Call { callee, args, span } => {
517                let callee_ty = self.infer_expr(callee);
518
519                match callee_ty {
520                    TypeExpr::Fn {
521                        is_unsafe,
522                        params,
523                        ret,
524                    } => {
525                        // Check unsafe
526                        if is_unsafe && !self.safety.is_unsafe() {
527                            self.error_with_hint(
528                                "call to unsafe function requires unsafe block".to_string(),
529                                span.clone(),
530                                "wrap the call in an unsafe block: unsafe { ... }",
531                            );
532                        }
533
534                        // Check argument count
535                        if args.len() != params.len() {
536                            self.error(
537                                format!(
538                                    "expected {} arguments, got {}",
539                                    params.len(),
540                                    args.len()
541                                ),
542                                span.clone(),
543                            );
544                        }
545
546                        // Check argument types
547                        for (arg, param_ty) in args.iter().zip(params.iter()) {
548                            let arg_ty = self.infer_expr(arg);
549                            if !self.types_compatible(param_ty, &arg_ty) {
550                                self.error_type_mismatch(param_ty, &arg_ty, &arg.span());
551                            }
552                        }
553
554                        *ret
555                    }
556                    _ => {
557                        self.error(
558                            format!("cannot call non-function type {:?}", callee_ty),
559                            span.clone(),
560                        );
561                        TypeExpr::Void
562                    }
563                }
564            }
565
566            Expr::Field { base, field, span } => {
567                let base_ty = self.infer_expr(base);
568
569                // Look up field in struct
570                if let TypeExpr::Named(struct_name) = &base_ty {
571                    // TODO: Look up struct definition and find field type
572                    // For now, return void as placeholder
573                    let _ = (struct_name, field, span);
574                    TypeExpr::Void
575                } else {
576                    self.error(
577                        format!("field access on non-struct type {:?}", base_ty),
578                        span.clone(),
579                    );
580                    TypeExpr::Void
581                }
582            }
583
584            Expr::Addr { operand, span } => {
585                let operand_ty = self.infer_expr(operand);
586                self.check_addressable(operand, span);
587                TypeExpr::Ref(Box::new(operand_ty))
588            }
589
590            Expr::Deref { operand, span } => {
591                let operand_ty = self.infer_expr(operand);
592
593                match operand_ty {
594                    TypeExpr::Ref(inner) | TypeExpr::Mref(inner) => *inner,
595                    TypeExpr::Raw(inner) | TypeExpr::Rawm(inner) => {
596                        // Deref of raw pointer requires unsafe
597                        if !self.safety.is_unsafe() {
598                            self.error(
599                                "dereference of raw pointer requires unsafe block".to_string(),
600                                span.clone(),
601                            );
602                        }
603                        *inner
604                    }
605                    _ => {
606                        self.error(
607                            format!("cannot dereference non-pointer type {:?}", operand_ty),
608                            span.clone(),
609                        );
610                        TypeExpr::Void
611                    }
612                }
613            }
614
615            Expr::At { base, index, span } => {
616                let base_ty = self.infer_expr(base);
617                let index_ty = self.infer_expr(index);
618
619                // Index must be usize
620                if !matches!(index_ty, TypeExpr::Primitive(PrimitiveType::Usize)) {
621                    // Allow any integer for now
622                    if !self.is_integer(&index_ty) {
623                        self.error(
624                            format!("index must be integer, got {:?}", index_ty),
625                            span.clone(),
626                        );
627                    }
628                }
629
630                match base_ty {
631                    TypeExpr::Slice(inner) => *inner,
632                    TypeExpr::Arr(inner, _) => *inner,
633                    _ => {
634                        self.error(
635                            format!("cannot index non-array type {:?}", base_ty),
636                            span.clone(),
637                        );
638                        TypeExpr::Void
639                    }
640                }
641            }
642
643            Expr::Cast { ty, expr, span } => {
644                let expr_ty = self.infer_expr(expr);
645
646                // Check that cast is valid
647                if !self.can_cast(&expr_ty, ty) {
648                    self.error(
649                        format!("cannot cast {:?} to {:?}", expr_ty, ty),
650                        span.clone(),
651                    );
652                }
653
654                ty.clone()
655            }
656
657            Expr::None { ty, .. } => TypeExpr::Opt(Box::new(ty.clone())),
658
659            Expr::Some { value, .. } => {
660                let inner_ty = self.infer_expr(value);
661                TypeExpr::Opt(Box::new(inner_ty))
662            }
663
664            Expr::Ok { value, .. } => {
665                let inner_ty = self.infer_expr(value);
666                // We don't know the error type, use void as placeholder
667                TypeExpr::Res(Box::new(inner_ty), Box::new(TypeExpr::Void))
668            }
669
670            Expr::Err { value, .. } => {
671                let inner_ty = self.infer_expr(value);
672                // We don't know the ok type, use void as placeholder
673                TypeExpr::Res(Box::new(TypeExpr::Void), Box::new(inner_ty))
674            }
675
676            Expr::StructLit { name, .. } => TypeExpr::Named(name.clone()),
677        }
678    }
679
680    // === Type compatibility checks ===
681
682    fn types_compatible(&self, expected: &TypeExpr, actual: &TypeExpr) -> bool {
683        match (expected, actual) {
684            (TypeExpr::Void, TypeExpr::Void) => true,
685            (TypeExpr::Primitive(a), TypeExpr::Primitive(b)) => a == b,
686            (TypeExpr::Named(a), TypeExpr::Named(b)) => a == b,
687            (TypeExpr::Ref(a), TypeExpr::Ref(b)) => self.types_compatible(a, b),
688            (TypeExpr::Mref(a), TypeExpr::Mref(b)) => self.types_compatible(a, b),
689            (TypeExpr::Raw(a), TypeExpr::Raw(b)) => self.types_compatible(a, b),
690            (TypeExpr::Rawm(a), TypeExpr::Rawm(b)) => self.types_compatible(a, b),
691            (TypeExpr::Own(a), TypeExpr::Own(b)) => self.types_compatible(a, b),
692            (TypeExpr::Slice(a), TypeExpr::Slice(b)) => self.types_compatible(a, b),
693            (TypeExpr::Arr(a, _), TypeExpr::Arr(b, _)) => self.types_compatible(a, b),
694            (TypeExpr::Opt(a), TypeExpr::Opt(b)) => self.types_compatible(a, b),
695            (TypeExpr::Res(a1, a2), TypeExpr::Res(b1, b2)) => {
696                self.types_compatible(a1, b1) && self.types_compatible(a2, b2)
697            }
698            (
699                TypeExpr::Fn {
700                    is_unsafe: u1,
701                    params: p1,
702                    ret: r1,
703                },
704                TypeExpr::Fn {
705                    is_unsafe: u2,
706                    params: p2,
707                    ret: r2,
708                },
709            ) => {
710                u1 == u2
711                    && p1.len() == p2.len()
712                    && p1.iter().zip(p2.iter()).all(|(a, b)| self.types_compatible(a, b))
713                    && self.types_compatible(r1, r2)
714            }
715            _ => false,
716        }
717    }
718
719    fn is_bool(&self, ty: &TypeExpr) -> bool {
720        matches!(ty, TypeExpr::Primitive(PrimitiveType::Bool))
721    }
722
723    fn is_integer(&self, ty: &TypeExpr) -> bool {
724        matches!(
725            ty,
726            TypeExpr::Primitive(
727                PrimitiveType::I8
728                    | PrimitiveType::I16
729                    | PrimitiveType::I32
730                    | PrimitiveType::I64
731                    | PrimitiveType::U8
732                    | PrimitiveType::U16
733                    | PrimitiveType::U32
734                    | PrimitiveType::U64
735                    | PrimitiveType::Usize
736                    | PrimitiveType::Isize
737            )
738        )
739    }
740
741    fn is_numeric(&self, ty: &TypeExpr) -> bool {
742        self.is_integer(ty)
743            || matches!(
744                ty,
745                TypeExpr::Primitive(PrimitiveType::F32 | PrimitiveType::F64)
746            )
747    }
748
749    fn can_cast(&self, from: &TypeExpr, to: &TypeExpr) -> bool {
750        // Allow casts between numeric types
751        if self.is_numeric(from) && self.is_numeric(to) {
752            return true;
753        }
754
755        // Allow casts between pointer types
756        match (from, to) {
757            (TypeExpr::Ref(_), TypeExpr::Raw(_))
758            | (TypeExpr::Mref(_), TypeExpr::Rawm(_))
759            | (TypeExpr::Raw(_), TypeExpr::Raw(_))
760            | (TypeExpr::Rawm(_), TypeExpr::Rawm(_)) => true,
761            _ => false,
762        }
763    }
764
765    fn check_assignable(&mut self, expr: &Expr, span: &Span) {
766        match expr {
767            Expr::Ident { .. } => {}
768            Expr::Deref { .. } => {}
769            Expr::At { .. } => {}
770            Expr::Field { .. } => {}
771            _ => {
772                self.error("expression is not assignable".to_string(), span.clone());
773            }
774        }
775    }
776
777    fn check_addressable(&mut self, expr: &Expr, span: &Span) {
778        match expr {
779            Expr::Ident { .. } => {}
780            Expr::Deref { .. } => {}
781            Expr::At { .. } => {}
782            Expr::Field { .. } => {}
783            _ => {
784                self.error("cannot take address of expression".to_string(), span.clone());
785            }
786        }
787    }
788
789    // === FFI validation ===
790
791    /// Validate that a type is allowed in FFI signatures
792    fn validate_ffi_type(&mut self, ty: &TypeExpr, span: &Span) {
793        match ty {
794            TypeExpr::Opt(_) => {
795                self.error(
796                    "opt(T) is not permitted in extern signatures".to_string(),
797                    span.clone(),
798                );
799            }
800            TypeExpr::Res(_, _) => {
801                self.error(
802                    "res(T, E) is not permitted in extern signatures".to_string(),
803                    span.clone(),
804                );
805            }
806            TypeExpr::Named(name) => {
807                // Check if it's a struct passed by value without @repr(C)
808                if let Some(struct_decl) = self.struct_decls.get(name) {
809                    if struct_decl.repr != Some(Repr::C) {
810                        self.error(
811                            format!(
812                                "struct '{}' passed by value in extern must have @repr(C)",
813                                name
814                            ),
815                            span.clone(),
816                        );
817                    }
818                }
819                // Note: enums are allowed without @repr as they default to i32
820            }
821            // Primitives, pointers, slices, void are OK in FFI
822            _ => {}
823        }
824    }
825
826    // === Error helpers ===
827
828    fn error(&mut self, message: String, span: Span) {
829        self.errors
830            .push(CompileError::type_error(message, span, self.source));
831    }
832
833    fn error_with_hint(&mut self, message: String, span: Span, hint: impl Into<String>) {
834        self.errors.push(CompileError::type_error_with_hint(
835            message,
836            span,
837            self.source,
838            hint,
839        ));
840    }
841
842    fn error_type_mismatch(&mut self, expected: &TypeExpr, actual: &TypeExpr, span: &Span) {
843        self.error(
844            format!("type mismatch: expected {:?}, got {:?}", expected, actual),
845            span.clone(),
846        );
847    }
848}
849
850impl Default for TypeChecker<'_> {
851    fn default() -> Self {
852        Self::new("", SymbolTable::new())
853    }
854}
855
856#[cfg(test)]
857mod tests {
858    use crate::driver::compile;
859
860    fn check_error(source: &str, expected_substr: &str) {
861        let result = compile(source, "test.fc");
862        assert!(result.is_err(), "Expected error for: {}", source);
863        let err_msg = format!("{:?}", result.unwrap_err());
864        assert!(
865            err_msg.contains(expected_substr),
866            "Expected error containing '{}', got: {}",
867            expected_substr,
868            err_msg
869        );
870    }
871
872    fn check_ok(source: &str) {
873        let result = compile(source, "test.fc");
874        assert!(
875            result.is_ok(),
876            "Expected success for: {}\nGot error: {:?}",
877            source,
878            result.err()
879        );
880    }
881
882    // === Type mismatch tests ===
883
884    #[test]
885    fn test_type_mismatch_let() {
886        check_error(
887            "fn foo() -> void { let x: i32 = true; }",
888            "type mismatch",
889        );
890    }
891
892    #[test]
893    fn test_type_mismatch_return() {
894        check_error(
895            "fn foo() -> i32 { return true; }",
896            "type mismatch",
897        );
898    }
899
900    #[test]
901    fn test_type_mismatch_binary() {
902        check_error(
903            "fn foo() -> i32 { return (1 + true); }",
904            "type mismatch",
905        );
906    }
907
908    #[test]
909    fn test_type_mismatch_assignment() {
910        check_error(
911            "fn foo() -> void { let x: i32 = 1; x = true; }",
912            "type mismatch",
913        );
914    }
915
916    // === Operator type tests ===
917
918    #[test]
919    fn test_logical_requires_bool() {
920        check_error(
921            "fn foo() -> bool { return (1 && 2); }",
922            "logical operator requires bool",
923        );
924    }
925
926    #[test]
927    fn test_not_requires_bool() {
928        check_error(
929            "fn foo() -> bool { return !1; }",
930            "logical not requires bool",
931        );
932    }
933
934    #[test]
935    fn test_condition_requires_bool() {
936        check_error(
937            "fn foo() -> void { if (1) { } }",
938            "condition must be bool",
939        );
940    }
941
942    #[test]
943    fn test_while_condition_requires_bool() {
944        check_error(
945            "fn foo() -> void { while (1) { } }",
946            "condition must be bool",
947        );
948    }
949
950    // === Unsafe context tests ===
951
952    #[test]
953    fn test_unsafe_function_call_requires_unsafe() {
954        check_error(
955            "unsafe fn danger() -> i32 { return 1; } fn foo() -> i32 { return danger(); }",
956            "call to unsafe function requires unsafe block",
957        );
958    }
959
960    #[test]
961    fn test_unsafe_function_call_in_unsafe_block() {
962        check_ok(
963            "unsafe fn danger() -> i32 { return 1; } fn foo() -> i32 { unsafe { return danger(); } }",
964        );
965    }
966
967    #[test]
968    fn test_unsafe_function_can_call_unsafe() {
969        check_ok(
970            "unsafe fn danger() -> i32 { return 1; } unsafe fn foo() -> i32 { return danger(); }",
971        );
972    }
973
974    // === Function call tests ===
975
976    #[test]
977    fn test_wrong_argument_count() {
978        check_error(
979            "fn bar(x: i32, y: i32) -> i32 { return (x + y); } fn foo() -> i32 { return bar(1); }",
980            "expected 2 arguments, got 1",
981        );
982    }
983
984    #[test]
985    fn test_wrong_argument_type() {
986        check_error(
987            "fn bar(x: i32) -> i32 { return x; } fn foo() -> i32 { return bar(true); }",
988            "type mismatch",
989        );
990    }
991
992    #[test]
993    fn test_call_non_function() {
994        // When calling a non-function, the type checker reports "cannot call non-function type"
995        // but the return type mismatch may also be reported
996        check_error(
997            "fn foo() -> i32 { let x: i32 = 1; return x(); }",
998            "cannot call non-function",
999        );
1000    }
1001
1002    // === Return type tests ===
1003
1004    #[test]
1005    fn test_missing_return_value() {
1006        check_error(
1007            "fn foo() -> i32 { return; }",
1008            "expected return value",
1009        );
1010    }
1011
1012    #[test]
1013    fn test_void_function_ok() {
1014        check_ok("fn foo() -> void { return; }");
1015    }
1016
1017    #[test]
1018    fn test_void_function_implicit_return() {
1019        check_ok("fn foo() -> void { let x: i32 = 1; }");
1020    }
1021
1022    // === Valid programs ===
1023
1024    #[test]
1025    fn test_basic_arithmetic() {
1026        check_ok("fn foo() -> i32 { return (1 + 2); }");
1027    }
1028
1029    #[test]
1030    fn test_comparison() {
1031        check_ok("fn foo() -> bool { return (1 < 2); }");
1032    }
1033
1034    #[test]
1035    fn test_logical_ops() {
1036        check_ok("fn foo() -> bool { return (true && false); }");
1037    }
1038
1039    #[test]
1040    fn test_function_call() {
1041        check_ok("fn bar(x: i32) -> i32 { return x; } fn foo() -> i32 { return bar(1); }");
1042    }
1043
1044    #[test]
1045    fn test_nested_calls() {
1046        check_ok(
1047            "fn a(x: i32) -> i32 { return x; } fn b(x: i32) -> i32 { return a(x); } fn foo() -> i32 { return b(1); }",
1048        );
1049    }
1050}