use scirs2_core::ndarray::Array2;
use std::collections::{HashMap, HashSet};
#[cfg(feature = "dwave")]
use crate::symbol::Expression;
#[derive(Debug, Clone)]
pub enum GlobalConstraint {
AllDifferent { variables: Vec<String> },
Cumulative { tasks: Vec<Task>, capacity: i32 },
GlobalCardinality {
variables: Vec<String>,
values: Vec<i32>,
min_occurrences: Vec<i32>,
max_occurrences: Vec<i32>,
},
Regular {
variables: Vec<String>,
automaton: FiniteAutomaton,
},
Element {
index_var: String,
array: Vec<i32>,
value_var: String,
},
Table {
variables: Vec<String>,
tuples: Vec<Vec<i32>>,
positive: bool,
},
Circuit { variables: Vec<String> },
BinPacking { items: Vec<Item>, bins: Vec<Bin> },
}
#[derive(Debug, Clone)]
pub struct Task {
pub start_var: String,
pub duration: i32,
pub resource_usage: i32,
}
#[derive(Debug, Clone)]
pub struct Item {
pub size: i32,
pub bin_var: String,
}
#[derive(Debug, Clone)]
pub struct Bin {
pub capacity: i32,
}
#[derive(Debug, Clone)]
pub struct FiniteAutomaton {
pub states: Vec<i32>,
pub initial_state: i32,
pub final_states: HashSet<i32>,
pub transitions: HashMap<(i32, i32), i32>, }
#[derive(Debug, Clone)]
pub struct SoftConstraint {
pub constraint: ConstraintExpression,
pub penalty: PenaltyFunction,
pub priority: i32,
}
#[derive(Debug, Clone)]
pub enum ConstraintExpression {
LinearInequality {
coefficients: Vec<(String, f64)>,
bound: f64,
},
Logical(LogicalExpression),
#[cfg(feature = "dwave")]
Custom(Expression),
}
#[derive(Debug, Clone)]
pub enum LogicalExpression {
Var(String),
Not(Box<Self>),
And(Vec<Self>),
Or(Vec<Self>),
Implies(Box<Self>, Box<Self>),
Iff(Box<Self>, Box<Self>),
}
#[derive(Debug, Clone)]
pub enum PenaltyFunction {
Linear { weight: f64 },
Quadratic { weight: f64 },
Exponential { weight: f64 },
Step { weight: f64 },
PiecewiseLinear { points: Vec<(f64, f64)> },
}
pub trait ConstraintPropagator {
fn propagate(&mut self, domains: &mut HashMap<String, Domain>) -> Result<bool, String>;
fn is_satisfied(&self, assignment: &HashMap<String, i32>) -> bool;
fn variables(&self) -> Vec<String>;
}
#[derive(Debug, Clone)]
pub struct Domain {
pub values: HashSet<i32>,
pub min: i32,
pub max: i32,
}
impl Domain {
pub fn new(min: i32, max: i32) -> Self {
let values: HashSet<i32> = (min..=max).collect();
Self { values, min, max }
}
pub fn from_values(values: Vec<i32>) -> Self {
let min = values.iter().min().copied().unwrap_or(0);
let max = values.iter().max().copied().unwrap_or(0);
Self {
values: values.into_iter().collect(),
min,
max,
}
}
pub fn remove(&mut self, value: i32) -> bool {
let removed = self.values.remove(&value);
if removed {
self.update_bounds();
}
removed
}
pub fn intersect(&mut self, values: &HashSet<i32>) {
self.values = self.values.intersection(values).copied().collect();
self.update_bounds();
}
fn update_bounds(&mut self) {
self.min = self.values.iter().min().copied().unwrap_or(self.min);
self.max = self.values.iter().max().copied().unwrap_or(self.max);
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn size(&self) -> usize {
self.values.len()
}
}
pub struct AllDifferentPropagator {
variables: Vec<String>,
}
impl AllDifferentPropagator {
pub const fn new(variables: Vec<String>) -> Self {
Self { variables }
}
}
impl ConstraintPropagator for AllDifferentPropagator {
fn propagate(&mut self, domains: &mut HashMap<String, Domain>) -> Result<bool, String> {
let mut changed = false;
let mut assigned_values = HashSet::new();
for var in &self.variables {
if let Some(domain) = domains.get(var) {
if domain.size() == 1 {
if let Some(&value) = domain.values.iter().next() {
assigned_values.insert(value);
}
}
}
}
for var in &self.variables {
if let Some(domain) = domains.get_mut(var) {
if domain.size() > 1 {
for &value in &assigned_values {
if domain.remove(value) {
changed = true;
}
}
if domain.is_empty() {
return Err(format!("Domain of {var} became empty"));
}
}
}
}
if self.variables.len() <= 10 {
changed |= self.hall_propagation(domains)?;
}
Ok(changed)
}
fn is_satisfied(&self, assignment: &HashMap<String, i32>) -> bool {
let mut seen = HashSet::new();
for var in &self.variables {
if let Some(&value) = assignment.get(var) {
if !seen.insert(value) {
return false;
}
}
}
true
}
fn variables(&self) -> Vec<String> {
self.variables.clone()
}
}
impl AllDifferentPropagator {
fn hall_propagation(&self, domains: &mut HashMap<String, Domain>) -> Result<bool, String> {
let n = self.variables.len();
let mut changed = false;
for subset_bits in 1..(1 << n) {
let mut subset_vars = Vec::new();
let mut union_values: HashSet<i32> = HashSet::new();
for (i, var) in self.variables.iter().enumerate() {
if (subset_bits >> i) & 1 == 1 {
subset_vars.push(var);
if let Some(domain) = domains.get(var) {
union_values.extend(&domain.values);
}
}
}
if subset_vars.len() > union_values.len() {
return Err("Unsatisfiable: Hall's condition violated".to_string());
}
if subset_vars.len() == union_values.len() {
for var in &self.variables {
if !subset_vars.contains(&var) {
if let Some(domain) = domains.get_mut(var) {
let old_size = domain.size();
domain.values.retain(|v| !union_values.contains(v));
domain.update_bounds();
if domain.size() < old_size {
changed = true;
}
if domain.is_empty() {
return Err(format!("Domain of {var} became empty"));
}
}
}
}
}
}
Ok(changed)
}
}
pub struct CumulativePropagator {
tasks: Vec<Task>,
capacity: i32,
}
impl CumulativePropagator {
pub const fn new(tasks: Vec<Task>, capacity: i32) -> Self {
Self { tasks, capacity }
}
fn time_tabling(&self, domains: &HashMap<String, Domain>) -> Result<(), String> {
let mut min_time = i32::MAX;
let mut max_time = i32::MIN;
for task in &self.tasks {
if let Some(domain) = domains.get(&task.start_var) {
min_time = min_time.min(domain.min);
max_time = max_time.max(domain.max + task.duration);
}
}
for t in min_time..max_time {
let mut min_usage = 0;
for task in &self.tasks {
if let Some(domain) = domains.get(&task.start_var) {
if domain.max < t && t < domain.min + task.duration {
min_usage += task.resource_usage;
}
}
}
if min_usage > self.capacity {
return Err(format!("Resource overload at time {t}"));
}
}
Ok(())
}
}
impl ConstraintPropagator for CumulativePropagator {
fn propagate(&mut self, domains: &mut HashMap<String, Domain>) -> Result<bool, String> {
self.time_tabling(domains)?;
Ok(false) }
fn is_satisfied(&self, assignment: &HashMap<String, i32>) -> bool {
let mut events = Vec::new();
for task in &self.tasks {
if let Some(&start) = assignment.get(&task.start_var) {
events.push((start, task.resource_usage));
events.push((start + task.duration, -task.resource_usage));
} else {
return false; }
}
events.sort_by_key(|&(time, _)| time);
let mut current_usage = 0;
for (_, delta) in events {
current_usage += delta;
if current_usage > self.capacity {
return false;
}
}
true
}
fn variables(&self) -> Vec<String> {
self.tasks.iter().map(|t| t.start_var.clone()).collect()
}
}
#[derive(Debug, Clone)]
pub enum SymmetryBreaking {
LexOrdering { variable_groups: Vec<Vec<String>> },
ValuePrecedence {
values: Vec<i32>,
variables: Vec<String>,
},
OrbitFixing {
symmetry_group: SymmetryGroup,
representative: Vec<(String, i32)>,
},
}
#[derive(Debug, Clone)]
pub enum SymmetryGroup {
Symmetric(usize),
Cyclic(usize),
Product(Box<Self>, Box<Self>),
}
pub struct ConstraintLibrary;
impl ConstraintLibrary {
pub fn n_queens(n: usize) -> Vec<GlobalConstraint> {
let vars: Vec<String> = (0..n).map(|i| format!("queen_{i}")).collect();
let constraints = vec![
GlobalConstraint::AllDifferent { variables: vars },
];
constraints
}
pub fn graph_coloring(edges: &[(usize, usize)], _num_colors: usize) -> Vec<GlobalConstraint> {
let mut constraints = Vec::new();
for &(i, j) in edges {
constraints.push(GlobalConstraint::AllDifferent {
variables: vec![format!("color_{}", i), format!("color_{}", j)],
});
}
constraints
}
pub fn sudoku() -> Vec<GlobalConstraint> {
let mut constraints = Vec::new();
for row in 0..9 {
let vars: Vec<String> = (0..9).map(|col| format!("cell_{row}_{col}")).collect();
constraints.push(GlobalConstraint::AllDifferent { variables: vars });
}
for col in 0..9 {
let vars: Vec<String> = (0..9).map(|row| format!("cell_{row}_{col}")).collect();
constraints.push(GlobalConstraint::AllDifferent { variables: vars });
}
for box_row in 0..3 {
for box_col in 0..3 {
let mut vars = Vec::new();
for r in 0..3 {
for c in 0..3 {
vars.push(format!("cell_{}_{}", box_row * 3 + r, box_col * 3 + c));
}
}
constraints.push(GlobalConstraint::AllDifferent { variables: vars });
}
}
constraints
}
}
pub fn constraints_to_penalties(
constraints: &[SoftConstraint],
variables: &HashMap<String, usize>,
) -> Array2<f64> {
let n = variables.len();
let mut penalty_matrix = Array2::zeros((n, n));
for constraint in constraints {
if let ConstraintExpression::LinearInequality {
coefficients,
bound: _,
} = &constraint.constraint
{
for (var1, coeff1) in coefficients {
if let Some(&idx1) = variables.get(var1) {
penalty_matrix[[idx1, idx1]] += coeff1 * coeff1;
for (var2, coeff2) in coefficients {
if var1 != var2 {
if let Some(&idx2) = variables.get(var2) {
penalty_matrix[[idx1, idx2]] += coeff1 * coeff2;
}
}
}
}
}
} else {
}
}
penalty_matrix
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_domain_operations() {
let mut domain = Domain::new(1, 5);
assert_eq!(domain.size(), 5);
domain.remove(3);
assert_eq!(domain.size(), 4);
assert!(!domain.values.contains(&3));
let mut keep = vec![1, 2, 5].into_iter().collect();
domain.intersect(&keep);
assert_eq!(domain.size(), 3);
assert_eq!(domain.min, 1);
assert_eq!(domain.max, 5);
}
#[test]
fn test_alldifferent_propagation() {
let mut propagator =
AllDifferentPropagator::new(vec!["x".to_string(), "y".to_string(), "z".to_string()]);
let mut domains = HashMap::new();
domains.insert("x".to_string(), Domain::from_values(vec![1]));
domains.insert("y".to_string(), Domain::from_values(vec![1, 2, 3]));
domains.insert("z".to_string(), Domain::from_values(vec![1, 2, 3]));
let mut changed = propagator
.propagate(&mut domains)
.expect("AllDifferent propagation should succeed with valid domains");
assert!(changed);
assert!(!domains["y"].values.contains(&1));
assert!(!domains["z"].values.contains(&1));
}
#[test]
fn test_constraint_library() {
let queens = ConstraintLibrary::n_queens(8);
assert!(!queens.is_empty());
let mut edges = vec![(0, 1), (1, 2), (2, 0)];
let coloring = ConstraintLibrary::graph_coloring(&edges, 3);
assert_eq!(coloring.len(), 3);
}
}