quantrs2_tytan/problem_dsl/
types.rs1use super::ast::AST;
4use super::error::TypeError;
5use std::collections::HashMap;
6
7#[derive(Debug, Clone, PartialEq)]
9pub enum VarType {
10 Binary,
11 Integer,
12 Continuous,
13 Spin,
14 Array {
15 element_type: Box<Self>,
16 dimensions: Vec<usize>,
17 },
18 Matrix {
19 element_type: Box<Self>,
20 rows: usize,
21 cols: usize,
22 },
23}
24
25#[derive(Debug, Clone)]
27pub struct TypeChecker {
28 var_types: HashMap<String, VarType>,
30 func_signatures: HashMap<String, FunctionSignature>,
32 errors: Vec<TypeError>,
34}
35
36#[derive(Debug, Clone)]
38pub struct FunctionSignature {
39 pub param_types: Vec<VarType>,
40 pub return_type: VarType,
41}
42
43impl TypeChecker {
44 pub fn new() -> Self {
46 let mut checker = Self {
47 var_types: HashMap::new(),
48 func_signatures: HashMap::new(),
49 errors: Vec::new(),
50 };
51
52 checker.register_builtin_functions();
54 checker
55 }
56
57 fn register_builtin_functions(&mut self) {
59 self.func_signatures.insert(
61 "abs".to_string(),
62 FunctionSignature {
63 param_types: vec![VarType::Continuous],
64 return_type: VarType::Continuous,
65 },
66 );
67
68 self.func_signatures.insert(
69 "sqrt".to_string(),
70 FunctionSignature {
71 param_types: vec![VarType::Continuous],
72 return_type: VarType::Continuous,
73 },
74 );
75
76 self.func_signatures.insert(
77 "exp".to_string(),
78 FunctionSignature {
79 param_types: vec![VarType::Continuous],
80 return_type: VarType::Continuous,
81 },
82 );
83
84 self.func_signatures.insert(
85 "log".to_string(),
86 FunctionSignature {
87 param_types: vec![VarType::Continuous],
88 return_type: VarType::Continuous,
89 },
90 );
91
92 self.func_signatures.insert(
94 "sum".to_string(),
95 FunctionSignature {
96 param_types: vec![VarType::Array {
97 element_type: Box::new(VarType::Continuous),
98 dimensions: vec![0],
99 }],
100 return_type: VarType::Continuous,
101 },
102 );
103
104 self.func_signatures.insert(
105 "product".to_string(),
106 FunctionSignature {
107 param_types: vec![VarType::Array {
108 element_type: Box::new(VarType::Continuous),
109 dimensions: vec![0],
110 }],
111 return_type: VarType::Continuous,
112 },
113 );
114
115 self.func_signatures.insert(
116 "min".to_string(),
117 FunctionSignature {
118 param_types: vec![VarType::Array {
119 element_type: Box::new(VarType::Continuous),
120 dimensions: vec![0],
121 }],
122 return_type: VarType::Continuous,
123 },
124 );
125
126 self.func_signatures.insert(
127 "max".to_string(),
128 FunctionSignature {
129 param_types: vec![VarType::Array {
130 element_type: Box::new(VarType::Continuous),
131 dimensions: vec![0],
132 }],
133 return_type: VarType::Continuous,
134 },
135 );
136 }
137
138 pub fn check(&mut self, ast: &AST) -> Result<(), TypeError> {
140 self.errors.clear();
141 self.check_ast(ast);
142
143 if self.errors.is_empty() {
144 Ok(())
145 } else {
146 Err(self.errors[0].clone())
147 }
148 }
149
150 fn check_ast(&mut self, ast: &AST) {
152 match ast {
153 AST::Program {
154 declarations,
155 objective,
156 constraints,
157 } => {
158 for decl in declarations {
160 self.check_declaration(decl);
161 }
162
163 self.check_objective(objective);
165
166 for constraint in constraints {
168 self.check_constraint(constraint);
169 }
170 }
171 AST::VarDecl { name, var_type, .. } => {
172 self.var_types.insert(name.clone(), var_type.clone());
173 }
174 AST::Expr(expr) => {
175 self.check_expression(expr);
176 }
177 AST::Stmt(stmt) => {
178 self.check_statement(stmt);
179 }
180 }
181 }
182
183 fn check_declaration(&mut self, decl: &super::ast::Declaration) {
185 match decl {
186 super::ast::Declaration::Variable { name, var_type, .. } => {
187 self.var_types.insert(name.clone(), var_type.clone());
188 }
189 super::ast::Declaration::Parameter { name, value, .. } => {
190 let value_type = self.infer_value_type(value);
191 self.var_types.insert(name.clone(), value_type);
192 }
193 super::ast::Declaration::Set { name, elements } => {
194 if !elements.is_empty() {
195 let element_type = self.infer_value_type(&elements[0]);
196 let array_type = VarType::Array {
197 element_type: Box::new(element_type),
198 dimensions: vec![elements.len()],
199 };
200 self.var_types.insert(name.clone(), array_type);
201 }
202 }
203 super::ast::Declaration::Function { name, params, body } => {
204 let param_types = params.iter().map(|_| VarType::Continuous).collect();
206 let signature = FunctionSignature {
207 param_types,
208 return_type: VarType::Continuous,
209 };
210 self.func_signatures.insert(name.clone(), signature);
211
212 self.check_expression(body);
214 }
215 }
216 }
217
218 fn check_objective(&mut self, obj: &super::ast::Objective) {
220 match obj {
221 super::ast::Objective::Minimize(expr) | super::ast::Objective::Maximize(expr) => {
222 self.check_expression(expr);
223 }
224 super::ast::Objective::MultiObjective { objectives } => {
225 for (_, expr, _) in objectives {
226 self.check_expression(expr);
227 }
228 }
229 }
230 }
231
232 fn check_constraint(&mut self, constraint: &super::ast::Constraint) {
234 self.check_constraint_expression(&constraint.expression);
235 }
236
237 fn check_constraint_expression(&mut self, expr: &super::ast::ConstraintExpression) {
239 match expr {
240 super::ast::ConstraintExpression::Comparison { left, right, .. } => {
241 self.check_expression(left);
242 self.check_expression(right);
243 }
244 super::ast::ConstraintExpression::Logical { operands, .. } => {
245 for operand in operands {
246 self.check_constraint_expression(operand);
247 }
248 }
249 super::ast::ConstraintExpression::Quantified { constraint, .. } => {
250 self.check_constraint_expression(constraint);
251 }
252 super::ast::ConstraintExpression::Implication {
253 condition,
254 consequence,
255 } => {
256 self.check_constraint_expression(condition);
257 self.check_constraint_expression(consequence);
258 }
259 super::ast::ConstraintExpression::Counting { count, .. } => {
260 self.check_expression(count);
261 }
262 }
263 }
264
265 fn check_expression(&mut self, expr: &super::ast::Expression) {
267 match expr {
268 super::ast::Expression::Literal(_) => {
269 }
271 super::ast::Expression::Variable(name) => {
272 if !self.var_types.contains_key(name) {
273 self.errors.push(TypeError {
274 message: format!("Undefined variable: {name}"),
275 location: name.clone(),
276 });
277 }
278 }
279 super::ast::Expression::IndexedVar { name, indices } => {
280 if !self.var_types.contains_key(name) {
281 self.errors.push(TypeError {
282 message: format!("Undefined variable: {name}"),
283 location: name.clone(),
284 });
285 }
286 for index in indices {
287 self.check_expression(index);
288 }
289 }
290 super::ast::Expression::BinaryOp { left, right, .. } => {
291 self.check_expression(left);
292 self.check_expression(right);
293 }
294 super::ast::Expression::UnaryOp { operand, .. } => {
295 self.check_expression(operand);
296 }
297 super::ast::Expression::FunctionCall { name, args } => {
298 if let Some(signature) = self.func_signatures.get(name) {
299 if args.len() != signature.param_types.len() {
300 self.errors.push(TypeError {
301 message: format!(
302 "Function {} expects {} arguments, got {}",
303 name,
304 signature.param_types.len(),
305 args.len()
306 ),
307 location: name.clone(),
308 });
309 }
310 } else {
311 self.errors.push(TypeError {
312 message: format!("Undefined function: {name}"),
313 location: name.clone(),
314 });
315 }
316
317 for arg in args {
318 self.check_expression(arg);
319 }
320 }
321 super::ast::Expression::Aggregation { expression, .. } => {
322 self.check_expression(expression);
323 }
324 super::ast::Expression::Conditional {
325 condition,
326 then_expr,
327 else_expr,
328 } => {
329 self.check_constraint_expression(condition);
330 self.check_expression(then_expr);
331 self.check_expression(else_expr);
332 }
333 }
334 }
335
336 fn check_statement(&mut self, stmt: &super::ast::Statement) {
338 match stmt {
339 super::ast::Statement::Assignment { target, value } => {
340 if !self.var_types.contains_key(target) {
341 self.errors.push(TypeError {
342 message: format!("Undefined variable: {target}"),
343 location: target.clone(),
344 });
345 }
346 self.check_expression(value);
347 }
348 super::ast::Statement::If {
349 condition,
350 then_branch,
351 else_branch,
352 } => {
353 self.check_constraint_expression(condition);
354 for stmt in then_branch {
355 self.check_statement(stmt);
356 }
357 if let Some(else_stmts) = else_branch {
358 for stmt in else_stmts {
359 self.check_statement(stmt);
360 }
361 }
362 }
363 super::ast::Statement::For { body, .. } => {
364 for stmt in body {
365 self.check_statement(stmt);
366 }
367 }
368 }
369 }
370
371 fn infer_value_type(&self, value: &super::ast::Value) -> VarType {
373 match value {
374 super::ast::Value::Number(_) => VarType::Continuous,
375 super::ast::Value::Boolean(_) => VarType::Binary,
376 super::ast::Value::String(_) => VarType::Continuous, super::ast::Value::Array(elements) => {
378 if elements.is_empty() {
379 VarType::Array {
380 element_type: Box::new(VarType::Continuous),
381 dimensions: vec![0],
382 }
383 } else {
384 let element_type = self.infer_value_type(&elements[0]);
385 VarType::Array {
386 element_type: Box::new(element_type),
387 dimensions: vec![elements.len()],
388 }
389 }
390 }
391 super::ast::Value::Tuple(elements) => {
392 if elements.is_empty() {
393 VarType::Array {
394 element_type: Box::new(VarType::Continuous),
395 dimensions: vec![0],
396 }
397 } else {
398 let element_type = self.infer_value_type(&elements[0]);
399 VarType::Array {
400 element_type: Box::new(element_type),
401 dimensions: vec![elements.len()],
402 }
403 }
404 }
405 }
406 }
407
408 pub fn get_var_type(&self, name: &str) -> Option<&VarType> {
410 self.var_types.get(name)
411 }
412
413 pub fn get_function_signature(&self, name: &str) -> Option<&FunctionSignature> {
415 self.func_signatures.get(name)
416 }
417}
418
419impl Default for TypeChecker {
420 fn default() -> Self {
421 Self::new()
422 }
423}