1#![allow(dead_code)]
8
9use scirs2_core::ndarray::Array;
10use std::collections::{HashMap, HashSet};
11
12#[cfg(feature = "scirs")]
13use crate::scirs_stub;
14
15#[cfg(feature = "dwave")]
16use quantrs2_symengine::Expression as SymEngineExpression;
17
18#[cfg(feature = "dwave")]
19type Expr = SymEngineExpression;
20use thiserror::Error;
21
22use quantrs2_anneal::QuboError;
23
24#[cfg(feature = "dwave")]
26pub mod expr {
27 use quantrs2_symengine::Expression as SymEngineExpression;
28
29 pub type Expr = SymEngineExpression;
30
31 pub fn constant(value: f64) -> Expr {
32 SymEngineExpression::from_f64(value)
33 }
34
35 pub fn var(name: &str) -> Expr {
36 SymEngineExpression::symbol(name)
37 }
38}
39
40#[cfg(not(feature = "dwave"))]
41pub mod expr {
42 use super::SimpleExpr;
43
44 pub type Expr = SimpleExpr;
45
46 pub const fn constant(value: f64) -> Expr {
47 SimpleExpr::constant(value)
48 }
49
50 pub fn var(name: &str) -> Expr {
51 SimpleExpr::var(name)
52 }
53}
54
55#[derive(Error, Debug)]
57pub enum CompileError {
58 #[error("Invalid expression: {0}")]
60 InvalidExpression(String),
61
62 #[error("Term has degree {0}, but maximum supported is {1}")]
64 DegreeTooHigh(usize, usize),
65
66 #[error("QUBO error: {0}")]
68 QuboError(#[from] QuboError),
69
70 #[error("Symengine error: {0}")]
72 SymengineError(String),
73}
74
75pub type CompileResult<T> = Result<T, CompileError>;
77
78#[cfg(not(feature = "dwave"))]
80#[derive(Debug, Clone)]
81pub enum SimpleExpr {
82 Var(String),
84 Const(f64),
86 Add(Box<Self>, Box<Self>),
88 Mul(Box<Self>, Box<Self>),
90 Pow(Box<Self>, i32),
92}
93
94#[cfg(not(feature = "dwave"))]
95impl SimpleExpr {
96 pub fn var(name: &str) -> Self {
98 Self::Var(name.to_string())
99 }
100
101 pub const fn constant(value: f64) -> Self {
103 Self::Const(value)
104 }
105}
106
107#[cfg(not(feature = "dwave"))]
108impl std::ops::Add for SimpleExpr {
109 type Output = Self;
110
111 fn add(self, rhs: Self) -> Self::Output {
112 Self::Add(Box::new(self), Box::new(rhs))
113 }
114}
115
116#[cfg(not(feature = "dwave"))]
117impl std::ops::Mul for SimpleExpr {
118 type Output = Self;
119
120 fn mul(self, rhs: Self) -> Self::Output {
121 Self::Mul(Box::new(self), Box::new(rhs))
122 }
123}
124
125#[cfg(not(feature = "dwave"))]
126impl std::iter::Sum for SimpleExpr {
127 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
128 iter.fold(Self::Const(0.0), |acc, x| acc + x)
129 }
130}
131
132#[cfg(feature = "dwave")]
134#[derive(Debug, Clone)]
135pub struct Model {
136 variables: HashSet<String>,
138 objective: Option<Expr>,
140 constraints: Vec<Constraint>,
142}
143
144#[cfg(feature = "dwave")]
146#[derive(Debug, Clone)]
147enum Constraint {
148 Equality {
150 name: String,
151 expr: Expr,
152 value: f64,
153 },
154 LessEqual {
156 name: String,
157 expr: Expr,
158 value: f64,
159 },
160 AtMostOne { name: String, variables: Vec<Expr> },
162 ImpliesAny {
164 name: String,
165 conditions: Vec<Expr>,
166 result: Expr,
167 },
168}
169
170#[cfg(feature = "dwave")]
171impl Default for Model {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[cfg(feature = "dwave")]
178impl Model {
179 pub fn new() -> Self {
181 Self {
182 variables: HashSet::new(),
183 objective: None,
184 constraints: Vec::new(),
185 }
186 }
187
188 pub fn add_variable(&mut self, name: &str) -> CompileResult<Expr> {
190 self.variables.insert(name.to_string());
191 Ok(SymEngineExpression::symbol(name))
192 }
193
194 pub fn set_objective(&mut self, expr: Expr) {
196 self.objective = Some(expr);
197 }
198
199 pub fn add_constraint_eq_one(&mut self, name: &str, variables: Vec<Expr>) -> CompileResult<()> {
201 let sum_expr = variables
203 .iter()
204 .fold(Expr::from(0), |acc, v| acc + v.clone());
205 self.constraints.push(Constraint::Equality {
206 name: name.to_string(),
207 expr: sum_expr,
208 value: 1.0,
209 });
210 Ok(())
211 }
212
213 pub fn add_constraint_at_most_one(
215 &mut self,
216 name: &str,
217 variables: Vec<Expr>,
218 ) -> CompileResult<()> {
219 self.constraints.push(Constraint::AtMostOne {
220 name: name.to_string(),
221 variables,
222 });
223 Ok(())
224 }
225
226 pub fn add_constraint_implies_any(
228 &mut self,
229 name: &str,
230 conditions: Vec<Expr>,
231 result: Expr,
232 ) -> CompileResult<()> {
233 self.constraints.push(Constraint::ImpliesAny {
234 name: name.to_string(),
235 conditions,
236 result,
237 });
238 Ok(())
239 }
240
241 pub fn compile(&self) -> CompileResult<CompiledModel> {
243 let mut final_expr = self.objective.clone().unwrap_or_else(|| Expr::from(0));
245
246 let penalty_weight = 10.0;
248
249 for constraint in &self.constraints {
251 match constraint {
252 Constraint::Equality { expr, value, .. } => {
253 let diff = expr.clone() - Expr::from(*value);
255 final_expr = final_expr + Expr::from(penalty_weight) * diff.clone() * diff;
256 }
257 #[cfg(feature = "dwave")]
258 Constraint::LessEqual { expr, value, .. } => {
259 let excess = expr.clone() - Expr::from(*value);
262 final_expr = final_expr + Expr::from(penalty_weight) * excess.clone() * excess;
263 }
264 Constraint::AtMostOne { variables, .. } => {
265 for i in 0..variables.len() {
267 for j in (i + 1)..variables.len() {
268 final_expr = final_expr
269 + Expr::from(penalty_weight)
270 * variables[i].clone()
271 * variables[j].clone();
272 }
273 }
274 }
275 Constraint::ImpliesAny {
276 conditions, result, ..
277 } => {
278 let conditions_sum = conditions
281 .iter()
282 .fold(Expr::from(0), |acc, c| acc + c.clone());
283 final_expr = final_expr
285 + Expr::from(penalty_weight)
286 * conditions_sum
287 * (Expr::from(1) - result.clone());
288 }
289 }
290 }
291
292 let mut compiler = Compile::new(final_expr);
294 let ((qubo_matrix, var_map), offset) = compiler.get_qubo()?;
295
296 Ok(CompiledModel {
297 qubo_matrix,
298 var_map,
299 offset,
300 constraints: self.constraints.clone(),
301 })
302 }
303}
304
305#[cfg(feature = "dwave")]
307#[derive(Debug, Clone)]
308pub struct CompiledModel {
309 pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
311 pub var_map: HashMap<String, usize>,
313 pub offset: f64,
315 constraints: Vec<Constraint>,
317}
318
319#[cfg(feature = "dwave")]
320impl CompiledModel {
321 pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
323 use quantrs2_anneal::ising::QuboModel;
324
325 let mut qubo = QuboModel::new(self.var_map.len());
326
327 qubo.offset = self.offset;
329
330 for i in 0..self.qubo_matrix.nrows() {
332 for j in i..self.qubo_matrix.ncols() {
333 let value = self.qubo_matrix[[i, j]];
334 if value.abs() > 1e-10 {
335 if i == j {
336 qubo.set_linear(i, value)
339 .expect("index within bounds from matrix dimensions");
340 } else {
341 qubo.set_quadratic(i, j, value)
344 .expect("indices within bounds from matrix dimensions");
345 }
346 }
347 }
348 }
349
350 qubo
351 }
352}
353
354#[cfg(not(feature = "dwave"))]
356#[derive(Debug, Clone)]
357pub struct Model {
358 variables: HashSet<String>,
360 objective: Option<SimpleExpr>,
362 constraints: Vec<Constraint>,
364}
365
366#[cfg(not(feature = "dwave"))]
368#[derive(Debug, Clone)]
369enum Constraint {
370 Equality {
372 name: String,
373 expr: SimpleExpr,
374 value: f64,
375 },
376 AtMostOne {
378 name: String,
379 variables: Vec<SimpleExpr>,
380 },
381 ImpliesAny {
383 name: String,
384 conditions: Vec<SimpleExpr>,
385 result: SimpleExpr,
386 },
387}
388
389#[cfg(not(feature = "dwave"))]
390impl Default for Model {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396#[cfg(not(feature = "dwave"))]
397impl Model {
398 pub fn new() -> Self {
400 Self {
401 variables: HashSet::new(),
402 objective: None,
403 constraints: Vec::new(),
404 }
405 }
406
407 pub fn add_variable(&mut self, name: &str) -> CompileResult<SimpleExpr> {
409 self.variables.insert(name.to_string());
410 Ok(SimpleExpr::var(name))
411 }
412
413 pub fn set_objective(&mut self, expr: SimpleExpr) {
415 self.objective = Some(expr);
416 }
417
418 pub fn add_constraint_eq_one(
420 &mut self,
421 name: &str,
422 variables: Vec<SimpleExpr>,
423 ) -> CompileResult<()> {
424 let sum_expr = variables.into_iter().sum();
425 self.constraints.push(Constraint::Equality {
426 name: name.to_string(),
427 expr: sum_expr,
428 value: 1.0,
429 });
430 Ok(())
431 }
432
433 pub fn add_constraint_at_most_one(
435 &mut self,
436 name: &str,
437 variables: Vec<SimpleExpr>,
438 ) -> CompileResult<()> {
439 self.constraints.push(Constraint::AtMostOne {
440 name: name.to_string(),
441 variables,
442 });
443 Ok(())
444 }
445
446 pub fn add_constraint_implies_any(
448 &mut self,
449 name: &str,
450 conditions: Vec<SimpleExpr>,
451 result: SimpleExpr,
452 ) -> CompileResult<()> {
453 self.constraints.push(Constraint::ImpliesAny {
454 name: name.to_string(),
455 conditions,
456 result,
457 });
458 Ok(())
459 }
460
461 pub fn compile(&self) -> CompileResult<CompiledModel> {
463 let mut qubo_terms: HashMap<(String, String), f64> = HashMap::new();
465 let mut offset = 0.0;
466 let penalty_weight = 10.0;
467
468 if let Some(ref obj) = self.objective {
470 self.add_expr_to_qubo(obj, 1.0, &mut qubo_terms, &mut offset)?;
471 }
472
473 for constraint in &self.constraints {
475 match constraint {
476 Constraint::Equality { expr, value, .. } => {
477 self.add_expr_squared_to_qubo(
480 expr,
481 penalty_weight,
482 &mut qubo_terms,
483 &mut offset,
484 )?;
485 self.add_expr_to_qubo(
486 expr,
487 -2.0 * penalty_weight * value,
488 &mut qubo_terms,
489 &mut offset,
490 )?;
491 offset += penalty_weight * value * value;
492 }
493 Constraint::AtMostOne { variables, .. } => {
494 for i in 0..variables.len() {
496 for j in (i + 1)..variables.len() {
497 if let (SimpleExpr::Var(vi), SimpleExpr::Var(vj)) =
498 (&variables[i], &variables[j])
499 {
500 let key = if vi < vj {
501 (vi.clone(), vj.clone())
502 } else {
503 (vj.clone(), vi.clone())
504 };
505 *qubo_terms.entry(key).or_insert(0.0) += penalty_weight;
506 }
507 }
508 }
509 }
510 Constraint::ImpliesAny {
511 conditions, result, ..
512 } => {
513 for cond in conditions {
515 if let (SimpleExpr::Var(c), SimpleExpr::Var(r)) = (cond, result) {
516 let key = if c < r {
517 (c.clone(), r.clone())
518 } else {
519 (r.clone(), c.clone())
520 };
521 *qubo_terms.entry(key).or_insert(0.0) -= penalty_weight;
522 }
523 if let SimpleExpr::Var(c) = cond {
525 *qubo_terms.entry((c.clone(), c.clone())).or_insert(0.0) +=
526 penalty_weight;
527 }
528 }
529 }
530 }
531 }
532
533 let all_vars: HashSet<String> = qubo_terms
535 .keys()
536 .flat_map(|(v1, v2)| vec![v1.clone(), v2.clone()])
537 .collect();
538 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
539 sorted_vars.sort();
540
541 let var_map: HashMap<String, usize> = sorted_vars
542 .iter()
543 .enumerate()
544 .map(|(i, v)| (v.clone(), i))
545 .collect();
546
547 let n = var_map.len();
548 let mut matrix = Array::zeros((n, n));
549
550 for ((v1, v2), coeff) in qubo_terms {
551 let i = var_map[&v1];
552 let j = var_map[&v2];
553 if i == j {
554 matrix[[i, i]] += coeff;
555 } else {
556 matrix[[i, j]] += coeff / 2.0;
557 matrix[[j, i]] += coeff / 2.0;
558 }
559 }
560
561 Ok(CompiledModel {
562 qubo_matrix: matrix,
563 var_map,
564 offset,
565 constraints: self.constraints.clone(),
566 })
567 }
568
569 fn add_expr_to_qubo(
571 &self,
572 expr: &SimpleExpr,
573 coeff: f64,
574 terms: &mut HashMap<(String, String), f64>,
575 offset: &mut f64,
576 ) -> CompileResult<()> {
577 match expr {
578 SimpleExpr::Var(name) => {
579 *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
580 }
581 SimpleExpr::Const(val) => {
582 *offset += coeff * val;
583 }
584 SimpleExpr::Add(left, right) => {
585 self.add_expr_to_qubo(left, coeff, terms, offset)?;
586 self.add_expr_to_qubo(right, coeff, terms, offset)?;
587 }
588 SimpleExpr::Mul(left, right) => {
589 if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
590 {
591 let key = if v1 < v2 {
592 (v1.clone(), v2.clone())
593 } else {
594 (v2.clone(), v1.clone())
595 };
596 *terms.entry(key).or_insert(0.0) += coeff;
597 } else if let (SimpleExpr::Const(c), var) | (var, SimpleExpr::Const(c)) =
598 (left.as_ref(), right.as_ref())
599 {
600 self.add_expr_to_qubo(var, coeff * c, terms, offset)?;
601 }
602 }
603 SimpleExpr::Pow(base, exp) => {
604 if *exp == 2 && matches!(base.as_ref(), SimpleExpr::Var(_)) {
605 self.add_expr_to_qubo(base, coeff, terms, offset)?;
607 }
608 }
609 }
610 Ok(())
611 }
612
613 fn add_expr_squared_to_qubo(
615 &self,
616 expr: &SimpleExpr,
617 coeff: f64,
618 terms: &mut HashMap<(String, String), f64>,
619 offset: &mut f64,
620 ) -> CompileResult<()> {
621 match expr {
623 SimpleExpr::Var(name) => {
624 *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
626 }
627 SimpleExpr::Add(left, right) => {
628 self.add_expr_squared_to_qubo(left, coeff, terms, offset)?;
630 self.add_expr_squared_to_qubo(right, coeff, terms, offset)?;
631 if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
633 {
634 let key = if v1 < v2 {
635 (v1.clone(), v2.clone())
636 } else {
637 (v2.clone(), v1.clone())
638 };
639 *terms.entry(key).or_insert(0.0) += 2.0 * coeff;
640 }
641 }
642 _ => {}
643 }
644 Ok(())
645 }
646}
647
648#[cfg(not(feature = "dwave"))]
650#[derive(Debug, Clone)]
651pub struct CompiledModel {
652 pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
654 pub var_map: HashMap<String, usize>,
656 pub offset: f64,
658 constraints: Vec<Constraint>,
660}
661
662#[cfg(not(feature = "dwave"))]
663impl CompiledModel {
664 pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
666 use quantrs2_anneal::ising::QuboModel;
667
668 let mut qubo = QuboModel::new(self.var_map.len());
669
670 qubo.offset = self.offset;
672
673 for i in 0..self.qubo_matrix.nrows() {
675 for j in i..self.qubo_matrix.ncols() {
676 let value = self.qubo_matrix[[i, j]];
677 if value.abs() > 1e-10 {
678 if i == j {
679 qubo.set_linear(i, value)
682 .expect("index within bounds from matrix dimensions");
683 } else {
684 qubo.set_quadratic(i, j, value)
687 .expect("indices within bounds from matrix dimensions");
688 }
689 }
690 }
691 }
692
693 qubo
694 }
695}
696
697#[cfg(feature = "dwave")]
702pub struct Compile {
703 expr: Expr,
705}
706
707#[cfg(feature = "dwave")]
708impl Compile {
709 pub fn new<T: Into<Expr>>(expr: T) -> Self {
711 Self { expr: expr.into() }
712 }
713
714 pub fn get_qubo(
725 &self,
726 ) -> CompileResult<(
727 (
728 Array<f64, scirs2_core::ndarray::Ix2>,
729 HashMap<String, usize>,
730 ),
731 f64,
732 )> {
733 #[cfg(feature = "scirs")]
734 {
735 self.get_qubo_scirs()
736 }
737 #[cfg(not(feature = "scirs"))]
738 {
739 self.get_qubo_standard()
740 }
741 }
742
743 fn get_qubo_standard(
745 &self,
746 ) -> CompileResult<(
747 (
748 Array<f64, scirs2_core::ndarray::Ix2>,
749 HashMap<String, usize>,
750 ),
751 f64,
752 )> {
753 let mut expr = self.expr.expand();
755
756 let max_degree = calc_highest_degree(&expr)?;
758 if max_degree > 2 {
759 return Err(CompileError::DegreeTooHigh(max_degree, 2));
760 }
761
762 let mut expr = replace_squared_terms(&expr)?;
764
765 let mut expr = expr.expand();
767
768 let (coeffs, offset) = extract_coefficients(&expr)?;
770
771 let (matrix, var_map) = build_qubo_matrix(&coeffs)?;
773
774 Ok(((matrix, var_map), offset))
775 }
776
777 #[cfg(feature = "scirs")]
779 fn get_qubo_scirs(
780 &self,
781 ) -> CompileResult<(
782 (
783 Array<f64, scirs2_core::ndarray::Ix2>,
784 HashMap<String, usize>,
785 ),
786 f64,
787 )> {
788 let ((matrix, var_map), offset) = self.get_qubo_standard()?;
790
791 let enhanced_matrix = crate::scirs_stub::enhance_qubo_matrix(&matrix);
793
794 Ok(((enhanced_matrix, var_map), offset))
795 }
796
797 pub fn get_hobo(
808 &self,
809 ) -> CompileResult<(
810 (
811 Array<f64, scirs2_core::ndarray::IxDyn>,
812 HashMap<String, usize>,
813 ),
814 f64,
815 )> {
816 let mut expr = self.expr.expand();
818
819 let max_degree = calc_highest_degree(&expr)?;
821
822 let mut expr = replace_squared_terms(&expr)?;
824
825 let mut expr = expr.expand();
827
828 let (coeffs, offset) = extract_coefficients(&expr)?;
830
831 let (tensor, var_map) = build_hobo_tensor(&coeffs, max_degree)?;
833
834 Ok(((tensor, var_map), offset))
835 }
836}
837
838#[cfg(feature = "dwave")]
840fn calc_highest_degree(expr: &Expr) -> CompileResult<usize> {
841 if expr.is_symbol() {
843 return Ok(1);
844 }
845
846 if expr.is_number() {
848 return Ok(0);
849 }
850
851 if expr.is_pow() {
853 let (base, exp) = expr.as_pow().expect("is_pow() was true");
855
856 if base.is_symbol() && exp.is_number() {
858 let exp_val = match exp.to_f64() {
859 Some(n) => n,
860 None => {
861 return Err(CompileError::InvalidExpression(
862 "Invalid exponent".to_string(),
863 ))
864 }
865 };
866
867 if exp_val.is_sign_positive() && exp_val.fract() == 0.0 {
869 return Ok(exp_val as usize);
870 }
871 }
872
873 let base_degree = calc_highest_degree(&base)?;
875 let exp_degree = if exp.is_number() {
876 match exp.to_f64() {
877 Some(n) => {
878 if n.is_sign_positive() && n.fract() == 0.0 {
879 n as usize
880 } else {
881 0 }
883 }
884 None => 0,
885 }
886 } else {
887 0 };
889
890 return Ok(base_degree * exp_degree);
891 }
892
893 if expr.is_mul() {
895 let mut total_degree = 0;
896 for factor in expr.as_mul().expect("is_mul() was true") {
898 total_degree += calc_highest_degree(&factor)?;
899 }
900 return Ok(total_degree);
901 }
902
903 if expr.is_add() {
905 let mut max_degree = 0;
906 for term in expr.as_add().expect("is_add() was true") {
908 let term_degree = calc_highest_degree(&term)?;
909 max_degree = std::cmp::max(max_degree, term_degree);
910 }
911 return Ok(max_degree);
912 }
913
914 let expr_str = format!("{expr}");
916 if expr_str.contains('+') || expr_str.contains('-') {
917 let mut max_degree = 0;
921
922 let parts: Vec<&str> = expr_str.split(['+', '-']).collect();
924
925 for part in parts {
926 let part = part.trim();
927 if part.is_empty() {
928 continue;
929 }
930
931 let degree = if part.contains("**") || part.contains('^') {
933 let exp_str = part
936 .split("**")
937 .nth(1)
938 .or_else(|| part.split('^').nth(1))
939 .unwrap_or("2")
940 .trim();
941 exp_str.parse::<usize>().unwrap_or(2)
942 } else if part.contains('*') {
943 let factors: Vec<&str> = part.split('*').collect();
945 let mut var_count = 0;
946 for factor in factors {
947 let factor = factor.trim();
948 if !factor.is_empty() && factor.parse::<f64>().is_err() {
950 var_count += 1;
951 }
952 }
953 var_count
954 } else if part.parse::<f64>().is_err() && !part.is_empty() {
955 1
957 } else {
958 0
960 };
961
962 max_degree = std::cmp::max(max_degree, degree);
963 }
964
965 return Ok(max_degree);
966 }
967
968 Err(CompileError::InvalidExpression(format!(
971 "Can't determine degree of: {expr}"
972 )))
973}
974
975#[cfg(feature = "dwave")]
977fn replace_squared_terms(expr: &Expr) -> CompileResult<Expr> {
978 if expr.is_symbol() || expr.is_number() {
982 return Ok(expr.clone());
983 }
984
985 if expr.is_pow() {
987 let (base, exp) = expr.as_pow().expect("is_pow() was true");
989
990 if base.is_symbol() && exp.is_number() {
992 let exp_val = match exp.to_f64() {
993 Some(n) => n,
994 None => {
995 return Err(CompileError::InvalidExpression(
996 "Invalid exponent".to_string(),
997 ))
998 }
999 };
1000
1001 if exp_val == 2.0 {
1003 return Ok(base);
1004 }
1005 }
1006
1007 let new_base = replace_squared_terms(&base)?;
1009 return Ok(new_base.pow(&exp));
1010 }
1011
1012 if expr.is_mul() {
1014 let mut new_terms = Vec::new();
1015 for factor in expr.as_mul().expect("is_mul() was true") {
1017 new_terms.push(replace_squared_terms(&factor)?);
1018 }
1019
1020 let mut result = Expr::from(1);
1022 for term in new_terms {
1023 result = result * term;
1024 }
1025 return Ok(result);
1026 }
1027
1028 if expr.is_add() {
1030 let mut new_terms = Vec::new();
1031 for term in expr.as_add().expect("is_add() was true") {
1033 new_terms.push(replace_squared_terms(&term)?);
1034 }
1035
1036 let mut result = Expr::from(0);
1038 for term in new_terms {
1039 result = result + term;
1040 }
1041 return Ok(result);
1042 }
1043
1044 Ok(expr.clone())
1046}
1047
1048#[cfg(feature = "dwave")]
1050fn extract_coefficients(expr: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1051 let mut coeffs = HashMap::new();
1052 let mut offset = 0.0;
1053
1054 if expr.is_add() {
1056 for term in expr.as_add().expect("is_add() was true") {
1058 let (term_coeffs, term_offset) = extract_term_coefficients(&term)?;
1059
1060 for (vars, coeff) in term_coeffs {
1062 *coeffs.entry(vars).or_insert(0.0) += coeff;
1063 }
1064
1065 offset += term_offset;
1067 }
1068 } else {
1069 let expr_str = format!("{expr}");
1071 if expr_str.contains('+') || expr_str.contains('-') {
1072 use regex::Regex;
1075 let re = Regex::new(r"([+-]?)([^+-]+)").expect("static regex pattern is valid");
1077
1078 for caps in re.captures_iter(&expr_str) {
1079 let sign = caps.get(1).map_or("", |m| m.as_str());
1080 let term = caps.get(2).map_or("", |m| m.as_str()).trim();
1081
1082 if term.is_empty() {
1083 continue;
1084 }
1085
1086 let sign_mult = if sign == "-" { -1.0 } else { 1.0 };
1087
1088 if term.contains("**") || term.contains('^') {
1090 let base = if term.contains("**") {
1091 term.split("**").next().unwrap_or(term)
1092 } else {
1093 term.split('^').next().unwrap_or(term)
1094 }
1095 .trim();
1096
1097 let (coeff_mult, var_name) = if base.contains('*') {
1099 let parts: Vec<&str> = base.split('*').collect();
1100 if parts.len() == 2 {
1101 if let Ok(num) = parts[0].trim().parse::<f64>() {
1102 (num, parts[1].trim().to_string())
1103 } else if let Ok(num) = parts[1].trim().parse::<f64>() {
1104 (num, parts[0].trim().to_string())
1105 } else {
1106 (1.0, base.to_string())
1107 }
1108 } else {
1109 (1.0, base.to_string())
1110 }
1111 } else {
1112 (1.0, base.to_string())
1113 };
1114
1115 let vars = vec![var_name.clone()];
1116 *coeffs.entry(vars).or_insert(0.0) += sign_mult * coeff_mult;
1117 } else if term.contains('*') {
1118 let parts: Vec<&str> = term.split('*').collect();
1120 let mut coeff = sign_mult;
1121 let mut vars = Vec::new();
1122
1123 for part in parts {
1124 let part = part.trim();
1125 if let Ok(num) = part.parse::<f64>() {
1126 coeff *= num;
1127 } else {
1128 vars.push(part.to_string());
1130 }
1131 }
1132
1133 vars.sort();
1135 *coeffs.entry(vars).or_insert(0.0) += coeff;
1136 } else if let Ok(num) = term.parse::<f64>() {
1137 offset += sign_mult * num;
1139 } else {
1140 let vars = vec![term.to_string()];
1142 *coeffs.entry(vars).or_insert(0.0) += sign_mult;
1143 }
1144 }
1145 return Ok((coeffs, offset));
1146 }
1147
1148 if coeffs.is_empty() {
1150 let (term_coeffs, term_offset) = extract_term_coefficients(expr)?;
1152
1153 for (vars, coeff) in term_coeffs {
1155 *coeffs.entry(vars).or_insert(0.0) += coeff;
1156 }
1157
1158 offset += term_offset;
1160 }
1161 }
1162
1163 Ok((coeffs, offset))
1164}
1165
1166#[cfg(feature = "dwave")]
1168fn extract_term_coefficients(term: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1169 let mut coeffs = HashMap::new();
1170
1171 if term.is_number() {
1173 let mut value = match term.to_f64() {
1174 Some(n) => n,
1175 None => {
1176 return Err(CompileError::InvalidExpression(
1177 "Invalid number".to_string(),
1178 ))
1179 }
1180 };
1181 return Ok((coeffs, value));
1182 }
1183
1184 if term.is_symbol() {
1186 let var_name = term.as_symbol().expect("is_symbol() was true");
1188 let vars = vec![var_name];
1189 coeffs.insert(vars, 1.0);
1190 return Ok((coeffs, 0.0));
1191 }
1192
1193 if term.is_mul() {
1195 let mut coeff = 1.0;
1196 let mut vars = Vec::new();
1197
1198 for factor in term.as_mul().expect("is_mul() was true") {
1200 if factor.is_number() {
1201 let value = match factor.to_f64() {
1203 Some(n) => n,
1204 None => {
1205 return Err(CompileError::InvalidExpression(
1206 "Invalid number in product".to_string(),
1207 ))
1208 }
1209 };
1210 coeff *= value;
1211 } else if factor.is_symbol() {
1212 let var_name = factor.as_symbol().expect("is_symbol() was true");
1215 vars.push(var_name);
1216 } else {
1217 return Err(CompileError::InvalidExpression(format!(
1219 "Unsupported term in product: {factor}"
1220 )));
1221 }
1222 }
1223
1224 vars.sort();
1226
1227 if vars.is_empty() {
1228 return Ok((coeffs, coeff));
1230 }
1231 coeffs.insert(vars, coeff);
1232
1233 return Ok((coeffs, 0.0));
1234 }
1235
1236 if term.is_pow() {
1238 return Err(CompileError::InvalidExpression(format!(
1239 "Unexpected power term after simplification: {term}"
1240 )));
1241 }
1242
1243 Err(CompileError::InvalidExpression(format!(
1245 "Unsupported term: {term}"
1246 )))
1247}
1248
1249#[allow(dead_code)]
1251fn build_qubo_matrix(
1252 coeffs: &HashMap<Vec<String>, f64>,
1253) -> CompileResult<(
1254 Array<f64, scirs2_core::ndarray::Ix2>,
1255 HashMap<String, usize>,
1256)> {
1257 let mut all_vars = HashSet::new();
1259 for vars in coeffs.keys() {
1260 for var in vars {
1261 all_vars.insert(var.clone());
1262 }
1263 }
1264
1265 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1267 sorted_vars.sort();
1268
1269 let var_map: HashMap<String, usize> = sorted_vars
1271 .iter()
1272 .enumerate()
1273 .map(|(i, var)| (var.clone(), i))
1274 .collect();
1275
1276 let n = var_map.len();
1278
1279 let mut matrix = Array::zeros((n, n));
1281
1282 for (vars, &coeff) in coeffs {
1284 match vars.len() {
1285 0 => {
1286 }
1288 1 => {
1289 let i = *var_map
1292 .get(&vars[0])
1293 .expect("variable exists in var_map built from coeffs");
1294 matrix[[i, i]] += coeff;
1295 }
1296 2 => {
1297 let i = *var_map
1300 .get(&vars[0])
1301 .expect("variable exists in var_map built from coeffs");
1302 let j = *var_map
1303 .get(&vars[1])
1304 .expect("variable exists in var_map built from coeffs");
1305
1306 if i == j {
1308 matrix[[i, i]] += coeff;
1310 } else {
1311 if i <= j {
1313 matrix[[i, j]] += coeff;
1314 } else {
1315 matrix[[j, i]] += coeff;
1316 }
1317 }
1318 }
1319 _ => {
1320 return Err(CompileError::DegreeTooHigh(vars.len(), 2));
1322 }
1323 }
1324 }
1325
1326 Ok((matrix, var_map))
1327}
1328
1329#[allow(dead_code)]
1331fn build_hobo_tensor(
1332 coeffs: &HashMap<Vec<String>, f64>,
1333 max_degree: usize,
1334) -> CompileResult<(
1335 Array<f64, scirs2_core::ndarray::IxDyn>,
1336 HashMap<String, usize>,
1337)> {
1338 let mut all_vars = HashSet::new();
1340 for vars in coeffs.keys() {
1341 for var in vars {
1342 all_vars.insert(var.clone());
1343 }
1344 }
1345
1346 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1348 sorted_vars.sort();
1349
1350 let var_map: HashMap<String, usize> = sorted_vars
1352 .iter()
1353 .enumerate()
1354 .map(|(i, var)| (var.clone(), i))
1355 .collect();
1356
1357 let n = var_map.len();
1359
1360 let shape: Vec<usize> = vec![n; max_degree];
1362
1363 let mut tensor = Array::zeros(scirs2_core::ndarray::IxDyn(&shape));
1365
1366 for (vars, &coeff) in coeffs {
1368 let degree = vars.len();
1369
1370 if degree == 0 {
1371 continue;
1373 }
1374
1375 if degree > max_degree {
1376 return Err(CompileError::DegreeTooHigh(degree, max_degree));
1377 }
1378
1379 let mut indices: Vec<usize> = vars
1382 .iter()
1383 .map(|var| {
1384 *var_map
1385 .get(var)
1386 .expect("variable exists in var_map built from coeffs")
1387 })
1388 .collect();
1389
1390 indices.sort_unstable();
1392
1393 while indices.len() < max_degree {
1395 indices.insert(0, indices[0]); }
1397
1398 let idx = scirs2_core::ndarray::IxDyn(&indices);
1400 tensor[idx] += coeff;
1401 }
1402
1403 Ok((tensor, var_map))
1404}
1405
1406#[cfg(feature = "dwave")]
1411pub struct PieckCompile {
1412 expr: Expr,
1414 verbose: bool,
1416}
1417
1418#[cfg(feature = "dwave")]
1419impl PieckCompile {
1420 pub fn new<T: Into<Expr>>(expr: T, verbose: bool) -> Self {
1422 Self {
1423 expr: expr.into(),
1424 verbose,
1425 }
1426 }
1427
1428 pub fn get_qubo(
1430 &self,
1431 ) -> CompileResult<(
1432 (
1433 Array<f64, scirs2_core::ndarray::Ix2>,
1434 HashMap<String, usize>,
1435 ),
1436 f64,
1437 )> {
1438 Compile::new(self.expr.clone()).get_qubo()
1441 }
1442}