Skip to main content

quantrs2_tytan/
compile.rs

1//! Compilation of symbolic expressions to QUBO/HOBO models.
2//!
3//! This module provides utilities for compiling symbolic expressions
4//! into QUBO (Quadratic Unconstrained Binary Optimization) and
5//! HOBO (Higher-Order Binary Optimization) models.
6
7#![allow(dead_code)]
8
9use scirs2_core::ndarray::Array;
10use std::collections::{HashMap, HashSet};
11
12#[cfg(feature = "scirs")]
13use crate::scirs_stub;
14
15#[cfg(feature = "dwave")]
16use quantrs2_symengine_pure::Expression as SymEngineExpression;
17
18#[cfg(feature = "dwave")]
19type Expr = SymEngineExpression;
20use thiserror::Error;
21
22use quantrs2_anneal::QuboError;
23
24/// Unified expression interface for examples
25#[cfg(feature = "dwave")]
26pub mod expr {
27    use quantrs2_symengine_pure::Expression as SymEngineExpression;
28
29    pub type Expr = SymEngineExpression;
30
31    pub fn constant(value: f64) -> Expr {
32        SymEngineExpression::from(value)
33    }
34
35    pub fn var(name: &str) -> Expr {
36        SymEngineExpression::symbol(name)
37    }
38}
39
40#[cfg(not(feature = "dwave"))]
41pub mod expr {
42    use super::SimpleExpr;
43
44    pub type Expr = SimpleExpr;
45
46    pub const fn constant(value: f64) -> Expr {
47        SimpleExpr::constant(value)
48    }
49
50    pub fn var(name: &str) -> Expr {
51        SimpleExpr::var(name)
52    }
53}
54
55/// Errors that can occur during compilation
56#[derive(Error, Debug)]
57#[non_exhaustive]
58pub enum CompileError {
59    /// Error when the expression is invalid
60    #[error("Invalid expression: {0}")]
61    InvalidExpression(String),
62
63    /// Error when a term has too high a degree
64    #[error("Term has degree {0}, but maximum supported is {1}")]
65    DegreeTooHigh(usize, usize),
66
67    /// Error in the underlying QUBO model
68    #[error("QUBO error: {0}")]
69    QuboError(#[from] QuboError),
70
71    /// Error in Symengine operations
72    #[error("Symengine error: {0}")]
73    SymengineError(String),
74}
75
76/// Result type for compilation operations
77pub type CompileResult<T> = Result<T, CompileError>;
78
79// Simple expression type for when dwave feature is not enabled
80#[cfg(not(feature = "dwave"))]
81#[derive(Debug, Clone)]
82pub enum SimpleExpr {
83    /// Variable
84    Var(String),
85    /// Constant
86    Const(f64),
87    /// Addition
88    Add(Box<Self>, Box<Self>),
89    /// Multiplication
90    Mul(Box<Self>, Box<Self>),
91    /// Power
92    Pow(Box<Self>, i32),
93}
94
95#[cfg(not(feature = "dwave"))]
96impl SimpleExpr {
97    /// Create a variable
98    pub fn var(name: &str) -> Self {
99        Self::Var(name.to_string())
100    }
101
102    /// Create a constant
103    pub const fn constant(value: f64) -> Self {
104        Self::Const(value)
105    }
106}
107
108#[cfg(not(feature = "dwave"))]
109impl std::ops::Add for SimpleExpr {
110    type Output = Self;
111
112    fn add(self, rhs: Self) -> Self::Output {
113        Self::Add(Box::new(self), Box::new(rhs))
114    }
115}
116
117#[cfg(not(feature = "dwave"))]
118impl std::ops::Mul for SimpleExpr {
119    type Output = Self;
120
121    fn mul(self, rhs: Self) -> Self::Output {
122        Self::Mul(Box::new(self), Box::new(rhs))
123    }
124}
125
126#[cfg(not(feature = "dwave"))]
127impl std::iter::Sum for SimpleExpr {
128    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
129        iter.fold(Self::Const(0.0), |acc, x| acc + x)
130    }
131}
132
133/// High-level model for constraint optimization problems
134#[cfg(feature = "dwave")]
135#[derive(Debug, Clone)]
136pub struct Model {
137    /// Variables in the model
138    variables: HashSet<String>,
139    /// Objective function expression
140    objective: Option<Expr>,
141    /// Constraints
142    constraints: Vec<Constraint>,
143}
144
145/// Constraint types
146#[cfg(feature = "dwave")]
147#[derive(Debug, Clone)]
148enum Constraint {
149    /// Equality constraint: sum of variables equals value
150    Equality {
151        name: String,
152        expr: Expr,
153        value: f64,
154    },
155    /// Inequality constraint: sum of variables <= value
156    LessEqual {
157        name: String,
158        expr: Expr,
159        value: f64,
160    },
161    /// At most one constraint: at most one variable can be 1
162    AtMostOne { name: String, variables: Vec<Expr> },
163    /// Implication constraint: if any condition is true, then result must be true
164    ImpliesAny {
165        name: String,
166        conditions: Vec<Expr>,
167        result: Expr,
168    },
169}
170
171#[cfg(feature = "dwave")]
172impl Default for Model {
173    fn default() -> Self {
174        Self::new()
175    }
176}
177
178#[cfg(feature = "dwave")]
179impl Model {
180    /// Create a new empty model
181    pub fn new() -> Self {
182        Self {
183            variables: HashSet::new(),
184            objective: None,
185            constraints: Vec::new(),
186        }
187    }
188
189    /// Add a variable to the model
190    pub fn add_variable(&mut self, name: &str) -> CompileResult<Expr> {
191        self.variables.insert(name.to_string());
192        Ok(SymEngineExpression::symbol(name))
193    }
194
195    /// Set the objective function
196    pub fn set_objective(&mut self, expr: Expr) {
197        self.objective = Some(expr);
198    }
199
200    /// Add constraint: exactly one of the variables must be 1
201    pub fn add_constraint_eq_one(&mut self, name: &str, variables: Vec<Expr>) -> CompileResult<()> {
202        // For binary variables, sum = 1 means exactly one is 1
203        let sum_expr = variables
204            .iter()
205            .fold(Expr::from(0), |acc, v| acc + v.clone());
206        self.constraints.push(Constraint::Equality {
207            name: name.to_string(),
208            expr: sum_expr,
209            value: 1.0,
210        });
211        Ok(())
212    }
213
214    /// Add constraint: at most one of the variables can be 1
215    pub fn add_constraint_at_most_one(
216        &mut self,
217        name: &str,
218        variables: Vec<Expr>,
219    ) -> CompileResult<()> {
220        self.constraints.push(Constraint::AtMostOne {
221            name: name.to_string(),
222            variables,
223        });
224        Ok(())
225    }
226
227    /// Add constraint: if any condition is true, then result must be true
228    pub fn add_constraint_implies_any(
229        &mut self,
230        name: &str,
231        conditions: Vec<Expr>,
232        result: Expr,
233    ) -> CompileResult<()> {
234        self.constraints.push(Constraint::ImpliesAny {
235            name: name.to_string(),
236            conditions,
237            result,
238        });
239        Ok(())
240    }
241
242    /// Compile the model to a CompiledModel
243    pub fn compile(&self) -> CompileResult<CompiledModel> {
244        // Build the final expression with penalty terms
245        let mut final_expr = self.objective.clone().unwrap_or_else(|| Expr::from(0));
246
247        // Default penalty weight
248        let penalty_weight = 10.0;
249
250        // Add penalty terms for constraints
251        for constraint in &self.constraints {
252            match constraint {
253                Constraint::Equality { expr, value, .. } => {
254                    // (expr - value)^2 penalty
255                    let diff = expr.clone() - Expr::from(*value);
256                    final_expr = final_expr + Expr::from(penalty_weight) * diff.clone() * diff;
257                }
258                #[cfg(feature = "dwave")]
259                Constraint::LessEqual { expr, value, .. } => {
260                    // max(0, expr - value)^2 penalty
261                    // For simplicity, we'll use a quadratic penalty
262                    let excess = expr.clone() - Expr::from(*value);
263                    final_expr = final_expr + Expr::from(penalty_weight) * excess.clone() * excess;
264                }
265                Constraint::AtMostOne { variables, .. } => {
266                    // Penalty: sum(xi * xj) for all i < j
267                    for i in 0..variables.len() {
268                        for j in (i + 1)..variables.len() {
269                            final_expr = final_expr
270                                + Expr::from(penalty_weight)
271                                    * variables[i].clone()
272                                    * variables[j].clone();
273                        }
274                    }
275                }
276                Constraint::ImpliesAny {
277                    conditions, result, ..
278                } => {
279                    // If any condition is true, result must be true
280                    // Penalty: (max(conditions) - result)^2 where max is approximated by sum
281                    let conditions_sum = conditions
282                        .iter()
283                        .fold(Expr::from(0), |acc, c| acc + c.clone());
284                    // Penalty when conditions_sum > 0 and result = 0
285                    final_expr = final_expr
286                        + Expr::from(penalty_weight)
287                            * conditions_sum
288                            * (Expr::from(1) - result.clone());
289                }
290            }
291        }
292
293        // Use the standard compiler
294        let mut compiler = Compile::new(final_expr);
295        let ((qubo_matrix, var_map), offset) = compiler.get_qubo()?;
296
297        Ok(CompiledModel {
298            qubo_matrix,
299            var_map,
300            offset,
301            constraints: self.constraints.clone(),
302        })
303    }
304}
305
306/// Compiled model ready for sampling
307#[cfg(feature = "dwave")]
308#[derive(Debug, Clone)]
309pub struct CompiledModel {
310    /// QUBO matrix
311    pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
312    /// Variable name to index mapping
313    pub var_map: HashMap<String, usize>,
314    /// Constant offset
315    pub offset: f64,
316    /// Original constraints (for analysis)
317    constraints: Vec<Constraint>,
318}
319
320#[cfg(feature = "dwave")]
321impl CompiledModel {
322    /// Convert to QUBO format
323    pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
324        use quantrs2_anneal::ising::QuboModel;
325
326        let mut qubo = QuboModel::new(self.var_map.len());
327
328        // Set the offset
329        qubo.offset = self.offset;
330
331        // Set all the QUBO coefficients
332        for i in 0..self.qubo_matrix.nrows() {
333            for j in i..self.qubo_matrix.ncols() {
334                let value = self.qubo_matrix[[i, j]];
335                if value.abs() > 1e-10 {
336                    if i == j {
337                        // Diagonal term (linear)
338                        // SAFETY: index i is derived from matrix dimensions which match QuboModel size
339                        qubo.set_linear(i, value)
340                            .expect("index within bounds from matrix dimensions");
341                    } else {
342                        // Off-diagonal term (quadratic)
343                        // SAFETY: indices i,j are derived from matrix dimensions which match QuboModel size
344                        qubo.set_quadratic(i, j, value)
345                            .expect("indices within bounds from matrix dimensions");
346                    }
347                }
348            }
349        }
350
351        qubo
352    }
353
354    /// Count the number of constraint violations for a given variable assignment.
355    ///
356    /// Returns the count of constraints that the assignment violates.
357    /// `assignments` maps variable names to their binary values (true = 1, false = 0).
358    pub fn count_constraint_violations(&self, assignments: &HashMap<String, bool>) -> usize {
359        let float_vals: HashMap<String, f64> = assignments
360            .iter()
361            .map(|(k, &v)| (k.clone(), if v { 1.0 } else { 0.0 }))
362            .collect();
363
364        let mut violations = 0usize;
365
366        for constraint in &self.constraints {
367            let violated = match constraint {
368                Constraint::Equality { expr, value, .. } => match expr.eval(&float_vals) {
369                    Ok(result) => (result - value).abs() > 1e-6,
370                    Err(_) => false,
371                },
372                Constraint::LessEqual { expr, value, .. } => match expr.eval(&float_vals) {
373                    Ok(result) => result > value + 1e-6,
374                    Err(_) => false,
375                },
376                Constraint::AtMostOne { variables, .. } => {
377                    let count: f64 = variables
378                        .iter()
379                        .filter_map(|v| v.eval(&float_vals).ok())
380                        .filter(|&val| val > 0.5)
381                        .count() as f64;
382                    count > 1.0 + 1e-6
383                }
384                Constraint::ImpliesAny {
385                    conditions, result, ..
386                } => {
387                    let any_condition_true = conditions
388                        .iter()
389                        .any(|c| c.eval(&float_vals).map(|val| val > 0.5).unwrap_or(false));
390                    if any_condition_true {
391                        match result.eval(&float_vals) {
392                            Ok(val) => val < 0.5,
393                            Err(_) => false,
394                        }
395                    } else {
396                        false
397                    }
398                }
399            };
400            if violated {
401                violations += 1;
402            }
403        }
404
405        violations
406    }
407
408    /// Return the total number of constraints in this model.
409    pub fn num_constraints(&self) -> usize {
410        self.constraints.len()
411    }
412}
413
414/// High-level model for constraint optimization problems (non-dwave version)
415#[cfg(not(feature = "dwave"))]
416#[derive(Debug, Clone)]
417pub struct Model {
418    /// Variables in the model
419    variables: HashSet<String>,
420    /// Objective function expression
421    objective: Option<SimpleExpr>,
422    /// Constraints
423    constraints: Vec<Constraint>,
424}
425
426/// Constraint types (non-dwave version)
427#[cfg(not(feature = "dwave"))]
428#[derive(Debug, Clone)]
429enum Constraint {
430    /// Equality constraint: sum of variables equals value
431    Equality {
432        name: String,
433        expr: SimpleExpr,
434        value: f64,
435    },
436    /// At most one constraint: at most one variable can be 1
437    AtMostOne {
438        name: String,
439        variables: Vec<SimpleExpr>,
440    },
441    /// Implication constraint: if any condition is true, then result must be true
442    ImpliesAny {
443        name: String,
444        conditions: Vec<SimpleExpr>,
445        result: SimpleExpr,
446    },
447}
448
449#[cfg(not(feature = "dwave"))]
450impl Default for Model {
451    fn default() -> Self {
452        Self::new()
453    }
454}
455
456#[cfg(not(feature = "dwave"))]
457impl Model {
458    /// Create a new empty model
459    pub fn new() -> Self {
460        Self {
461            variables: HashSet::new(),
462            objective: None,
463            constraints: Vec::new(),
464        }
465    }
466
467    /// Add a variable to the model
468    pub fn add_variable(&mut self, name: &str) -> CompileResult<SimpleExpr> {
469        self.variables.insert(name.to_string());
470        Ok(SimpleExpr::var(name))
471    }
472
473    /// Set the objective function
474    pub fn set_objective(&mut self, expr: SimpleExpr) {
475        self.objective = Some(expr);
476    }
477
478    /// Add constraint: exactly one of the variables must be 1
479    pub fn add_constraint_eq_one(
480        &mut self,
481        name: &str,
482        variables: Vec<SimpleExpr>,
483    ) -> CompileResult<()> {
484        let sum_expr = variables.into_iter().sum();
485        self.constraints.push(Constraint::Equality {
486            name: name.to_string(),
487            expr: sum_expr,
488            value: 1.0,
489        });
490        Ok(())
491    }
492
493    /// Add constraint: at most one of the variables can be 1
494    pub fn add_constraint_at_most_one(
495        &mut self,
496        name: &str,
497        variables: Vec<SimpleExpr>,
498    ) -> CompileResult<()> {
499        self.constraints.push(Constraint::AtMostOne {
500            name: name.to_string(),
501            variables,
502        });
503        Ok(())
504    }
505
506    /// Add constraint: if any condition is true, then result must be true
507    pub fn add_constraint_implies_any(
508        &mut self,
509        name: &str,
510        conditions: Vec<SimpleExpr>,
511        result: SimpleExpr,
512    ) -> CompileResult<()> {
513        self.constraints.push(Constraint::ImpliesAny {
514            name: name.to_string(),
515            conditions,
516            result,
517        });
518        Ok(())
519    }
520
521    /// Compile the model to a CompiledModel
522    pub fn compile(&self) -> CompileResult<CompiledModel> {
523        // Build QUBO directly from constraints
524        let mut qubo_terms: HashMap<(String, String), f64> = HashMap::new();
525        let mut offset = 0.0;
526        let penalty_weight = 10.0;
527
528        // Process objective if present
529        if let Some(ref obj) = self.objective {
530            self.add_expr_to_qubo(obj, 1.0, &mut qubo_terms, &mut offset)?;
531        }
532
533        // Process constraints
534        for constraint in &self.constraints {
535            match constraint {
536                Constraint::Equality { expr, value, .. } => {
537                    // (expr - value)^2 penalty
538                    // Expand: expr^2 - 2*expr*value + value^2
539                    self.add_expr_squared_to_qubo(
540                        expr,
541                        penalty_weight,
542                        &mut qubo_terms,
543                        &mut offset,
544                    )?;
545                    self.add_expr_to_qubo(
546                        expr,
547                        -2.0 * penalty_weight * value,
548                        &mut qubo_terms,
549                        &mut offset,
550                    )?;
551                    offset += penalty_weight * value * value;
552                }
553                Constraint::AtMostOne { variables, .. } => {
554                    // Penalty: sum(xi * xj) for all i < j
555                    for i in 0..variables.len() {
556                        for j in (i + 1)..variables.len() {
557                            if let (SimpleExpr::Var(vi), SimpleExpr::Var(vj)) =
558                                (&variables[i], &variables[j])
559                            {
560                                let key = if vi < vj {
561                                    (vi.clone(), vj.clone())
562                                } else {
563                                    (vj.clone(), vi.clone())
564                                };
565                                *qubo_terms.entry(key).or_insert(0.0) += penalty_weight;
566                            }
567                        }
568                    }
569                }
570                Constraint::ImpliesAny {
571                    conditions, result, ..
572                } => {
573                    // Penalty: sum(conditions) * (1 - result)
574                    for cond in conditions {
575                        if let (SimpleExpr::Var(c), SimpleExpr::Var(r)) = (cond, result) {
576                            let key = if c < r {
577                                (c.clone(), r.clone())
578                            } else {
579                                (r.clone(), c.clone())
580                            };
581                            *qubo_terms.entry(key).or_insert(0.0) -= penalty_weight;
582                        }
583                        // Also add linear term for condition
584                        if let SimpleExpr::Var(c) = cond {
585                            *qubo_terms.entry((c.clone(), c.clone())).or_insert(0.0) +=
586                                penalty_weight;
587                        }
588                    }
589                }
590            }
591        }
592
593        // Convert to matrix form
594        let all_vars: HashSet<String> = qubo_terms
595            .keys()
596            .flat_map(|(v1, v2)| vec![v1.clone(), v2.clone()])
597            .collect();
598        let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
599        sorted_vars.sort();
600
601        let var_map: HashMap<String, usize> = sorted_vars
602            .iter()
603            .enumerate()
604            .map(|(i, v)| (v.clone(), i))
605            .collect();
606
607        let n = var_map.len();
608        let mut matrix = Array::zeros((n, n));
609
610        for ((v1, v2), coeff) in qubo_terms {
611            let i = var_map[&v1];
612            let j = var_map[&v2];
613            if i == j {
614                matrix[[i, i]] += coeff;
615            } else {
616                matrix[[i, j]] += coeff / 2.0;
617                matrix[[j, i]] += coeff / 2.0;
618            }
619        }
620
621        Ok(CompiledModel {
622            qubo_matrix: matrix,
623            var_map,
624            offset,
625            constraints: self.constraints.clone(),
626        })
627    }
628
629    /// Add expression to QUBO terms
630    fn add_expr_to_qubo(
631        &self,
632        expr: &SimpleExpr,
633        coeff: f64,
634        terms: &mut HashMap<(String, String), f64>,
635        offset: &mut f64,
636    ) -> CompileResult<()> {
637        match expr {
638            SimpleExpr::Var(name) => {
639                *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
640            }
641            SimpleExpr::Const(val) => {
642                *offset += coeff * val;
643            }
644            SimpleExpr::Add(left, right) => {
645                self.add_expr_to_qubo(left, coeff, terms, offset)?;
646                self.add_expr_to_qubo(right, coeff, terms, offset)?;
647            }
648            SimpleExpr::Mul(left, right) => {
649                if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
650                {
651                    let key = if v1 < v2 {
652                        (v1.clone(), v2.clone())
653                    } else {
654                        (v2.clone(), v1.clone())
655                    };
656                    *terms.entry(key).or_insert(0.0) += coeff;
657                } else if let (SimpleExpr::Const(c), var) | (var, SimpleExpr::Const(c)) =
658                    (left.as_ref(), right.as_ref())
659                {
660                    self.add_expr_to_qubo(var, coeff * c, terms, offset)?;
661                }
662            }
663            SimpleExpr::Pow(base, exp) => {
664                if *exp == 2 && matches!(base.as_ref(), SimpleExpr::Var(_)) {
665                    // x^2 = x for binary variables
666                    self.add_expr_to_qubo(base, coeff, terms, offset)?;
667                }
668            }
669        }
670        Ok(())
671    }
672
673    /// Add expression squared to QUBO terms
674    fn add_expr_squared_to_qubo(
675        &self,
676        expr: &SimpleExpr,
677        coeff: f64,
678        terms: &mut HashMap<(String, String), f64>,
679        offset: &mut f64,
680    ) -> CompileResult<()> {
681        // For simplicity, only handle simple cases
682        match expr {
683            SimpleExpr::Var(name) => {
684                // x^2 = x for binary
685                *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
686            }
687            SimpleExpr::Add(left, right) => {
688                // (a + b)^2 = a^2 + 2ab + b^2
689                self.add_expr_squared_to_qubo(left, coeff, terms, offset)?;
690                self.add_expr_squared_to_qubo(right, coeff, terms, offset)?;
691                // Cross term
692                if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
693                {
694                    let key = if v1 < v2 {
695                        (v1.clone(), v2.clone())
696                    } else {
697                        (v2.clone(), v1.clone())
698                    };
699                    *terms.entry(key).or_insert(0.0) += 2.0 * coeff;
700                }
701            }
702            _ => {}
703        }
704        Ok(())
705    }
706}
707
708/// Compiled model ready for sampling (non-dwave version)
709#[cfg(not(feature = "dwave"))]
710#[derive(Debug, Clone)]
711pub struct CompiledModel {
712    /// QUBO matrix
713    pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
714    /// Variable name to index mapping
715    pub var_map: HashMap<String, usize>,
716    /// Constant offset
717    pub offset: f64,
718    /// Original constraints (for analysis)
719    constraints: Vec<Constraint>,
720}
721
722#[cfg(not(feature = "dwave"))]
723impl CompiledModel {
724    /// Convert to QUBO format
725    pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
726        use quantrs2_anneal::ising::QuboModel;
727
728        let mut qubo = QuboModel::new(self.var_map.len());
729
730        // Set the offset
731        qubo.offset = self.offset;
732
733        // Set all the QUBO coefficients
734        for i in 0..self.qubo_matrix.nrows() {
735            for j in i..self.qubo_matrix.ncols() {
736                let value = self.qubo_matrix[[i, j]];
737                if value.abs() > 1e-10 {
738                    if i == j {
739                        // Diagonal term (linear)
740                        // SAFETY: index i is derived from matrix dimensions which match QuboModel size
741                        qubo.set_linear(i, value)
742                            .expect("index within bounds from matrix dimensions");
743                    } else {
744                        // Off-diagonal term (quadratic)
745                        // SAFETY: indices i,j are derived from matrix dimensions which match QuboModel size
746                        qubo.set_quadratic(i, j, value)
747                            .expect("indices within bounds from matrix dimensions");
748                    }
749                }
750            }
751        }
752
753        qubo
754    }
755}
756
757/// Compiler for converting symbolic expressions to QUBO models
758///
759/// This struct provides methods for converting symbolic expressions
760/// to QUBO models, which can then be solved using quantum annealing.
761#[cfg(feature = "dwave")]
762pub struct Compile {
763    /// The symbolic expression to compile
764    expr: Expr,
765}
766
767#[cfg(feature = "dwave")]
768impl Compile {
769    /// Create a new compiler with the given expression
770    pub fn new<T: Into<Expr>>(expr: T) -> Self {
771        Self { expr: expr.into() }
772    }
773
774    /// Compile the expression to a QUBO model
775    ///
776    /// This method compiles the symbolic expression to a QUBO model,
777    /// which can then be passed to a sampler for solving.
778    ///
779    /// # Returns
780    ///
781    /// A tuple containing:
782    /// - A tuple with the QUBO matrix and a mapping of variable names to indices
783    /// - An offset value that should be added to all energy values
784    pub fn get_qubo(
785        &self,
786    ) -> CompileResult<(
787        (
788            Array<f64, scirs2_core::ndarray::Ix2>,
789            HashMap<String, usize>,
790        ),
791        f64,
792    )> {
793        #[cfg(feature = "scirs")]
794        {
795            self.get_qubo_scirs()
796        }
797        #[cfg(not(feature = "scirs"))]
798        {
799            self.get_qubo_standard()
800        }
801    }
802
803    /// Standard QUBO compilation without SciRS2
804    fn get_qubo_standard(
805        &self,
806    ) -> CompileResult<(
807        (
808            Array<f64, scirs2_core::ndarray::Ix2>,
809            HashMap<String, usize>,
810        ),
811        f64,
812    )> {
813        // Expand the expression to simplify
814        let expr = self.expr.expand();
815
816        // Replace all second-degree terms (x^2 and x*x) with x, since x^2 = x for binary variables
817        // Do this BEFORE degree checking so that x^2 terms correctly appear as degree-1 after reduction
818        let expr = replace_squared_terms(&expr)?;
819
820        // Extract the coefficients and variables
821        let (coeffs, offset) = extract_coefficients(&expr)?;
822
823        // Check the actual degree using the extracted coefficient map (reliable, symbolic-expression-agnostic)
824        let max_degree = coeffs.keys().map(|vars| vars.len()).max().unwrap_or(0);
825        if max_degree > 2 {
826            return Err(CompileError::DegreeTooHigh(max_degree, 2));
827        }
828
829        // Convert to a QUBO matrix
830        let (matrix, var_map) = build_qubo_matrix(&coeffs)?;
831
832        Ok(((matrix, var_map), offset))
833    }
834
835    /// QUBO compilation with SciRS2 optimization
836    #[cfg(feature = "scirs")]
837    fn get_qubo_scirs(
838        &self,
839    ) -> CompileResult<(
840        (
841            Array<f64, scirs2_core::ndarray::Ix2>,
842            HashMap<String, usize>,
843        ),
844        f64,
845    )> {
846        // Get standard result
847        let ((matrix, var_map), offset) = self.get_qubo_standard()?;
848
849        // Apply SciRS2 enhancements
850        let enhanced_matrix = crate::scirs_stub::enhance_qubo_matrix(&matrix);
851
852        Ok(((enhanced_matrix, var_map), offset))
853    }
854
855    /// Compile the expression to a HOBO model
856    ///
857    /// This method compiles the symbolic expression to a Higher-Order Binary Optimization model,
858    /// which can handle terms of degree higher than 2.
859    ///
860    /// # Returns
861    ///
862    /// A tuple containing:
863    /// - A tuple with the HOBO tensor and a mapping of variable names to indices
864    /// - An offset value that should be added to all energy values
865    pub fn get_hobo(
866        &self,
867    ) -> CompileResult<(
868        (
869            Array<f64, scirs2_core::ndarray::IxDyn>,
870            HashMap<String, usize>,
871        ),
872        f64,
873    )> {
874        // Expand the expression to simplify
875        let mut expr = self.expr.expand();
876
877        // Calculate highest degree (dimension of the tensor)
878        let max_degree = calc_highest_degree(&expr)?;
879
880        // Replace all squared terms (x^2) with x, since x^2 = x for binary variables
881        let mut expr = replace_squared_terms(&expr)?;
882
883        // Expand again to collect like terms
884        let mut expr = expr.expand();
885
886        // Extract the coefficients and variables
887        let (coeffs, offset) = extract_coefficients(&expr)?;
888
889        // Build the HOBO tensor
890        let (tensor, var_map) = build_hobo_tensor(&coeffs, max_degree)?;
891
892        Ok(((tensor, var_map), offset))
893    }
894}
895
896// Helper function to calculate the highest degree in the expression
897#[cfg(feature = "dwave")]
898fn calc_highest_degree(expr: &Expr) -> CompileResult<usize> {
899    // If the expression is a single variable, it's degree 1
900    if expr.is_symbol() {
901        return Ok(1);
902    }
903
904    // If it's a number constant, degree is 0
905    if expr.is_number() {
906        return Ok(0);
907    }
908
909    // If it's a negation, recursively calculate the degree of the inner expression
910    if expr.is_neg() {
911        // SAFETY: is_neg() check guarantees as_neg() will succeed
912        let inner = expr.as_neg().expect("is_neg() was true");
913        return calc_highest_degree(&inner);
914    }
915
916    // If it's a power operation (like x^2)
917    if expr.is_pow() {
918        // SAFETY: is_pow() check guarantees as_pow() will succeed
919        let (base, exp) = expr.as_pow().expect("is_pow() was true");
920
921        // If the base is a symbol and exponent is a number
922        if base.is_symbol() && exp.is_number() {
923            let exp_val = match exp.to_f64() {
924                Some(n) => n,
925                None => {
926                    return Err(CompileError::InvalidExpression(
927                        "Invalid exponent".to_string(),
928                    ))
929                }
930            };
931
932            // Check if exponent is a positive integer
933            if exp_val.is_sign_positive() && exp_val.fract() == 0.0 {
934                return Ok(exp_val as usize);
935            }
936        }
937
938        // For other power expressions, recursively calculate the degree
939        let base_degree = calc_highest_degree(&base)?;
940        let exp_degree = if exp.is_number() {
941            match exp.to_f64() {
942                Some(n) => {
943                    if n.is_sign_positive() && n.fract() == 0.0 {
944                        n as usize
945                    } else {
946                        0 // Non-integer or negative exponents don't contribute to degree
947                    }
948                }
949                None => 0,
950            }
951        } else {
952            0 // Non-constant exponents don't contribute to degree
953        };
954
955        return Ok(base_degree * exp_degree);
956    }
957
958    // If it's a product (like x*y or x*x)
959    if expr.is_mul() {
960        let mut total_degree = 0;
961        // SAFETY: is_mul() check guarantees as_mul() will succeed
962        for factor in expr.as_mul().expect("is_mul() was true") {
963            total_degree += calc_highest_degree(&factor)?;
964        }
965        return Ok(total_degree);
966    }
967
968    // If it's a sum (like x + y)
969    if expr.is_add() {
970        let mut max_degree = 0;
971        // SAFETY: is_add() check guarantees as_add() will succeed
972        for term in expr.as_add().expect("is_add() was true") {
973            let term_degree = calc_highest_degree(&term)?;
974            max_degree = std::cmp::max(max_degree, term_degree);
975        }
976        return Ok(max_degree);
977    }
978
979    // Check for other compound expressions by trying to parse them
980    let expr_str = format!("{expr}");
981    if expr_str.contains('+') || expr_str.contains('-') {
982        // It's a sum-like expression but not recognized as ADD
983        // Parse the string to find the highest degree term
984        // This is a workaround for symengine type detection issues
985        let mut max_degree = 0;
986
987        // Split by + and - (keeping the sign)
988        let parts: Vec<&str> = expr_str.split(['+', '-']).collect();
989
990        for part in parts {
991            let part = part.trim();
992            if part.is_empty() {
993                continue;
994            }
995
996            // Count degree based on what the term contains
997            let degree = if part.contains("**") || part.contains('^') {
998                // Power term like x**2 or y**3
999                // Extract the exponent
1000                let exp_str = part
1001                    .split("**")
1002                    .nth(1)
1003                    .or_else(|| part.split('^').nth(1))
1004                    .unwrap_or("2")
1005                    .trim();
1006                exp_str.parse::<usize>().unwrap_or(2)
1007            } else if part.contains('*') {
1008                // Product term - count the number of variables
1009                let factors: Vec<&str> = part.split('*').collect();
1010                let mut var_count = 0;
1011                for factor in factors {
1012                    let factor = factor.trim();
1013                    // Check if it's a variable (not a number)
1014                    if !factor.is_empty() && factor.parse::<f64>().is_err() {
1015                        var_count += 1;
1016                    }
1017                }
1018                var_count
1019            } else if part.parse::<f64>().is_err() && !part.is_empty() {
1020                // Single variable
1021                1
1022            } else {
1023                // Constant
1024                0
1025            };
1026
1027            max_degree = std::cmp::max(max_degree, degree);
1028        }
1029
1030        return Ok(max_degree);
1031    }
1032
1033    // Default case - for simplicity, we'll say degree is 0
1034    // but for a complete implementation, we'd need to handle all cases
1035    Err(CompileError::InvalidExpression(format!(
1036        "Can't determine degree of: {expr}"
1037    )))
1038}
1039
1040// Helper function to replace squared terms with linear terms
1041#[cfg(feature = "dwave")]
1042fn replace_squared_terms(expr: &Expr) -> CompileResult<Expr> {
1043    // For binary variables, x^2 = x since x ∈ {0,1}
1044
1045    // If the expression is a symbol or number, just return it
1046    if expr.is_symbol() || expr.is_number() {
1047        return Ok(expr.clone());
1048    }
1049
1050    // If it's a negation, recursively process the inner expression
1051    if expr.is_neg() {
1052        // SAFETY: is_neg() check guarantees as_neg() will succeed
1053        let inner = expr.as_neg().expect("is_neg() was true");
1054        let new_inner = replace_squared_terms(&inner)?;
1055        return Ok(-new_inner);
1056    }
1057
1058    // If it's a power operation (like x^2)
1059    if expr.is_pow() {
1060        // SAFETY: is_pow() check guarantees as_pow() will succeed
1061        let (base, exp) = expr.as_pow().expect("is_pow() was true");
1062
1063        // If the base is a symbol and exponent is 2, replace with base
1064        if base.is_symbol() && exp.is_number() {
1065            let exp_val = match exp.to_f64() {
1066                Some(n) => n,
1067                None => {
1068                    return Err(CompileError::InvalidExpression(
1069                        "Invalid exponent".to_string(),
1070                    ))
1071                }
1072            };
1073
1074            // Check if exponent is 2 (for higher exponents we'd need to recurse)
1075            if exp_val == 2.0 {
1076                return Ok(base);
1077            }
1078        }
1079
1080        // For other power expressions, recursively replace
1081        let new_base = replace_squared_terms(&base)?;
1082        return Ok(new_base.pow(&exp));
1083    }
1084
1085    // If it's a product (like x*y or x*x)
1086    if expr.is_mul() {
1087        let mut new_terms = Vec::new();
1088        // SAFETY: is_mul() check guarantees as_mul() will succeed
1089        for factor in expr.as_mul().expect("is_mul() was true") {
1090            new_terms.push(replace_squared_terms(&factor)?);
1091        }
1092
1093        // Check for x*x pattern (same symbol multiplied by itself)
1094        // For binary variables, x*x = x
1095        if new_terms.len() == 2 {
1096            if let (Some(name1), Some(name2)) = (new_terms[0].as_symbol(), new_terms[1].as_symbol())
1097            {
1098                if name1 == name2 {
1099                    // x*x = x for binary variables
1100                    return Ok(new_terms.remove(0));
1101                }
1102            }
1103        }
1104
1105        // Combine the terms back into a product (without identity element)
1106        if new_terms.is_empty() {
1107            return Ok(Expr::from(1));
1108        }
1109        let mut result = new_terms.remove(0);
1110        for term in new_terms {
1111            result = result * term;
1112        }
1113        return Ok(result);
1114    }
1115
1116    // If it's a sum (like x + y)
1117    if expr.is_add() {
1118        let mut new_terms = Vec::new();
1119        // SAFETY: is_add() check guarantees as_add() will succeed
1120        for term in expr.as_add().expect("is_add() was true") {
1121            new_terms.push(replace_squared_terms(&term)?);
1122        }
1123
1124        // Combine the terms back into a sum (without identity element)
1125        if new_terms.is_empty() {
1126            return Ok(Expr::from(0));
1127        }
1128        let mut result = new_terms.remove(0);
1129        for term in new_terms {
1130            result = result + term;
1131        }
1132        return Ok(result);
1133    }
1134
1135    // For any other type of expression, just return it unchanged
1136    Ok(expr.clone())
1137}
1138
1139// Helper function to extract coefficients and variables from the expression
1140#[cfg(feature = "dwave")]
1141fn extract_coefficients(expr: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1142    let mut coeffs = HashMap::new();
1143    let mut offset = 0.0;
1144
1145    // Process expression as a sum of terms
1146    if expr.is_add() {
1147        // SAFETY: is_add() check guarantees as_add() will succeed
1148        for term in expr.as_add().expect("is_add() was true") {
1149            let (term_coeffs, term_offset) = extract_term_coefficients(&term)?;
1150
1151            // Merge coefficients
1152            for (vars, coeff) in term_coeffs {
1153                *coeffs.entry(vars).or_insert(0.0) += coeff;
1154            }
1155
1156            // Add constant terms to offset
1157            offset += term_offset;
1158        }
1159    } else {
1160        // Check if it's a sum-like expression that wasn't detected as ADD
1161        let expr_str = format!("{expr}");
1162        if expr_str.contains('+') || expr_str.contains('-') {
1163            // Use regex to split properly maintaining signs
1164            // This is a more robust workaround for symengine type detection issues
1165            use regex::Regex;
1166            // SAFETY: Static regex pattern is known to be valid at compile time
1167            let re = Regex::new(r"([+-]?)([^+-]+)").expect("static regex pattern is valid");
1168
1169            for caps in re.captures_iter(&expr_str) {
1170                let sign = caps.get(1).map_or("", |m| m.as_str());
1171                let term = caps.get(2).map_or("", |m| m.as_str()).trim();
1172
1173                if term.is_empty() {
1174                    continue;
1175                }
1176
1177                let sign_mult = if sign == "-" { -1.0 } else { 1.0 };
1178
1179                // Handle x**2 or x^2 (becomes just x for binary)
1180                if term.contains("**") || term.contains('^') {
1181                    let base = if term.contains("**") {
1182                        term.split("**").next().unwrap_or(term)
1183                    } else {
1184                        term.split('^').next().unwrap_or(term)
1185                    }
1186                    .trim();
1187
1188                    // Extract coefficient if present (e.g., "10*x^2" -> coeff=10, base="x")
1189                    let (coeff_mult, var_name) = if base.contains('*') {
1190                        let parts: Vec<&str> = base.split('*').collect();
1191                        if parts.len() == 2 {
1192                            if let Ok(num) = parts[0].trim().parse::<f64>() {
1193                                (num, parts[1].trim().to_string())
1194                            } else if let Ok(num) = parts[1].trim().parse::<f64>() {
1195                                (num, parts[0].trim().to_string())
1196                            } else {
1197                                (1.0, base.to_string())
1198                            }
1199                        } else {
1200                            (1.0, base.to_string())
1201                        }
1202                    } else {
1203                        (1.0, base.to_string())
1204                    };
1205
1206                    let vars = vec![var_name.clone()];
1207                    *coeffs.entry(vars).or_insert(0.0) += sign_mult * coeff_mult;
1208                } else if term.contains('*') {
1209                    // Handle multiplication: could be "x*y", "2*x", "x*2", "x*y*z", etc.
1210                    let parts: Vec<&str> = term.split('*').collect();
1211                    let mut coeff = sign_mult;
1212                    let mut vars = Vec::new();
1213
1214                    for part in parts {
1215                        let part = part.trim();
1216                        if let Ok(num) = part.parse::<f64>() {
1217                            coeff *= num;
1218                        } else {
1219                            // It's a variable
1220                            vars.push(part.to_string());
1221                        }
1222                    }
1223
1224                    // Sort variables for consistent ordering
1225                    vars.sort();
1226                    *coeffs.entry(vars).or_insert(0.0) += coeff;
1227                } else if let Ok(num) = term.parse::<f64>() {
1228                    // Constant term
1229                    offset += sign_mult * num;
1230                } else {
1231                    // Single variable with coefficient 1
1232                    let vars = vec![term.to_string()];
1233                    *coeffs.entry(vars).or_insert(0.0) += sign_mult;
1234                }
1235            }
1236            return Ok((coeffs, offset));
1237        }
1238
1239        // Only process as a single term if we haven't processed it as ADD yet
1240        if coeffs.is_empty() {
1241            // Process a single term
1242            let (term_coeffs, term_offset) = extract_term_coefficients(expr)?;
1243
1244            // Merge coefficients
1245            for (vars, coeff) in term_coeffs {
1246                *coeffs.entry(vars).or_insert(0.0) += coeff;
1247            }
1248
1249            // Add constant terms to offset
1250            offset += term_offset;
1251        }
1252    }
1253
1254    Ok((coeffs, offset))
1255}
1256
1257// Helper function to extract coefficient and variables from a single term
1258#[cfg(feature = "dwave")]
1259fn extract_term_coefficients(term: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1260    let mut coeffs = HashMap::new();
1261
1262    // If it's a number constant, it's an offset
1263    if term.is_number() {
1264        let value = match term.to_f64() {
1265            Some(n) => n,
1266            None => {
1267                return Err(CompileError::InvalidExpression(
1268                    "Invalid number".to_string(),
1269                ))
1270            }
1271        };
1272        return Ok((coeffs, value));
1273    }
1274
1275    // If it's an addition, recursively extract from both sides
1276    if term.is_add() {
1277        let mut offset = 0.0;
1278        // SAFETY: is_add() check guarantees as_add() will succeed
1279        for sub_term in term.as_add().expect("is_add() was true") {
1280            let (sub_coeffs, sub_offset) = extract_term_coefficients(&sub_term)?;
1281            for (vars, coeff) in sub_coeffs {
1282                *coeffs.entry(vars).or_insert(0.0) += coeff;
1283            }
1284            offset += sub_offset;
1285        }
1286        return Ok((coeffs, offset));
1287    }
1288
1289    // If it's a negation, recursively extract and negate
1290    if term.is_neg() {
1291        // SAFETY: is_neg() check guarantees as_neg() will succeed
1292        let inner = term.as_neg().expect("is_neg() was true");
1293        let (inner_coeffs, inner_offset) = extract_term_coefficients(&inner)?;
1294
1295        // Negate all coefficients
1296        for (vars, coeff) in inner_coeffs {
1297            coeffs.insert(vars, -coeff);
1298        }
1299
1300        return Ok((coeffs, -inner_offset));
1301    }
1302
1303    // If it's a symbol, it's a linear term with coefficient 1
1304    if term.is_symbol() {
1305        // SAFETY: is_symbol() check guarantees as_symbol() will succeed
1306        let var_name = term.as_symbol().expect("is_symbol() was true");
1307        let vars = vec![var_name.to_string()];
1308        coeffs.insert(vars, 1.0);
1309        return Ok((coeffs, 0.0));
1310    }
1311
1312    // If it's a product of terms
1313    if term.is_mul() {
1314        let mut coeff = 1.0;
1315        let mut vars = Vec::new();
1316
1317        // SAFETY: is_mul() check guarantees as_mul() will succeed
1318        let factors = term.as_mul().expect("is_mul() was true");
1319        // Use a stack to iteratively flatten nested products (handles symengine's
1320        // internal representation where x*y*z may appear as (* (* x y) z))
1321        let mut factor_stack: Vec<_> = factors.into_iter().collect();
1322        while let Some(factor) = factor_stack.pop() {
1323            if factor.is_number() {
1324                // Numerical factor is a coefficient
1325                let value = match factor.to_f64() {
1326                    Some(n) => n,
1327                    None => {
1328                        return Err(CompileError::InvalidExpression(
1329                            "Invalid number in product".to_string(),
1330                        ))
1331                    }
1332                };
1333                coeff *= value;
1334            } else if factor.is_symbol() {
1335                // Symbol is a variable
1336                // SAFETY: is_symbol() check guarantees as_symbol() will succeed
1337                let var_name = factor.as_symbol().expect("is_symbol() was true");
1338                vars.push(var_name.to_string());
1339            } else if factor.is_mul() {
1340                // Nested product — flatten by pushing sub-factors back onto the stack
1341                let sub_factors = factor.as_mul().expect("is_mul() was true");
1342                factor_stack.extend(sub_factors);
1343            } else if factor.is_pow() {
1344                // Power term like x^2 — for binary vars x^k = x, treat as x
1345                let (base, exp) = factor.as_pow().expect("is_pow() was true");
1346                if base.is_symbol() && exp.is_number() {
1347                    let exp_val = exp.to_f64().unwrap_or(0.0);
1348                    if exp_val.is_sign_positive() && exp_val.fract() == 0.0 && exp_val >= 1.0 {
1349                        // Binary variable: x^k = x for k >= 1
1350                        let var_name = base.as_symbol().expect("is_symbol() was true");
1351                        vars.push(var_name.to_string());
1352                    } else {
1353                        return Err(CompileError::InvalidExpression(format!(
1354                            "Unsupported power in product: {factor}"
1355                        )));
1356                    }
1357                } else {
1358                    return Err(CompileError::InvalidExpression(format!(
1359                        "Unsupported power term in product: {factor}"
1360                    )));
1361                }
1362            } else {
1363                // More complex factors not supported
1364                return Err(CompileError::InvalidExpression(format!(
1365                    "Unsupported term in product: {factor}"
1366                )));
1367            }
1368        }
1369
1370        // Sort variables for consistent ordering
1371        vars.sort();
1372
1373        if vars.is_empty() {
1374            // If there are no variables, it's a constant term
1375            return Ok((coeffs, coeff));
1376        }
1377        coeffs.insert(vars, coeff);
1378
1379        return Ok((coeffs, 0.0));
1380    }
1381
1382    // If it's a power operation (like x^2), should have been simplified earlier
1383    if term.is_pow() {
1384        return Err(CompileError::InvalidExpression(format!(
1385            "Unexpected power term after simplification: {term}"
1386        )));
1387    }
1388
1389    // Unsupported term type
1390    Err(CompileError::InvalidExpression(format!(
1391        "Unsupported term: {term}"
1392    )))
1393}
1394
1395// Helper function to build the QUBO matrix
1396#[allow(dead_code)]
1397fn build_qubo_matrix(
1398    coeffs: &HashMap<Vec<String>, f64>,
1399) -> CompileResult<(
1400    Array<f64, scirs2_core::ndarray::Ix2>,
1401    HashMap<String, usize>,
1402)> {
1403    // Collect all unique variable names
1404    let mut all_vars = HashSet::new();
1405    for vars in coeffs.keys() {
1406        for var in vars {
1407            all_vars.insert(var.clone());
1408        }
1409    }
1410
1411    // Convert to a sorted vector
1412    let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1413    sorted_vars.sort();
1414
1415    // Create the variable-to-index mapping
1416    let var_map: HashMap<String, usize> = sorted_vars
1417        .iter()
1418        .enumerate()
1419        .map(|(i, var)| (var.clone(), i))
1420        .collect();
1421
1422    // Size of the matrix
1423    let n = var_map.len();
1424
1425    // Create an empty matrix
1426    let mut matrix = Array::zeros((n, n));
1427
1428    // Fill the matrix with coefficients
1429    for (vars, &coeff) in coeffs {
1430        match vars.len() {
1431            0 => {
1432                // Should never happen since constants are handled in offset
1433            }
1434            1 => {
1435                // Linear term: var * coeff
1436                // SAFETY: var_map was built from the same variables in coeffs
1437                let i = *var_map
1438                    .get(&vars[0])
1439                    .expect("variable exists in var_map built from coeffs");
1440                matrix[[i, i]] += coeff;
1441            }
1442            2 => {
1443                // Quadratic term: var1 * var2 * coeff
1444                // SAFETY: var_map was built from the same variables in coeffs
1445                let i = *var_map
1446                    .get(&vars[0])
1447                    .expect("variable exists in var_map built from coeffs");
1448                let j = *var_map
1449                    .get(&vars[1])
1450                    .expect("variable exists in var_map built from coeffs");
1451
1452                // QUBO format requires i <= j
1453                if i == j {
1454                    // Diagonal term
1455                    matrix[[i, i]] += coeff;
1456                } else {
1457                    // Off-diagonal term - store full coefficient in upper triangular, zero in lower
1458                    if i <= j {
1459                        matrix[[i, j]] += coeff;
1460                    } else {
1461                        matrix[[j, i]] += coeff;
1462                    }
1463                }
1464            }
1465            _ => {
1466                // Higher-order terms are not supported in QUBO
1467                return Err(CompileError::DegreeTooHigh(vars.len(), 2));
1468            }
1469        }
1470    }
1471
1472    Ok((matrix, var_map))
1473}
1474
1475// Helper function to build the HOBO tensor
1476#[allow(dead_code)]
1477fn build_hobo_tensor(
1478    coeffs: &HashMap<Vec<String>, f64>,
1479    max_degree: usize,
1480) -> CompileResult<(
1481    Array<f64, scirs2_core::ndarray::IxDyn>,
1482    HashMap<String, usize>,
1483)> {
1484    // Collect all unique variable names
1485    let mut all_vars = HashSet::new();
1486    for vars in coeffs.keys() {
1487        for var in vars {
1488            all_vars.insert(var.clone());
1489        }
1490    }
1491
1492    // Convert to a sorted vector
1493    let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1494    sorted_vars.sort();
1495
1496    // Create the variable-to-index mapping
1497    let var_map: HashMap<String, usize> = sorted_vars
1498        .iter()
1499        .enumerate()
1500        .map(|(i, var)| (var.clone(), i))
1501        .collect();
1502
1503    // Size of each dimension
1504    let n = var_map.len();
1505
1506    // Create shape vector for the tensor
1507    let shape: Vec<usize> = vec![n; max_degree];
1508
1509    // Create an empty tensor
1510    let mut tensor = Array::zeros(scirs2_core::ndarray::IxDyn(&shape));
1511
1512    // Fill the tensor with coefficients
1513    for (vars, &coeff) in coeffs {
1514        let degree = vars.len();
1515
1516        if degree == 0 {
1517            // Should never happen since constants are handled in offset
1518            continue;
1519        }
1520
1521        if degree > max_degree {
1522            return Err(CompileError::DegreeTooHigh(degree, max_degree));
1523        }
1524
1525        // Convert variable names to indices
1526        // SAFETY: var_map was built from the same variables in coeffs
1527        let mut indices: Vec<usize> = vars
1528            .iter()
1529            .map(|var| {
1530                *var_map
1531                    .get(var)
1532                    .expect("variable exists in var_map built from coeffs")
1533            })
1534            .collect();
1535
1536        // Sort indices (canonical ordering)
1537        indices.sort_unstable();
1538
1539        // Pad indices to match tensor order if necessary
1540        while indices.len() < max_degree {
1541            indices.insert(0, indices[0]); // Padding with first index
1542        }
1543
1544        // Set the coefficient in the tensor
1545        let idx = scirs2_core::ndarray::IxDyn(&indices);
1546        tensor[idx] += coeff;
1547    }
1548
1549    Ok((tensor, var_map))
1550}
1551
1552/// Special compiler for problems with one-hot constraints
1553///
1554/// This is a specialized compiler that is optimized for problems
1555/// with one-hot constraints, common in many optimization problems.
1556#[cfg(feature = "dwave")]
1557pub struct PieckCompile {
1558    /// The symbolic expression to compile
1559    expr: Expr,
1560    /// Whether to show verbose output
1561    verbose: bool,
1562}
1563
1564#[cfg(feature = "dwave")]
1565impl PieckCompile {
1566    /// Create a new Pieck compiler with the given expression
1567    pub fn new<T: Into<Expr>>(expr: T, verbose: bool) -> Self {
1568        Self {
1569            expr: expr.into(),
1570            verbose,
1571        }
1572    }
1573
1574    /// Compile the expression to a QUBO model optimized for one-hot constraints
1575    pub fn get_qubo(
1576        &self,
1577    ) -> CompileResult<(
1578        (
1579            Array<f64, scirs2_core::ndarray::Ix2>,
1580            HashMap<String, usize>,
1581        ),
1582        f64,
1583    )> {
1584        // Implementation will compile the expression using specialized techniques
1585        // For now, call the regular compiler
1586        Compile::new(self.expr.clone()).get_qubo()
1587    }
1588}