1#![allow(dead_code)]
8
9use scirs2_core::ndarray::Array;
10use std::collections::{HashMap, HashSet};
11
12#[cfg(feature = "scirs")]
13use crate::scirs_stub;
14
15#[cfg(feature = "dwave")]
16use quantrs2_symengine_pure::Expression as SymEngineExpression;
17
18#[cfg(feature = "dwave")]
19type Expr = SymEngineExpression;
20use thiserror::Error;
21
22use quantrs2_anneal::QuboError;
23
24#[cfg(feature = "dwave")]
26pub mod expr {
27 use quantrs2_symengine_pure::Expression as SymEngineExpression;
28
29 pub type Expr = SymEngineExpression;
30
31 pub fn constant(value: f64) -> Expr {
32 SymEngineExpression::from(value)
33 }
34
35 pub fn var(name: &str) -> Expr {
36 SymEngineExpression::symbol(name)
37 }
38}
39
40#[cfg(not(feature = "dwave"))]
41pub mod expr {
42 use super::SimpleExpr;
43
44 pub type Expr = SimpleExpr;
45
46 pub const fn constant(value: f64) -> Expr {
47 SimpleExpr::constant(value)
48 }
49
50 pub fn var(name: &str) -> Expr {
51 SimpleExpr::var(name)
52 }
53}
54
55#[derive(Error, Debug)]
57pub enum CompileError {
58 #[error("Invalid expression: {0}")]
60 InvalidExpression(String),
61
62 #[error("Term has degree {0}, but maximum supported is {1}")]
64 DegreeTooHigh(usize, usize),
65
66 #[error("QUBO error: {0}")]
68 QuboError(#[from] QuboError),
69
70 #[error("Symengine error: {0}")]
72 SymengineError(String),
73}
74
75pub type CompileResult<T> = Result<T, CompileError>;
77
78#[cfg(not(feature = "dwave"))]
80#[derive(Debug, Clone)]
81pub enum SimpleExpr {
82 Var(String),
84 Const(f64),
86 Add(Box<Self>, Box<Self>),
88 Mul(Box<Self>, Box<Self>),
90 Pow(Box<Self>, i32),
92}
93
94#[cfg(not(feature = "dwave"))]
95impl SimpleExpr {
96 pub fn var(name: &str) -> Self {
98 Self::Var(name.to_string())
99 }
100
101 pub const fn constant(value: f64) -> Self {
103 Self::Const(value)
104 }
105}
106
107#[cfg(not(feature = "dwave"))]
108impl std::ops::Add for SimpleExpr {
109 type Output = Self;
110
111 fn add(self, rhs: Self) -> Self::Output {
112 Self::Add(Box::new(self), Box::new(rhs))
113 }
114}
115
116#[cfg(not(feature = "dwave"))]
117impl std::ops::Mul for SimpleExpr {
118 type Output = Self;
119
120 fn mul(self, rhs: Self) -> Self::Output {
121 Self::Mul(Box::new(self), Box::new(rhs))
122 }
123}
124
125#[cfg(not(feature = "dwave"))]
126impl std::iter::Sum for SimpleExpr {
127 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
128 iter.fold(Self::Const(0.0), |acc, x| acc + x)
129 }
130}
131
132#[cfg(feature = "dwave")]
134#[derive(Debug, Clone)]
135pub struct Model {
136 variables: HashSet<String>,
138 objective: Option<Expr>,
140 constraints: Vec<Constraint>,
142}
143
144#[cfg(feature = "dwave")]
146#[derive(Debug, Clone)]
147enum Constraint {
148 Equality {
150 name: String,
151 expr: Expr,
152 value: f64,
153 },
154 LessEqual {
156 name: String,
157 expr: Expr,
158 value: f64,
159 },
160 AtMostOne { name: String, variables: Vec<Expr> },
162 ImpliesAny {
164 name: String,
165 conditions: Vec<Expr>,
166 result: Expr,
167 },
168}
169
170#[cfg(feature = "dwave")]
171impl Default for Model {
172 fn default() -> Self {
173 Self::new()
174 }
175}
176
177#[cfg(feature = "dwave")]
178impl Model {
179 pub fn new() -> Self {
181 Self {
182 variables: HashSet::new(),
183 objective: None,
184 constraints: Vec::new(),
185 }
186 }
187
188 pub fn add_variable(&mut self, name: &str) -> CompileResult<Expr> {
190 self.variables.insert(name.to_string());
191 Ok(SymEngineExpression::symbol(name))
192 }
193
194 pub fn set_objective(&mut self, expr: Expr) {
196 self.objective = Some(expr);
197 }
198
199 pub fn add_constraint_eq_one(&mut self, name: &str, variables: Vec<Expr>) -> CompileResult<()> {
201 let sum_expr = variables
203 .iter()
204 .fold(Expr::from(0), |acc, v| acc + v.clone());
205 self.constraints.push(Constraint::Equality {
206 name: name.to_string(),
207 expr: sum_expr,
208 value: 1.0,
209 });
210 Ok(())
211 }
212
213 pub fn add_constraint_at_most_one(
215 &mut self,
216 name: &str,
217 variables: Vec<Expr>,
218 ) -> CompileResult<()> {
219 self.constraints.push(Constraint::AtMostOne {
220 name: name.to_string(),
221 variables,
222 });
223 Ok(())
224 }
225
226 pub fn add_constraint_implies_any(
228 &mut self,
229 name: &str,
230 conditions: Vec<Expr>,
231 result: Expr,
232 ) -> CompileResult<()> {
233 self.constraints.push(Constraint::ImpliesAny {
234 name: name.to_string(),
235 conditions,
236 result,
237 });
238 Ok(())
239 }
240
241 pub fn compile(&self) -> CompileResult<CompiledModel> {
243 let mut final_expr = self.objective.clone().unwrap_or_else(|| Expr::from(0));
245
246 let penalty_weight = 10.0;
248
249 for constraint in &self.constraints {
251 match constraint {
252 Constraint::Equality { expr, value, .. } => {
253 let diff = expr.clone() - Expr::from(*value);
255 final_expr = final_expr + Expr::from(penalty_weight) * diff.clone() * diff;
256 }
257 #[cfg(feature = "dwave")]
258 Constraint::LessEqual { expr, value, .. } => {
259 let excess = expr.clone() - Expr::from(*value);
262 final_expr = final_expr + Expr::from(penalty_weight) * excess.clone() * excess;
263 }
264 Constraint::AtMostOne { variables, .. } => {
265 for i in 0..variables.len() {
267 for j in (i + 1)..variables.len() {
268 final_expr = final_expr
269 + Expr::from(penalty_weight)
270 * variables[i].clone()
271 * variables[j].clone();
272 }
273 }
274 }
275 Constraint::ImpliesAny {
276 conditions, result, ..
277 } => {
278 let conditions_sum = conditions
281 .iter()
282 .fold(Expr::from(0), |acc, c| acc + c.clone());
283 final_expr = final_expr
285 + Expr::from(penalty_weight)
286 * conditions_sum
287 * (Expr::from(1) - result.clone());
288 }
289 }
290 }
291
292 let mut compiler = Compile::new(final_expr);
294 let ((qubo_matrix, var_map), offset) = compiler.get_qubo()?;
295
296 Ok(CompiledModel {
297 qubo_matrix,
298 var_map,
299 offset,
300 constraints: self.constraints.clone(),
301 })
302 }
303}
304
305#[cfg(feature = "dwave")]
307#[derive(Debug, Clone)]
308pub struct CompiledModel {
309 pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
311 pub var_map: HashMap<String, usize>,
313 pub offset: f64,
315 constraints: Vec<Constraint>,
317}
318
319#[cfg(feature = "dwave")]
320impl CompiledModel {
321 pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
323 use quantrs2_anneal::ising::QuboModel;
324
325 let mut qubo = QuboModel::new(self.var_map.len());
326
327 qubo.offset = self.offset;
329
330 for i in 0..self.qubo_matrix.nrows() {
332 for j in i..self.qubo_matrix.ncols() {
333 let value = self.qubo_matrix[[i, j]];
334 if value.abs() > 1e-10 {
335 if i == j {
336 qubo.set_linear(i, value)
339 .expect("index within bounds from matrix dimensions");
340 } else {
341 qubo.set_quadratic(i, j, value)
344 .expect("indices within bounds from matrix dimensions");
345 }
346 }
347 }
348 }
349
350 qubo
351 }
352}
353
354#[cfg(not(feature = "dwave"))]
356#[derive(Debug, Clone)]
357pub struct Model {
358 variables: HashSet<String>,
360 objective: Option<SimpleExpr>,
362 constraints: Vec<Constraint>,
364}
365
366#[cfg(not(feature = "dwave"))]
368#[derive(Debug, Clone)]
369enum Constraint {
370 Equality {
372 name: String,
373 expr: SimpleExpr,
374 value: f64,
375 },
376 AtMostOne {
378 name: String,
379 variables: Vec<SimpleExpr>,
380 },
381 ImpliesAny {
383 name: String,
384 conditions: Vec<SimpleExpr>,
385 result: SimpleExpr,
386 },
387}
388
389#[cfg(not(feature = "dwave"))]
390impl Default for Model {
391 fn default() -> Self {
392 Self::new()
393 }
394}
395
396#[cfg(not(feature = "dwave"))]
397impl Model {
398 pub fn new() -> Self {
400 Self {
401 variables: HashSet::new(),
402 objective: None,
403 constraints: Vec::new(),
404 }
405 }
406
407 pub fn add_variable(&mut self, name: &str) -> CompileResult<SimpleExpr> {
409 self.variables.insert(name.to_string());
410 Ok(SimpleExpr::var(name))
411 }
412
413 pub fn set_objective(&mut self, expr: SimpleExpr) {
415 self.objective = Some(expr);
416 }
417
418 pub fn add_constraint_eq_one(
420 &mut self,
421 name: &str,
422 variables: Vec<SimpleExpr>,
423 ) -> CompileResult<()> {
424 let sum_expr = variables.into_iter().sum();
425 self.constraints.push(Constraint::Equality {
426 name: name.to_string(),
427 expr: sum_expr,
428 value: 1.0,
429 });
430 Ok(())
431 }
432
433 pub fn add_constraint_at_most_one(
435 &mut self,
436 name: &str,
437 variables: Vec<SimpleExpr>,
438 ) -> CompileResult<()> {
439 self.constraints.push(Constraint::AtMostOne {
440 name: name.to_string(),
441 variables,
442 });
443 Ok(())
444 }
445
446 pub fn add_constraint_implies_any(
448 &mut self,
449 name: &str,
450 conditions: Vec<SimpleExpr>,
451 result: SimpleExpr,
452 ) -> CompileResult<()> {
453 self.constraints.push(Constraint::ImpliesAny {
454 name: name.to_string(),
455 conditions,
456 result,
457 });
458 Ok(())
459 }
460
461 pub fn compile(&self) -> CompileResult<CompiledModel> {
463 let mut qubo_terms: HashMap<(String, String), f64> = HashMap::new();
465 let mut offset = 0.0;
466 let penalty_weight = 10.0;
467
468 if let Some(ref obj) = self.objective {
470 self.add_expr_to_qubo(obj, 1.0, &mut qubo_terms, &mut offset)?;
471 }
472
473 for constraint in &self.constraints {
475 match constraint {
476 Constraint::Equality { expr, value, .. } => {
477 self.add_expr_squared_to_qubo(
480 expr,
481 penalty_weight,
482 &mut qubo_terms,
483 &mut offset,
484 )?;
485 self.add_expr_to_qubo(
486 expr,
487 -2.0 * penalty_weight * value,
488 &mut qubo_terms,
489 &mut offset,
490 )?;
491 offset += penalty_weight * value * value;
492 }
493 Constraint::AtMostOne { variables, .. } => {
494 for i in 0..variables.len() {
496 for j in (i + 1)..variables.len() {
497 if let (SimpleExpr::Var(vi), SimpleExpr::Var(vj)) =
498 (&variables[i], &variables[j])
499 {
500 let key = if vi < vj {
501 (vi.clone(), vj.clone())
502 } else {
503 (vj.clone(), vi.clone())
504 };
505 *qubo_terms.entry(key).or_insert(0.0) += penalty_weight;
506 }
507 }
508 }
509 }
510 Constraint::ImpliesAny {
511 conditions, result, ..
512 } => {
513 for cond in conditions {
515 if let (SimpleExpr::Var(c), SimpleExpr::Var(r)) = (cond, result) {
516 let key = if c < r {
517 (c.clone(), r.clone())
518 } else {
519 (r.clone(), c.clone())
520 };
521 *qubo_terms.entry(key).or_insert(0.0) -= penalty_weight;
522 }
523 if let SimpleExpr::Var(c) = cond {
525 *qubo_terms.entry((c.clone(), c.clone())).or_insert(0.0) +=
526 penalty_weight;
527 }
528 }
529 }
530 }
531 }
532
533 let all_vars: HashSet<String> = qubo_terms
535 .keys()
536 .flat_map(|(v1, v2)| vec![v1.clone(), v2.clone()])
537 .collect();
538 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
539 sorted_vars.sort();
540
541 let var_map: HashMap<String, usize> = sorted_vars
542 .iter()
543 .enumerate()
544 .map(|(i, v)| (v.clone(), i))
545 .collect();
546
547 let n = var_map.len();
548 let mut matrix = Array::zeros((n, n));
549
550 for ((v1, v2), coeff) in qubo_terms {
551 let i = var_map[&v1];
552 let j = var_map[&v2];
553 if i == j {
554 matrix[[i, i]] += coeff;
555 } else {
556 matrix[[i, j]] += coeff / 2.0;
557 matrix[[j, i]] += coeff / 2.0;
558 }
559 }
560
561 Ok(CompiledModel {
562 qubo_matrix: matrix,
563 var_map,
564 offset,
565 constraints: self.constraints.clone(),
566 })
567 }
568
569 fn add_expr_to_qubo(
571 &self,
572 expr: &SimpleExpr,
573 coeff: f64,
574 terms: &mut HashMap<(String, String), f64>,
575 offset: &mut f64,
576 ) -> CompileResult<()> {
577 match expr {
578 SimpleExpr::Var(name) => {
579 *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
580 }
581 SimpleExpr::Const(val) => {
582 *offset += coeff * val;
583 }
584 SimpleExpr::Add(left, right) => {
585 self.add_expr_to_qubo(left, coeff, terms, offset)?;
586 self.add_expr_to_qubo(right, coeff, terms, offset)?;
587 }
588 SimpleExpr::Mul(left, right) => {
589 if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
590 {
591 let key = if v1 < v2 {
592 (v1.clone(), v2.clone())
593 } else {
594 (v2.clone(), v1.clone())
595 };
596 *terms.entry(key).or_insert(0.0) += coeff;
597 } else if let (SimpleExpr::Const(c), var) | (var, SimpleExpr::Const(c)) =
598 (left.as_ref(), right.as_ref())
599 {
600 self.add_expr_to_qubo(var, coeff * c, terms, offset)?;
601 }
602 }
603 SimpleExpr::Pow(base, exp) => {
604 if *exp == 2 && matches!(base.as_ref(), SimpleExpr::Var(_)) {
605 self.add_expr_to_qubo(base, coeff, terms, offset)?;
607 }
608 }
609 }
610 Ok(())
611 }
612
613 fn add_expr_squared_to_qubo(
615 &self,
616 expr: &SimpleExpr,
617 coeff: f64,
618 terms: &mut HashMap<(String, String), f64>,
619 offset: &mut f64,
620 ) -> CompileResult<()> {
621 match expr {
623 SimpleExpr::Var(name) => {
624 *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
626 }
627 SimpleExpr::Add(left, right) => {
628 self.add_expr_squared_to_qubo(left, coeff, terms, offset)?;
630 self.add_expr_squared_to_qubo(right, coeff, terms, offset)?;
631 if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
633 {
634 let key = if v1 < v2 {
635 (v1.clone(), v2.clone())
636 } else {
637 (v2.clone(), v1.clone())
638 };
639 *terms.entry(key).or_insert(0.0) += 2.0 * coeff;
640 }
641 }
642 _ => {}
643 }
644 Ok(())
645 }
646}
647
648#[cfg(not(feature = "dwave"))]
650#[derive(Debug, Clone)]
651pub struct CompiledModel {
652 pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
654 pub var_map: HashMap<String, usize>,
656 pub offset: f64,
658 constraints: Vec<Constraint>,
660}
661
662#[cfg(not(feature = "dwave"))]
663impl CompiledModel {
664 pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
666 use quantrs2_anneal::ising::QuboModel;
667
668 let mut qubo = QuboModel::new(self.var_map.len());
669
670 qubo.offset = self.offset;
672
673 for i in 0..self.qubo_matrix.nrows() {
675 for j in i..self.qubo_matrix.ncols() {
676 let value = self.qubo_matrix[[i, j]];
677 if value.abs() > 1e-10 {
678 if i == j {
679 qubo.set_linear(i, value)
682 .expect("index within bounds from matrix dimensions");
683 } else {
684 qubo.set_quadratic(i, j, value)
687 .expect("indices within bounds from matrix dimensions");
688 }
689 }
690 }
691 }
692
693 qubo
694 }
695}
696
697#[cfg(feature = "dwave")]
702pub struct Compile {
703 expr: Expr,
705}
706
707#[cfg(feature = "dwave")]
708impl Compile {
709 pub fn new<T: Into<Expr>>(expr: T) -> Self {
711 Self { expr: expr.into() }
712 }
713
714 pub fn get_qubo(
725 &self,
726 ) -> CompileResult<(
727 (
728 Array<f64, scirs2_core::ndarray::Ix2>,
729 HashMap<String, usize>,
730 ),
731 f64,
732 )> {
733 #[cfg(feature = "scirs")]
734 {
735 self.get_qubo_scirs()
736 }
737 #[cfg(not(feature = "scirs"))]
738 {
739 self.get_qubo_standard()
740 }
741 }
742
743 fn get_qubo_standard(
745 &self,
746 ) -> CompileResult<(
747 (
748 Array<f64, scirs2_core::ndarray::Ix2>,
749 HashMap<String, usize>,
750 ),
751 f64,
752 )> {
753 let expr = self.expr.expand();
755
756 let max_degree = calc_highest_degree(&expr)?;
758 if max_degree > 2 {
759 return Err(CompileError::DegreeTooHigh(max_degree, 2));
760 }
761
762 let expr = replace_squared_terms(&expr)?;
764
765 let (coeffs, offset) = extract_coefficients(&expr)?;
767
768 let (matrix, var_map) = build_qubo_matrix(&coeffs)?;
770
771 Ok(((matrix, var_map), offset))
772 }
773
774 #[cfg(feature = "scirs")]
776 fn get_qubo_scirs(
777 &self,
778 ) -> CompileResult<(
779 (
780 Array<f64, scirs2_core::ndarray::Ix2>,
781 HashMap<String, usize>,
782 ),
783 f64,
784 )> {
785 let ((matrix, var_map), offset) = self.get_qubo_standard()?;
787
788 let enhanced_matrix = crate::scirs_stub::enhance_qubo_matrix(&matrix);
790
791 Ok(((enhanced_matrix, var_map), offset))
792 }
793
794 pub fn get_hobo(
805 &self,
806 ) -> CompileResult<(
807 (
808 Array<f64, scirs2_core::ndarray::IxDyn>,
809 HashMap<String, usize>,
810 ),
811 f64,
812 )> {
813 let mut expr = self.expr.expand();
815
816 let max_degree = calc_highest_degree(&expr)?;
818
819 let mut expr = replace_squared_terms(&expr)?;
821
822 let mut expr = expr.expand();
824
825 let (coeffs, offset) = extract_coefficients(&expr)?;
827
828 let (tensor, var_map) = build_hobo_tensor(&coeffs, max_degree)?;
830
831 Ok(((tensor, var_map), offset))
832 }
833}
834
835#[cfg(feature = "dwave")]
837fn calc_highest_degree(expr: &Expr) -> CompileResult<usize> {
838 if expr.is_symbol() {
840 return Ok(1);
841 }
842
843 if expr.is_number() {
845 return Ok(0);
846 }
847
848 if expr.is_neg() {
850 let inner = expr.as_neg().expect("is_neg() was true");
852 return calc_highest_degree(&inner);
853 }
854
855 if expr.is_pow() {
857 let (base, exp) = expr.as_pow().expect("is_pow() was true");
859
860 if base.is_symbol() && exp.is_number() {
862 let exp_val = match exp.to_f64() {
863 Some(n) => n,
864 None => {
865 return Err(CompileError::InvalidExpression(
866 "Invalid exponent".to_string(),
867 ))
868 }
869 };
870
871 if exp_val.is_sign_positive() && exp_val.fract() == 0.0 {
873 return Ok(exp_val as usize);
874 }
875 }
876
877 let base_degree = calc_highest_degree(&base)?;
879 let exp_degree = if exp.is_number() {
880 match exp.to_f64() {
881 Some(n) => {
882 if n.is_sign_positive() && n.fract() == 0.0 {
883 n as usize
884 } else {
885 0 }
887 }
888 None => 0,
889 }
890 } else {
891 0 };
893
894 return Ok(base_degree * exp_degree);
895 }
896
897 if expr.is_mul() {
899 let mut total_degree = 0;
900 for factor in expr.as_mul().expect("is_mul() was true") {
902 total_degree += calc_highest_degree(&factor)?;
903 }
904 return Ok(total_degree);
905 }
906
907 if expr.is_add() {
909 let mut max_degree = 0;
910 for term in expr.as_add().expect("is_add() was true") {
912 let term_degree = calc_highest_degree(&term)?;
913 max_degree = std::cmp::max(max_degree, term_degree);
914 }
915 return Ok(max_degree);
916 }
917
918 let expr_str = format!("{expr}");
920 if expr_str.contains('+') || expr_str.contains('-') {
921 let mut max_degree = 0;
925
926 let parts: Vec<&str> = expr_str.split(['+', '-']).collect();
928
929 for part in parts {
930 let part = part.trim();
931 if part.is_empty() {
932 continue;
933 }
934
935 let degree = if part.contains("**") || part.contains('^') {
937 let exp_str = part
940 .split("**")
941 .nth(1)
942 .or_else(|| part.split('^').nth(1))
943 .unwrap_or("2")
944 .trim();
945 exp_str.parse::<usize>().unwrap_or(2)
946 } else if part.contains('*') {
947 let factors: Vec<&str> = part.split('*').collect();
949 let mut var_count = 0;
950 for factor in factors {
951 let factor = factor.trim();
952 if !factor.is_empty() && factor.parse::<f64>().is_err() {
954 var_count += 1;
955 }
956 }
957 var_count
958 } else if part.parse::<f64>().is_err() && !part.is_empty() {
959 1
961 } else {
962 0
964 };
965
966 max_degree = std::cmp::max(max_degree, degree);
967 }
968
969 return Ok(max_degree);
970 }
971
972 Err(CompileError::InvalidExpression(format!(
975 "Can't determine degree of: {expr}"
976 )))
977}
978
979#[cfg(feature = "dwave")]
981fn replace_squared_terms(expr: &Expr) -> CompileResult<Expr> {
982 if expr.is_symbol() || expr.is_number() {
986 return Ok(expr.clone());
987 }
988
989 if expr.is_neg() {
991 let inner = expr.as_neg().expect("is_neg() was true");
993 let new_inner = replace_squared_terms(&inner)?;
994 return Ok(-new_inner);
995 }
996
997 if expr.is_pow() {
999 let (base, exp) = expr.as_pow().expect("is_pow() was true");
1001
1002 if base.is_symbol() && exp.is_number() {
1004 let exp_val = match exp.to_f64() {
1005 Some(n) => n,
1006 None => {
1007 return Err(CompileError::InvalidExpression(
1008 "Invalid exponent".to_string(),
1009 ))
1010 }
1011 };
1012
1013 if exp_val == 2.0 {
1015 return Ok(base);
1016 }
1017 }
1018
1019 let new_base = replace_squared_terms(&base)?;
1021 return Ok(new_base.pow(&exp));
1022 }
1023
1024 if expr.is_mul() {
1026 let mut new_terms = Vec::new();
1027 for factor in expr.as_mul().expect("is_mul() was true") {
1029 new_terms.push(replace_squared_terms(&factor)?);
1030 }
1031
1032 if new_terms.len() == 2 {
1035 if let (Some(name1), Some(name2)) = (new_terms[0].as_symbol(), new_terms[1].as_symbol())
1036 {
1037 if name1 == name2 {
1038 return Ok(new_terms.remove(0));
1040 }
1041 }
1042 }
1043
1044 if new_terms.is_empty() {
1046 return Ok(Expr::from(1));
1047 }
1048 let mut result = new_terms.remove(0);
1049 for term in new_terms {
1050 result = result * term;
1051 }
1052 return Ok(result);
1053 }
1054
1055 if expr.is_add() {
1057 let mut new_terms = Vec::new();
1058 for term in expr.as_add().expect("is_add() was true") {
1060 new_terms.push(replace_squared_terms(&term)?);
1061 }
1062
1063 if new_terms.is_empty() {
1065 return Ok(Expr::from(0));
1066 }
1067 let mut result = new_terms.remove(0);
1068 for term in new_terms {
1069 result = result + term;
1070 }
1071 return Ok(result);
1072 }
1073
1074 Ok(expr.clone())
1076}
1077
1078#[cfg(feature = "dwave")]
1080fn extract_coefficients(expr: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1081 let mut coeffs = HashMap::new();
1082 let mut offset = 0.0;
1083
1084 if expr.is_add() {
1086 for term in expr.as_add().expect("is_add() was true") {
1088 let (term_coeffs, term_offset) = extract_term_coefficients(&term)?;
1089
1090 for (vars, coeff) in term_coeffs {
1092 *coeffs.entry(vars).or_insert(0.0) += coeff;
1093 }
1094
1095 offset += term_offset;
1097 }
1098 } else {
1099 let expr_str = format!("{expr}");
1101 if expr_str.contains('+') || expr_str.contains('-') {
1102 use regex::Regex;
1105 let re = Regex::new(r"([+-]?)([^+-]+)").expect("static regex pattern is valid");
1107
1108 for caps in re.captures_iter(&expr_str) {
1109 let sign = caps.get(1).map_or("", |m| m.as_str());
1110 let term = caps.get(2).map_or("", |m| m.as_str()).trim();
1111
1112 if term.is_empty() {
1113 continue;
1114 }
1115
1116 let sign_mult = if sign == "-" { -1.0 } else { 1.0 };
1117
1118 if term.contains("**") || term.contains('^') {
1120 let base = if term.contains("**") {
1121 term.split("**").next().unwrap_or(term)
1122 } else {
1123 term.split('^').next().unwrap_or(term)
1124 }
1125 .trim();
1126
1127 let (coeff_mult, var_name) = if base.contains('*') {
1129 let parts: Vec<&str> = base.split('*').collect();
1130 if parts.len() == 2 {
1131 if let Ok(num) = parts[0].trim().parse::<f64>() {
1132 (num, parts[1].trim().to_string())
1133 } else if let Ok(num) = parts[1].trim().parse::<f64>() {
1134 (num, parts[0].trim().to_string())
1135 } else {
1136 (1.0, base.to_string())
1137 }
1138 } else {
1139 (1.0, base.to_string())
1140 }
1141 } else {
1142 (1.0, base.to_string())
1143 };
1144
1145 let vars = vec![var_name.clone()];
1146 *coeffs.entry(vars).or_insert(0.0) += sign_mult * coeff_mult;
1147 } else if term.contains('*') {
1148 let parts: Vec<&str> = term.split('*').collect();
1150 let mut coeff = sign_mult;
1151 let mut vars = Vec::new();
1152
1153 for part in parts {
1154 let part = part.trim();
1155 if let Ok(num) = part.parse::<f64>() {
1156 coeff *= num;
1157 } else {
1158 vars.push(part.to_string());
1160 }
1161 }
1162
1163 vars.sort();
1165 *coeffs.entry(vars).or_insert(0.0) += coeff;
1166 } else if let Ok(num) = term.parse::<f64>() {
1167 offset += sign_mult * num;
1169 } else {
1170 let vars = vec![term.to_string()];
1172 *coeffs.entry(vars).or_insert(0.0) += sign_mult;
1173 }
1174 }
1175 return Ok((coeffs, offset));
1176 }
1177
1178 if coeffs.is_empty() {
1180 let (term_coeffs, term_offset) = extract_term_coefficients(expr)?;
1182
1183 for (vars, coeff) in term_coeffs {
1185 *coeffs.entry(vars).or_insert(0.0) += coeff;
1186 }
1187
1188 offset += term_offset;
1190 }
1191 }
1192
1193 Ok((coeffs, offset))
1194}
1195
1196#[cfg(feature = "dwave")]
1198fn extract_term_coefficients(term: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1199 let mut coeffs = HashMap::new();
1200
1201 if term.is_number() {
1203 let value = match term.to_f64() {
1204 Some(n) => n,
1205 None => {
1206 return Err(CompileError::InvalidExpression(
1207 "Invalid number".to_string(),
1208 ))
1209 }
1210 };
1211 return Ok((coeffs, value));
1212 }
1213
1214 if term.is_add() {
1216 let mut offset = 0.0;
1217 for sub_term in term.as_add().expect("is_add() was true") {
1219 let (sub_coeffs, sub_offset) = extract_term_coefficients(&sub_term)?;
1220 for (vars, coeff) in sub_coeffs {
1221 *coeffs.entry(vars).or_insert(0.0) += coeff;
1222 }
1223 offset += sub_offset;
1224 }
1225 return Ok((coeffs, offset));
1226 }
1227
1228 if term.is_neg() {
1230 let inner = term.as_neg().expect("is_neg() was true");
1232 let (inner_coeffs, inner_offset) = extract_term_coefficients(&inner)?;
1233
1234 for (vars, coeff) in inner_coeffs {
1236 coeffs.insert(vars, -coeff);
1237 }
1238
1239 return Ok((coeffs, -inner_offset));
1240 }
1241
1242 if term.is_symbol() {
1244 let var_name = term.as_symbol().expect("is_symbol() was true");
1246 let vars = vec![var_name.to_string()];
1247 coeffs.insert(vars, 1.0);
1248 return Ok((coeffs, 0.0));
1249 }
1250
1251 if term.is_mul() {
1253 let mut coeff = 1.0;
1254 let mut vars = Vec::new();
1255
1256 for factor in term.as_mul().expect("is_mul() was true") {
1258 if factor.is_number() {
1259 let value = match factor.to_f64() {
1261 Some(n) => n,
1262 None => {
1263 return Err(CompileError::InvalidExpression(
1264 "Invalid number in product".to_string(),
1265 ))
1266 }
1267 };
1268 coeff *= value;
1269 } else if factor.is_symbol() {
1270 let var_name = factor.as_symbol().expect("is_symbol() was true");
1273 vars.push(var_name.to_string());
1274 } else {
1275 return Err(CompileError::InvalidExpression(format!(
1277 "Unsupported term in product: {factor}"
1278 )));
1279 }
1280 }
1281
1282 vars.sort();
1284
1285 if vars.is_empty() {
1286 return Ok((coeffs, coeff));
1288 }
1289 coeffs.insert(vars, coeff);
1290
1291 return Ok((coeffs, 0.0));
1292 }
1293
1294 if term.is_pow() {
1296 return Err(CompileError::InvalidExpression(format!(
1297 "Unexpected power term after simplification: {term}"
1298 )));
1299 }
1300
1301 Err(CompileError::InvalidExpression(format!(
1303 "Unsupported term: {term}"
1304 )))
1305}
1306
1307#[allow(dead_code)]
1309fn build_qubo_matrix(
1310 coeffs: &HashMap<Vec<String>, f64>,
1311) -> CompileResult<(
1312 Array<f64, scirs2_core::ndarray::Ix2>,
1313 HashMap<String, usize>,
1314)> {
1315 let mut all_vars = HashSet::new();
1317 for vars in coeffs.keys() {
1318 for var in vars {
1319 all_vars.insert(var.clone());
1320 }
1321 }
1322
1323 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1325 sorted_vars.sort();
1326
1327 let var_map: HashMap<String, usize> = sorted_vars
1329 .iter()
1330 .enumerate()
1331 .map(|(i, var)| (var.clone(), i))
1332 .collect();
1333
1334 let n = var_map.len();
1336
1337 let mut matrix = Array::zeros((n, n));
1339
1340 for (vars, &coeff) in coeffs {
1342 match vars.len() {
1343 0 => {
1344 }
1346 1 => {
1347 let i = *var_map
1350 .get(&vars[0])
1351 .expect("variable exists in var_map built from coeffs");
1352 matrix[[i, i]] += coeff;
1353 }
1354 2 => {
1355 let i = *var_map
1358 .get(&vars[0])
1359 .expect("variable exists in var_map built from coeffs");
1360 let j = *var_map
1361 .get(&vars[1])
1362 .expect("variable exists in var_map built from coeffs");
1363
1364 if i == j {
1366 matrix[[i, i]] += coeff;
1368 } else {
1369 if i <= j {
1371 matrix[[i, j]] += coeff;
1372 } else {
1373 matrix[[j, i]] += coeff;
1374 }
1375 }
1376 }
1377 _ => {
1378 return Err(CompileError::DegreeTooHigh(vars.len(), 2));
1380 }
1381 }
1382 }
1383
1384 Ok((matrix, var_map))
1385}
1386
1387#[allow(dead_code)]
1389fn build_hobo_tensor(
1390 coeffs: &HashMap<Vec<String>, f64>,
1391 max_degree: usize,
1392) -> CompileResult<(
1393 Array<f64, scirs2_core::ndarray::IxDyn>,
1394 HashMap<String, usize>,
1395)> {
1396 let mut all_vars = HashSet::new();
1398 for vars in coeffs.keys() {
1399 for var in vars {
1400 all_vars.insert(var.clone());
1401 }
1402 }
1403
1404 let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1406 sorted_vars.sort();
1407
1408 let var_map: HashMap<String, usize> = sorted_vars
1410 .iter()
1411 .enumerate()
1412 .map(|(i, var)| (var.clone(), i))
1413 .collect();
1414
1415 let n = var_map.len();
1417
1418 let shape: Vec<usize> = vec![n; max_degree];
1420
1421 let mut tensor = Array::zeros(scirs2_core::ndarray::IxDyn(&shape));
1423
1424 for (vars, &coeff) in coeffs {
1426 let degree = vars.len();
1427
1428 if degree == 0 {
1429 continue;
1431 }
1432
1433 if degree > max_degree {
1434 return Err(CompileError::DegreeTooHigh(degree, max_degree));
1435 }
1436
1437 let mut indices: Vec<usize> = vars
1440 .iter()
1441 .map(|var| {
1442 *var_map
1443 .get(var)
1444 .expect("variable exists in var_map built from coeffs")
1445 })
1446 .collect();
1447
1448 indices.sort_unstable();
1450
1451 while indices.len() < max_degree {
1453 indices.insert(0, indices[0]); }
1455
1456 let idx = scirs2_core::ndarray::IxDyn(&indices);
1458 tensor[idx] += coeff;
1459 }
1460
1461 Ok((tensor, var_map))
1462}
1463
1464#[cfg(feature = "dwave")]
1469pub struct PieckCompile {
1470 expr: Expr,
1472 verbose: bool,
1474}
1475
1476#[cfg(feature = "dwave")]
1477impl PieckCompile {
1478 pub fn new<T: Into<Expr>>(expr: T, verbose: bool) -> Self {
1480 Self {
1481 expr: expr.into(),
1482 verbose,
1483 }
1484 }
1485
1486 pub fn get_qubo(
1488 &self,
1489 ) -> CompileResult<(
1490 (
1491 Array<f64, scirs2_core::ndarray::Ix2>,
1492 HashMap<String, usize>,
1493 ),
1494 f64,
1495 )> {
1496 Compile::new(self.expr.clone()).get_qubo()
1499 }
1500}