1use crate::literal::{Lit, Var};
15#[allow(unused_imports)]
16use crate::prelude::*;
17use crate::solver::Solver;
18use smallvec::SmallVec;
19
20pub struct CardinalityEncoder;
22
23impl CardinalityEncoder {
24 pub fn encode_at_most_k(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
34 if k >= lits.len() {
35 return true; }
37
38 if k == 0 {
39 for &lit in lits {
41 solver.add_clause([lit.negate()]);
42 }
43 return true;
44 }
45
46 if lits.len() <= 4 {
47 Self::encode_at_most_k_direct(solver, lits, k)
49 } else {
50 Self::encode_at_most_k_totalizer(solver, lits, k)
52 }
53 }
54
55 pub fn encode_at_least_k(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
59 if k == 0 {
60 return true; }
62
63 if k > lits.len() {
64 return false; }
66
67 if k == 1 {
68 solver.add_clause(lits.iter().copied());
70 return true;
71 }
72
73 let negated: Vec<Lit> = lits.iter().map(|&l| l.negate()).collect();
75 Self::encode_at_most_k(solver, &negated, lits.len() - k)
76 }
77
78 pub fn encode_exactly_k(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
80 if k > lits.len() {
81 return false;
82 }
83
84 Self::encode_at_most_k(solver, lits, k) && Self::encode_at_least_k(solver, lits, k)
86 }
87
88 fn encode_at_most_k_direct(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
90 let n = lits.len();
92 if k >= n {
93 return true;
94 }
95
96 Self::generate_combinations(lits, k + 1, &mut |combo| {
98 let negated: SmallVec<[Lit; 8]> = combo.iter().map(|&&l| l.negate()).collect();
100 solver.add_clause(negated.iter().copied());
101 });
102
103 true
104 }
105
106 fn generate_combinations<F>(lits: &[Lit], k: usize, callback: &mut F)
108 where
109 F: FnMut(&[&Lit]),
110 {
111 let mut indices = vec![0; k];
112 let n = lits.len();
113
114 if k > n {
115 return;
116 }
117
118 for (i, item) in indices.iter_mut().enumerate().take(k) {
120 *item = i;
121 }
122
123 loop {
124 let combo: Vec<&Lit> = indices.iter().map(|&i| &lits[i]).collect();
126 callback(&combo);
127
128 let mut i = k;
130 loop {
131 if i == 0 {
132 return; }
134 i -= 1;
135 if indices[i] < n - k + i {
136 break;
137 }
138 }
139
140 indices[i] += 1;
142 for j in (i + 1)..k {
143 indices[j] = indices[j - 1] + 1;
144 }
145 }
146 }
147
148 fn encode_at_most_k_totalizer(solver: &mut Solver, lits: &[Lit], k: usize) -> bool {
154 if lits.is_empty() || k >= lits.len() {
155 return true;
156 }
157
158 let root_vars = Self::build_totalizer_tree(solver, lits, k);
160
161 if k < root_vars.len() {
164 solver.add_clause([Lit::neg(root_vars[k])]);
165 }
166
167 true
168 }
169
170 fn build_totalizer_tree(solver: &mut Solver, lits: &[Lit], bound: usize) -> Vec<Var> {
174 if lits.len() == 1 {
175 return vec![lits[0].var()];
177 }
178
179 let mid = lits.len() / 2;
181 let left_lits = &lits[..mid];
182 let right_lits = &lits[mid..];
183
184 let left_vars = Self::build_totalizer_tree(solver, left_lits, bound);
186 let right_vars = Self::build_totalizer_tree(solver, right_lits, bound);
187
188 let max_count = (left_vars.len() + right_vars.len()).min(bound + 1);
190 let mut output = Vec::with_capacity(max_count);
191
192 for _ in 0..max_count {
193 output.push(solver.new_var());
194 }
195
196 Self::add_totalizer_clauses(solver, &left_vars, &right_vars, &output);
198
199 output
200 }
201
202 fn add_totalizer_clauses(solver: &mut Solver, left: &[Var], right: &[Var], output: &[Var]) {
207 for (i, &out_var) in output.iter().enumerate() {
209 let count = i + 1; for j in 0..=left.len() {
213 for k in 0..=right.len() {
214 if j + k >= count && j + k > 0 {
215 let mut clause = SmallVec::<[Lit; 4]>::new();
216
217 if j > 0 && j <= left.len() {
219 clause.push(Lit::neg(left[j - 1]));
220 }
221 if k > 0 && k <= right.len() {
222 clause.push(Lit::neg(right[k - 1]));
223 }
224
225 if !clause.is_empty() {
226 clause.push(Lit::pos(out_var));
227 solver.add_clause(clause.iter().copied());
228 }
229 }
230 }
231 }
232
233 }
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use crate::solver::SolverResult;
244
245 #[test]
246 fn test_at_most_0() {
247 let mut solver = Solver::new();
248 let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
249 let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
250
251 CardinalityEncoder::encode_at_most_k(&mut solver, &lits, 0);
252
253 let result = solver.solve();
254 assert_eq!(result, SolverResult::Sat);
256 }
257
258 #[test]
259 fn test_at_most_1() {
260 let mut solver = Solver::new();
261 let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
262 let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
263
264 CardinalityEncoder::encode_at_most_k(&mut solver, &lits, 1);
265
266 let result = solver.solve();
267 assert_eq!(result, SolverResult::Sat);
268 }
269
270 #[test]
271 fn test_at_least_1() {
272 let mut solver = Solver::new();
273 let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
274 let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
275
276 CardinalityEncoder::encode_at_least_k(&mut solver, &lits, 1);
277
278 let result = solver.solve();
279 assert_eq!(result, SolverResult::Sat);
280 }
281
282 #[test]
283 fn test_exactly_2() {
284 let mut solver = Solver::new();
285 let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
286 let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
287
288 CardinalityEncoder::encode_exactly_k(&mut solver, &lits, 2);
289
290 let result = solver.solve();
291 assert_eq!(result, SolverResult::Sat);
292 }
293
294 #[test]
295 fn test_at_most_exceeds_length() {
296 let mut solver = Solver::new();
297 let vars: Vec<Var> = (0..3).map(|_| solver.new_var()).collect();
298 let lits: Vec<Lit> = vars.iter().map(|&v| Lit::pos(v)).collect();
299
300 let success = CardinalityEncoder::encode_at_most_k(&mut solver, &lits, 5);
301 assert!(success);
302
303 let result = solver.solve();
304 assert_eq!(result, SolverResult::Sat);
305 }
306}