somni_expr/
visitor.rs

1use somni_parser::{
2    ast::{
3        Body, Expression, Function, If, LeftHandExpression, LiteralValue, Loop,
4        RightHandExpression, Statement, TypeHint, VariableDefinition,
5    },
6    lexer,
7    parser::DefaultTypeSet,
8    Location,
9};
10
11use crate::{
12    value::LoadStore, EvalError, ExprContext, FunctionCallError, Type, TypeSet, TypedValue,
13};
14
15/// A visitor that can process an abstract syntax tree.
16pub struct ExpressionVisitor<'a, C, T = DefaultTypeSet> {
17    /// The context in which the expression is evaluated.
18    pub context: &'a mut C,
19    /// The source code from which the expression was parsed.
20    pub source: &'a str,
21    /// The types of the variables in the context.
22    pub _marker: std::marker::PhantomData<T>,
23}
24
25impl<'a, C, T> ExpressionVisitor<'a, C, T>
26where
27    C: ExprContext<T>,
28    T: TypeSet,
29{
30    fn visit_variable(&mut self, variable: &lexer::Token) -> Result<TypedValue<T>, EvalError> {
31        let name = variable.source(self.source);
32        self.context.try_load_variable(name).ok_or(EvalError {
33            message: format!("Variable {name} was not found").into_boxed_str(),
34            location: variable.location,
35        })
36    }
37
38    /// Visits an expression and evaluates it, returning the result as a `TypedValue`.
39    pub fn visit_expression(
40        &mut self,
41        expression: &Expression<T::Parser>,
42    ) -> Result<TypedValue<T>, EvalError> {
43        let result = match expression {
44            Expression::Expression { expression } => {
45                self.visit_right_hand_expression(expression)?
46            }
47            Expression::Assignment {
48                left_expr,
49                operator: _,
50                right_expr,
51            } => {
52                let rhs = self.visit_right_hand_expression(right_expr)?;
53                let assign_result = match left_expr {
54                    LeftHandExpression::Deref { name, .. } => {
55                        let address =
56                            self.visit_right_hand_expression(&RightHandExpression::Variable {
57                                variable: *name,
58                            })?;
59                        self.context.assign_address(address, &rhs)
60                    }
61                    LeftHandExpression::Name { variable } => {
62                        let name = variable.source(self.source);
63                        self.context.assign_variable(name, &rhs)
64                    }
65                };
66
67                if let Err(error) = assign_result {
68                    return Err(EvalError {
69                        message: error,
70                        location: expression.location(),
71                    });
72                }
73
74                TypedValue::Void
75            }
76        };
77
78        Ok(result)
79    }
80
81    /// Visits an expression and evaluates it, returning the result as a `TypedValue`.
82    pub fn visit_right_hand_expression(
83        &mut self,
84        expression: &RightHandExpression<T::Parser>,
85    ) -> Result<TypedValue<T>, EvalError> {
86        let result = match expression {
87            RightHandExpression::Variable { variable } => self.visit_variable(variable)?,
88            RightHandExpression::Literal { value } => match &value.value {
89                LiteralValue::Integer(value) => TypedValue::<T>::MaybeSignedInt(*value),
90                LiteralValue::Float(value) => TypedValue::<T>::Float(*value),
91                LiteralValue::String(value) => value.store(self.context.type_context()),
92                LiteralValue::Boolean(value) => TypedValue::<T>::Bool(*value),
93            },
94            RightHandExpression::UnaryOperator { name, operand } => {
95                match name.source(self.source) {
96                    "!" => {
97                        let operand = self.visit_right_hand_expression(operand)?;
98
99                        match TypedValue::<T>::not(self.context.type_context(), operand) {
100                            Ok(r) => r,
101                            Err(error) => {
102                                return Err(EvalError {
103                                    message: format!("Failed to evaluate expression: {error}")
104                                        .into_boxed_str(),
105                                    location: expression.location(),
106                                });
107                            }
108                        }
109                    }
110
111                    "-" => {
112                        let value = self.visit_right_hand_expression(operand)?;
113                        let ty = value.type_of();
114                        TypedValue::<T>::negate(self.context.type_context(), value).map_err(
115                            |e| EvalError {
116                                message: format!("Cannot negate {ty}: {e}").into_boxed_str(),
117                                location: operand.location(),
118                            },
119                        )?
120                    }
121
122                    "&" => {
123                        let RightHandExpression::Variable { variable } = operand.as_ref() else {
124                            return Err(EvalError {
125                                message: String::from(
126                                    "Cannot take address of non-variable expression",
127                                )
128                                .into_boxed_str(),
129                                location: operand.location(),
130                            });
131                        };
132
133                        let name = variable.source(self.source);
134                        self.context.address_of(name)
135                    }
136                    "*" => {
137                        let address = self.visit_right_hand_expression(operand)?;
138                        self.context.at_address(address).map_err(|e| EvalError {
139                            message: format!("Failed to load variable from address: {e}")
140                                .into_boxed_str(),
141                            location: operand.location(),
142                        })?
143                    }
144                    _ => {
145                        return Err(EvalError {
146                            message: format!(
147                                "Unknown unary operator: {}",
148                                name.source(self.source)
149                            )
150                            .into_boxed_str(),
151                            location: expression.location(),
152                        });
153                    }
154                }
155            }
156            RightHandExpression::BinaryOperator { name, operands } => {
157                let lhs = self.visit_right_hand_expression(&operands[0])?;
158
159                let short_circuiting = ["&&", "||"];
160                let operator = name.source(self.source);
161
162                // Special cases
163                if short_circuiting.contains(&operator) {
164                    return match operator {
165                        "&&" if lhs == TypedValue::<T>::Bool(false) => Ok(TypedValue::Bool(false)),
166                        "||" if lhs == TypedValue::<T>::Bool(true) => Ok(TypedValue::Bool(true)),
167                        _ => self.visit_right_hand_expression(&operands[1]),
168                    };
169                }
170
171                // "Normal" binary operators
172                let rhs = self.visit_right_hand_expression(&operands[1])?;
173                let type_context = self.context.type_context();
174                let result = match operator {
175                    "+" => TypedValue::<T>::add(type_context, lhs, rhs),
176                    "-" => TypedValue::<T>::subtract(type_context, lhs, rhs),
177                    "*" => TypedValue::<T>::multiply(type_context, lhs, rhs),
178                    "/" => TypedValue::<T>::divide(type_context, lhs, rhs),
179                    "%" => TypedValue::<T>::modulo(type_context, lhs, rhs),
180                    "<" => TypedValue::<T>::less_than(type_context, lhs, rhs),
181                    ">" => TypedValue::<T>::less_than(type_context, rhs, lhs),
182                    "<=" => TypedValue::<T>::less_than_or_equal(type_context, lhs, rhs),
183                    ">=" => TypedValue::<T>::less_than_or_equal(type_context, rhs, lhs),
184                    "==" => TypedValue::<T>::equals(type_context, lhs, rhs),
185                    "!=" => TypedValue::<T>::not_equals(type_context, lhs, rhs),
186                    "|" => TypedValue::<T>::bitwise_or(type_context, lhs, rhs),
187                    "^" => TypedValue::<T>::bitwise_xor(type_context, lhs, rhs),
188                    "&" => TypedValue::<T>::bitwise_and(type_context, lhs, rhs),
189                    "<<" => TypedValue::<T>::shift_left(type_context, lhs, rhs),
190                    ">>" => TypedValue::<T>::shift_right(type_context, lhs, rhs),
191
192                    other => {
193                        return Err(EvalError {
194                            message: format!("Unknown binary operator: {other}").into_boxed_str(),
195                            location: expression.location(),
196                        });
197                    }
198                };
199
200                match result {
201                    Ok(r) => r,
202                    Err(error) => {
203                        return Err(EvalError {
204                            message: format!("Failed to evaluate expression: {error}")
205                                .into_boxed_str(),
206                            location: expression.location(),
207                        });
208                    }
209                }
210            }
211            RightHandExpression::FunctionCall { name, arguments } => {
212                let function_name = name.source(self.source);
213                let mut args = Vec::with_capacity(arguments.len());
214                for arg in arguments {
215                    args.push(self.visit_right_hand_expression(arg)?);
216                }
217
218                match self.context.call_function(function_name, &args) {
219                    Ok(result) => result,
220                    Err(FunctionCallError::IncorrectArgumentCount { expected }) => {
221                        return Err(EvalError {
222                            message: format!(
223                                "{function_name} takes {expected} arguments, {} given",
224                                args.len()
225                            )
226                            .into_boxed_str(),
227                            location: expression.location(),
228                        });
229                    }
230                    Err(FunctionCallError::IncorrectArgumentType { idx, expected }) => {
231                        return Err(EvalError {
232                            message: format!(
233                                "{function_name} expects argument {idx} to be {expected}, got {}",
234                                args[idx].type_of()
235                            )
236                            .into_boxed_str(),
237                            location: arguments[idx].location(),
238                        });
239                    }
240                    Err(FunctionCallError::FunctionNotFound) => {
241                        return Err(EvalError {
242                            message: format!("Function {function_name} is not found")
243                                .into_boxed_str(),
244                            location: expression.location(),
245                        });
246                    }
247                    Err(FunctionCallError::Other(error)) => {
248                        return Err(EvalError {
249                            message: format!("Failed to call {function_name}: {error}")
250                                .into_boxed_str(),
251                            location: expression.location(),
252                        });
253                    }
254                }
255            }
256        };
257
258        Ok(result)
259    }
260
261    fn typecheck_with_hint(
262        &self,
263        value: TypedValue<T>,
264        hint: Option<TypeHint>,
265    ) -> Result<TypedValue<T>, EvalError> {
266        let Some(hint) = hint else {
267            // No hint
268            return Ok(value);
269        };
270
271        let ty =
272            Type::from_name(hint.type_name.source(self.source)).map_err(|message| EvalError {
273                message,
274                location: hint.type_name.location,
275            })?;
276
277        self.typecheck(value, ty, hint.type_name.location)
278    }
279
280    fn typecheck(
281        &self,
282        value: TypedValue<T>,
283        hint: Type,
284        location: Location,
285    ) -> Result<TypedValue<T>, EvalError> {
286        match (value, hint) {
287            (value, hint) if value.type_of() == hint => Ok(value),
288            (TypedValue::MaybeSignedInt(val), Type::Int) => Ok(TypedValue::Int(val)),
289            (TypedValue::MaybeSignedInt(val), Type::SignedInt) => Ok(TypedValue::<T>::SignedInt(
290                T::to_signed(val).map_err(|_| EvalError {
291                    message: format!("Failed to cast {val:?} to signed int").into_boxed_str(),
292                    location,
293                })?,
294            )),
295            (value, hint) => Err(EvalError {
296                message: format!("Expected {hint}, got {}", value.type_of()).into_boxed_str(),
297                location,
298            }),
299        }
300    }
301
302    /// Evaluates a function with the given arguments.
303    pub fn visit_function(
304        &mut self,
305        function: &Function<T::Parser>,
306        args: &[TypedValue<T>],
307    ) -> Result<TypedValue<T>, EvalError> {
308        for (arg, arg_value) in function.arguments.iter().zip(args.iter()) {
309            let arg_name = arg.name.source(self.source);
310
311            let arg_value = self.typecheck_with_hint(arg_value.clone(), Some(arg.arg_type))?;
312
313            self.context.declare(arg_name, arg_value);
314        }
315
316        let retval = match self.visit_body(&function.body)? {
317            StatementResult::Return(typed_value) | StatementResult::ImplicitReturn(typed_value) => {
318                typed_value
319            }
320            StatementResult::EndOfBody => TypedValue::Void,
321            StatementResult::LoopBreak | StatementResult::LoopContinue => todo!(),
322        };
323
324        let retval =
325            self.typecheck_with_hint(retval, function.return_decl.as_ref().map(|d| d.return_type))?;
326
327        Ok(retval)
328    }
329
330    fn visit_body(&mut self, body: &Body<T::Parser>) -> Result<StatementResult<T>, EvalError> {
331        self.context.open_scope();
332
333        let mut body_result = StatementResult::EndOfBody;
334        for statement in body.statements.iter() {
335            if let Some(retval) = self.visit_statement(statement)? {
336                body_result = retval;
337                match body_result {
338                    StatementResult::ImplicitReturn(_) => {}
339                    _ => break,
340                }
341            } else {
342                // Reset result if we have statements after implicit returns.
343                body_result = StatementResult::EndOfBody;
344            }
345        }
346
347        self.context.close_scope();
348        Ok(body_result)
349    }
350
351    fn visit_statement(
352        &mut self,
353        statement: &Statement<T::Parser>,
354    ) -> Result<Option<StatementResult<T>>, EvalError> {
355        match statement {
356            Statement::Return(return_with_value) => {
357                return self
358                    .visit_right_hand_expression(&return_with_value.expression)
359                    .map(|rv| Some(StatementResult::Return(rv)));
360            }
361            Statement::ImplicitReturn(expression) => {
362                return self
363                    .visit_right_hand_expression(expression)
364                    .map(|rv| Some(StatementResult::ImplicitReturn(rv)));
365            }
366            Statement::EmptyReturn(_) => {
367                return Ok(Some(StatementResult::Return(TypedValue::Void)));
368            }
369            Statement::If(if_statement) => return self.visit_if(if_statement),
370            Statement::Loop(loop_statement) => return self.visit_loop(loop_statement),
371            Statement::Break(_) => return Ok(Some(StatementResult::LoopBreak)),
372            Statement::Continue(_) => return Ok(Some(StatementResult::LoopContinue)),
373            Statement::Scope(body) => {
374                return self.visit_body(body).map(|r| match r {
375                    StatementResult::EndOfBody => None,
376                    r => Some(r),
377                })
378            }
379            Statement::VariableDefinition(variable_definition) => {
380                self.visit_declaration(variable_definition)?;
381            }
382            Statement::Expression { expression, .. } => {
383                self.visit_expression(expression)?;
384            }
385        }
386
387        Ok(None)
388    }
389
390    fn visit_declaration(&mut self, decl: &VariableDefinition<T::Parser>) -> Result<(), EvalError> {
391        let name = decl.identifier.source(self.source);
392        let value = self.visit_right_hand_expression(&decl.initializer)?;
393
394        let value = self.typecheck_with_hint(value, decl.type_token)?;
395
396        self.context.declare(name, value);
397
398        Ok(())
399    }
400
401    fn visit_if(
402        &mut self,
403        if_statement: &If<T::Parser>,
404    ) -> Result<Option<StatementResult<T>>, EvalError> {
405        let condition = self.visit_right_hand_expression(&if_statement.condition)?;
406
407        let condition = self.typecheck(condition, Type::Bool, if_statement.condition.location())?;
408
409        let body = if condition == TypedValue::Bool(true) {
410            &if_statement.body
411        } else if let Some(ref else_branch) = if_statement.else_branch {
412            &else_branch.else_body
413        } else {
414            // Condition is false, but there is no `else`
415            return Ok(None);
416        };
417
418        let retval = match self.visit_body(body)? {
419            StatementResult::EndOfBody => None,
420            other => Some(other),
421        };
422        Ok(retval)
423    }
424
425    fn visit_loop(
426        &mut self,
427        loop_statement: &Loop<T::Parser>,
428    ) -> Result<Option<StatementResult<T>>, EvalError> {
429        loop {
430            match self.visit_body(&loop_statement.body)? {
431                ret @ StatementResult::Return(_) => return Ok(Some(ret)),
432                StatementResult::LoopBreak => return Ok(None),
433                StatementResult::LoopContinue
434                | StatementResult::EndOfBody
435                | StatementResult::ImplicitReturn(_) => {}
436            }
437        }
438    }
439}
440
441enum StatementResult<T: TypeSet> {
442    Return(TypedValue<T>),
443    ImplicitReturn(TypedValue<T>),
444    LoopBreak,
445    LoopContinue,
446    EndOfBody,
447}