use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct Domain {
values: Vec<f64>,
}
impl Domain {
pub fn new(values: Vec<f64>) -> Self {
let mut seen: HashSet<u64> = HashSet::new();
let deduped: Vec<f64> = values
.into_iter()
.filter(|v| seen.insert(v.to_bits()))
.collect();
Self { values: deduped }
}
pub fn from_range(start: f64, end: f64, step: f64) -> Self {
let mut values: Vec<f64> = Vec::new();
if step <= 0.0 {
if start <= end {
values.push(start);
}
return Self { values };
}
let mut current = start;
let epsilon = step * 1e-9;
while current <= end + epsilon {
values.push(current);
current += step;
}
Self { values }
}
pub fn boolean() -> Self {
Self {
values: vec![0.0, 1.0],
}
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn values(&self) -> &[f64] {
&self.values
}
pub fn remove_value(&mut self, val: f64) -> bool {
let bits = val.to_bits();
if let Some(pos) = self.values.iter().position(|v| v.to_bits() == bits) {
self.values.remove(pos);
true
} else {
false
}
}
pub fn retain<F: Fn(f64) -> bool>(&mut self, predicate: F) -> usize {
let before = self.values.len();
self.values.retain(|&v| predicate(v));
before - self.values.len()
}
pub fn intersect(&self, other: &Domain) -> Domain {
let other_bits: HashSet<u64> = other.values.iter().map(|v| v.to_bits()).collect();
let values: Vec<f64> = self
.values
.iter()
.copied()
.filter(|v| other_bits.contains(&v.to_bits()))
.collect();
Domain { values }
}
pub fn union(&self, other: &Domain) -> Domain {
let mut seen: HashSet<u64> = self.values.iter().map(|v| v.to_bits()).collect();
let mut values = self.values.clone();
for &v in &other.values {
if seen.insert(v.to_bits()) {
values.push(v);
}
}
Domain { values }
}
}
impl Default for Domain {
fn default() -> Self {
Self { values: Vec::new() }
}
}
#[derive(Clone)]
pub enum ConstraintRelation {
Equal,
NotEqual,
LessThan,
LessOrEqual,
GreaterThan,
GreaterOrEqual,
Difference(f64),
Custom(Arc<dyn Fn(f64, f64) -> bool + Send + Sync>),
}
impl std::fmt::Debug for ConstraintRelation {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstraintRelation::Equal => write!(f, "Equal"),
ConstraintRelation::NotEqual => write!(f, "NotEqual"),
ConstraintRelation::LessThan => write!(f, "LessThan"),
ConstraintRelation::LessOrEqual => write!(f, "LessOrEqual"),
ConstraintRelation::GreaterThan => write!(f, "GreaterThan"),
ConstraintRelation::GreaterOrEqual => write!(f, "GreaterOrEqual"),
ConstraintRelation::Difference(d) => write!(f, "Difference({})", d),
ConstraintRelation::Custom(_) => write!(f, "Custom(<fn>)"),
}
}
}
impl ConstraintRelation {
pub fn holds(&self, x: f64, y: f64) -> bool {
match self {
ConstraintRelation::Equal => (x - y).abs() < f64::EPSILON,
ConstraintRelation::NotEqual => (x - y).abs() >= f64::EPSILON,
ConstraintRelation::LessThan => x < y,
ConstraintRelation::LessOrEqual => x <= y,
ConstraintRelation::GreaterThan => x > y,
ConstraintRelation::GreaterOrEqual => x >= y,
ConstraintRelation::Difference(delta) => (x - y).abs() <= *delta,
ConstraintRelation::Custom(f) => f(x, y),
}
}
pub fn reversed(&self) -> ConstraintRelation {
match self {
ConstraintRelation::Equal => ConstraintRelation::Equal,
ConstraintRelation::NotEqual => ConstraintRelation::NotEqual,
ConstraintRelation::LessThan => ConstraintRelation::GreaterThan,
ConstraintRelation::LessOrEqual => ConstraintRelation::GreaterOrEqual,
ConstraintRelation::GreaterThan => ConstraintRelation::LessThan,
ConstraintRelation::GreaterOrEqual => ConstraintRelation::LessOrEqual,
ConstraintRelation::Difference(d) => ConstraintRelation::Difference(*d),
ConstraintRelation::Custom(f) => {
let f_clone = Arc::clone(f);
ConstraintRelation::Custom(Arc::new(move |x, y| f_clone(y, x)))
}
}
}
}
#[derive(Debug, Clone)]
pub struct BinaryConstraint {
pub var_x: String,
pub var_y: String,
pub relation: ConstraintRelation,
}
impl BinaryConstraint {
pub fn new(
var_x: impl Into<String>,
var_y: impl Into<String>,
relation: ConstraintRelation,
) -> Self {
Self {
var_x: var_x.into(),
var_y: var_y.into(),
relation,
}
}
}
pub struct ConstraintNetwork {
variables: HashMap<String, Domain>,
constraints: Vec<BinaryConstraint>,
}
impl ConstraintNetwork {
pub fn new() -> Self {
Self {
variables: HashMap::new(),
constraints: Vec::new(),
}
}
pub fn add_variable(&mut self, name: impl Into<String>, domain: Domain) {
self.variables.insert(name.into(), domain);
}
pub fn add_constraint(&mut self, constraint: BinaryConstraint) {
self.constraints.push(constraint);
}
pub fn variable_count(&self) -> usize {
self.variables.len()
}
pub fn constraint_count(&self) -> usize {
self.constraints.len()
}
pub fn domain(&self, var: &str) -> Option<&Domain> {
self.variables.get(var)
}
pub fn variable_names(&self) -> impl Iterator<Item = &str> {
self.variables.keys().map(|s| s.as_str())
}
}
impl Default for ConstraintNetwork {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct PropagationResult {
pub consistent: bool,
pub iterations: usize,
pub pruned: usize,
pub empty_domains: Vec<String>,
}
#[derive(Debug)]
pub struct SolveStats {
pub propagation_result: PropagationResult,
pub solutions_found: usize,
pub backtrack_count: usize,
pub nodes_explored: usize,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum VarOrdering {
Lexicographic,
MinRemainingValues,
DegreeHeuristic,
}
pub struct CspConfig {
pub max_solutions: usize,
pub use_arc_consistency: bool,
pub forward_checking: bool,
pub variable_ordering: VarOrdering,
}
impl Default for CspConfig {
fn default() -> Self {
Self {
max_solutions: 0,
use_arc_consistency: true,
forward_checking: true,
variable_ordering: VarOrdering::MinRemainingValues,
}
}
}
#[derive(Debug)]
struct ArcEntry {
xi: String,
xj: String,
relation: ConstraintRelation,
}
pub fn propagate_arc_consistency(network: &mut ConstraintNetwork) -> PropagationResult {
let mut queue: VecDeque<ArcEntry> = VecDeque::new();
for constraint in &network.constraints {
queue.push_back(ArcEntry {
xi: constraint.var_x.clone(),
xj: constraint.var_y.clone(),
relation: constraint.relation.clone(),
});
queue.push_back(ArcEntry {
xi: constraint.var_y.clone(),
xj: constraint.var_x.clone(),
relation: constraint.relation.reversed(),
});
}
let mut iterations: usize = 0;
let mut pruned: usize = 0;
let mut empty_domains: Vec<String> = Vec::new();
while let Some(arc) = queue.pop_front() {
iterations += 1;
let xi_values: Vec<f64> = match network.variables.get(&arc.xi) {
Some(d) => d.values().to_vec(),
None => continue,
};
let xj_values: Vec<f64> = match network.variables.get(&arc.xj) {
Some(d) => d.values().to_vec(),
None => continue,
};
let to_remove: Vec<f64> = xi_values
.iter()
.copied()
.filter(|&a| !xj_values.iter().any(|&b| arc.relation.holds(a, b)))
.collect();
if to_remove.is_empty() {
continue;
}
let xi_domain = network
.variables
.get_mut(&arc.xi)
.expect("variable must exist after earlier borrow");
for val in &to_remove {
xi_domain.remove_value(*val);
}
pruned += to_remove.len();
if xi_domain.is_empty() {
if !empty_domains.contains(&arc.xi) {
empty_domains.push(arc.xi.clone());
}
continue;
}
let xi_name = arc.xi.clone();
let xj_name = arc.xj.clone();
let new_arcs: Vec<ArcEntry> = network
.constraints
.iter()
.flat_map(|c| {
let mut arcs = Vec::new();
if c.var_x == xi_name && c.var_y != xj_name {
arcs.push(ArcEntry {
xi: c.var_y.clone(),
xj: xi_name.clone(),
relation: c.relation.reversed(),
});
}
if c.var_y == xi_name && c.var_x != xj_name {
arcs.push(ArcEntry {
xi: c.var_x.clone(),
xj: xi_name.clone(),
relation: c.relation.clone(),
});
}
arcs
})
.collect();
for new_arc in new_arcs {
queue.push_back(new_arc);
}
}
let consistent = empty_domains.is_empty();
PropagationResult {
consistent,
iterations,
pruned,
empty_domains,
}
}
fn select_variable(
unassigned: &[String],
domains: &HashMap<String, Domain>,
constraints: &[BinaryConstraint],
ordering: &VarOrdering,
) -> String {
match ordering {
VarOrdering::Lexicographic => {
let mut sorted = unassigned.to_vec();
sorted.sort();
sorted
.into_iter()
.next()
.expect("select_variable called with non-empty unassigned list")
}
VarOrdering::MinRemainingValues => unassigned
.iter()
.min_by_key(|v| domains.get(*v).map(|d| d.len()).unwrap_or(usize::MAX))
.cloned()
.expect("select_variable called with non-empty unassigned list"),
VarOrdering::DegreeHeuristic => {
let unassigned_set: HashSet<&String> = unassigned.iter().collect();
unassigned
.iter()
.max_by_key(|v| {
constraints
.iter()
.filter(|c| {
(&c.var_x == *v && unassigned_set.contains(&c.var_y))
|| (&c.var_y == *v && unassigned_set.contains(&c.var_x))
})
.count()
})
.cloned()
.expect("select_variable called with non-empty unassigned list")
}
}
}
fn forward_check(
var: &str,
value: f64,
domains: &mut HashMap<String, Domain>,
constraints: &[BinaryConstraint],
assigned: &HashMap<String, f64>,
) -> bool {
for constraint in constraints {
let (other_var, rel) = if constraint.var_x == var {
(constraint.var_y.as_str(), constraint.relation.clone())
} else if constraint.var_y == var {
(constraint.var_x.as_str(), constraint.relation.reversed())
} else {
continue;
};
if assigned.contains_key(other_var) {
continue;
}
let other_domain = match domains.get_mut(other_var) {
Some(d) => d,
None => continue,
};
other_domain.retain(|b| rel.holds(value, b));
if other_domain.is_empty() {
return false;
}
}
true
}
fn backtrack(
unassigned: &mut Vec<String>,
domains: &mut HashMap<String, Domain>,
assigned: &mut HashMap<String, f64>,
constraints: &[BinaryConstraint],
config: &CspConfig,
solutions: &mut Vec<HashMap<String, f64>>,
stats: &mut SolveStats,
) {
if config.max_solutions != 0 && solutions.len() >= config.max_solutions {
return;
}
if unassigned.is_empty() {
solutions.push(assigned.clone());
stats.solutions_found += 1;
return;
}
let var = select_variable(unassigned, domains, constraints, &config.variable_ordering);
unassigned.retain(|v| v != &var);
let original_domain = domains.get(&var).cloned().unwrap_or_default();
let candidates: Vec<f64> = domains
.get(&var)
.map(|d| d.values().to_vec())
.unwrap_or_default();
for value in candidates {
if config.max_solutions != 0 && solutions.len() >= config.max_solutions {
break;
}
stats.nodes_explored += 1;
assigned.insert(var.clone(), value);
let domain_snapshot: HashMap<String, Domain> = if config.forward_checking {
domains.clone()
} else {
HashMap::new()
};
let fc_ok = if config.forward_checking {
forward_check(&var, value, domains, constraints, assigned)
} else {
constraints.iter().all(|c| {
let x_assigned = assigned.get(&c.var_x);
let y_assigned = assigned.get(&c.var_y);
match (x_assigned, y_assigned) {
(Some(&xv), Some(&yv)) => c.relation.holds(xv, yv),
_ => true, }
})
};
if fc_ok {
backtrack(
unassigned,
domains,
assigned,
constraints,
config,
solutions,
stats,
);
} else {
stats.backtrack_count += 1;
}
assigned.remove(&var);
if config.forward_checking {
for (k, v) in domain_snapshot {
domains.insert(k, v);
}
}
}
unassigned.push(var.clone());
domains.insert(var, original_domain);
}
pub fn solve(
network: &ConstraintNetwork,
config: &CspConfig,
) -> (Vec<HashMap<String, f64>>, SolveStats) {
let mut domains: HashMap<String, Domain> = network.variables.clone();
let constraints = network.constraints.clone();
let propagation_result = if config.use_arc_consistency {
let mut temp_network = ConstraintNetwork {
variables: domains.clone(),
constraints: constraints.clone(),
};
let result = propagate_arc_consistency(&mut temp_network);
domains = temp_network.variables;
result
} else {
PropagationResult {
consistent: true,
iterations: 0,
pruned: 0,
empty_domains: Vec::new(),
}
};
let mut stats = SolveStats {
propagation_result,
solutions_found: 0,
backtrack_count: 0,
nodes_explored: 0,
};
if !stats.propagation_result.consistent {
return (Vec::new(), stats);
}
let mut unassigned: Vec<String> = domains.keys().cloned().collect();
let mut assigned: HashMap<String, f64> = HashMap::new();
let mut solutions: Vec<HashMap<String, f64>> = Vec::new();
backtrack(
&mut unassigned,
&mut domains,
&mut assigned,
&constraints,
config,
&mut solutions,
&mut stats,
);
(solutions, stats)
}
#[cfg(test)]
mod tests {
use super::*;
fn two_var_network(
domain_x: Domain,
domain_y: Domain,
relation: ConstraintRelation,
) -> ConstraintNetwork {
let mut net = ConstraintNetwork::new();
net.add_variable("x", domain_x);
net.add_variable("y", domain_y);
net.add_constraint(BinaryConstraint::new("x", "y", relation));
net
}
#[test]
fn test_domain_boolean_has_two_values() {
let d = Domain::boolean();
assert_eq!(d.len(), 2);
let vals = d.values();
assert!((vals[0] - 0.0).abs() < f64::EPSILON);
assert!((vals[1] - 1.0).abs() < f64::EPSILON);
}
#[test]
fn test_domain_from_range_basic() {
let d = Domain::from_range(0.0, 4.0, 1.0);
assert_eq!(d.len(), 5); assert!((d.values()[0] - 0.0).abs() < 1e-9);
assert!((d.values()[4] - 4.0).abs() < 1e-9);
}
#[test]
fn test_domain_from_range_fractional_step() {
let d = Domain::from_range(0.0, 1.0, 0.5);
assert_eq!(d.len(), 3); }
#[test]
fn test_domain_retain_removes_correct_values() {
let mut d = Domain::new(vec![1.0, 2.0, 3.0, 4.0]);
let removed = d.retain(|v| v < 3.0);
assert_eq!(removed, 2);
assert_eq!(d.len(), 2);
assert!(d.values().iter().all(|&v| v < 3.0));
}
#[test]
fn test_domain_remove_value() {
let mut d = Domain::new(vec![1.0, 2.0, 3.0]);
assert!(d.remove_value(2.0));
assert!(!d.remove_value(99.0));
assert_eq!(d.len(), 2);
}
#[test]
fn test_domain_intersect() {
let a = Domain::new(vec![1.0, 2.0, 3.0]);
let b = Domain::new(vec![2.0, 3.0, 4.0]);
let c = a.intersect(&b);
assert_eq!(c.len(), 2);
assert!(c.values().contains(&2.0));
assert!(c.values().contains(&3.0));
}
#[test]
fn test_domain_union() {
let a = Domain::new(vec![1.0, 2.0]);
let b = Domain::new(vec![2.0, 3.0]);
let c = a.union(&b);
assert_eq!(c.len(), 3);
}
#[test]
fn test_domain_deduplication_on_new() {
let d = Domain::new(vec![1.0, 1.0, 2.0]);
assert_eq!(d.len(), 2);
}
#[test]
fn test_empty_network_is_consistent() {
let mut net = ConstraintNetwork::new();
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
assert_eq!(result.pruned, 0);
assert_eq!(result.iterations, 0);
}
#[test]
fn test_single_variable_no_constraints_consistent() {
let mut net = ConstraintNetwork::new();
net.add_variable("x", Domain::from_range(0.0, 5.0, 1.0));
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
assert_eq!(result.pruned, 0);
}
#[test]
fn test_equal_constraint_prunes_to_intersection() {
let mut net = two_var_network(
Domain::new(vec![1.0, 2.0, 3.0]),
Domain::new(vec![2.0, 3.0, 4.0]),
ConstraintRelation::Equal,
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
let dx = net.domain("x").expect("x must exist");
let dy = net.domain("y").expect("y must exist");
assert_eq!(dx.len(), 2);
assert_eq!(dy.len(), 2);
assert!(result.pruned > 0);
}
#[test]
fn test_not_equal_boolean_domain() {
let mut net = two_var_network(
Domain::boolean(),
Domain::boolean(),
ConstraintRelation::NotEqual,
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
assert_eq!(net.domain("x").expect("x").len(), 2);
assert_eq!(net.domain("y").expect("y").len(), 2);
}
#[test]
fn test_not_equal_single_value_forces_removal() {
let mut net = two_var_network(
Domain::new(vec![1.0]),
Domain::new(vec![1.0, 2.0]),
ConstraintRelation::NotEqual,
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
let dy = net.domain("y").expect("y");
assert_eq!(dy.len(), 1);
assert!((dy.values()[0] - 2.0).abs() < f64::EPSILON);
}
#[test]
fn test_less_than_prunes_domains() {
let mut net = two_var_network(
Domain::new(vec![1.0, 2.0, 3.0]),
Domain::new(vec![1.0, 2.0, 3.0]),
ConstraintRelation::LessThan,
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
let dx = net.domain("x").expect("x");
let dy = net.domain("y").expect("y");
assert!(!dx.values().contains(&3.0));
assert!(!dy.values().contains(&1.0));
}
#[test]
fn test_less_or_equal_overlapping_domains() {
let mut net = two_var_network(
Domain::new(vec![1.0, 2.0, 3.0]),
Domain::new(vec![2.0, 3.0, 4.0]),
ConstraintRelation::LessOrEqual,
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
let dx = net.domain("x").expect("x");
assert_eq!(dx.len(), 3);
}
#[test]
fn test_chain_propagation_x_lt_y_lt_z() {
let mut net = ConstraintNetwork::new();
net.add_variable("x", Domain::new(vec![1.0, 2.0, 3.0]));
net.add_variable("y", Domain::new(vec![1.0, 2.0, 3.0]));
net.add_variable("z", Domain::new(vec![1.0, 2.0, 3.0]));
net.add_constraint(BinaryConstraint::new(
"x",
"y",
ConstraintRelation::LessThan,
));
net.add_constraint(BinaryConstraint::new(
"y",
"z",
ConstraintRelation::LessThan,
));
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
assert!(!net.domain("x").expect("x").is_empty());
assert!(!net.domain("y").expect("y").is_empty());
assert!(!net.domain("z").expect("z").is_empty());
let dx = net.domain("x").expect("x");
assert!(!dx.values().contains(&3.0));
}
#[test]
fn test_inconsistency_detected_empty_domain() {
let mut net = two_var_network(
Domain::new(vec![3.0]),
Domain::new(vec![1.0, 2.0]),
ConstraintRelation::LessThan,
);
let result = propagate_arc_consistency(&mut net);
assert!(!result.consistent);
assert!(!result.empty_domains.is_empty());
}
#[test]
fn test_propagation_result_pruned_counts_correctly() {
let mut net = two_var_network(
Domain::new(vec![1.0, 2.0, 3.0, 4.0]),
Domain::new(vec![3.0, 4.0, 5.0]),
ConstraintRelation::Equal,
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
assert_eq!(result.pruned, 3);
}
#[test]
fn test_difference_constraint_within_tolerance() {
let mut net = two_var_network(
Domain::new(vec![1.0, 2.0, 3.0]),
Domain::new(vec![1.0, 2.0, 3.0]),
ConstraintRelation::Difference(0.5),
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
assert_eq!(result.pruned, 0);
}
#[test]
fn test_difference_constraint_prunes_far_values() {
let mut net = two_var_network(
Domain::new(vec![1.0]),
Domain::new(vec![1.0, 5.0]),
ConstraintRelation::Difference(0.5),
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
assert_eq!(net.domain("y").expect("y").len(), 1);
}
#[test]
fn test_custom_constraint_function() {
let rel = ConstraintRelation::Custom(Arc::new(|x, y| (x + y - 3.0).abs() < 1e-9));
let mut net = two_var_network(
Domain::new(vec![1.0, 2.0, 3.0]),
Domain::new(vec![1.0, 2.0, 3.0]),
rel,
);
let result = propagate_arc_consistency(&mut net);
assert!(result.consistent);
let dx = net.domain("x").expect("x");
let dy = net.domain("y").expect("y");
assert!(!dx.values().contains(&3.0));
assert!(!dy.values().contains(&3.0));
}
#[test]
fn test_solver_finds_single_solution_equality() {
let net = two_var_network(
Domain::new(vec![42.0]),
Domain::new(vec![42.0]),
ConstraintRelation::Equal,
);
let config = CspConfig::default();
let (solutions, stats) = solve(&net, &config);
assert_eq!(solutions.len(), 1);
assert_eq!(stats.solutions_found, 1);
let sol = &solutions[0];
assert!((sol["x"] - 42.0).abs() < f64::EPSILON);
assert!((sol["y"] - 42.0).abs() < f64::EPSILON);
}
#[test]
fn test_solver_finds_all_solutions_not_equal_boolean() {
let net = two_var_network(
Domain::boolean(),
Domain::boolean(),
ConstraintRelation::NotEqual,
);
let config = CspConfig {
max_solutions: 0,
..CspConfig::default()
};
let (solutions, stats) = solve(&net, &config);
assert_eq!(solutions.len(), 2, "expected exactly 2 solutions");
assert_eq!(stats.solutions_found, 2);
}
#[test]
fn test_solver_max_solutions_stops_early() {
let net = two_var_network(
Domain::boolean(),
Domain::boolean(),
ConstraintRelation::NotEqual,
);
let config = CspConfig {
max_solutions: 1,
..CspConfig::default()
};
let (solutions, _stats) = solve(&net, &config);
assert_eq!(solutions.len(), 1);
}
#[test]
fn test_mrv_ordering_selects_smallest_domain() {
let domains: HashMap<String, Domain> = [
("x".to_string(), Domain::new(vec![1.0])),
("y".to_string(), Domain::new(vec![1.0, 2.0, 3.0])),
]
.into_iter()
.collect();
let constraints: Vec<BinaryConstraint> = Vec::new();
let unassigned = vec!["x".to_string(), "y".to_string()];
let chosen = select_variable(
&unassigned,
&domains,
&constraints,
&VarOrdering::MinRemainingValues,
);
assert_eq!(chosen, "x");
}
#[test]
fn test_degree_heuristic_selects_most_constrained() {
let domains: HashMap<String, Domain> = [
("x".to_string(), Domain::new(vec![1.0, 2.0])),
("y".to_string(), Domain::new(vec![1.0, 2.0])),
("z".to_string(), Domain::new(vec![1.0, 2.0])),
]
.into_iter()
.collect();
let constraints = vec![
BinaryConstraint::new("x", "y", ConstraintRelation::NotEqual),
BinaryConstraint::new("x", "z", ConstraintRelation::NotEqual),
];
let unassigned = vec!["x".to_string(), "y".to_string(), "z".to_string()];
let chosen = select_variable(
&unassigned,
&domains,
&constraints,
&VarOrdering::DegreeHeuristic,
);
assert_eq!(chosen, "x");
}
#[test]
fn test_solver_backtrack_count_nonzero_for_conflicted_search() {
let rel = ConstraintRelation::Custom(Arc::new(|x, y| (x + y - 3.0).abs() < 1e-9));
let net = two_var_network(
Domain::new(vec![1.0, 2.0, 3.0]),
Domain::new(vec![1.0, 2.0, 3.0]),
rel,
);
let config = CspConfig {
use_arc_consistency: false,
forward_checking: false,
max_solutions: 0,
variable_ordering: VarOrdering::Lexicographic,
};
let (solutions, stats) = solve(&net, &config);
assert!(!solutions.is_empty());
assert!(stats.backtrack_count > 0 || stats.nodes_explored > solutions.len());
}
#[test]
fn test_solve_stats_nodes_explored() {
let net = two_var_network(
Domain::new(vec![1.0, 2.0]),
Domain::new(vec![1.0, 2.0]),
ConstraintRelation::LessThan,
);
let config = CspConfig::default();
let (_solutions, stats) = solve(&net, &config);
assert!(stats.nodes_explored > 0);
}
#[test]
fn test_constraint_relation_reversed() {
let rel = ConstraintRelation::LessThan;
let rev = rel.reversed();
assert!(rel.holds(1.0, 3.0));
assert!(!rel.holds(3.0, 1.0));
assert!(rev.holds(3.0, 1.0));
assert!(!rev.holds(1.0, 3.0));
let loe = ConstraintRelation::LessOrEqual;
let goe = loe.reversed();
assert!(goe.holds(3.0, 1.0)); assert!(goe.holds(2.0, 2.0)); assert!(!goe.holds(1.0, 3.0));
let eq = ConstraintRelation::Equal;
let eq_rev = eq.reversed();
assert!(eq_rev.holds(5.0, 5.0));
assert!(!eq_rev.holds(5.0, 6.0));
let ne = ConstraintRelation::NotEqual;
let ne_rev = ne.reversed();
assert!(ne_rev.holds(1.0, 2.0));
assert!(!ne_rev.holds(1.0, 1.0));
}
#[test]
fn test_constraint_network_counts() {
let mut net = ConstraintNetwork::new();
net.add_variable("a", Domain::boolean());
net.add_variable("b", Domain::boolean());
net.add_constraint(BinaryConstraint::new(
"a",
"b",
ConstraintRelation::NotEqual,
));
assert_eq!(net.variable_count(), 2);
assert_eq!(net.constraint_count(), 1);
}
#[test]
fn test_lexicographic_ordering() {
let domains: HashMap<String, Domain> = [
("zebra".to_string(), Domain::new(vec![1.0])),
("apple".to_string(), Domain::new(vec![1.0, 2.0, 3.0])),
]
.into_iter()
.collect();
let constraints: Vec<BinaryConstraint> = Vec::new();
let unassigned = vec!["zebra".to_string(), "apple".to_string()];
let chosen = select_variable(
&unassigned,
&domains,
&constraints,
&VarOrdering::Lexicographic,
);
assert_eq!(chosen, "apple");
}
}