rustsat/encodings/pb/
gte.rs

1//! # Generalized Totalizer Encoding
2//!
3//! Implementation of the binary adder tree generalized totalizer encoding
4//! \[1\]. The implementation is incremental. The implementation is recursive.
5//! This encoding only support upper bounding. Lower bounding can be achieved by
6//! negating the input literals. This is implemented in
7//! [`super::simulators::Inverted`].
8//! The implementation is based on a node database.
9//!
10//! # References
11//!
12//! - \[1\] Saurabh Joshi and Ruben Martins and Vasco Manquinho: _Generalized
13//!   Totalizer Encoding for Pseudo-Boolean Constraints_, CP 2015.
14
15use std::ops::RangeBounds;
16
17use crate::{
18    encodings::{
19        nodedb::{NodeById, NodeCon, NodeLike},
20        totdb, CollectClauses, EncodeStats, EnforceError, Monotone,
21    },
22    instances::ManageVars,
23    types::{Lit, RsHashMap},
24};
25
26use super::{BoundUpper, BoundUpperIncremental, Encode, EncodeIncremental};
27
28/// Implementation of the binary adder tree generalized totalizer encoding
29/// \[1\]. The implementation is incremental. The implementation is recursive.
30/// This encoding only support upper bounding. Lower bounding can be achieved by
31/// negating the input literals. This is implemented in
32/// [`super::simulators::Inverted`].
33/// The implementation is based on a node database.
34///
35/// # References
36///
37/// - \[1\] Saurabh Joshi and Ruben Martins and Vasco Manquinho: _Generalized
38///   Totalizer Encoding for Pseudo-Boolean Constraints_, CP 2015.
39#[derive(Default, Debug)]
40#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
41pub struct GeneralizedTotalizer {
42    /// Input literals and weights not yet in the tree
43    lit_buffer: RsHashMap<Lit, usize>,
44    /// The root of the tree, if constructed
45    root: Option<NodeCon>,
46    /// Maximum weight of a leaf, needed for computing how much more than
47    /// `max_rhs` to encode
48    max_leaf_weight: usize,
49    /// The number of variables in the totalizer
50    n_vars: u32,
51    /// The number of clauses in the totalizer
52    n_clauses: usize,
53    /// The node database of the totalizer
54    db: totdb::Db,
55}
56
57impl GeneralizedTotalizer {
58    /// Creates a generalized totalizer from its internal parts
59    #[cfg(feature = "_internals")]
60    #[must_use]
61    pub fn from_raw(root: NodeCon, db: totdb::Db, max_leaf_weight: usize) -> Self {
62        Self {
63            root: Some(root),
64            max_leaf_weight,
65            db,
66            ..Default::default()
67        }
68    }
69
70    fn extend_tree(&mut self, max_weight: usize) {
71        if !self.lit_buffer.is_empty() {
72            let mut new_lits: Vec<(Lit, usize)> = self
73                .lit_buffer
74                .iter()
75                .filter_map(|(&l, &w)| {
76                    if w <= max_weight {
77                        if w > self.max_leaf_weight {
78                            self.max_leaf_weight = w;
79                        }
80                        Some((l, w))
81                    } else {
82                        None
83                    }
84                })
85                .collect();
86            if !new_lits.is_empty() {
87                // add nodes in sorted fashion to minimize clauses
88                new_lits.sort_by_key(|(_, w)| *w);
89                // Detect sequences of literals of equal weight and merge them
90                let mut seg_begin = 0;
91                let mut seg_end = 0;
92                let mut cons = vec![];
93                loop {
94                    seg_end += 1;
95                    if seg_end < new_lits.len() && new_lits[seg_end].1 == new_lits[seg_begin].1 {
96                        continue;
97                    }
98                    // merge lits of equal weight
99                    let seg: Vec<_> = new_lits[seg_begin..seg_end]
100                        .iter()
101                        .map(|(lit, _)| *lit)
102                        .collect();
103                    let id = self.db.lit_tree(&seg);
104                    cons.push(NodeCon::weighted(id, new_lits[seg_begin].1));
105                    seg_begin = seg_end;
106                    if seg_end >= new_lits.len() {
107                        break;
108                    }
109                }
110                if let Some(con) = self.root {
111                    cons.push(con);
112                }
113                self.root = Some(self.db.merge_balanced(&cons));
114                self.lit_buffer.retain(|_, w| *w > max_weight);
115            }
116        }
117    }
118
119    /// Gets the depth of the encoding, i.e., the longest path from the root to a leaf
120    #[must_use]
121    pub fn depth(&self) -> usize {
122        self.root.map_or(0, |con| self.db[con.id].depth())
123    }
124
125    /// Gets the details of a generalized totalizer output related to proof logging
126    ///
127    /// # Errors
128    ///
129    /// If the requested output is not encoded
130    #[cfg(feature = "proof-logging")]
131    pub fn output_proof_details(
132        &self,
133        value: usize,
134    ) -> Result<(Lit, totdb::cert::SemDefs), crate::encodings::NotEncoded> {
135        match self.root {
136            None => Err(crate::encodings::NotEncoded),
137            Some(root) => {
138                if !root.is_possible(value) {
139                    return Err(crate::encodings::NotEncoded);
140                }
141                self.db
142                    .get_semantics(root.id, root.offset, root.rev_map(value))
143                    .map(|sem| (self.db[root.id][root.rev_map(value)], sem))
144                    .ok_or(crate::encodings::NotEncoded)
145            }
146        }
147    }
148
149    /// Gets the number of output literals in the generalized totalizer
150    #[must_use]
151    pub fn n_output_lits(&self) -> usize {
152        match self.root {
153            Some(root) => self.db[root.id].len(),
154            None => 0,
155        }
156    }
157
158    /// Checks if the input literal buffer is empty, i.e., all input literals are included in the
159    /// encoding.
160    ///
161    /// Even after encodings, this might not be the case, if an input literal has a higher weight
162    /// than the bound encoded for
163    #[must_use]
164    pub fn is_buffer_empty(&self) -> bool {
165        self.lit_buffer.is_empty()
166    }
167
168    /// From an assignment to the input literals, generates an assignment over the totalizer
169    /// variables following strict semantics, i.e., `sum >= k <-> olit`
170    ///
171    /// # Panics
172    ///
173    /// If `assign` does not assign all input literals
174    pub fn strictly_extend_assignment<'slf>(
175        &'slf self,
176        assign: &'slf crate::types::Assignment,
177    ) -> std::iter::Flatten<std::option::IntoIter<totdb::AssignIter<'slf>>> {
178        self.root
179            .map(|root| self.db.strictly_extend_assignment(root.id, assign))
180            .into_iter()
181            .flatten()
182    }
183}
184
185impl Encode for GeneralizedTotalizer {
186    fn weight_sum(&self) -> usize {
187        self.lit_buffer.iter().fold(0, |sum, (_, w)| sum + w)
188            + if let Some(root) = self.root {
189                root.map(self.db[root.id].max_val())
190            } else {
191                0
192            }
193    }
194
195    fn next_higher(&self, val: usize) -> usize {
196        if let Some(con) = self.root {
197            self.db[con.id]
198                .vals(con.rev_map_round_up(val + 1)..)
199                .next()
200                .map_or(val + 1, |val| con.map(val))
201        } else {
202            val + 1
203        }
204    }
205
206    fn next_lower(&self, val: usize) -> usize {
207        if val == 0 {
208            return 0;
209        }
210        if let Some(con) = self.root {
211            return self.db[con.id]
212                .vals(con.offset()..con.rev_map_round_up(val))
213                .next_back()
214                .map_or(0, |val| con.map(val));
215        }
216        val - 1
217    }
218}
219
220impl EncodeIncremental for GeneralizedTotalizer {
221    fn reserve(&mut self, var_manager: &mut dyn ManageVars) {
222        self.extend_tree(usize::MAX);
223        if let Some(con) = self.root {
224            self.db.reserve_vars(con, var_manager);
225        }
226    }
227}
228
229impl BoundUpper for GeneralizedTotalizer {
230    fn encode_ub<Col, R>(
231        &mut self,
232        range: R,
233        collector: &mut Col,
234        var_manager: &mut dyn ManageVars,
235    ) -> Result<(), crate::OutOfMemory>
236    where
237        Col: CollectClauses,
238        R: RangeBounds<usize>,
239    {
240        self.db.reset_encoded(totdb::Semantics::If);
241        self.encode_ub_change(range, collector, var_manager)
242    }
243
244    fn enforce_ub(&self, ub: usize) -> Result<Vec<Lit>, EnforceError> {
245        if ub >= self.weight_sum() {
246            return Ok(vec![]);
247        }
248
249        let mut assumps = vec![];
250        self.lit_buffer.iter().try_for_each(|(&l, &w)| {
251            if w <= ub {
252                Err(EnforceError::NotEncoded)
253            } else {
254                assumps.push(!l);
255                Ok(())
256            }
257        })?;
258        // Enforce bound on internal tree
259        if let Some(con) = self.root {
260            self.db[con.id]
261                .vals(con.rev_map_round_up(ub + 1)..=con.rev_map(ub + self.max_leaf_weight))
262                .try_for_each(|val| {
263                    match &self.db[con.id] {
264                        totdb::Node::Leaf(lit) => {
265                            assumps.push(!*lit);
266                            return Ok(());
267                        }
268                        totdb::Node::Unit(node) => {
269                            if let totdb::LitData::Lit {
270                                lit,
271                                semantics: Some(semantics),
272                            } = node.lits[val - 1]
273                            {
274                                if semantics.has_if() {
275                                    assumps.push(!lit);
276                                    return Ok(());
277                                }
278                            }
279                        }
280                        totdb::Node::General(node) => {
281                            if let Some(totdb::LitData::Lit {
282                                lit,
283                                semantics: Some(semantics),
284                            }) = node.lit_data(val)
285                            {
286                                if semantics.has_if() {
287                                    assumps.push(!lit);
288                                    return Ok(());
289                                }
290                            }
291                        }
292                        totdb::Node::Dummy => unreachable!(),
293                    }
294                    Err(EnforceError::NotEncoded)
295                })?;
296        }
297        Ok(assumps)
298    }
299}
300
301impl BoundUpperIncremental for GeneralizedTotalizer {
302    fn encode_ub_change<Col, R>(
303        &mut self,
304        range: R,
305        collector: &mut Col,
306        var_manager: &mut dyn ManageVars,
307    ) -> Result<(), crate::OutOfMemory>
308    where
309        Col: CollectClauses,
310        R: RangeBounds<usize>,
311    {
312        let range = super::prepare_ub_range(self, range);
313        if range.is_empty() {
314            return Ok(());
315        }
316        let n_vars_before = var_manager.n_used();
317        let n_clauses_before = collector.n_clauses();
318        self.extend_tree(range.end - 1);
319        if let Some(con) = self.root {
320            self.db[con.id]
321                .vals(
322                    con.rev_map_round_up(range.start + 1)
323                        ..=con.rev_map(range.end + self.max_leaf_weight),
324                )
325                .try_for_each(|val| {
326                    self.db
327                        .define_weighted(con.id, val, collector, var_manager)?
328                        .unwrap();
329                    Ok::<(), crate::OutOfMemory>(())
330                })?;
331        }
332        self.n_clauses += collector.n_clauses() - n_clauses_before;
333        self.n_vars += var_manager.n_used() - n_vars_before;
334        Ok(())
335    }
336}
337
338impl Monotone for GeneralizedTotalizer {}
339
340impl EncodeStats for GeneralizedTotalizer {
341    fn n_clauses(&self) -> usize {
342        self.n_clauses
343    }
344
345    fn n_vars(&self) -> u32 {
346        self.n_vars
347    }
348}
349
350impl From<RsHashMap<Lit, usize>> for GeneralizedTotalizer {
351    fn from(lits: RsHashMap<Lit, usize>) -> Self {
352        Self {
353            lit_buffer: lits,
354            ..Default::default()
355        }
356    }
357}
358
359impl FromIterator<(Lit, usize)> for GeneralizedTotalizer {
360    fn from_iter<T: IntoIterator<Item = (Lit, usize)>>(iter: T) -> Self {
361        let lits: RsHashMap<Lit, usize> = RsHashMap::from_iter(iter);
362        Self::from(lits)
363    }
364}
365
366impl Extend<(Lit, usize)> for GeneralizedTotalizer {
367    fn extend<T: IntoIterator<Item = (Lit, usize)>>(&mut self, iter: T) {
368        iter.into_iter().for_each(|(l, w)| {
369            // Insert into buffer to be added to tree
370            match self.lit_buffer.get_mut(&l) {
371                Some(old_w) => *old_w += w,
372                None => {
373                    self.lit_buffer.insert(l, w);
374                }
375            }
376        });
377    }
378}
379
380#[cfg(feature = "proof-logging")]
381impl super::cert::BoundUpper for GeneralizedTotalizer {
382    fn encode_ub_cert<Col, R, W>(
383        &mut self,
384        range: R,
385        collector: &mut Col,
386        var_manager: &mut dyn ManageVars,
387        proof: &mut pigeons::Proof<W>,
388    ) -> Result<(), crate::encodings::cert::EncodingError>
389    where
390        Col: crate::encodings::cert::CollectClauses,
391        R: RangeBounds<usize>,
392        W: std::io::Write,
393    {
394        use super::cert::BoundUpperIncremental;
395        self.db.reset_encoded(totdb::Semantics::If);
396        self.encode_ub_change_cert(range, collector, var_manager, proof)
397    }
398
399    fn encode_ub_constr_cert<Col, W>(
400        constr: (
401            crate::types::constraints::PbUbConstr,
402            pigeons::AbsConstraintId,
403        ),
404        collector: &mut Col,
405        var_manager: &mut dyn ManageVars,
406        proof: &mut pigeons::Proof<W>,
407    ) -> Result<(), crate::encodings::cert::ConstraintEncodingError>
408    where
409        Col: crate::encodings::cert::CollectClauses,
410        W: std::io::Write,
411        Self: FromIterator<(Lit, usize)> + Sized,
412    {
413        use pigeons::{OperationSequence, VarLike};
414
415        use crate::types::Var;
416
417        // TODO: properly take care of constraints where no structure is built
418
419        let (constr, mut id) = constr;
420        let (lits, ub) = constr.decompose();
421        if ub < 0 {
422            return Err(crate::encodings::cert::ConstraintEncodingError::Unsat);
423        }
424        let ub = ub.unsigned_abs();
425        if ub > lits.iter().fold(0, |sum, (_, w)| sum + *w) {
426            return Ok(());
427        }
428        let mut enc = Self::from_iter(lits);
429        enc.encode_ub_cert(ub..=ub, collector, var_manager, proof)?;
430        let mut val = enc.next_higher(ub);
431        for unit in enc
432            .enforce_ub(ub)
433            .expect("should have caught special case before here")
434        {
435            let (olit, sem_defs) = enc
436                .output_proof_details(val)
437                .expect("encoded just before, so should be fine");
438            let unit_cl = crate::clause![unit];
439            let unit_id = if unit.var() < olit.var() {
440                // input literal with weight larger than bound
441                let weight = *enc.lit_buffer.get(&!unit).unwrap();
442                let unit_id = proof.reverse_unit_prop(&unit_cl, [id.into()])?;
443                // simplify main constraint
444                #[cfg(feature = "verbose-proofs")]
445                proof.comment(&"rewritten main constraint")?;
446                id = proof.operations(
447                    &(OperationSequence::<Var>::from(unit.var().axiom(!unit.is_neg())) * weight
448                        + id),
449                )?;
450                unit_id
451            } else {
452                // output literal
453                // NOTE: by the time we're here, all buffered literals have been removed from `id`
454                debug_assert_eq!(!unit, olit);
455                let unit_id = proof.operations(
456                    &((OperationSequence::<Var>::from(id) + sem_defs.only_if_def.unwrap()) / val),
457                )?;
458                #[cfg(feature = "verbose-proofs")]
459                proof.equals(&unit_cl, Some(unit_id.into()))?;
460                val = enc.next_higher(val);
461                unit_id
462            };
463            collector.add_cert_clause(unit_cl, unit_id)?;
464        }
465        enc.db.delete_semantics(proof)?;
466        Ok(())
467    }
468}
469
470#[cfg(feature = "proof-logging")]
471impl super::cert::BoundUpperIncremental for GeneralizedTotalizer {
472    fn encode_ub_change_cert<Col, R, W>(
473        &mut self,
474        range: R,
475        collector: &mut Col,
476        var_manager: &mut dyn ManageVars,
477        proof: &mut pigeons::Proof<W>,
478    ) -> Result<(), crate::encodings::cert::EncodingError>
479    where
480        Col: crate::encodings::cert::CollectClauses,
481        R: RangeBounds<usize>,
482        W: std::io::Write,
483    {
484        let range = super::prepare_ub_range(self, range);
485        if range.is_empty() {
486            return Ok(());
487        }
488        let n_vars_before = var_manager.n_used();
489        let n_clauses_before = collector.n_clauses();
490        self.extend_tree(range.end - 1);
491        if let Some(con) = self.root {
492            let mut leaves = vec![(crate::lit![0], 0); self.db[con.id].n_leaves()];
493            let mut leaves_init = false;
494            self.db[con.id]
495                .vals(
496                    con.rev_map_round_up(range.start + 1)
497                        ..=con.rev_map(range.end + self.max_leaf_weight),
498                )
499                .try_for_each(|val| {
500                    (_, leaves_init) = self
501                        .db
502                        .define_weighted_cert(
503                            con.id,
504                            val,
505                            collector,
506                            var_manager,
507                            proof,
508                            (&mut leaves, leaves_init, false),
509                        )?
510                        .unwrap();
511                    Ok::<(), crate::encodings::cert::EncodingError>(())
512                })?;
513        }
514        self.n_clauses += collector.n_clauses() - n_clauses_before;
515        self.n_vars += var_manager.n_used() - n_vars_before;
516        Ok(())
517    }
518}
519
520/// Generalized totalizer encoding types that do not own but reference their [`totdb::Db`]
521#[cfg(feature = "_internals")]
522pub mod referenced {
523    use std::{cell::RefCell, ops::RangeBounds};
524
525    use crate::{
526        encodings::{
527            nodedb::{NodeCon, NodeLike},
528            pb::{BoundUpper, BoundUpperIncremental, Encode, EncodeIncremental},
529            totdb, CollectClauses, EnforceError,
530        },
531        instances::ManageVars,
532        types::Lit,
533    };
534
535    /// Generalized totalizer encoding with a _mutable reference_ to a totalizer
536    /// database rather than owning it.
537    ///
538    /// ## References
539    ///
540    /// - \[1\] Saurabh Joshi and Ruben Martins and Vasco Manquinho: _Generalized
541    ///   Totalizer Encoding for Pseudo-Boolean Constraints_, CP 2015.
542    #[derive(Debug)]
543    pub struct Gte<'totdb> {
544        /// A node connection to the root
545        root: NodeCon,
546        /// The maximum weight of any leaf
547        max_leaf_weight: usize,
548        /// The node database of the totalizer
549        db: &'totdb mut totdb::Db,
550    }
551
552    /// Generalized totalizer encoding with a [`RefCell`] to a totalizer
553    /// database rather than owning it.
554    ///
555    /// ## References
556    ///
557    /// - \[1\] Saurabh Joshi and Ruben Martins and Vasco Manquinho: _Generalized
558    ///   Totalizer Encoding for Pseudo-Boolean Constraints_, CP 2015.
559    #[derive(Debug)]
560    pub struct GteCell<'totdb> {
561        /// A node connection to the root
562        root: NodeCon,
563        /// The maximum weight of any leaf
564        max_leaf_weight: usize,
565        /// The node database of the totalizer
566        db: &'totdb RefCell<&'totdb mut totdb::Db>,
567    }
568
569    impl<'totdb> Gte<'totdb> {
570        /// Constructs a new GTE encoding referencing a totalizer database
571        pub fn new(root: NodeCon, max_leaf_weight: usize, db: &'totdb mut totdb::Db) -> Self {
572            Self {
573                root,
574                max_leaf_weight,
575                db,
576            }
577        }
578
579        /// Gets the maximum depth of the tree
580        #[must_use]
581        pub fn depth(&self) -> usize {
582            self.db[self.root.id].depth()
583        }
584    }
585
586    impl<'totdb> GteCell<'totdb> {
587        /// Constructs a new GTE encoding referencing a totalizer database
588        pub fn new(
589            root: NodeCon,
590            max_leaf_weight: usize,
591            db: &'totdb RefCell<&'totdb mut totdb::Db>,
592        ) -> Self {
593            Self {
594                root,
595                max_leaf_weight,
596                db,
597            }
598        }
599
600        /// Gets the maximum depth of the tree
601        #[must_use]
602        pub fn depth(&self) -> usize {
603            self.db.borrow()[self.root.id].depth()
604        }
605    }
606
607    impl Encode for Gte<'_> {
608        fn weight_sum(&self) -> usize {
609            self.root.map(self.db[self.root.id].max_val())
610        }
611
612        fn next_higher(&self, val: usize) -> usize {
613            self.db[self.root.id]
614                .vals(self.root.rev_map_round_up(val + 1)..)
615                .next()
616                .map_or(val + 1, |val| self.root.map(val))
617        }
618
619        fn next_lower(&self, val: usize) -> usize {
620            self.db[self.root.id]
621                .vals(self.root.offset()..self.root.rev_map_round_up(val))
622                .next_back()
623                .map_or(val - 1, |val| self.root.map(val))
624        }
625    }
626
627    impl Encode for GteCell<'_> {
628        fn weight_sum(&self) -> usize {
629            self.root.map(self.db.borrow()[self.root.id].max_val())
630        }
631
632        fn next_higher(&self, val: usize) -> usize {
633            self.db.borrow()[self.root.id]
634                .vals(self.root.rev_map_round_up(val + 1)..)
635                .next()
636                .map_or(val + 1, |val| self.root.map(val))
637        }
638
639        fn next_lower(&self, val: usize) -> usize {
640            self.db.borrow()[self.root.id]
641                .vals(self.root.offset()..self.root.rev_map_round_up(val))
642                .next_back()
643                .map_or(val - 1, |val| self.root.map(val))
644        }
645    }
646
647    impl EncodeIncremental for Gte<'_> {
648        fn reserve(&mut self, var_manager: &mut dyn ManageVars) {
649            self.db.reserve_vars(self.root, var_manager);
650        }
651    }
652
653    impl EncodeIncremental for GteCell<'_> {
654        fn reserve(&mut self, var_manager: &mut dyn ManageVars) {
655            self.db.borrow_mut().reserve_vars(self.root, var_manager);
656        }
657    }
658
659    impl BoundUpper for Gte<'_> {
660        fn encode_ub<Col, R>(
661            &mut self,
662            range: R,
663            collector: &mut Col,
664            var_manager: &mut dyn ManageVars,
665        ) -> Result<(), crate::OutOfMemory>
666        where
667            Col: CollectClauses,
668            R: RangeBounds<usize>,
669        {
670            self.db.reset_encoded(totdb::Semantics::If);
671            self.encode_ub_change(range, collector, var_manager)
672        }
673
674        fn enforce_ub(&self, ub: usize) -> Result<Vec<Lit>, EnforceError> {
675            if ub >= self.weight_sum() {
676                return Ok(vec![]);
677            }
678
679            let mut assumps = vec![];
680            // Enforce bound on internal tree
681            self.db[self.root.id]
682                .vals(
683                    self.root.rev_map_round_up(ub + 1)
684                        ..=self.root.rev_map(ub + self.max_leaf_weight),
685                )
686                .try_for_each(|val| {
687                    match &self.db[self.root.id] {
688                        totdb::Node::Leaf(lit) => {
689                            assumps.push(!*lit);
690                            return Ok(());
691                        }
692                        totdb::Node::Unit(node) => {
693                            if let totdb::LitData::Lit {
694                                lit,
695                                semantics: Some(semantics),
696                            } = node.lits[val - 1]
697                            {
698                                if semantics.has_if() {
699                                    assumps.push(!lit);
700                                    return Ok(());
701                                }
702                            }
703                        }
704                        totdb::Node::General(node) => {
705                            if let Some(totdb::LitData::Lit {
706                                lit,
707                                semantics: Some(semantics),
708                            }) = node.lit_data(val)
709                            {
710                                if semantics.has_if() {
711                                    assumps.push(!lit);
712                                    return Ok(());
713                                }
714                            }
715                        }
716                        totdb::Node::Dummy => panic!(),
717                    }
718                    Err(EnforceError::NotEncoded)
719                })?;
720            Ok(assumps)
721        }
722    }
723
724    impl BoundUpper for GteCell<'_> {
725        fn encode_ub<Col, R>(
726            &mut self,
727            range: R,
728            collector: &mut Col,
729            var_manager: &mut dyn ManageVars,
730        ) -> Result<(), crate::OutOfMemory>
731        where
732            Col: CollectClauses,
733            R: RangeBounds<usize>,
734        {
735            self.db.borrow_mut().reset_encoded(totdb::Semantics::If);
736            self.encode_ub_change(range, collector, var_manager)
737        }
738
739        fn enforce_ub(&self, ub: usize) -> Result<Vec<Lit>, EnforceError> {
740            if ub >= self.weight_sum() {
741                return Ok(vec![]);
742            }
743
744            let mut assumps = vec![];
745            // Enforce bound on internal tree
746            self.db.borrow()[self.root.id]
747                .vals(
748                    self.root.rev_map_round_up(ub + 1)
749                        ..=self.root.rev_map(ub + self.max_leaf_weight),
750                )
751                .try_for_each(|val| {
752                    match &self.db.borrow()[self.root.id] {
753                        totdb::Node::Leaf(lit) => {
754                            assumps.push(!*lit);
755                            return Ok(());
756                        }
757                        totdb::Node::Unit(node) => {
758                            if let totdb::LitData::Lit {
759                                lit,
760                                semantics: Some(semantics),
761                            } = node.lits[val - 1]
762                            {
763                                if semantics.has_if() {
764                                    assumps.push(!lit);
765                                    return Ok(());
766                                }
767                            }
768                        }
769                        totdb::Node::General(node) => {
770                            if let Some(totdb::LitData::Lit {
771                                lit,
772                                semantics: Some(semantics),
773                            }) = node.lit_data(val)
774                            {
775                                if semantics.has_if() {
776                                    assumps.push(!lit);
777                                    return Ok(());
778                                }
779                            }
780                        }
781                        totdb::Node::Dummy => unreachable!(),
782                    }
783                    Err(EnforceError::NotEncoded)
784                })?;
785            Ok(assumps)
786        }
787    }
788
789    impl BoundUpperIncremental for Gte<'_> {
790        fn encode_ub_change<Col, R>(
791            &mut self,
792            range: R,
793            collector: &mut Col,
794            var_manager: &mut dyn ManageVars,
795        ) -> Result<(), crate::OutOfMemory>
796        where
797            Col: CollectClauses,
798            R: RangeBounds<usize>,
799        {
800            let range = super::super::prepare_ub_range(self, range);
801            if range.is_empty() {
802                return Ok(());
803            }
804            self.db[self.root.id]
805                .vals(
806                    self.root.rev_map_round_up(range.start + 1)
807                        ..=self.root.rev_map(range.end + self.max_leaf_weight),
808                )
809                .try_for_each(|val| {
810                    self.db
811                        .define_weighted(self.root.id, val, collector, var_manager)?
812                        .unwrap();
813                    Ok::<(), crate::OutOfMemory>(())
814                })?;
815            Ok(())
816        }
817    }
818
819    impl BoundUpperIncremental for GteCell<'_> {
820        fn encode_ub_change<Col, R>(
821            &mut self,
822            range: R,
823            collector: &mut Col,
824            var_manager: &mut dyn ManageVars,
825        ) -> Result<(), crate::OutOfMemory>
826        where
827            Col: CollectClauses,
828            R: RangeBounds<usize>,
829        {
830            let range = super::super::prepare_ub_range(self, range);
831            if range.is_empty() {
832                return Ok(());
833            }
834            let mut vals = self.db.borrow()[self.root.id].vals(
835                self.root.rev_map_round_up(range.start + 1)
836                    ..=self.root.rev_map(range.end + self.max_leaf_weight),
837            );
838            vals.try_for_each(|val| {
839                self.db
840                    .borrow_mut()
841                    .define_weighted(self.root.id, val, collector, var_manager)?
842                    .unwrap();
843                Ok::<(), crate::OutOfMemory>(())
844            })?;
845            Ok(())
846        }
847    }
848}
849
850#[cfg(test)]
851mod tests {
852    use super::GeneralizedTotalizer;
853    use crate::{
854        encodings::{
855            card,
856            pb::{BoundUpper, BoundUpperIncremental, EncodeIncremental},
857            EncodeStats, EnforceError,
858        },
859        instances::{BasicVarManager, Cnf, ManageVars},
860        lit,
861        types::RsHashMap,
862        var,
863    };
864
865    #[test]
866    fn ub_gte_functions() {
867        let mut gte = GeneralizedTotalizer::default();
868        let mut lits = RsHashMap::default();
869        lits.insert(lit![0], 5);
870        lits.insert(lit![1], 5);
871        lits.insert(lit![2], 3);
872        lits.insert(lit![3], 3);
873        gte.extend(lits);
874        assert_eq!(gte.enforce_ub(4), Err(EnforceError::NotEncoded));
875        let mut var_manager = BasicVarManager::default();
876        gte.encode_ub(0..7, &mut Cnf::new(), &mut var_manager)
877            .unwrap();
878        assert_eq!(gte.depth(), 3);
879        assert_eq!(gte.n_vars(), 10);
880    }
881
882    #[test]
883    fn ub_gte_incremental_building() {
884        let mut gte1 = GeneralizedTotalizer::default();
885        let mut lits = RsHashMap::default();
886        lits.insert(lit![0], 5);
887        lits.insert(lit![1], 5);
888        lits.insert(lit![2], 3);
889        lits.insert(lit![3], 3);
890        gte1.extend(lits.clone());
891        let mut var_manager = BasicVarManager::default();
892        let mut cnf1 = Cnf::new();
893        gte1.encode_ub(0..5, &mut cnf1, &mut var_manager).unwrap();
894        let mut gte2 = GeneralizedTotalizer::default();
895        gte2.extend(lits);
896        let mut var_manager = BasicVarManager::default();
897        let mut cnf2 = Cnf::new();
898        gte2.encode_ub(0..3, &mut cnf2, &mut var_manager).unwrap();
899        gte2.encode_ub_change(0..5, &mut cnf2, &mut var_manager)
900            .unwrap();
901        assert_eq!(cnf1.len(), cnf2.len());
902        assert_eq!(cnf1.len(), gte1.n_clauses());
903        assert_eq!(cnf2.len(), gte2.n_clauses());
904    }
905
906    #[test]
907    fn from_capi() {
908        let mut gte1 = GeneralizedTotalizer::default();
909        let mut lits = RsHashMap::default();
910        lits.insert(lit![0], 1);
911        lits.insert(lit![1], 2);
912        lits.insert(lit![2], 3);
913        lits.insert(lit![3], 4);
914        gte1.extend(lits);
915        let mut var_manager = BasicVarManager::from_next_free(var![4]);
916        let mut cnf = Cnf::new();
917        gte1.encode_ub(0..=6, &mut cnf, &mut var_manager).unwrap();
918        debug_assert_eq!(var_manager.n_used(), 24);
919        debug_assert_eq!(cnf.len(), 25);
920    }
921
922    #[test]
923    fn ub_gte_multiplication() {
924        let mut gte1 = GeneralizedTotalizer::default();
925        let mut lits = RsHashMap::default();
926        lits.insert(lit![0], 5);
927        lits.insert(lit![1], 5);
928        lits.insert(lit![2], 3);
929        lits.insert(lit![3], 3);
930        gte1.extend(lits);
931        let mut var_manager = BasicVarManager::default();
932        let mut cnf1 = Cnf::new();
933        gte1.encode_ub(0..5, &mut cnf1, &mut var_manager).unwrap();
934        let mut gte2 = GeneralizedTotalizer::default();
935        let mut lits = RsHashMap::default();
936        lits.insert(lit![0], 10);
937        lits.insert(lit![1], 10);
938        lits.insert(lit![2], 6);
939        lits.insert(lit![3], 6);
940        gte2.extend(lits);
941        let mut var_manager = BasicVarManager::default();
942        let mut cnf2 = Cnf::new();
943        gte2.encode_ub(0..9, &mut cnf2, &mut var_manager).unwrap();
944        assert_eq!(cnf1.len(), cnf2.len());
945        assert_eq!(cnf1.len(), gte1.n_clauses());
946        assert_eq!(cnf2.len(), gte2.n_clauses());
947    }
948
949    #[test]
950    fn ub_gte_equals_tot() {
951        let mut var_manager_gte = BasicVarManager::default();
952        var_manager_gte.increase_next_free(var![7]);
953        let mut var_manager_tot = var_manager_gte.clone();
954        // Set up GTE
955        let mut gte = GeneralizedTotalizer::default();
956        let mut lits = RsHashMap::default();
957        lits.insert(lit![0], 1);
958        lits.insert(lit![1], 1);
959        lits.insert(lit![2], 1);
960        lits.insert(lit![3], 1);
961        lits.insert(lit![4], 1);
962        lits.insert(lit![5], 1);
963        lits.insert(lit![6], 1);
964        gte.extend(lits);
965        let mut gte_cnf = Cnf::new();
966        gte.encode_ub(3..8, &mut gte_cnf, &mut var_manager_gte)
967            .unwrap();
968        // Set up Tot
969        let mut tot = card::Totalizer::default();
970        tot.extend(vec![
971            lit![0],
972            lit![1],
973            lit![2],
974            lit![3],
975            lit![4],
976            lit![5],
977            lit![6],
978        ]);
979        let mut tot_cnf = Cnf::new();
980        card::BoundUpper::encode_ub(&mut tot, 3..8, &mut tot_cnf, &mut var_manager_tot).unwrap();
981        println!("{gte_cnf:?}");
982        println!("{tot_cnf:?}");
983        assert_eq!(var_manager_gte.new_var(), var_manager_tot.new_var());
984        assert_eq!(gte_cnf.len(), tot_cnf.len());
985        assert_eq!(gte_cnf.len(), gte.n_clauses());
986        assert_eq!(tot_cnf.len(), tot.n_clauses());
987    }
988
989    #[test]
990    fn reserve() {
991        let mut gte = GeneralizedTotalizer::default();
992        gte.extend(vec![(lit![0], 1), (lit![1], 2), (lit![2], 3), (lit![3], 4)]);
993        let mut var_manager = BasicVarManager::from_next_free(var![4]);
994        gte.reserve(&mut var_manager);
995        assert_eq!(var_manager.n_used(), 24);
996        let mut cnf = Cnf::new();
997        gte.encode_ub(0..3, &mut cnf, &mut var_manager).unwrap();
998        assert_eq!(var_manager.n_used(), 24);
999    }
1000
1001    #[cfg(feature = "proof-logging")]
1002    mod proofs {
1003        use std::{
1004            fs::File,
1005            io::{BufRead, BufReader},
1006            path::Path,
1007            process::Command,
1008        };
1009
1010        use crate::{
1011            encodings::pb::cert::BoundUpper,
1012            instances::{Cnf, SatInstance},
1013            types::{constraints::PbConstraint, Var},
1014        };
1015
1016        fn print_file<P: AsRef<Path>>(path: P) {
1017            println!();
1018            for line in BufReader::new(File::open(path).expect("could not open file")).lines() {
1019                println!("{}", line.unwrap());
1020            }
1021            println!();
1022        }
1023
1024        fn verify_proof<P1: AsRef<Path>, P2: AsRef<Path>>(instance: P1, proof: P2) {
1025            if let Ok(veripb) = std::env::var("VERIPB_CHECKER") {
1026                println!("start checking proof");
1027                let out = Command::new(veripb)
1028                    .arg(instance.as_ref())
1029                    .arg(proof.as_ref())
1030                    .output()
1031                    .expect("failed to run veripb");
1032                print_file(proof);
1033                if out.status.success() {
1034                    return;
1035                }
1036                panic!("verification failed: {out:?}")
1037            } else {
1038                println!("`$VERIPB_CHECKER` not set, omitting proof checking");
1039            }
1040        }
1041
1042        fn new_proof(
1043            num_constraints: usize,
1044            optimization: bool,
1045        ) -> pigeons::Proof<tempfile::NamedTempFile> {
1046            let file =
1047                tempfile::NamedTempFile::new().expect("failed to create temporary proof file");
1048            pigeons::Proof::new(file, num_constraints, optimization).expect("failed to start proof")
1049        }
1050
1051        #[test]
1052        fn constraint() {
1053            let manifest = std::env::var("CARGO_MANIFEST_DIR").unwrap();
1054            let inst_path = format!("{manifest}/data/single-ub.opb");
1055            let constr: SatInstance = SatInstance::from_opb_path(
1056                &inst_path,
1057                crate::instances::fio::opb::Options::default(),
1058            )
1059            .unwrap();
1060            let (constr, mut vm) = constr.into_pbs();
1061            assert_eq!(constr.len(), 1);
1062            let Some(PbConstraint::Lb(constr)) = constr.into_iter().next() else {
1063                panic!()
1064            };
1065            let constr = constr.invert();
1066            let mut cnf = Cnf::new();
1067            let mut proof = new_proof(1, false);
1068            super::GeneralizedTotalizer::encode_ub_constr_cert(
1069                (constr, pigeons::AbsConstraintId::new(1)),
1070                &mut cnf,
1071                &mut vm,
1072                &mut proof,
1073            )
1074            .unwrap();
1075            let proof_file = proof
1076                .conclude::<Var>(pigeons::OutputGuarantee::None, &pigeons::Conclusion::None)
1077                .unwrap();
1078            verify_proof(&inst_path, proof_file.path());
1079        }
1080    }
1081}