1use crate::literal::{LBool, Lit, Var};
38#[allow(unused_imports)]
39use crate::prelude::*;
40use crate::solver::{Solver, SolverResult};
41use smallvec::SmallVec;
42
43pub type Model = Vec<Lit>;
45
46#[derive(Debug, Clone, Default)]
48pub struct EnumerationConfig {
49 pub max_models: Option<usize>,
51 pub project_vars: Option<HashSet<Var>>,
53 pub minimal_models: bool,
55 pub maximal_models: bool,
57 pub block_positive_only: bool,
59}
60
61impl EnumerationConfig {
62 #[must_use]
64 pub fn all() -> Self {
65 Self::default()
66 }
67
68 #[must_use]
70 pub fn limited(max_models: usize) -> Self {
71 Self {
72 max_models: Some(max_models),
73 ..Default::default()
74 }
75 }
76
77 #[must_use]
79 pub fn projected(vars: HashSet<Var>) -> Self {
80 Self {
81 project_vars: Some(vars),
82 ..Default::default()
83 }
84 }
85
86 pub fn with_max_models(mut self, max: usize) -> Self {
88 self.max_models = Some(max);
89 self
90 }
91
92 pub fn with_projection(mut self, vars: HashSet<Var>) -> Self {
94 self.project_vars = Some(vars);
95 self
96 }
97
98 #[must_use]
100 pub const fn minimal(mut self) -> Self {
101 self.minimal_models = true;
102 self
103 }
104
105 #[must_use]
107 pub const fn maximal(mut self) -> Self {
108 self.maximal_models = true;
109 self
110 }
111}
112
113#[derive(Debug, Default, Clone)]
115pub struct EnumerationStats {
116 pub models_found: usize,
118 pub solver_calls: usize,
120 pub blocking_clauses: usize,
122 pub total_literals: usize,
124}
125
126impl EnumerationStats {
127 #[must_use]
129 pub fn avg_model_size(&self) -> f64 {
130 if self.models_found == 0 {
131 0.0
132 } else {
133 self.total_literals as f64 / self.models_found as f64
134 }
135 }
136}
137
138#[derive(Debug, Clone)]
140pub enum EnumerationResult {
141 Complete(Vec<Model>),
143 Incomplete(Vec<Model>),
145 Unsat,
147}
148
149impl EnumerationResult {
150 #[must_use]
152 pub fn models(&self) -> &[Model] {
153 match self {
154 EnumerationResult::Complete(models) | EnumerationResult::Incomplete(models) => models,
155 EnumerationResult::Unsat => &[],
156 }
157 }
158
159 #[must_use]
161 pub const fn is_complete(&self) -> bool {
162 matches!(self, EnumerationResult::Complete(_))
163 }
164
165 #[must_use]
167 pub fn count(&self) -> usize {
168 self.models().len()
169 }
170}
171
172pub struct AllSatEnumerator {
174 config: EnumerationConfig,
176 stats: EnumerationStats,
178 models: Vec<Model>,
180}
181
182impl AllSatEnumerator {
183 #[must_use]
185 pub fn new(config: EnumerationConfig) -> Self {
186 Self {
187 config,
188 stats: EnumerationStats::default(),
189 models: Vec::new(),
190 }
191 }
192
193 #[must_use]
195 pub fn default_config() -> Self {
196 Self::new(EnumerationConfig::default())
197 }
198
199 #[must_use]
201 pub const fn stats(&self) -> &EnumerationStats {
202 &self.stats
203 }
204
205 #[must_use]
207 pub fn models(&self) -> &[Model] {
208 &self.models
209 }
210
211 pub fn reset(&mut self) {
213 self.models.clear();
214 self.stats = EnumerationStats::default();
215 }
216
217 pub fn enumerate(&mut self, solver: &mut Solver, num_vars: usize) -> EnumerationResult {
228 self.reset();
229
230 loop {
231 if let Some(max) = self.config.max_models
233 && self.models.len() >= max
234 {
235 return EnumerationResult::Complete(self.models.clone());
236 }
237
238 self.stats.solver_calls += 1;
240 let result = solver.solve();
241
242 match result {
243 SolverResult::Sat => {
244 let model = self.extract_model(solver, num_vars);
246
247 solver.backtrack_to_root();
250
251 if self.is_valid_model(&model) {
253 self.stats.models_found += 1;
254 self.stats.total_literals += model.len();
255 self.models.push(model.clone());
256
257 let blocking_clause = self.create_blocking_clause(&model);
259 self.stats.blocking_clauses += 1;
260 if !solver.add_clause(blocking_clause.iter().copied()) {
261 return EnumerationResult::Complete(self.models.clone());
263 }
264 } else {
265 let blocking_clause = self.create_blocking_clause(&model);
267 self.stats.blocking_clauses += 1;
268 if !solver.add_clause(blocking_clause.iter().copied()) {
269 return EnumerationResult::Complete(self.models.clone());
270 }
271 }
272 }
273 SolverResult::Unsat => {
274 if self.models.is_empty() {
276 return EnumerationResult::Unsat;
277 }
278 return EnumerationResult::Complete(self.models.clone());
279 }
280 SolverResult::Unknown => {
281 return EnumerationResult::Incomplete(self.models.clone());
283 }
284 }
285 }
286 }
287
288 fn extract_model(&self, solver: &Solver, num_vars: usize) -> Model {
290 let mut model = Vec::new();
291
292 for i in 0..num_vars {
293 let var = Var(i as u32);
294 let value = solver.model_value(var);
295
296 if value == LBool::Undef {
298 continue;
299 }
300
301 if let Some(ref project_vars) = self.config.project_vars
303 && !project_vars.contains(&var)
304 {
305 continue;
306 }
307
308 let lit = if value == LBool::True {
310 Lit::pos(var)
311 } else {
312 Lit::neg(var)
313 };
314
315 model.push(lit);
316 }
317
318 model
319 }
320
321 fn is_valid_model(&self, _model: &Model) -> bool {
323 if self.config.minimal_models {
324 }
330
331 if self.config.maximal_models {
332 }
335
336 true
337 }
338
339 fn create_blocking_clause(&self, model: &Model) -> SmallVec<[Lit; 32]> {
341 let mut clause = SmallVec::new();
342
343 for &lit in model {
344 if self.config.block_positive_only && lit.is_neg() {
348 continue;
350 }
351 clause.push(lit.negate());
352 }
353
354 if clause.is_empty() && !model.is_empty() {
356 clause.push(model[0].negate());
359 }
360
361 clause
362 }
363}
364
365impl AllSatEnumerator {
367 #[must_use]
378 pub fn enumerate_all(solver: &mut Solver, num_vars: usize) -> Vec<Model> {
379 let mut enumerator = Self::new(EnumerationConfig::all());
380 enumerator.enumerate(solver, num_vars).models().to_vec()
381 }
382
383 #[must_use]
391 pub fn enumerate_limited(
392 solver: &mut Solver,
393 num_vars: usize,
394 max_models: usize,
395 ) -> Vec<Model> {
396 let mut enumerator = Self::new(EnumerationConfig::limited(max_models));
397 enumerator.enumerate(solver, num_vars).models().to_vec()
398 }
399
400 pub fn count_models(solver: &mut Solver, num_vars: usize, max_count: Option<usize>) -> usize {
412 let config = if let Some(max) = max_count {
413 EnumerationConfig::limited(max)
414 } else {
415 EnumerationConfig::all()
416 };
417
418 let mut enumerator = Self::new(config);
419 let result = enumerator.enumerate(solver, num_vars);
420 result.count()
421 }
422
423 #[must_use]
431 pub fn enumerate_projected(
432 solver: &mut Solver,
433 num_vars: usize,
434 project_vars: HashSet<Var>,
435 ) -> Vec<Model> {
436 let mut enumerator = Self::new(EnumerationConfig::projected(project_vars));
437 enumerator.enumerate(solver, num_vars).models().to_vec()
438 }
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::solver::Solver;
445
446 #[test]
447 fn test_enumerate_simple() {
448 let mut solver = Solver::new();
449 solver.add_clause([Lit::pos(Var(0)), Lit::pos(Var(1))]);
451
452 let models = AllSatEnumerator::enumerate_all(&mut solver, 2);
453 assert!(!models.is_empty());
456 }
457
458 #[test]
459 fn test_enumerate_unsat() {
460 let mut solver = Solver::new();
461 solver.add_clause([Lit::pos(Var(0))]);
463 solver.add_clause([Lit::neg(Var(0))]);
464
465 let mut enumerator = AllSatEnumerator::new(EnumerationConfig::all());
466 let result = enumerator.enumerate(&mut solver, 1);
467
468 assert!(matches!(result, EnumerationResult::Unsat));
469 assert_eq!(result.count(), 0);
470 }
471
472 #[test]
473 fn test_enumerate_limited() {
474 let mut solver = Solver::new();
475 solver.add_clause([Lit::pos(Var(0)), Lit::pos(Var(1))]);
477
478 let models = AllSatEnumerator::enumerate_limited(&mut solver, 2, 2);
479 assert!(models.len() <= 2);
480 }
481
482 #[test]
483 fn test_enumerate_single_var() {
484 let mut solver = Solver::new();
485 solver.add_clause([Lit::pos(Var(0)), Lit::neg(Var(0))]);
489
490 let models = AllSatEnumerator::enumerate_all(&mut solver, 1);
491 assert!(!models.is_empty());
494 }
495
496 #[test]
497 fn test_count_models() {
498 let mut solver = Solver::new();
499 solver.add_clause([Lit::pos(Var(0)), Lit::pos(Var(1))]);
500
501 let count = AllSatEnumerator::count_models(&mut solver, 2, Some(5));
502 assert!(count >= 1);
503 assert!(count <= 5);
504 }
505
506 #[test]
507 fn test_enumerator_stats() {
508 let mut solver = Solver::new();
509 solver.add_clause([Lit::pos(Var(0))]);
510
511 let mut enumerator = AllSatEnumerator::new(EnumerationConfig::limited(10));
512 enumerator.enumerate(&mut solver, 1);
513
514 let stats = enumerator.stats();
515 assert!(stats.models_found >= 1);
516 assert!(stats.solver_calls >= 1);
517 assert_eq!(stats.models_found, stats.blocking_clauses);
518 }
519
520 #[test]
521 fn test_projected_enumeration() {
522 let mut solver = Solver::new();
523 solver.add_clause([Lit::pos(Var(0))]);
525 solver.add_clause([Lit::pos(Var(1)), Lit::pos(Var(2))]);
526
527 let mut project_vars = HashSet::new();
529 project_vars.insert(Var(0));
530 project_vars.insert(Var(1));
531
532 let models = AllSatEnumerator::enumerate_projected(&mut solver, 3, project_vars);
533 assert!(!models.is_empty());
535
536 for model in &models {
538 for lit in model {
539 assert!(lit.var() == Var(0) || lit.var() == Var(1));
540 }
541 }
542 }
543
544 #[test]
545 fn test_blocking_clause_creation() {
546 let enumerator = AllSatEnumerator::new(EnumerationConfig::all());
547 let model = vec![Lit::pos(Var(0)), Lit::neg(Var(1)), Lit::pos(Var(2))];
548
549 let blocking = enumerator.create_blocking_clause(&model);
550
551 assert_eq!(blocking.len(), 3);
553 assert!(blocking.contains(&Lit::neg(Var(0))));
554 assert!(blocking.contains(&Lit::pos(Var(1))));
555 assert!(blocking.contains(&Lit::neg(Var(2))));
556 }
557
558 #[test]
559 fn test_enumeration_result_methods() {
560 let models = vec![vec![Lit::pos(Var(0))], vec![Lit::neg(Var(0))]];
561 let result = EnumerationResult::Complete(models.clone());
562
563 assert!(result.is_complete());
564 assert_eq!(result.count(), 2);
565 assert_eq!(result.models().len(), 2);
566
567 let unsat = EnumerationResult::Unsat;
568 assert_eq!(unsat.count(), 0);
569 assert!(unsat.models().is_empty());
570 }
571
572 #[test]
573 fn test_config_builders() {
574 let config = EnumerationConfig::all();
575 assert!(config.max_models.is_none());
576
577 let config = EnumerationConfig::limited(10);
578 assert_eq!(config.max_models, Some(10));
579
580 let mut vars = HashSet::new();
581 vars.insert(Var(0));
582 let config = EnumerationConfig::projected(vars.clone());
583 assert!(config.project_vars.is_some());
584
585 let config = EnumerationConfig::all().minimal().maximal();
586 assert!(config.minimal_models);
587 assert!(config.maximal_models);
588 }
589
590 #[test]
591 fn test_stats_avg_model_size() {
592 let mut stats = EnumerationStats::default();
593 assert_eq!(stats.avg_model_size(), 0.0);
594
595 stats.models_found = 3;
596 stats.total_literals = 12;
597 assert_eq!(stats.avg_model_size(), 4.0);
598 }
599
600 #[test]
601 fn test_reset() {
602 let mut solver = Solver::new();
603 solver.add_clause([Lit::pos(Var(0))]);
604
605 let mut enumerator = AllSatEnumerator::new(EnumerationConfig::limited(5));
606 enumerator.enumerate(&mut solver, 1);
607
608 assert!(!enumerator.models().is_empty());
609
610 enumerator.reset();
611 assert!(enumerator.models().is_empty());
612 assert_eq!(enumerator.stats().models_found, 0);
613 }
614}