aces/
solver.rs

1use std::{
2    collections::{BTreeSet, BTreeMap},
3    mem, fmt,
4};
5use varisat::{Var, Lit, ExtendFormula, solver::SolverError};
6use crate::{
7    ContextHandle, Contextual, NodeID, AtomID, ForkID, JoinID, Harc, AcesError, AcesErrorKind,
8    atom::Atom,
9    sat::{CEVar, CELit, Encoding, Search, Clause, Formula},
10};
11
12#[derive(Clone, Default, Debug)]
13pub(crate) struct Props {
14    pub(crate) sat_encoding: Option<Encoding>,
15    pub(crate) sat_search:   Option<Search>,
16}
17
18impl Props {
19    pub(crate) fn clear(&mut self) {
20        *self = Default::default();
21    }
22}
23
24enum ModelSearchResult {
25    Reset,
26    Found(Vec<Lit>),
27    Done,
28    Failed(SolverError),
29}
30
31impl ModelSearchResult {
32    #[allow(dead_code)]
33    fn get_model(&self) -> Option<&[Lit]> {
34        match self {
35            ModelSearchResult::Found(ref v) => Some(v.as_slice()),
36            _ => None,
37        }
38    }
39
40    #[inline]
41    #[allow(dead_code)]
42    fn take(&mut self) -> Self {
43        mem::replace(self, ModelSearchResult::Reset)
44    }
45
46    fn take_error(&mut self) -> Option<SolverError> {
47        let old_result = match self {
48            ModelSearchResult::Failed(_) => mem::replace(self, ModelSearchResult::Reset),
49            _ => return None,
50        };
51
52        if let ModelSearchResult::Failed(err) = old_result {
53            Some(err)
54        } else {
55            unreachable!()
56        }
57    }
58}
59
60impl Default for ModelSearchResult {
61    fn default() -> Self {
62        ModelSearchResult::Reset
63    }
64}
65
66#[derive(Default, Debug)]
67struct Assumptions {
68    literals:      Vec<Lit>,
69    permanent_len: usize,
70}
71
72impl Assumptions {
73    fn block_variable(&mut self, var: Var) {
74        let lit = Lit::from_var(var, false);
75
76        let pos = if self.permanent_len > 0 {
77            match self.literals[..self.permanent_len].binary_search(&lit) {
78                Ok(_) => return,
79                Err(pos) => pos,
80            }
81        } else {
82            0
83        };
84
85        self.literals.insert(pos, lit);
86        self.permanent_len += 1;
87    }
88
89    fn unblock_variable(&mut self, var: Var) -> bool {
90        let lit = Lit::from_var(var, false);
91
92        if self.permanent_len > 0 {
93            match self.literals[..self.permanent_len].binary_search(&lit) {
94                Ok(pos) => {
95                    self.literals.remove(pos);
96                    self.permanent_len -= 1;
97                    true
98                }
99                Err(_) => false,
100            }
101        } else {
102            false
103        }
104    }
105
106    fn unblock_all_variables(&mut self) {
107        if self.permanent_len > 0 {
108            let new_literals = self.literals.split_off(self.permanent_len);
109
110            self.literals = new_literals;
111            self.permanent_len = 0;
112        }
113    }
114
115    #[inline]
116    fn is_empty(&self) -> bool {
117        self.literals.len() == self.permanent_len
118    }
119
120    #[inline]
121    fn reset(&mut self) {
122        self.literals.truncate(self.permanent_len);
123    }
124
125    #[inline]
126    fn add(&mut self, lit: Lit) {
127        self.literals.push(lit);
128    }
129
130    #[inline]
131    fn get_literals(&self) -> &[Lit] {
132        assert!(self.literals.len() >= self.permanent_len);
133
134        self.literals.as_slice()
135    }
136}
137
138impl Contextual for Assumptions {
139    fn format(&self, ctx: &ContextHandle) -> Result<String, AcesError> {
140        self.literals.format(ctx)
141    }
142}
143
144pub struct Solver<'a> {
145    context:     ContextHandle,
146    engine:      varisat::Solver<'a>,
147    all_vars:    BTreeSet<Var>,
148    is_sat:      Option<bool>,
149    last_model:  ModelSearchResult,
150    min_residue: BTreeSet<Var>,
151    assumptions: Assumptions,
152}
153
154impl<'a> Solver<'a> {
155    pub fn new(ctx: &ContextHandle) -> Self {
156        Self {
157            context:     ctx.clone(),
158            engine:      Default::default(),
159            all_vars:    Default::default(),
160            is_sat:      None,
161            last_model:  Default::default(),
162            min_residue: Default::default(),
163            assumptions: Default::default(),
164        }
165    }
166
167    pub fn reset(&mut self) -> Result<(), SolverError> {
168        self.is_sat = None;
169        self.last_model = ModelSearchResult::Reset;
170        self.min_residue.clear();
171        self.assumptions.reset();
172        self.engine.close_proof()
173    }
174
175    pub fn block_atom_id(&mut self, atom_id: AtomID) {
176        let var = Var::from_atom_id(atom_id);
177
178        self.assumptions.block_variable(var);
179    }
180
181    pub fn unblock_atom_id(&mut self, atom_id: AtomID) -> bool {
182        let var = Var::from_atom_id(atom_id);
183
184        self.assumptions.unblock_variable(var)
185    }
186
187    pub fn unblock_all_atoms(&mut self) {
188        self.assumptions.unblock_all_variables();
189    }
190
191    /// Only for internal use.
192    fn add_clause(&mut self, clause: Clause) -> Result<(), AcesError> {
193        if clause.is_empty() {
194            Err(AcesErrorKind::EmptyClauseRejectedBySolver(clause.get_info().to_owned())
195                .with_context(&self.context))
196        } else {
197            debug!("Add (to solver) {} clause: {}", clause.get_info(), clause.with(&self.context));
198
199            self.engine.add_clause(clause.get_literals());
200
201            Ok(())
202        }
203    }
204
205    pub fn add_formula(&mut self, formula: &Formula) -> Result<(), AcesError> {
206        self.engine.add_formula(formula.get_cnf());
207
208        let new_vars = formula.get_variables();
209        self.all_vars.extend(new_vars);
210
211        Ok(())
212    }
213
214    /// Blocks empty solution models by adding a _void inhibition_
215    /// clause.
216    ///
217    /// A model represents an empty solution iff it contains only
218    /// negative [`Port`] literals (hence no [`Port`] variable
219    /// evaluates to `true`).  Thus, the blocking clause is the
220    /// disjunction of all [`Port`] variables known by the solver.
221    ///
222    /// [`Port`]: crate::Port
223    pub fn inhibit_empty_solution(&mut self) -> Result<(), AcesError> {
224        let clause = {
225            let ctx = self.context.lock().unwrap();
226            let mut all_lits: Vec<_> = self
227                .all_vars
228                .iter()
229                .filter_map(|&var| {
230                    ctx.is_port(var.into_atom_id()).then(|| Lit::from_var(var, true))
231                })
232                .collect();
233            let mut fork_lits: Vec<_> = self
234                .all_vars
235                .iter()
236                .filter_map(|&var| {
237                    ctx.is_fork(var.into_atom_id()).then(|| Lit::from_var(var, true))
238                })
239                .collect();
240            let mut join_lits: Vec<_> = self
241                .all_vars
242                .iter()
243                .filter_map(|&var| {
244                    ctx.is_join(var.into_atom_id()).then(|| Lit::from_var(var, true))
245                })
246                .collect();
247
248            // Include all fork variables or all join variables,
249            // depending on in which case the clause grows less.
250            if fork_lits.len() > join_lits.len() {
251                if join_lits.is_empty() {
252                    return Err(AcesErrorKind::IncoherencyLeak.with_context(&self.context))
253                } else {
254                    all_lits.append(&mut join_lits);
255                }
256            } else if !fork_lits.is_empty() {
257                all_lits.append(&mut fork_lits);
258            } else if !join_lits.is_empty() {
259                return Err(AcesErrorKind::IncoherencyLeak.with_context(&self.context))
260            }
261
262            Clause::from_vec(all_lits, "void inhibition")
263        };
264
265        self.add_clause(clause)
266    }
267
268    /// Adds a _model inhibition_ clause which will remove a specific
269    /// `model` from solution space.
270    ///
271    /// The blocking clause is constructed by negating the given
272    /// `model`, i.e. by taking the disjunction of all _explicit_
273    /// literals and reversing polarity of each.  A literal is
274    /// explicit iff its variable has been _registered_ by occurring
275    /// in a formula passed to a call to [`add_formula()`].
276    ///
277    /// [`add_formula()`]: Solver::add_formula()
278    pub fn inhibit_model(&mut self, model: &[Lit]) -> Result<(), AcesError> {
279        let anti_lits =
280            model.iter().filter_map(|&lit| self.all_vars.contains(&lit.var()).then(|| !lit));
281        let clause = Clause::from_literals(anti_lits, "model inhibition");
282
283        self.add_clause(clause)
284    }
285
286    fn inhibit_last_model(&mut self) -> Result<(), AcesError> {
287        if let ModelSearchResult::Found(ref model) = self.last_model {
288            let anti_lits =
289                model.iter().filter_map(|&lit| self.all_vars.contains(&lit.var()).then(|| !lit));
290            let clause = Clause::from_literals(anti_lits, "model inhibition");
291
292            self.add_clause(clause)
293        } else {
294            Err(AcesErrorKind::NoModelToInhibit.with_context(&self.context))
295        }
296    }
297
298    fn reduce_model(&mut self, model: &[Lit]) -> Result<bool, AcesError> {
299        let mut reducing_lits = Vec::new();
300
301        for &lit in model.iter() {
302            if !self.min_residue.contains(&lit.var()) {
303                if lit.is_positive() {
304                    reducing_lits.push(!lit);
305                } else {
306                    self.assumptions.add(lit);
307                    self.min_residue.insert(lit.var());
308                }
309            }
310        }
311
312        if reducing_lits.is_empty() {
313            Ok(false)
314        } else {
315            let clause = Clause::from_literals(reducing_lits.into_iter(), "model reduction");
316            self.add_clause(clause)?;
317
318            Ok(true)
319        }
320    }
321
322    fn solve(&mut self) -> Option<bool> {
323        if !self.assumptions.is_empty() {
324            debug!("Solving under assumptions: {}", self.assumptions.with(&self.context));
325        }
326
327        self.engine.assume(self.assumptions.get_literals());
328
329        let result = self.engine.solve();
330
331        if self.is_sat.is_none() {
332            self.is_sat = result.as_ref().ok().copied();
333        }
334
335        match result {
336            Ok(is_sat) => {
337                if is_sat {
338                    if let Some(model) = self.engine.model() {
339                        self.last_model = ModelSearchResult::Found(model);
340                        Some(true)
341                    } else {
342                        warn!("Solver reported SAT without a model");
343
344                        self.last_model = ModelSearchResult::Done;
345                        Some(false)
346                    }
347                } else {
348                    self.last_model = ModelSearchResult::Done;
349                    Some(false)
350                }
351            }
352            Err(err) => {
353                self.last_model = ModelSearchResult::Failed(err);
354                None
355            }
356        }
357    }
358
359    pub(crate) fn is_sat(&self) -> Option<bool> {
360        self.is_sat
361    }
362
363    /// Returns `true` if last call to [`solve()`] was interrupted.
364    /// Returns `false` if [`solve()`] either failed, or succeeded, or
365    /// hasn't been called yet.
366    ///
367    /// Note, that even if last call to [`solve()`] was indeed
368    /// interrupted, a subsequent invocation of [`take_last_error()`]
369    /// resets this to return `false` until next [`solve()`].
370    ///
371    /// [`solve()`]: Solver::solve()
372    /// [`take_last_error()`]: Solver::take_last_error()
373    pub fn was_interrupted(&self) -> bool {
374        if let ModelSearchResult::Failed(ref err) = self.last_model {
375            err.is_recoverable()
376        } else {
377            false
378        }
379    }
380
381    pub fn last_solution(&self) -> Option<Solution> {
382        self.engine.model().and_then(|model| match Solution::from_model(&self.context, model) {
383            Ok(solution) => Some(solution),
384            Err(err) => {
385                warn!("{} in solver's solution ctor", err);
386                None
387            }
388        })
389    }
390
391    /// Returns the error reported by last call to [`solve()`], if
392    /// solving has failed; otherwise returns `None`.
393    ///
394    /// Note, that this may be invoked only once for every
395    /// unsuccessful call to [`solve()`], because, in varisat 0.2,
396    /// `varisat::solver::SolverError` can't be cloned.
397    ///
398    /// [`solve()`]: Solver::solve()
399    pub(crate) fn take_last_error(&mut self) -> Option<SolverError> {
400        self.last_model.take_error()
401    }
402
403    fn next_solution(&mut self) -> Option<Solution> {
404        self.solve();
405
406        if let ModelSearchResult::Found(ref model) = self.last_model {
407            match Solution::from_model(&self.context, model.iter().copied()) {
408                Ok(solution) => {
409                    if let Err(err) = self.inhibit_last_model() {
410                        warn!("{} in solver's iteration", err);
411
412                        None
413                    } else {
414                        Some(solution)
415                    }
416                }
417                Err(err) => {
418                    warn!("{} in solver's iteration", err);
419                    None
420                }
421            }
422        } else {
423            None
424        }
425    }
426
427    fn next_minimal_solution(&mut self) -> Option<Solution> {
428        self.assumptions.reset();
429
430        self.solve();
431
432        if let ModelSearchResult::Found(ref top_model) = self.last_model {
433            let top_model = top_model.clone();
434
435            trace!("Top model: {:?}", top_model);
436
437            self.min_residue.clear();
438            self.assumptions.reset();
439
440            let mut model = top_model.clone();
441
442            loop {
443                match self.reduce_model(&model) {
444                    Ok(true) => {}
445                    Ok(false) => break,
446                    Err(err) => {
447                        warn!("{} in solver's iteration", err);
448                        return None
449                    }
450                }
451
452                self.solve();
453
454                if let ModelSearchResult::Found(ref reduced_model) = self.last_model {
455                    trace!("Reduced model: {:?}", reduced_model);
456                    model = reduced_model.clone();
457                } else {
458                    break
459                }
460            }
461
462            let min_model = top_model
463                .iter()
464                .map(|lit| Lit::from_var(lit.var(), !self.min_residue.contains(&lit.var())));
465
466            match Solution::from_model(&self.context, min_model) {
467                Ok(solution) => Some(solution),
468                Err(err) => {
469                    warn!("{} in solver's iteration", err);
470                    None
471                }
472            }
473        } else {
474            None
475        }
476    }
477}
478
479impl Iterator for Solver<'_> {
480    type Item = Solution;
481
482    fn next(&mut self) -> Option<Self::Item> {
483        let search = self.context.lock().unwrap().get_search().unwrap_or(Search::MinSolutions);
484
485        match search {
486            Search::MinSolutions => self.next_minimal_solution(),
487            Search::AllSolutions => self.next_solution(),
488        }
489    }
490}
491
492pub struct Solution {
493    context:  ContextHandle,
494    model:    Vec<Lit>,
495    pre_set:  Vec<NodeID>,
496    post_set: Vec<NodeID>,
497    fork_set: Vec<ForkID>,
498    join_set: Vec<JoinID>,
499}
500
501impl Solution {
502    fn new(ctx: &ContextHandle) -> Self {
503        Self {
504            context:  ctx.clone(),
505            model:    Default::default(),
506            pre_set:  Default::default(),
507            post_set: Default::default(),
508            fork_set: Default::default(),
509            join_set: Default::default(),
510        }
511    }
512
513    fn from_model<I: IntoIterator<Item = Lit>>(
514        ctx: &ContextHandle,
515        model: I,
516    ) -> Result<Self, AcesError> {
517        let mut solution = Self::new(ctx);
518
519        let mut pre_set: BTreeSet<NodeID> = BTreeSet::new();
520        let mut post_set: BTreeSet<NodeID> = BTreeSet::new();
521        let mut fork_map: BTreeMap<NodeID, BTreeSet<NodeID>> = BTreeMap::new();
522        let mut join_map: BTreeMap<NodeID, BTreeSet<NodeID>> = BTreeMap::new();
523        let mut fork_set: BTreeSet<ForkID> = BTreeSet::new();
524        let mut join_set: BTreeSet<JoinID> = BTreeSet::new();
525
526        for lit in model {
527            solution.model.push(lit);
528
529            if lit.is_positive() {
530                let (atom_id, _) = lit.into_atom_id();
531                let ctx = solution.context.lock().unwrap();
532
533                if let Some(atom) = ctx.get_atom(atom_id) {
534                    match atom {
535                        Atom::Tx(port) => {
536                            pre_set.insert(port.get_node_id());
537                        }
538                        Atom::Rx(port) => {
539                            post_set.insert(port.get_node_id());
540                        }
541                        Atom::Link(link) => {
542                            let tx_node_id = link.get_tx_node_id();
543                            let rx_node_id = link.get_rx_node_id();
544
545                            fork_map
546                                .entry(tx_node_id)
547                                .or_insert_with(BTreeSet::new)
548                                .insert(rx_node_id);
549                            join_map
550                                .entry(rx_node_id)
551                                .or_insert_with(BTreeSet::new)
552                                .insert(tx_node_id);
553                        }
554                        Atom::Fork(fork) => {
555                            if let Some(fork_id) = fork.get_fork_id() {
556                                pre_set.insert(fork.get_host_id());
557                                fork_set.insert(fork_id);
558                            } else if let Some(join_id) = fork.get_join_id() {
559                                return Err(AcesErrorKind::HarcNotAForkMismatch(join_id)
560                                    .with_context(&solution.context))
561                            } else {
562                                unreachable!()
563                            }
564                        }
565                        Atom::Join(join) => {
566                            if let Some(join_id) = join.get_join_id() {
567                                post_set.insert(join.get_host_id());
568                                join_set.insert(join_id);
569                            } else if let Some(fork_id) = join.get_fork_id() {
570                                return Err(AcesErrorKind::HarcNotAJoinMismatch(fork_id)
571                                    .with_context(&solution.context))
572                            } else {
573                                unreachable!()
574                            }
575                        }
576                        Atom::Bottom => {
577                            return Err(
578                                AcesErrorKind::BottomAtomAccess.with_context(&solution.context)
579                            )
580                        }
581                    }
582                } else {
583                    return Err(
584                        AcesErrorKind::AtomMissingForID(atom_id).with_context(&solution.context)
585                    )
586                }
587            }
588        }
589
590        fork_set.extend(fork_map.into_iter().map(|(host, suit)| {
591            let mut fork = Harc::new_fork_unchecked(host, suit);
592            solution.context.lock().unwrap().share_fork(&mut fork)
593        }));
594
595        join_set.extend(join_map.into_iter().map(|(host, suit)| {
596            let mut join = Harc::new_join_unchecked(host, suit);
597            solution.context.lock().unwrap().share_join(&mut join)
598        }));
599
600        solution.pre_set.extend(pre_set.into_iter());
601        solution.post_set.extend(post_set.into_iter());
602        solution.fork_set.extend(fork_set.into_iter());
603        solution.join_set.extend(join_set.into_iter());
604
605        Ok(solution)
606    }
607
608    pub fn get_context(&self) -> &ContextHandle {
609        &self.context
610    }
611
612    pub fn get_model(&self) -> &[Lit] {
613        self.model.as_slice()
614    }
615
616    pub fn get_pre_set(&self) -> &[NodeID] {
617        self.pre_set.as_slice()
618    }
619
620    pub fn get_post_set(&self) -> &[NodeID] {
621        self.post_set.as_slice()
622    }
623
624    pub fn get_fork_set(&self) -> &[ForkID] {
625        self.fork_set.as_slice()
626    }
627
628    pub fn get_join_set(&self) -> &[JoinID] {
629        self.join_set.as_slice()
630    }
631}
632
633impl fmt::Debug for Solution {
634    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
635        write!(
636            f,
637            "Solution {{ model: {:?}, pre_set: {}, post_set: {}, fork_set: {}, join_set: {} }}",
638            self.model,
639            self.pre_set.with(&self.context),
640            self.post_set.with(&self.context),
641            self.fork_set.with(&self.context),
642            self.join_set.with(&self.context),
643        )
644    }
645}
646
647impl fmt::Display for Solution {
648    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
649        if self.pre_set.is_empty() {
650            write!(f, "{{}} => {{")?;
651        } else {
652            write!(f, "{{")?;
653
654            for node_id in self.pre_set.iter() {
655                write!(f, " {}", node_id.with(&self.context))?;
656            }
657
658            write!(f, " }} => {{")?;
659        }
660
661        if self.post_set.is_empty() {
662            write!(f, "}}")?;
663        } else {
664            for node_id in self.post_set.iter() {
665                write!(f, " {}", node_id.with(&self.context))?;
666            }
667
668            write!(f, " }}")?;
669        }
670
671        Ok(())
672    }
673}