Skip to main content

oxiz_sat/
cardinality.rs

1//! Cardinality constraint encoding
2//!
3//! This module implements efficient encoding of cardinality constraints into CNF.
4//! Cardinality constraints express conditions like:
5//! - At-most-k: at most k of the given literals can be true
6//! - At-least-k: at least k of the given literals must be true
7//! - Exactly-k: exactly k of the given literals must be true
8//!
9//! We use the Totalizer encoding which provides:
10//! - Efficient incremental strengthening
11//! - Good propagation
12//! - Reasonable clause count
13
14use crate::literal::{Lit, Var};
15#[allow(unused_imports)]
16use crate::prelude::*;
17use crate::solver::Solver;
18use smallvec::SmallVec;
19
20/// Cardinality constraint encoder
21pub struct CardinalityEncoder;
22
23impl CardinalityEncoder {
24    /// Encode an at-most-k constraint: sum(lits) <= k
25    ///
26    /// # Arguments
27    ///
28    /// * `solver` - The SAT solver
29    /// * `lits` - The literals in the constraint
30    /// * `k` - The upper bound
31    ///
32    /// Returns true if the constraint was successfully encoded
33    pub fn encode_at_most_k(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
34        if k >= lits.len() {
35            return true; // Constraint is trivially satisfied
36        }
37
38        if k == 0 {
39            // None of the literals can be true
40            for &lit in lits {
41                solver.add_clause([lit.negate()]);
42            }
43            return true;
44        }
45
46        if lits.len() <= 4 {
47            // For small constraints, use direct encoding
48            Self::encode_at_most_k_direct(solver, lits, k)
49        } else {
50            // For larger constraints, use totalizer encoding
51            Self::encode_at_most_k_totalizer(solver, lits, k)
52        }
53    }
54
55    /// Encode an at-least-k constraint: sum(lits) >= k
56    ///
57    /// Equivalent to: at-most-(n-k) of the negations
58    pub fn encode_at_least_k(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
59        if k == 0 {
60            return true; // Trivially satisfied
61        }
62
63        if k > lits.len() {
64            return false; // Unsatisfiable
65        }
66
67        if k == 1 {
68            // At least one must be true - simple clause
69            solver.add_clause(lits.iter().copied());
70            return true;
71        }
72
73        // Transform to at-most constraint on negations
74        let negated: Vec<Lit> = lits.iter().map(|&l| l.negate()).collect();
75        Self::encode_at_most_k(solver, &negated, lits.len() - k)
76    }
77
78    /// Encode an exactly-k constraint: sum(lits) == k
79    pub fn encode_exactly_k(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
80        if k > lits.len() {
81            return false;
82        }
83
84        // Combine at-most-k and at-least-k
85        Self::encode_at_most_k(solver, lits, k) && Self::encode_at_least_k(solver, lits, k)
86    }
87
88    /// Direct encoding for small at-most-k constraints
89    fn encode_at_most_k_direct(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
90        // Generate all subsets of size k+1 and forbid them
91        let n = lits.len();
92        if k >= n {
93            return true;
94        }
95
96        // Generate all combinations of k+1 literals
97        Self::generate_combinations(lits, k + 1, &mut |combo| {
98            // Add clause: at least one of these must be false
99            let negated: SmallVec<[Lit; 8]> = combo.iter().map(|&&l| l.negate()).collect();
100            solver.add_clause(negated.iter().copied());
101        });
102
103        true
104    }
105
106    /// Helper function to generate all k-combinations
107    fn generate_combinations<F>(lits: &[Lit], k: usize, callback: &mut F)
108    where
109        F: FnMut(&[&Lit]),
110    {
111        let mut indices = vec![0; k];
112        let n = lits.len();
113
114        if k > n {
115            return;
116        }
117
118        // Initialize first combination
119        for (i, item) in indices.iter_mut().enumerate().take(k) {
120            *item = i;
121        }
122
123        loop {
124            // Call callback with current combination
125            let combo: Vec<&Lit> = indices.iter().map(|&i| &lits[i]).collect();
126            callback(&combo);
127
128            // Find the rightmost index that can be incremented
129            let mut i = k;
130            loop {
131                if i == 0 {
132                    return; // No more combinations
133                }
134                i -= 1;
135                if indices[i] < n - k + i {
136                    break;
137                }
138            }
139
140            // Increment this index and reset all following indices
141            indices[i] += 1;
142            for j in (i + 1)..k {
143                indices[j] = indices[j - 1] + 1;
144            }
145        }
146    }
147
148    /// Totalizer encoding for at-most-k constraints
149    ///
150    /// The totalizer builds a tree of adder circuits that count the number
151    /// of true literals. It introduces auxiliary variables representing
152    /// "at least i literals are true" for various i.
153    fn encode_at_most_k_totalizer(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
154        if lits.is_empty() || k >= lits.len() {
155            return true;
156        }
157
158        // Build totalizer tree
159        let root_vars = Self::build_totalizer_tree(solver, lits, k);
160
161        // Add constraint: at most k can be true
162        // This means the (k+1)-th totalizer variable must be false
163        if k < root_vars.len() {
164            solver.add_clause([Lit::neg(root_vars[k])]);
165        }
166
167        true
168    }
169
170    /// Build the totalizer tree and return the root counting variables
171    ///
172    /// Returns a vector where output[i] represents "at least i+1 literals are true"
173    fn build_totalizer_tree(solver: &mut Solver, lits: &[Lit], bound: usize) -> Vec<Var> {
174        if lits.len() == 1 {
175            // Leaf node: return the variable of the literal
176            return vec![lits[0].var()];
177        }
178
179        // Split the literals into two halves
180        let mid = lits.len() / 2;
181        let left_lits = &lits[..mid];
182        let right_lits = &lits[mid..];
183
184        // Recursively build subtrees
185        let left_vars = Self::build_totalizer_tree(solver, left_lits, bound);
186        let right_vars = Self::build_totalizer_tree(solver, right_lits, bound);
187
188        // Merge the two subtrees
189        let max_count = (left_vars.len() + right_vars.len()).min(bound + 1);
190        let mut output = Vec::with_capacity(max_count);
191
192        for _ in 0..max_count {
193            output.push(solver.new_var());
194        }
195
196        // Add clauses for the totalizer merge
197        Self::add_totalizer_clauses(solver, &left_vars, &right_vars, &output);
198
199        output
200    }
201
202    /// Add clauses for merging two totalizer trees
203    ///
204    /// Implements the totalizer merge operation:
205    /// output[i] is true iff at least i+1 of the input literals are true
206    fn add_totalizer_clauses(solver: &mut Solver, left: &[Var], right: &[Var], output: &[Var]) {
207        // For each output position i (representing "at least i+1 are true")
208        for (i, &out_var) in output.iter().enumerate() {
209            let count = i + 1; // Number of true literals needed
210
211            // If left has >= j and right has >= k where j+k >= count, then output[i] is true
212            for j in 0..=left.len() {
213                for k in 0..=right.len() {
214                    if j + k >= count && j + k > 0 {
215                        let mut clause = SmallVec::<[Lit; 4]>::new();
216
217                        // If left[j-1] is true and right[k-1] is true, then output[i] is true
218                        if j > 0 && j <= left.len() {
219                            clause.push(Lit::neg(left[j - 1]));
220                        }
221                        if k > 0 && k <= right.len() {
222                            clause.push(Lit::neg(right[k - 1]));
223                        }
224
225                        if !clause.is_empty() {
226                            clause.push(Lit::pos(out_var));
227                            solver.add_clause(clause.iter().copied());
228                        }
229                    }
230                }
231            }
232
233            // Reverse direction: if output[i] is true, then sufficient input must be true
234            // output[i] => (left[j] or right[k]) for all valid j, k where j+k+2 == count+1
235            // This is captured by the contrapositive of the above clauses
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243    use crate::solver::SolverResult;
244
245    #[test]
246    fn test_at_most_0() {
247        let mut solver = Solver::new();
248        let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
249        let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
250
251        CardinalityEncoder::encode_at_most_k(&mut solver, &lits, 0);
252
253        let result = solver.solve();
254        // At most 0 means all must be false, which is satisfiable
255        assert_eq!(result, SolverResult::Sat);
256    }
257
258    #[test]
259    fn test_at_most_1() {
260        let mut solver = Solver::new();
261        let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
262        let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
263
264        CardinalityEncoder::encode_at_most_k(&mut solver, &lits, 1);
265
266        let result = solver.solve();
267        assert_eq!(result, SolverResult::Sat);
268    }
269
270    #[test]
271    fn test_at_least_1() {
272        let mut solver = Solver::new();
273        let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
274        let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
275
276        CardinalityEncoder::encode_at_least_k(&mut solver, &lits, 1);
277
278        let result = solver.solve();
279        assert_eq!(result, SolverResult::Sat);
280    }
281
282    #[test]
283    fn test_exactly_2() {
284        let mut solver = Solver::new();
285        let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
286        let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
287
288        CardinalityEncoder::encode_exactly_k(&mut solver, &lits, 2);
289
290        let result = solver.solve();
291        assert_eq!(result, SolverResult::Sat);
292    }
293
294    #[test]
295    fn test_at_most_exceeds_length() {
296        let mut solver = Solver::new();
297        let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
298        let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
299
300        let success = CardinalityEncoder::encode_at_most_k(&mut solver, &lits, 5);
301        assert!(success);
302
303        let result = solver.solve();
304        assert_eq!(result, SolverResult::Sat);
305    }
306}