Skip to main content

quantrs2_tytan/
compile.rs

1//! Compilation of symbolic expressions to QUBO/HOBO models.
2//!
3//! This module provides utilities for compiling symbolic expressions
4//! into QUBO (Quadratic Unconstrained Binary Optimization) and
5//! HOBO (Higher-Order Binary Optimization) models.
6
7#![allow(dead_code)]
8
9use scirs2_core::ndarray::Array;
10use std::collections::{HashMap, HashSet};
11
12#[cfg(feature = "scirs")]
13use crate::scirs_stub;
14
15#[cfg(feature = "dwave")]
16use quantrs2_symengine_pure::Expression as SymEngineExpression;
17
18#[cfg(feature = "dwave")]
19type Expr = SymEngineExpression;
20use thiserror::Error;
21
22use quantrs2_anneal::QuboError;
23
24/// Unified expression interface for examples
25#[cfg(feature = "dwave")]
26pub mod expr {
27    use quantrs2_symengine_pure::Expression as SymEngineExpression;
28
29    pub type Expr = SymEngineExpression;
30
31    pub fn constant(value: f64) -> Expr {
32        SymEngineExpression::from(value)
33    }
34
35    pub fn var(name: &str) -> Expr {
36        SymEngineExpression::symbol(name)
37    }
38}
39
40#[cfg(not(feature = "dwave"))]
41pub mod expr {
42    use super::SimpleExpr;
43
44    pub type Expr = SimpleExpr;
45
46    pub const fn constant(value: f64) -> Expr {
47        SimpleExpr::constant(value)
48    }
49
50    pub fn var(name: &str) -> Expr {
51        SimpleExpr::var(name)
52    }
53}
54
55/// Errors that can occur during compilation
56#[derive(Error, Debug)]
57pub enum CompileError {
58    /// Error when the expression is invalid
59    #[error("Invalid expression: {0}")]
60    InvalidExpression(String),
61
62    /// Error when a term has too high a degree
63    #[error("Term has degree {0}, but maximum supported is {1}")]
64    DegreeTooHigh(usize, usize),
65
66    /// Error in the underlying QUBO model
67    #[error("QUBO error: {0}")]
68    QuboError(#[from] QuboError),
69
70    /// Error in Symengine operations
71    #[error("Symengine error: {0}")]
72    SymengineError(String),
73}
74
75/// Result type for compilation operations
76pub type CompileResult<T> = Result<T, CompileError>;
77
78// Simple expression type for when dwave feature is not enabled
79#[cfg(not(feature = "dwave"))]
80#[derive(Debug, Clone)]
81pub enum SimpleExpr {
82    /// Variable
83    Var(String),
84    /// Constant
85    Const(f64),
86    /// Addition
87    Add(Box<Self>, Box<Self>),
88    /// Multiplication
89    Mul(Box<Self>, Box<Self>),
90    /// Power
91    Pow(Box<Self>, i32),
92}
93
94#[cfg(not(feature = "dwave"))]
95impl SimpleExpr {
96    /// Create a variable
97    pub fn var(name: &str) -> Self {
98        Self::Var(name.to_string())
99    }
100
101    /// Create a constant
102    pub const fn constant(value: f64) -> Self {
103        Self::Const(value)
104    }
105}
106
107#[cfg(not(feature = "dwave"))]
108impl std::ops::Add for SimpleExpr {
109    type Output = Self;
110
111    fn add(self, rhs: Self) -> Self::Output {
112        Self::Add(Box::new(self), Box::new(rhs))
113    }
114}
115
116#[cfg(not(feature = "dwave"))]
117impl std::ops::Mul for SimpleExpr {
118    type Output = Self;
119
120    fn mul(self, rhs: Self) -> Self::Output {
121        Self::Mul(Box::new(self), Box::new(rhs))
122    }
123}
124
125#[cfg(not(feature = "dwave"))]
126impl std::iter::Sum for SimpleExpr {
127    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
128        iter.fold(Self::Const(0.0), |acc, x| acc + x)
129    }
130}
131
132/// High-level model for constraint optimization problems
133#[cfg(feature = "dwave")]
134#[derive(Debug, Clone)]
135pub struct Model {
136    /// Variables in the model
137    variables: HashSet<String>,
138    /// Objective function expression
139    objective: Option<Expr>,
140    /// Constraints
141    constraints: Vec<Constraint>,
142}
143
144/// Constraint types
145#[cfg(feature = "dwave")]
146#[derive(Debug, Clone)]
147enum Constraint {
148    /// Equality constraint: sum of variables equals value
149    Equality {
150        name: String,
151        expr: Expr,
152        value: f64,
153    },
154    /// Inequality constraint: sum of variables <= value
155    LessEqual {
156        name: String,
157        expr: Expr,
158        value: f64,
159    },
160    /// At most one constraint: at most one variable can be 1
161    AtMostOne { name: String, variables: Vec<Expr> },
162    /// Implication constraint: if any condition is true, then result must be true
163    ImpliesAny {
164        name: String,
165        conditions: Vec<Expr>,
166        result: Expr,
167    },
168}
169
170#[cfg(feature = "dwave")]
171impl Default for Model {
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177#[cfg(feature = "dwave")]
178impl Model {
179    /// Create a new empty model
180    pub fn new() -> Self {
181        Self {
182            variables: HashSet::new(),
183            objective: None,
184            constraints: Vec::new(),
185        }
186    }
187
188    /// Add a variable to the model
189    pub fn add_variable(&mut self, name: &str) -> CompileResult<Expr> {
190        self.variables.insert(name.to_string());
191        Ok(SymEngineExpression::symbol(name))
192    }
193
194    /// Set the objective function
195    pub fn set_objective(&mut self, expr: Expr) {
196        self.objective = Some(expr);
197    }
198
199    /// Add constraint: exactly one of the variables must be 1
200    pub fn add_constraint_eq_one(&mut self, name: &str, variables: Vec<Expr>) -> CompileResult<()> {
201        // For binary variables, sum = 1 means exactly one is 1
202        let sum_expr = variables
203            .iter()
204            .fold(Expr::from(0), |acc, v| acc + v.clone());
205        self.constraints.push(Constraint::Equality {
206            name: name.to_string(),
207            expr: sum_expr,
208            value: 1.0,
209        });
210        Ok(())
211    }
212
213    /// Add constraint: at most one of the variables can be 1
214    pub fn add_constraint_at_most_one(
215        &mut self,
216        name: &str,
217        variables: Vec<Expr>,
218    ) -> CompileResult<()> {
219        self.constraints.push(Constraint::AtMostOne {
220            name: name.to_string(),
221            variables,
222        });
223        Ok(())
224    }
225
226    /// Add constraint: if any condition is true, then result must be true
227    pub fn add_constraint_implies_any(
228        &mut self,
229        name: &str,
230        conditions: Vec<Expr>,
231        result: Expr,
232    ) -> CompileResult<()> {
233        self.constraints.push(Constraint::ImpliesAny {
234            name: name.to_string(),
235            conditions,
236            result,
237        });
238        Ok(())
239    }
240
241    /// Compile the model to a CompiledModel
242    pub fn compile(&self) -> CompileResult<CompiledModel> {
243        // Build the final expression with penalty terms
244        let mut final_expr = self.objective.clone().unwrap_or_else(|| Expr::from(0));
245
246        // Default penalty weight
247        let penalty_weight = 10.0;
248
249        // Add penalty terms for constraints
250        for constraint in &self.constraints {
251            match constraint {
252                Constraint::Equality { expr, value, .. } => {
253                    // (expr - value)^2 penalty
254                    let diff = expr.clone() - Expr::from(*value);
255                    final_expr = final_expr + Expr::from(penalty_weight) * diff.clone() * diff;
256                }
257                #[cfg(feature = "dwave")]
258                Constraint::LessEqual { expr, value, .. } => {
259                    // max(0, expr - value)^2 penalty
260                    // For simplicity, we'll use a quadratic penalty
261                    let excess = expr.clone() - Expr::from(*value);
262                    final_expr = final_expr + Expr::from(penalty_weight) * excess.clone() * excess;
263                }
264                Constraint::AtMostOne { variables, .. } => {
265                    // Penalty: sum(xi * xj) for all i < j
266                    for i in 0..variables.len() {
267                        for j in (i + 1)..variables.len() {
268                            final_expr = final_expr
269                                + Expr::from(penalty_weight)
270                                    * variables[i].clone()
271                                    * variables[j].clone();
272                        }
273                    }
274                }
275                Constraint::ImpliesAny {
276                    conditions, result, ..
277                } => {
278                    // If any condition is true, result must be true
279                    // Penalty: (max(conditions) - result)^2 where max is approximated by sum
280                    let conditions_sum = conditions
281                        .iter()
282                        .fold(Expr::from(0), |acc, c| acc + c.clone());
283                    // Penalty when conditions_sum > 0 and result = 0
284                    final_expr = final_expr
285                        + Expr::from(penalty_weight)
286                            * conditions_sum
287                            * (Expr::from(1) - result.clone());
288                }
289            }
290        }
291
292        // Use the standard compiler
293        let mut compiler = Compile::new(final_expr);
294        let ((qubo_matrix, var_map), offset) = compiler.get_qubo()?;
295
296        Ok(CompiledModel {
297            qubo_matrix,
298            var_map,
299            offset,
300            constraints: self.constraints.clone(),
301        })
302    }
303}
304
305/// Compiled model ready for sampling
306#[cfg(feature = "dwave")]
307#[derive(Debug, Clone)]
308pub struct CompiledModel {
309    /// QUBO matrix
310    pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
311    /// Variable name to index mapping
312    pub var_map: HashMap<String, usize>,
313    /// Constant offset
314    pub offset: f64,
315    /// Original constraints (for analysis)
316    constraints: Vec<Constraint>,
317}
318
319#[cfg(feature = "dwave")]
320impl CompiledModel {
321    /// Convert to QUBO format
322    pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
323        use quantrs2_anneal::ising::QuboModel;
324
325        let mut qubo = QuboModel::new(self.var_map.len());
326
327        // Set the offset
328        qubo.offset = self.offset;
329
330        // Set all the QUBO coefficients
331        for i in 0..self.qubo_matrix.nrows() {
332            for j in i..self.qubo_matrix.ncols() {
333                let value = self.qubo_matrix[[i, j]];
334                if value.abs() > 1e-10 {
335                    if i == j {
336                        // Diagonal term (linear)
337                        // SAFETY: index i is derived from matrix dimensions which match QuboModel size
338                        qubo.set_linear(i, value)
339                            .expect("index within bounds from matrix dimensions");
340                    } else {
341                        // Off-diagonal term (quadratic)
342                        // SAFETY: indices i,j are derived from matrix dimensions which match QuboModel size
343                        qubo.set_quadratic(i, j, value)
344                            .expect("indices within bounds from matrix dimensions");
345                    }
346                }
347            }
348        }
349
350        qubo
351    }
352
353    /// Count the number of constraint violations for a given variable assignment.
354    ///
355    /// Returns the count of constraints that the assignment violates.
356    /// `assignments` maps variable names to their binary values (true = 1, false = 0).
357    pub fn count_constraint_violations(&self, assignments: &HashMap<String, bool>) -> usize {
358        let float_vals: HashMap<String, f64> = assignments
359            .iter()
360            .map(|(k, &v)| (k.clone(), if v { 1.0 } else { 0.0 }))
361            .collect();
362
363        let mut violations = 0usize;
364
365        for constraint in &self.constraints {
366            let violated = match constraint {
367                Constraint::Equality { expr, value, .. } => match expr.eval(&float_vals) {
368                    Ok(result) => (result - value).abs() > 1e-6,
369                    Err(_) => false,
370                },
371                Constraint::LessEqual { expr, value, .. } => match expr.eval(&float_vals) {
372                    Ok(result) => result > value + 1e-6,
373                    Err(_) => false,
374                },
375                Constraint::AtMostOne { variables, .. } => {
376                    let count: f64 = variables
377                        .iter()
378                        .filter_map(|v| v.eval(&float_vals).ok())
379                        .filter(|&val| val > 0.5)
380                        .count() as f64;
381                    count > 1.0 + 1e-6
382                }
383                Constraint::ImpliesAny {
384                    conditions, result, ..
385                } => {
386                    let any_condition_true = conditions
387                        .iter()
388                        .any(|c| c.eval(&float_vals).map(|val| val > 0.5).unwrap_or(false));
389                    if any_condition_true {
390                        match result.eval(&float_vals) {
391                            Ok(val) => val < 0.5,
392                            Err(_) => false,
393                        }
394                    } else {
395                        false
396                    }
397                }
398            };
399            if violated {
400                violations += 1;
401            }
402        }
403
404        violations
405    }
406
407    /// Return the total number of constraints in this model.
408    pub fn num_constraints(&self) -> usize {
409        self.constraints.len()
410    }
411}
412
413/// High-level model for constraint optimization problems (non-dwave version)
414#[cfg(not(feature = "dwave"))]
415#[derive(Debug, Clone)]
416pub struct Model {
417    /// Variables in the model
418    variables: HashSet<String>,
419    /// Objective function expression
420    objective: Option<SimpleExpr>,
421    /// Constraints
422    constraints: Vec<Constraint>,
423}
424
425/// Constraint types (non-dwave version)
426#[cfg(not(feature = "dwave"))]
427#[derive(Debug, Clone)]
428enum Constraint {
429    /// Equality constraint: sum of variables equals value
430    Equality {
431        name: String,
432        expr: SimpleExpr,
433        value: f64,
434    },
435    /// At most one constraint: at most one variable can be 1
436    AtMostOne {
437        name: String,
438        variables: Vec<SimpleExpr>,
439    },
440    /// Implication constraint: if any condition is true, then result must be true
441    ImpliesAny {
442        name: String,
443        conditions: Vec<SimpleExpr>,
444        result: SimpleExpr,
445    },
446}
447
448#[cfg(not(feature = "dwave"))]
449impl Default for Model {
450    fn default() -> Self {
451        Self::new()
452    }
453}
454
455#[cfg(not(feature = "dwave"))]
456impl Model {
457    /// Create a new empty model
458    pub fn new() -> Self {
459        Self {
460            variables: HashSet::new(),
461            objective: None,
462            constraints: Vec::new(),
463        }
464    }
465
466    /// Add a variable to the model
467    pub fn add_variable(&mut self, name: &str) -> CompileResult<SimpleExpr> {
468        self.variables.insert(name.to_string());
469        Ok(SimpleExpr::var(name))
470    }
471
472    /// Set the objective function
473    pub fn set_objective(&mut self, expr: SimpleExpr) {
474        self.objective = Some(expr);
475    }
476
477    /// Add constraint: exactly one of the variables must be 1
478    pub fn add_constraint_eq_one(
479        &mut self,
480        name: &str,
481        variables: Vec<SimpleExpr>,
482    ) -> CompileResult<()> {
483        let sum_expr = variables.into_iter().sum();
484        self.constraints.push(Constraint::Equality {
485            name: name.to_string(),
486            expr: sum_expr,
487            value: 1.0,
488        });
489        Ok(())
490    }
491
492    /// Add constraint: at most one of the variables can be 1
493    pub fn add_constraint_at_most_one(
494        &mut self,
495        name: &str,
496        variables: Vec<SimpleExpr>,
497    ) -> CompileResult<()> {
498        self.constraints.push(Constraint::AtMostOne {
499            name: name.to_string(),
500            variables,
501        });
502        Ok(())
503    }
504
505    /// Add constraint: if any condition is true, then result must be true
506    pub fn add_constraint_implies_any(
507        &mut self,
508        name: &str,
509        conditions: Vec<SimpleExpr>,
510        result: SimpleExpr,
511    ) -> CompileResult<()> {
512        self.constraints.push(Constraint::ImpliesAny {
513            name: name.to_string(),
514            conditions,
515            result,
516        });
517        Ok(())
518    }
519
520    /// Compile the model to a CompiledModel
521    pub fn compile(&self) -> CompileResult<CompiledModel> {
522        // Build QUBO directly from constraints
523        let mut qubo_terms: HashMap<(String, String), f64> = HashMap::new();
524        let mut offset = 0.0;
525        let penalty_weight = 10.0;
526
527        // Process objective if present
528        if let Some(ref obj) = self.objective {
529            self.add_expr_to_qubo(obj, 1.0, &mut qubo_terms, &mut offset)?;
530        }
531
532        // Process constraints
533        for constraint in &self.constraints {
534            match constraint {
535                Constraint::Equality { expr, value, .. } => {
536                    // (expr - value)^2 penalty
537                    // Expand: expr^2 - 2*expr*value + value^2
538                    self.add_expr_squared_to_qubo(
539                        expr,
540                        penalty_weight,
541                        &mut qubo_terms,
542                        &mut offset,
543                    )?;
544                    self.add_expr_to_qubo(
545                        expr,
546                        -2.0 * penalty_weight * value,
547                        &mut qubo_terms,
548                        &mut offset,
549                    )?;
550                    offset += penalty_weight * value * value;
551                }
552                Constraint::AtMostOne { variables, .. } => {
553                    // Penalty: sum(xi * xj) for all i < j
554                    for i in 0..variables.len() {
555                        for j in (i + 1)..variables.len() {
556                            if let (SimpleExpr::Var(vi), SimpleExpr::Var(vj)) =
557                                (&variables[i], &variables[j])
558                            {
559                                let key = if vi < vj {
560                                    (vi.clone(), vj.clone())
561                                } else {
562                                    (vj.clone(), vi.clone())
563                                };
564                                *qubo_terms.entry(key).or_insert(0.0) += penalty_weight;
565                            }
566                        }
567                    }
568                }
569                Constraint::ImpliesAny {
570                    conditions, result, ..
571                } => {
572                    // Penalty: sum(conditions) * (1 - result)
573                    for cond in conditions {
574                        if let (SimpleExpr::Var(c), SimpleExpr::Var(r)) = (cond, result) {
575                            let key = if c < r {
576                                (c.clone(), r.clone())
577                            } else {
578                                (r.clone(), c.clone())
579                            };
580                            *qubo_terms.entry(key).or_insert(0.0) -= penalty_weight;
581                        }
582                        // Also add linear term for condition
583                        if let SimpleExpr::Var(c) = cond {
584                            *qubo_terms.entry((c.clone(), c.clone())).or_insert(0.0) +=
585                                penalty_weight;
586                        }
587                    }
588                }
589            }
590        }
591
592        // Convert to matrix form
593        let all_vars: HashSet<String> = qubo_terms
594            .keys()
595            .flat_map(|(v1, v2)| vec![v1.clone(), v2.clone()])
596            .collect();
597        let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
598        sorted_vars.sort();
599
600        let var_map: HashMap<String, usize> = sorted_vars
601            .iter()
602            .enumerate()
603            .map(|(i, v)| (v.clone(), i))
604            .collect();
605
606        let n = var_map.len();
607        let mut matrix = Array::zeros((n, n));
608
609        for ((v1, v2), coeff) in qubo_terms {
610            let i = var_map[&v1];
611            let j = var_map[&v2];
612            if i == j {
613                matrix[[i, i]] += coeff;
614            } else {
615                matrix[[i, j]] += coeff / 2.0;
616                matrix[[j, i]] += coeff / 2.0;
617            }
618        }
619
620        Ok(CompiledModel {
621            qubo_matrix: matrix,
622            var_map,
623            offset,
624            constraints: self.constraints.clone(),
625        })
626    }
627
628    /// Add expression to QUBO terms
629    fn add_expr_to_qubo(
630        &self,
631        expr: &SimpleExpr,
632        coeff: f64,
633        terms: &mut HashMap<(String, String), f64>,
634        offset: &mut f64,
635    ) -> CompileResult<()> {
636        match expr {
637            SimpleExpr::Var(name) => {
638                *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
639            }
640            SimpleExpr::Const(val) => {
641                *offset += coeff * val;
642            }
643            SimpleExpr::Add(left, right) => {
644                self.add_expr_to_qubo(left, coeff, terms, offset)?;
645                self.add_expr_to_qubo(right, coeff, terms, offset)?;
646            }
647            SimpleExpr::Mul(left, right) => {
648                if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
649                {
650                    let key = if v1 < v2 {
651                        (v1.clone(), v2.clone())
652                    } else {
653                        (v2.clone(), v1.clone())
654                    };
655                    *terms.entry(key).or_insert(0.0) += coeff;
656                } else if let (SimpleExpr::Const(c), var) | (var, SimpleExpr::Const(c)) =
657                    (left.as_ref(), right.as_ref())
658                {
659                    self.add_expr_to_qubo(var, coeff * c, terms, offset)?;
660                }
661            }
662            SimpleExpr::Pow(base, exp) => {
663                if *exp == 2 && matches!(base.as_ref(), SimpleExpr::Var(_)) {
664                    // x^2 = x for binary variables
665                    self.add_expr_to_qubo(base, coeff, terms, offset)?;
666                }
667            }
668        }
669        Ok(())
670    }
671
672    /// Add expression squared to QUBO terms
673    fn add_expr_squared_to_qubo(
674        &self,
675        expr: &SimpleExpr,
676        coeff: f64,
677        terms: &mut HashMap<(String, String), f64>,
678        offset: &mut f64,
679    ) -> CompileResult<()> {
680        // For simplicity, only handle simple cases
681        match expr {
682            SimpleExpr::Var(name) => {
683                // x^2 = x for binary
684                *terms.entry((name.clone(), name.clone())).or_insert(0.0) += coeff;
685            }
686            SimpleExpr::Add(left, right) => {
687                // (a + b)^2 = a^2 + 2ab + b^2
688                self.add_expr_squared_to_qubo(left, coeff, terms, offset)?;
689                self.add_expr_squared_to_qubo(right, coeff, terms, offset)?;
690                // Cross term
691                if let (SimpleExpr::Var(v1), SimpleExpr::Var(v2)) = (left.as_ref(), right.as_ref())
692                {
693                    let key = if v1 < v2 {
694                        (v1.clone(), v2.clone())
695                    } else {
696                        (v2.clone(), v1.clone())
697                    };
698                    *terms.entry(key).or_insert(0.0) += 2.0 * coeff;
699                }
700            }
701            _ => {}
702        }
703        Ok(())
704    }
705}
706
707/// Compiled model ready for sampling (non-dwave version)
708#[cfg(not(feature = "dwave"))]
709#[derive(Debug, Clone)]
710pub struct CompiledModel {
711    /// QUBO matrix
712    pub qubo_matrix: Array<f64, scirs2_core::ndarray::Ix2>,
713    /// Variable name to index mapping
714    pub var_map: HashMap<String, usize>,
715    /// Constant offset
716    pub offset: f64,
717    /// Original constraints (for analysis)
718    constraints: Vec<Constraint>,
719}
720
721#[cfg(not(feature = "dwave"))]
722impl CompiledModel {
723    /// Convert to QUBO format
724    pub fn to_qubo(&self) -> quantrs2_anneal::ising::QuboModel {
725        use quantrs2_anneal::ising::QuboModel;
726
727        let mut qubo = QuboModel::new(self.var_map.len());
728
729        // Set the offset
730        qubo.offset = self.offset;
731
732        // Set all the QUBO coefficients
733        for i in 0..self.qubo_matrix.nrows() {
734            for j in i..self.qubo_matrix.ncols() {
735                let value = self.qubo_matrix[[i, j]];
736                if value.abs() > 1e-10 {
737                    if i == j {
738                        // Diagonal term (linear)
739                        // SAFETY: index i is derived from matrix dimensions which match QuboModel size
740                        qubo.set_linear(i, value)
741                            .expect("index within bounds from matrix dimensions");
742                    } else {
743                        // Off-diagonal term (quadratic)
744                        // SAFETY: indices i,j are derived from matrix dimensions which match QuboModel size
745                        qubo.set_quadratic(i, j, value)
746                            .expect("indices within bounds from matrix dimensions");
747                    }
748                }
749            }
750        }
751
752        qubo
753    }
754}
755
756/// Compiler for converting symbolic expressions to QUBO models
757///
758/// This struct provides methods for converting symbolic expressions
759/// to QUBO models, which can then be solved using quantum annealing.
760#[cfg(feature = "dwave")]
761pub struct Compile {
762    /// The symbolic expression to compile
763    expr: Expr,
764}
765
766#[cfg(feature = "dwave")]
767impl Compile {
768    /// Create a new compiler with the given expression
769    pub fn new<T: Into<Expr>>(expr: T) -> Self {
770        Self { expr: expr.into() }
771    }
772
773    /// Compile the expression to a QUBO model
774    ///
775    /// This method compiles the symbolic expression to a QUBO model,
776    /// which can then be passed to a sampler for solving.
777    ///
778    /// # Returns
779    ///
780    /// A tuple containing:
781    /// - A tuple with the QUBO matrix and a mapping of variable names to indices
782    /// - An offset value that should be added to all energy values
783    pub fn get_qubo(
784        &self,
785    ) -> CompileResult<(
786        (
787            Array<f64, scirs2_core::ndarray::Ix2>,
788            HashMap<String, usize>,
789        ),
790        f64,
791    )> {
792        #[cfg(feature = "scirs")]
793        {
794            self.get_qubo_scirs()
795        }
796        #[cfg(not(feature = "scirs"))]
797        {
798            self.get_qubo_standard()
799        }
800    }
801
802    /// Standard QUBO compilation without SciRS2
803    fn get_qubo_standard(
804        &self,
805    ) -> CompileResult<(
806        (
807            Array<f64, scirs2_core::ndarray::Ix2>,
808            HashMap<String, usize>,
809        ),
810        f64,
811    )> {
812        // Expand the expression to simplify
813        let expr = self.expr.expand();
814
815        // Replace all second-degree terms (x^2 and x*x) with x, since x^2 = x for binary variables
816        // Do this BEFORE degree checking so that x^2 terms correctly appear as degree-1 after reduction
817        let expr = replace_squared_terms(&expr)?;
818
819        // Extract the coefficients and variables
820        let (coeffs, offset) = extract_coefficients(&expr)?;
821
822        // Check the actual degree using the extracted coefficient map (reliable, symbolic-expression-agnostic)
823        let max_degree = coeffs.keys().map(|vars| vars.len()).max().unwrap_or(0);
824        if max_degree > 2 {
825            return Err(CompileError::DegreeTooHigh(max_degree, 2));
826        }
827
828        // Convert to a QUBO matrix
829        let (matrix, var_map) = build_qubo_matrix(&coeffs)?;
830
831        Ok(((matrix, var_map), offset))
832    }
833
834    /// QUBO compilation with SciRS2 optimization
835    #[cfg(feature = "scirs")]
836    fn get_qubo_scirs(
837        &self,
838    ) -> CompileResult<(
839        (
840            Array<f64, scirs2_core::ndarray::Ix2>,
841            HashMap<String, usize>,
842        ),
843        f64,
844    )> {
845        // Get standard result
846        let ((matrix, var_map), offset) = self.get_qubo_standard()?;
847
848        // Apply SciRS2 enhancements
849        let enhanced_matrix = crate::scirs_stub::enhance_qubo_matrix(&matrix);
850
851        Ok(((enhanced_matrix, var_map), offset))
852    }
853
854    /// Compile the expression to a HOBO model
855    ///
856    /// This method compiles the symbolic expression to a Higher-Order Binary Optimization model,
857    /// which can handle terms of degree higher than 2.
858    ///
859    /// # Returns
860    ///
861    /// A tuple containing:
862    /// - A tuple with the HOBO tensor and a mapping of variable names to indices
863    /// - An offset value that should be added to all energy values
864    pub fn get_hobo(
865        &self,
866    ) -> CompileResult<(
867        (
868            Array<f64, scirs2_core::ndarray::IxDyn>,
869            HashMap<String, usize>,
870        ),
871        f64,
872    )> {
873        // Expand the expression to simplify
874        let mut expr = self.expr.expand();
875
876        // Calculate highest degree (dimension of the tensor)
877        let max_degree = calc_highest_degree(&expr)?;
878
879        // Replace all squared terms (x^2) with x, since x^2 = x for binary variables
880        let mut expr = replace_squared_terms(&expr)?;
881
882        // Expand again to collect like terms
883        let mut expr = expr.expand();
884
885        // Extract the coefficients and variables
886        let (coeffs, offset) = extract_coefficients(&expr)?;
887
888        // Build the HOBO tensor
889        let (tensor, var_map) = build_hobo_tensor(&coeffs, max_degree)?;
890
891        Ok(((tensor, var_map), offset))
892    }
893}
894
895// Helper function to calculate the highest degree in the expression
896#[cfg(feature = "dwave")]
897fn calc_highest_degree(expr: &Expr) -> CompileResult<usize> {
898    // If the expression is a single variable, it's degree 1
899    if expr.is_symbol() {
900        return Ok(1);
901    }
902
903    // If it's a number constant, degree is 0
904    if expr.is_number() {
905        return Ok(0);
906    }
907
908    // If it's a negation, recursively calculate the degree of the inner expression
909    if expr.is_neg() {
910        // SAFETY: is_neg() check guarantees as_neg() will succeed
911        let inner = expr.as_neg().expect("is_neg() was true");
912        return calc_highest_degree(&inner);
913    }
914
915    // If it's a power operation (like x^2)
916    if expr.is_pow() {
917        // SAFETY: is_pow() check guarantees as_pow() will succeed
918        let (base, exp) = expr.as_pow().expect("is_pow() was true");
919
920        // If the base is a symbol and exponent is a number
921        if base.is_symbol() && exp.is_number() {
922            let exp_val = match exp.to_f64() {
923                Some(n) => n,
924                None => {
925                    return Err(CompileError::InvalidExpression(
926                        "Invalid exponent".to_string(),
927                    ))
928                }
929            };
930
931            // Check if exponent is a positive integer
932            if exp_val.is_sign_positive() && exp_val.fract() == 0.0 {
933                return Ok(exp_val as usize);
934            }
935        }
936
937        // For other power expressions, recursively calculate the degree
938        let base_degree = calc_highest_degree(&base)?;
939        let exp_degree = if exp.is_number() {
940            match exp.to_f64() {
941                Some(n) => {
942                    if n.is_sign_positive() && n.fract() == 0.0 {
943                        n as usize
944                    } else {
945                        0 // Non-integer or negative exponents don't contribute to degree
946                    }
947                }
948                None => 0,
949            }
950        } else {
951            0 // Non-constant exponents don't contribute to degree
952        };
953
954        return Ok(base_degree * exp_degree);
955    }
956
957    // If it's a product (like x*y or x*x)
958    if expr.is_mul() {
959        let mut total_degree = 0;
960        // SAFETY: is_mul() check guarantees as_mul() will succeed
961        for factor in expr.as_mul().expect("is_mul() was true") {
962            total_degree += calc_highest_degree(&factor)?;
963        }
964        return Ok(total_degree);
965    }
966
967    // If it's a sum (like x + y)
968    if expr.is_add() {
969        let mut max_degree = 0;
970        // SAFETY: is_add() check guarantees as_add() will succeed
971        for term in expr.as_add().expect("is_add() was true") {
972            let term_degree = calc_highest_degree(&term)?;
973            max_degree = std::cmp::max(max_degree, term_degree);
974        }
975        return Ok(max_degree);
976    }
977
978    // Check for other compound expressions by trying to parse them
979    let expr_str = format!("{expr}");
980    if expr_str.contains('+') || expr_str.contains('-') {
981        // It's a sum-like expression but not recognized as ADD
982        // Parse the string to find the highest degree term
983        // This is a workaround for symengine type detection issues
984        let mut max_degree = 0;
985
986        // Split by + and - (keeping the sign)
987        let parts: Vec<&str> = expr_str.split(['+', '-']).collect();
988
989        for part in parts {
990            let part = part.trim();
991            if part.is_empty() {
992                continue;
993            }
994
995            // Count degree based on what the term contains
996            let degree = if part.contains("**") || part.contains('^') {
997                // Power term like x**2 or y**3
998                // Extract the exponent
999                let exp_str = part
1000                    .split("**")
1001                    .nth(1)
1002                    .or_else(|| part.split('^').nth(1))
1003                    .unwrap_or("2")
1004                    .trim();
1005                exp_str.parse::<usize>().unwrap_or(2)
1006            } else if part.contains('*') {
1007                // Product term - count the number of variables
1008                let factors: Vec<&str> = part.split('*').collect();
1009                let mut var_count = 0;
1010                for factor in factors {
1011                    let factor = factor.trim();
1012                    // Check if it's a variable (not a number)
1013                    if !factor.is_empty() && factor.parse::<f64>().is_err() {
1014                        var_count += 1;
1015                    }
1016                }
1017                var_count
1018            } else if part.parse::<f64>().is_err() && !part.is_empty() {
1019                // Single variable
1020                1
1021            } else {
1022                // Constant
1023                0
1024            };
1025
1026            max_degree = std::cmp::max(max_degree, degree);
1027        }
1028
1029        return Ok(max_degree);
1030    }
1031
1032    // Default case - for simplicity, we'll say degree is 0
1033    // but for a complete implementation, we'd need to handle all cases
1034    Err(CompileError::InvalidExpression(format!(
1035        "Can't determine degree of: {expr}"
1036    )))
1037}
1038
1039// Helper function to replace squared terms with linear terms
1040#[cfg(feature = "dwave")]
1041fn replace_squared_terms(expr: &Expr) -> CompileResult<Expr> {
1042    // For binary variables, x^2 = x since x ∈ {0,1}
1043
1044    // If the expression is a symbol or number, just return it
1045    if expr.is_symbol() || expr.is_number() {
1046        return Ok(expr.clone());
1047    }
1048
1049    // If it's a negation, recursively process the inner expression
1050    if expr.is_neg() {
1051        // SAFETY: is_neg() check guarantees as_neg() will succeed
1052        let inner = expr.as_neg().expect("is_neg() was true");
1053        let new_inner = replace_squared_terms(&inner)?;
1054        return Ok(-new_inner);
1055    }
1056
1057    // If it's a power operation (like x^2)
1058    if expr.is_pow() {
1059        // SAFETY: is_pow() check guarantees as_pow() will succeed
1060        let (base, exp) = expr.as_pow().expect("is_pow() was true");
1061
1062        // If the base is a symbol and exponent is 2, replace with base
1063        if base.is_symbol() && exp.is_number() {
1064            let exp_val = match exp.to_f64() {
1065                Some(n) => n,
1066                None => {
1067                    return Err(CompileError::InvalidExpression(
1068                        "Invalid exponent".to_string(),
1069                    ))
1070                }
1071            };
1072
1073            // Check if exponent is 2 (for higher exponents we'd need to recurse)
1074            if exp_val == 2.0 {
1075                return Ok(base);
1076            }
1077        }
1078
1079        // For other power expressions, recursively replace
1080        let new_base = replace_squared_terms(&base)?;
1081        return Ok(new_base.pow(&exp));
1082    }
1083
1084    // If it's a product (like x*y or x*x)
1085    if expr.is_mul() {
1086        let mut new_terms = Vec::new();
1087        // SAFETY: is_mul() check guarantees as_mul() will succeed
1088        for factor in expr.as_mul().expect("is_mul() was true") {
1089            new_terms.push(replace_squared_terms(&factor)?);
1090        }
1091
1092        // Check for x*x pattern (same symbol multiplied by itself)
1093        // For binary variables, x*x = x
1094        if new_terms.len() == 2 {
1095            if let (Some(name1), Some(name2)) = (new_terms[0].as_symbol(), new_terms[1].as_symbol())
1096            {
1097                if name1 == name2 {
1098                    // x*x = x for binary variables
1099                    return Ok(new_terms.remove(0));
1100                }
1101            }
1102        }
1103
1104        // Combine the terms back into a product (without identity element)
1105        if new_terms.is_empty() {
1106            return Ok(Expr::from(1));
1107        }
1108        let mut result = new_terms.remove(0);
1109        for term in new_terms {
1110            result = result * term;
1111        }
1112        return Ok(result);
1113    }
1114
1115    // If it's a sum (like x + y)
1116    if expr.is_add() {
1117        let mut new_terms = Vec::new();
1118        // SAFETY: is_add() check guarantees as_add() will succeed
1119        for term in expr.as_add().expect("is_add() was true") {
1120            new_terms.push(replace_squared_terms(&term)?);
1121        }
1122
1123        // Combine the terms back into a sum (without identity element)
1124        if new_terms.is_empty() {
1125            return Ok(Expr::from(0));
1126        }
1127        let mut result = new_terms.remove(0);
1128        for term in new_terms {
1129            result = result + term;
1130        }
1131        return Ok(result);
1132    }
1133
1134    // For any other type of expression, just return it unchanged
1135    Ok(expr.clone())
1136}
1137
1138// Helper function to extract coefficients and variables from the expression
1139#[cfg(feature = "dwave")]
1140fn extract_coefficients(expr: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1141    let mut coeffs = HashMap::new();
1142    let mut offset = 0.0;
1143
1144    // Process expression as a sum of terms
1145    if expr.is_add() {
1146        // SAFETY: is_add() check guarantees as_add() will succeed
1147        for term in expr.as_add().expect("is_add() was true") {
1148            let (term_coeffs, term_offset) = extract_term_coefficients(&term)?;
1149
1150            // Merge coefficients
1151            for (vars, coeff) in term_coeffs {
1152                *coeffs.entry(vars).or_insert(0.0) += coeff;
1153            }
1154
1155            // Add constant terms to offset
1156            offset += term_offset;
1157        }
1158    } else {
1159        // Check if it's a sum-like expression that wasn't detected as ADD
1160        let expr_str = format!("{expr}");
1161        if expr_str.contains('+') || expr_str.contains('-') {
1162            // Use regex to split properly maintaining signs
1163            // This is a more robust workaround for symengine type detection issues
1164            use regex::Regex;
1165            // SAFETY: Static regex pattern is known to be valid at compile time
1166            let re = Regex::new(r"([+-]?)([^+-]+)").expect("static regex pattern is valid");
1167
1168            for caps in re.captures_iter(&expr_str) {
1169                let sign = caps.get(1).map_or("", |m| m.as_str());
1170                let term = caps.get(2).map_or("", |m| m.as_str()).trim();
1171
1172                if term.is_empty() {
1173                    continue;
1174                }
1175
1176                let sign_mult = if sign == "-" { -1.0 } else { 1.0 };
1177
1178                // Handle x**2 or x^2 (becomes just x for binary)
1179                if term.contains("**") || term.contains('^') {
1180                    let base = if term.contains("**") {
1181                        term.split("**").next().unwrap_or(term)
1182                    } else {
1183                        term.split('^').next().unwrap_or(term)
1184                    }
1185                    .trim();
1186
1187                    // Extract coefficient if present (e.g., "10*x^2" -> coeff=10, base="x")
1188                    let (coeff_mult, var_name) = if base.contains('*') {
1189                        let parts: Vec<&str> = base.split('*').collect();
1190                        if parts.len() == 2 {
1191                            if let Ok(num) = parts[0].trim().parse::<f64>() {
1192                                (num, parts[1].trim().to_string())
1193                            } else if let Ok(num) = parts[1].trim().parse::<f64>() {
1194                                (num, parts[0].trim().to_string())
1195                            } else {
1196                                (1.0, base.to_string())
1197                            }
1198                        } else {
1199                            (1.0, base.to_string())
1200                        }
1201                    } else {
1202                        (1.0, base.to_string())
1203                    };
1204
1205                    let vars = vec![var_name.clone()];
1206                    *coeffs.entry(vars).or_insert(0.0) += sign_mult * coeff_mult;
1207                } else if term.contains('*') {
1208                    // Handle multiplication: could be "x*y", "2*x", "x*2", "x*y*z", etc.
1209                    let parts: Vec<&str> = term.split('*').collect();
1210                    let mut coeff = sign_mult;
1211                    let mut vars = Vec::new();
1212
1213                    for part in parts {
1214                        let part = part.trim();
1215                        if let Ok(num) = part.parse::<f64>() {
1216                            coeff *= num;
1217                        } else {
1218                            // It's a variable
1219                            vars.push(part.to_string());
1220                        }
1221                    }
1222
1223                    // Sort variables for consistent ordering
1224                    vars.sort();
1225                    *coeffs.entry(vars).or_insert(0.0) += coeff;
1226                } else if let Ok(num) = term.parse::<f64>() {
1227                    // Constant term
1228                    offset += sign_mult * num;
1229                } else {
1230                    // Single variable with coefficient 1
1231                    let vars = vec![term.to_string()];
1232                    *coeffs.entry(vars).or_insert(0.0) += sign_mult;
1233                }
1234            }
1235            return Ok((coeffs, offset));
1236        }
1237
1238        // Only process as a single term if we haven't processed it as ADD yet
1239        if coeffs.is_empty() {
1240            // Process a single term
1241            let (term_coeffs, term_offset) = extract_term_coefficients(expr)?;
1242
1243            // Merge coefficients
1244            for (vars, coeff) in term_coeffs {
1245                *coeffs.entry(vars).or_insert(0.0) += coeff;
1246            }
1247
1248            // Add constant terms to offset
1249            offset += term_offset;
1250        }
1251    }
1252
1253    Ok((coeffs, offset))
1254}
1255
1256// Helper function to extract coefficient and variables from a single term
1257#[cfg(feature = "dwave")]
1258fn extract_term_coefficients(term: &Expr) -> CompileResult<(HashMap<Vec<String>, f64>, f64)> {
1259    let mut coeffs = HashMap::new();
1260
1261    // If it's a number constant, it's an offset
1262    if term.is_number() {
1263        let value = match term.to_f64() {
1264            Some(n) => n,
1265            None => {
1266                return Err(CompileError::InvalidExpression(
1267                    "Invalid number".to_string(),
1268                ))
1269            }
1270        };
1271        return Ok((coeffs, value));
1272    }
1273
1274    // If it's an addition, recursively extract from both sides
1275    if term.is_add() {
1276        let mut offset = 0.0;
1277        // SAFETY: is_add() check guarantees as_add() will succeed
1278        for sub_term in term.as_add().expect("is_add() was true") {
1279            let (sub_coeffs, sub_offset) = extract_term_coefficients(&sub_term)?;
1280            for (vars, coeff) in sub_coeffs {
1281                *coeffs.entry(vars).or_insert(0.0) += coeff;
1282            }
1283            offset += sub_offset;
1284        }
1285        return Ok((coeffs, offset));
1286    }
1287
1288    // If it's a negation, recursively extract and negate
1289    if term.is_neg() {
1290        // SAFETY: is_neg() check guarantees as_neg() will succeed
1291        let inner = term.as_neg().expect("is_neg() was true");
1292        let (inner_coeffs, inner_offset) = extract_term_coefficients(&inner)?;
1293
1294        // Negate all coefficients
1295        for (vars, coeff) in inner_coeffs {
1296            coeffs.insert(vars, -coeff);
1297        }
1298
1299        return Ok((coeffs, -inner_offset));
1300    }
1301
1302    // If it's a symbol, it's a linear term with coefficient 1
1303    if term.is_symbol() {
1304        // SAFETY: is_symbol() check guarantees as_symbol() will succeed
1305        let var_name = term.as_symbol().expect("is_symbol() was true");
1306        let vars = vec![var_name.to_string()];
1307        coeffs.insert(vars, 1.0);
1308        return Ok((coeffs, 0.0));
1309    }
1310
1311    // If it's a product of terms
1312    if term.is_mul() {
1313        let mut coeff = 1.0;
1314        let mut vars = Vec::new();
1315
1316        // SAFETY: is_mul() check guarantees as_mul() will succeed
1317        let factors = term.as_mul().expect("is_mul() was true");
1318        // Use a stack to iteratively flatten nested products (handles symengine's
1319        // internal representation where x*y*z may appear as (* (* x y) z))
1320        let mut factor_stack: Vec<_> = factors.into_iter().collect();
1321        while let Some(factor) = factor_stack.pop() {
1322            if factor.is_number() {
1323                // Numerical factor is a coefficient
1324                let value = match factor.to_f64() {
1325                    Some(n) => n,
1326                    None => {
1327                        return Err(CompileError::InvalidExpression(
1328                            "Invalid number in product".to_string(),
1329                        ))
1330                    }
1331                };
1332                coeff *= value;
1333            } else if factor.is_symbol() {
1334                // Symbol is a variable
1335                // SAFETY: is_symbol() check guarantees as_symbol() will succeed
1336                let var_name = factor.as_symbol().expect("is_symbol() was true");
1337                vars.push(var_name.to_string());
1338            } else if factor.is_mul() {
1339                // Nested product — flatten by pushing sub-factors back onto the stack
1340                let sub_factors = factor.as_mul().expect("is_mul() was true");
1341                factor_stack.extend(sub_factors);
1342            } else if factor.is_pow() {
1343                // Power term like x^2 — for binary vars x^k = x, treat as x
1344                let (base, exp) = factor.as_pow().expect("is_pow() was true");
1345                if base.is_symbol() && exp.is_number() {
1346                    let exp_val = exp.to_f64().unwrap_or(0.0);
1347                    if exp_val.is_sign_positive() && exp_val.fract() == 0.0 && exp_val >= 1.0 {
1348                        // Binary variable: x^k = x for k >= 1
1349                        let var_name = base.as_symbol().expect("is_symbol() was true");
1350                        vars.push(var_name.to_string());
1351                    } else {
1352                        return Err(CompileError::InvalidExpression(format!(
1353                            "Unsupported power in product: {factor}"
1354                        )));
1355                    }
1356                } else {
1357                    return Err(CompileError::InvalidExpression(format!(
1358                        "Unsupported power term in product: {factor}"
1359                    )));
1360                }
1361            } else {
1362                // More complex factors not supported
1363                return Err(CompileError::InvalidExpression(format!(
1364                    "Unsupported term in product: {factor}"
1365                )));
1366            }
1367        }
1368
1369        // Sort variables for consistent ordering
1370        vars.sort();
1371
1372        if vars.is_empty() {
1373            // If there are no variables, it's a constant term
1374            return Ok((coeffs, coeff));
1375        }
1376        coeffs.insert(vars, coeff);
1377
1378        return Ok((coeffs, 0.0));
1379    }
1380
1381    // If it's a power operation (like x^2), should have been simplified earlier
1382    if term.is_pow() {
1383        return Err(CompileError::InvalidExpression(format!(
1384            "Unexpected power term after simplification: {term}"
1385        )));
1386    }
1387
1388    // Unsupported term type
1389    Err(CompileError::InvalidExpression(format!(
1390        "Unsupported term: {term}"
1391    )))
1392}
1393
1394// Helper function to build the QUBO matrix
1395#[allow(dead_code)]
1396fn build_qubo_matrix(
1397    coeffs: &HashMap<Vec<String>, f64>,
1398) -> CompileResult<(
1399    Array<f64, scirs2_core::ndarray::Ix2>,
1400    HashMap<String, usize>,
1401)> {
1402    // Collect all unique variable names
1403    let mut all_vars = HashSet::new();
1404    for vars in coeffs.keys() {
1405        for var in vars {
1406            all_vars.insert(var.clone());
1407        }
1408    }
1409
1410    // Convert to a sorted vector
1411    let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1412    sorted_vars.sort();
1413
1414    // Create the variable-to-index mapping
1415    let var_map: HashMap<String, usize> = sorted_vars
1416        .iter()
1417        .enumerate()
1418        .map(|(i, var)| (var.clone(), i))
1419        .collect();
1420
1421    // Size of the matrix
1422    let n = var_map.len();
1423
1424    // Create an empty matrix
1425    let mut matrix = Array::zeros((n, n));
1426
1427    // Fill the matrix with coefficients
1428    for (vars, &coeff) in coeffs {
1429        match vars.len() {
1430            0 => {
1431                // Should never happen since constants are handled in offset
1432            }
1433            1 => {
1434                // Linear term: var * coeff
1435                // SAFETY: var_map was built from the same variables in coeffs
1436                let i = *var_map
1437                    .get(&vars[0])
1438                    .expect("variable exists in var_map built from coeffs");
1439                matrix[[i, i]] += coeff;
1440            }
1441            2 => {
1442                // Quadratic term: var1 * var2 * coeff
1443                // SAFETY: var_map was built from the same variables in coeffs
1444                let i = *var_map
1445                    .get(&vars[0])
1446                    .expect("variable exists in var_map built from coeffs");
1447                let j = *var_map
1448                    .get(&vars[1])
1449                    .expect("variable exists in var_map built from coeffs");
1450
1451                // QUBO format requires i <= j
1452                if i == j {
1453                    // Diagonal term
1454                    matrix[[i, i]] += coeff;
1455                } else {
1456                    // Off-diagonal term - store full coefficient in upper triangular, zero in lower
1457                    if i <= j {
1458                        matrix[[i, j]] += coeff;
1459                    } else {
1460                        matrix[[j, i]] += coeff;
1461                    }
1462                }
1463            }
1464            _ => {
1465                // Higher-order terms are not supported in QUBO
1466                return Err(CompileError::DegreeTooHigh(vars.len(), 2));
1467            }
1468        }
1469    }
1470
1471    Ok((matrix, var_map))
1472}
1473
1474// Helper function to build the HOBO tensor
1475#[allow(dead_code)]
1476fn build_hobo_tensor(
1477    coeffs: &HashMap<Vec<String>, f64>,
1478    max_degree: usize,
1479) -> CompileResult<(
1480    Array<f64, scirs2_core::ndarray::IxDyn>,
1481    HashMap<String, usize>,
1482)> {
1483    // Collect all unique variable names
1484    let mut all_vars = HashSet::new();
1485    for vars in coeffs.keys() {
1486        for var in vars {
1487            all_vars.insert(var.clone());
1488        }
1489    }
1490
1491    // Convert to a sorted vector
1492    let mut sorted_vars: Vec<String> = all_vars.into_iter().collect();
1493    sorted_vars.sort();
1494
1495    // Create the variable-to-index mapping
1496    let var_map: HashMap<String, usize> = sorted_vars
1497        .iter()
1498        .enumerate()
1499        .map(|(i, var)| (var.clone(), i))
1500        .collect();
1501
1502    // Size of each dimension
1503    let n = var_map.len();
1504
1505    // Create shape vector for the tensor
1506    let shape: Vec<usize> = vec![n; max_degree];
1507
1508    // Create an empty tensor
1509    let mut tensor = Array::zeros(scirs2_core::ndarray::IxDyn(&shape));
1510
1511    // Fill the tensor with coefficients
1512    for (vars, &coeff) in coeffs {
1513        let degree = vars.len();
1514
1515        if degree == 0 {
1516            // Should never happen since constants are handled in offset
1517            continue;
1518        }
1519
1520        if degree > max_degree {
1521            return Err(CompileError::DegreeTooHigh(degree, max_degree));
1522        }
1523
1524        // Convert variable names to indices
1525        // SAFETY: var_map was built from the same variables in coeffs
1526        let mut indices: Vec<usize> = vars
1527            .iter()
1528            .map(|var| {
1529                *var_map
1530                    .get(var)
1531                    .expect("variable exists in var_map built from coeffs")
1532            })
1533            .collect();
1534
1535        // Sort indices (canonical ordering)
1536        indices.sort_unstable();
1537
1538        // Pad indices to match tensor order if necessary
1539        while indices.len() < max_degree {
1540            indices.insert(0, indices[0]); // Padding with first index
1541        }
1542
1543        // Set the coefficient in the tensor
1544        let idx = scirs2_core::ndarray::IxDyn(&indices);
1545        tensor[idx] += coeff;
1546    }
1547
1548    Ok((tensor, var_map))
1549}
1550
1551/// Special compiler for problems with one-hot constraints
1552///
1553/// This is a specialized compiler that is optimized for problems
1554/// with one-hot constraints, common in many optimization problems.
1555#[cfg(feature = "dwave")]
1556pub struct PieckCompile {
1557    /// The symbolic expression to compile
1558    expr: Expr,
1559    /// Whether to show verbose output
1560    verbose: bool,
1561}
1562
1563#[cfg(feature = "dwave")]
1564impl PieckCompile {
1565    /// Create a new Pieck compiler with the given expression
1566    pub fn new<T: Into<Expr>>(expr: T, verbose: bool) -> Self {
1567        Self {
1568            expr: expr.into(),
1569            verbose,
1570        }
1571    }
1572
1573    /// Compile the expression to a QUBO model optimized for one-hot constraints
1574    pub fn get_qubo(
1575        &self,
1576    ) -> CompileResult<(
1577        (
1578            Array<f64, scirs2_core::ndarray::Ix2>,
1579            HashMap<String, usize>,
1580        ),
1581        f64,
1582    )> {
1583        // Implementation will compile the expression using specialized techniques
1584        // For now, call the regular compiler
1585        Compile::new(self.expr.clone()).get_qubo()
1586    }
1587}