quantrs2_tytan/
compile.rs

1//! Compilation of symbolic expressions to QUBO/HOBO models.
2//!
3//! This module provides utilities for compiling symbolic expressions
4//! into QUBO (Quadratic Unconstrained Binary Optimization) and
5//! HOBO (Higher-Order Binary Optimization) models.
6
7#![allow(dead_code)]
8
9use scirs2_core::ndarray::Array;
10use std::collections::{HashMap, HashSet};
11
12#[cfg(feature = "scirs")]
13use crate::scirs_stub;
14
15#[cfg(feature = "dwave")]
16use quantrs2_symengine::Expression as SymEngineExpression;
17
18#[cfg(feature = "dwave")]
19type Expr = SymEngineExpression;
20use thiserror::Error;
21
22use quantrs2_anneal::QuboError;
23
24/// Unified expression interface for examples
25#[cfg(feature = "dwave")]
26pub mod expr {
27    use quantrs2_symengine::Expression as SymEngineExpression;
28
29    pub type Expr = SymEngineExpression;
30
31    pub fn constant(value: f64) -> Expr {
32        SymEngineExpression::from_f64(value)
33    }
34
35    pub fn var(name: &str) -> Expr {
36        SymEngineExpression::symbol(name)
37    }
38}
39
40#[cfg(not(feature = "dwave"))]
41pub mod expr {
42    use super::SimpleExpr;
43
44    pub type Expr = SimpleExpr;
45
46    pub const fn constant(value: f64) -> Expr {
47        SimpleExpr::constant(value)
48    }
49
50    pub fn var(name: &str) -> Expr {
51        SimpleExpr::var(name)
52    }
53}
54
55/// Errors that can occur during compilation
56#[derive(Error, Debug)]
57pub enum CompileError {
58    /// Error when the expression is invalid
59    #[error("Invalid expression: {0}")]
60    InvalidExpression(String),
61
62    /// Error when a term has too high a degree
63    #[error("Term has degree {0}, but maximum supported is {1}")]
64    DegreeTooHigh(usize, usize),
65
66    /// Error in the underlying QUBO model
67    #[error("QUBO error: {0}")]
68    QuboError(#[from] QuboError),
69
70    /// Error in Symengine operations
71    #[error("Symengine error: {0}")]
72    SymengineError(String),
73}
74
75/// Result type for compilation operations
76pub type CompileResult<T> = Result<T, CompileError>;
77
78// Simple expression type for when dwave feature is not enabled
79#[cfg(not(feature = "dwave"))]
80#[derive(Debug, Clone)]
81pub enum SimpleExpr {
82    /// Variable
83    Var(String),
84    /// Constant
85    Const(f64),
86    /// Addition
87    Add(Box<Self>, Box<Self>),
88    /// Multiplication
89    Mul(Box<Self>, Box<Self>),
90    /// Power
91    Pow(Box<Self>, i32),
92}
93
94#[cfg(not(feature = "dwave"))]
95impl SimpleExpr {
96    /// Create a variable
97    pub fn var(name: &str) -> Self {
98        Self::Var(name.to_string())
99    }
100
101    /// Create a constant
102    pub const fn constant(value: f64) -> Self {
103        Self::Const(value)
104    }
105}
106
107#[cfg(not(feature = "dwave"))]
108impl std::ops::Add for SimpleExpr {
109    type Output = Self;
110
111    fn add(self, rhs: Self) -> Self::Output {
112        Self::Add(Box::new(self), Box::new(rhs))
113    }
114}
115
116#[cfg(not(feature = "dwave"))]
117impl std::ops::Mul for SimpleExpr {
118    type Output = Self;
119
120    fn mul(self, rhs: Self) -> Self::Output {
121        Self::Mul(Box::new(self), Box::new(rhs))
122    }
123}
124
125#[cfg(not(feature = "dwave"))]
126impl std::iter::Sum for SimpleExpr {
127    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
128        iter.fold(Self::Const(0.0), |acc, x| acc + x)
129    }
130}
131
132/// High-level model for constraint optimization problems
133#[cfg(feature = "dwave")]
134#[derive(Debug, Clone)]
135pub struct Model {
136    /// Variables in the model
137    variables: HashSet<String>,
138    /// Objective function expression
139    objective: Option<Expr>,
140    /// Constraints
141    constraints: Vec<Constraint>,
142}
143
144/// Constraint types
145#[cfg(feature = "dwave")]
146#[derive(Debug, Clone)]
147enum Constraint {
148    /// Equality constraint: sum of variables equals value
149    Equality {
150        name: String,
151        expr: Expr,
152        value: f64,
153    },
154    /// Inequality constraint: sum of variables <= value
155    LessEqual {
156        name: String,
157        expr: Expr,
158        value: f64,
159    },
160    /// At most one constraint: at most one variable can be 1
161    AtMostOne { name: String, variables: Vec<Expr> },
162    /// Implication constraint: if any condition is true, then result must be true
163    ImpliesAny {
164        name: String,
165        conditions: Vec<Expr>,
166        result: Expr,
167    },
168}
169
170#[cfg(feature = "dwave")]
171impl Default for Model {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177#[cfg(feature = "dwave")]
178impl Model {
179    /// Create a new empty model
180    pub fn new() -> Self {
181        Self {
182            variables: HashSet::new(),
183            objective: None,
184            constraints: Vec::new(),
185        }
186    }
187
188    /// Add a variable to the model
189    pub fn add_variable(&mut self, name: &str) -> CompileResult<Expr> {
190        self.variables.insert(name.to_string());
191        Ok(SymEngineExpression::symbol(name))
192    }
193
194    /// Set the objective function
195    pub fn set_objective(&mut self, expr: Expr) {
196        self.objective = Some(expr);
197    }
198
199    /// Add constraint: exactly one of the variables must be 1
200    pub fn add_constraint_eq_one(&mut self, name: &str, variables: Vec<Expr>) -> CompileResult<()> {
201        // For binary variables, sum = 1 means exactly one is 1
202        let sum_expr = variables
203            .iter()
204            .fold(Expr::from(0), |acc, v| acc + v.clone());
205        self.constraints.push(Constraint::Equality {
206            name: name.to_string(),
207            expr: sum_expr,
208            value: 1.0,
209        });
210        Ok(())
211    }
212
213    /// Add constraint: at most one of the variables can be 1
214    pub fn add_constraint_at_most_one(
215        &mut self,
216        name: &str,
217        variables: Vec<Expr>,
218    ) -> CompileResult<()> {
219        self.constraints.push(Constraint::AtMostOne {
220            name: name.to_string(),
221            variables,
222        });
223        Ok(())
224    }
225
226    /// Add constraint: if any condition is true, then result must be true
227    pub fn add_constraint_implies_any(
228        &mut self,
229        name: &str,
230        conditions: Vec<Expr>,
231        result: Expr,
232    ) -> CompileResult<()> {
233        self.constraints.push(Constraint::ImpliesAny {
234            name: name.to_string(),
235            conditions,
236            result,
237        });
238        Ok(())
239    }
240
241    /// Compile the model to a CompiledModel
242    pub fn compile(&self) -> CompileResult<CompiledModel> {
243        // Build the final expression with penalty terms
244        let mut final_expr = self.objective.clone().unwrap_or_else(|| Expr::from(0));
245
246        // Default penalty weight
247        let penalty_weight = 10.0;
248
249        // Add penalty terms for constraints
250        for constraint in &self.constraints {
251            match constraint {
252                Constraint::Equality { expr, value, .. } => {
253                    // (expr - value)^2 penalty
254                    let diff = expr.clone() - Expr::from(*value);
255                    final_expr = final_expr + Expr::from(penalty_weight) * diff.clone() * diff;
256                }
257                #[cfg(feature = "dwave")]
258                Constraint::LessEqual { expr, value, .. } => {
259                    // max(0, expr - value)^2 penalty
260                    // For simplicity, we'll use a quadratic penalty
261                    let excess = expr.clone() - Expr::from(*value);
262                    final_expr = final_expr + Expr::from(penalty_weight) * excess.clone() * excess;
263                }
264                Constraint::AtMostOne { variables, .. } => {
265                    // Penalty: sum(xi * xj) for all i < j
266                    for i in 0..variables.len() {
267                        for j in (i + 1)..variables.len() {
268                            final_expr = final_expr
269                                + Expr::from(penalty_weight)
270                                    * variables[i].clone()
271                                    * variables[j].clone();
272                        }
273                    }
274                }
275                Constraint::ImpliesAny {
276                    conditions, result, ..
277                } => {
278                    // If any condition is true, result must be true
279                    // Penalty: (max(conditions) - result)^2 where max is approximated by sum
280                    let conditions_sum = conditions
281                        .iter()
282                        .fold(Expr::from(0), |acc, c| acc + c.clone());
283                    // Penalty when conditions_sum > 0 and result = 0
284                    final_expr = final_expr
285                        + Expr::from(penalty_weight)
286                            * conditions_sum
287                            * (Expr::from(1) - result.clone());
288                }
289            }
290        }
291
292        // Use the standard compiler
293        let mut compiler = Compile::new(final_expr);
294        let ((qubo_matrix, var_map), offset) = compiler.get_qubo()?;
295
296        Ok(CompiledModel {
297            qubo_matrix,
298            var_map,
299            offset,
300            constraints: self.constraints.clone(),
301        })
302    }
303}
304
305/// Compiled model ready for sampling
306#[cfg(feature = "dwave")]
307#[derive(Debug, Clone)]
308pub struct CompiledModel {
309    /// QUBO matrix
310    pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
311    /// Variable name to index mapping
312    pub var_map: HashMap<String, usize>,
313    /// Constant offset
314    pub offset: f64,
315    /// Original constraints (for analysis)
316    constraints: Vec<Constraint>,
317}
318
319#[cfg(feature = "dwave")]
320impl CompiledModel {
321    /// Convert to QUBO format
322    pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
323        use quantrs2_anneal::ising::QuboModel;
324
325        let mut qubo = QuboModel::new(self.var_map.len());
326
327        // Set the offset
328        qubo.offset = self.offset;
329
330        // Set all the QUBO coefficients
331        for i in 0..self.qubo_matrix.nrows() {
332            for j in i..self.qubo_matrix.ncols() {
333                let value = self.qubo_matrix[[i, j]];
334                if value.abs() > 1e-10 {
335                    if i == j {
336                        // Diagonal term (linear)
337                        // SAFETY: index i is derived from matrix dimensions which match QuboModel size
338                        qubo.set_linear(i, value)
339                            .expect("index within bounds from matrix dimensions");
340                    } else {
341                        // Off-diagonal term (quadratic)
342                        // SAFETY: indices i,j are derived from matrix dimensions which match QuboModel size
343                        qubo.set_quadratic(i, j, value)
344                            .expect("indices within bounds from matrix dimensions");
345                    }
346                }
347            }
348        }
349
350        qubo
351    }
352}
353
354/// High-level model for constraint optimization problems (non-dwave version)
355#[cfg(not(feature = "dwave"))]
356#[derive(Debug, Clone)]
357pub struct Model {
358    /// Variables in the model
359    variables: HashSet<String>,
360    /// Objective function expression
361    objective: Option<SimpleExpr>,
362    /// Constraints
363    constraints: Vec<Constraint>,
364}
365
366/// Constraint types (non-dwave version)
367#[cfg(not(feature = "dwave"))]
368#[derive(Debug, Clone)]
369enum Constraint {
370    /// Equality constraint: sum of variables equals value
371    Equality {
372        name: String,
373        expr: SimpleExpr,
374        value: f64,
375    },
376    /// At most one constraint: at most one variable can be 1
377    AtMostOne {
378        name: String,
379        variables: Vec<SimpleExpr>,
380    },
381    /// Implication constraint: if any condition is true, then result must be true
382    ImpliesAny {
383        name: String,
384        conditions: Vec<SimpleExpr>,
385        result: SimpleExpr,
386    },
387}
388
389#[cfg(not(feature = "dwave"))]
390impl Default for Model {
391    fn default() -> Self {
392        Self::new()
393    }
394}
395
396#[cfg(not(feature = "dwave"))]
397impl Model {
398    /// Create a new empty model
399    pub fn new() -> Self {
400        Self {
401            variables: HashSet::new(),
402            objective: None,
403            constraints: Vec::new(),
404        }
405    }
406
407    /// Add a variable to the model
408    pub fn add_variable(&mut self, name: &str) -> CompileResult<SimpleExpr> {
409        self.variables.insert(name.to_string());
410        Ok(SimpleExpr::var(name))
411    }
412
413    /// Set the objective function
414    pub fn set_objective(&mut self, expr: SimpleExpr) {
415        self.objective = Some(expr);
416    }
417
418    /// Add constraint: exactly one of the variables must be 1
419    pub fn add_constraint_eq_one(
420        &mut self,
421        name: &str,
422        variables: Vec<SimpleExpr>,
423    ) -> CompileResult<()> {
424        let sum_expr = variables.into_iter().sum();
425        self.constraints.push(Constraint::Equality {
426            name: name.to_string(),
427            expr: sum_expr,
428            value: 1.0,
429        });
430        Ok(())
431    }
432
433    /// Add constraint: at most one of the variables can be 1
434    pub fn add_constraint_at_most_one(
435        &mut self,
436        name: &str,
437        variables: Vec<SimpleExpr>,
438    ) -> CompileResult<()> {
439        self.constraints.push(Constraint::AtMostOne {
440            name: name.to_string(),
441            variables,
442        });
443        Ok(())
444    }
445
446    /// Add constraint: if any condition is true, then result must be true
447    pub fn add_constraint_implies_any(
448        &mut self,
449        name: &str,
450        conditions: Vec<SimpleExpr>,
451        result: SimpleExpr,
452    ) -> CompileResult<()> {
453        self.constraints.push(Constraint::ImpliesAny {
454            name: name.to_string(),
455            conditions,
456            result,
457        });
458        Ok(())
459    }
460
461    /// Compile the model to a CompiledModel
462    pub fn compile(&self) -> CompileResult<CompiledModel> {
463        // Build QUBO directly from constraints
464        let mut qubo_terms: HashMap<(String, String), f64> = HashMap::new();
465        let mut offset = 0.0;
466        let penalty_weight = 10.0;
467
468        // Process objective if present
469        if let Some(ref obj) = self.objective {
470            self.add_expr_to_qubo(obj, 1.0, &mut qubo_terms, &mut offset)?;
471        }
472
473        // Process constraints
474        for constraint in &self.constraints {
475            match constraint {
476                Constraint::Equality { expr, value, .. } => {
477                    // (expr - value)^2 penalty
478                    // Expand: expr^2 - 2*expr*value + value^2
479                    self.add_expr_squared_to_qubo(
480                        expr,
481                        penalty_weight,
482                        &mut qubo_terms,
483                        &mut offset,
484                    )?;
485                    self.add_expr_to_qubo(
486                        expr,
487                        -2.0 * penalty_weight * value,
488                        &mut qubo_terms,
489                        &mut offset,
490                    )?;
491                    offset += penalty_weight * value * value;
492                }
493                Constraint::AtMostOne { variables, .. } => {
494                    // Penalty: sum(xi * xj) for all i < j
495                    for i in 0..variables.len() {
496                        for j in (i + 1)..variables.len() {
497                            if let (SimpleExpr::Var(vi), SimpleExpr::Var(vj)) =
498                                (&variables[i], &variables[j])
499                            {
500                                let key = if vi < vj {
501                                    (vi.clone(), vj.clone())
502                                } else {
503                                    (vj.clone(), vi.clone())
504                                };
505                                *qubo_terms.entry(key).or_insert(0.0) += penalty_weight;
506                            }
507                        }
508                    }
509                }
510                Constraint::ImpliesAny {
511                    conditions, result, ..
512                } => {
513                    // Penalty: sum(conditions) * (1 - result)
514                    for cond in conditions {
515                        if let (SimpleExpr::Var(c), SimpleExpr::Var(r)) = (cond, result) {
516                            let key = if c < r {
517                                (c.clone(), r.clone())
518                            } else {
519                                (r.clone(), c.clone())
520                            };
521                            *qubo_terms.entry(key).or_insert(0.0) -= penalty_weight;
522                        }
523                        // Also add linear term for condition
524                        if let SimpleExpr::Var(c) = cond {
525                            *qubo_terms.entry((c.clone(), c.clone())).or_insert(0.0) +=
526                                penalty_weight;
527                        }
528                    }
529                }
530            }
531        }
532
533        // Convert to matrix form
534        let all_vars: HashSet<String> = qubo_terms
535            .keys()
536            .flat_map(|(v1, v2)| vec![v1.clone(), v2.clone()])
537            .collect();
538        let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
539        sorted_vars.sort();
540
541        let var_map: HashMap<String, usize> = sorted_vars
542            .iter()
543            .enumerate()
544            .map(|(i, v)| (v.clone(), i))
545            .collect();
546
547        let n = var_map.len();
548        let mut matrix = Array::zeros((n, n));
549
550        for ((v1, v2), coeff) in qubo_terms {
551            let i = var_map[&v1];
552            let j = var_map[&v2];
553            if i == j {
554                matrix[[i, i]] += coeff;
555            } else {
556                matrix[[i, j]] += coeff / 2.0;
557                matrix[[j, i]] += coeff / 2.0;
558            }
559        }
560
561        Ok(CompiledModel {
562            qubo_matrix: matrix,
563            var_map,
564            offset,
565            constraints: self.constraints.clone(),
566        })
567    }
568
569    /// Add expression to QUBO terms
570    fn add_expr_to_qubo(
571        &self,
572        expr: &SimpleExpr,
573        coeff: f64,
574        terms: &mut HashMap<(String, String), f64>,
575        offset: &mut f64,
576    ) -> CompileResult<()> {
577        match expr {
578            SimpleExpr::Var(name) => {
579                *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
580            }
581            SimpleExpr::Const(val) => {
582                *offset += coeff * val;
583            }
584            SimpleExpr::Add(left, right) => {
585                self.add_expr_to_qubo(left, coeff, terms, offset)?;
586                self.add_expr_to_qubo(right, coeff, terms, offset)?;
587            }
588            SimpleExpr::Mul(left, right) => {
589                if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
590                {
591                    let key = if v1 < v2 {
592                        (v1.clone(), v2.clone())
593                    } else {
594                        (v2.clone(), v1.clone())
595                    };
596                    *terms.entry(key).or_insert(0.0) += coeff;
597                } else if let (SimpleExpr::Const(c), var) | (var, SimpleExpr::Const(c)) =
598                    (left.as_ref(), right.as_ref())
599                {
600                    self.add_expr_to_qubo(var, coeff * c, terms, offset)?;
601                }
602            }
603            SimpleExpr::Pow(base, exp) => {
604                if *exp == 2 && matches!(base.as_ref(), SimpleExpr::Var(_)) {
605                    // x^2 = x for binary variables
606                    self.add_expr_to_qubo(base, coeff, terms, offset)?;
607                }
608            }
609        }
610        Ok(())
611    }
612
613    /// Add expression squared to QUBO terms
614    fn add_expr_squared_to_qubo(
615        &self,
616        expr: &SimpleExpr,
617        coeff: f64,
618        terms: &mut HashMap<(String, String), f64>,
619        offset: &mut f64,
620    ) -> CompileResult<()> {
621        // For simplicity, only handle simple cases
622        match expr {
623            SimpleExpr::Var(name) => {
624                // x^2 = x for binary
625                *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
626            }
627            SimpleExpr::Add(left, right) => {
628                // (a + b)^2 = a^2 + 2ab + b^2
629                self.add_expr_squared_to_qubo(left, coeff, terms, offset)?;
630                self.add_expr_squared_to_qubo(right, coeff, terms, offset)?;
631                // Cross term
632                if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
633                {
634                    let key = if v1 < v2 {
635                        (v1.clone(), v2.clone())
636                    } else {
637                        (v2.clone(), v1.clone())
638                    };
639                    *terms.entry(key).or_insert(0.0) += 2.0 * coeff;
640                }
641            }
642            _ => {}
643        }
644        Ok(())
645    }
646}
647
648/// Compiled model ready for sampling (non-dwave version)
649#[cfg(not(feature = "dwave"))]
650#[derive(Debug, Clone)]
651pub struct CompiledModel {
652    /// QUBO matrix
653    pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
654    /// Variable name to index mapping
655    pub var_map: HashMap<String, usize>,
656    /// Constant offset
657    pub offset: f64,
658    /// Original constraints (for analysis)
659    constraints: Vec<Constraint>,
660}
661
662#[cfg(not(feature = "dwave"))]
663impl CompiledModel {
664    /// Convert to QUBO format
665    pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
666        use quantrs2_anneal::ising::QuboModel;
667
668        let mut qubo = QuboModel::new(self.var_map.len());
669
670        // Set the offset
671        qubo.offset = self.offset;
672
673        // Set all the QUBO coefficients
674        for i in 0..self.qubo_matrix.nrows() {
675            for j in i..self.qubo_matrix.ncols() {
676                let value = self.qubo_matrix[[i, j]];
677                if value.abs() > 1e-10 {
678                    if i == j {
679                        // Diagonal term (linear)
680                        // SAFETY: index i is derived from matrix dimensions which match QuboModel size
681                        qubo.set_linear(i, value)
682                            .expect("index within bounds from matrix dimensions");
683                    } else {
684                        // Off-diagonal term (quadratic)
685                        // SAFETY: indices i,j are derived from matrix dimensions which match QuboModel size
686                        qubo.set_quadratic(i, j, value)
687                            .expect("indices within bounds from matrix dimensions");
688                    }
689                }
690            }
691        }
692
693        qubo
694    }
695}
696
697/// Compiler for converting symbolic expressions to QUBO models
698///
699/// This struct provides methods for converting symbolic expressions
700/// to QUBO models, which can then be solved using quantum annealing.
701#[cfg(feature = "dwave")]
702pub struct Compile {
703    /// The symbolic expression to compile
704    expr: Expr,
705}
706
707#[cfg(feature = "dwave")]
708impl Compile {
709    /// Create a new compiler with the given expression
710    pub fn new<T: Into<Expr>>(expr: T) -> Self {
711        Self { expr: expr.into() }
712    }
713
714    /// Compile the expression to a QUBO model
715    ///
716    /// This method compiles the symbolic expression to a QUBO model,
717    /// which can then be passed to a sampler for solving.
718    ///
719    /// # Returns
720    ///
721    /// A tuple containing:
722    /// - A tuple with the QUBO matrix and a mapping of variable names to indices
723    /// - An offset value that should be added to all energy values
724    pub fn get_qubo(
725        &self,
726    ) -> CompileResult<(
727        (
728            Array<f64, scirs2_core::ndarray::Ix2>,
729            HashMap<String, usize>,
730        ),
731        f64,
732    )> {
733        #[cfg(feature = "scirs")]
734        {
735            self.get_qubo_scirs()
736        }
737        #[cfg(not(feature = "scirs"))]
738        {
739            self.get_qubo_standard()
740        }
741    }
742
743    /// Standard QUBO compilation without SciRS2
744    fn get_qubo_standard(
745        &self,
746    ) -> CompileResult<(
747        (
748            Array<f64, scirs2_core::ndarray::Ix2>,
749            HashMap<String, usize>,
750        ),
751        f64,
752    )> {
753        // Expand the expression to simplify
754        let mut expr = self.expr.expand();
755
756        // Check the degree of each term
757        let max_degree = calc_highest_degree(&expr)?;
758        if max_degree > 2 {
759            return Err(CompileError::DegreeTooHigh(max_degree, 2));
760        }
761
762        // Replace all second-degree terms (x^2) with x, since x^2 = x for binary variables
763        let mut expr = replace_squared_terms(&expr)?;
764
765        // Expand again to collect like terms
766        let mut expr = expr.expand();
767
768        // Extract the coefficients and variables
769        let (coeffs, offset) = extract_coefficients(&expr)?;
770
771        // Convert to a QUBO matrix
772        let (matrix, var_map) = build_qubo_matrix(&coeffs)?;
773
774        Ok(((matrix, var_map), offset))
775    }
776
777    /// QUBO compilation with SciRS2 optimization
778    #[cfg(feature = "scirs")]
779    fn get_qubo_scirs(
780        &self,
781    ) -> CompileResult<(
782        (
783            Array<f64, scirs2_core::ndarray::Ix2>,
784            HashMap<String, usize>,
785        ),
786        f64,
787    )> {
788        // Get standard result
789        let ((matrix, var_map), offset) = self.get_qubo_standard()?;
790
791        // Apply SciRS2 enhancements
792        let enhanced_matrix = crate::scirs_stub::enhance_qubo_matrix(&matrix);
793
794        Ok(((enhanced_matrix, var_map), offset))
795    }
796
797    /// Compile the expression to a HOBO model
798    ///
799    /// This method compiles the symbolic expression to a Higher-Order Binary Optimization model,
800    /// which can handle terms of degree higher than 2.
801    ///
802    /// # Returns
803    ///
804    /// A tuple containing:
805    /// - A tuple with the HOBO tensor and a mapping of variable names to indices
806    /// - An offset value that should be added to all energy values
807    pub fn get_hobo(
808        &self,
809    ) -> CompileResult<(
810        (
811            Array<f64, scirs2_core::ndarray::IxDyn>,
812            HashMap<String, usize>,
813        ),
814        f64,
815    )> {
816        // Expand the expression to simplify
817        let mut expr = self.expr.expand();
818
819        // Calculate highest degree (dimension of the tensor)
820        let max_degree = calc_highest_degree(&expr)?;
821
822        // Replace all squared terms (x^2) with x, since x^2 = x for binary variables
823        let mut expr = replace_squared_terms(&expr)?;
824
825        // Expand again to collect like terms
826        let mut expr = expr.expand();
827
828        // Extract the coefficients and variables
829        let (coeffs, offset) = extract_coefficients(&expr)?;
830
831        // Build the HOBO tensor
832        let (tensor, var_map) = build_hobo_tensor(&coeffs, max_degree)?;
833
834        Ok(((tensor, var_map), offset))
835    }
836}
837
838// Helper function to calculate the highest degree in the expression
839#[cfg(feature = "dwave")]
840fn calc_highest_degree(expr: &Expr) -> CompileResult<usize> {
841    // If the expression is a single variable, it's degree 1
842    if expr.is_symbol() {
843        return Ok(1);
844    }
845
846    // If it's a number constant, degree is 0
847    if expr.is_number() {
848        return Ok(0);
849    }
850
851    // If it's a power operation (like x^2)
852    if expr.is_pow() {
853        // SAFETY: is_pow() check guarantees as_pow() will succeed
854        let (base, exp) = expr.as_pow().expect("is_pow() was true");
855
856        // If the base is a symbol and exponent is a number
857        if base.is_symbol() && exp.is_number() {
858            let exp_val = match exp.to_f64() {
859                Some(n) => n,
860                None => {
861                    return Err(CompileError::InvalidExpression(
862                        "Invalid exponent".to_string(),
863                    ))
864                }
865            };
866
867            // Check if exponent is a positive integer
868            if exp_val.is_sign_positive() && exp_val.fract() == 0.0 {
869                return Ok(exp_val as usize);
870            }
871        }
872
873        // For other power expressions, recursively calculate the degree
874        let base_degree = calc_highest_degree(&base)?;
875        let exp_degree = if exp.is_number() {
876            match exp.to_f64() {
877                Some(n) => {
878                    if n.is_sign_positive() && n.fract() == 0.0 {
879                        n as usize
880                    } else {
881                        0 // Non-integer or negative exponents don't contribute to degree
882                    }
883                }
884                None => 0,
885            }
886        } else {
887            0 // Non-constant exponents don't contribute to degree
888        };
889
890        return Ok(base_degree * exp_degree);
891    }
892
893    // If it's a product (like x*y or x*x)
894    if expr.is_mul() {
895        let mut total_degree = 0;
896        // SAFETY: is_mul() check guarantees as_mul() will succeed
897        for factor in expr.as_mul().expect("is_mul() was true") {
898            total_degree += calc_highest_degree(&factor)?;
899        }
900        return Ok(total_degree);
901    }
902
903    // If it's a sum (like x + y)
904    if expr.is_add() {
905        let mut max_degree = 0;
906        // SAFETY: is_add() check guarantees as_add() will succeed
907        for term in expr.as_add().expect("is_add() was true") {
908            let term_degree = calc_highest_degree(&term)?;
909            max_degree = std::cmp::max(max_degree, term_degree);
910        }
911        return Ok(max_degree);
912    }
913
914    // Check for other compound expressions by trying to parse them
915    let expr_str = format!("{expr}");
916    if expr_str.contains('+') || expr_str.contains('-') {
917        // It's a sum-like expression but not recognized as ADD
918        // Parse the string to find the highest degree term
919        // This is a workaround for symengine type detection issues
920        let mut max_degree = 0;
921
922        // Split by + and - (keeping the sign)
923        let parts: Vec<&str> = expr_str.split(['+', '-']).collect();
924
925        for part in parts {
926            let part = part.trim();
927            if part.is_empty() {
928                continue;
929            }
930
931            // Count degree based on what the term contains
932            let degree = if part.contains("**") || part.contains('^') {
933                // Power term like x**2 or y**3
934                // Extract the exponent
935                let exp_str = part
936                    .split("**")
937                    .nth(1)
938                    .or_else(|| part.split('^').nth(1))
939                    .unwrap_or("2")
940                    .trim();
941                exp_str.parse::<usize>().unwrap_or(2)
942            } else if part.contains('*') {
943                // Product term - count the number of variables
944                let factors: Vec<&str> = part.split('*').collect();
945                let mut var_count = 0;
946                for factor in factors {
947                    let factor = factor.trim();
948                    // Check if it's a variable (not a number)
949                    if !factor.is_empty() && factor.parse::<f64>().is_err() {
950                        var_count += 1;
951                    }
952                }
953                var_count
954            } else if part.parse::<f64>().is_err() && !part.is_empty() {
955                // Single variable
956                1
957            } else {
958                // Constant
959                0
960            };
961
962            max_degree = std::cmp::max(max_degree, degree);
963        }
964
965        return Ok(max_degree);
966    }
967
968    // Default case - for simplicity, we'll say degree is 0
969    // but for a complete implementation, we'd need to handle all cases
970    Err(CompileError::InvalidExpression(format!(
971        "Can't determine degree of: {expr}"
972    )))
973}
974
975// Helper function to replace squared terms with linear terms
976#[cfg(feature = "dwave")]
977fn replace_squared_terms(expr: &Expr) -> CompileResult<Expr> {
978    // For binary variables, x^2 = x since x ∈ {0,1}
979
980    // If the expression is a symbol or number, just return it
981    if expr.is_symbol() || expr.is_number() {
982        return Ok(expr.clone());
983    }
984
985    // If it's a power operation (like x^2)
986    if expr.is_pow() {
987        // SAFETY: is_pow() check guarantees as_pow() will succeed
988        let (base, exp) = expr.as_pow().expect("is_pow() was true");
989
990        // If the base is a symbol and exponent is 2, replace with base
991        if base.is_symbol() && exp.is_number() {
992            let exp_val = match exp.to_f64() {
993                Some(n) => n,
994                None => {
995                    return Err(CompileError::InvalidExpression(
996                        "Invalid exponent".to_string(),
997                    ))
998                }
999            };
1000
1001            // Check if exponent is 2 (for higher exponents we'd need to recurse)
1002            if exp_val == 2.0 {
1003                return Ok(base);
1004            }
1005        }
1006
1007        // For other power expressions, recursively replace
1008        let new_base = replace_squared_terms(&base)?;
1009        return Ok(new_base.pow(&exp));
1010    }
1011
1012    // If it's a product (like x*y or x*x)
1013    if expr.is_mul() {
1014        let mut new_terms = Vec::new();
1015        // SAFETY: is_mul() check guarantees as_mul() will succeed
1016        for factor in expr.as_mul().expect("is_mul() was true") {
1017            new_terms.push(replace_squared_terms(&factor)?);
1018        }
1019
1020        // Combine the terms back into a product
1021        let mut result = Expr::from(1);
1022        for term in new_terms {
1023            result = result * term;
1024        }
1025        return Ok(result);
1026    }
1027
1028    // If it's a sum (like x + y)
1029    if expr.is_add() {
1030        let mut new_terms = Vec::new();
1031        // SAFETY: is_add() check guarantees as_add() will succeed
1032        for term in expr.as_add().expect("is_add() was true") {
1033            new_terms.push(replace_squared_terms(&term)?);
1034        }
1035
1036        // Combine the terms back into a sum
1037        let mut result = Expr::from(0);
1038        for term in new_terms {
1039            result = result + term;
1040        }
1041        return Ok(result);
1042    }
1043
1044    // For any other type of expression, just return it unchanged
1045    Ok(expr.clone())
1046}
1047
1048// Helper function to extract coefficients and variables from the expression
1049#[cfg(feature = "dwave")]
1050fn extract_coefficients(expr: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1051    let mut coeffs = HashMap::new();
1052    let mut offset = 0.0;
1053
1054    // Process expression as a sum of terms
1055    if expr.is_add() {
1056        // SAFETY: is_add() check guarantees as_add() will succeed
1057        for term in expr.as_add().expect("is_add() was true") {
1058            let (term_coeffs, term_offset) = extract_term_coefficients(&term)?;
1059
1060            // Merge coefficients
1061            for (vars, coeff) in term_coeffs {
1062                *coeffs.entry(vars).or_insert(0.0) += coeff;
1063            }
1064
1065            // Add constant terms to offset
1066            offset += term_offset;
1067        }
1068    } else {
1069        // Check if it's a sum-like expression that wasn't detected as ADD
1070        let expr_str = format!("{expr}");
1071        if expr_str.contains('+') || expr_str.contains('-') {
1072            // Use regex to split properly maintaining signs
1073            // This is a more robust workaround for symengine type detection issues
1074            use regex::Regex;
1075            // SAFETY: Static regex pattern is known to be valid at compile time
1076            let re = Regex::new(r"([+-]?)([^+-]+)").expect("static regex pattern is valid");
1077
1078            for caps in re.captures_iter(&expr_str) {
1079                let sign = caps.get(1).map_or("", |m| m.as_str());
1080                let term = caps.get(2).map_or("", |m| m.as_str()).trim();
1081
1082                if term.is_empty() {
1083                    continue;
1084                }
1085
1086                let sign_mult = if sign == "-" { -1.0 } else { 1.0 };
1087
1088                // Handle x**2 or x^2 (becomes just x for binary)
1089                if term.contains("**") || term.contains('^') {
1090                    let base = if term.contains("**") {
1091                        term.split("**").next().unwrap_or(term)
1092                    } else {
1093                        term.split('^').next().unwrap_or(term)
1094                    }
1095                    .trim();
1096
1097                    // Extract coefficient if present (e.g., "10*x^2" -> coeff=10, base="x")
1098                    let (coeff_mult, var_name) = if base.contains('*') {
1099                        let parts: Vec<&str> = base.split('*').collect();
1100                        if parts.len() == 2 {
1101                            if let Ok(num) = parts[0].trim().parse::<f64>() {
1102                                (num, parts[1].trim().to_string())
1103                            } else if let Ok(num) = parts[1].trim().parse::<f64>() {
1104                                (num, parts[0].trim().to_string())
1105                            } else {
1106                                (1.0, base.to_string())
1107                            }
1108                        } else {
1109                            (1.0, base.to_string())
1110                        }
1111                    } else {
1112                        (1.0, base.to_string())
1113                    };
1114
1115                    let vars = vec![var_name.clone()];
1116                    *coeffs.entry(vars).or_insert(0.0) += sign_mult * coeff_mult;
1117                } else if term.contains('*') {
1118                    // Handle multiplication: could be "x*y", "2*x", "x*2", "x*y*z", etc.
1119                    let parts: Vec<&str> = term.split('*').collect();
1120                    let mut coeff = sign_mult;
1121                    let mut vars = Vec::new();
1122
1123                    for part in parts {
1124                        let part = part.trim();
1125                        if let Ok(num) = part.parse::<f64>() {
1126                            coeff *= num;
1127                        } else {
1128                            // It's a variable
1129                            vars.push(part.to_string());
1130                        }
1131                    }
1132
1133                    // Sort variables for consistent ordering
1134                    vars.sort();
1135                    *coeffs.entry(vars).or_insert(0.0) += coeff;
1136                } else if let Ok(num) = term.parse::<f64>() {
1137                    // Constant term
1138                    offset += sign_mult * num;
1139                } else {
1140                    // Single variable with coefficient 1
1141                    let vars = vec![term.to_string()];
1142                    *coeffs.entry(vars).or_insert(0.0) += sign_mult;
1143                }
1144            }
1145            return Ok((coeffs, offset));
1146        }
1147
1148        // Only process as a single term if we haven't processed it as ADD yet
1149        if coeffs.is_empty() {
1150            // Process a single term
1151            let (term_coeffs, term_offset) = extract_term_coefficients(expr)?;
1152
1153            // Merge coefficients
1154            for (vars, coeff) in term_coeffs {
1155                *coeffs.entry(vars).or_insert(0.0) += coeff;
1156            }
1157
1158            // Add constant terms to offset
1159            offset += term_offset;
1160        }
1161    }
1162
1163    Ok((coeffs, offset))
1164}
1165
1166// Helper function to extract coefficient and variables from a single term
1167#[cfg(feature = "dwave")]
1168fn extract_term_coefficients(term: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1169    let mut coeffs = HashMap::new();
1170
1171    // If it's a number constant, it's an offset
1172    if term.is_number() {
1173        let mut value = match term.to_f64() {
1174            Some(n) => n,
1175            None => {
1176                return Err(CompileError::InvalidExpression(
1177                    "Invalid number".to_string(),
1178                ))
1179            }
1180        };
1181        return Ok((coeffs, value));
1182    }
1183
1184    // If it's a symbol, it's a linear term with coefficient 1
1185    if term.is_symbol() {
1186        // SAFETY: is_symbol() check guarantees as_symbol() will succeed
1187        let var_name = term.as_symbol().expect("is_symbol() was true");
1188        let vars = vec![var_name];
1189        coeffs.insert(vars, 1.0);
1190        return Ok((coeffs, 0.0));
1191    }
1192
1193    // If it's a product of terms
1194    if term.is_mul() {
1195        let mut coeff = 1.0;
1196        let mut vars = Vec::new();
1197
1198        // SAFETY: is_mul() check guarantees as_mul() will succeed
1199        for factor in term.as_mul().expect("is_mul() was true") {
1200            if factor.is_number() {
1201                // Numerical factor is a coefficient
1202                let value = match factor.to_f64() {
1203                    Some(n) => n,
1204                    None => {
1205                        return Err(CompileError::InvalidExpression(
1206                            "Invalid number in product".to_string(),
1207                        ))
1208                    }
1209                };
1210                coeff *= value;
1211            } else if factor.is_symbol() {
1212                // Symbol is a variable
1213                // SAFETY: is_symbol() check guarantees as_symbol() will succeed
1214                let var_name = factor.as_symbol().expect("is_symbol() was true");
1215                vars.push(var_name);
1216            } else {
1217                // More complex factors not supported in this example
1218                return Err(CompileError::InvalidExpression(format!(
1219                    "Unsupported term in product: {factor}"
1220                )));
1221            }
1222        }
1223
1224        // Sort variables for consistent ordering
1225        vars.sort();
1226
1227        if vars.is_empty() {
1228            // If there are no variables, it's a constant term
1229            return Ok((coeffs, coeff));
1230        }
1231        coeffs.insert(vars, coeff);
1232
1233        return Ok((coeffs, 0.0));
1234    }
1235
1236    // If it's a power operation (like x^2), should have been simplified earlier
1237    if term.is_pow() {
1238        return Err(CompileError::InvalidExpression(format!(
1239            "Unexpected power term after simplification: {term}"
1240        )));
1241    }
1242
1243    // Unsupported term type
1244    Err(CompileError::InvalidExpression(format!(
1245        "Unsupported term: {term}"
1246    )))
1247}
1248
1249// Helper function to build the QUBO matrix
1250#[allow(dead_code)]
1251fn build_qubo_matrix(
1252    coeffs: &HashMap<Vec<String>, f64>,
1253) -> CompileResult<(
1254    Array<f64, scirs2_core::ndarray::Ix2>,
1255    HashMap<String, usize>,
1256)> {
1257    // Collect all unique variable names
1258    let mut all_vars = HashSet::new();
1259    for vars in coeffs.keys() {
1260        for var in vars {
1261            all_vars.insert(var.clone());
1262        }
1263    }
1264
1265    // Convert to a sorted vector
1266    let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1267    sorted_vars.sort();
1268
1269    // Create the variable-to-index mapping
1270    let var_map: HashMap<String, usize> = sorted_vars
1271        .iter()
1272        .enumerate()
1273        .map(|(i, var)| (var.clone(), i))
1274        .collect();
1275
1276    // Size of the matrix
1277    let n = var_map.len();
1278
1279    // Create an empty matrix
1280    let mut matrix = Array::zeros((n, n));
1281
1282    // Fill the matrix with coefficients
1283    for (vars, &coeff) in coeffs {
1284        match vars.len() {
1285            0 => {
1286                // Should never happen since constants are handled in offset
1287            }
1288            1 => {
1289                // Linear term: var * coeff
1290                // SAFETY: var_map was built from the same variables in coeffs
1291                let i = *var_map
1292                    .get(&vars[0])
1293                    .expect("variable exists in var_map built from coeffs");
1294                matrix[[i, i]] += coeff;
1295            }
1296            2 => {
1297                // Quadratic term: var1 * var2 * coeff
1298                // SAFETY: var_map was built from the same variables in coeffs
1299                let i = *var_map
1300                    .get(&vars[0])
1301                    .expect("variable exists in var_map built from coeffs");
1302                let j = *var_map
1303                    .get(&vars[1])
1304                    .expect("variable exists in var_map built from coeffs");
1305
1306                // QUBO format requires i <= j
1307                if i == j {
1308                    // Diagonal term
1309                    matrix[[i, i]] += coeff;
1310                } else {
1311                    // Off-diagonal term - store full coefficient in upper triangular, zero in lower
1312                    if i <= j {
1313                        matrix[[i, j]] += coeff;
1314                    } else {
1315                        matrix[[j, i]] += coeff;
1316                    }
1317                }
1318            }
1319            _ => {
1320                // Higher-order terms are not supported in QUBO
1321                return Err(CompileError::DegreeTooHigh(vars.len(), 2));
1322            }
1323        }
1324    }
1325
1326    Ok((matrix, var_map))
1327}
1328
1329// Helper function to build the HOBO tensor
1330#[allow(dead_code)]
1331fn build_hobo_tensor(
1332    coeffs: &HashMap<Vec<String>, f64>,
1333    max_degree: usize,
1334) -> CompileResult<(
1335    Array<f64, scirs2_core::ndarray::IxDyn>,
1336    HashMap<String, usize>,
1337)> {
1338    // Collect all unique variable names
1339    let mut all_vars = HashSet::new();
1340    for vars in coeffs.keys() {
1341        for var in vars {
1342            all_vars.insert(var.clone());
1343        }
1344    }
1345
1346    // Convert to a sorted vector
1347    let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1348    sorted_vars.sort();
1349
1350    // Create the variable-to-index mapping
1351    let var_map: HashMap<String, usize> = sorted_vars
1352        .iter()
1353        .enumerate()
1354        .map(|(i, var)| (var.clone(), i))
1355        .collect();
1356
1357    // Size of each dimension
1358    let n = var_map.len();
1359
1360    // Create shape vector for the tensor
1361    let shape: Vec<usize> = vec![n; max_degree];
1362
1363    // Create an empty tensor
1364    let mut tensor = Array::zeros(scirs2_core::ndarray::IxDyn(&shape));
1365
1366    // Fill the tensor with coefficients
1367    for (vars, &coeff) in coeffs {
1368        let degree = vars.len();
1369
1370        if degree == 0 {
1371            // Should never happen since constants are handled in offset
1372            continue;
1373        }
1374
1375        if degree > max_degree {
1376            return Err(CompileError::DegreeTooHigh(degree, max_degree));
1377        }
1378
1379        // Convert variable names to indices
1380        // SAFETY: var_map was built from the same variables in coeffs
1381        let mut indices: Vec<usize> = vars
1382            .iter()
1383            .map(|var| {
1384                *var_map
1385                    .get(var)
1386                    .expect("variable exists in var_map built from coeffs")
1387            })
1388            .collect();
1389
1390        // Sort indices (canonical ordering)
1391        indices.sort_unstable();
1392
1393        // Pad indices to match tensor order if necessary
1394        while indices.len() < max_degree {
1395            indices.insert(0, indices[0]); // Padding with first index
1396        }
1397
1398        // Set the coefficient in the tensor
1399        let idx = scirs2_core::ndarray::IxDyn(&indices);
1400        tensor[idx] += coeff;
1401    }
1402
1403    Ok((tensor, var_map))
1404}
1405
1406/// Special compiler for problems with one-hot constraints
1407///
1408/// This is a specialized compiler that is optimized for problems
1409/// with one-hot constraints, common in many optimization problems.
1410#[cfg(feature = "dwave")]
1411pub struct PieckCompile {
1412    /// The symbolic expression to compile
1413    expr: Expr,
1414    /// Whether to show verbose output
1415    verbose: bool,
1416}
1417
1418#[cfg(feature = "dwave")]
1419impl PieckCompile {
1420    /// Create a new Pieck compiler with the given expression
1421    pub fn new<T: Into<Expr>>(expr: T, verbose: bool) -> Self {
1422        Self {
1423            expr: expr.into(),
1424            verbose,
1425        }
1426    }
1427
1428    /// Compile the expression to a QUBO model optimized for one-hot constraints
1429    pub fn get_qubo(
1430        &self,
1431    ) -> CompileResult<(
1432        (
1433            Array<f64, scirs2_core::ndarray::Ix2>,
1434            HashMap<String, usize>,
1435        ),
1436        f64,
1437    )> {
1438        // Implementation will compile the expression using specialized techniques
1439        // For now, call the regular compiler
1440        Compile::new(self.expr.clone()).get_qubo()
1441    }
1442}