expr_solver/
sema.rs

1use crate::ast::*;
2use crate::span::{Span, SpanError};
3use crate::symbol::{SymTable, Symbol};
4use thiserror::Error;
5
6/// Expression parsing and evaluation errors.
7#[derive(Error, Debug, Clone)]
8pub enum SemaError {
9    #[error("Undefined symbol '{name}'")]
10    UndefinedSymbol { name: String, span: Span },
11    #[error("Symbol '{name}' is not a constant")]
12    SymbolIsNotAConstant { name: String, span: Span },
13    #[error("Symbol '{name}' is not a function")]
14    SymbolIsNotAFunction { name: String, span: Span },
15    #[error("Function '{name}' expects exactly {expected} arguments but got {got}")]
16    ArgumentCountMismatch {
17        name: String,
18        expected: usize,
19        got: usize,
20        span: Span,
21    },
22    #[error("Function '{name}' expects at least {min} arguments but got {got}")]
23    InsufficientArguments {
24        name: String,
25        min: usize,
26        got: usize,
27        span: Span,
28    },
29}
30
31impl SpanError for SemaError {
32    fn span(&self) -> Span {
33        match self {
34            SemaError::UndefinedSymbol { span, .. } => *span,
35            SemaError::SymbolIsNotAConstant { span, .. } => *span,
36            SemaError::SymbolIsNotAFunction { span, .. } => *span,
37            SemaError::ArgumentCountMismatch { span, .. } => *span,
38            SemaError::InsufficientArguments { span, .. } => *span,
39        }
40    }
41}
42
43/// Semantic analyzer for type checking and symbol resolution.
44///
45/// Validates that identifiers reference valid symbols and that function
46/// calls have the correct number of arguments.
47#[derive(Debug)]
48pub struct Sema<'sym> {
49    table: &'sym SymTable,
50}
51
52impl<'src, 'sym> Sema<'sym> {
53    /// Creates a new semantic analyzer with the given symbol table.
54    pub fn new(table: &'sym SymTable) -> Self {
55        Self { table }
56    }
57
58    /// Analyzes an AST expression, resolving symbols and checking types.
59    pub fn visit(&mut self, ast: &mut Expr<'src, 'sym>) -> Result<(), SemaError> {
60        match &mut ast.kind {
61            ExprKind::Literal(_) => Ok(()),
62            ExprKind::Ident { name, sym } => self.visit_ident(name, sym, ast.span),
63            ExprKind::Unary { op: _, expr } => self.visit_unary(expr),
64            ExprKind::Binary { op: _, left, right } => self.visit_binary(left, right),
65            ExprKind::Call { name, args, sym } => self.visit_call(name, args, sym, ast.span),
66        }
67    }
68
69    fn visit_ident(
70        &mut self,
71        name: &str,
72        sym: &mut Option<&'sym Symbol>,
73        span: Span,
74    ) -> Result<(), SemaError> {
75        let s = self.get_symbol(name, span)?;
76
77        let Symbol::Const { .. } = s else {
78            return Err(SemaError::SymbolIsNotAConstant {
79                name: name.to_string(),
80                span,
81            });
82        };
83
84        *sym = Some(s);
85        Ok(())
86    }
87
88    fn visit_unary(&mut self, expr: &mut Expr<'src, 'sym>) -> Result<(), SemaError> {
89        self.visit(expr)
90    }
91
92    fn visit_binary(
93        &mut self,
94        left: &mut Expr<'src, 'sym>,
95        right: &mut Expr<'src, 'sym>,
96    ) -> Result<(), SemaError> {
97        self.visit(left)?;
98        self.visit(right)
99    }
100
101    fn visit_call(
102        &mut self,
103        name: &str,
104        args: &mut Vec<Expr<'src, 'sym>>,
105        sym: &mut Option<&'sym Symbol>,
106        span: Span,
107    ) -> Result<(), SemaError> {
108        // span here will include a whole call expression,
109        // but is guaranteed to start with the symbol
110        let sym_span = Span::new(span.start, span.start + name.len());
111        let s = self.get_symbol(name, sym_span)?;
112
113        let Symbol::Func {
114            args: min_args,
115            variadic,
116            ..
117        } = s
118        else {
119            return Err(SemaError::SymbolIsNotAFunction {
120                name: name.to_string(),
121                span: sym_span,
122            });
123        };
124
125        self.validate_arity(name, args.len(), *min_args, *variadic, span)?;
126        self.analyse_arguments(args)?;
127
128        *sym = Some(s);
129        Ok(())
130    }
131
132    fn validate_arity(
133        &self,
134        name: &str,
135        args: usize,
136        min_args: usize,
137        variadic: bool,
138        span: Span,
139    ) -> Result<(), SemaError> {
140        if args == min_args || variadic && args > min_args {
141            return Ok(());
142        }
143        if variadic {
144            Err(SemaError::InsufficientArguments {
145                name: name.to_string(),
146                min: min_args,
147                got: args,
148                span,
149            })
150        } else {
151            Err(SemaError::ArgumentCountMismatch {
152                name: name.to_string(),
153                expected: min_args,
154                got: args,
155                span,
156            })
157        }
158    }
159
160    fn analyse_arguments(&mut self, args: &mut [Expr<'src, 'sym>]) -> Result<(), SemaError> {
161        args.iter_mut().try_for_each(|a| self.visit(a))
162    }
163
164    fn get_symbol(&self, name: &str, span: Span) -> Result<&'sym Symbol, SemaError> {
165        self.table
166            .get(name)
167            .ok_or_else(|| SemaError::UndefinedSymbol {
168                name: name.to_string(),
169                span,
170            })
171    }
172}