1use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
12pub enum ConstraintType {
13 Equality { target: f64 },
15 LessThanOrEqual { bound: f64 },
17 GreaterThanOrEqual { bound: f64 },
19 Range { lower: f64, upper: f64 },
21 OneHot,
23 Cardinality { k: usize },
25 IntegerEncoding { min: i32, max: i32 },
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct Constraint {
32 pub name: String,
33 pub constraint_type: ConstraintType,
34 pub expression: Expression,
35 pub variables: Vec<String>,
36 pub penalty_weight: Option<f64>,
37 pub slack_variables: Vec<String>,
38}
39
40pub struct ConstraintHandler {
42 constraints: Vec<Constraint>,
43 slack_variable_counter: usize,
44 encoding_cache: HashMap<String, EncodingInfo>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
49struct EncodingInfo {
50 pub variable_name: String,
51 pub bit_variables: Vec<String>,
52 pub min_value: i32,
53 pub max_value: i32,
54 pub encoding_type: EncodingType,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
59pub enum EncodingType {
60 Binary,
61 Unary,
62 OneHot,
63 Gray,
64}
65
66impl Default for ConstraintHandler {
67 fn default() -> Self {
68 Self::new()
69 }
70}
71
72impl ConstraintHandler {
73 pub fn new() -> Self {
75 Self {
76 constraints: Vec::new(),
77 slack_variable_counter: 0,
78 encoding_cache: HashMap::new(),
79 }
80 }
81
82 pub fn add_constraint(&mut self, constraint: Constraint) {
84 self.constraints.push(constraint);
85 }
86
87 pub fn add_equality(
89 &mut self,
90 name: String,
91 expression: Expression,
92 target: f64,
93 ) -> Result<(), Box<dyn std::error::Error>> {
94 let variables = expression.get_variables();
95
96 self.add_constraint(Constraint {
97 name,
98 constraint_type: ConstraintType::Equality { target },
99 expression,
100 variables,
101 penalty_weight: None,
102 slack_variables: Vec::new(),
103 });
104
105 Ok(())
106 }
107
108 pub fn add_inequality(
110 &mut self,
111 name: String,
112 expression: Expression,
113 bound: f64,
114 less_than: bool,
115 ) -> Result<(), Box<dyn std::error::Error>> {
116 let variables = expression.get_variables();
117 let mut constraint = Constraint {
118 name: name.clone(),
119 constraint_type: if less_than {
120 ConstraintType::LessThanOrEqual { bound }
121 } else {
122 ConstraintType::GreaterThanOrEqual { bound }
123 },
124 expression,
125 variables,
126 penalty_weight: None,
127 slack_variables: Vec::new(),
128 };
129
130 if less_than {
132 let slack_var = self.create_slack_variable(&name);
134 constraint.slack_variables.push(slack_var);
135 } else {
136 let slack_var = self.create_slack_variable(&name);
138 constraint.slack_variables.push(slack_var);
139 }
140
141 self.add_constraint(constraint);
142 Ok(())
143 }
144
145 pub fn add_one_hot(
147 &mut self,
148 name: String,
149 variables: Vec<String>,
150 ) -> Result<(), Box<dyn std::error::Error>> {
151 let mut expr = Expression::zero();
153 for var in &variables {
154 expr = expr + Variable::new(var.clone()).into();
155 }
156 expr = expr - 1.0.into();
157
158 self.add_constraint(Constraint {
159 name,
160 constraint_type: ConstraintType::OneHot,
161 expression: expr,
162 variables,
163 penalty_weight: None,
164 slack_variables: Vec::new(),
165 });
166
167 Ok(())
168 }
169
170 pub fn add_cardinality(
172 &mut self,
173 name: String,
174 variables: Vec<String>,
175 k: usize,
176 ) -> Result<(), Box<dyn std::error::Error>> {
177 let mut expr = Expression::zero();
179 for var in &variables {
180 expr = expr + Variable::new(var.clone()).into();
181 }
182 expr = expr - (k as f64).into();
183
184 self.add_constraint(Constraint {
185 name,
186 constraint_type: ConstraintType::Cardinality { k },
187 expression: expr,
188 variables,
189 penalty_weight: None,
190 slack_variables: Vec::new(),
191 });
192
193 Ok(())
194 }
195
196 pub fn add_integer_encoding(
198 &mut self,
199 name: String,
200 base_name: String,
201 min: i32,
202 max: i32,
203 encoding_type: EncodingType,
204 ) -> Result<Vec<String>, Box<dyn std::error::Error>> {
205 let num_bits = ((max - min + 1) as f64).log2().ceil() as usize;
206 let mut bit_variables = Vec::new();
207
208 for i in 0..num_bits {
210 bit_variables.push(format!("{base_name}_{i}"));
211 }
212
213 self.encoding_cache.insert(
215 base_name.clone(),
216 EncodingInfo {
217 variable_name: base_name,
218 bit_variables: bit_variables.clone(),
219 min_value: min,
220 max_value: max,
221 encoding_type,
222 },
223 );
224
225 match encoding_type {
227 EncodingType::Binary => {
228 }
230 EncodingType::Unary => {
231 for i in 1..bit_variables.len() {
233 let expr: Expression = Variable::new(bit_variables[i].clone()).into();
234 let prev_expr: Expression = Variable::new(bit_variables[i - 1].clone()).into();
235 let constraint_expr = expr - prev_expr;
236
237 self.add_inequality(format!("{name}_unary_{i}"), constraint_expr, 0.0, true)?;
238 }
239 }
240 EncodingType::OneHot => {
241 self.add_one_hot(format!("{name}_onehot"), bit_variables.clone())?;
243 }
244 EncodingType::Gray => {
245 }
247 }
248
249 self.add_constraint(Constraint {
250 name,
251 constraint_type: ConstraintType::IntegerEncoding { min, max },
252 expression: Expression::zero(), variables: bit_variables.clone(),
254 penalty_weight: None,
255 slack_variables: Vec::new(),
256 });
257
258 Ok(bit_variables)
259 }
260
261 pub fn generate_penalty_terms(
263 &self,
264 penalty_weights: &HashMap<String, f64>,
265 ) -> Result<Expression, Box<dyn std::error::Error>> {
266 let mut total_penalty = Expression::zero();
267
268 for constraint in &self.constraints {
269 let weight = penalty_weights
270 .get(&constraint.name)
271 .or(constraint.penalty_weight.as_ref())
272 .copied()
273 .unwrap_or(1.0);
274
275 let penalty_expr = match &constraint.constraint_type {
276 ConstraintType::Equality { target } => {
277 let diff = constraint.expression.clone() - (*target).into();
279 diff.clone() * diff
280 }
281 ConstraintType::LessThanOrEqual { bound } => {
282 if let Some(slack_var) = constraint.slack_variables.first() {
284 let expr_with_slack =
285 constraint.expression.clone() + Variable::new(slack_var.clone()).into();
286 let diff = expr_with_slack - (*bound).into();
287 diff.clone() * diff
288 } else {
289 self.generate_inequality_penalty(&constraint.expression, *bound, true)?
291 }
292 }
293 ConstraintType::GreaterThanOrEqual { bound } => {
294 if let Some(slack_var) = constraint.slack_variables.first() {
296 let expr_with_slack =
297 constraint.expression.clone() - Variable::new(slack_var.clone()).into();
298 let diff = expr_with_slack - (*bound).into();
299 diff.clone() * diff
300 } else {
301 self.generate_inequality_penalty(&constraint.expression, *bound, false)?
303 }
304 }
305 ConstraintType::Range { lower, upper } => {
306 let lower_penalty =
308 self.generate_inequality_penalty(&constraint.expression, *lower, false)?;
309 let upper_penalty =
310 self.generate_inequality_penalty(&constraint.expression, *upper, true)?;
311 lower_penalty + upper_penalty
312 }
313 ConstraintType::OneHot => {
314 let expr = constraint.expression.clone();
316 expr.clone() * expr
317 }
318 ConstraintType::Cardinality { k: _ } => {
319 let expr = constraint.expression.clone();
321 expr.clone() * expr
322 }
323 ConstraintType::IntegerEncoding { .. } => {
324 Expression::zero()
326 }
327 };
328
329 total_penalty = total_penalty + weight * penalty_expr;
330 }
331
332 Ok(total_penalty)
333 }
334
335 fn generate_inequality_penalty(
337 &self,
338 _expression: &Expression,
339 _bound: f64,
340 less_than: bool,
341 ) -> Result<Expression, Box<dyn std::error::Error>> {
342 if less_than {
347 Ok(Expression::zero()) } else {
350 Ok(Expression::zero()) }
353 }
354
355 fn create_slack_variable(&mut self, constraint_name: &str) -> String {
357 let var_name = format!("_slack_{}_{}", constraint_name, self.slack_variable_counter);
358 self.slack_variable_counter += 1;
359 var_name
360 }
361
362 pub fn get_all_variables(&self) -> Vec<String> {
364 let mut variables = Vec::new();
365
366 for constraint in &self.constraints {
367 variables.extend(constraint.variables.clone());
368 variables.extend(constraint.slack_variables.clone());
369 }
370
371 for encoding in self.encoding_cache.values() {
373 variables.extend(encoding.bit_variables.clone());
374 }
375
376 variables.sort();
378 variables.dedup();
379
380 variables
381 }
382
383 pub fn decode_integer(
385 &self,
386 variable_name: &str,
387 assignment: &HashMap<String, bool>,
388 ) -> Option<i32> {
389 let encoding = self.encoding_cache.get(variable_name)?;
390
391 match encoding.encoding_type {
392 EncodingType::Binary => {
393 let mut value = 0;
394 for (i, bit_var) in encoding.bit_variables.iter().enumerate() {
395 if *assignment.get(bit_var).unwrap_or(&false) {
396 value += 1 << i;
397 }
398 }
399 Some(encoding.min_value + value)
400 }
401 EncodingType::Unary => {
402 let mut count = 0;
403 for bit_var in &encoding.bit_variables {
404 if *assignment.get(bit_var).unwrap_or(&false) {
405 count += 1;
406 } else {
407 break;
408 }
409 }
410 Some(encoding.min_value + count)
411 }
412 EncodingType::OneHot => {
413 for (i, bit_var) in encoding.bit_variables.iter().enumerate() {
414 if *assignment.get(bit_var).unwrap_or(&false) {
415 return Some(encoding.min_value + i as i32);
416 }
417 }
418 None
419 }
420 EncodingType::Gray => {
421 let mut gray_value = 0;
423 for (i, bit_var) in encoding.bit_variables.iter().enumerate() {
424 if *assignment.get(bit_var).unwrap_or(&false) {
425 gray_value |= 1 << i;
426 }
427 }
428
429 let mut binary_value = gray_value;
431 binary_value ^= binary_value >> 16;
432 binary_value ^= binary_value >> 8;
433 binary_value ^= binary_value >> 4;
434 binary_value ^= binary_value >> 2;
435 binary_value ^= binary_value >> 1;
436
437 Some(encoding.min_value + binary_value)
438 }
439 }
440 }
441
442 pub fn analyze_constraints(&self) -> ConstraintAnalysis {
444 let total_constraints = self.constraints.len();
445 let total_variables = self.get_all_variables().len();
446
447 let mut type_counts = HashMap::new();
448 let mut avg_variables_per_constraint = 0.0;
449 let mut max_variables_in_constraint = 0;
450
451 for constraint in &self.constraints {
452 let type_name = match constraint.constraint_type {
453 ConstraintType::Equality { .. } => "equality",
454 ConstraintType::LessThanOrEqual { .. } => "less_than",
455 ConstraintType::GreaterThanOrEqual { .. } => "greater_than",
456 ConstraintType::Range { .. } => "range",
457 ConstraintType::OneHot => "one_hot",
458 ConstraintType::Cardinality { .. } => "cardinality",
459 ConstraintType::IntegerEncoding { .. } => "integer",
460 };
461
462 *type_counts.entry(type_name.to_string()).or_insert(0) += 1;
463
464 let var_count = constraint.variables.len();
465 avg_variables_per_constraint += var_count as f64;
466 max_variables_in_constraint = max_variables_in_constraint.max(var_count);
467 }
468
469 if total_constraints > 0 {
470 avg_variables_per_constraint /= total_constraints as f64;
471 }
472
473 ConstraintAnalysis {
474 total_constraints,
475 total_variables,
476 slack_variables: self.slack_variable_counter,
477 constraint_types: type_counts,
478 avg_variables_per_constraint,
479 max_variables_in_constraint,
480 encoding_info: self.encoding_cache.len(),
481 }
482 }
483}
484
485#[derive(Debug, Clone, Serialize, Deserialize)]
487pub struct ConstraintAnalysis {
488 pub total_constraints: usize,
489 pub total_variables: usize,
490 pub slack_variables: usize,
491 pub constraint_types: HashMap<String, usize>,
492 pub avg_variables_per_constraint: f64,
493 pub max_variables_in_constraint: usize,
494 pub encoding_info: usize,
495}
496
497trait ExpressionExt {
499 fn zero() -> Self;
500 fn get_variables(&self) -> Vec<String>;
501}
502
503impl ExpressionExt for Expression {
504 fn zero() -> Self {
505 Self::Constant(0.0)
507 }
508
509 fn get_variables(&self) -> Vec<String> {
510 Vec::new()
512 }
513}
514
515#[derive(Debug, Clone)]
517pub struct Variable {
518 name: String,
519}
520
521impl Variable {
522 pub const fn new(name: String) -> Self {
523 Self { name }
524 }
525}
526
527#[derive(Debug, Clone, Serialize, Deserialize)]
529pub enum Expression {
530 Constant(f64),
531 Variable(String),
532 Add(Box<Self>, Box<Self>),
533 Multiply(Box<Self>, Box<Self>),
534}
535
536impl From<f64> for Expression {
537 fn from(value: f64) -> Self {
538 Self::Constant(value)
539 }
540}
541
542impl From<Variable> for Expression {
543 fn from(var: Variable) -> Self {
544 Self::Variable(var.name)
545 }
546}
547
548impl std::ops::Add for Expression {
549 type Output = Self;
550
551 fn add(self, rhs: Self) -> Self::Output {
552 Self::Add(Box::new(self), Box::new(rhs))
553 }
554}
555
556impl std::ops::Sub for Expression {
557 type Output = Self;
558
559 fn sub(self, rhs: Self) -> Self::Output {
560 Self::Add(
561 Box::new(self),
562 Box::new(Self::Multiply(
563 Box::new(Self::Constant(-1.0)),
564 Box::new(rhs),
565 )),
566 )
567 }
568}
569
570impl std::ops::Mul for Expression {
571 type Output = Self;
572
573 fn mul(self, rhs: Self) -> Self::Output {
574 Self::Multiply(Box::new(self), Box::new(rhs))
575 }
576}
577
578impl std::ops::Mul<Expression> for f64 {
579 type Output = Expression;
580
581 fn mul(self, rhs: Expression) -> Self::Output {
582 Expression::Multiply(Box::new(Expression::Constant(self)), Box::new(rhs))
583 }
584}