use std::collections::{HashMap, HashSet};
use std::fmt;
use thiserror::Error;
use crate::ising::{IsingError, IsingModel, QuboModel};
#[derive(Error, Debug)]
pub enum DslError {
#[error("Variable not found: {0}")]
VariableNotFound(String),
#[error("Invalid constraint: {0}")]
InvalidConstraint(String),
#[error("Invalid objective: {0}")]
InvalidObjective(String),
#[error("Compilation error: {0}")]
CompilationError(String),
#[error("Type mismatch: {0}")]
TypeMismatch(String),
#[error("Ising error: {0}")]
IsingError(#[from] IsingError),
#[error("Dimension mismatch: expected {expected}, got {actual}")]
DimensionMismatch { expected: usize, actual: usize },
#[error("Invalid range: {0}")]
InvalidRange(String),
}
pub type DslResult<T> = Result<T, DslError>;
#[derive(Debug, Clone, PartialEq)]
pub enum VariableType {
Binary,
Integer { min: i32, max: i32 },
Spin,
Categorical { categories: Vec<String> },
Continuous { min: f64, max: f64, steps: usize },
}
#[derive(Debug, Clone)]
pub struct Variable {
pub id: String,
pub var_type: VariableType,
pub qubit_indices: Vec<usize>,
pub description: Option<String>,
}
#[derive(Debug, Clone)]
pub struct VariableVector {
pub variables: Vec<Variable>,
pub name: String,
}
#[derive(Debug, Clone)]
pub enum Expression {
Constant(f64),
Variable(Variable),
Sum(Vec<Self>),
Product(Box<Self>, Box<Self>),
LinearCombination { weights: Vec<f64>, terms: Vec<Self> },
Quadratic {
var1: Variable,
var2: Variable,
coefficient: f64,
},
Power(Box<Self>, i32),
Negate(Box<Self>),
Abs(Box<Self>),
Conditional {
condition: Box<BooleanExpression>,
if_true: Box<Self>,
if_false: Box<Self>,
},
}
#[derive(Debug, Clone)]
pub enum BooleanExpression {
True,
False,
Equal(Expression, Expression),
LessThan(Expression, Expression),
LessThanOrEqual(Expression, Expression),
GreaterThan(Expression, Expression),
GreaterThanOrEqual(Expression, Expression),
And(Box<Self>, Box<Self>),
Or(Box<Self>, Box<Self>),
Not(Box<Self>),
Xor(Box<Self>, Box<Self>),
Implies(Box<Self>, Box<Self>),
AllDifferent(Vec<Variable>),
AtMostOne(Vec<Variable>),
ExactlyOne(Vec<Variable>),
}
#[derive(Debug, Clone)]
pub struct Constraint {
pub expression: BooleanExpression,
pub name: Option<String>,
pub penalty_weight: Option<f64>,
pub is_hard: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ObjectiveDirection {
Minimize,
Maximize,
}
#[derive(Debug, Clone)]
pub struct Objective {
pub expression: Expression,
pub direction: ObjectiveDirection,
pub name: Option<String>,
}
pub struct OptimizationModel {
pub name: String,
variables: HashMap<String, Variable>,
variable_vectors: HashMap<String, VariableVector>,
constraints: Vec<Constraint>,
objectives: Vec<Objective>,
next_qubit: usize,
metadata: HashMap<String, String>,
}
impl OptimizationModel {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
variables: HashMap::new(),
variable_vectors: HashMap::new(),
constraints: Vec::new(),
objectives: Vec::new(),
next_qubit: 0,
metadata: HashMap::new(),
}
}
pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
pub fn add_binary(&mut self, name: impl Into<String>) -> DslResult<Variable> {
let var_name = name.into();
if self.variables.contains_key(&var_name) {
return Err(DslError::InvalidConstraint(format!(
"Variable {var_name} already exists"
)));
}
let var = Variable {
id: var_name.clone(),
var_type: VariableType::Binary,
qubit_indices: vec![self.next_qubit],
description: None,
};
self.next_qubit += 1;
self.variables.insert(var_name, var.clone());
Ok(var)
}
pub fn add_binary_vector(
&mut self,
name: impl Into<String>,
size: usize,
) -> DslResult<VariableVector> {
let vec_name = name.into();
let mut variables = Vec::new();
for i in 0..size {
let var_name = format!("{vec_name}[{i}]");
let var = self.add_binary(var_name)?;
variables.push(var);
}
let var_vec = VariableVector {
variables,
name: vec_name.clone(),
};
self.variable_vectors.insert(vec_name, var_vec.clone());
Ok(var_vec)
}
pub fn add_integer(
&mut self,
name: impl Into<String>,
min: i32,
max: i32,
) -> DslResult<Variable> {
let var_name = name.into();
if self.variables.contains_key(&var_name) {
return Err(DslError::InvalidConstraint(format!(
"Variable {var_name} already exists"
)));
}
if min > max {
return Err(DslError::InvalidRange(format!(
"Invalid range [{min}, {max}]"
)));
}
let range = (max - min) as u32;
let num_bits = (range + 1).next_power_of_two().trailing_zeros() as usize;
let qubit_indices: Vec<usize> = (0..num_bits)
.map(|_| {
let idx = self.next_qubit;
self.next_qubit += 1;
idx
})
.collect();
let var = Variable {
id: var_name.clone(),
var_type: VariableType::Integer { min, max },
qubit_indices,
description: None,
};
self.variables.insert(var_name, var.clone());
Ok(var)
}
pub fn add_spin(&mut self, name: impl Into<String>) -> DslResult<Variable> {
let var_name = name.into();
if self.variables.contains_key(&var_name) {
return Err(DslError::InvalidConstraint(format!(
"Variable {var_name} already exists"
)));
}
let var = Variable {
id: var_name.clone(),
var_type: VariableType::Spin,
qubit_indices: vec![self.next_qubit],
description: None,
};
self.next_qubit += 1;
self.variables.insert(var_name, var.clone());
Ok(var)
}
pub fn add_constraint(&mut self, constraint: impl Into<Constraint>) -> DslResult<()> {
let constraint = constraint.into();
self.constraints.push(constraint);
Ok(())
}
pub fn minimize(&mut self, expression: impl Into<Expression>) -> DslResult<()> {
let objective = Objective {
expression: expression.into(),
direction: ObjectiveDirection::Minimize,
name: None,
};
self.objectives.push(objective);
Ok(())
}
pub fn maximize(&mut self, expression: impl Into<Expression>) -> DslResult<()> {
let objective = Objective {
expression: expression.into(),
direction: ObjectiveDirection::Maximize,
name: None,
};
self.objectives.push(objective);
Ok(())
}
pub fn compile_to_qubo(&self) -> DslResult<QuboModel> {
let mut model = QuboModel::new(self.next_qubit);
for objective in &self.objectives {
let sign = match objective.direction {
ObjectiveDirection::Minimize => 1.0,
ObjectiveDirection::Maximize => -1.0,
};
self.add_expression_to_qubo(&mut model, &objective.expression, sign)?;
}
for constraint in &self.constraints {
let penalty_weight = if constraint.is_hard {
1000.0 } else {
constraint.penalty_weight.unwrap_or(1.0)
};
self.add_constraint_to_qubo(&mut model, &constraint.expression, penalty_weight)?;
}
Ok(model)
}
pub fn compile_to_ising(&self) -> DslResult<IsingModel> {
let qubo = self.compile_to_qubo()?;
let mut ising = IsingModel::new(self.next_qubit);
let mut offset = 0.0;
for i in 0..self.next_qubit {
let q_val = qubo.get_linear(i)?;
if q_val.abs() > 1e-10 {
let h_i = q_val / 2.0; let current_bias = ising.get_bias(i)?;
ising.set_bias(i, current_bias + h_i)?;
offset += q_val / 4.0; }
}
for (i, j, q_val) in qubo.quadratic_terms() {
if q_val.abs() > 1e-10 {
let j_ij = q_val / 4.0; ising.set_coupling(i, j, j_ij)?;
let bias_i = ising.get_bias(i)?;
ising.set_bias(i, bias_i + q_val / 4.0)?;
let bias_j = ising.get_bias(j)?;
ising.set_bias(j, bias_j + q_val / 4.0)?;
offset += q_val / 4.0;
}
}
Ok(ising)
}
fn add_expression_to_qubo(
&self,
model: &mut QuboModel,
expr: &Expression,
coefficient: f64,
) -> DslResult<()> {
match expr {
Expression::Constant(_c) => {
Ok(())
}
Expression::Variable(var) => {
if let Some(&qubit_idx) = var.qubit_indices.first() {
model.add_linear(qubit_idx, coefficient)?;
}
Ok(())
}
Expression::Sum(terms) => {
for term in terms {
self.add_expression_to_qubo(model, term, coefficient)?;
}
Ok(())
}
Expression::Product(e1, e2) => {
if let (Expression::Variable(v1), Expression::Variable(v2)) =
(e1.as_ref(), e2.as_ref())
{
if let (Some(&q1), Some(&q2)) =
(v1.qubit_indices.first(), v2.qubit_indices.first())
{
if q1 == q2 {
model.add_linear(q1, coefficient)?;
} else {
let current = model.get_quadratic(q1, q2)?;
model.set_quadratic(q1, q2, current + coefficient)?;
}
}
}
Ok(())
}
Expression::Quadratic {
var1,
var2,
coefficient: coef,
} => {
if let (Some(&q1), Some(&q2)) =
(var1.qubit_indices.first(), var2.qubit_indices.first())
{
if q1 == q2 {
model.add_linear(q1, coefficient * coef)?;
} else {
let current = model.get_quadratic(q1, q2)?;
model.set_quadratic(q1, q2, current + coefficient * coef)?;
}
}
Ok(())
}
Expression::LinearCombination { weights, terms } => {
for (weight, term) in weights.iter().zip(terms.iter()) {
self.add_expression_to_qubo(model, term, coefficient * weight)?;
}
Ok(())
}
Expression::Negate(e) => {
self.add_expression_to_qubo(model, e, -coefficient)?;
Ok(())
}
_ => {
Err(DslError::CompilationError(
"Unsupported expression type in QUBO compilation".to_string(),
))
}
}
}
fn add_constraint_to_qubo(
&self,
model: &mut QuboModel,
constraint: &BooleanExpression,
penalty: f64,
) -> DslResult<()> {
match constraint {
BooleanExpression::True => Ok(()),
BooleanExpression::False => {
Err(DslError::InvalidConstraint(
"Unsatisfiable constraint (False)".to_string(),
))
}
BooleanExpression::Equal(e1, e2) => {
let diff =
Expression::Sum(vec![e1.clone(), Expression::Negate(Box::new(e2.clone()))]);
let penalty_expr = Expression::Product(Box::new(diff.clone()), Box::new(diff));
self.add_expression_to_qubo(model, &penalty_expr, penalty)
}
BooleanExpression::ExactlyOne(vars) => {
let sum_expr = Expression::Sum(
vars.iter()
.map(|v| Expression::Variable(v.clone()))
.collect(),
);
let one = Expression::Constant(1.0);
let diff = Expression::Sum(vec![sum_expr, Expression::Negate(Box::new(one))]);
let penalty_expr = Expression::Product(Box::new(diff.clone()), Box::new(diff));
self.add_expression_to_qubo(model, &penalty_expr, penalty)
}
BooleanExpression::AtMostOne(vars) => {
for i in 0..vars.len() {
for j in (i + 1)..vars.len() {
if let (Some(&q1), Some(&q2)) =
(vars[i].qubit_indices.first(), vars[j].qubit_indices.first())
{
let current = model.get_quadratic(q1, q2)?;
model.set_quadratic(q1, q2, current + penalty)?;
}
}
}
Ok(())
}
BooleanExpression::And(b1, b2) => {
self.add_constraint_to_qubo(model, b1, penalty)?;
self.add_constraint_to_qubo(model, b2, penalty)?;
Ok(())
}
_ => {
Err(DslError::CompilationError(
"Unsupported constraint type in QUBO compilation".to_string(),
))
}
}
}
#[must_use]
pub fn summary(&self) -> ModelSummary {
ModelSummary {
name: self.name.clone(),
num_variables: self.variables.len(),
num_qubits: self.next_qubit,
num_constraints: self.constraints.len(),
num_objectives: self.objectives.len(),
variable_types: self.count_variable_types(),
}
}
fn count_variable_types(&self) -> HashMap<String, usize> {
let mut counts = HashMap::new();
for var in self.variables.values() {
let type_name = match &var.var_type {
VariableType::Binary => "binary",
VariableType::Integer { .. } => "integer",
VariableType::Spin => "spin",
VariableType::Categorical { .. } => "categorical",
VariableType::Continuous { .. } => "continuous",
};
*counts.entry(type_name.to_string()).or_insert(0) += 1;
}
counts
}
}
#[derive(Debug)]
pub struct ModelSummary {
pub name: String,
pub num_variables: usize,
pub num_qubits: usize,
pub num_constraints: usize,
pub num_objectives: usize,
pub variable_types: HashMap<String, usize>,
}
impl fmt::Display for ModelSummary {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Model: {}", self.name)?;
writeln!(f, " Variables: {}", self.num_variables)?;
writeln!(f, " Qubits: {}", self.num_qubits)?;
writeln!(f, " Constraints: {}", self.num_constraints)?;
writeln!(f, " Objectives: {}", self.num_objectives)?;
writeln!(f, " Variable types:")?;
for (var_type, count) in &self.variable_types {
writeln!(f, " {var_type}: {count}")?;
}
Ok(())
}
}
impl Expression {
#[must_use]
pub const fn constant(value: f64) -> Self {
Self::Constant(value)
}
#[must_use]
pub const fn sum(terms: Vec<Self>) -> Self {
Self::Sum(terms)
}
#[must_use]
pub fn add(self, other: Self) -> Self {
match (self, other) {
(Self::Sum(mut terms), Self::Sum(other_terms)) => {
terms.extend(other_terms);
Self::Sum(terms)
}
(Self::Sum(mut terms), other) => {
terms.push(other);
Self::Sum(terms)
}
(expr, Self::Sum(mut terms)) => {
terms.insert(0, expr);
Self::Sum(terms)
}
(expr1, expr2) => Self::Sum(vec![expr1, expr2]),
}
}
#[must_use]
pub fn scale(self, factor: f64) -> Self {
match self {
Self::Constant(value) => Self::Constant(value * factor),
Self::LinearCombination { weights, terms } => Self::LinearCombination {
weights: weights.into_iter().map(|w| w * factor).collect(),
terms,
},
expr => Self::LinearCombination {
weights: vec![factor],
terms: vec![expr],
},
}
}
#[must_use]
pub fn negate(self) -> Self {
Self::Negate(Box::new(self))
}
}
impl VariableVector {
#[must_use]
pub fn sum(&self) -> Expression {
Expression::Sum(
self.variables
.iter()
.map(|v| Expression::Variable(v.clone()))
.collect(),
)
}
#[must_use]
pub fn weighted_sum(&self, weights: &[f64]) -> Expression {
if weights.len() != self.variables.len() {
return Expression::Constant(0.0);
}
Expression::LinearCombination {
weights: weights.to_vec(),
terms: self
.variables
.iter()
.map(|v| Expression::Variable(v.clone()))
.collect(),
}
}
#[must_use]
pub fn get(&self, index: usize) -> Option<&Variable> {
self.variables.get(index)
}
#[must_use]
pub fn len(&self) -> usize {
self.variables.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.variables.is_empty()
}
}
impl Expression {
pub fn equals(self, other: impl Into<Self>) -> Constraint {
Constraint {
expression: BooleanExpression::Equal(self, other.into()),
name: None,
penalty_weight: None,
is_hard: true,
}
}
pub fn less_than(self, other: impl Into<Self>) -> Constraint {
Constraint {
expression: BooleanExpression::LessThan(self, other.into()),
name: None,
penalty_weight: None,
is_hard: true,
}
}
pub fn less_than_or_equal(self, other: impl Into<Self>) -> Constraint {
Constraint {
expression: BooleanExpression::LessThanOrEqual(self, other.into()),
name: None,
penalty_weight: None,
is_hard: true,
}
}
pub fn greater_than(self, other: impl Into<Self>) -> Constraint {
Constraint {
expression: BooleanExpression::GreaterThan(self, other.into()),
name: None,
penalty_weight: None,
is_hard: true,
}
}
pub fn greater_than_or_equal(self, other: impl Into<Self>) -> Constraint {
Constraint {
expression: BooleanExpression::GreaterThanOrEqual(self, other.into()),
name: None,
penalty_weight: None,
is_hard: true,
}
}
}
impl From<f64> for Expression {
fn from(value: f64) -> Self {
Self::Constant(value)
}
}
impl From<i32> for Expression {
fn from(value: i32) -> Self {
Self::Constant(f64::from(value))
}
}
impl From<Variable> for Expression {
fn from(var: Variable) -> Self {
Self::Variable(var)
}
}
pub mod patterns {
use super::{
BooleanExpression, Constraint, DslError, DslResult, Expression, OptimizationModel, Variable,
};
pub fn knapsack(
items: &[String],
values: &[f64],
weights: &[f64],
capacity: f64,
) -> DslResult<OptimizationModel> {
let n = items.len();
if values.len() != n || weights.len() != n {
return Err(DslError::DimensionMismatch {
expected: n,
actual: values.len().min(weights.len()),
});
}
let mut model = OptimizationModel::new("Knapsack Problem");
let selection = model.add_binary_vector("select", n)?;
model.add_constraint(selection.weighted_sum(weights).less_than_or_equal(capacity))?;
model.maximize(selection.weighted_sum(values))?;
Ok(model)
}
pub fn graph_coloring(
vertices: &[String],
edges: &[(usize, usize)],
num_colors: usize,
) -> DslResult<OptimizationModel> {
let n = vertices.len();
let mut model = OptimizationModel::new("Graph Coloring");
let mut x = Vec::new();
for v in 0..n {
let colors = model.add_binary_vector(format!("vertex_{v}_color"), num_colors)?;
x.push(colors);
}
for v in 0..n {
let vertex_vars: Vec<Variable> = (0..num_colors)
.filter_map(|c| x[v].get(c).cloned())
.collect();
model.add_constraint(Constraint {
expression: BooleanExpression::ExactlyOne(vertex_vars),
name: Some(format!("vertex_{v}_one_color")),
penalty_weight: None,
is_hard: true,
})?;
}
for &(u, v) in edges {
for c in 0..num_colors {
if let (Some(var_u), Some(var_v)) = (x[u].get(c), x[v].get(c)) {
model.add_constraint(Constraint {
expression: BooleanExpression::AtMostOne(vec![
var_u.clone(),
var_v.clone(),
]),
name: Some(format!("edge_{u}_{v}_color_{c}")),
penalty_weight: None,
is_hard: true,
})?;
}
}
}
let mut color_used = Vec::new();
for c in 0..num_colors {
let color_var = model.add_binary(format!("color_{c}_used"))?;
color_used.push(color_var.clone());
for v in 0..n {
if let Some(var_vc) = x[v].get(c) {
model.add_constraint(
Expression::Variable(var_vc.clone())
.less_than_or_equal(Expression::Variable(color_var.clone())),
)?;
}
}
}
model.minimize(Expression::Sum(
color_used.into_iter().map(Expression::Variable).collect(),
))?;
Ok(model)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_binary_variable_creation() {
let mut model = OptimizationModel::new("Test Model");
let var = model
.add_binary("x")
.expect("Failed to add binary variable");
assert_eq!(var.id, "x");
assert_eq!(var.qubit_indices.len(), 1);
assert!(matches!(var.var_type, VariableType::Binary));
}
#[test]
fn test_binary_vector_creation() {
let mut model = OptimizationModel::new("Test Model");
let vec = model
.add_binary_vector("x", 5)
.expect("Failed to add binary vector");
assert_eq!(vec.name, "x");
assert_eq!(vec.len(), 5);
assert_eq!(vec.variables[0].id, "x[0]");
assert_eq!(vec.variables[4].id, "x[4]");
}
#[test]
fn test_integer_variable_creation() {
let mut model = OptimizationModel::new("Test Model");
let var = model
.add_integer("i", 0, 7)
.expect("Failed to add integer variable");
assert_eq!(var.id, "i");
assert_eq!(var.qubit_indices.len(), 3); assert!(matches!(
var.var_type,
VariableType::Integer { min: 0, max: 7 }
));
}
#[test]
fn test_expression_building() {
let expr1 = Expression::constant(5.0);
let expr2 = Expression::constant(3.0);
let sum = expr1.add(expr2);
assert!(matches!(sum, Expression::Sum(_)));
let scaled = Expression::constant(2.0).scale(3.0);
if let Expression::Constant(value) = scaled {
assert_eq!(value, 6.0);
} else {
panic!("Expected constant expression");
}
}
#[test]
fn test_knapsack_pattern() {
let items = vec![
"Item1".to_string(),
"Item2".to_string(),
"Item3".to_string(),
];
let values = vec![10.0, 20.0, 15.0];
let weights = vec![5.0, 10.0, 7.0];
let capacity = 15.0;
let model = patterns::knapsack(&items, &values, &weights, capacity)
.expect("Failed to create knapsack model");
assert_eq!(model.name, "Knapsack Problem");
assert_eq!(model.summary().num_variables, 3);
assert_eq!(model.summary().num_constraints, 1);
assert_eq!(model.summary().num_objectives, 1);
}
}