Skip to main content

csp_solver/constraint/
cardinality.rs

1//! Cardinality constraint: at-least / at-most / exactly N variables
2//! take a particular value.
3//!
4//! `Cardinality { vars, value, lo, hi }` enforces that the number of
5//! variables in `vars` taking exactly `value` lies in `[lo, hi]`.
6//! Used by recognizer-tier CSPs to enforce hoisting thresholds:
7//! "at least N peer sites must commit to `Hoist` for the shared helper
8//! to be eligible."
9
10use crate::domain::Domain;
11use crate::variable::Variable;
12
13use super::traits::{Constraint, Revision, VarId};
14
15#[derive(Debug, Clone)]
16pub struct CardinalityConstraint<V: Clone + PartialEq + std::fmt::Debug> {
17    pub(crate) scope: Vec<VarId>,
18    pub(crate) value: V,
19    pub(crate) lo: usize,
20    pub(crate) hi: usize,
21}
22
23impl<V: Clone + PartialEq + std::fmt::Debug> CardinalityConstraint<V> {
24    pub fn new(vars: Vec<VarId>, value: V, lo: usize, hi: usize) -> Self {
25        debug_assert!(lo <= hi);
26        Self {
27            scope: vars,
28            value,
29            lo,
30            hi,
31        }
32    }
33
34    /// Convenience: at-least-`lo`.
35    pub fn at_least(vars: Vec<VarId>, value: V, lo: usize) -> Self {
36        let hi = vars.len();
37        Self::new(vars, value, lo, hi)
38    }
39
40    /// Convenience: at-most-`hi`.
41    pub fn at_most(vars: Vec<VarId>, value: V, hi: usize) -> Self {
42        Self::new(vars, value, 0, hi)
43    }
44}
45
46impl<D: Domain> Constraint<D> for CardinalityConstraint<D::Value>
47where
48    D::Value: PartialEq,
49{
50    fn scope(&self) -> &[VarId] {
51        &self.scope
52    }
53
54    fn check(&self, assignment: &[Option<D::Value>]) -> bool {
55        let mut taken = 0usize;
56        let mut unbound = 0usize;
57        for &v in &self.scope {
58            match &assignment[v as usize] {
59                Some(val) if *val == self.value => taken += 1,
60                Some(_) => {}
61                None => unbound += 1,
62            }
63        }
64        // The full constraint can still be satisfied as long as
65        // [taken, taken + unbound] intersects [lo, hi].
66        let max = taken + unbound;
67        max >= self.lo && taken <= self.hi
68    }
69
70    fn revise(&self, vars: &mut [Variable<D>], depth: usize) -> Revision {
71        // Count how many vars are forced to take `value`, how many are
72        // forced not to, and how many are still flexible.
73        let mut forced_yes = 0usize;
74        let mut flexible: Vec<VarId> = Vec::new();
75
76        for &v in &self.scope {
77            let dom = &vars[v as usize].domain;
78            let contains_value = dom.contains(&self.value);
79            let only_value = dom
80                .singleton_value()
81                .map(|sv| sv == self.value)
82                .unwrap_or(false);
83            if only_value {
84                forced_yes += 1;
85            } else if contains_value {
86                flexible.push(v);
87            }
88            // else: domain doesn't contain `value` at all — neutral.
89        }
90
91        // Total possible takers = forced_yes + flexible.len()
92        let max_takers = forced_yes + flexible.len();
93        // Minimum forced takers
94        let min_takers = forced_yes;
95
96        if max_takers < self.lo || min_takers > self.hi {
97            return Revision::Unsatisfiable;
98        }
99
100        let mut changed = false;
101
102        // If forced_yes already at upper bound, prune `value` from
103        // every flexible variable.
104        if forced_yes == self.hi {
105            for v in &flexible {
106                if vars[*v as usize].prune(&self.value, depth) {
107                    changed = true;
108                }
109                if vars[*v as usize].domain.is_empty() {
110                    return Revision::Unsatisfiable;
111                }
112            }
113        }
114
115        // If we need more takers than the flexible set can provide,
116        // every flexible variable must be `value`. Restrict each
117        // flexible domain to the singleton.
118        if min_takers + flexible.len() == self.lo && self.lo > min_takers {
119            for v in &flexible {
120                vars[*v as usize].restrict_to(&self.value, depth);
121                if vars[*v as usize].domain.is_empty() {
122                    return Revision::Unsatisfiable;
123                }
124                changed = true;
125            }
126        }
127
128        if changed {
129            Revision::Changed
130        } else {
131            Revision::Unchanged
132        }
133    }
134}