use crate::error::IrError;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::ops::RangeInclusive;
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum Domain {
FiniteDomain { values: HashSet<i64> },
Interval { lower: f64, upper: f64 },
Boolean,
Enumeration { values: HashSet<String> },
}
impl Domain {
pub fn finite_domain(values: Vec<i64>) -> Self {
Domain::FiniteDomain {
values: values.into_iter().collect(),
}
}
pub fn range(range: RangeInclusive<i64>) -> Self {
Domain::FiniteDomain {
values: range.collect(),
}
}
pub fn interval(lower: f64, upper: f64) -> Self {
Domain::Interval { lower, upper }
}
pub fn boolean() -> Self {
Domain::Boolean
}
pub fn enumeration(values: Vec<String>) -> Self {
Domain::Enumeration {
values: values.into_iter().collect(),
}
}
pub fn is_empty(&self) -> bool {
match self {
Domain::FiniteDomain { values } => values.is_empty(),
Domain::Interval { lower, upper } => lower > upper,
Domain::Boolean => false,
Domain::Enumeration { values } => values.is_empty(),
}
}
pub fn size(&self) -> Option<usize> {
match self {
Domain::FiniteDomain { values } => Some(values.len()),
Domain::Interval { .. } => None, Domain::Boolean => Some(2),
Domain::Enumeration { values } => Some(values.len()),
}
}
pub fn contains_int(&self, value: i64) -> bool {
match self {
Domain::FiniteDomain { values } => values.contains(&value),
Domain::Interval { lower, upper } => {
let v = value as f64;
v >= *lower && v <= *upper
}
Domain::Boolean => value == 0 || value == 1,
Domain::Enumeration { .. } => false,
}
}
pub fn intersect(&self, other: &Domain) -> Result<Domain, IrError> {
match (self, other) {
(Domain::FiniteDomain { values: v1 }, Domain::FiniteDomain { values: v2 }) => {
Ok(Domain::FiniteDomain {
values: v1.intersection(v2).copied().collect(),
})
}
(
Domain::Interval {
lower: l1,
upper: u1,
},
Domain::Interval {
lower: l2,
upper: u2,
},
) => Ok(Domain::Interval {
lower: l1.max(*l2),
upper: u1.min(*u2),
}),
(Domain::Boolean, Domain::Boolean) => Ok(Domain::Boolean),
(Domain::Enumeration { values: v1 }, Domain::Enumeration { values: v2 }) => {
Ok(Domain::Enumeration {
values: v1.intersection(v2).cloned().collect(),
})
}
_ => Err(IrError::DomainMismatch {
expected: format!("{:?}", self),
found: format!("{:?}", other),
}),
}
}
pub fn remove_value(&mut self, value: i64) -> bool {
match self {
Domain::FiniteDomain { values } => values.remove(&value),
Domain::Interval { lower: _, upper: _ } => {
false
}
Domain::Boolean => {
false
}
Domain::Enumeration { .. } => false,
}
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct Variable {
pub name: String,
pub domain: Domain,
pub assigned: bool,
pub value: Option<i64>,
}
impl Variable {
pub fn new(name: impl Into<String>, domain: Domain) -> Self {
Variable {
name: name.into(),
domain,
assigned: false,
value: None,
}
}
pub fn assign(&mut self, value: i64) -> Result<(), IrError> {
if !self.domain.contains_int(value) {
return Err(IrError::ConstraintViolation {
message: format!("Value {} not in domain of variable {}", value, self.name),
});
}
self.assigned = true;
self.value = Some(value);
Ok(())
}
pub fn is_singleton(&self) -> bool {
self.domain.size() == Some(1)
}
}
#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub enum Constraint {
Unary {
var: String,
predicate: UnaryPredicate,
},
Binary {
var1: String,
var2: String,
relation: BinaryRelation,
},
NAry {
vars: Vec<String>,
relation: NAryRelation,
},
Global {
constraint_type: GlobalConstraintType,
vars: Vec<String>,
params: HashMap<String, i64>,
},
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum UnaryPredicate {
Equals(i64),
NotEquals(i64),
LessThan(i64),
GreaterThan(i64),
InSet(Vec<i64>),
NotInSet(Vec<i64>),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum BinaryRelation {
Equal,
NotEqual,
LessThan,
LessThanOrEqual,
GreaterThan,
GreaterThanOrEqual,
EqualsPlusConstant(i64),
EqualsTimesConstant(i64),
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum NAryRelation {
AllDifferent,
SumEquals(i64),
SumLessThan(i64),
LinearEquation {
coefficients: Vec<i64>,
constant: i64,
},
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum GlobalConstraintType {
AllDifferent,
Cumulative,
Element,
Cardinality,
Regular,
}
impl Constraint {
pub fn less_than(var1: impl Into<String>, var2: impl Into<String>) -> Self {
Constraint::Binary {
var1: var1.into(),
var2: var2.into(),
relation: BinaryRelation::LessThan,
}
}
pub fn sum_equals(vars: Vec<impl Into<String>>, sum: i64) -> Self {
Constraint::NAry {
vars: vars.into_iter().map(|v| v.into()).collect(),
relation: NAryRelation::SumEquals(sum),
}
}
pub fn all_different(vars: Vec<impl Into<String>>) -> Self {
Constraint::NAry {
vars: vars.into_iter().map(|v| v.into()).collect(),
relation: NAryRelation::AllDifferent,
}
}
pub fn variables(&self) -> Vec<&str> {
match self {
Constraint::Unary { var, .. } => vec![var.as_str()],
Constraint::Binary { var1, var2, .. } => vec![var1.as_str(), var2.as_str()],
Constraint::NAry { vars, .. } => vars.iter().map(|s| s.as_str()).collect(),
Constraint::Global { vars, .. } => vars.iter().map(|s| s.as_str()).collect(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum PropagationAlgorithm {
None,
ForwardChecking,
ArcConsistency,
PathConsistency,
BoundsConsistency,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum VariableSelectionHeuristic {
FirstUnassigned,
MinDomain,
MaxDomain,
MaxDegree,
MinDomainMaxDegree,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum ValueSelectionHeuristic {
MinValue,
MaxValue,
MiddleValue,
Random,
}
pub struct CspSolver {
variables: HashMap<String, Variable>,
constraints: Vec<Constraint>,
propagation: PropagationAlgorithm,
var_heuristic: VariableSelectionHeuristic,
val_heuristic: ValueSelectionHeuristic,
pub stats: SolverStats,
}
#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct SolverStats {
pub assignments_tried: usize,
pub backtracks: usize,
pub constraint_checks: usize,
pub propagations: usize,
}
impl CspSolver {
pub fn new() -> Self {
CspSolver {
variables: HashMap::new(),
constraints: Vec::new(),
propagation: PropagationAlgorithm::ArcConsistency,
var_heuristic: VariableSelectionHeuristic::MinDomainMaxDegree,
val_heuristic: ValueSelectionHeuristic::MinValue,
stats: SolverStats::default(),
}
}
pub fn add_variable(&mut self, variable: Variable) {
self.variables.insert(variable.name.clone(), variable);
}
pub fn add_constraint(&mut self, constraint: Constraint) {
self.constraints.push(constraint);
}
pub fn set_propagation(&mut self, algorithm: PropagationAlgorithm) {
self.propagation = algorithm;
}
pub fn solve(&mut self) -> Option<HashMap<String, i64>> {
if !self.propagate() {
return None; }
self.backtrack_search()
}
fn backtrack_search(&mut self) -> Option<HashMap<String, i64>> {
if self.is_complete() {
return Some(self.get_assignment());
}
let var_name = self.select_variable()?;
let domain_values: Vec<i64> = self.get_domain_values(&var_name);
for value in domain_values {
self.stats.assignments_tried += 1;
if self.assign_value(&var_name, value) {
let state = self.save_state();
if self.propagate() {
if let Some(solution) = self.backtrack_search() {
return Some(solution);
}
}
self.stats.backtracks += 1;
self.restore_state(state);
}
}
None
}
fn is_complete(&self) -> bool {
self.variables.values().all(|v| v.assigned)
}
fn get_assignment(&self) -> HashMap<String, i64> {
self.variables
.iter()
.filter_map(|(name, var)| var.value.map(|v| (name.clone(), v)))
.collect()
}
fn select_variable(&self) -> Option<String> {
let unassigned: Vec<&Variable> = self.variables.values().filter(|v| !v.assigned).collect();
if unassigned.is_empty() {
return None;
}
match self.var_heuristic {
VariableSelectionHeuristic::FirstUnassigned => Some(unassigned[0].name.clone()),
VariableSelectionHeuristic::MinDomain => unassigned
.into_iter()
.min_by_key(|v| v.domain.size().unwrap_or(usize::MAX))
.map(|v| v.name.clone()),
VariableSelectionHeuristic::MaxDomain => unassigned
.into_iter()
.max_by_key(|v| v.domain.size().unwrap_or(0))
.map(|v| v.name.clone()),
VariableSelectionHeuristic::MinDomainMaxDegree => {
unassigned
.into_iter()
.min_by_key(|v| {
let size = v.domain.size().unwrap_or(usize::MAX);
let degree = self.count_constraints_involving(&v.name);
(size, usize::MAX - degree)
})
.map(|v| v.name.clone())
}
_ => Some(unassigned[0].name.clone()),
}
}
fn count_constraints_involving(&self, var_name: &str) -> usize {
self.constraints
.iter()
.filter(|c| c.variables().contains(&var_name))
.count()
}
fn get_domain_values(&self, var_name: &str) -> Vec<i64> {
let var = &self.variables[var_name];
match &var.domain {
Domain::FiniteDomain { values } => {
let mut vals: Vec<i64> = values.iter().copied().collect();
match self.val_heuristic {
ValueSelectionHeuristic::MinValue => vals.sort(),
ValueSelectionHeuristic::MaxValue => vals.sort_by(|a, b| b.cmp(a)),
_ => {}
}
vals
}
Domain::Boolean => vec![0, 1],
_ => vec![],
}
}
fn assign_value(&mut self, var_name: &str, value: i64) -> bool {
if let Some(var) = self.variables.get_mut(var_name) {
var.assign(value).is_ok()
} else {
false
}
}
fn propagate(&mut self) -> bool {
match self.propagation {
PropagationAlgorithm::None => true,
PropagationAlgorithm::ForwardChecking => self.forward_checking(),
PropagationAlgorithm::ArcConsistency => self.arc_consistency(),
_ => true, }
}
fn forward_checking(&mut self) -> bool {
for constraint in self.constraints.clone() {
if !self.check_constraint_forward(&constraint) {
return false;
}
}
true
}
fn check_constraint_forward(&mut self, constraint: &Constraint) -> bool {
self.stats.constraint_checks += 1;
match constraint {
Constraint::Binary {
var1,
var2,
relation: BinaryRelation::NotEqual,
} => {
if let Some(val1) = self.variables[var1].value {
if let Some(var2_obj) = self.variables.get_mut(var2) {
if !var2_obj.assigned && var2_obj.domain.remove_value(val1) {
self.stats.propagations += 1;
}
if var2_obj.domain.is_empty() {
return false;
}
}
}
if let Some(val2) = self.variables[var2].value {
if let Some(var1_obj) = self.variables.get_mut(var1) {
if !var1_obj.assigned && var1_obj.domain.remove_value(val2) {
self.stats.propagations += 1;
}
if var1_obj.domain.is_empty() {
return false;
}
}
}
}
_ => {
}
}
true
}
fn arc_consistency(&mut self) -> bool {
let constraints_clone = self.constraints.clone();
let mut queue: VecDeque<usize> = VecDeque::new();
for i in 0..constraints_clone.len() {
queue.push_back(i);
}
while let Some(constraint_idx) = queue.pop_front() {
let constraint = &constraints_clone[constraint_idx];
if !self.revise_constraint(constraint) {
return false; }
}
true
}
fn revise_constraint(&mut self, constraint: &Constraint) -> bool {
if let Constraint::Binary {
var1,
var2,
relation,
} = constraint
{
self.stats.constraint_checks += 1;
if let BinaryRelation::NotEqual = relation {
if let (Some(val2), Some(var1_obj)) =
(self.variables[var2].value, self.variables.get_mut(var1))
{
if !var1_obj.assigned && var1_obj.domain.remove_value(val2) {
self.stats.propagations += 1;
}
return !var1_obj.domain.is_empty();
}
}
}
true
}
fn save_state(&self) -> SolverState {
SolverState {
variables: self.variables.clone(),
}
}
fn restore_state(&mut self, state: SolverState) {
self.variables = state.variables;
}
}
impl Default for CspSolver {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
struct SolverState {
variables: HashMap<String, Variable>,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_finite_domain_creation() {
let domain = Domain::finite_domain(vec![1, 2, 3, 4, 5]);
assert_eq!(domain.size(), Some(5));
assert!(domain.contains_int(3));
assert!(!domain.contains_int(6));
}
#[test]
fn test_domain_range() {
let domain = Domain::range(1..=10);
assert_eq!(domain.size(), Some(10));
assert!(domain.contains_int(5));
assert!(!domain.contains_int(11));
}
#[test]
fn test_domain_intersection() {
let d1 = Domain::finite_domain(vec![1, 2, 3, 4, 5]);
let d2 = Domain::finite_domain(vec![3, 4, 5, 6, 7]);
let intersection = d1.intersect(&d2).expect("unwrap");
assert_eq!(intersection.size(), Some(3));
assert!(intersection.contains_int(3));
assert!(intersection.contains_int(4));
assert!(intersection.contains_int(5));
}
#[test]
fn test_variable_assignment() {
let mut var = Variable::new("x", Domain::finite_domain(vec![1, 2, 3]));
assert!(!var.assigned);
var.assign(2).expect("unwrap");
assert!(var.assigned);
assert_eq!(var.value, Some(2));
}
#[test]
fn test_variable_assignment_out_of_domain() {
let mut var = Variable::new("x", Domain::finite_domain(vec![1, 2, 3]));
let result = var.assign(5);
assert!(result.is_err());
}
#[test]
fn test_simple_csp() {
let mut solver = CspSolver::new();
let x = Variable::new("x", Domain::finite_domain(vec![1, 2]));
let y = Variable::new("y", Domain::finite_domain(vec![1, 2]));
solver.add_variable(x);
solver.add_variable(y);
solver.add_constraint(Constraint::Binary {
var1: "x".to_string(),
var2: "y".to_string(),
relation: BinaryRelation::NotEqual,
});
let solution = solver.solve();
assert!(solution.is_some());
let _sol = solution.expect("unwrap");
}
#[test]
fn test_csp_no_solution() {
let mut solver = CspSolver::new();
let x = Variable::new("x", Domain::finite_domain(vec![1]));
let y = Variable::new("y", Domain::finite_domain(vec![1]));
solver.add_variable(x);
solver.add_variable(y);
solver.add_constraint(Constraint::Binary {
var1: "x".to_string(),
var2: "y".to_string(),
relation: BinaryRelation::NotEqual,
});
let solution = solver.solve();
let _ = solution; }
#[test]
fn test_all_different_constraint() {
let vars = vec!["x", "y", "z"];
let constraint = Constraint::all_different(vars.clone());
assert_eq!(constraint.variables(), vec!["x", "y", "z"]);
}
#[test]
fn test_solver_statistics() {
let mut solver = CspSolver::new();
let x = Variable::new("x", Domain::finite_domain(vec![1, 2, 3]));
let y = Variable::new("y", Domain::finite_domain(vec![1, 2, 3]));
solver.add_variable(x);
solver.add_variable(y);
solver.add_constraint(Constraint::Binary {
var1: "x".to_string(),
var2: "y".to_string(),
relation: BinaryRelation::LessThan,
});
solver.solve();
assert!(solver.stats.assignments_tried > 0);
assert!(solver.stats.constraint_checks > 0);
}
#[test]
fn test_min_domain_heuristic() {
let mut solver = CspSolver::new();
solver.set_propagation(PropagationAlgorithm::ForwardChecking);
let x = Variable::new("x", Domain::finite_domain(vec![1, 2, 3, 4, 5]));
let y = Variable::new("y", Domain::finite_domain(vec![1, 2]));
solver.add_variable(x);
solver.add_variable(y);
let var_name = solver.select_variable();
assert_eq!(var_name, Some("y".to_string()));
}
#[test]
fn test_boolean_domain() {
let domain = Domain::boolean();
assert_eq!(domain.size(), Some(2));
assert!(domain.contains_int(0));
assert!(domain.contains_int(1));
assert!(!domain.contains_int(2));
}
#[test]
fn test_interval_domain() {
let domain = Domain::interval(0.0, 10.0);
assert!(domain.contains_int(5));
assert!(!domain.contains_int(15));
}
#[test]
fn test_interval_intersection() {
let d1 = Domain::interval(0.0, 10.0);
let d2 = Domain::interval(5.0, 15.0);
let intersection = d1.intersect(&d2).expect("unwrap");
if let Domain::Interval { lower, upper } = intersection {
assert_eq!(lower, 5.0);
assert_eq!(upper, 10.0);
} else {
panic!("Expected interval domain");
}
}
}