use std::collections::HashMap;
use crate::constraints::gac_hybrid::Variable;
use crate::variables::domain::bitset_domain::BitSetDomain;
fn combinations<T: Clone>(items: &[T], k: usize) -> Vec<Vec<T>> {
if k == 0 {
return vec![vec![]];
}
if items.len() < k {
return vec![];
}
if k == 1 {
return items.iter().map(|item| vec![item.clone()]).collect();
}
let mut result = Vec::new();
for (i, item) in items.iter().enumerate() {
let rest = &items[i + 1..];
for mut sub_combination in combinations(rest, k - 1) {
let mut combination = vec![item.clone()];
combination.append(&mut sub_combination);
result.push(combination);
}
}
result
}
pub struct BitSetGAC {
pub domains: HashMap<Variable, BitSetDomain>,
domains_changed: bool,
}
impl Default for BitSetGAC {
fn default() -> Self {
Self::new()
}
}
impl BitSetGAC {
pub fn new() -> Self {
Self {
domains: HashMap::with_capacity(32),
domains_changed: false,
}
}
pub fn add_variable(&mut self, var: Variable, min_val: i32, max_val: i32) {
let domain = BitSetDomain::new(min_val, max_val);
self.domains.insert(var, domain);
self.domains_changed = true;
}
pub fn add_variable_with_values(&mut self, var: Variable, values: Vec<i32>) {
let domain = BitSetDomain::new_from_values(values);
self.domains.insert(var, domain);
self.domains_changed = true;
}
pub fn remove_value(&mut self, var: Variable, val: i32) -> bool {
if let Some(domain) = self.domains.get_mut(&var) {
let changed = domain.remove(val);
if changed {
self.domains_changed = true;
}
changed
} else {
false
}
}
pub fn assign_variable(&mut self, var: Variable, val: i32) -> bool {
if let Some(domain) = self.domains.get_mut(&var) {
let old_size = domain.size();
domain.remove_all_but(val);
let changed = domain.size() != old_size;
if changed {
self.domains_changed = true;
}
changed
} else {
false
}
}
pub fn remove_above(&mut self, var: Variable, threshold: i32) -> bool {
if let Some(domain) = self.domains.get_mut(&var) {
let old_size = domain.size();
domain.remove_above(threshold);
let changed = domain.size() != old_size;
if changed {
self.domains_changed = true;
}
changed
} else {
false
}
}
pub fn remove_below(&mut self, var: Variable, threshold: i32) -> bool {
if let Some(domain) = self.domains.get_mut(&var) {
let old_size = domain.size();
domain.remove_below(threshold);
let changed = domain.size() != old_size;
if changed {
self.domains_changed = true;
}
changed
} else {
false
}
}
pub fn get_domain_values(&self, var: Variable) -> Vec<i32> {
if let Some(domain) = self.domains.get(&var) {
domain.to_vec()
} else {
Vec::new()
}
}
pub fn domain_size(&self, var: Variable) -> usize {
self.domains.get(&var).map_or(0, |d| d.size())
}
pub fn is_assigned(&self, var: Variable) -> bool {
self.domains.get(&var).map_or(false, |d| d.is_fixed())
}
pub fn assigned_value(&self, var: Variable) -> Option<i32> {
let domain = self.domains.get(&var)?;
if domain.is_fixed() {
domain.fixed_value()
} else {
None
}
}
pub fn is_inconsistent(&self, var: Variable) -> bool {
self.domains.get(&var).map_or(true, |d| d.is_empty())
}
pub fn get_bounds(&self, var: Variable) -> Option<(i32, i32)> {
if let Some(domain) = self.domains.get(&var) {
if let (Some(min), Some(max)) = (domain.min(), domain.max()) {
Some((min, max))
} else {
None
}
} else {
None
}
}
pub fn propagate_alldiff(&mut self, variables: &[Variable]) -> (bool, bool) {
if variables.len() <= 1 {
return (false, true); }
let mut changed = false;
let assigned_values: Vec<(Variable, i32)> = variables
.iter()
.filter_map(|&var| {
let domain = self.domains.get(&var)?;
if domain.is_fixed() {
Some((var, domain.fixed_value()?))
} else {
None
}
})
.collect();
for (assigned_var, assigned_val) in assigned_values {
for &var in variables {
if var != assigned_var {
if self.remove_value(var, assigned_val) {
changed = true;
if self.is_inconsistent(var) {
return (changed, false); }
}
}
}
}
let (hall_changed, hall_consistent) = self.propagate_hall_sets(variables);
if !hall_consistent {
return (changed, false); }
changed |= hall_changed;
(changed, true)
}
fn propagate_hall_sets(&mut self, variables: &[Variable]) -> (bool, bool) {
let mut changed = false;
if variables.len() <= 6 { for subset_size in 2..=variables.len().min(4) { let indices: Vec<usize> = (0..variables.len()).collect();
for subset_indices in combinations(&indices, subset_size) {
let subset: Vec<Variable> = subset_indices.iter()
.map(|&i| variables[i])
.collect();
let mut union_values = std::collections::HashSet::new();
for &var in &subset {
if let Some(domain) = self.domains.get(&var) {
let var_values: Vec<i32> = domain.into_iter().collect();
for val in &var_values {
union_values.insert(*val);
}
}
}
let union_size = union_values.len();
if subset.len() == union_size {
for &var in variables {
if !subset.contains(&var) {
if let Some(domain) = self.domains.get_mut(&var) {
let mut removed_any = false;
for &value in &union_values {
if domain.remove(value) {
removed_any = true;
}
}
if removed_any {
changed = true;
if domain.is_empty() {
return (changed, false); }
}
}
}
}
}
}
}
}
(changed, true)
}
pub fn variables(&self) -> impl Iterator<Item = Variable> + '_ {
self.domains.keys().copied()
}
pub fn domains_changed(&self) -> bool {
self.domains_changed
}
pub fn reset_changed_flag(&mut self) {
self.domains_changed = false;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bitset_domain_basic() {
let domain = BitSetDomain::new(1, 5);
assert_eq!(domain.size(), 5);
assert!(domain.contains(1));
assert!(domain.contains(5));
assert!(!domain.contains(0));
assert!(!domain.contains(6));
}
#[test]
fn test_bitset_domain_operations() {
let mut domain = BitSetDomain::new(1, 10);
assert!(domain.remove(5));
assert!(!domain.contains(5));
assert_eq!(domain.size(), 9);
assert_eq!(domain.min(), Some(1));
assert_eq!(domain.max(), Some(10));
domain.remove_above(7);
assert!(!domain.contains(8));
assert!(!domain.contains(9));
assert!(!domain.contains(10));
assert!(domain.contains(7));
domain.remove_below(3);
assert!(!domain.contains(1));
assert!(!domain.contains(2));
assert!(domain.contains(3));
}
#[test]
fn test_bitset_gac_basic() {
let mut gac = BitSetGAC::new();
gac.add_variable(Variable(0), 1, 3);
gac.add_variable(Variable(1), 1, 3);
gac.add_variable(Variable(2), 1, 3);
gac.assign_variable(Variable(0), 1);
let (changed, consistent) = gac.propagate_alldiff(&[Variable(0), Variable(1), Variable(2)]);
assert!(consistent);
assert!(changed);
assert!(!gac.domains.get(&Variable(1)).unwrap().contains(1));
assert!(!gac.domains.get(&Variable(2)).unwrap().contains(1));
}
#[test]
fn test_bitset_intersection_union() {
let domain1 = BitSetDomain::new_from_values(vec![1, 2, 3, 4]);
let domain2 = BitSetDomain::new_from_values(vec![3, 4, 5, 6]);
assert!(domain1.union_mask(&domain2).is_none());
let domain1 = BitSetDomain::new(1, 6);
let mut domain2 = BitSetDomain::new(1, 6);
domain2.remove(1);
domain2.remove(2);
let union_mask = domain1.union_mask(&domain2).unwrap();
let expected_size = (domain1.get_mask() | domain2.get_mask()).count_ones() as usize;
assert_eq!(union_mask.count_ones() as usize, expected_size);
let mut intersection = domain1.clone();
let _ = intersection.intersect_with(&domain2);
assert_eq!(intersection.size(), 4); assert!(!intersection.contains(1));
assert!(!intersection.contains(2));
assert!(intersection.contains(3));
}
}