expr_solver/
sema.rs

1use crate::ast::*;
2use crate::span::{Span, SpanError};
3use crate::symbol::{SymTable, Symbol};
4use thiserror::Error;
5
6/// Expression parsing and evaluation errors.
7#[derive(Error, Debug, Clone)]
8pub enum SemaError {
9    #[error("Undefined symbol '{name}'")]
10    UndefinedSymbol { name: String, span: Span },
11    #[error("Symbol '{name}' is not a constant")]
12    SymbolIsNotAConstant { name: String, span: Span },
13    #[error("Symbol '{name}' is not a function")]
14    SymbolIsNotAFunction { name: String, span: Span },
15    #[error("Function '{name}' expects exactly {expected} arguments but got {got}")]
16    ArgumentCountMismatch {
17        name: String,
18        expected: usize,
19        got: usize,
20        span: Span,
21    },
22    #[error("Function '{name}' expects at least {min} arguments but got {got}")]
23    InsufficientArguments {
24        name: String,
25        min: usize,
26        got: usize,
27        span: Span,
28    },
29}
30
31impl SpanError for SemaError {
32    fn span(&self) -> Span {
33        match self {
34            SemaError::UndefinedSymbol { span, .. } => *span,
35            SemaError::SymbolIsNotAConstant { span, .. } => *span,
36            SemaError::SymbolIsNotAFunction { span, .. } => *span,
37            SemaError::ArgumentCountMismatch { span, .. } => *span,
38            SemaError::InsufficientArguments { span, .. } => *span,
39        }
40    }
41}
42
43/// Semantic analyzer: resolves identifiers/calls to symbols and checks arity & types.
44#[derive(Debug)]
45pub struct Sema<'sym> {
46    table: &'sym SymTable,
47}
48
49impl<'src, 'sym> Sema<'sym> {
50    pub fn new(table: &'sym SymTable) -> Self {
51        Self { table }
52    }
53
54    pub fn visit(&mut self, ast: &mut Expr<'src, 'sym>) -> Result<(), SemaError> {
55        match &mut ast.kind {
56            ExprKind::Literal(_) => Ok(()),
57            ExprKind::Ident { name, sym } => self.visit_ident(name, sym, ast.span),
58            ExprKind::Unary { op: _, expr } => self.visit_unary(expr),
59            ExprKind::Binary { op: _, left, right } => self.visit_binary(left, right),
60            ExprKind::Call { name, args, sym } => self.visit_call(name, args, sym, ast.span),
61        }
62    }
63
64    fn visit_ident(
65        &mut self,
66        name: &str,
67        sym: &mut Option<&'sym Symbol>,
68        span: Span,
69    ) -> Result<(), SemaError> {
70        let s = self.get_symbol(name, span)?;
71
72        let Symbol::Const { .. } = s else {
73            return Err(SemaError::SymbolIsNotAConstant {
74                name: name.to_string(),
75                span,
76            });
77        };
78
79        *sym = Some(s);
80        Ok(())
81    }
82
83    fn visit_unary(&mut self, expr: &mut Expr<'src, 'sym>) -> Result<(), SemaError> {
84        self.visit(expr)
85    }
86
87    fn visit_binary(
88        &mut self,
89        left: &mut Expr<'src, 'sym>,
90        right: &mut Expr<'src, 'sym>,
91    ) -> Result<(), SemaError> {
92        self.visit(left)?;
93        self.visit(right)
94    }
95
96    fn visit_call(
97        &mut self,
98        name: &str,
99        args: &mut Vec<Expr<'src, 'sym>>,
100        sym: &mut Option<&'sym Symbol>,
101        span: Span,
102    ) -> Result<(), SemaError> {
103        // span here will include a whole call expression,
104        // but is guaranteed to start with the symbol
105        let sym_span = Span::new(span.start, span.start + name.len());
106        let s = self.get_symbol(name, sym_span)?;
107
108        let Symbol::Func {
109            args: min_args,
110            variadic,
111            ..
112        } = s
113        else {
114            return Err(SemaError::SymbolIsNotAFunction {
115                name: name.to_string(),
116                span: sym_span,
117            });
118        };
119
120        self.validate_arity(name, args.len(), *min_args, *variadic, span)?;
121        self.analyse_arguments(args)?;
122
123        *sym = Some(s);
124        Ok(())
125    }
126
127    fn validate_arity(
128        &self,
129        name: &str,
130        args: usize,
131        min_args: usize,
132        variadic: bool,
133        span: Span,
134    ) -> Result<(), SemaError> {
135        if args == min_args || variadic && args > min_args {
136            return Ok(());
137        }
138        if variadic {
139            Err(SemaError::InsufficientArguments {
140                name: name.to_string(),
141                min: min_args,
142                got: args,
143                span,
144            })
145        } else {
146            Err(SemaError::ArgumentCountMismatch {
147                name: name.to_string(),
148                expected: min_args,
149                got: args,
150                span,
151            })
152        }
153    }
154
155    fn analyse_arguments(&mut self, args: &mut [Expr<'src, 'sym>]) -> Result<(), SemaError> {
156        args.iter_mut().try_for_each(|a| self.visit(a))
157    }
158
159    fn get_symbol(&self, name: &str, span: Span) -> Result<&'sym Symbol, SemaError> {
160        self.table
161            .get(name)
162            .ok_or_else(|| SemaError::UndefinedSymbol {
163                name: name.to_string(),
164                span,
165            })
166    }
167}