1use super::ast::{
4 AggregationOp, BinaryOperator, ComparisonOp, Constraint, ConstraintExpression, Declaration,
5 Expression, Objective, ObjectiveType, Value, AST,
6};
7use super::error::CompileError;
8use scirs2_core::ndarray::Array2;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone)]
13pub struct CompilerOptions {
14 pub optimization_level: OptimizationLevel,
16 pub target: TargetBackend,
18 pub debug_info: bool,
20 pub warnings_as_errors: bool,
22}
23
24#[derive(Debug, Clone)]
25pub enum OptimizationLevel {
26 None,
27 Basic,
28 Full,
29}
30
31#[derive(Debug, Clone)]
32pub enum TargetBackend {
33 QUBO,
34 Ising,
35 HigherOrder,
36}
37
38impl Default for CompilerOptions {
39 fn default() -> Self {
40 Self {
41 optimization_level: OptimizationLevel::Basic,
42 target: TargetBackend::QUBO,
43 debug_info: false,
44 warnings_as_errors: false,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51struct VariableRegistry {
52 var_indices: HashMap<String, usize>,
54 indexed_var_indices: HashMap<String, HashMap<Vec<usize>, usize>>,
56 num_vars: usize,
58 domains: HashMap<String, VariableDomain>,
60}
61
62#[derive(Debug, Clone)]
63enum VariableDomain {
64 Binary,
65 Integer { min: i32, max: i32 },
66 Continuous { min: f64, max: f64 },
67}
68
69impl VariableRegistry {
70 fn new() -> Self {
71 Self {
72 var_indices: HashMap::new(),
73 indexed_var_indices: HashMap::new(),
74 num_vars: 0,
75 domains: HashMap::new(),
76 }
77 }
78
79 fn register_variable(&mut self, name: &str, domain: VariableDomain) -> usize {
80 if let Some(&idx) = self.var_indices.get(name) {
81 return idx;
82 }
83 let idx = self.num_vars;
84 self.var_indices.insert(name.to_string(), idx);
85 self.domains.insert(name.to_string(), domain);
86 self.num_vars += 1;
87 idx
88 }
89
90 fn register_indexed_variable(
91 &mut self,
92 base_name: &str,
93 indices: Vec<usize>,
94 domain: VariableDomain,
95 ) -> usize {
96 let indexed_map = self
97 .indexed_var_indices
98 .entry(base_name.to_string())
99 .or_default();
100 if let Some(&idx) = indexed_map.get(&indices) {
101 return idx;
102 }
103 let idx = self.num_vars;
104 indexed_map.insert(indices, idx);
105 let full_name = format!("{}_{}", base_name, self.num_vars);
106 self.domains.insert(full_name, domain);
107 self.num_vars += 1;
108 idx
109 }
110}
111
112pub fn compile_to_qubo(ast: &AST, options: &CompilerOptions) -> Result<Array2<f64>, CompileError> {
114 match ast {
115 AST::Program {
116 declarations,
117 objective,
118 constraints,
119 } => {
120 let mut compiler = Compiler::new(options.clone());
121
122 for decl in declarations {
124 compiler.process_declaration(decl)?;
125 }
126
127 let mut qubo = compiler.build_objective_qubo(objective)?;
129
130 for constraint in constraints {
132 compiler.add_constraint_penalty(&mut qubo, constraint)?;
133 }
134
135 Ok(qubo)
136 }
137 _ => Err(CompileError {
138 message: "Can only compile program AST nodes".to_string(),
139 context: "compile_to_qubo".to_string(),
140 }),
141 }
142}
143
144#[derive(Clone)]
146struct Compiler {
147 options: CompilerOptions,
148 registry: VariableRegistry,
149 parameters: HashMap<String, Value>,
150 penalty_weight: f64,
151}
152
153impl Compiler {
154 fn new(options: CompilerOptions) -> Self {
155 Self {
156 options,
157 registry: VariableRegistry::new(),
158 parameters: HashMap::new(),
159 penalty_weight: 1000.0, }
161 }
162
163 fn process_declaration(&mut self, decl: &Declaration) -> Result<(), CompileError> {
164 match decl {
165 Declaration::Variable {
166 name,
167 var_type: _,
168 domain: _,
169 attributes: _,
170 } => {
171 self.registry
173 .register_variable(name, VariableDomain::Binary);
174 Ok(())
175 }
176 Declaration::Parameter { name, value, .. } => {
177 self.parameters.insert(name.clone(), value.clone());
178 Ok(())
179 }
180 Declaration::Set { name, elements } => {
181 self.parameters
183 .insert(name.clone(), Value::Array(elements.clone()));
184 Ok(())
185 }
186 Declaration::Function {
187 name,
188 params: _,
189 body: _,
190 } => {
191 self.parameters.insert(
194 format!("func_{name}"),
195 Value::String(format!("function_{name}")),
196 );
197 Ok(())
198 }
199 }
200 }
201
202 fn build_objective_qubo(&mut self, objective: &Objective) -> Result<Array2<f64>, CompileError> {
203 let num_vars = self.registry.num_vars;
204 let mut qubo = Array2::zeros((num_vars, num_vars));
205
206 match objective {
207 Objective::Minimize(expr) => {
208 self.add_expression_to_qubo(&mut qubo, expr, 1.0)?;
209 }
210 Objective::Maximize(expr) => {
211 self.add_expression_to_qubo(&mut qubo, expr, -1.0)?;
212 }
213 Objective::MultiObjective { objectives } => {
214 for (obj_type, expr, weight) in objectives {
215 let sign = match obj_type {
216 ObjectiveType::Minimize => 1.0,
217 ObjectiveType::Maximize => -1.0,
218 };
219 self.add_expression_to_qubo(&mut qubo, expr, sign * weight)?;
220 }
221 }
222 }
223
224 Ok(qubo)
225 }
226
227 fn add_expression_to_qubo(
228 &mut self,
229 qubo: &mut Array2<f64>,
230 expr: &Expression,
231 coefficient: f64,
232 ) -> Result<(), CompileError> {
233 match expr {
234 Expression::Variable(name) => {
235 if let Some(&idx) = self.registry.var_indices.get(name) {
236 qubo[[idx, idx]] += coefficient;
237 } else {
238 return Err(CompileError {
239 message: format!("Unknown variable: {name}"),
240 context: "add_expression_to_qubo".to_string(),
241 });
242 }
243 }
244 Expression::BinaryOp { op, left, right } => {
245 match op {
246 BinaryOperator::Add => {
247 self.add_expression_to_qubo(qubo, left, coefficient)?;
248 self.add_expression_to_qubo(qubo, right, coefficient)?;
249 }
250 BinaryOperator::Subtract => {
251 self.add_expression_to_qubo(qubo, left, coefficient)?;
252 self.add_expression_to_qubo(qubo, right, -coefficient)?;
253 }
254 BinaryOperator::Multiply => {
255 if let (Expression::Variable(v1), Expression::Variable(v2)) =
257 (left.as_ref(), right.as_ref())
258 {
259 if let (Some(&idx1), Some(&idx2)) = (
260 self.registry.var_indices.get(v1),
261 self.registry.var_indices.get(v2),
262 ) {
263 if idx1 == idx2 {
264 qubo[[idx1, idx1]] += coefficient;
266 } else {
267 qubo[[idx1, idx2]] += coefficient / 2.0;
269 qubo[[idx2, idx1]] += coefficient / 2.0;
270 }
271 }
272 } else {
273 return Err(CompileError {
274 message: "Complex multiplication not yet supported".to_string(),
275 context: "add_expression_to_qubo".to_string(),
276 });
277 }
278 }
279 _ => {
280 return Err(CompileError {
281 message: format!("Unsupported binary operator: {op:?}"),
282 context: "add_expression_to_qubo".to_string(),
283 });
284 }
285 }
286 }
287 Expression::Literal(Value::Number(_)) => {
288 }
291 Expression::Aggregation {
292 op,
293 variables,
294 expression,
295 } => {
296 match op {
297 AggregationOp::Sum => {
298 for (var_name, set_name) in variables {
300 let elements = if let Some(Value::Array(elements)) =
302 self.parameters.get(set_name)
303 {
304 elements.clone()
305 } else {
306 return Err(CompileError {
307 message: format!("Unknown set for aggregation: {set_name}"),
308 context: "add_expression_to_qubo".to_string(),
309 });
310 };
311
312 for (i, element) in elements.iter().enumerate() {
314 let substituted_expr = {
316 let mut compiler = self.clone();
317 compiler.substitute_variable_in_expression(
318 expression, var_name, element, i,
319 )?
320 };
321 let mut qubo_mut = qubo.clone();
322 let mut compiler = self.clone();
323 compiler.add_expression_to_qubo(
324 &mut qubo_mut,
325 &substituted_expr,
326 coefficient,
327 )?;
328 *qubo = qubo_mut;
329 }
330 }
331 }
332 AggregationOp::Product => {
333 let mut product_expr = Expression::Literal(Value::Number(1.0));
335 for (var_name, set_name) in variables {
336 let elements = if let Some(Value::Array(elements)) =
338 self.parameters.get(set_name)
339 {
340 elements.clone()
341 } else {
342 continue; };
344
345 for (i, element) in elements.iter().enumerate() {
346 let substituted_expr = {
347 let mut compiler = self.clone();
348 compiler.substitute_variable_in_expression(
349 expression, var_name, element, i,
350 )?
351 };
352 product_expr = Expression::BinaryOp {
353 op: BinaryOperator::Multiply,
354 left: Box::new(product_expr),
355 right: Box::new(substituted_expr),
356 };
357 }
358 }
359 let mut qubo_mut = qubo.clone();
360 let mut compiler = self.clone();
361 compiler.add_expression_to_qubo(
362 &mut qubo_mut,
363 &product_expr,
364 coefficient,
365 )?;
366 *qubo = qubo_mut;
367 }
368 _ => {
369 return Err(CompileError {
370 message: format!("Unsupported aggregation operator: {op:?}"),
371 context: "add_expression_to_qubo".to_string(),
372 });
373 }
374 }
375 }
376 _ => {
377 return Err(CompileError {
378 message: "Expression type not yet supported".to_string(),
379 context: "add_expression_to_qubo".to_string(),
380 });
381 }
382 }
383 Ok(())
384 }
385
386 fn substitute_variable_in_expression(
387 &mut self,
388 expr: &Expression,
389 var_name: &str,
390 value: &Value,
391 index: usize,
392 ) -> Result<Expression, CompileError> {
393 match expr {
394 Expression::Variable(name) if name == var_name => {
395 match value {
397 Value::Number(_n) => {
398 let indexed_name = format!("{var_name}_{index}");
399 self.registry
400 .register_variable(&indexed_name, VariableDomain::Binary);
401 Ok(Expression::Variable(indexed_name))
402 }
403 _ => Ok(Expression::Literal(value.clone())),
404 }
405 }
406 Expression::Variable(name) => Ok(Expression::Variable(name.clone())),
407 Expression::BinaryOp { op, left, right } => {
408 let new_left =
409 self.substitute_variable_in_expression(left, var_name, value, index)?;
410 let new_right =
411 self.substitute_variable_in_expression(right, var_name, value, index)?;
412 Ok(Expression::BinaryOp {
413 op: op.clone(),
414 left: Box::new(new_left),
415 right: Box::new(new_right),
416 })
417 }
418 Expression::IndexedVar { name, indices } => {
419 let new_indices = indices
420 .iter()
421 .map(|idx| self.substitute_variable_in_expression(idx, var_name, value, index))
422 .collect::<Result<Vec<_>, _>>()?;
423 Ok(Expression::IndexedVar {
424 name: name.clone(),
425 indices: new_indices,
426 })
427 }
428 _ => Ok(expr.clone()),
429 }
430 }
431
432 fn add_constraint_penalty(
433 &mut self,
434 qubo: &mut Array2<f64>,
435 constraint: &Constraint,
436 ) -> Result<(), CompileError> {
437 match &constraint.expression {
438 ConstraintExpression::Comparison { left, op, right } => {
439 match op {
440 ComparisonOp::Equal => {
441 self.add_expression_to_qubo(qubo, left, self.penalty_weight)?;
444 self.add_expression_to_qubo(qubo, right, self.penalty_weight)?;
445
446 if let (Expression::Variable(v1), Expression::Variable(v2)) = (left, right)
448 {
449 if let (Some(&idx1), Some(&idx2)) = (
450 self.registry.var_indices.get(v1),
451 self.registry.var_indices.get(v2),
452 ) {
453 qubo[[idx1, idx2]] -= self.penalty_weight;
454 qubo[[idx2, idx1]] -= self.penalty_weight;
455 }
456 }
457 }
458 ComparisonOp::LessEqual => {
459 let slack_name = format!("slack_{}", self.registry.num_vars);
461 let _slack_idx = self
462 .registry
463 .register_variable(&slack_name, VariableDomain::Binary);
464
465 let penalty_expr = Expression::BinaryOp {
467 op: BinaryOperator::Subtract,
468 left: Box::new(Expression::BinaryOp {
469 op: BinaryOperator::Add,
470 left: Box::new(left.clone()),
471 right: Box::new(Expression::Variable(slack_name)),
472 }),
473 right: Box::new(right.clone()),
474 };
475
476 self.add_squared_penalty_to_qubo(qubo, &penalty_expr)?;
478 }
479 ComparisonOp::GreaterEqual => {
480 let slack_name = format!("slack_{}", self.registry.num_vars);
482 let _slack_idx = self
483 .registry
484 .register_variable(&slack_name, VariableDomain::Binary);
485
486 let penalty_expr = Expression::BinaryOp {
488 op: BinaryOperator::Subtract,
489 left: Box::new(Expression::BinaryOp {
490 op: BinaryOperator::Add,
491 left: Box::new(right.clone()),
492 right: Box::new(Expression::Variable(slack_name)),
493 }),
494 right: Box::new(left.clone()),
495 };
496
497 self.add_squared_penalty_to_qubo(qubo, &penalty_expr)?;
499 }
500 _ => {
501 return Err(CompileError {
502 message: format!("Unsupported comparison operator: {op:?}"),
503 context: "add_constraint_penalty".to_string(),
504 });
505 }
506 }
507 }
508 _ => {
509 return Err(CompileError {
510 message: "Complex constraints not yet supported".to_string(),
511 context: "add_constraint_penalty".to_string(),
512 });
513 }
514 }
515 Ok(())
516 }
517
518 fn add_squared_penalty_to_qubo(
519 &mut self,
520 qubo: &mut Array2<f64>,
521 expr: &Expression,
522 ) -> Result<(), CompileError> {
523 match expr {
526 Expression::Variable(name) => {
527 if let Some(&idx) = self.registry.var_indices.get(name) {
528 qubo[[idx, idx]] += self.penalty_weight;
529 }
530 }
531 Expression::BinaryOp { op, left, right } => {
532 match op {
533 BinaryOperator::Add => {
534 self.add_squared_penalty_to_qubo(qubo, left)?;
536 self.add_squared_penalty_to_qubo(qubo, right)?;
537 self.add_cross_term_penalty(qubo, left, right, 2.0)?;
538 }
539 BinaryOperator::Subtract => {
540 self.add_squared_penalty_to_qubo(qubo, left)?;
542 self.add_squared_penalty_to_qubo(qubo, right)?;
543 self.add_cross_term_penalty(qubo, left, right, -2.0)?;
544 }
545 _ => {
546 return Err(CompileError {
547 message: "Complex penalty expressions not yet supported".to_string(),
548 context: "add_squared_penalty_to_qubo".to_string(),
549 });
550 }
551 }
552 }
553 _ => {
554 return Err(CompileError {
555 message: "Unsupported penalty expression type".to_string(),
556 context: "add_squared_penalty_to_qubo".to_string(),
557 });
558 }
559 }
560 Ok(())
561 }
562
563 fn add_cross_term_penalty(
564 &mut self,
565 qubo: &mut Array2<f64>,
566 left: &Expression,
567 right: &Expression,
568 coefficient: f64,
569 ) -> Result<(), CompileError> {
570 if let (Expression::Variable(v1), Expression::Variable(v2)) = (left, right) {
571 if let (Some(&idx1), Some(&idx2)) = (
572 self.registry.var_indices.get(v1),
573 self.registry.var_indices.get(v2),
574 ) {
575 let penalty = self.penalty_weight * coefficient;
576 qubo[[idx1, idx2]] += penalty / 2.0;
577 qubo[[idx2, idx1]] += penalty / 2.0;
578 }
579 }
580 Ok(())
581 }
582}
583
584#[cfg(test)]
585mod tests {
586 use super::*;
587 use crate::problem_dsl::types::VarType;
588
589 #[test]
590 fn test_simple_binary_compilation() {
591 let ast = AST::Program {
593 declarations: vec![
594 Declaration::Variable {
595 name: "x".to_string(),
596 var_type: VarType::Binary,
597 domain: None,
598 attributes: HashMap::new(),
599 },
600 Declaration::Variable {
601 name: "y".to_string(),
602 var_type: VarType::Binary,
603 domain: None,
604 attributes: HashMap::new(),
605 },
606 ],
607 objective: Objective::Minimize(Expression::BinaryOp {
608 op: BinaryOperator::Add,
609 left: Box::new(Expression::Variable("x".to_string())),
610 right: Box::new(Expression::Variable("y".to_string())),
611 }),
612 constraints: vec![],
613 };
614
615 let options = CompilerOptions::default();
616 let result = compile_to_qubo(&ast, &options);
617
618 assert!(result.is_ok());
619 let qubo = result.expect("compilation should succeed for valid binary program");
620 assert_eq!(qubo.shape(), &[2, 2]);
621 assert_eq!(qubo[[0, 0]], 1.0); assert_eq!(qubo[[1, 1]], 1.0); }
624
625 #[test]
626 fn test_quadratic_term_compilation() {
627 let ast = AST::Program {
629 declarations: vec![
630 Declaration::Variable {
631 name: "x".to_string(),
632 var_type: VarType::Binary,
633 domain: None,
634 attributes: HashMap::new(),
635 },
636 Declaration::Variable {
637 name: "y".to_string(),
638 var_type: VarType::Binary,
639 domain: None,
640 attributes: HashMap::new(),
641 },
642 ],
643 objective: Objective::Minimize(Expression::BinaryOp {
644 op: BinaryOperator::Multiply,
645 left: Box::new(Expression::Variable("x".to_string())),
646 right: Box::new(Expression::Variable("y".to_string())),
647 }),
648 constraints: vec![],
649 };
650
651 let options = CompilerOptions::default();
652 let result = compile_to_qubo(&ast, &options);
653
654 assert!(result.is_ok());
655 let qubo = result.expect("compilation should succeed for quadratic term");
656 assert_eq!(qubo.shape(), &[2, 2]);
657 assert_eq!(qubo[[0, 1]], 0.5); assert_eq!(qubo[[1, 0]], 0.5); }
660
661 #[test]
662 fn test_equality_constraint() {
663 let ast = AST::Program {
665 declarations: vec![
666 Declaration::Variable {
667 name: "x".to_string(),
668 var_type: VarType::Binary,
669 domain: None,
670 attributes: HashMap::new(),
671 },
672 Declaration::Variable {
673 name: "y".to_string(),
674 var_type: VarType::Binary,
675 domain: None,
676 attributes: HashMap::new(),
677 },
678 ],
679 objective: Objective::Minimize(Expression::Literal(Value::Number(0.0))),
680 constraints: vec![Constraint {
681 name: None,
682 expression: ConstraintExpression::Comparison {
683 left: Expression::Variable("x".to_string()),
684 op: ComparisonOp::Equal,
685 right: Expression::Variable("y".to_string()),
686 },
687 tags: vec![],
688 }],
689 };
690
691 let options = CompilerOptions::default();
692 let result = compile_to_qubo(&ast, &options);
693
694 assert!(result.is_ok());
695 let qubo = result.expect("compilation should succeed for equality constraint");
696 assert_eq!(qubo.shape(), &[2, 2]);
697 assert_eq!(qubo[[0, 0]], 1000.0); assert_eq!(qubo[[1, 1]], 1000.0); assert_eq!(qubo[[0, 1]], -1000.0); assert_eq!(qubo[[1, 0]], -1000.0); }
703}