1use std::{
2 collections::{BTreeSet, BTreeMap},
3 mem, fmt,
4};
5use varisat::{Var, Lit, ExtendFormula, solver::SolverError};
6use crate::{
7 ContextHandle, Contextual, NodeID, AtomID, ForkID, JoinID, Harc, AcesError, AcesErrorKind,
8 atom::Atom,
9 sat::{CEVar, CELit, Encoding, Search, Clause, Formula},
10};
11
12#[derive(Clone, Default, Debug)]
13pub(crate) struct Props {
14 pub(crate) sat_encoding: Option<Encoding>,
15 pub(crate) sat_search: Option<Search>,
16}
17
18impl Props {
19 pub(crate) fn clear(&mut self) {
20 *self = Default::default();
21 }
22}
23
24enum ModelSearchResult {
25 Reset,
26 Found(Vec<Lit>),
27 Done,
28 Failed(SolverError),
29}
30
31impl ModelSearchResult {
32 #[allow(dead_code)]
33 fn get_model(&self) -> Option<&[Lit]> {
34 match self {
35 ModelSearchResult::Found(ref v) => Some(v.as_slice()),
36 _ => None,
37 }
38 }
39
40 #[inline]
41 #[allow(dead_code)]
42 fn take(&mut self) -> Self {
43 mem::replace(self, ModelSearchResult::Reset)
44 }
45
46 fn take_error(&mut self) -> Option<SolverError> {
47 let old_result = match self {
48 ModelSearchResult::Failed(_) => mem::replace(self, ModelSearchResult::Reset),
49 _ => return None,
50 };
51
52 if let ModelSearchResult::Failed(err) = old_result {
53 Some(err)
54 } else {
55 unreachable!()
56 }
57 }
58}
59
60impl Default for ModelSearchResult {
61 fn default() -> Self {
62 ModelSearchResult::Reset
63 }
64}
65
66#[derive(Default, Debug)]
67struct Assumptions {
68 literals: Vec<Lit>,
69 permanent_len: usize,
70}
71
72impl Assumptions {
73 fn block_variable(&mut self, var: Var) {
74 let lit = Lit::from_var(var, false);
75
76 let pos = if self.permanent_len > 0 {
77 match self.literals[..self.permanent_len].binary_search(&lit) {
78 Ok(_) => return,
79 Err(pos) => pos,
80 }
81 } else {
82 0
83 };
84
85 self.literals.insert(pos, lit);
86 self.permanent_len += 1;
87 }
88
89 fn unblock_variable(&mut self, var: Var) -> bool {
90 let lit = Lit::from_var(var, false);
91
92 if self.permanent_len > 0 {
93 match self.literals[..self.permanent_len].binary_search(&lit) {
94 Ok(pos) => {
95 self.literals.remove(pos);
96 self.permanent_len -= 1;
97 true
98 }
99 Err(_) => false,
100 }
101 } else {
102 false
103 }
104 }
105
106 fn unblock_all_variables(&mut self) {
107 if self.permanent_len > 0 {
108 let new_literals = self.literals.split_off(self.permanent_len);
109
110 self.literals = new_literals;
111 self.permanent_len = 0;
112 }
113 }
114
115 #[inline]
116 fn is_empty(&self) -> bool {
117 self.literals.len() == self.permanent_len
118 }
119
120 #[inline]
121 fn reset(&mut self) {
122 self.literals.truncate(self.permanent_len);
123 }
124
125 #[inline]
126 fn add(&mut self, lit: Lit) {
127 self.literals.push(lit);
128 }
129
130 #[inline]
131 fn get_literals(&self) -> &[Lit] {
132 assert!(self.literals.len() >= self.permanent_len);
133
134 self.literals.as_slice()
135 }
136}
137
138impl Contextual for Assumptions {
139 fn format(&self, ctx: &ContextHandle) -> Result<String, AcesError> {
140 self.literals.format(ctx)
141 }
142}
143
144pub struct Solver<'a> {
145 context: ContextHandle,
146 engine: varisat::Solver<'a>,
147 all_vars: BTreeSet<Var>,
148 is_sat: Option<bool>,
149 last_model: ModelSearchResult,
150 min_residue: BTreeSet<Var>,
151 assumptions: Assumptions,
152}
153
154impl<'a> Solver<'a> {
155 pub fn new(ctx: &ContextHandle) -> Self {
156 Self {
157 context: ctx.clone(),
158 engine: Default::default(),
159 all_vars: Default::default(),
160 is_sat: None,
161 last_model: Default::default(),
162 min_residue: Default::default(),
163 assumptions: Default::default(),
164 }
165 }
166
167 pub fn reset(&mut self) -> Result<(), SolverError> {
168 self.is_sat = None;
169 self.last_model = ModelSearchResult::Reset;
170 self.min_residue.clear();
171 self.assumptions.reset();
172 self.engine.close_proof()
173 }
174
175 pub fn block_atom_id(&mut self, atom_id: AtomID) {
176 let var = Var::from_atom_id(atom_id);
177
178 self.assumptions.block_variable(var);
179 }
180
181 pub fn unblock_atom_id(&mut self, atom_id: AtomID) -> bool {
182 let var = Var::from_atom_id(atom_id);
183
184 self.assumptions.unblock_variable(var)
185 }
186
187 pub fn unblock_all_atoms(&mut self) {
188 self.assumptions.unblock_all_variables();
189 }
190
191 fn add_clause(&mut self, clause: Clause) -> Result<(), AcesError> {
193 if clause.is_empty() {
194 Err(AcesErrorKind::EmptyClauseRejectedBySolver(clause.get_info().to_owned())
195 .with_context(&self.context))
196 } else {
197 debug!("Add (to solver) {} clause: {}", clause.get_info(), clause.with(&self.context));
198
199 self.engine.add_clause(clause.get_literals());
200
201 Ok(())
202 }
203 }
204
205 pub fn add_formula(&mut self, formula: &Formula) -> Result<(), AcesError> {
206 self.engine.add_formula(formula.get_cnf());
207
208 let new_vars = formula.get_variables();
209 self.all_vars.extend(new_vars);
210
211 Ok(())
212 }
213
214 pub fn inhibit_empty_solution(&mut self) -> Result<(), AcesError> {
224 let clause = {
225 let ctx = self.context.lock().unwrap();
226 let mut all_lits: Vec<_> = self
227 .all_vars
228 .iter()
229 .filter_map(|&var| {
230 ctx.is_port(var.into_atom_id()).then(|| Lit::from_var(var, true))
231 })
232 .collect();
233 let mut fork_lits: Vec<_> = self
234 .all_vars
235 .iter()
236 .filter_map(|&var| {
237 ctx.is_fork(var.into_atom_id()).then(|| Lit::from_var(var, true))
238 })
239 .collect();
240 let mut join_lits: Vec<_> = self
241 .all_vars
242 .iter()
243 .filter_map(|&var| {
244 ctx.is_join(var.into_atom_id()).then(|| Lit::from_var(var, true))
245 })
246 .collect();
247
248 if fork_lits.len() > join_lits.len() {
251 if join_lits.is_empty() {
252 return Err(AcesErrorKind::IncoherencyLeak.with_context(&self.context))
253 } else {
254 all_lits.append(&mut join_lits);
255 }
256 } else if !fork_lits.is_empty() {
257 all_lits.append(&mut fork_lits);
258 } else if !join_lits.is_empty() {
259 return Err(AcesErrorKind::IncoherencyLeak.with_context(&self.context))
260 }
261
262 Clause::from_vec(all_lits, "void inhibition")
263 };
264
265 self.add_clause(clause)
266 }
267
268 pub fn inhibit_model(&mut self, model: &[Lit]) -> Result<(), AcesError> {
279 let anti_lits =
280 model.iter().filter_map(|&lit| self.all_vars.contains(&lit.var()).then(|| !lit));
281 let clause = Clause::from_literals(anti_lits, "model inhibition");
282
283 self.add_clause(clause)
284 }
285
286 fn inhibit_last_model(&mut self) -> Result<(), AcesError> {
287 if let ModelSearchResult::Found(ref model) = self.last_model {
288 let anti_lits =
289 model.iter().filter_map(|&lit| self.all_vars.contains(&lit.var()).then(|| !lit));
290 let clause = Clause::from_literals(anti_lits, "model inhibition");
291
292 self.add_clause(clause)
293 } else {
294 Err(AcesErrorKind::NoModelToInhibit.with_context(&self.context))
295 }
296 }
297
298 fn reduce_model(&mut self, model: &[Lit]) -> Result<bool, AcesError> {
299 let mut reducing_lits = Vec::new();
300
301 for &lit in model.iter() {
302 if !self.min_residue.contains(&lit.var()) {
303 if lit.is_positive() {
304 reducing_lits.push(!lit);
305 } else {
306 self.assumptions.add(lit);
307 self.min_residue.insert(lit.var());
308 }
309 }
310 }
311
312 if reducing_lits.is_empty() {
313 Ok(false)
314 } else {
315 let clause = Clause::from_literals(reducing_lits.into_iter(), "model reduction");
316 self.add_clause(clause)?;
317
318 Ok(true)
319 }
320 }
321
322 fn solve(&mut self) -> Option<bool> {
323 if !self.assumptions.is_empty() {
324 debug!("Solving under assumptions: {}", self.assumptions.with(&self.context));
325 }
326
327 self.engine.assume(self.assumptions.get_literals());
328
329 let result = self.engine.solve();
330
331 if self.is_sat.is_none() {
332 self.is_sat = result.as_ref().ok().copied();
333 }
334
335 match result {
336 Ok(is_sat) => {
337 if is_sat {
338 if let Some(model) = self.engine.model() {
339 self.last_model = ModelSearchResult::Found(model);
340 Some(true)
341 } else {
342 warn!("Solver reported SAT without a model");
343
344 self.last_model = ModelSearchResult::Done;
345 Some(false)
346 }
347 } else {
348 self.last_model = ModelSearchResult::Done;
349 Some(false)
350 }
351 }
352 Err(err) => {
353 self.last_model = ModelSearchResult::Failed(err);
354 None
355 }
356 }
357 }
358
359 pub(crate) fn is_sat(&self) -> Option<bool> {
360 self.is_sat
361 }
362
363 pub fn was_interrupted(&self) -> bool {
374 if let ModelSearchResult::Failed(ref err) = self.last_model {
375 err.is_recoverable()
376 } else {
377 false
378 }
379 }
380
381 pub fn last_solution(&self) -> Option<Solution> {
382 self.engine.model().and_then(|model| match Solution::from_model(&self.context, model) {
383 Ok(solution) => Some(solution),
384 Err(err) => {
385 warn!("{} in solver's solution ctor", err);
386 None
387 }
388 })
389 }
390
391 pub(crate) fn take_last_error(&mut self) -> Option<SolverError> {
400 self.last_model.take_error()
401 }
402
403 fn next_solution(&mut self) -> Option<Solution> {
404 self.solve();
405
406 if let ModelSearchResult::Found(ref model) = self.last_model {
407 match Solution::from_model(&self.context, model.iter().copied()) {
408 Ok(solution) => {
409 if let Err(err) = self.inhibit_last_model() {
410 warn!("{} in solver's iteration", err);
411
412 None
413 } else {
414 Some(solution)
415 }
416 }
417 Err(err) => {
418 warn!("{} in solver's iteration", err);
419 None
420 }
421 }
422 } else {
423 None
424 }
425 }
426
427 fn next_minimal_solution(&mut self) -> Option<Solution> {
428 self.assumptions.reset();
429
430 self.solve();
431
432 if let ModelSearchResult::Found(ref top_model) = self.last_model {
433 let top_model = top_model.clone();
434
435 trace!("Top model: {:?}", top_model);
436
437 self.min_residue.clear();
438 self.assumptions.reset();
439
440 let mut model = top_model.clone();
441
442 loop {
443 match self.reduce_model(&model) {
444 Ok(true) => {}
445 Ok(false) => break,
446 Err(err) => {
447 warn!("{} in solver's iteration", err);
448 return None
449 }
450 }
451
452 self.solve();
453
454 if let ModelSearchResult::Found(ref reduced_model) = self.last_model {
455 trace!("Reduced model: {:?}", reduced_model);
456 model = reduced_model.clone();
457 } else {
458 break
459 }
460 }
461
462 let min_model = top_model
463 .iter()
464 .map(|lit| Lit::from_var(lit.var(), !self.min_residue.contains(&lit.var())));
465
466 match Solution::from_model(&self.context, min_model) {
467 Ok(solution) => Some(solution),
468 Err(err) => {
469 warn!("{} in solver's iteration", err);
470 None
471 }
472 }
473 } else {
474 None
475 }
476 }
477}
478
479impl Iterator for Solver<'_> {
480 type Item = Solution;
481
482 fn next(&mut self) -> Option<Self::Item> {
483 let search = self.context.lock().unwrap().get_search().unwrap_or(Search::MinSolutions);
484
485 match search {
486 Search::MinSolutions => self.next_minimal_solution(),
487 Search::AllSolutions => self.next_solution(),
488 }
489 }
490}
491
492pub struct Solution {
493 context: ContextHandle,
494 model: Vec<Lit>,
495 pre_set: Vec<NodeID>,
496 post_set: Vec<NodeID>,
497 fork_set: Vec<ForkID>,
498 join_set: Vec<JoinID>,
499}
500
501impl Solution {
502 fn new(ctx: &ContextHandle) -> Self {
503 Self {
504 context: ctx.clone(),
505 model: Default::default(),
506 pre_set: Default::default(),
507 post_set: Default::default(),
508 fork_set: Default::default(),
509 join_set: Default::default(),
510 }
511 }
512
513 fn from_model<I: IntoIterator<Item = Lit>>(
514 ctx: &ContextHandle,
515 model: I,
516 ) -> Result<Self, AcesError> {
517 let mut solution = Self::new(ctx);
518
519 let mut pre_set: BTreeSet<NodeID> = BTreeSet::new();
520 let mut post_set: BTreeSet<NodeID> = BTreeSet::new();
521 let mut fork_map: BTreeMap<NodeID, BTreeSet<NodeID>> = BTreeMap::new();
522 let mut join_map: BTreeMap<NodeID, BTreeSet<NodeID>> = BTreeMap::new();
523 let mut fork_set: BTreeSet<ForkID> = BTreeSet::new();
524 let mut join_set: BTreeSet<JoinID> = BTreeSet::new();
525
526 for lit in model {
527 solution.model.push(lit);
528
529 if lit.is_positive() {
530 let (atom_id, _) = lit.into_atom_id();
531 let ctx = solution.context.lock().unwrap();
532
533 if let Some(atom) = ctx.get_atom(atom_id) {
534 match atom {
535 Atom::Tx(port) => {
536 pre_set.insert(port.get_node_id());
537 }
538 Atom::Rx(port) => {
539 post_set.insert(port.get_node_id());
540 }
541 Atom::Link(link) => {
542 let tx_node_id = link.get_tx_node_id();
543 let rx_node_id = link.get_rx_node_id();
544
545 fork_map
546 .entry(tx_node_id)
547 .or_insert_with(BTreeSet::new)
548 .insert(rx_node_id);
549 join_map
550 .entry(rx_node_id)
551 .or_insert_with(BTreeSet::new)
552 .insert(tx_node_id);
553 }
554 Atom::Fork(fork) => {
555 if let Some(fork_id) = fork.get_fork_id() {
556 pre_set.insert(fork.get_host_id());
557 fork_set.insert(fork_id);
558 } else if let Some(join_id) = fork.get_join_id() {
559 return Err(AcesErrorKind::HarcNotAForkMismatch(join_id)
560 .with_context(&solution.context))
561 } else {
562 unreachable!()
563 }
564 }
565 Atom::Join(join) => {
566 if let Some(join_id) = join.get_join_id() {
567 post_set.insert(join.get_host_id());
568 join_set.insert(join_id);
569 } else if let Some(fork_id) = join.get_fork_id() {
570 return Err(AcesErrorKind::HarcNotAJoinMismatch(fork_id)
571 .with_context(&solution.context))
572 } else {
573 unreachable!()
574 }
575 }
576 Atom::Bottom => {
577 return Err(
578 AcesErrorKind::BottomAtomAccess.with_context(&solution.context)
579 )
580 }
581 }
582 } else {
583 return Err(
584 AcesErrorKind::AtomMissingForID(atom_id).with_context(&solution.context)
585 )
586 }
587 }
588 }
589
590 fork_set.extend(fork_map.into_iter().map(|(host, suit)| {
591 let mut fork = Harc::new_fork_unchecked(host, suit);
592 solution.context.lock().unwrap().share_fork(&mut fork)
593 }));
594
595 join_set.extend(join_map.into_iter().map(|(host, suit)| {
596 let mut join = Harc::new_join_unchecked(host, suit);
597 solution.context.lock().unwrap().share_join(&mut join)
598 }));
599
600 solution.pre_set.extend(pre_set.into_iter());
601 solution.post_set.extend(post_set.into_iter());
602 solution.fork_set.extend(fork_set.into_iter());
603 solution.join_set.extend(join_set.into_iter());
604
605 Ok(solution)
606 }
607
608 pub fn get_context(&self) -> &ContextHandle {
609 &self.context
610 }
611
612 pub fn get_model(&self) -> &[Lit] {
613 self.model.as_slice()
614 }
615
616 pub fn get_pre_set(&self) -> &[NodeID] {
617 self.pre_set.as_slice()
618 }
619
620 pub fn get_post_set(&self) -> &[NodeID] {
621 self.post_set.as_slice()
622 }
623
624 pub fn get_fork_set(&self) -> &[ForkID] {
625 self.fork_set.as_slice()
626 }
627
628 pub fn get_join_set(&self) -> &[JoinID] {
629 self.join_set.as_slice()
630 }
631}
632
633impl fmt::Debug for Solution {
634 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
635 write!(
636 f,
637 "Solution {{ model: {:?}, pre_set: {}, post_set: {}, fork_set: {}, join_set: {} }}",
638 self.model,
639 self.pre_set.with(&self.context),
640 self.post_set.with(&self.context),
641 self.fork_set.with(&self.context),
642 self.join_set.with(&self.context),
643 )
644 }
645}
646
647impl fmt::Display for Solution {
648 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
649 if self.pre_set.is_empty() {
650 write!(f, "{{}} => {{")?;
651 } else {
652 write!(f, "{{")?;
653
654 for node_id in self.pre_set.iter() {
655 write!(f, " {}", node_id.with(&self.context))?;
656 }
657
658 write!(f, " }} => {{")?;
659 }
660
661 if self.post_set.is_empty() {
662 write!(f, "}}")?;
663 } else {
664 for node_id in self.post_set.iter() {
665 write!(f, " {}", node_id.with(&self.context))?;
666 }
667
668 write!(f, " }}")?;
669 }
670
671 Ok(())
672 }
673}