1#![allow(dead_code)]
8
9use scirs2_core::ndarray::Array;
10use std::collections::{HashMap, HashSet};
11
12#[cfg(feature = "scirs")]
13use crate::scirs_stub;
14
15#[cfg(feature = "dwave")]
16use quantrs2_symengine_pure::Expression as SymEngineExpression;
17
18#[cfg(feature = "dwave")]
19type Expr = SymEngineExpression;
20use thiserror::Error;
21
22use quantrs2_anneal::QuboError;
23
24#[cfg(feature = "dwave")]
26pub mod expr {
27 use quantrs2_symengine_pure::Expression as SymEngineExpression;
28
29 pub type Expr = SymEngineExpression;
30
31 pub fn constant(value: f64) -> Expr {
32 SymEngineExpression::from(value)
33 }
34
35 pub fn var(name: &str) -> Expr {
36 SymEngineExpression::symbol(name)
37 }
38}
39
40#[cfg(not(feature = "dwave"))]
41pub mod expr {
42 use super::SimpleExpr;
43
44 pub type Expr = SimpleExpr;
45
46 pub const fn constant(value: f64) -> Expr {
47 SimpleExpr::constant(value)
48 }
49
50 pub fn var(name: &str) -> Expr {
51 SimpleExpr::var(name)
52 }
53}
54
55#[derive(Error, Debug)]
57pub enum CompileError {
58 #[error("Invalid expression: {0}")]
60 InvalidExpression(String),
61
62 #[error("Term has degree {0}, but maximum supported is {1}")]
64 DegreeTooHigh(usize, usize),
65
66 #[error("QUBO error: {0}")]
68 QuboError(#[from] QuboError),
69
70 #[error("Symengine error: {0}")]
72 SymengineError(String),
73}
74
75pub type CompileResult<T> = Result<T, CompileError>;
77
78#[cfg(not(feature = "dwave"))]
80#[derive(Debug, Clone)]
81pub enum SimpleExpr {
82 Var(String),
84 Const(f64),
86 Add(Box<Self>, Box<Self>),
88 Mul(Box<Self>, Box<Self>),
90 Pow(Box<Self>, i32),
92}
93
94#[cfg(not(feature = "dwave"))]
95impl SimpleExpr {
96 pub fn var(name: &str) -> Self {
98 Self::Var(name.to_string())
99 }
100
101 pub const fn constant(value: f64) -> Self {
103 Self::Const(value)
104 }
105}
106
107#[cfg(not(feature = "dwave"))]
108impl std::ops::Add for SimpleExpr {
109 type Output = Self;
110
111 fn add(self, rhs: Self) -> Self::Output {
112 Self::Add(Box::new(self), Box::new(rhs))
113 }
114}
115
116#[cfg(not(feature = "dwave"))]
117impl std::ops::Mul for SimpleExpr {
118 type Output = Self;
119
120 fn mul(self, rhs: Self) -> Self::Output {
121 Self::Mul(Box::new(self), Box::new(rhs))
122 }
123}
124
125#[cfg(not(feature = "dwave"))]
126impl std::iter::Sum for SimpleExpr {
127 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
128 iter.fold(Self::Const(0.0), |acc, x| acc + x)
129 }
130}
131
132#[cfg(feature = "dwave")]
134#[derive(Debug, Clone)]
135pub struct Model {
136 variables: HashSet<String>,
138 objective: Option<Expr>,
140 constraints: Vec<Constraint>,
142}
143
144#[cfg(feature = "dwave")]
146#[derive(Debug, Clone)]
147enum Constraint {
148 Equality {
150 name: String,
151 expr: Expr,
152 value: f64,
153 },
154 LessEqual {
156 name: String,
157 expr: Expr,
158 value: f64,
159 },
160 AtMostOne { name: String, variables: Vec<Expr> },
162 ImpliesAny {
164 name: String,
165 conditions: Vec<Expr>,
166 result: Expr,
167 },
168}
169
170#[cfg(feature = "dwave")]
171impl Default for Model {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[cfg(feature = "dwave")]
178impl Model {
179 pub fn new() -> Self {
181 Self {
182 variables: HashSet::new(),
183 objective: None,
184 constraints: Vec::new(),
185 }
186 }
187
188 pub fn add_variable(&mut self, name: &str) -> CompileResult<Expr> {
190 self.variables.insert(name.to_string());
191 Ok(SymEngineExpression::symbol(name))
192 }
193
194 pub fn set_objective(&mut self, expr: Expr) {
196 self.objective = Some(expr);
197 }
198
199 pub fn add_constraint_eq_one(&mut self, name: &str, variables: Vec<Expr>) -> CompileResult<()> {
201 let sum_expr = variables
203 .iter()
204 .fold(Expr::from(0), |acc, v| acc + v.clone());
205 self.constraints.push(Constraint::Equality {
206 name: name.to_string(),
207 expr: sum_expr,
208 value: 1.0,
209 });
210 Ok(())
211 }
212
213 pub fn add_constraint_at_most_one(
215 &mut self,
216 name: &str,
217 variables: Vec<Expr>,
218 ) -> CompileResult<()> {
219 self.constraints.push(Constraint::AtMostOne {
220 name: name.to_string(),
221 variables,
222 });
223 Ok(())
224 }
225
226 pub fn add_constraint_implies_any(
228 &mut self,
229 name: &str,
230 conditions: Vec<Expr>,
231 result: Expr,
232 ) -> CompileResult<()> {
233 self.constraints.push(Constraint::ImpliesAny {
234 name: name.to_string(),
235 conditions,
236 result,
237 });
238 Ok(())
239 }
240
241 pub fn compile(&self) -> CompileResult<CompiledModel> {
243 let mut final_expr = self.objective.clone().unwrap_or_else(|| Expr::from(0));
245
246 let penalty_weight = 10.0;
248
249 for constraint in &self.constraints {
251 match constraint {
252 Constraint::Equality { expr, value, .. } => {
253 let diff = expr.clone() - Expr::from(*value);
255 final_expr = final_expr + Expr::from(penalty_weight) * diff.clone() * diff;
256 }
257 #[cfg(feature = "dwave")]
258 Constraint::LessEqual { expr, value, .. } => {
259 let excess = expr.clone() - Expr::from(*value);
262 final_expr = final_expr + Expr::from(penalty_weight) * excess.clone() * excess;
263 }
264 Constraint::AtMostOne { variables, .. } => {
265 for i in 0..variables.len() {
267 for j in (i + 1)..variables.len() {
268 final_expr = final_expr
269 + Expr::from(penalty_weight)
270 * variables[i].clone()
271 * variables[j].clone();
272 }
273 }
274 }
275 Constraint::ImpliesAny {
276 conditions, result, ..
277 } => {
278 let conditions_sum = conditions
281 .iter()
282 .fold(Expr::from(0), |acc, c| acc + c.clone());
283 final_expr = final_expr
285 + Expr::from(penalty_weight)
286 * conditions_sum
287 * (Expr::from(1) - result.clone());
288 }
289 }
290 }
291
292 let mut compiler = Compile::new(final_expr);
294 let ((qubo_matrix, var_map), offset) = compiler.get_qubo()?;
295
296 Ok(CompiledModel {
297 qubo_matrix,
298 var_map,
299 offset,
300 constraints: self.constraints.clone(),
301 })
302 }
303}
304
305#[cfg(feature = "dwave")]
307#[derive(Debug, Clone)]
308pub struct CompiledModel {
309 pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
311 pub var_map: HashMap<String, usize>,
313 pub offset: f64,
315 constraints: Vec<Constraint>,
317}
318
319#[cfg(feature = "dwave")]
320impl CompiledModel {
321 pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
323 use quantrs2_anneal::ising::QuboModel;
324
325 let mut qubo = QuboModel::new(self.var_map.len());
326
327 qubo.offset = self.offset;
329
330 for i in 0..self.qubo_matrix.nrows() {
332 for j in i..self.qubo_matrix.ncols() {
333 let value = self.qubo_matrix[[i, j]];
334 if value.abs() > 1e-10 {
335 if i == j {
336 qubo.set_linear(i, value)
339 .expect("index within bounds from matrix dimensions");
340 } else {
341 qubo.set_quadratic(i, j, value)
344 .expect("indices within bounds from matrix dimensions");
345 }
346 }
347 }
348 }
349
350 qubo
351 }
352
353 pub fn count_constraint_violations(&self, assignments: &HashMap<String, bool>) -> usize {
358 let float_vals: HashMap<String, f64> = assignments
359 .iter()
360 .map(|(k, &v)| (k.clone(), if v { 1.0 } else { 0.0 }))
361 .collect();
362
363 let mut violations = 0usize;
364
365 for constraint in &self.constraints {
366 let violated = match constraint {
367 Constraint::Equality { expr, value, .. } => match expr.eval(&float_vals) {
368 Ok(result) => (result - value).abs() > 1e-6,
369 Err(_) => false,
370 },
371 Constraint::LessEqual { expr, value, .. } => match expr.eval(&float_vals) {
372 Ok(result) => result > value + 1e-6,
373 Err(_) => false,
374 },
375 Constraint::AtMostOne { variables, .. } => {
376 let count: f64 = variables
377 .iter()
378 .filter_map(|v| v.eval(&float_vals).ok())
379 .filter(|&val| val > 0.5)
380 .count() as f64;
381 count > 1.0 + 1e-6
382 }
383 Constraint::ImpliesAny {
384 conditions, result, ..
385 } => {
386 let any_condition_true = conditions
387 .iter()
388 .any(|c| c.eval(&float_vals).map(|val| val > 0.5).unwrap_or(false));
389 if any_condition_true {
390 match result.eval(&float_vals) {
391 Ok(val) => val < 0.5,
392 Err(_) => false,
393 }
394 } else {
395 false
396 }
397 }
398 };
399 if violated {
400 violations += 1;
401 }
402 }
403
404 violations
405 }
406
407 pub fn num_constraints(&self) -> usize {
409 self.constraints.len()
410 }
411}
412
413#[cfg(not(feature = "dwave"))]
415#[derive(Debug, Clone)]
416pub struct Model {
417 variables: HashSet<String>,
419 objective: Option<SimpleExpr>,
421 constraints: Vec<Constraint>,
423}
424
425#[cfg(not(feature = "dwave"))]
427#[derive(Debug, Clone)]
428enum Constraint {
429 Equality {
431 name: String,
432 expr: SimpleExpr,
433 value: f64,
434 },
435 AtMostOne {
437 name: String,
438 variables: Vec<SimpleExpr>,
439 },
440 ImpliesAny {
442 name: String,
443 conditions: Vec<SimpleExpr>,
444 result: SimpleExpr,
445 },
446}
447
448#[cfg(not(feature = "dwave"))]
449impl Default for Model {
450 fn default() -> Self {
451 Self::new()
452 }
453}
454
455#[cfg(not(feature = "dwave"))]
456impl Model {
457 pub fn new() -> Self {
459 Self {
460 variables: HashSet::new(),
461 objective: None,
462 constraints: Vec::new(),
463 }
464 }
465
466 pub fn add_variable(&mut self, name: &str) -> CompileResult<SimpleExpr> {
468 self.variables.insert(name.to_string());
469 Ok(SimpleExpr::var(name))
470 }
471
472 pub fn set_objective(&mut self, expr: SimpleExpr) {
474 self.objective = Some(expr);
475 }
476
477 pub fn add_constraint_eq_one(
479 &mut self,
480 name: &str,
481 variables: Vec<SimpleExpr>,
482 ) -> CompileResult<()> {
483 let sum_expr = variables.into_iter().sum();
484 self.constraints.push(Constraint::Equality {
485 name: name.to_string(),
486 expr: sum_expr,
487 value: 1.0,
488 });
489 Ok(())
490 }
491
492 pub fn add_constraint_at_most_one(
494 &mut self,
495 name: &str,
496 variables: Vec<SimpleExpr>,
497 ) -> CompileResult<()> {
498 self.constraints.push(Constraint::AtMostOne {
499 name: name.to_string(),
500 variables,
501 });
502 Ok(())
503 }
504
505 pub fn add_constraint_implies_any(
507 &mut self,
508 name: &str,
509 conditions: Vec<SimpleExpr>,
510 result: SimpleExpr,
511 ) -> CompileResult<()> {
512 self.constraints.push(Constraint::ImpliesAny {
513 name: name.to_string(),
514 conditions,
515 result,
516 });
517 Ok(())
518 }
519
520 pub fn compile(&self) -> CompileResult<CompiledModel> {
522 let mut qubo_terms: HashMap<(String, String), f64> = HashMap::new();
524 let mut offset = 0.0;
525 let penalty_weight = 10.0;
526
527 if let Some(ref obj) = self.objective {
529 self.add_expr_to_qubo(obj, 1.0, &mut qubo_terms, &mut offset)?;
530 }
531
532 for constraint in &self.constraints {
534 match constraint {
535 Constraint::Equality { expr, value, .. } => {
536 self.add_expr_squared_to_qubo(
539 expr,
540 penalty_weight,
541 &mut qubo_terms,
542 &mut offset,
543 )?;
544 self.add_expr_to_qubo(
545 expr,
546 -2.0 * penalty_weight * value,
547 &mut qubo_terms,
548 &mut offset,
549 )?;
550 offset += penalty_weight * value * value;
551 }
552 Constraint::AtMostOne { variables, .. } => {
553 for i in 0..variables.len() {
555 for j in (i + 1)..variables.len() {
556 if let (SimpleExpr::Var(vi), SimpleExpr::Var(vj)) =
557 (&variables[i], &variables[j])
558 {
559 let key = if vi < vj {
560 (vi.clone(), vj.clone())
561 } else {
562 (vj.clone(), vi.clone())
563 };
564 *qubo_terms.entry(key).or_insert(0.0) += penalty_weight;
565 }
566 }
567 }
568 }
569 Constraint::ImpliesAny {
570 conditions, result, ..
571 } => {
572 for cond in conditions {
574 if let (SimpleExpr::Var(c), SimpleExpr::Var(r)) = (cond, result) {
575 let key = if c < r {
576 (c.clone(), r.clone())
577 } else {
578 (r.clone(), c.clone())
579 };
580 *qubo_terms.entry(key).or_insert(0.0) -= penalty_weight;
581 }
582 if let SimpleExpr::Var(c) = cond {
584 *qubo_terms.entry((c.clone(), c.clone())).or_insert(0.0) +=
585 penalty_weight;
586 }
587 }
588 }
589 }
590 }
591
592 let all_vars: HashSet<String> = qubo_terms
594 .keys()
595 .flat_map(|(v1, v2)| vec![v1.clone(), v2.clone()])
596 .collect();
597 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
598 sorted_vars.sort();
599
600 let var_map: HashMap<String, usize> = sorted_vars
601 .iter()
602 .enumerate()
603 .map(|(i, v)| (v.clone(), i))
604 .collect();
605
606 let n = var_map.len();
607 let mut matrix = Array::zeros((n, n));
608
609 for ((v1, v2), coeff) in qubo_terms {
610 let i = var_map[&v1];
611 let j = var_map[&v2];
612 if i == j {
613 matrix[[i, i]] += coeff;
614 } else {
615 matrix[[i, j]] += coeff / 2.0;
616 matrix[[j, i]] += coeff / 2.0;
617 }
618 }
619
620 Ok(CompiledModel {
621 qubo_matrix: matrix,
622 var_map,
623 offset,
624 constraints: self.constraints.clone(),
625 })
626 }
627
628 fn add_expr_to_qubo(
630 &self,
631 expr: &SimpleExpr,
632 coeff: f64,
633 terms: &mut HashMap<(String, String), f64>,
634 offset: &mut f64,
635 ) -> CompileResult<()> {
636 match expr {
637 SimpleExpr::Var(name) => {
638 *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
639 }
640 SimpleExpr::Const(val) => {
641 *offset += coeff * val;
642 }
643 SimpleExpr::Add(left, right) => {
644 self.add_expr_to_qubo(left, coeff, terms, offset)?;
645 self.add_expr_to_qubo(right, coeff, terms, offset)?;
646 }
647 SimpleExpr::Mul(left, right) => {
648 if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
649 {
650 let key = if v1 < v2 {
651 (v1.clone(), v2.clone())
652 } else {
653 (v2.clone(), v1.clone())
654 };
655 *terms.entry(key).or_insert(0.0) += coeff;
656 } else if let (SimpleExpr::Const(c), var) | (var, SimpleExpr::Const(c)) =
657 (left.as_ref(), right.as_ref())
658 {
659 self.add_expr_to_qubo(var, coeff * c, terms, offset)?;
660 }
661 }
662 SimpleExpr::Pow(base, exp) => {
663 if *exp == 2 && matches!(base.as_ref(), SimpleExpr::Var(_)) {
664 self.add_expr_to_qubo(base, coeff, terms, offset)?;
666 }
667 }
668 }
669 Ok(())
670 }
671
672 fn add_expr_squared_to_qubo(
674 &self,
675 expr: &SimpleExpr,
676 coeff: f64,
677 terms: &mut HashMap<(String, String), f64>,
678 offset: &mut f64,
679 ) -> CompileResult<()> {
680 match expr {
682 SimpleExpr::Var(name) => {
683 *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
685 }
686 SimpleExpr::Add(left, right) => {
687 self.add_expr_squared_to_qubo(left, coeff, terms, offset)?;
689 self.add_expr_squared_to_qubo(right, coeff, terms, offset)?;
690 if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
692 {
693 let key = if v1 < v2 {
694 (v1.clone(), v2.clone())
695 } else {
696 (v2.clone(), v1.clone())
697 };
698 *terms.entry(key).or_insert(0.0) += 2.0 * coeff;
699 }
700 }
701 _ => {}
702 }
703 Ok(())
704 }
705}
706
707#[cfg(not(feature = "dwave"))]
709#[derive(Debug, Clone)]
710pub struct CompiledModel {
711 pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
713 pub var_map: HashMap<String, usize>,
715 pub offset: f64,
717 constraints: Vec<Constraint>,
719}
720
721#[cfg(not(feature = "dwave"))]
722impl CompiledModel {
723 pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
725 use quantrs2_anneal::ising::QuboModel;
726
727 let mut qubo = QuboModel::new(self.var_map.len());
728
729 qubo.offset = self.offset;
731
732 for i in 0..self.qubo_matrix.nrows() {
734 for j in i..self.qubo_matrix.ncols() {
735 let value = self.qubo_matrix[[i, j]];
736 if value.abs() > 1e-10 {
737 if i == j {
738 qubo.set_linear(i, value)
741 .expect("index within bounds from matrix dimensions");
742 } else {
743 qubo.set_quadratic(i, j, value)
746 .expect("indices within bounds from matrix dimensions");
747 }
748 }
749 }
750 }
751
752 qubo
753 }
754}
755
756#[cfg(feature = "dwave")]
761pub struct Compile {
762 expr: Expr,
764}
765
766#[cfg(feature = "dwave")]
767impl Compile {
768 pub fn new<T: Into<Expr>>(expr: T) -> Self {
770 Self { expr: expr.into() }
771 }
772
773 pub fn get_qubo(
784 &self,
785 ) -> CompileResult<(
786 (
787 Array<f64, scirs2_core::ndarray::Ix2>,
788 HashMap<String, usize>,
789 ),
790 f64,
791 )> {
792 #[cfg(feature = "scirs")]
793 {
794 self.get_qubo_scirs()
795 }
796 #[cfg(not(feature = "scirs"))]
797 {
798 self.get_qubo_standard()
799 }
800 }
801
802 fn get_qubo_standard(
804 &self,
805 ) -> CompileResult<(
806 (
807 Array<f64, scirs2_core::ndarray::Ix2>,
808 HashMap<String, usize>,
809 ),
810 f64,
811 )> {
812 let expr = self.expr.expand();
814
815 let expr = replace_squared_terms(&expr)?;
818
819 let (coeffs, offset) = extract_coefficients(&expr)?;
821
822 let max_degree = coeffs.keys().map(|vars| vars.len()).max().unwrap_or(0);
824 if max_degree > 2 {
825 return Err(CompileError::DegreeTooHigh(max_degree, 2));
826 }
827
828 let (matrix, var_map) = build_qubo_matrix(&coeffs)?;
830
831 Ok(((matrix, var_map), offset))
832 }
833
834 #[cfg(feature = "scirs")]
836 fn get_qubo_scirs(
837 &self,
838 ) -> CompileResult<(
839 (
840 Array<f64, scirs2_core::ndarray::Ix2>,
841 HashMap<String, usize>,
842 ),
843 f64,
844 )> {
845 let ((matrix, var_map), offset) = self.get_qubo_standard()?;
847
848 let enhanced_matrix = crate::scirs_stub::enhance_qubo_matrix(&matrix);
850
851 Ok(((enhanced_matrix, var_map), offset))
852 }
853
854 pub fn get_hobo(
865 &self,
866 ) -> CompileResult<(
867 (
868 Array<f64, scirs2_core::ndarray::IxDyn>,
869 HashMap<String, usize>,
870 ),
871 f64,
872 )> {
873 let mut expr = self.expr.expand();
875
876 let max_degree = calc_highest_degree(&expr)?;
878
879 let mut expr = replace_squared_terms(&expr)?;
881
882 let mut expr = expr.expand();
884
885 let (coeffs, offset) = extract_coefficients(&expr)?;
887
888 let (tensor, var_map) = build_hobo_tensor(&coeffs, max_degree)?;
890
891 Ok(((tensor, var_map), offset))
892 }
893}
894
895#[cfg(feature = "dwave")]
897fn calc_highest_degree(expr: &Expr) -> CompileResult<usize> {
898 if expr.is_symbol() {
900 return Ok(1);
901 }
902
903 if expr.is_number() {
905 return Ok(0);
906 }
907
908 if expr.is_neg() {
910 let inner = expr.as_neg().expect("is_neg() was true");
912 return calc_highest_degree(&inner);
913 }
914
915 if expr.is_pow() {
917 let (base, exp) = expr.as_pow().expect("is_pow() was true");
919
920 if base.is_symbol() && exp.is_number() {
922 let exp_val = match exp.to_f64() {
923 Some(n) => n,
924 None => {
925 return Err(CompileError::InvalidExpression(
926 "Invalid exponent".to_string(),
927 ))
928 }
929 };
930
931 if exp_val.is_sign_positive() && exp_val.fract() == 0.0 {
933 return Ok(exp_val as usize);
934 }
935 }
936
937 let base_degree = calc_highest_degree(&base)?;
939 let exp_degree = if exp.is_number() {
940 match exp.to_f64() {
941 Some(n) => {
942 if n.is_sign_positive() && n.fract() == 0.0 {
943 n as usize
944 } else {
945 0 }
947 }
948 None => 0,
949 }
950 } else {
951 0 };
953
954 return Ok(base_degree * exp_degree);
955 }
956
957 if expr.is_mul() {
959 let mut total_degree = 0;
960 for factor in expr.as_mul().expect("is_mul() was true") {
962 total_degree += calc_highest_degree(&factor)?;
963 }
964 return Ok(total_degree);
965 }
966
967 if expr.is_add() {
969 let mut max_degree = 0;
970 for term in expr.as_add().expect("is_add() was true") {
972 let term_degree = calc_highest_degree(&term)?;
973 max_degree = std::cmp::max(max_degree, term_degree);
974 }
975 return Ok(max_degree);
976 }
977
978 let expr_str = format!("{expr}");
980 if expr_str.contains('+') || expr_str.contains('-') {
981 let mut max_degree = 0;
985
986 let parts: Vec<&str> = expr_str.split(['+', '-']).collect();
988
989 for part in parts {
990 let part = part.trim();
991 if part.is_empty() {
992 continue;
993 }
994
995 let degree = if part.contains("**") || part.contains('^') {
997 let exp_str = part
1000 .split("**")
1001 .nth(1)
1002 .or_else(|| part.split('^').nth(1))
1003 .unwrap_or("2")
1004 .trim();
1005 exp_str.parse::<usize>().unwrap_or(2)
1006 } else if part.contains('*') {
1007 let factors: Vec<&str> = part.split('*').collect();
1009 let mut var_count = 0;
1010 for factor in factors {
1011 let factor = factor.trim();
1012 if !factor.is_empty() && factor.parse::<f64>().is_err() {
1014 var_count += 1;
1015 }
1016 }
1017 var_count
1018 } else if part.parse::<f64>().is_err() && !part.is_empty() {
1019 1
1021 } else {
1022 0
1024 };
1025
1026 max_degree = std::cmp::max(max_degree, degree);
1027 }
1028
1029 return Ok(max_degree);
1030 }
1031
1032 Err(CompileError::InvalidExpression(format!(
1035 "Can't determine degree of: {expr}"
1036 )))
1037}
1038
1039#[cfg(feature = "dwave")]
1041fn replace_squared_terms(expr: &Expr) -> CompileResult<Expr> {
1042 if expr.is_symbol() || expr.is_number() {
1046 return Ok(expr.clone());
1047 }
1048
1049 if expr.is_neg() {
1051 let inner = expr.as_neg().expect("is_neg() was true");
1053 let new_inner = replace_squared_terms(&inner)?;
1054 return Ok(-new_inner);
1055 }
1056
1057 if expr.is_pow() {
1059 let (base, exp) = expr.as_pow().expect("is_pow() was true");
1061
1062 if base.is_symbol() && exp.is_number() {
1064 let exp_val = match exp.to_f64() {
1065 Some(n) => n,
1066 None => {
1067 return Err(CompileError::InvalidExpression(
1068 "Invalid exponent".to_string(),
1069 ))
1070 }
1071 };
1072
1073 if exp_val == 2.0 {
1075 return Ok(base);
1076 }
1077 }
1078
1079 let new_base = replace_squared_terms(&base)?;
1081 return Ok(new_base.pow(&exp));
1082 }
1083
1084 if expr.is_mul() {
1086 let mut new_terms = Vec::new();
1087 for factor in expr.as_mul().expect("is_mul() was true") {
1089 new_terms.push(replace_squared_terms(&factor)?);
1090 }
1091
1092 if new_terms.len() == 2 {
1095 if let (Some(name1), Some(name2)) = (new_terms[0].as_symbol(), new_terms[1].as_symbol())
1096 {
1097 if name1 == name2 {
1098 return Ok(new_terms.remove(0));
1100 }
1101 }
1102 }
1103
1104 if new_terms.is_empty() {
1106 return Ok(Expr::from(1));
1107 }
1108 let mut result = new_terms.remove(0);
1109 for term in new_terms {
1110 result = result * term;
1111 }
1112 return Ok(result);
1113 }
1114
1115 if expr.is_add() {
1117 let mut new_terms = Vec::new();
1118 for term in expr.as_add().expect("is_add() was true") {
1120 new_terms.push(replace_squared_terms(&term)?);
1121 }
1122
1123 if new_terms.is_empty() {
1125 return Ok(Expr::from(0));
1126 }
1127 let mut result = new_terms.remove(0);
1128 for term in new_terms {
1129 result = result + term;
1130 }
1131 return Ok(result);
1132 }
1133
1134 Ok(expr.clone())
1136}
1137
1138#[cfg(feature = "dwave")]
1140fn extract_coefficients(expr: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1141 let mut coeffs = HashMap::new();
1142 let mut offset = 0.0;
1143
1144 if expr.is_add() {
1146 for term in expr.as_add().expect("is_add() was true") {
1148 let (term_coeffs, term_offset) = extract_term_coefficients(&term)?;
1149
1150 for (vars, coeff) in term_coeffs {
1152 *coeffs.entry(vars).or_insert(0.0) += coeff;
1153 }
1154
1155 offset += term_offset;
1157 }
1158 } else {
1159 let expr_str = format!("{expr}");
1161 if expr_str.contains('+') || expr_str.contains('-') {
1162 use regex::Regex;
1165 let re = Regex::new(r"([+-]?)([^+-]+)").expect("static regex pattern is valid");
1167
1168 for caps in re.captures_iter(&expr_str) {
1169 let sign = caps.get(1).map_or("", |m| m.as_str());
1170 let term = caps.get(2).map_or("", |m| m.as_str()).trim();
1171
1172 if term.is_empty() {
1173 continue;
1174 }
1175
1176 let sign_mult = if sign == "-" { -1.0 } else { 1.0 };
1177
1178 if term.contains("**") || term.contains('^') {
1180 let base = if term.contains("**") {
1181 term.split("**").next().unwrap_or(term)
1182 } else {
1183 term.split('^').next().unwrap_or(term)
1184 }
1185 .trim();
1186
1187 let (coeff_mult, var_name) = if base.contains('*') {
1189 let parts: Vec<&str> = base.split('*').collect();
1190 if parts.len() == 2 {
1191 if let Ok(num) = parts[0].trim().parse::<f64>() {
1192 (num, parts[1].trim().to_string())
1193 } else if let Ok(num) = parts[1].trim().parse::<f64>() {
1194 (num, parts[0].trim().to_string())
1195 } else {
1196 (1.0, base.to_string())
1197 }
1198 } else {
1199 (1.0, base.to_string())
1200 }
1201 } else {
1202 (1.0, base.to_string())
1203 };
1204
1205 let vars = vec![var_name.clone()];
1206 *coeffs.entry(vars).or_insert(0.0) += sign_mult * coeff_mult;
1207 } else if term.contains('*') {
1208 let parts: Vec<&str> = term.split('*').collect();
1210 let mut coeff = sign_mult;
1211 let mut vars = Vec::new();
1212
1213 for part in parts {
1214 let part = part.trim();
1215 if let Ok(num) = part.parse::<f64>() {
1216 coeff *= num;
1217 } else {
1218 vars.push(part.to_string());
1220 }
1221 }
1222
1223 vars.sort();
1225 *coeffs.entry(vars).or_insert(0.0) += coeff;
1226 } else if let Ok(num) = term.parse::<f64>() {
1227 offset += sign_mult * num;
1229 } else {
1230 let vars = vec![term.to_string()];
1232 *coeffs.entry(vars).or_insert(0.0) += sign_mult;
1233 }
1234 }
1235 return Ok((coeffs, offset));
1236 }
1237
1238 if coeffs.is_empty() {
1240 let (term_coeffs, term_offset) = extract_term_coefficients(expr)?;
1242
1243 for (vars, coeff) in term_coeffs {
1245 *coeffs.entry(vars).or_insert(0.0) += coeff;
1246 }
1247
1248 offset += term_offset;
1250 }
1251 }
1252
1253 Ok((coeffs, offset))
1254}
1255
1256#[cfg(feature = "dwave")]
1258fn extract_term_coefficients(term: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1259 let mut coeffs = HashMap::new();
1260
1261 if term.is_number() {
1263 let value = match term.to_f64() {
1264 Some(n) => n,
1265 None => {
1266 return Err(CompileError::InvalidExpression(
1267 "Invalid number".to_string(),
1268 ))
1269 }
1270 };
1271 return Ok((coeffs, value));
1272 }
1273
1274 if term.is_add() {
1276 let mut offset = 0.0;
1277 for sub_term in term.as_add().expect("is_add() was true") {
1279 let (sub_coeffs, sub_offset) = extract_term_coefficients(&sub_term)?;
1280 for (vars, coeff) in sub_coeffs {
1281 *coeffs.entry(vars).or_insert(0.0) += coeff;
1282 }
1283 offset += sub_offset;
1284 }
1285 return Ok((coeffs, offset));
1286 }
1287
1288 if term.is_neg() {
1290 let inner = term.as_neg().expect("is_neg() was true");
1292 let (inner_coeffs, inner_offset) = extract_term_coefficients(&inner)?;
1293
1294 for (vars, coeff) in inner_coeffs {
1296 coeffs.insert(vars, -coeff);
1297 }
1298
1299 return Ok((coeffs, -inner_offset));
1300 }
1301
1302 if term.is_symbol() {
1304 let var_name = term.as_symbol().expect("is_symbol() was true");
1306 let vars = vec![var_name.to_string()];
1307 coeffs.insert(vars, 1.0);
1308 return Ok((coeffs, 0.0));
1309 }
1310
1311 if term.is_mul() {
1313 let mut coeff = 1.0;
1314 let mut vars = Vec::new();
1315
1316 let factors = term.as_mul().expect("is_mul() was true");
1318 let mut factor_stack: Vec<_> = factors.into_iter().collect();
1321 while let Some(factor) = factor_stack.pop() {
1322 if factor.is_number() {
1323 let value = match factor.to_f64() {
1325 Some(n) => n,
1326 None => {
1327 return Err(CompileError::InvalidExpression(
1328 "Invalid number in product".to_string(),
1329 ))
1330 }
1331 };
1332 coeff *= value;
1333 } else if factor.is_symbol() {
1334 let var_name = factor.as_symbol().expect("is_symbol() was true");
1337 vars.push(var_name.to_string());
1338 } else if factor.is_mul() {
1339 let sub_factors = factor.as_mul().expect("is_mul() was true");
1341 factor_stack.extend(sub_factors);
1342 } else if factor.is_pow() {
1343 let (base, exp) = factor.as_pow().expect("is_pow() was true");
1345 if base.is_symbol() && exp.is_number() {
1346 let exp_val = exp.to_f64().unwrap_or(0.0);
1347 if exp_val.is_sign_positive() && exp_val.fract() == 0.0 && exp_val >= 1.0 {
1348 let var_name = base.as_symbol().expect("is_symbol() was true");
1350 vars.push(var_name.to_string());
1351 } else {
1352 return Err(CompileError::InvalidExpression(format!(
1353 "Unsupported power in product: {factor}"
1354 )));
1355 }
1356 } else {
1357 return Err(CompileError::InvalidExpression(format!(
1358 "Unsupported power term in product: {factor}"
1359 )));
1360 }
1361 } else {
1362 return Err(CompileError::InvalidExpression(format!(
1364 "Unsupported term in product: {factor}"
1365 )));
1366 }
1367 }
1368
1369 vars.sort();
1371
1372 if vars.is_empty() {
1373 return Ok((coeffs, coeff));
1375 }
1376 coeffs.insert(vars, coeff);
1377
1378 return Ok((coeffs, 0.0));
1379 }
1380
1381 if term.is_pow() {
1383 return Err(CompileError::InvalidExpression(format!(
1384 "Unexpected power term after simplification: {term}"
1385 )));
1386 }
1387
1388 Err(CompileError::InvalidExpression(format!(
1390 "Unsupported term: {term}"
1391 )))
1392}
1393
1394#[allow(dead_code)]
1396fn build_qubo_matrix(
1397 coeffs: &HashMap<Vec<String>, f64>,
1398) -> CompileResult<(
1399 Array<f64, scirs2_core::ndarray::Ix2>,
1400 HashMap<String, usize>,
1401)> {
1402 let mut all_vars = HashSet::new();
1404 for vars in coeffs.keys() {
1405 for var in vars {
1406 all_vars.insert(var.clone());
1407 }
1408 }
1409
1410 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1412 sorted_vars.sort();
1413
1414 let var_map: HashMap<String, usize> = sorted_vars
1416 .iter()
1417 .enumerate()
1418 .map(|(i, var)| (var.clone(), i))
1419 .collect();
1420
1421 let n = var_map.len();
1423
1424 let mut matrix = Array::zeros((n, n));
1426
1427 for (vars, &coeff) in coeffs {
1429 match vars.len() {
1430 0 => {
1431 }
1433 1 => {
1434 let i = *var_map
1437 .get(&vars[0])
1438 .expect("variable exists in var_map built from coeffs");
1439 matrix[[i, i]] += coeff;
1440 }
1441 2 => {
1442 let i = *var_map
1445 .get(&vars[0])
1446 .expect("variable exists in var_map built from coeffs");
1447 let j = *var_map
1448 .get(&vars[1])
1449 .expect("variable exists in var_map built from coeffs");
1450
1451 if i == j {
1453 matrix[[i, i]] += coeff;
1455 } else {
1456 if i <= j {
1458 matrix[[i, j]] += coeff;
1459 } else {
1460 matrix[[j, i]] += coeff;
1461 }
1462 }
1463 }
1464 _ => {
1465 return Err(CompileError::DegreeTooHigh(vars.len(), 2));
1467 }
1468 }
1469 }
1470
1471 Ok((matrix, var_map))
1472}
1473
1474#[allow(dead_code)]
1476fn build_hobo_tensor(
1477 coeffs: &HashMap<Vec<String>, f64>,
1478 max_degree: usize,
1479) -> CompileResult<(
1480 Array<f64, scirs2_core::ndarray::IxDyn>,
1481 HashMap<String, usize>,
1482)> {
1483 let mut all_vars = HashSet::new();
1485 for vars in coeffs.keys() {
1486 for var in vars {
1487 all_vars.insert(var.clone());
1488 }
1489 }
1490
1491 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1493 sorted_vars.sort();
1494
1495 let var_map: HashMap<String, usize> = sorted_vars
1497 .iter()
1498 .enumerate()
1499 .map(|(i, var)| (var.clone(), i))
1500 .collect();
1501
1502 let n = var_map.len();
1504
1505 let shape: Vec<usize> = vec![n; max_degree];
1507
1508 let mut tensor = Array::zeros(scirs2_core::ndarray::IxDyn(&shape));
1510
1511 for (vars, &coeff) in coeffs {
1513 let degree = vars.len();
1514
1515 if degree == 0 {
1516 continue;
1518 }
1519
1520 if degree > max_degree {
1521 return Err(CompileError::DegreeTooHigh(degree, max_degree));
1522 }
1523
1524 let mut indices: Vec<usize> = vars
1527 .iter()
1528 .map(|var| {
1529 *var_map
1530 .get(var)
1531 .expect("variable exists in var_map built from coeffs")
1532 })
1533 .collect();
1534
1535 indices.sort_unstable();
1537
1538 while indices.len() < max_degree {
1540 indices.insert(0, indices[0]); }
1542
1543 let idx = scirs2_core::ndarray::IxDyn(&indices);
1545 tensor[idx] += coeff;
1546 }
1547
1548 Ok((tensor, var_map))
1549}
1550
1551#[cfg(feature = "dwave")]
1556pub struct PieckCompile {
1557 expr: Expr,
1559 verbose: bool,
1561}
1562
1563#[cfg(feature = "dwave")]
1564impl PieckCompile {
1565 pub fn new<T: Into<Expr>>(expr: T, verbose: bool) -> Self {
1567 Self {
1568 expr: expr.into(),
1569 verbose,
1570 }
1571 }
1572
1573 pub fn get_qubo(
1575 &self,
1576 ) -> CompileResult<(
1577 (
1578 Array<f64, scirs2_core::ndarray::Ix2>,
1579 HashMap<String, usize>,
1580 ),
1581 f64,
1582 )> {
1583 Compile::new(self.expr.clone()).get_qubo()
1586 }
1587}