quantrs2_tytan/optimization/
constraints.rs

1//! Constraint handling for quantum annealing
2//!
3//! This module provides comprehensive constraint management including
4//! automatic penalty term generation and constraint analysis.
5
6// Optimization penalty types
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// Constraint types
11#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub enum ConstraintType {
13    /// Equality constraint: expr = target
14    Equality { target: f64 },
15    /// Inequality constraint: expr <= bound
16    LessThanOrEqual { bound: f64 },
17    /// Inequality constraint: expr >= bound
18    GreaterThanOrEqual { bound: f64 },
19    /// Range constraint: lower <= expr <= upper
20    Range { lower: f64, upper: f64 },
21    /// One-hot constraint: exactly one variable true
22    OneHot,
23    /// Cardinality constraint: exactly k variables true
24    Cardinality { k: usize },
25    /// Integer encoding constraint
26    IntegerEncoding { min: i32, max: i32 },
27}
28
29/// Constraint definition
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Constraint {
32    pub name: String,
33    pub constraint_type: ConstraintType,
34    pub expression: Expression,
35    pub variables: Vec<String>,
36    pub penalty_weight: Option<f64>,
37    pub slack_variables: Vec<String>,
38}
39
40/// Constraint handler for automatic penalty generation
41pub struct ConstraintHandler {
42    constraints: Vec<Constraint>,
43    slack_variable_counter: usize,
44    encoding_cache: HashMap<String, EncodingInfo>,
45}
46
47/// Encoding information for integer variables
48#[derive(Debug, Clone, Serialize, Deserialize)]
49struct EncodingInfo {
50    pub variable_name: String,
51    pub bit_variables: Vec<String>,
52    pub min_value: i32,
53    pub max_value: i32,
54    pub encoding_type: EncodingType,
55}
56
57/// Integer encoding types
58#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum EncodingType {
60    Binary,
61    Unary,
62    OneHot,
63    Gray,
64}
65
66impl Default for ConstraintHandler {
67    fn default() -> Self {
68        Self::new()
69    }
70}
71
72impl ConstraintHandler {
73    /// Create new constraint handler
74    pub fn new() -> Self {
75        Self {
76            constraints: Vec::new(),
77            slack_variable_counter: 0,
78            encoding_cache: HashMap::new(),
79        }
80    }
81
82    /// Add constraint
83    pub fn add_constraint(&mut self, constraint: Constraint) {
84        self.constraints.push(constraint);
85    }
86
87    /// Add equality constraint
88    pub fn add_equality(
89        &mut self,
90        name: String,
91        expression: Expression,
92        target: f64,
93    ) -> Result<(), Box<dyn std::error::Error>> {
94        let variables = expression.get_variables();
95
96        self.add_constraint(Constraint {
97            name,
98            constraint_type: ConstraintType::Equality { target },
99            expression,
100            variables,
101            penalty_weight: None,
102            slack_variables: Vec::new(),
103        });
104
105        Ok(())
106    }
107
108    /// Add inequality constraint
109    pub fn add_inequality(
110        &mut self,
111        name: String,
112        expression: Expression,
113        bound: f64,
114        less_than: bool,
115    ) -> Result<(), Box<dyn std::error::Error>> {
116        let variables = expression.get_variables();
117        let mut constraint = Constraint {
118            name: name.clone(),
119            constraint_type: if less_than {
120                ConstraintType::LessThanOrEqual { bound }
121            } else {
122                ConstraintType::GreaterThanOrEqual { bound }
123            },
124            expression,
125            variables,
126            penalty_weight: None,
127            slack_variables: Vec::new(),
128        };
129
130        // Add slack variables for inequality constraints
131        if less_than {
132            // expr + slack = bound, slack >= 0
133            let slack_var = self.create_slack_variable(&name);
134            constraint.slack_variables.push(slack_var);
135        } else {
136            // expr - slack = bound, slack >= 0
137            let slack_var = self.create_slack_variable(&name);
138            constraint.slack_variables.push(slack_var);
139        }
140
141        self.add_constraint(constraint);
142        Ok(())
143    }
144
145    /// Add one-hot constraint
146    pub fn add_one_hot(
147        &mut self,
148        name: String,
149        variables: Vec<String>,
150    ) -> Result<(), Box<dyn std::error::Error>> {
151        // Create expression: (sum_i x_i - 1)^2
152        let mut expr = Expression::zero();
153        for var in &variables {
154            expr = expr + Variable::new(var.clone()).into();
155        }
156        expr = expr - 1.0.into();
157
158        self.add_constraint(Constraint {
159            name,
160            constraint_type: ConstraintType::OneHot,
161            expression: expr,
162            variables,
163            penalty_weight: None,
164            slack_variables: Vec::new(),
165        });
166
167        Ok(())
168    }
169
170    /// Add cardinality constraint
171    pub fn add_cardinality(
172        &mut self,
173        name: String,
174        variables: Vec<String>,
175        k: usize,
176    ) -> Result<(), Box<dyn std::error::Error>> {
177        // Create expression: (sum_i x_i - k)^2
178        let mut expr = Expression::zero();
179        for var in &variables {
180            expr = expr + Variable::new(var.clone()).into();
181        }
182        expr = expr - (k as f64).into();
183
184        self.add_constraint(Constraint {
185            name,
186            constraint_type: ConstraintType::Cardinality { k },
187            expression: expr,
188            variables,
189            penalty_weight: None,
190            slack_variables: Vec::new(),
191        });
192
193        Ok(())
194    }
195
196    /// Add integer encoding constraint
197    pub fn add_integer_encoding(
198        &mut self,
199        name: String,
200        base_name: String,
201        min: i32,
202        max: i32,
203        encoding_type: EncodingType,
204    ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
205        let num_bits = ((max - min + 1) as f64).log2().ceil() as usize;
206        let mut bit_variables = Vec::new();
207
208        // Create bit variables
209        for i in 0..num_bits {
210            bit_variables.push(format!("{base_name}_{i}"));
211        }
212
213        // Store encoding info
214        self.encoding_cache.insert(
215            base_name.clone(),
216            EncodingInfo {
217                variable_name: base_name,
218                bit_variables: bit_variables.clone(),
219                min_value: min,
220                max_value: max,
221                encoding_type,
222            },
223        );
224
225        // Add encoding-specific constraints
226        match encoding_type {
227            EncodingType::Binary => {
228                // No additional constraints for binary encoding
229            }
230            EncodingType::Unary => {
231                // Unary: if x_i = 1, then x_{i-1} = 1
232                for i in 1..bit_variables.len() {
233                    let expr: Expression = Variable::new(bit_variables[i].clone()).into();
234                    let prev_expr: Expression = Variable::new(bit_variables[i - 1].clone()).into();
235                    let constraint_expr = expr - prev_expr;
236
237                    self.add_inequality(format!("{name}_unary_{i}"), constraint_expr, 0.0, true)?;
238                }
239            }
240            EncodingType::OneHot => {
241                // Exactly one bit active
242                self.add_one_hot(format!("{name}_onehot"), bit_variables.clone())?;
243            }
244            EncodingType::Gray => {
245                // Gray code constraints are implicit in the mapping
246            }
247        }
248
249        self.add_constraint(Constraint {
250            name,
251            constraint_type: ConstraintType::IntegerEncoding { min, max },
252            expression: Expression::zero(), // Placeholder
253            variables: bit_variables.clone(),
254            penalty_weight: None,
255            slack_variables: Vec::new(),
256        });
257
258        Ok(bit_variables)
259    }
260
261    /// Generate penalty terms for all constraints
262    pub fn generate_penalty_terms(
263        &self,
264        penalty_weights: &HashMap<String, f64>,
265    ) -> Result<Expression, Box<dyn std::error::Error>> {
266        let mut total_penalty = Expression::zero();
267
268        for constraint in &self.constraints {
269            let weight = penalty_weights
270                .get(&constraint.name)
271                .or(constraint.penalty_weight.as_ref())
272                .copied()
273                .unwrap_or(1.0);
274
275            let penalty_expr = match &constraint.constraint_type {
276                ConstraintType::Equality { target } => {
277                    // (expr - target)^2
278                    let diff = constraint.expression.clone() - (*target).into();
279                    diff.clone() * diff
280                }
281                ConstraintType::LessThanOrEqual { bound } => {
282                    // expr + slack = bound => (expr + slack - bound)^2
283                    if let Some(slack_var) = constraint.slack_variables.first() {
284                        let expr_with_slack =
285                            constraint.expression.clone() + Variable::new(slack_var.clone()).into();
286                        let diff = expr_with_slack - (*bound).into();
287                        diff.clone() * diff
288                    } else {
289                        // Penalty for violation: max(0, expr - bound)^2
290                        self.generate_inequality_penalty(&constraint.expression, *bound, true)?
291                    }
292                }
293                ConstraintType::GreaterThanOrEqual { bound } => {
294                    // expr - slack = bound => (expr - slack - bound)^2
295                    if let Some(slack_var) = constraint.slack_variables.first() {
296                        let expr_with_slack =
297                            constraint.expression.clone() - Variable::new(slack_var.clone()).into();
298                        let diff = expr_with_slack - (*bound).into();
299                        diff.clone() * diff
300                    } else {
301                        // Penalty for violation: max(0, bound - expr)^2
302                        self.generate_inequality_penalty(&constraint.expression, *bound, false)?
303                    }
304                }
305                ConstraintType::Range { lower, upper } => {
306                    // Combine two inequality penalties
307                    let lower_penalty =
308                        self.generate_inequality_penalty(&constraint.expression, *lower, false)?;
309                    let upper_penalty =
310                        self.generate_inequality_penalty(&constraint.expression, *upper, true)?;
311                    lower_penalty + upper_penalty
312                }
313                ConstraintType::OneHot => {
314                    // (sum_i x_i - 1)^2
315                    let expr = constraint.expression.clone();
316                    expr.clone() * expr
317                }
318                ConstraintType::Cardinality { k: _ } => {
319                    // (sum_i x_i - k)^2
320                    let expr = constraint.expression.clone();
321                    expr.clone() * expr
322                }
323                ConstraintType::IntegerEncoding { .. } => {
324                    // Encoding constraints are handled separately
325                    Expression::zero()
326                }
327            };
328
329            total_penalty = total_penalty + weight * penalty_expr;
330        }
331
332        Ok(total_penalty)
333    }
334
335    /// Generate inequality penalty using auxiliary binary expansion
336    fn generate_inequality_penalty(
337        &self,
338        _expression: &Expression,
339        _bound: f64,
340        less_than: bool,
341    ) -> Result<Expression, Box<dyn std::error::Error>> {
342        // For now, return a quadratic penalty
343        // In a full implementation, this would use binary expansion
344        // to exactly encode the inequality
345
346        if less_than {
347            // max(0, expr - bound)^2
348            Ok(Expression::zero()) // Placeholder
349        } else {
350            // max(0, bound - expr)^2
351            Ok(Expression::zero()) // Placeholder
352        }
353    }
354
355    /// Create slack variable
356    fn create_slack_variable(&mut self, constraint_name: &str) -> String {
357        let var_name = format!("_slack_{}_{}", constraint_name, self.slack_variable_counter);
358        self.slack_variable_counter += 1;
359        var_name
360    }
361
362    /// Get all variables including slack
363    pub fn get_all_variables(&self) -> Vec<String> {
364        let mut variables = Vec::new();
365
366        for constraint in &self.constraints {
367            variables.extend(constraint.variables.clone());
368            variables.extend(constraint.slack_variables.clone());
369        }
370
371        // Include integer encoding bit variables
372        for encoding in self.encoding_cache.values() {
373            variables.extend(encoding.bit_variables.clone());
374        }
375
376        // Remove duplicates
377        variables.sort();
378        variables.dedup();
379
380        variables
381    }
382
383    /// Decode integer value from bit assignment
384    pub fn decode_integer(
385        &self,
386        variable_name: &str,
387        assignment: &HashMap<String, bool>,
388    ) -> Option<i32> {
389        let encoding = self.encoding_cache.get(variable_name)?;
390
391        match encoding.encoding_type {
392            EncodingType::Binary => {
393                let mut value = 0;
394                for (i, bit_var) in encoding.bit_variables.iter().enumerate() {
395                    if *assignment.get(bit_var).unwrap_or(&false) {
396                        value += 1 << i;
397                    }
398                }
399                Some(encoding.min_value + value)
400            }
401            EncodingType::Unary => {
402                let mut count = 0;
403                for bit_var in &encoding.bit_variables {
404                    if *assignment.get(bit_var).unwrap_or(&false) {
405                        count += 1;
406                    } else {
407                        break;
408                    }
409                }
410                Some(encoding.min_value + count)
411            }
412            EncodingType::OneHot => {
413                for (i, bit_var) in encoding.bit_variables.iter().enumerate() {
414                    if *assignment.get(bit_var).unwrap_or(&false) {
415                        return Some(encoding.min_value + i as i32);
416                    }
417                }
418                None
419            }
420            EncodingType::Gray => {
421                // Convert Gray code to binary
422                let mut gray_value = 0;
423                for (i, bit_var) in encoding.bit_variables.iter().enumerate() {
424                    if *assignment.get(bit_var).unwrap_or(&false) {
425                        gray_value |= 1 << i;
426                    }
427                }
428
429                // Gray to binary conversion
430                let mut binary_value = gray_value;
431                binary_value ^= binary_value >> 16;
432                binary_value ^= binary_value >> 8;
433                binary_value ^= binary_value >> 4;
434                binary_value ^= binary_value >> 2;
435                binary_value ^= binary_value >> 1;
436
437                Some(encoding.min_value + binary_value)
438            }
439        }
440    }
441
442    /// Analyze constraint structure
443    pub fn analyze_constraints(&self) -> ConstraintAnalysis {
444        let total_constraints = self.constraints.len();
445        let total_variables = self.get_all_variables().len();
446
447        let mut type_counts = HashMap::new();
448        let mut avg_variables_per_constraint = 0.0;
449        let mut max_variables_in_constraint = 0;
450
451        for constraint in &self.constraints {
452            let type_name = match constraint.constraint_type {
453                ConstraintType::Equality { .. } => "equality",
454                ConstraintType::LessThanOrEqual { .. } => "less_than",
455                ConstraintType::GreaterThanOrEqual { .. } => "greater_than",
456                ConstraintType::Range { .. } => "range",
457                ConstraintType::OneHot => "one_hot",
458                ConstraintType::Cardinality { .. } => "cardinality",
459                ConstraintType::IntegerEncoding { .. } => "integer",
460            };
461
462            *type_counts.entry(type_name.to_string()).or_insert(0) += 1;
463
464            let var_count = constraint.variables.len();
465            avg_variables_per_constraint += var_count as f64;
466            max_variables_in_constraint = max_variables_in_constraint.max(var_count);
467        }
468
469        if total_constraints > 0 {
470            avg_variables_per_constraint /= total_constraints as f64;
471        }
472
473        ConstraintAnalysis {
474            total_constraints,
475            total_variables,
476            slack_variables: self.slack_variable_counter,
477            constraint_types: type_counts,
478            avg_variables_per_constraint,
479            max_variables_in_constraint,
480            encoding_info: self.encoding_cache.len(),
481        }
482    }
483}
484
485/// Constraint analysis results
486#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct ConstraintAnalysis {
488    pub total_constraints: usize,
489    pub total_variables: usize,
490    pub slack_variables: usize,
491    pub constraint_types: HashMap<String, usize>,
492    pub avg_variables_per_constraint: f64,
493    pub max_variables_in_constraint: usize,
494    pub encoding_info: usize,
495}
496
497// Helper trait implementations for Expression
498trait ExpressionExt {
499    fn zero() -> Self;
500    fn get_variables(&self) -> Vec<String>;
501}
502
503impl ExpressionExt for Expression {
504    fn zero() -> Self {
505        // Placeholder implementation
506        Self::Constant(0.0)
507    }
508
509    fn get_variables(&self) -> Vec<String> {
510        // Placeholder implementation
511        Vec::new()
512    }
513}
514
515/// Variable placeholder
516#[derive(Debug, Clone)]
517pub struct Variable {
518    name: String,
519}
520
521impl Variable {
522    pub const fn new(name: String) -> Self {
523        Self { name }
524    }
525}
526
527/// Expression type placeholder
528#[derive(Debug, Clone, Serialize, Deserialize)]
529pub enum Expression {
530    Constant(f64),
531    Variable(String),
532    Add(Box<Self>, Box<Self>),
533    Multiply(Box<Self>, Box<Self>),
534}
535
536impl From<f64> for Expression {
537    fn from(value: f64) -> Self {
538        Self::Constant(value)
539    }
540}
541
542impl From<Variable> for Expression {
543    fn from(var: Variable) -> Self {
544        Self::Variable(var.name)
545    }
546}
547
548impl std::ops::Add for Expression {
549    type Output = Self;
550
551    fn add(self, rhs: Self) -> Self::Output {
552        Self::Add(Box::new(self), Box::new(rhs))
553    }
554}
555
556impl std::ops::Sub for Expression {
557    type Output = Self;
558
559    fn sub(self, rhs: Self) -> Self::Output {
560        Self::Add(
561            Box::new(self),
562            Box::new(Self::Multiply(
563                Box::new(Self::Constant(-1.0)),
564                Box::new(rhs),
565            )),
566        )
567    }
568}
569
570impl std::ops::Mul for Expression {
571    type Output = Self;
572
573    fn mul(self, rhs: Self) -> Self::Output {
574        Self::Multiply(Box::new(self), Box::new(rhs))
575    }
576}
577
578impl std::ops::Mul<Expression> for f64 {
579    type Output = Expression;
580
581    fn mul(self, rhs: Expression) -> Self::Output {
582        Expression::Multiply(Box::new(Expression::Constant(self)), Box::new(rhs))
583    }
584}