libreda_logic/networks/mig/
mod.rs

1// SPDX-FileCopyrightText: 2022 Thomas Kramer <code@tkramer.ch>
2//
3// SPDX-License-Identifier: AGPL-3.0-or-later
4
5//! Majority-inverter graph (MIG) data structures.
6
7use crate::{
8    helpers::sort3,
9    network::*,
10    networks::generic_network::*,
11    truth_table::small_lut::{truth_table_library, SmallTruthTable},
12};
13
14use super::SimplifyResult;
15
16mod transforms;
17
18/// Majority-inverter graph.
19///
20/// A logic network consisting of three-input majority nodes.
21///
22/// # Examples
23/// ```
24/// use libreda_logic::traits::*;
25/// use libreda_logic::network::*;
26/// use libreda_logic::networks::mig::*;
27///
28/// // Create an empty network.
29/// let mut mig = Mig::new();
30///
31/// // Create two inputs.
32/// let a = mig.create_primary_input();
33/// let b = mig.create_primary_input();
34///
35/// // Create the boolean AND of the two inputs.
36/// let a_and_b = mig.create_and(a, b);
37/// ```
38pub type Mig = LogicNetwork<Maj3Node>;
39
40/// Identifier for nodes in the majority-inverter graph.
41pub type MigNodeId = NodeId;
42
43/// Node in the majority-inverter graph.
44#[derive(Clone, Copy, Debug, Hash, PartialEq, PartialOrd, Eq, Ord)]
45pub struct Maj3Node {
46    pub(super) a: MigNodeId,
47    pub(super) b: MigNodeId,
48    pub(super) c: MigNodeId,
49    /// Number of nodes which reference to this node.
50    num_references: usize,
51}
52
53impl Maj3Node {
54    /// Create a new three-input majority node.
55    fn new([a, b, c]: [MigNodeId; 3]) -> Self {
56        Self {
57            a,
58            b,
59            c,
60            num_references: 0,
61        }
62    }
63
64    /// Get the child node IDs as an array.
65    pub(crate) fn to_array(self) -> [MigNodeId; 3] {
66        [self.a, self.b, self.c]
67    }
68
69    /// Get an array of mutable references to the node inputs.
70    pub(crate) fn to_array_mut(&mut self) -> [&mut MigNodeId; 3] {
71        [&mut self.a, &mut self.b, &mut self.c]
72    }
73}
74
75impl NetworkNode for Maj3Node {
76    type NodeId = MigNodeId;
77
78    fn num_inputs(&self) -> usize {
79        3
80    }
81
82    fn get_input(&self, i: usize) -> Self::NodeId {
83        match i {
84            0 => self.a,
85            1 => self.b,
86            2 => self.c,
87            _ => panic!("index out of bounds"),
88        }
89    }
90
91    fn function(&self) -> SmallTruthTable {
92        truth_table_library::maj3()
93    }
94
95    fn normalized(self) -> SimplifyResult<Self, Self::NodeId> {
96        Mig::simplify_node(self)
97    }
98}
99
100impl MutNetworkNode for Maj3Node {
101    fn set_input(&mut self, i: usize, signal: Self::NodeId) {
102        let input = match i {
103            0 => &mut self.a,
104            1 => &mut self.b,
105            2 => &mut self.c,
106            _ => panic!("index out of bounds"),
107        };
108
109        *input = signal;
110    }
111}
112
113impl NetworkNodeWithReferenceCount for Maj3Node {
114    fn num_references(&self) -> usize {
115        self.num_references
116    }
117}
118
119impl MutNetworkNodeWithReferenceCount for Maj3Node {
120    fn reference(&mut self) {
121        self.num_references += 1;
122    }
123
124    fn dereference(&mut self) {
125        self.num_references -= 1;
126    }
127}
128
129impl IntoIterator for Maj3Node {
130    type Item = MigNodeId;
131
132    type IntoIter = std::array::IntoIter<MigNodeId, 3>;
133
134    fn into_iter(self) -> Self::IntoIter {
135        // Sanity check: inputs must be sorted.
136        debug_assert_eq!(
137            self.to_array(),
138            sort3(self.to_array()),
139            "inputs must be sorted"
140        );
141
142        self.to_array().into_iter()
143    }
144}
145
146impl Mig {
147    /// Simplify the node without making changes to the graph.
148    ///
149    /// Either returns the ID of an existing node which is equivalent to the given node,
150    /// or returns a node with simplified input, or returns the unmodified node.
151    fn simplify_node(node: Maj3Node) -> SimplifyResult<Maj3Node, NodeId> {
152        SimplifyResult::new_node(node)
153            // Use majority rule
154            .and_then(Self::simplify_node_by_majority)
155            .and_then(Self::normalize_by_input_inversions)
156            .map_unsimplified(Self::normalize_node_by_commutativity) // TODO: Is this necessary? Might be done in simplify_node_with_hashtable
157    }
158
159    /// Invert the inputs such that
160    /// the majority of inputs is not inverted.
161    /// Returns a tuple `(n, need_inversion)`. Where `n` is the modified node and `need_inversion` is set to true
162    /// iff the node inputs have been inverted.
163    fn normalize_by_input_inversions(node: Maj3Node) -> SimplifyResult<Maj3Node, NodeId> {
164        let [a, b, c] = node.to_array();
165
166        // Convert inputs into a unique form.
167        // * sort them
168        // * enventually invert this signal output such that the majority of the inputs is not inverted
169        let (a, b, c, invert_output) = {
170            let num_inversions =
171                (a.is_inverted() as u8) + (b.is_inverted() as u8) + (c.is_inverted() as u8);
172
173            let (a, b, c, invert_output) = if num_inversions >= 2 {
174                (a.invert(), b.invert(), c.invert(), true)
175            } else {
176                (a, b, c, false)
177            };
178            (a, b, c, invert_output)
179        };
180
181        let node = Maj3Node::new([a, b, c]);
182        SimplifyResult::Node(node, invert_output)
183    }
184
185    /// Sort the inputs of the node.
186    /// Use the rule `M(x, y, z) == M(y, z, x) == M(z, y, x)`.
187    fn normalize_node_by_commutativity(node: Maj3Node) -> Maj3Node {
188        let [a, b, c] = sort3(node.to_array());
189        Maj3Node { a, b, c, ..node }
190    }
191
192    /// Simplify the node to a single signal if there are either two equal inputs or an input `x` and another input `y'`.
193    /// Both cases decide the majority function.
194    /// Returns a node ID if simplification was successful, otherwise returns the original node.
195    ///
196    /// Majority rule:
197    /// * if (x == y): M(x, y, z) = x = y
198    /// * if (x == y'): M(x, y, z) = z
199    fn simplify_node_by_majority(node: Maj3Node) -> SimplifyResult<Maj3Node, NodeId> {
200        let [a, b, c] = [node.a, node.b, node.c];
201        match [a, b, c] {
202            // M(x, x, _) => x
203            [x, y, _] | [y, _, x] | [_, x, y] if x == y => SimplifyResult::new_id(x),
204            // M(x, x', z) => z
205            [x, y, z] | [y, z, x] | [z, x, y] if x == y.invert() => SimplifyResult::new_id(z),
206            _ => SimplifyResult::new_node(node),
207        }
208    }
209}
210
211impl HomogeneousNetwork for Mig {
212    const NUM_NODE_INPUTS: usize = 3;
213
214    fn function(&self) -> Self::NodeFunction {
215        crate::truth_table::small_lut::truth_table_library::maj3()
216    }
217}
218
219impl SubstituteInNode for Mig {
220    fn substitute_in_node(
221        &mut self,
222        node: Self::NodeId,
223        old_signal: Self::Signal,
224        new_signal: Self::Signal,
225    ) {
226        if let Some(n) = self.get_logic_node_mut(node) {
227            // Replace the each occurrence of the old signal with the new signal.
228            n.to_array_mut()
229                .into_iter()
230                .filter(|input| **input == old_signal)
231                .for_each(|input| *input = new_signal);
232        }
233    }
234}
235
236impl UnaryOp for Mig {
237    fn create_not(&mut self, signal: Self::Signal) -> Self::Signal {
238        signal.invert()
239    }
240}
241
242impl BinaryOp for Mig {
243    fn create_and(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
244        self.create_maj3(MigNodeId::zero(), a, b)
245    }
246
247    fn create_or(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
248        self.create_maj3(MigNodeId::zero().invert(), a, b)
249    }
250
251    fn create_nand(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
252        self.create_and(a, b).invert()
253    }
254
255    fn create_nor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
256        self.create_or(a, b).invert()
257    }
258
259    fn create_xor(&mut self, a: Self::Signal, b: Self::Signal) -> Self::Signal {
260        let x = self.create_and(a, b.invert());
261        let y = self.create_and(a.invert(), b);
262        self.create_or(x, y)
263    }
264}
265
266impl TernaryOp for Mig {
267    fn create_maj3(&mut self, a: Self::Signal, b: Self::Signal, c: Self::Signal) -> Self::Signal {
268        let node = Maj3Node::new([a, b, c]);
269
270        // TODO: doing double the work. Normalization is already called by `create_node`
271        match Self::simplify_node(node) {
272            SimplifyResult::Node(n, invert) => self.create_node(n).invert_if(invert),
273            SimplifyResult::Simplified(s, invert) => s.invert_if(invert),
274        }
275    }
276}
277
278#[test]
279fn test_mig_create_constants() {
280    let mig = Mig::new();
281
282    let zero = mig.get_constant(false);
283    let one = mig.get_constant(true);
284
285    assert!(zero.is_constant());
286    assert!(one.is_constant());
287
288    assert!(zero.is_zero());
289
290    assert_ne!(zero, one);
291    assert_eq!(zero.invert(), one);
292}
293
294#[test]
295fn test_mig_create_primary_inputs() {
296    let mut mig = Mig::new();
297    let a = mig.create_primary_input();
298    let b = mig.create_primary_input();
299    let c = mig.create_primary_input();
300
301    assert_ne!(a, b);
302    assert_ne!(b, c);
303
304    assert!(mig.is_input(a));
305    assert!(mig.is_input(b));
306    assert!(mig.is_input(c));
307
308    assert!(!mig.is_constant(a));
309}
310
311#[test]
312fn test_mig_node_deduplication() {
313    let mut mig = Mig::new();
314
315    let a = mig.create_primary_input();
316    let b = mig.create_primary_input();
317    let c = mig.create_primary_input();
318
319    let maj_abc_1 = mig.create_maj3(a, b, c);
320    let maj_abc_2 = mig.create_maj3(a, b, c);
321
322    assert_eq!(maj_abc_1, maj_abc_2);
323}
324
325#[test]
326fn test_mig_node_simplification_by_commutativity() {
327    let mut mig = Mig::new();
328
329    let a = mig.create_primary_input();
330    let b = mig.create_primary_input();
331    let c = mig.create_primary_input();
332
333    let maj_abc = mig.create_maj3(a, b, c);
334
335    // Deduplication under permutation of inputs.
336    assert_eq!(maj_abc, mig.create_maj3(c, b, a));
337
338    // Deduplication under permutation and inversion of inputs.
339    // M(a, b, c) == M(a', b', c')'
340    assert_eq!(
341        maj_abc,
342        mig.create_maj3(c.invert(), b.invert(), a.invert()).invert()
343    );
344}
345
346#[test]
347fn test_mig_node_simplification_by_majority() {
348    let mut mig = Mig::new();
349
350    let a = mig.create_primary_input();
351    let b = mig.create_primary_input();
352
353    let maj_aab = mig.create_maj3(a, a, b);
354    assert_eq!(maj_aab, a);
355}
356
357#[test]
358fn test_mig_node_simplification_with_constants() {
359    let mut mig = Mig::new();
360
361    let a = mig.create_primary_input();
362
363    let zero = mig.get_constant(false);
364
365    let a_and_zero = mig.create_and(a, zero);
366    //let a_or_zero = mig.create_or(a, zero);
367
368    assert_eq!(a_and_zero, zero);
369    //assert_eq!(a_or_zero, a);
370}
371
372#[test]
373fn test_mig_simulation() {
374    use crate::native_boolean_functions::NativeBooleanFunction;
375    use crate::traits::BooleanSystem;
376
377    // Construct a one-bit full adder.
378    let mut mig = Mig::new();
379    let [in1, in2, carry_in] = mig.create_primary_inputs();
380
381    let sum = mig.create_xor3(in1, in2, carry_in);
382    let carry = mig.create_maj3(in1, in2, carry_in);
383
384    let output_sum = mig.create_primary_output(sum);
385    let output_carry = mig.create_primary_output(carry);
386
387    let simulator = crate::network_simulator::RecursiveSim::new(&mig);
388
389    // Reference model of the full adder.
390    fn full_adder([a, b, c]: [bool; 3]) -> [bool; 2] {
391        let sum = (a as usize) + (b as usize) + (c as usize);
392        [
393            sum & 0b1 == 1,
394            sum & 0b10 == 0b10, // carry
395        ]
396    }
397
398    let reference = NativeBooleanFunction::new(full_adder);
399
400    for i in 0..(1 << 3) {
401        let inputs = [0, 1, 2].map(|idx| (i >> idx) & 1 == 1);
402
403        let exptected_output = [0, 1].map(|out| reference.evaluate_term(&out, &inputs));
404        let actual_output: Vec<_> = simulator
405            .simulate(&[output_sum, output_carry], &inputs)
406            .collect();
407
408        dbg!(inputs);
409
410        assert_eq!(exptected_output.as_slice(), actual_output.as_slice());
411    }
412}