quantrs2_tytan/problem_dsl/
compiler.rs

1//! Compiler for the problem DSL.
2
3use super::ast::{
4    AggregationOp, BinaryOperator, ComparisonOp, Constraint, ConstraintExpression, Declaration,
5    Expression, Objective, ObjectiveType, Value, AST,
6};
7use super::error::CompileError;
8use scirs2_core::ndarray::Array2;
9use std::collections::HashMap;
10
11/// Compiler options
12#[derive(Debug, Clone)]
13pub struct CompilerOptions {
14    /// Optimization level
15    pub optimization_level: OptimizationLevel,
16    /// Target backend
17    pub target: TargetBackend,
18    /// Debug information
19    pub debug_info: bool,
20    /// Warnings as errors
21    pub warnings_as_errors: bool,
22}
23
24#[derive(Debug, Clone)]
25pub enum OptimizationLevel {
26    None,
27    Basic,
28    Full,
29}
30
31#[derive(Debug, Clone)]
32pub enum TargetBackend {
33    QUBO,
34    Ising,
35    HigherOrder,
36}
37
38impl Default for CompilerOptions {
39    fn default() -> Self {
40        Self {
41            optimization_level: OptimizationLevel::Basic,
42            target: TargetBackend::QUBO,
43            debug_info: false,
44            warnings_as_errors: false,
45        }
46    }
47}
48
49/// Variable registry for tracking variables during compilation
50#[derive(Debug, Clone)]
51struct VariableRegistry {
52    /// Maps variable names to their indices in the QUBO matrix
53    var_indices: HashMap<String, usize>,
54    /// Maps indexed variables (e.g., x[i,j]) to their indices
55    indexed_var_indices: HashMap<String, HashMap<Vec<usize>, usize>>,
56    /// Total number of variables
57    num_vars: usize,
58    /// Variable domains
59    domains: HashMap<String, VariableDomain>,
60}
61
62#[derive(Debug, Clone)]
63enum VariableDomain {
64    Binary,
65    Integer { min: i32, max: i32 },
66    Continuous { min: f64, max: f64 },
67}
68
69impl VariableRegistry {
70    fn new() -> Self {
71        Self {
72            var_indices: HashMap::new(),
73            indexed_var_indices: HashMap::new(),
74            num_vars: 0,
75            domains: HashMap::new(),
76        }
77    }
78
79    fn register_variable(&mut self, name: &str, domain: VariableDomain) -> usize {
80        if let Some(&idx) = self.var_indices.get(name) {
81            return idx;
82        }
83        let idx = self.num_vars;
84        self.var_indices.insert(name.to_string(), idx);
85        self.domains.insert(name.to_string(), domain);
86        self.num_vars += 1;
87        idx
88    }
89
90    fn register_indexed_variable(
91        &mut self,
92        base_name: &str,
93        indices: Vec<usize>,
94        domain: VariableDomain,
95    ) -> usize {
96        let indexed_map = self
97            .indexed_var_indices
98            .entry(base_name.to_string())
99            .or_default();
100        if let Some(&idx) = indexed_map.get(&indices) {
101            return idx;
102        }
103        let idx = self.num_vars;
104        indexed_map.insert(indices, idx);
105        let full_name = format!("{}_{}", base_name, self.num_vars);
106        self.domains.insert(full_name, domain);
107        self.num_vars += 1;
108        idx
109    }
110}
111
112/// Compile AST to QUBO matrix
113pub fn compile_to_qubo(ast: &AST, options: &CompilerOptions) -> Result<Array2<f64>, CompileError> {
114    match ast {
115        AST::Program {
116            declarations,
117            objective,
118            constraints,
119        } => {
120            let mut compiler = Compiler::new(options.clone());
121
122            // Process declarations
123            for decl in declarations {
124                compiler.process_declaration(decl)?;
125            }
126
127            // Build QUBO from objective
128            let mut qubo = compiler.build_objective_qubo(objective)?;
129
130            // Add constraint penalties
131            for constraint in constraints {
132                compiler.add_constraint_penalty(&mut qubo, constraint)?;
133            }
134
135            Ok(qubo)
136        }
137        _ => Err(CompileError {
138            message: "Can only compile program AST nodes".to_string(),
139            context: "compile_to_qubo".to_string(),
140        }),
141    }
142}
143
144/// Internal compiler state
145#[derive(Clone)]
146struct Compiler {
147    options: CompilerOptions,
148    registry: VariableRegistry,
149    parameters: HashMap<String, Value>,
150    penalty_weight: f64,
151}
152
153impl Compiler {
154    fn new(options: CompilerOptions) -> Self {
155        Self {
156            options,
157            registry: VariableRegistry::new(),
158            parameters: HashMap::new(),
159            penalty_weight: 1000.0, // Default penalty weight for constraints
160        }
161    }
162
163    fn process_declaration(&mut self, decl: &Declaration) -> Result<(), CompileError> {
164        match decl {
165            Declaration::Variable {
166                name,
167                var_type: _,
168                domain: _,
169                attributes: _,
170            } => {
171                // For now, assume all variables are binary
172                self.registry
173                    .register_variable(name, VariableDomain::Binary);
174                Ok(())
175            }
176            Declaration::Parameter { name, value, .. } => {
177                self.parameters.insert(name.clone(), value.clone());
178                Ok(())
179            }
180            Declaration::Set { name, elements } => {
181                // Register set as parameter for later use in aggregations
182                self.parameters
183                    .insert(name.clone(), Value::Array(elements.clone()));
184                Ok(())
185            }
186            Declaration::Function {
187                name,
188                params: _,
189                body: _,
190            } => {
191                // Store function definition for later expansion
192                // For now, treat as a complex parameter
193                self.parameters.insert(
194                    format!("func_{name}"),
195                    Value::String(format!("function_{name}")),
196                );
197                Ok(())
198            }
199        }
200    }
201
202    fn build_objective_qubo(&mut self, objective: &Objective) -> Result<Array2<f64>, CompileError> {
203        let num_vars = self.registry.num_vars;
204        let mut qubo = Array2::zeros((num_vars, num_vars));
205
206        match objective {
207            Objective::Minimize(expr) => {
208                self.add_expression_to_qubo(&mut qubo, expr, 1.0)?;
209            }
210            Objective::Maximize(expr) => {
211                self.add_expression_to_qubo(&mut qubo, expr, -1.0)?;
212            }
213            Objective::MultiObjective { objectives } => {
214                for (obj_type, expr, weight) in objectives {
215                    let sign = match obj_type {
216                        ObjectiveType::Minimize => 1.0,
217                        ObjectiveType::Maximize => -1.0,
218                    };
219                    self.add_expression_to_qubo(&mut qubo, expr, sign * weight)?;
220                }
221            }
222        }
223
224        Ok(qubo)
225    }
226
227    fn add_expression_to_qubo(
228        &mut self,
229        qubo: &mut Array2<f64>,
230        expr: &Expression,
231        coefficient: f64,
232    ) -> Result<(), CompileError> {
233        match expr {
234            Expression::Variable(name) => {
235                if let Some(&idx) = self.registry.var_indices.get(name) {
236                    qubo[[idx, idx]] += coefficient;
237                } else {
238                    return Err(CompileError {
239                        message: format!("Unknown variable: {name}"),
240                        context: "add_expression_to_qubo".to_string(),
241                    });
242                }
243            }
244            Expression::BinaryOp { op, left, right } => {
245                match op {
246                    BinaryOperator::Add => {
247                        self.add_expression_to_qubo(qubo, left, coefficient)?;
248                        self.add_expression_to_qubo(qubo, right, coefficient)?;
249                    }
250                    BinaryOperator::Subtract => {
251                        self.add_expression_to_qubo(qubo, left, coefficient)?;
252                        self.add_expression_to_qubo(qubo, right, -coefficient)?;
253                    }
254                    BinaryOperator::Multiply => {
255                        // Handle multiplication of two variables (creates quadratic term)
256                        if let (Expression::Variable(v1), Expression::Variable(v2)) =
257                            (left.as_ref(), right.as_ref())
258                        {
259                            if let (Some(&idx1), Some(&idx2)) = (
260                                self.registry.var_indices.get(v1),
261                                self.registry.var_indices.get(v2),
262                            ) {
263                                if idx1 == idx2 {
264                                    // x*x = x for binary variables
265                                    qubo[[idx1, idx1]] += coefficient;
266                                } else {
267                                    // Quadratic term
268                                    qubo[[idx1, idx2]] += coefficient / 2.0;
269                                    qubo[[idx2, idx1]] += coefficient / 2.0;
270                                }
271                            }
272                        } else {
273                            return Err(CompileError {
274                                message: "Complex multiplication not yet supported".to_string(),
275                                context: "add_expression_to_qubo".to_string(),
276                            });
277                        }
278                    }
279                    _ => {
280                        return Err(CompileError {
281                            message: format!("Unsupported binary operator: {op:?}"),
282                            context: "add_expression_to_qubo".to_string(),
283                        });
284                    }
285                }
286            }
287            Expression::Literal(Value::Number(_)) => {
288                // Constants don't affect the optimization, but we could track them
289                // for the objective value offset
290            }
291            Expression::Aggregation {
292                op,
293                variables,
294                expression,
295            } => {
296                match op {
297                    AggregationOp::Sum => {
298                        // Expand sum over index sets
299                        for (var_name, set_name) in variables {
300                            // Clone the elements to avoid borrowing conflicts
301                            let elements = if let Some(Value::Array(elements)) =
302                                self.parameters.get(set_name)
303                            {
304                                elements.clone()
305                            } else {
306                                return Err(CompileError {
307                                    message: format!("Unknown set for aggregation: {set_name}"),
308                                    context: "add_expression_to_qubo".to_string(),
309                                });
310                            };
311
312                            // For each element in the set, substitute and add expression
313                            for (i, element) in elements.iter().enumerate() {
314                                // Create substituted expression
315                                let substituted_expr = {
316                                    let mut compiler = self.clone();
317                                    compiler.substitute_variable_in_expression(
318                                        expression, var_name, element, i,
319                                    )?
320                                };
321                                let mut qubo_mut = qubo.clone();
322                                let mut compiler = self.clone();
323                                compiler.add_expression_to_qubo(
324                                    &mut qubo_mut,
325                                    &substituted_expr,
326                                    coefficient,
327                                )?;
328                                *qubo = qubo_mut;
329                            }
330                        }
331                    }
332                    AggregationOp::Product => {
333                        // Product expansion (multiply all terms)
334                        let mut product_expr = Expression::Literal(Value::Number(1.0));
335                        for (var_name, set_name) in variables {
336                            // Clone the elements to avoid borrowing conflicts
337                            let elements = if let Some(Value::Array(elements)) =
338                                self.parameters.get(set_name)
339                            {
340                                elements.clone()
341                            } else {
342                                continue; // Skip if set doesn't exist
343                            };
344
345                            for (i, element) in elements.iter().enumerate() {
346                                let substituted_expr = {
347                                    let mut compiler = self.clone();
348                                    compiler.substitute_variable_in_expression(
349                                        expression, var_name, element, i,
350                                    )?
351                                };
352                                product_expr = Expression::BinaryOp {
353                                    op: BinaryOperator::Multiply,
354                                    left: Box::new(product_expr),
355                                    right: Box::new(substituted_expr),
356                                };
357                            }
358                        }
359                        let mut qubo_mut = qubo.clone();
360                        let mut compiler = self.clone();
361                        compiler.add_expression_to_qubo(
362                            &mut qubo_mut,
363                            &product_expr,
364                            coefficient,
365                        )?;
366                        *qubo = qubo_mut;
367                    }
368                    _ => {
369                        return Err(CompileError {
370                            message: format!("Unsupported aggregation operator: {op:?}"),
371                            context: "add_expression_to_qubo".to_string(),
372                        });
373                    }
374                }
375            }
376            _ => {
377                return Err(CompileError {
378                    message: "Expression type not yet supported".to_string(),
379                    context: "add_expression_to_qubo".to_string(),
380                });
381            }
382        }
383        Ok(())
384    }
385
386    fn substitute_variable_in_expression(
387        &mut self,
388        expr: &Expression,
389        var_name: &str,
390        value: &Value,
391        index: usize,
392    ) -> Result<Expression, CompileError> {
393        match expr {
394            Expression::Variable(name) if name == var_name => {
395                // Replace with indexed variable or direct substitution
396                match value {
397                    Value::Number(_n) => {
398                        let indexed_name = format!("{var_name}_{index}");
399                        self.registry
400                            .register_variable(&indexed_name, VariableDomain::Binary);
401                        Ok(Expression::Variable(indexed_name))
402                    }
403                    _ => Ok(Expression::Literal(value.clone())),
404                }
405            }
406            Expression::Variable(name) => Ok(Expression::Variable(name.clone())),
407            Expression::BinaryOp { op, left, right } => {
408                let new_left =
409                    self.substitute_variable_in_expression(left, var_name, value, index)?;
410                let new_right =
411                    self.substitute_variable_in_expression(right, var_name, value, index)?;
412                Ok(Expression::BinaryOp {
413                    op: op.clone(),
414                    left: Box::new(new_left),
415                    right: Box::new(new_right),
416                })
417            }
418            Expression::IndexedVar { name, indices } => {
419                let new_indices = indices
420                    .iter()
421                    .map(|idx| self.substitute_variable_in_expression(idx, var_name, value, index))
422                    .collect::<Result<Vec<_>, _>>()?;
423                Ok(Expression::IndexedVar {
424                    name: name.clone(),
425                    indices: new_indices,
426                })
427            }
428            _ => Ok(expr.clone()),
429        }
430    }
431
432    fn add_constraint_penalty(
433        &mut self,
434        qubo: &mut Array2<f64>,
435        constraint: &Constraint,
436    ) -> Result<(), CompileError> {
437        match &constraint.expression {
438            ConstraintExpression::Comparison { left, op, right } => {
439                match op {
440                    ComparisonOp::Equal => {
441                        // For equality constraint: (left - right)^2
442                        // Expand: left^2 - 2*left*right + right^2
443                        self.add_expression_to_qubo(qubo, left, self.penalty_weight)?;
444                        self.add_expression_to_qubo(qubo, right, self.penalty_weight)?;
445
446                        // Cross term: -2*left*right
447                        if let (Expression::Variable(v1), Expression::Variable(v2)) = (left, right)
448                        {
449                            if let (Some(&idx1), Some(&idx2)) = (
450                                self.registry.var_indices.get(v1),
451                                self.registry.var_indices.get(v2),
452                            ) {
453                                qubo[[idx1, idx2]] -= self.penalty_weight;
454                                qubo[[idx2, idx1]] -= self.penalty_weight;
455                            }
456                        }
457                    }
458                    ComparisonOp::LessEqual => {
459                        // For a <= b, add slack variable: a + s = b, where s >= 0
460                        let slack_name = format!("slack_{}", self.registry.num_vars);
461                        let _slack_idx = self
462                            .registry
463                            .register_variable(&slack_name, VariableDomain::Binary);
464
465                        // Convert to equality: (a + s - b)^2
466                        let penalty_expr = Expression::BinaryOp {
467                            op: BinaryOperator::Subtract,
468                            left: Box::new(Expression::BinaryOp {
469                                op: BinaryOperator::Add,
470                                left: Box::new(left.clone()),
471                                right: Box::new(Expression::Variable(slack_name)),
472                            }),
473                            right: Box::new(right.clone()),
474                        };
475
476                        // Add squared penalty
477                        self.add_squared_penalty_to_qubo(qubo, &penalty_expr)?;
478                    }
479                    ComparisonOp::GreaterEqual => {
480                        // For a >= b, equivalent to b <= a
481                        let slack_name = format!("slack_{}", self.registry.num_vars);
482                        let _slack_idx = self
483                            .registry
484                            .register_variable(&slack_name, VariableDomain::Binary);
485
486                        // Convert to equality: (b + s - a)^2
487                        let penalty_expr = Expression::BinaryOp {
488                            op: BinaryOperator::Subtract,
489                            left: Box::new(Expression::BinaryOp {
490                                op: BinaryOperator::Add,
491                                left: Box::new(right.clone()),
492                                right: Box::new(Expression::Variable(slack_name)),
493                            }),
494                            right: Box::new(left.clone()),
495                        };
496
497                        // Add squared penalty
498                        self.add_squared_penalty_to_qubo(qubo, &penalty_expr)?;
499                    }
500                    _ => {
501                        return Err(CompileError {
502                            message: format!("Unsupported comparison operator: {op:?}"),
503                            context: "add_constraint_penalty".to_string(),
504                        });
505                    }
506                }
507            }
508            _ => {
509                return Err(CompileError {
510                    message: "Complex constraints not yet supported".to_string(),
511                    context: "add_constraint_penalty".to_string(),
512                });
513            }
514        }
515        Ok(())
516    }
517
518    fn add_squared_penalty_to_qubo(
519        &mut self,
520        qubo: &mut Array2<f64>,
521        expr: &Expression,
522    ) -> Result<(), CompileError> {
523        // For a squared penalty (expr)^2, we expand it and add to QUBO
524        // This is a simplified implementation for basic expressions
525        match expr {
526            Expression::Variable(name) => {
527                if let Some(&idx) = self.registry.var_indices.get(name) {
528                    qubo[[idx, idx]] += self.penalty_weight;
529                }
530            }
531            Expression::BinaryOp { op, left, right } => {
532                match op {
533                    BinaryOperator::Add => {
534                        // (a + b)^2 = a^2 + 2ab + b^2
535                        self.add_squared_penalty_to_qubo(qubo, left)?;
536                        self.add_squared_penalty_to_qubo(qubo, right)?;
537                        self.add_cross_term_penalty(qubo, left, right, 2.0)?;
538                    }
539                    BinaryOperator::Subtract => {
540                        // (a - b)^2 = a^2 - 2ab + b^2
541                        self.add_squared_penalty_to_qubo(qubo, left)?;
542                        self.add_squared_penalty_to_qubo(qubo, right)?;
543                        self.add_cross_term_penalty(qubo, left, right, -2.0)?;
544                    }
545                    _ => {
546                        return Err(CompileError {
547                            message: "Complex penalty expressions not yet supported".to_string(),
548                            context: "add_squared_penalty_to_qubo".to_string(),
549                        });
550                    }
551                }
552            }
553            _ => {
554                return Err(CompileError {
555                    message: "Unsupported penalty expression type".to_string(),
556                    context: "add_squared_penalty_to_qubo".to_string(),
557                });
558            }
559        }
560        Ok(())
561    }
562
563    fn add_cross_term_penalty(
564        &mut self,
565        qubo: &mut Array2<f64>,
566        left: &Expression,
567        right: &Expression,
568        coefficient: f64,
569    ) -> Result<(), CompileError> {
570        if let (Expression::Variable(v1), Expression::Variable(v2)) = (left, right) {
571            if let (Some(&idx1), Some(&idx2)) = (
572                self.registry.var_indices.get(v1),
573                self.registry.var_indices.get(v2),
574            ) {
575                let penalty = self.penalty_weight * coefficient;
576                qubo[[idx1, idx2]] += penalty / 2.0;
577                qubo[[idx2, idx1]] += penalty / 2.0;
578            }
579        }
580        Ok(())
581    }
582}
583
584#[cfg(test)]
585mod tests {
586    use super::*;
587    use crate::problem_dsl::types::VarType;
588
589    #[test]
590    fn test_simple_binary_compilation() {
591        // Create a simple AST manually
592        let ast = AST::Program {
593            declarations: vec![
594                Declaration::Variable {
595                    name: "x".to_string(),
596                    var_type: VarType::Binary,
597                    domain: None,
598                    attributes: HashMap::new(),
599                },
600                Declaration::Variable {
601                    name: "y".to_string(),
602                    var_type: VarType::Binary,
603                    domain: None,
604                    attributes: HashMap::new(),
605                },
606            ],
607            objective: Objective::Minimize(Expression::BinaryOp {
608                op: BinaryOperator::Add,
609                left: Box::new(Expression::Variable("x".to_string())),
610                right: Box::new(Expression::Variable("y".to_string())),
611            }),
612            constraints: vec![],
613        };
614
615        let options = CompilerOptions::default();
616        let result = compile_to_qubo(&ast, &options);
617
618        assert!(result.is_ok());
619        let qubo = result.expect("compilation should succeed for valid binary program");
620        assert_eq!(qubo.shape(), &[2, 2]);
621        assert_eq!(qubo[[0, 0]], 1.0); // x coefficient
622        assert_eq!(qubo[[1, 1]], 1.0); // y coefficient
623    }
624
625    #[test]
626    fn test_quadratic_term_compilation() {
627        // Test x*y term
628        let ast = AST::Program {
629            declarations: vec![
630                Declaration::Variable {
631                    name: "x".to_string(),
632                    var_type: VarType::Binary,
633                    domain: None,
634                    attributes: HashMap::new(),
635                },
636                Declaration::Variable {
637                    name: "y".to_string(),
638                    var_type: VarType::Binary,
639                    domain: None,
640                    attributes: HashMap::new(),
641                },
642            ],
643            objective: Objective::Minimize(Expression::BinaryOp {
644                op: BinaryOperator::Multiply,
645                left: Box::new(Expression::Variable("x".to_string())),
646                right: Box::new(Expression::Variable("y".to_string())),
647            }),
648            constraints: vec![],
649        };
650
651        let options = CompilerOptions::default();
652        let result = compile_to_qubo(&ast, &options);
653
654        assert!(result.is_ok());
655        let qubo = result.expect("compilation should succeed for quadratic term");
656        assert_eq!(qubo.shape(), &[2, 2]);
657        assert_eq!(qubo[[0, 1]], 0.5); // x*y coefficient (split)
658        assert_eq!(qubo[[1, 0]], 0.5); // y*x coefficient (split)
659    }
660
661    #[test]
662    fn test_equality_constraint() {
663        // Test x == y constraint
664        let ast = AST::Program {
665            declarations: vec![
666                Declaration::Variable {
667                    name: "x".to_string(),
668                    var_type: VarType::Binary,
669                    domain: None,
670                    attributes: HashMap::new(),
671                },
672                Declaration::Variable {
673                    name: "y".to_string(),
674                    var_type: VarType::Binary,
675                    domain: None,
676                    attributes: HashMap::new(),
677                },
678            ],
679            objective: Objective::Minimize(Expression::Literal(Value::Number(0.0))),
680            constraints: vec![Constraint {
681                name: None,
682                expression: ConstraintExpression::Comparison {
683                    left: Expression::Variable("x".to_string()),
684                    op: ComparisonOp::Equal,
685                    right: Expression::Variable("y".to_string()),
686                },
687                tags: vec![],
688            }],
689        };
690
691        let options = CompilerOptions::default();
692        let result = compile_to_qubo(&ast, &options);
693
694        assert!(result.is_ok());
695        let qubo = result.expect("compilation should succeed for equality constraint");
696        assert_eq!(qubo.shape(), &[2, 2]);
697        // For x == y, penalty is (x - y)^2 = x^2 - 2xy + y^2
698        assert_eq!(qubo[[0, 0]], 1000.0); // x^2 term with penalty weight
699        assert_eq!(qubo[[1, 1]], 1000.0); // y^2 term with penalty weight
700        assert_eq!(qubo[[0, 1]], -1000.0); // -xy term
701        assert_eq!(qubo[[1, 0]], -1000.0); // -yx term
702    }
703}