Skip to main content

t_ree/
validate.rs

1use std::collections::{HashMap, HashSet};
2
3use crate::declaration::{Declaration, Module};
4use crate::expression::{Block, Expression, ExpressionKind, Statement};
5use crate::operator::BinaryOperator;
6use crate::types::Type;
7
8/// Validates type correctness across a module.
9///
10/// Checks binary operation type compatibility, assignment type safety,
11/// and type construction field correctness.
12pub fn validate_module(module: &Module) -> Result<(), String> {
13    let mut newtypes: HashMap<String, Type> = HashMap::new();
14    for declaration in module {
15        if let Declaration::Type(newtype) = declaration {
16            newtypes.insert(newtype.name.clone(), newtype.inner_type.clone());
17        }
18    }
19    let context = ValidationContext { newtypes };
20    for declaration in module {
21        match declaration {
22            Declaration::Function(function) => {
23                context.validate_block(&function.body)?;
24            }
25            Declaration::Constant(constant) => {
26                context.validate_expression(&constant.value)?;
27            }
28            _ => {}
29        }
30    }
31    Ok(())
32}
33
34struct ValidationContext {
35    newtypes: HashMap<String, Type>,
36}
37
38impl ValidationContext {
39    fn validate_block(&self, block: &Block) -> Result<(), String> {
40        for statement in &block.statements {
41            self.validate_statement(statement)?;
42        }
43        if let Some(result) = &block.result {
44            self.validate_expression(result)?;
45        }
46        Ok(())
47    }
48
49    fn validate_statement(&self, statement: &Statement) -> Result<(), String> {
50        match statement {
51            Statement::Expression(expression) | Statement::Return(Some(expression)) => {
52                self.validate_expression(expression)?;
53            }
54            Statement::Let { value, .. } => {
55                self.validate_expression(value)?;
56            }
57            Statement::Assign(target, value) => {
58                self.validate_expression(target)?;
59                self.validate_expression(value)?;
60                Self::check_replace_types(target, value)?;
61            }
62            Statement::Label {
63                initial_arguments, ..
64            } => {
65                for argument in initial_arguments {
66                    self.validate_expression(argument)?;
67                }
68            }
69            Statement::Jump { arguments, .. } => {
70                for argument in arguments {
71                    self.validate_expression(argument)?;
72                }
73            }
74            Statement::MultiReplace {
75                targets, values, ..
76            } => {
77                for target in targets {
78                    self.validate_expression(target)?;
79                }
80                for value in values {
81                    self.validate_expression(value)?;
82                }
83            }
84            Statement::Defer(inner) => {
85                self.validate_statement(inner)?;
86            }
87            Statement::Return(None) => {}
88        }
89        Ok(())
90    }
91
92    fn validate_expression(&self, expression: &Expression) -> Result<(), String> {
93        match &expression.kind {
94            ExpressionKind::BinaryOperation(operator, left, right) => {
95                self.validate_expression(left)?;
96                self.validate_expression(right)?;
97                self.check_binary_operands(operator, left, right)?;
98            }
99            ExpressionKind::TypeConstruction(name, fields) => {
100                for (_, value) in fields {
101                    self.validate_expression(value)?;
102                }
103                self.check_construction_fields(name, fields)?;
104            }
105            ExpressionKind::Replace(target, value) | ExpressionKind::OpAssign(_, target, value) => {
106                self.validate_expression(target)?;
107                self.validate_expression(value)?;
108                Self::check_replace_types(target, value)?;
109            }
110            ExpressionKind::Call(callee, arguments) => {
111                self.validate_expression(callee)?;
112                for argument in arguments {
113                    self.validate_expression(argument)?;
114                }
115            }
116            ExpressionKind::UnaryOperation(_, operand)
117            | ExpressionKind::Dereference(operand)
118            | ExpressionKind::Convert(operand, _)
119            | ExpressionKind::Transmute(operand, _) => {
120                self.validate_expression(operand)?;
121            }
122            ExpressionKind::Field(object, _) => {
123                self.validate_expression(object)?;
124            }
125            ExpressionKind::Index(object, index) => {
126                self.validate_expression(object)?;
127                self.validate_expression(index)?;
128            }
129            ExpressionKind::ArrayLiteral(elements)
130            | ExpressionKind::TupleLiteral(elements)
131            | ExpressionKind::Print(elements) => {
132                for element in elements {
133                    self.validate_expression(element)?;
134                }
135            }
136            ExpressionKind::Block(block) => {
137                self.validate_block(block)?;
138            }
139            ExpressionKind::If {
140                condition,
141                then_branch,
142                else_branch,
143            } => {
144                self.validate_expression(condition)?;
145                self.validate_block(then_branch)?;
146                if let Some(else_branch) = else_branch {
147                    self.validate_block(else_branch)?;
148                }
149            }
150            ExpressionKind::Match { value, arms } => {
151                self.validate_expression(value)?;
152                for arm in arms {
153                    self.validate_block(&arm.body)?;
154                }
155            }
156            ExpressionKind::Slice(array, start, end) => {
157                self.validate_expression(array)?;
158                if let Some(start) = start {
159                    self.validate_expression(start)?;
160                }
161                if let Some(end) = end {
162                    self.validate_expression(end)?;
163                }
164            }
165            ExpressionKind::Literal(_)
166            | ExpressionKind::Variable(_)
167            | ExpressionKind::SizeOf(_) => {}
168        }
169        Ok(())
170    }
171
172    fn resolve_underlying(&self, resolved_type: &Type) -> Type {
173        match resolved_type {
174            Type::Named(name) => self.newtypes.get(name).map_or_else(
175                || resolved_type.clone(),
176                |inner| self.resolve_underlying(inner),
177            ),
178            Type::Pointer(mutability, inner) => {
179                Type::Pointer(*mutability, Box::new(self.resolve_underlying(inner)))
180            }
181            other => other.clone(),
182        }
183    }
184
185    fn check_binary_operands(
186        &self,
187        operator: &BinaryOperator,
188        left: &Expression,
189        right: &Expression,
190    ) -> Result<(), String> {
191        if matches!(operator, BinaryOperator::Logical(_)) {
192            return Ok(());
193        }
194        let (Some(left_type), Some(right_type)) = (&left.resolved_type, &right.resolved_type)
195        else {
196            return Ok(());
197        };
198        if left_type == right_type {
199            return Ok(());
200        }
201        let left_resolved = self.resolve_underlying(left_type);
202        let right_resolved = self.resolve_underlying(right_type);
203        if left_resolved != right_resolved {
204            return Err(format!(
205                "type mismatch in '{operator}': left is {left_type}, right is {right_type}",
206            ));
207        }
208        if matches!(left_type, Type::Named(_)) && matches!(right_type, Type::Named(_)) {
209            return Err(format!(
210                "cannot mix distinct types in '{operator}': left is {left_type}, right is {right_type}",
211            ));
212        }
213        Ok(())
214    }
215
216    fn check_replace_types(target: &Expression, value: &Expression) -> Result<(), String> {
217        let Some(target_resolved) = &target.resolved_type else {
218            return Ok(());
219        };
220        let Some(value_type) = &value.resolved_type else {
221            return Ok(());
222        };
223        let target_type = match target_resolved {
224            Type::Pointer(_, inner) => inner.as_ref(),
225            other => other,
226        };
227        if target_type == value_type {
228            return Ok(());
229        }
230        if matches!(target_type, Type::Named(_)) && matches!(value_type, Type::Named(_)) {
231            return Err(format!(
232                "type mismatch in assignment: target is {target_type}, value is {value_type}",
233            ));
234        }
235        Ok(())
236    }
237
238    fn check_construction_fields(
239        &self,
240        type_name: &str,
241        fields: &[(String, Expression)],
242    ) -> Result<(), String> {
243        let Some(inner) = self.newtypes.get(type_name) else {
244            return Ok(());
245        };
246        let expected_fields: Vec<&str> = match inner {
247            Type::Tuple(field_types) => field_types
248                .iter()
249                .filter_map(|field_type| match field_type {
250                    Type::Named(name) => Some(name.as_str()),
251                    _ => None,
252                })
253                .collect(),
254            Type::Named(name) => vec![name.as_str()],
255            _ => return Ok(()),
256        };
257
258        let mut seen = HashSet::new();
259        for (field_name, _) in fields {
260            if !expected_fields.contains(&field_name.as_str()) {
261                return Err(format!("'{type_name}' has no field '{field_name}'"));
262            }
263            if !seen.insert(field_name.as_str()) {
264                return Err(format!(
265                    "duplicate field '{field_name}' in '{type_name}' construction"
266                ));
267            }
268        }
269        for expected in &expected_fields {
270            if !seen.contains(expected) {
271                return Err(format!(
272                    "missing field '{expected}' in '{type_name}' construction"
273                ));
274            }
275        }
276        Ok(())
277    }
278}