logic_form/dagcnf/
simplify.rs

1use super::{DagCnf, occur::Occurs};
2use crate::{LitMap, LitOrdVec, LitVec, LitVvec, Var, lemmas_subsume_simplify};
3use giputils::{allocator::Gallocator, grc::Grc, hash::GHashSet, heap::BinaryHeap};
4use std::iter::once;
5
6pub struct DagCnfSimplify {
7    cdb: Grc<Gallocator<LitOrdVec>>,
8    max_var: Var,
9    cnf: LitMap<Vec<u32>>,
10    occur: Grc<Occurs>,
11    frozen: GHashSet<Var>,
12    qbve: BinaryHeap<Var, Occurs>,
13}
14
15impl DagCnfSimplify {
16    pub fn new(dagcnf: &DagCnf) -> Self {
17        let cdb = Grc::new(Gallocator::new());
18        let max_var = dagcnf.max_var;
19        let occur = Grc::new(Occurs::new_with(max_var, cdb.clone()));
20        let cnf = LitMap::new_with(max_var);
21        let qbve = BinaryHeap::new(occur.clone());
22        let mut res = Self {
23            cdb,
24            occur,
25            max_var,
26            cnf,
27            frozen: Default::default(),
28            qbve,
29        };
30        for v in Var::CONST..=max_var {
31            for mut cls in dagcnf.cnf[v].clone() {
32                cls.cls_simp();
33                if cls.is_empty() {
34                    continue;
35                }
36                assert!(cls.last().var().eq(&v));
37                res.add_rel(cls);
38            }
39        }
40        for v in Var::CONST..=max_var {
41            res.qbve.push(v);
42        }
43        res
44    }
45
46    pub fn froze(&mut self, v: Var) {
47        self.frozen.insert(v);
48    }
49
50    fn add_rel(&mut self, rel: LitVec) {
51        let rel = LitOrdVec::new(rel);
52        let n = rel.last();
53        let relid = self.cdb.alloc(rel);
54        self.cnf[n].push(relid);
55        for &l in self.cdb[relid].iter() {
56            let lv = l.var();
57            if lv != n.var() {
58                self.occur.add(l, relid);
59                self.qbve.down(lv);
60            }
61        }
62    }
63
64    fn remove_rel(&mut self, rels: Vec<u32>) {
65        let relset = GHashSet::from_iter(rels.iter().copied());
66        let outs = GHashSet::from_iter(rels.iter().map(|&cls| self.cdb[cls].last()));
67        for o in outs {
68            let mut i = 0;
69            while i < self.cnf[o].len() {
70                if relset.contains(&self.cnf[o][i]) {
71                    let cls = self.cnf[o].swap_remove(i);
72                    for &l in self.cdb[cls].iter() {
73                        let lv = l.var();
74                        if lv != o.var() {
75                            self.occur.del(l, cls);
76                        }
77                    }
78                    self.cdb.dealloc(cls);
79                } else {
80                    i += 1;
81                }
82            }
83        }
84    }
85
86    #[inline]
87    fn remove_node(&mut self, n: Var) {
88        let ln = n.lit();
89        assert!(self.occur.num_occur(ln) == 0);
90        assert!(self.occur.num_occur(!ln) == 0);
91        for &cls in self.cnf[ln].iter().chain(self.cnf[!ln].iter()) {
92            for &l in self.cdb[cls].iter() {
93                let lv = l.var();
94                if lv != n {
95                    self.occur.del(l, cls);
96                    self.qbve.up(lv);
97                }
98            }
99            self.cdb.dealloc(cls);
100        }
101        self.cnf[ln].clear();
102        self.cnf[!ln].clear();
103    }
104
105    fn resolvent(&self, pcnf: &[u32], ncnf: &[u32], pivot: Var, limit: usize) -> Option<LitVvec> {
106        let mut res = LitVvec::new();
107        for &pcls in pcnf {
108            for &ncls in ncnf {
109                if let Some(resolvent) = self.cdb[pcls].ordered_resolvent(&self.cdb[ncls], pivot) {
110                    res.push(resolvent);
111                }
112                if res.len() > limit {
113                    return None;
114                }
115            }
116        }
117        Some(res)
118    }
119
120    fn eliminate(&mut self, v: Var) {
121        if self.frozen.contains(&v) {
122            return;
123        }
124        let lv = v.lit();
125        let ocost = self.occur.num_occur(lv)
126            + self.occur.num_occur(!lv)
127            + self.cnf[lv].len()
128            + self.cnf[!lv].len();
129        if ocost == 0 || ocost > 2000 {
130            return;
131        }
132        let (pos, neg) = (self.cnf[lv].clone(), self.cnf[!lv].clone());
133        let mut ncost = 0;
134        let mut opos = self.occur.get(lv).to_vec();
135        let oneg = self.occur.get(!lv).to_vec();
136        let Some(respn) = self.resolvent(&pos, &oneg, v, ocost - ncost) else {
137            return;
138        };
139        ncost += respn.len();
140        if ncost > ocost {
141            return;
142        }
143        let Some(resnp) = self.resolvent(&neg, &opos, v, ocost - ncost) else {
144            return;
145        };
146        ncost += resnp.len();
147        if ncost > ocost {
148            return;
149        }
150        let mut res = respn;
151        res.extend(resnp);
152        let res = clause_subsume_simplify(res);
153        opos.extend(oneg);
154        self.remove_rel(opos);
155        self.remove_node(v);
156        for r in res {
157            self.add_rel(r);
158        }
159    }
160
161    pub fn bve_simplify(&mut self) {
162        while let Some(v) = self.qbve.pop() {
163            self.eliminate(v);
164        }
165    }
166
167    fn cls_subsume_check(&mut self, ci: u32) {
168        if self.cdb.is_removed(ci) {
169            return;
170        }
171        let best_lit = *self.cdb[ci]
172            .iter()
173            .min_by_key(|&&l| self.occur.num_occur(l) + self.cnf[l].len())
174            .unwrap();
175        let mut occur = self.occur.get(best_lit).to_vec();
176        occur.extend(self.cnf[best_lit].iter());
177        for cj in occur {
178            if self.cdb.is_removed(cj) {
179                continue;
180            }
181            if cj == ci {
182                continue;
183            }
184            let (res, diff) = self.cdb[ci].subsume_execpt_one(&self.cdb[cj]);
185            if res {
186                self.cnf[self.cdb[cj].last()].retain(|&c| c != cj);
187                self.cdb.dealloc(cj);
188                continue;
189            } else if let Some(diff) = diff {
190                if self.cdb[ci].len() == self.cdb[cj].len() {
191                    let mut cube = self.cdb[ci].cube().clone();
192                    cube.retain(|l| *l != diff);
193                    assert!(cube.last() == self.cdb[ci].last());
194                    self.cdb[ci] = LitOrdVec::new(cube);
195                    self.cnf[self.cdb[cj].last()].retain(|&c| c != cj);
196                    self.cdb.dealloc(cj);
197                } else {
198                    let mut cube = self.cdb[cj].cube().clone();
199                    assert!(cube.last() == self.cdb[cj].last());
200                    cube.retain(|l| *l != !diff);
201                    self.cdb[cj] = LitOrdVec::new(cube);
202                }
203            }
204        }
205    }
206
207    pub fn subsume_simplify(&mut self) {
208        for v in Var::CONST..=self.max_var {
209            for cls in self.cnf[v.lit()].clone() {
210                self.cls_subsume_check(cls);
211            }
212            for cls in self.cnf[!v.lit()].clone() {
213                self.cls_subsume_check(cls);
214            }
215        }
216    }
217
218    pub fn simplify(&mut self) -> DagCnf {
219        self.bve_simplify();
220        self.subsume_simplify();
221        let mut dagcnf = DagCnf::new();
222        dagcnf.new_var_to(self.max_var);
223        for v in Var(1)..=self.max_var {
224            let cnf: Vec<_> = self.cnf[v.lit()]
225                .iter()
226                .chain(self.cnf[!v.lit()].iter())
227                .filter(|&&cls| !self.cdb.is_removed(cls))
228                .map(|&cls| self.cdb[cls].cube().clone())
229                .collect();
230            assert!(cnf.iter().all(|cls| cls.last().var().eq(&v)));
231            dagcnf.add_rel(v, &cnf);
232        }
233        dagcnf
234    }
235}
236
237fn clause_subsume_simplify(lemmas: LitVvec) -> LitVvec {
238    let lemmas: Vec<LitOrdVec> = lemmas.into_iter().map(LitOrdVec::new).collect();
239    let lemmas = lemmas_subsume_simplify(lemmas);
240    lemmas
241        .into_iter()
242        .map(|l| LitVec::from(l.cube().as_slice()))
243        .collect()
244}
245
246impl DagCnf {
247    pub fn simplify(&self, frozen: impl Iterator<Item = Var>) -> Self {
248        let mut simp = DagCnfSimplify::new(self);
249        for v in frozen.chain(once(Var::CONST)) {
250            simp.froze(v);
251        }
252        simp.simplify()
253    }
254}