1#![allow(dead_code)]
8
9use scirs2_core::ndarray::Array;
10use std::collections::{HashMap, HashSet};
11
12#[cfg(feature = "scirs")]
13use crate::scirs_stub;
14
15#[cfg(feature = "dwave")]
16use quantrs2_symengine_pure::Expression as SymEngineExpression;
17
18#[cfg(feature = "dwave")]
19type Expr = SymEngineExpression;
20use thiserror::Error;
21
22use quantrs2_anneal::QuboError;
23
24#[cfg(feature = "dwave")]
26pub mod expr {
27 use quantrs2_symengine_pure::Expression as SymEngineExpression;
28
29 pub type Expr = SymEngineExpression;
30
31 pub fn constant(value: f64) -> Expr {
32 SymEngineExpression::from(value)
33 }
34
35 pub fn var(name: &str) -> Expr {
36 SymEngineExpression::symbol(name)
37 }
38}
39
40#[cfg(not(feature = "dwave"))]
41pub mod expr {
42 use super::SimpleExpr;
43
44 pub type Expr = SimpleExpr;
45
46 pub const fn constant(value: f64) -> Expr {
47 SimpleExpr::constant(value)
48 }
49
50 pub fn var(name: &str) -> Expr {
51 SimpleExpr::var(name)
52 }
53}
54
55#[derive(Error, Debug)]
57#[non_exhaustive]
58pub enum CompileError {
59 #[error("Invalid expression: {0}")]
61 InvalidExpression(String),
62
63 #[error("Term has degree {0}, but maximum supported is {1}")]
65 DegreeTooHigh(usize, usize),
66
67 #[error("QUBO error: {0}")]
69 QuboError(#[from] QuboError),
70
71 #[error("Symengine error: {0}")]
73 SymengineError(String),
74}
75
76pub type CompileResult<T> = Result<T, CompileError>;
78
79#[cfg(not(feature = "dwave"))]
81#[derive(Debug, Clone)]
82pub enum SimpleExpr {
83 Var(String),
85 Const(f64),
87 Add(Box<Self>, Box<Self>),
89 Mul(Box<Self>, Box<Self>),
91 Pow(Box<Self>, i32),
93}
94
95#[cfg(not(feature = "dwave"))]
96impl SimpleExpr {
97 pub fn var(name: &str) -> Self {
99 Self::Var(name.to_string())
100 }
101
102 pub const fn constant(value: f64) -> Self {
104 Self::Const(value)
105 }
106}
107
108#[cfg(not(feature = "dwave"))]
109impl std::ops::Add for SimpleExpr {
110 type Output = Self;
111
112 fn add(self, rhs: Self) -> Self::Output {
113 Self::Add(Box::new(self), Box::new(rhs))
114 }
115}
116
117#[cfg(not(feature = "dwave"))]
118impl std::ops::Mul for SimpleExpr {
119 type Output = Self;
120
121 fn mul(self, rhs: Self) -> Self::Output {
122 Self::Mul(Box::new(self), Box::new(rhs))
123 }
124}
125
126#[cfg(not(feature = "dwave"))]
127impl std::iter::Sum for SimpleExpr {
128 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
129 iter.fold(Self::Const(0.0), |acc, x| acc + x)
130 }
131}
132
133#[cfg(feature = "dwave")]
135#[derive(Debug, Clone)]
136pub struct Model {
137 variables: HashSet<String>,
139 objective: Option<Expr>,
141 constraints: Vec<Constraint>,
143}
144
145#[cfg(feature = "dwave")]
147#[derive(Debug, Clone)]
148enum Constraint {
149 Equality {
151 name: String,
152 expr: Expr,
153 value: f64,
154 },
155 LessEqual {
157 name: String,
158 expr: Expr,
159 value: f64,
160 },
161 AtMostOne { name: String, variables: Vec<Expr> },
163 ImpliesAny {
165 name: String,
166 conditions: Vec<Expr>,
167 result: Expr,
168 },
169}
170
171#[cfg(feature = "dwave")]
172impl Default for Model {
173 fn default() -> Self {
174 Self::new()
175 }
176}
177
178#[cfg(feature = "dwave")]
179impl Model {
180 pub fn new() -> Self {
182 Self {
183 variables: HashSet::new(),
184 objective: None,
185 constraints: Vec::new(),
186 }
187 }
188
189 pub fn add_variable(&mut self, name: &str) -> CompileResult<Expr> {
191 self.variables.insert(name.to_string());
192 Ok(SymEngineExpression::symbol(name))
193 }
194
195 pub fn set_objective(&mut self, expr: Expr) {
197 self.objective = Some(expr);
198 }
199
200 pub fn add_constraint_eq_one(&mut self, name: &str, variables: Vec<Expr>) -> CompileResult<()> {
202 let sum_expr = variables
204 .iter()
205 .fold(Expr::from(0), |acc, v| acc + v.clone());
206 self.constraints.push(Constraint::Equality {
207 name: name.to_string(),
208 expr: sum_expr,
209 value: 1.0,
210 });
211 Ok(())
212 }
213
214 pub fn add_constraint_at_most_one(
216 &mut self,
217 name: &str,
218 variables: Vec<Expr>,
219 ) -> CompileResult<()> {
220 self.constraints.push(Constraint::AtMostOne {
221 name: name.to_string(),
222 variables,
223 });
224 Ok(())
225 }
226
227 pub fn add_constraint_implies_any(
229 &mut self,
230 name: &str,
231 conditions: Vec<Expr>,
232 result: Expr,
233 ) -> CompileResult<()> {
234 self.constraints.push(Constraint::ImpliesAny {
235 name: name.to_string(),
236 conditions,
237 result,
238 });
239 Ok(())
240 }
241
242 pub fn compile(&self) -> CompileResult<CompiledModel> {
244 let mut final_expr = self.objective.clone().unwrap_or_else(|| Expr::from(0));
246
247 let penalty_weight = 10.0;
249
250 for constraint in &self.constraints {
252 match constraint {
253 Constraint::Equality { expr, value, .. } => {
254 let diff = expr.clone() - Expr::from(*value);
256 final_expr = final_expr + Expr::from(penalty_weight) * diff.clone() * diff;
257 }
258 #[cfg(feature = "dwave")]
259 Constraint::LessEqual { expr, value, .. } => {
260 let excess = expr.clone() - Expr::from(*value);
263 final_expr = final_expr + Expr::from(penalty_weight) * excess.clone() * excess;
264 }
265 Constraint::AtMostOne { variables, .. } => {
266 for i in 0..variables.len() {
268 for j in (i + 1)..variables.len() {
269 final_expr = final_expr
270 + Expr::from(penalty_weight)
271 * variables[i].clone()
272 * variables[j].clone();
273 }
274 }
275 }
276 Constraint::ImpliesAny {
277 conditions, result, ..
278 } => {
279 let conditions_sum = conditions
282 .iter()
283 .fold(Expr::from(0), |acc, c| acc + c.clone());
284 final_expr = final_expr
286 + Expr::from(penalty_weight)
287 * conditions_sum
288 * (Expr::from(1) - result.clone());
289 }
290 }
291 }
292
293 let mut compiler = Compile::new(final_expr);
295 let ((qubo_matrix, var_map), offset) = compiler.get_qubo()?;
296
297 Ok(CompiledModel {
298 qubo_matrix,
299 var_map,
300 offset,
301 constraints: self.constraints.clone(),
302 })
303 }
304}
305
306#[cfg(feature = "dwave")]
308#[derive(Debug, Clone)]
309pub struct CompiledModel {
310 pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
312 pub var_map: HashMap<String, usize>,
314 pub offset: f64,
316 constraints: Vec<Constraint>,
318}
319
320#[cfg(feature = "dwave")]
321impl CompiledModel {
322 pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
324 use quantrs2_anneal::ising::QuboModel;
325
326 let mut qubo = QuboModel::new(self.var_map.len());
327
328 qubo.offset = self.offset;
330
331 for i in 0..self.qubo_matrix.nrows() {
333 for j in i..self.qubo_matrix.ncols() {
334 let value = self.qubo_matrix[[i, j]];
335 if value.abs() > 1e-10 {
336 if i == j {
337 qubo.set_linear(i, value)
340 .expect("index within bounds from matrix dimensions");
341 } else {
342 qubo.set_quadratic(i, j, value)
345 .expect("indices within bounds from matrix dimensions");
346 }
347 }
348 }
349 }
350
351 qubo
352 }
353
354 pub fn count_constraint_violations(&self, assignments: &HashMap<String, bool>) -> usize {
359 let float_vals: HashMap<String, f64> = assignments
360 .iter()
361 .map(|(k, &v)| (k.clone(), if v { 1.0 } else { 0.0 }))
362 .collect();
363
364 let mut violations = 0usize;
365
366 for constraint in &self.constraints {
367 let violated = match constraint {
368 Constraint::Equality { expr, value, .. } => match expr.eval(&float_vals) {
369 Ok(result) => (result - value).abs() > 1e-6,
370 Err(_) => false,
371 },
372 Constraint::LessEqual { expr, value, .. } => match expr.eval(&float_vals) {
373 Ok(result) => result > value + 1e-6,
374 Err(_) => false,
375 },
376 Constraint::AtMostOne { variables, .. } => {
377 let count: f64 = variables
378 .iter()
379 .filter_map(|v| v.eval(&float_vals).ok())
380 .filter(|&val| val > 0.5)
381 .count() as f64;
382 count > 1.0 + 1e-6
383 }
384 Constraint::ImpliesAny {
385 conditions, result, ..
386 } => {
387 let any_condition_true = conditions
388 .iter()
389 .any(|c| c.eval(&float_vals).map(|val| val > 0.5).unwrap_or(false));
390 if any_condition_true {
391 match result.eval(&float_vals) {
392 Ok(val) => val < 0.5,
393 Err(_) => false,
394 }
395 } else {
396 false
397 }
398 }
399 };
400 if violated {
401 violations += 1;
402 }
403 }
404
405 violations
406 }
407
408 pub fn num_constraints(&self) -> usize {
410 self.constraints.len()
411 }
412}
413
414#[cfg(not(feature = "dwave"))]
416#[derive(Debug, Clone)]
417pub struct Model {
418 variables: HashSet<String>,
420 objective: Option<SimpleExpr>,
422 constraints: Vec<Constraint>,
424}
425
426#[cfg(not(feature = "dwave"))]
428#[derive(Debug, Clone)]
429enum Constraint {
430 Equality {
432 name: String,
433 expr: SimpleExpr,
434 value: f64,
435 },
436 AtMostOne {
438 name: String,
439 variables: Vec<SimpleExpr>,
440 },
441 ImpliesAny {
443 name: String,
444 conditions: Vec<SimpleExpr>,
445 result: SimpleExpr,
446 },
447}
448
449#[cfg(not(feature = "dwave"))]
450impl Default for Model {
451 fn default() -> Self {
452 Self::new()
453 }
454}
455
456#[cfg(not(feature = "dwave"))]
457impl Model {
458 pub fn new() -> Self {
460 Self {
461 variables: HashSet::new(),
462 objective: None,
463 constraints: Vec::new(),
464 }
465 }
466
467 pub fn add_variable(&mut self, name: &str) -> CompileResult<SimpleExpr> {
469 self.variables.insert(name.to_string());
470 Ok(SimpleExpr::var(name))
471 }
472
473 pub fn set_objective(&mut self, expr: SimpleExpr) {
475 self.objective = Some(expr);
476 }
477
478 pub fn add_constraint_eq_one(
480 &mut self,
481 name: &str,
482 variables: Vec<SimpleExpr>,
483 ) -> CompileResult<()> {
484 let sum_expr = variables.into_iter().sum();
485 self.constraints.push(Constraint::Equality {
486 name: name.to_string(),
487 expr: sum_expr,
488 value: 1.0,
489 });
490 Ok(())
491 }
492
493 pub fn add_constraint_at_most_one(
495 &mut self,
496 name: &str,
497 variables: Vec<SimpleExpr>,
498 ) -> CompileResult<()> {
499 self.constraints.push(Constraint::AtMostOne {
500 name: name.to_string(),
501 variables,
502 });
503 Ok(())
504 }
505
506 pub fn add_constraint_implies_any(
508 &mut self,
509 name: &str,
510 conditions: Vec<SimpleExpr>,
511 result: SimpleExpr,
512 ) -> CompileResult<()> {
513 self.constraints.push(Constraint::ImpliesAny {
514 name: name.to_string(),
515 conditions,
516 result,
517 });
518 Ok(())
519 }
520
521 pub fn compile(&self) -> CompileResult<CompiledModel> {
523 let mut qubo_terms: HashMap<(String, String), f64> = HashMap::new();
525 let mut offset = 0.0;
526 let penalty_weight = 10.0;
527
528 if let Some(ref obj) = self.objective {
530 self.add_expr_to_qubo(obj, 1.0, &mut qubo_terms, &mut offset)?;
531 }
532
533 for constraint in &self.constraints {
535 match constraint {
536 Constraint::Equality { expr, value, .. } => {
537 self.add_expr_squared_to_qubo(
540 expr,
541 penalty_weight,
542 &mut qubo_terms,
543 &mut offset,
544 )?;
545 self.add_expr_to_qubo(
546 expr,
547 -2.0 * penalty_weight * value,
548 &mut qubo_terms,
549 &mut offset,
550 )?;
551 offset += penalty_weight * value * value;
552 }
553 Constraint::AtMostOne { variables, .. } => {
554 for i in 0..variables.len() {
556 for j in (i + 1)..variables.len() {
557 if let (SimpleExpr::Var(vi), SimpleExpr::Var(vj)) =
558 (&variables[i], &variables[j])
559 {
560 let key = if vi < vj {
561 (vi.clone(), vj.clone())
562 } else {
563 (vj.clone(), vi.clone())
564 };
565 *qubo_terms.entry(key).or_insert(0.0) += penalty_weight;
566 }
567 }
568 }
569 }
570 Constraint::ImpliesAny {
571 conditions, result, ..
572 } => {
573 for cond in conditions {
575 if let (SimpleExpr::Var(c), SimpleExpr::Var(r)) = (cond, result) {
576 let key = if c < r {
577 (c.clone(), r.clone())
578 } else {
579 (r.clone(), c.clone())
580 };
581 *qubo_terms.entry(key).or_insert(0.0) -= penalty_weight;
582 }
583 if let SimpleExpr::Var(c) = cond {
585 *qubo_terms.entry((c.clone(), c.clone())).or_insert(0.0) +=
586 penalty_weight;
587 }
588 }
589 }
590 }
591 }
592
593 let all_vars: HashSet<String> = qubo_terms
595 .keys()
596 .flat_map(|(v1, v2)| vec![v1.clone(), v2.clone()])
597 .collect();
598 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
599 sorted_vars.sort();
600
601 let var_map: HashMap<String, usize> = sorted_vars
602 .iter()
603 .enumerate()
604 .map(|(i, v)| (v.clone(), i))
605 .collect();
606
607 let n = var_map.len();
608 let mut matrix = Array::zeros((n, n));
609
610 for ((v1, v2), coeff) in qubo_terms {
611 let i = var_map[&v1];
612 let j = var_map[&v2];
613 if i == j {
614 matrix[[i, i]] += coeff;
615 } else {
616 matrix[[i, j]] += coeff / 2.0;
617 matrix[[j, i]] += coeff / 2.0;
618 }
619 }
620
621 Ok(CompiledModel {
622 qubo_matrix: matrix,
623 var_map,
624 offset,
625 constraints: self.constraints.clone(),
626 })
627 }
628
629 fn add_expr_to_qubo(
631 &self,
632 expr: &SimpleExpr,
633 coeff: f64,
634 terms: &mut HashMap<(String, String), f64>,
635 offset: &mut f64,
636 ) -> CompileResult<()> {
637 match expr {
638 SimpleExpr::Var(name) => {
639 *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
640 }
641 SimpleExpr::Const(val) => {
642 *offset += coeff * val;
643 }
644 SimpleExpr::Add(left, right) => {
645 self.add_expr_to_qubo(left, coeff, terms, offset)?;
646 self.add_expr_to_qubo(right, coeff, terms, offset)?;
647 }
648 SimpleExpr::Mul(left, right) => {
649 if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
650 {
651 let key = if v1 < v2 {
652 (v1.clone(), v2.clone())
653 } else {
654 (v2.clone(), v1.clone())
655 };
656 *terms.entry(key).or_insert(0.0) += coeff;
657 } else if let (SimpleExpr::Const(c), var) | (var, SimpleExpr::Const(c)) =
658 (left.as_ref(), right.as_ref())
659 {
660 self.add_expr_to_qubo(var, coeff * c, terms, offset)?;
661 }
662 }
663 SimpleExpr::Pow(base, exp) => {
664 if *exp == 2 && matches!(base.as_ref(), SimpleExpr::Var(_)) {
665 self.add_expr_to_qubo(base, coeff, terms, offset)?;
667 }
668 }
669 }
670 Ok(())
671 }
672
673 fn add_expr_squared_to_qubo(
675 &self,
676 expr: &SimpleExpr,
677 coeff: f64,
678 terms: &mut HashMap<(String, String), f64>,
679 offset: &mut f64,
680 ) -> CompileResult<()> {
681 match expr {
683 SimpleExpr::Var(name) => {
684 *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
686 }
687 SimpleExpr::Add(left, right) => {
688 self.add_expr_squared_to_qubo(left, coeff, terms, offset)?;
690 self.add_expr_squared_to_qubo(right, coeff, terms, offset)?;
691 if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
693 {
694 let key = if v1 < v2 {
695 (v1.clone(), v2.clone())
696 } else {
697 (v2.clone(), v1.clone())
698 };
699 *terms.entry(key).or_insert(0.0) += 2.0 * coeff;
700 }
701 }
702 _ => {}
703 }
704 Ok(())
705 }
706}
707
708#[cfg(not(feature = "dwave"))]
710#[derive(Debug, Clone)]
711pub struct CompiledModel {
712 pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
714 pub var_map: HashMap<String, usize>,
716 pub offset: f64,
718 constraints: Vec<Constraint>,
720}
721
722#[cfg(not(feature = "dwave"))]
723impl CompiledModel {
724 pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
726 use quantrs2_anneal::ising::QuboModel;
727
728 let mut qubo = QuboModel::new(self.var_map.len());
729
730 qubo.offset = self.offset;
732
733 for i in 0..self.qubo_matrix.nrows() {
735 for j in i..self.qubo_matrix.ncols() {
736 let value = self.qubo_matrix[[i, j]];
737 if value.abs() > 1e-10 {
738 if i == j {
739 qubo.set_linear(i, value)
742 .expect("index within bounds from matrix dimensions");
743 } else {
744 qubo.set_quadratic(i, j, value)
747 .expect("indices within bounds from matrix dimensions");
748 }
749 }
750 }
751 }
752
753 qubo
754 }
755}
756
757#[cfg(feature = "dwave")]
762pub struct Compile {
763 expr: Expr,
765}
766
767#[cfg(feature = "dwave")]
768impl Compile {
769 pub fn new<T: Into<Expr>>(expr: T) -> Self {
771 Self { expr: expr.into() }
772 }
773
774 pub fn get_qubo(
785 &self,
786 ) -> CompileResult<(
787 (
788 Array<f64, scirs2_core::ndarray::Ix2>,
789 HashMap<String, usize>,
790 ),
791 f64,
792 )> {
793 #[cfg(feature = "scirs")]
794 {
795 self.get_qubo_scirs()
796 }
797 #[cfg(not(feature = "scirs"))]
798 {
799 self.get_qubo_standard()
800 }
801 }
802
803 fn get_qubo_standard(
805 &self,
806 ) -> CompileResult<(
807 (
808 Array<f64, scirs2_core::ndarray::Ix2>,
809 HashMap<String, usize>,
810 ),
811 f64,
812 )> {
813 let expr = self.expr.expand();
815
816 let expr = replace_squared_terms(&expr)?;
819
820 let (coeffs, offset) = extract_coefficients(&expr)?;
822
823 let max_degree = coeffs.keys().map(|vars| vars.len()).max().unwrap_or(0);
825 if max_degree > 2 {
826 return Err(CompileError::DegreeTooHigh(max_degree, 2));
827 }
828
829 let (matrix, var_map) = build_qubo_matrix(&coeffs)?;
831
832 Ok(((matrix, var_map), offset))
833 }
834
835 #[cfg(feature = "scirs")]
837 fn get_qubo_scirs(
838 &self,
839 ) -> CompileResult<(
840 (
841 Array<f64, scirs2_core::ndarray::Ix2>,
842 HashMap<String, usize>,
843 ),
844 f64,
845 )> {
846 let ((matrix, var_map), offset) = self.get_qubo_standard()?;
848
849 let enhanced_matrix = crate::scirs_stub::enhance_qubo_matrix(&matrix);
851
852 Ok(((enhanced_matrix, var_map), offset))
853 }
854
855 pub fn get_hobo(
866 &self,
867 ) -> CompileResult<(
868 (
869 Array<f64, scirs2_core::ndarray::IxDyn>,
870 HashMap<String, usize>,
871 ),
872 f64,
873 )> {
874 let mut expr = self.expr.expand();
876
877 let max_degree = calc_highest_degree(&expr)?;
879
880 let mut expr = replace_squared_terms(&expr)?;
882
883 let mut expr = expr.expand();
885
886 let (coeffs, offset) = extract_coefficients(&expr)?;
888
889 let (tensor, var_map) = build_hobo_tensor(&coeffs, max_degree)?;
891
892 Ok(((tensor, var_map), offset))
893 }
894}
895
896#[cfg(feature = "dwave")]
898fn calc_highest_degree(expr: &Expr) -> CompileResult<usize> {
899 if expr.is_symbol() {
901 return Ok(1);
902 }
903
904 if expr.is_number() {
906 return Ok(0);
907 }
908
909 if expr.is_neg() {
911 let inner = expr.as_neg().expect("is_neg() was true");
913 return calc_highest_degree(&inner);
914 }
915
916 if expr.is_pow() {
918 let (base, exp) = expr.as_pow().expect("is_pow() was true");
920
921 if base.is_symbol() && exp.is_number() {
923 let exp_val = match exp.to_f64() {
924 Some(n) => n,
925 None => {
926 return Err(CompileError::InvalidExpression(
927 "Invalid exponent".to_string(),
928 ))
929 }
930 };
931
932 if exp_val.is_sign_positive() && exp_val.fract() == 0.0 {
934 return Ok(exp_val as usize);
935 }
936 }
937
938 let base_degree = calc_highest_degree(&base)?;
940 let exp_degree = if exp.is_number() {
941 match exp.to_f64() {
942 Some(n) => {
943 if n.is_sign_positive() && n.fract() == 0.0 {
944 n as usize
945 } else {
946 0 }
948 }
949 None => 0,
950 }
951 } else {
952 0 };
954
955 return Ok(base_degree * exp_degree);
956 }
957
958 if expr.is_mul() {
960 let mut total_degree = 0;
961 for factor in expr.as_mul().expect("is_mul() was true") {
963 total_degree += calc_highest_degree(&factor)?;
964 }
965 return Ok(total_degree);
966 }
967
968 if expr.is_add() {
970 let mut max_degree = 0;
971 for term in expr.as_add().expect("is_add() was true") {
973 let term_degree = calc_highest_degree(&term)?;
974 max_degree = std::cmp::max(max_degree, term_degree);
975 }
976 return Ok(max_degree);
977 }
978
979 let expr_str = format!("{expr}");
981 if expr_str.contains('+') || expr_str.contains('-') {
982 let mut max_degree = 0;
986
987 let parts: Vec<&str> = expr_str.split(['+', '-']).collect();
989
990 for part in parts {
991 let part = part.trim();
992 if part.is_empty() {
993 continue;
994 }
995
996 let degree = if part.contains("**") || part.contains('^') {
998 let exp_str = part
1001 .split("**")
1002 .nth(1)
1003 .or_else(|| part.split('^').nth(1))
1004 .unwrap_or("2")
1005 .trim();
1006 exp_str.parse::<usize>().unwrap_or(2)
1007 } else if part.contains('*') {
1008 let factors: Vec<&str> = part.split('*').collect();
1010 let mut var_count = 0;
1011 for factor in factors {
1012 let factor = factor.trim();
1013 if !factor.is_empty() && factor.parse::<f64>().is_err() {
1015 var_count += 1;
1016 }
1017 }
1018 var_count
1019 } else if part.parse::<f64>().is_err() && !part.is_empty() {
1020 1
1022 } else {
1023 0
1025 };
1026
1027 max_degree = std::cmp::max(max_degree, degree);
1028 }
1029
1030 return Ok(max_degree);
1031 }
1032
1033 Err(CompileError::InvalidExpression(format!(
1036 "Can't determine degree of: {expr}"
1037 )))
1038}
1039
1040#[cfg(feature = "dwave")]
1042fn replace_squared_terms(expr: &Expr) -> CompileResult<Expr> {
1043 if expr.is_symbol() || expr.is_number() {
1047 return Ok(expr.clone());
1048 }
1049
1050 if expr.is_neg() {
1052 let inner = expr.as_neg().expect("is_neg() was true");
1054 let new_inner = replace_squared_terms(&inner)?;
1055 return Ok(-new_inner);
1056 }
1057
1058 if expr.is_pow() {
1060 let (base, exp) = expr.as_pow().expect("is_pow() was true");
1062
1063 if base.is_symbol() && exp.is_number() {
1065 let exp_val = match exp.to_f64() {
1066 Some(n) => n,
1067 None => {
1068 return Err(CompileError::InvalidExpression(
1069 "Invalid exponent".to_string(),
1070 ))
1071 }
1072 };
1073
1074 if exp_val == 2.0 {
1076 return Ok(base);
1077 }
1078 }
1079
1080 let new_base = replace_squared_terms(&base)?;
1082 return Ok(new_base.pow(&exp));
1083 }
1084
1085 if expr.is_mul() {
1087 let mut new_terms = Vec::new();
1088 for factor in expr.as_mul().expect("is_mul() was true") {
1090 new_terms.push(replace_squared_terms(&factor)?);
1091 }
1092
1093 if new_terms.len() == 2 {
1096 if let (Some(name1), Some(name2)) = (new_terms[0].as_symbol(), new_terms[1].as_symbol())
1097 {
1098 if name1 == name2 {
1099 return Ok(new_terms.remove(0));
1101 }
1102 }
1103 }
1104
1105 if new_terms.is_empty() {
1107 return Ok(Expr::from(1));
1108 }
1109 let mut result = new_terms.remove(0);
1110 for term in new_terms {
1111 result = result * term;
1112 }
1113 return Ok(result);
1114 }
1115
1116 if expr.is_add() {
1118 let mut new_terms = Vec::new();
1119 for term in expr.as_add().expect("is_add() was true") {
1121 new_terms.push(replace_squared_terms(&term)?);
1122 }
1123
1124 if new_terms.is_empty() {
1126 return Ok(Expr::from(0));
1127 }
1128 let mut result = new_terms.remove(0);
1129 for term in new_terms {
1130 result = result + term;
1131 }
1132 return Ok(result);
1133 }
1134
1135 Ok(expr.clone())
1137}
1138
1139#[cfg(feature = "dwave")]
1141fn extract_coefficients(expr: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1142 let mut coeffs = HashMap::new();
1143 let mut offset = 0.0;
1144
1145 if expr.is_add() {
1147 for term in expr.as_add().expect("is_add() was true") {
1149 let (term_coeffs, term_offset) = extract_term_coefficients(&term)?;
1150
1151 for (vars, coeff) in term_coeffs {
1153 *coeffs.entry(vars).or_insert(0.0) += coeff;
1154 }
1155
1156 offset += term_offset;
1158 }
1159 } else {
1160 let expr_str = format!("{expr}");
1162 if expr_str.contains('+') || expr_str.contains('-') {
1163 use regex::Regex;
1166 let re = Regex::new(r"([+-]?)([^+-]+)").expect("static regex pattern is valid");
1168
1169 for caps in re.captures_iter(&expr_str) {
1170 let sign = caps.get(1).map_or("", |m| m.as_str());
1171 let term = caps.get(2).map_or("", |m| m.as_str()).trim();
1172
1173 if term.is_empty() {
1174 continue;
1175 }
1176
1177 let sign_mult = if sign == "-" { -1.0 } else { 1.0 };
1178
1179 if term.contains("**") || term.contains('^') {
1181 let base = if term.contains("**") {
1182 term.split("**").next().unwrap_or(term)
1183 } else {
1184 term.split('^').next().unwrap_or(term)
1185 }
1186 .trim();
1187
1188 let (coeff_mult, var_name) = if base.contains('*') {
1190 let parts: Vec<&str> = base.split('*').collect();
1191 if parts.len() == 2 {
1192 if let Ok(num) = parts[0].trim().parse::<f64>() {
1193 (num, parts[1].trim().to_string())
1194 } else if let Ok(num) = parts[1].trim().parse::<f64>() {
1195 (num, parts[0].trim().to_string())
1196 } else {
1197 (1.0, base.to_string())
1198 }
1199 } else {
1200 (1.0, base.to_string())
1201 }
1202 } else {
1203 (1.0, base.to_string())
1204 };
1205
1206 let vars = vec![var_name.clone()];
1207 *coeffs.entry(vars).or_insert(0.0) += sign_mult * coeff_mult;
1208 } else if term.contains('*') {
1209 let parts: Vec<&str> = term.split('*').collect();
1211 let mut coeff = sign_mult;
1212 let mut vars = Vec::new();
1213
1214 for part in parts {
1215 let part = part.trim();
1216 if let Ok(num) = part.parse::<f64>() {
1217 coeff *= num;
1218 } else {
1219 vars.push(part.to_string());
1221 }
1222 }
1223
1224 vars.sort();
1226 *coeffs.entry(vars).or_insert(0.0) += coeff;
1227 } else if let Ok(num) = term.parse::<f64>() {
1228 offset += sign_mult * num;
1230 } else {
1231 let vars = vec![term.to_string()];
1233 *coeffs.entry(vars).or_insert(0.0) += sign_mult;
1234 }
1235 }
1236 return Ok((coeffs, offset));
1237 }
1238
1239 if coeffs.is_empty() {
1241 let (term_coeffs, term_offset) = extract_term_coefficients(expr)?;
1243
1244 for (vars, coeff) in term_coeffs {
1246 *coeffs.entry(vars).or_insert(0.0) += coeff;
1247 }
1248
1249 offset += term_offset;
1251 }
1252 }
1253
1254 Ok((coeffs, offset))
1255}
1256
1257#[cfg(feature = "dwave")]
1259fn extract_term_coefficients(term: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1260 let mut coeffs = HashMap::new();
1261
1262 if term.is_number() {
1264 let value = match term.to_f64() {
1265 Some(n) => n,
1266 None => {
1267 return Err(CompileError::InvalidExpression(
1268 "Invalid number".to_string(),
1269 ))
1270 }
1271 };
1272 return Ok((coeffs, value));
1273 }
1274
1275 if term.is_add() {
1277 let mut offset = 0.0;
1278 for sub_term in term.as_add().expect("is_add() was true") {
1280 let (sub_coeffs, sub_offset) = extract_term_coefficients(&sub_term)?;
1281 for (vars, coeff) in sub_coeffs {
1282 *coeffs.entry(vars).or_insert(0.0) += coeff;
1283 }
1284 offset += sub_offset;
1285 }
1286 return Ok((coeffs, offset));
1287 }
1288
1289 if term.is_neg() {
1291 let inner = term.as_neg().expect("is_neg() was true");
1293 let (inner_coeffs, inner_offset) = extract_term_coefficients(&inner)?;
1294
1295 for (vars, coeff) in inner_coeffs {
1297 coeffs.insert(vars, -coeff);
1298 }
1299
1300 return Ok((coeffs, -inner_offset));
1301 }
1302
1303 if term.is_symbol() {
1305 let var_name = term.as_symbol().expect("is_symbol() was true");
1307 let vars = vec![var_name.to_string()];
1308 coeffs.insert(vars, 1.0);
1309 return Ok((coeffs, 0.0));
1310 }
1311
1312 if term.is_mul() {
1314 let mut coeff = 1.0;
1315 let mut vars = Vec::new();
1316
1317 let factors = term.as_mul().expect("is_mul() was true");
1319 let mut factor_stack: Vec<_> = factors.into_iter().collect();
1322 while let Some(factor) = factor_stack.pop() {
1323 if factor.is_number() {
1324 let value = match factor.to_f64() {
1326 Some(n) => n,
1327 None => {
1328 return Err(CompileError::InvalidExpression(
1329 "Invalid number in product".to_string(),
1330 ))
1331 }
1332 };
1333 coeff *= value;
1334 } else if factor.is_symbol() {
1335 let var_name = factor.as_symbol().expect("is_symbol() was true");
1338 vars.push(var_name.to_string());
1339 } else if factor.is_mul() {
1340 let sub_factors = factor.as_mul().expect("is_mul() was true");
1342 factor_stack.extend(sub_factors);
1343 } else if factor.is_pow() {
1344 let (base, exp) = factor.as_pow().expect("is_pow() was true");
1346 if base.is_symbol() && exp.is_number() {
1347 let exp_val = exp.to_f64().unwrap_or(0.0);
1348 if exp_val.is_sign_positive() && exp_val.fract() == 0.0 && exp_val >= 1.0 {
1349 let var_name = base.as_symbol().expect("is_symbol() was true");
1351 vars.push(var_name.to_string());
1352 } else {
1353 return Err(CompileError::InvalidExpression(format!(
1354 "Unsupported power in product: {factor}"
1355 )));
1356 }
1357 } else {
1358 return Err(CompileError::InvalidExpression(format!(
1359 "Unsupported power term in product: {factor}"
1360 )));
1361 }
1362 } else {
1363 return Err(CompileError::InvalidExpression(format!(
1365 "Unsupported term in product: {factor}"
1366 )));
1367 }
1368 }
1369
1370 vars.sort();
1372
1373 if vars.is_empty() {
1374 return Ok((coeffs, coeff));
1376 }
1377 coeffs.insert(vars, coeff);
1378
1379 return Ok((coeffs, 0.0));
1380 }
1381
1382 if term.is_pow() {
1384 return Err(CompileError::InvalidExpression(format!(
1385 "Unexpected power term after simplification: {term}"
1386 )));
1387 }
1388
1389 Err(CompileError::InvalidExpression(format!(
1391 "Unsupported term: {term}"
1392 )))
1393}
1394
1395#[allow(dead_code)]
1397fn build_qubo_matrix(
1398 coeffs: &HashMap<Vec<String>, f64>,
1399) -> CompileResult<(
1400 Array<f64, scirs2_core::ndarray::Ix2>,
1401 HashMap<String, usize>,
1402)> {
1403 let mut all_vars = HashSet::new();
1405 for vars in coeffs.keys() {
1406 for var in vars {
1407 all_vars.insert(var.clone());
1408 }
1409 }
1410
1411 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1413 sorted_vars.sort();
1414
1415 let var_map: HashMap<String, usize> = sorted_vars
1417 .iter()
1418 .enumerate()
1419 .map(|(i, var)| (var.clone(), i))
1420 .collect();
1421
1422 let n = var_map.len();
1424
1425 let mut matrix = Array::zeros((n, n));
1427
1428 for (vars, &coeff) in coeffs {
1430 match vars.len() {
1431 0 => {
1432 }
1434 1 => {
1435 let i = *var_map
1438 .get(&vars[0])
1439 .expect("variable exists in var_map built from coeffs");
1440 matrix[[i, i]] += coeff;
1441 }
1442 2 => {
1443 let i = *var_map
1446 .get(&vars[0])
1447 .expect("variable exists in var_map built from coeffs");
1448 let j = *var_map
1449 .get(&vars[1])
1450 .expect("variable exists in var_map built from coeffs");
1451
1452 if i == j {
1454 matrix[[i, i]] += coeff;
1456 } else {
1457 if i <= j {
1459 matrix[[i, j]] += coeff;
1460 } else {
1461 matrix[[j, i]] += coeff;
1462 }
1463 }
1464 }
1465 _ => {
1466 return Err(CompileError::DegreeTooHigh(vars.len(), 2));
1468 }
1469 }
1470 }
1471
1472 Ok((matrix, var_map))
1473}
1474
1475#[allow(dead_code)]
1477fn build_hobo_tensor(
1478 coeffs: &HashMap<Vec<String>, f64>,
1479 max_degree: usize,
1480) -> CompileResult<(
1481 Array<f64, scirs2_core::ndarray::IxDyn>,
1482 HashMap<String, usize>,
1483)> {
1484 let mut all_vars = HashSet::new();
1486 for vars in coeffs.keys() {
1487 for var in vars {
1488 all_vars.insert(var.clone());
1489 }
1490 }
1491
1492 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1494 sorted_vars.sort();
1495
1496 let var_map: HashMap<String, usize> = sorted_vars
1498 .iter()
1499 .enumerate()
1500 .map(|(i, var)| (var.clone(), i))
1501 .collect();
1502
1503 let n = var_map.len();
1505
1506 let shape: Vec<usize> = vec![n; max_degree];
1508
1509 let mut tensor = Array::zeros(scirs2_core::ndarray::IxDyn(&shape));
1511
1512 for (vars, &coeff) in coeffs {
1514 let degree = vars.len();
1515
1516 if degree == 0 {
1517 continue;
1519 }
1520
1521 if degree > max_degree {
1522 return Err(CompileError::DegreeTooHigh(degree, max_degree));
1523 }
1524
1525 let mut indices: Vec<usize> = vars
1528 .iter()
1529 .map(|var| {
1530 *var_map
1531 .get(var)
1532 .expect("variable exists in var_map built from coeffs")
1533 })
1534 .collect();
1535
1536 indices.sort_unstable();
1538
1539 while indices.len() < max_degree {
1541 indices.insert(0, indices[0]); }
1543
1544 let idx = scirs2_core::ndarray::IxDyn(&indices);
1546 tensor[idx] += coeff;
1547 }
1548
1549 Ok((tensor, var_map))
1550}
1551
1552#[cfg(feature = "dwave")]
1557pub struct PieckCompile {
1558 expr: Expr,
1560 verbose: bool,
1562}
1563
1564#[cfg(feature = "dwave")]
1565impl PieckCompile {
1566 pub fn new<T: Into<Expr>>(expr: T, verbose: bool) -> Self {
1568 Self {
1569 expr: expr.into(),
1570 verbose,
1571 }
1572 }
1573
1574 pub fn get_qubo(
1576 &self,
1577 ) -> CompileResult<(
1578 (
1579 Array<f64, scirs2_core::ndarray::Ix2>,
1580 HashMap<String, usize>,
1581 ),
1582 f64,
1583 )> {
1584 Compile::new(self.expr.clone()).get_qubo()
1587 }
1588}