oxidd_rules_mtbdd/
lib.rs

1//! Rules and other basic definitions for multi-terminal binary decision
2//! diagrams
3//!
4//! ## Feature flags
5#![doc = document_features::document_features!()]
6#![forbid(unsafe_code)]
7// `'id` lifetimes may make the code easier to understand
8#![allow(clippy::needless_lifetimes)]
9
10use std::borrow::Borrow;
11use std::cmp::Ordering;
12use std::hash::Hash;
13
14use oxidd_core::function::NumberBase;
15use oxidd_core::util::{AllocResult, Borrowed};
16use oxidd_core::{DiagramRules, Edge, InnerNode, LevelNo, Manager, Node, ReducedOrNew};
17use oxidd_derive::Countable;
18
19pub mod terminal;
20
21mod apply_rec;
22
23// --- Reduction Rules ---------------------------------------------------------
24
25/// [`DiagramRules`] for (multi-terminal) binary decision diagrams
26pub struct MTBDDRules;
27
28impl<E: Edge, N: InnerNode<E>, T> DiagramRules<E, N, T> for MTBDDRules {
29    type Cofactors<'a>
30        = N::ChildrenIter<'a>
31    where
32        N: 'a,
33        E: 'a;
34
35    #[inline]
36    fn reduce<M: Manager<Edge = E, InnerNode = N, Terminal = T>>(
37        manager: &M,
38        level: LevelNo,
39        children: impl IntoIterator<Item = E>,
40    ) -> ReducedOrNew<E, N> {
41        let mut it = children.into_iter();
42        let t = it.next().unwrap();
43        let e = it.next().unwrap();
44        debug_assert!(it.next().is_none());
45
46        if t == e {
47            manager.drop_edge(e);
48            ReducedOrNew::Reduced(t)
49        } else {
50            ReducedOrNew::New(N::new(level, [t, e]), Default::default())
51        }
52    }
53
54    #[inline]
55    fn cofactors(_tag: E::Tag, node: &N) -> Self::Cofactors<'_> {
56        node.children()
57    }
58}
59
60#[inline(always)]
61fn reduce<M: Manager>(
62    manager: &M,
63    level: LevelNo,
64    t: M::Edge,
65    e: M::Edge,
66    op: MTBDDOp,
67) -> AllocResult<M::Edge> {
68    let _ = op;
69    let tmp = <MTBDDRules as DiagramRules<_, _, _>>::reduce(manager, level, [t, e]);
70    if let ReducedOrNew::Reduced(..) = &tmp {
71        stat!(reduced op);
72    }
73    tmp.then_insert(manager, level)
74}
75
76// --- Operations & Apply Implementation ---------------------------------------
77
78/// Native operations of this MTBDD implementation
79#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Countable, Debug)]
80#[repr(u8)]
81#[allow(missing_docs)]
82pub enum MTBDDOp {
83    Add,
84    Sub,
85    Mul,
86    Div,
87    Min,
88    Max,
89
90    /// If-then-else
91    Ite,
92}
93
94/// Collect the two children of a binary node
95#[inline]
96#[must_use]
97fn collect_children<E: Edge, N: InnerNode<E>>(node: &N) -> (Borrowed<'_, E>, Borrowed<'_, E>) {
98    debug_assert_eq!(N::ARITY, 2);
99    let mut it = node.children();
100    let t = it.next().unwrap();
101    let e = it.next().unwrap();
102    debug_assert!(it.next().is_none());
103    (t, e)
104}
105
106enum Operation<'a, E: 'a + Edge> {
107    Binary(MTBDDOp, Borrowed<'a, E>, Borrowed<'a, E>),
108    Done(E),
109}
110
111/// Terminal case for binary operators
112#[inline]
113fn terminal_bin<'a, M: Manager<Terminal = T>, T: NumberBase, const OP: u8>(
114    m: &M,
115    f: &'a M::Edge,
116    g: &'a M::Edge,
117) -> AllocResult<Operation<'a, M::Edge>> {
118    use Node::*;
119    use Operation::*;
120
121    Ok(if OP == MTBDDOp::Add as u8 {
122        match (m.get_node(f), m.get_node(g)) {
123            (Terminal(tf), Terminal(tg)) => {
124                let val = tf.borrow().add(tg.borrow());
125                Done(m.get_terminal(val)?)
126            }
127            (Terminal(t), _) if t.borrow().is_zero() => Done(m.clone_edge(g)),
128            (_, Terminal(t)) if t.borrow().is_zero() => Done(m.clone_edge(f)),
129            (Terminal(t), _) | (_, Terminal(t)) if t.borrow().is_nan() => {
130                Done(m.get_terminal(T::nan())?)
131            }
132            _ if f > g => Binary(MTBDDOp::Add, g.borrowed(), f.borrowed()),
133            _ => Binary(MTBDDOp::Add, f.borrowed(), g.borrowed()),
134        }
135    } else if OP == MTBDDOp::Sub as u8 {
136        match (m.get_node(f), m.get_node(g)) {
137            (Terminal(tf), Terminal(tg)) => {
138                let val = tf.borrow().sub(tg.borrow());
139                Done(m.get_terminal(val)?)
140            }
141            (Terminal(t), _) if t.borrow().is_zero() => Done(m.clone_edge(g)),
142            (_, Terminal(t)) if t.borrow().is_zero() => Done(m.clone_edge(f)),
143            (Terminal(t), _) | (_, Terminal(t)) if t.borrow().is_nan() => {
144                Done(m.get_terminal(T::nan())?)
145            }
146            _ => Binary(MTBDDOp::Sub, f.borrowed(), g.borrowed()),
147        }
148    } else if OP == MTBDDOp::Mul as u8 {
149        match (m.get_node(f), m.get_node(g)) {
150            (Terminal(tf), Terminal(tg)) => {
151                let val = tf.borrow().mul(tg.borrow());
152                Done(m.get_terminal(val)?)
153            }
154            (Terminal(t), _) if t.borrow().is_one() => Done(m.clone_edge(g)),
155            (_, Terminal(t)) if t.borrow().is_one() => Done(m.clone_edge(f)),
156            (Terminal(t), _) | (_, Terminal(t)) if t.borrow().is_nan() => {
157                Done(m.get_terminal(T::nan())?)
158            }
159            // Don't optimize the case where one of the operands is 0. 0 * NaN
160            // is still NaN.
161            _ if f > g => Binary(MTBDDOp::Mul, g.borrowed(), f.borrowed()),
162            _ => Binary(MTBDDOp::Mul, f.borrowed(), g.borrowed()),
163        }
164    } else if OP == MTBDDOp::Div as u8 {
165        match (m.get_node(f), m.get_node(g)) {
166            (Terminal(tf), Terminal(tg)) => {
167                let val = tf.borrow().div(tg.borrow());
168                Done(m.get_terminal(val)?)
169            }
170            (_, Terminal(t)) if t.borrow().is_one() => Done(m.clone_edge(f)),
171            (Terminal(t), _) | (_, Terminal(t)) if t.borrow().is_nan() => {
172                Done(m.get_terminal(T::nan())?)
173            }
174            _ => Binary(MTBDDOp::Div, f.borrowed(), g.borrowed()),
175        }
176    } else if OP == MTBDDOp::Min as u8 {
177        if f == g {
178            return Ok(Done(m.clone_edge(f)));
179        }
180        match (m.get_node(f), m.get_node(g)) {
181            (Terminal(tf), Terminal(tg)) => Done(match tf.borrow().partial_cmp(tg.borrow()) {
182                Some(Ordering::Less | Ordering::Equal) => m.clone_edge(f),
183                Some(Ordering::Greater) => m.clone_edge(g),
184                None => m.get_terminal(T::nan())?,
185            }),
186            (Terminal(t), _) | (_, Terminal(t)) if t.borrow().is_nan() => {
187                Done(m.get_terminal(T::nan())?)
188            }
189            _ if f > g => Binary(MTBDDOp::Min, g.borrowed(), f.borrowed()),
190            _ => Binary(MTBDDOp::Min, f.borrowed(), g.borrowed()),
191        }
192    } else if OP == MTBDDOp::Max as u8 {
193        if f == g {
194            return Ok(Done(m.clone_edge(f)));
195        }
196        match (m.get_node(f), m.get_node(g)) {
197            (Terminal(tf), Terminal(tg)) => Done(match tf.borrow().partial_cmp(tg.borrow()) {
198                Some(Ordering::Greater | Ordering::Equal) => m.clone_edge(f),
199                Some(Ordering::Less) => m.clone_edge(g),
200                None => m.get_terminal(T::nan())?,
201            }),
202            (Terminal(t), _) | (_, Terminal(t)) if t.borrow().is_nan() => {
203                Done(m.get_terminal(T::nan())?)
204            }
205            _ if f > g => Binary(MTBDDOp::Min, g.borrowed(), f.borrowed()),
206            _ => Binary(MTBDDOp::Min, f.borrowed(), g.borrowed()),
207        }
208    } else {
209        unreachable!("invalid binary operator")
210    })
211}
212
213// --- Function Interface ------------------------------------------------------
214
215//#[cfg(feature = "multi-threading")]
216//pub use apply_rec::mt::MTBDDFunctionMT;
217pub use apply_rec::MTBDDFunction;
218
219// --- Statistics --------------------------------------------------------------
220
221#[cfg(feature = "statistics")]
222struct StatCounters {
223    calls: std::sync::atomic::AtomicI64,
224    cache_queries: std::sync::atomic::AtomicI64,
225    cache_hits: std::sync::atomic::AtomicI64,
226    reduced: std::sync::atomic::AtomicI64,
227}
228
229#[cfg(feature = "statistics")]
230impl StatCounters {
231    #[allow(clippy::declare_interior_mutable_const)]
232    const INIT: StatCounters = StatCounters {
233        calls: std::sync::atomic::AtomicI64::new(0),
234        cache_queries: std::sync::atomic::AtomicI64::new(0),
235        cache_hits: std::sync::atomic::AtomicI64::new(0),
236        reduced: std::sync::atomic::AtomicI64::new(0),
237    };
238
239    fn print(counters: &[Self]) {
240        // spell-checker:ignore ctrs
241        for (i, ctrs) in counters.iter().enumerate() {
242            let calls = ctrs.calls.swap(0, std::sync::atomic::Ordering::Relaxed);
243            let cache_queries = ctrs
244                .cache_queries
245                .swap(0, std::sync::atomic::Ordering::Relaxed);
246            let cache_hits = ctrs
247                .cache_hits
248                .swap(0, std::sync::atomic::Ordering::Relaxed);
249            let reduced = ctrs.reduced.swap(0, std::sync::atomic::Ordering::Relaxed);
250
251            if calls == 0 {
252                continue;
253            }
254
255            let terminal_percent = (calls - cache_queries) as f32 / calls as f32 * 100.0;
256            let cache_hit_percent = cache_hits as f32 / cache_queries as f32 * 100.0;
257            let op = <MTBDDOp as oxidd_core::Countable>::from_usize(i);
258            eprintln!("  {op:?}: calls: {calls}, cache queries: {cache_queries} ({terminal_percent} % terminal cases), cache hits: {cache_hits} ({cache_hit_percent} %), reduced: {reduced}");
259        }
260    }
261}
262
263#[cfg(feature = "statistics")]
264static STAT_COUNTERS: [crate::StatCounters; <MTBDDOp as oxidd_core::Countable>::MAX_VALUE + 1] =
265    [crate::StatCounters::INIT; <MTBDDOp as oxidd_core::Countable>::MAX_VALUE + 1];
266
267#[cfg(feature = "statistics")]
268/// Print statistics to stderr
269pub fn print_stats() {
270    eprintln!("[oxidd_rules_mtbdd]");
271    crate::StatCounters::print(&STAT_COUNTERS);
272}
273
274macro_rules! stat {
275    (call $op:expr) => {
276        let _ = $op as usize;
277        #[cfg(feature = "statistics")]
278        STAT_COUNTERS[$op as usize]
279            .calls
280            .fetch_add(1, ::std::sync::atomic::Ordering::Relaxed);
281    };
282    (cache_query $op:expr) => {
283        let _ = $op as usize;
284        #[cfg(feature = "statistics")]
285        STAT_COUNTERS[$op as usize]
286            .cache_queries
287            .fetch_add(1, ::std::sync::atomic::Ordering::Relaxed);
288    };
289    (cache_hit $op:expr) => {
290        let _ = $op as usize;
291        #[cfg(feature = "statistics")]
292        STAT_COUNTERS[$op as usize]
293            .cache_hits
294            .fetch_add(1, ::std::sync::atomic::Ordering::Relaxed);
295    };
296    (reduced $op:expr) => {
297        let _ = $op as usize;
298        #[cfg(feature = "statistics")]
299        STAT_COUNTERS[$op as usize]
300            .reduced
301            .fetch_add(1, ::std::sync::atomic::Ordering::Relaxed);
302    };
303}
304
305pub(crate) use stat;