quantrs2_tytan/problem_dsl/
types.rs

1//! Type system for the problem DSL.
2
3use super::ast::AST;
4use super::error::TypeError;
5use std::collections::HashMap;
6
7/// Variable types
8#[derive(Debug, Clone, PartialEq)]
9pub enum VarType {
10    Binary,
11    Integer,
12    Continuous,
13    Spin,
14    Array {
15        element_type: Box<Self>,
16        dimensions: Vec<usize>,
17    },
18    Matrix {
19        element_type: Box<Self>,
20        rows: usize,
21        cols: usize,
22    },
23}
24
25/// Type checker
26#[derive(Debug, Clone)]
27pub struct TypeChecker {
28    /// Variable types
29    var_types: HashMap<String, VarType>,
30    /// Function signatures
31    func_signatures: HashMap<String, FunctionSignature>,
32    /// Type errors
33    errors: Vec<TypeError>,
34}
35
36/// Function signature
37#[derive(Debug, Clone)]
38pub struct FunctionSignature {
39    pub param_types: Vec<VarType>,
40    pub return_type: VarType,
41}
42
43impl TypeChecker {
44    /// Create a new type checker
45    pub fn new() -> Self {
46        let mut checker = Self {
47            var_types: HashMap::new(),
48            func_signatures: HashMap::new(),
49            errors: Vec::new(),
50        };
51
52        // Register built-in functions
53        checker.register_builtin_functions();
54        checker
55    }
56
57    /// Register built-in function signatures
58    fn register_builtin_functions(&mut self) {
59        // Mathematical functions
60        self.func_signatures.insert(
61            "abs".to_string(),
62            FunctionSignature {
63                param_types: vec![VarType::Continuous],
64                return_type: VarType::Continuous,
65            },
66        );
67
68        self.func_signatures.insert(
69            "sqrt".to_string(),
70            FunctionSignature {
71                param_types: vec![VarType::Continuous],
72                return_type: VarType::Continuous,
73            },
74        );
75
76        self.func_signatures.insert(
77            "exp".to_string(),
78            FunctionSignature {
79                param_types: vec![VarType::Continuous],
80                return_type: VarType::Continuous,
81            },
82        );
83
84        self.func_signatures.insert(
85            "log".to_string(),
86            FunctionSignature {
87                param_types: vec![VarType::Continuous],
88                return_type: VarType::Continuous,
89            },
90        );
91
92        // Aggregation functions
93        self.func_signatures.insert(
94            "sum".to_string(),
95            FunctionSignature {
96                param_types: vec![VarType::Array {
97                    element_type: Box::new(VarType::Continuous),
98                    dimensions: vec![0],
99                }],
100                return_type: VarType::Continuous,
101            },
102        );
103
104        self.func_signatures.insert(
105            "product".to_string(),
106            FunctionSignature {
107                param_types: vec![VarType::Array {
108                    element_type: Box::new(VarType::Continuous),
109                    dimensions: vec![0],
110                }],
111                return_type: VarType::Continuous,
112            },
113        );
114
115        self.func_signatures.insert(
116            "min".to_string(),
117            FunctionSignature {
118                param_types: vec![VarType::Array {
119                    element_type: Box::new(VarType::Continuous),
120                    dimensions: vec![0],
121                }],
122                return_type: VarType::Continuous,
123            },
124        );
125
126        self.func_signatures.insert(
127            "max".to_string(),
128            FunctionSignature {
129                param_types: vec![VarType::Array {
130                    element_type: Box::new(VarType::Continuous),
131                    dimensions: vec![0],
132                }],
133                return_type: VarType::Continuous,
134            },
135        );
136    }
137
138    /// Type check an AST
139    pub fn check(&mut self, ast: &AST) -> Result<(), TypeError> {
140        self.errors.clear();
141        self.check_ast(ast);
142
143        if self.errors.is_empty() {
144            Ok(())
145        } else {
146            Err(self.errors[0].clone())
147        }
148    }
149
150    /// Check AST node
151    fn check_ast(&mut self, ast: &AST) {
152        match ast {
153            AST::Program {
154                declarations,
155                objective,
156                constraints,
157            } => {
158                // Check declarations first to build symbol table
159                for decl in declarations {
160                    self.check_declaration(decl);
161                }
162
163                // Check objective
164                self.check_objective(objective);
165
166                // Check constraints
167                for constraint in constraints {
168                    self.check_constraint(constraint);
169                }
170            }
171            AST::VarDecl { name, var_type, .. } => {
172                self.var_types.insert(name.clone(), var_type.clone());
173            }
174            AST::Expr(expr) => {
175                self.check_expression(expr);
176            }
177            AST::Stmt(stmt) => {
178                self.check_statement(stmt);
179            }
180        }
181    }
182
183    /// Check declaration
184    fn check_declaration(&mut self, decl: &super::ast::Declaration) {
185        match decl {
186            super::ast::Declaration::Variable { name, var_type, .. } => {
187                self.var_types.insert(name.clone(), var_type.clone());
188            }
189            super::ast::Declaration::Parameter { name, value, .. } => {
190                let value_type = self.infer_value_type(value);
191                self.var_types.insert(name.clone(), value_type);
192            }
193            super::ast::Declaration::Set { name, elements } => {
194                if !elements.is_empty() {
195                    let element_type = self.infer_value_type(&elements[0]);
196                    let array_type = VarType::Array {
197                        element_type: Box::new(element_type),
198                        dimensions: vec![elements.len()],
199                    };
200                    self.var_types.insert(name.clone(), array_type);
201                }
202            }
203            super::ast::Declaration::Function { name, params, body } => {
204                // For now, assume functions return continuous values
205                let param_types = params.iter().map(|_| VarType::Continuous).collect();
206                let signature = FunctionSignature {
207                    param_types,
208                    return_type: VarType::Continuous,
209                };
210                self.func_signatures.insert(name.clone(), signature);
211
212                // Check function body
213                self.check_expression(body);
214            }
215        }
216    }
217
218    /// Check objective
219    fn check_objective(&mut self, obj: &super::ast::Objective) {
220        match obj {
221            super::ast::Objective::Minimize(expr) | super::ast::Objective::Maximize(expr) => {
222                self.check_expression(expr);
223            }
224            super::ast::Objective::MultiObjective { objectives } => {
225                for (_, expr, _) in objectives {
226                    self.check_expression(expr);
227                }
228            }
229        }
230    }
231
232    /// Check constraint
233    fn check_constraint(&mut self, constraint: &super::ast::Constraint) {
234        self.check_constraint_expression(&constraint.expression);
235    }
236
237    /// Check constraint expression
238    fn check_constraint_expression(&mut self, expr: &super::ast::ConstraintExpression) {
239        match expr {
240            super::ast::ConstraintExpression::Comparison { left, right, .. } => {
241                self.check_expression(left);
242                self.check_expression(right);
243            }
244            super::ast::ConstraintExpression::Logical { operands, .. } => {
245                for operand in operands {
246                    self.check_constraint_expression(operand);
247                }
248            }
249            super::ast::ConstraintExpression::Quantified { constraint, .. } => {
250                self.check_constraint_expression(constraint);
251            }
252            super::ast::ConstraintExpression::Implication {
253                condition,
254                consequence,
255            } => {
256                self.check_constraint_expression(condition);
257                self.check_constraint_expression(consequence);
258            }
259            super::ast::ConstraintExpression::Counting { count, .. } => {
260                self.check_expression(count);
261            }
262        }
263    }
264
265    /// Check expression
266    fn check_expression(&mut self, expr: &super::ast::Expression) {
267        match expr {
268            super::ast::Expression::Literal(_) => {
269                // Literals are always valid
270            }
271            super::ast::Expression::Variable(name) => {
272                if !self.var_types.contains_key(name) {
273                    self.errors.push(TypeError {
274                        message: format!("Undefined variable: {name}"),
275                        location: name.clone(),
276                    });
277                }
278            }
279            super::ast::Expression::IndexedVar { name, indices } => {
280                if !self.var_types.contains_key(name) {
281                    self.errors.push(TypeError {
282                        message: format!("Undefined variable: {name}"),
283                        location: name.clone(),
284                    });
285                }
286                for index in indices {
287                    self.check_expression(index);
288                }
289            }
290            super::ast::Expression::BinaryOp { left, right, .. } => {
291                self.check_expression(left);
292                self.check_expression(right);
293            }
294            super::ast::Expression::UnaryOp { operand, .. } => {
295                self.check_expression(operand);
296            }
297            super::ast::Expression::FunctionCall { name, args } => {
298                if let Some(signature) = self.func_signatures.get(name) {
299                    if args.len() != signature.param_types.len() {
300                        self.errors.push(TypeError {
301                            message: format!(
302                                "Function {} expects {} arguments, got {}",
303                                name,
304                                signature.param_types.len(),
305                                args.len()
306                            ),
307                            location: name.clone(),
308                        });
309                    }
310                } else {
311                    self.errors.push(TypeError {
312                        message: format!("Undefined function: {name}"),
313                        location: name.clone(),
314                    });
315                }
316
317                for arg in args {
318                    self.check_expression(arg);
319                }
320            }
321            super::ast::Expression::Aggregation { expression, .. } => {
322                self.check_expression(expression);
323            }
324            super::ast::Expression::Conditional {
325                condition,
326                then_expr,
327                else_expr,
328            } => {
329                self.check_constraint_expression(condition);
330                self.check_expression(then_expr);
331                self.check_expression(else_expr);
332            }
333        }
334    }
335
336    /// Check statement
337    fn check_statement(&mut self, stmt: &super::ast::Statement) {
338        match stmt {
339            super::ast::Statement::Assignment { target, value } => {
340                if !self.var_types.contains_key(target) {
341                    self.errors.push(TypeError {
342                        message: format!("Undefined variable: {target}"),
343                        location: target.clone(),
344                    });
345                }
346                self.check_expression(value);
347            }
348            super::ast::Statement::If {
349                condition,
350                then_branch,
351                else_branch,
352            } => {
353                self.check_constraint_expression(condition);
354                for stmt in then_branch {
355                    self.check_statement(stmt);
356                }
357                if let Some(else_stmts) = else_branch {
358                    for stmt in else_stmts {
359                        self.check_statement(stmt);
360                    }
361                }
362            }
363            super::ast::Statement::For { body, .. } => {
364                for stmt in body {
365                    self.check_statement(stmt);
366                }
367            }
368        }
369    }
370
371    /// Infer type from value
372    fn infer_value_type(&self, value: &super::ast::Value) -> VarType {
373        match value {
374            super::ast::Value::Number(_) => VarType::Continuous,
375            super::ast::Value::Boolean(_) => VarType::Binary,
376            super::ast::Value::String(_) => VarType::Continuous, // Treat as parameter
377            super::ast::Value::Array(elements) => {
378                if elements.is_empty() {
379                    VarType::Array {
380                        element_type: Box::new(VarType::Continuous),
381                        dimensions: vec![0],
382                    }
383                } else {
384                    let element_type = self.infer_value_type(&elements[0]);
385                    VarType::Array {
386                        element_type: Box::new(element_type),
387                        dimensions: vec![elements.len()],
388                    }
389                }
390            }
391            super::ast::Value::Tuple(elements) => {
392                if elements.is_empty() {
393                    VarType::Array {
394                        element_type: Box::new(VarType::Continuous),
395                        dimensions: vec![0],
396                    }
397                } else {
398                    let element_type = self.infer_value_type(&elements[0]);
399                    VarType::Array {
400                        element_type: Box::new(element_type),
401                        dimensions: vec![elements.len()],
402                    }
403                }
404            }
405        }
406    }
407
408    /// Get variable type
409    pub fn get_var_type(&self, name: &str) -> Option<&VarType> {
410        self.var_types.get(name)
411    }
412
413    /// Get function signature
414    pub fn get_function_signature(&self, name: &str) -> Option<&FunctionSignature> {
415        self.func_signatures.get(name)
416    }
417}
418
419impl Default for TypeChecker {
420    fn default() -> Self {
421        Self::new()
422    }
423}