boolean_circuit/
deep_copy.rs

1use std::collections::HashMap;
2
3use itertools::Itertools;
4
5use crate::{Circuit, Gate, Operation};
6
7/// A trait used to return names for input and output gates
8/// when creating copies of circuits.
9pub trait NameMapper {
10    /// Given an input name in the original circuit, returns
11    /// the name in the copy.
12    fn map_input_name(&mut self, input_name: &str) -> String;
13    /// Given an output name in the original circuit, returns
14    /// the name in the copy.
15    ///
16    /// For both the argument and the return value, the empty string
17    /// counts as "unnamed output".
18    fn map_output_name(&mut self, output_name: &str) -> String;
19}
20
21/// Creates a deep / independent copy of the given circuit.
22/// The `name_mapper` is responsible for returning new names for the inputs and outputs.
23/// If the `name_mapper` returns existing names for inputs, these gates will still be shared,
24/// since input nodes are identified by their names.
25///
26/// The name mapper will only be called once for each input or output name.
27///
28/// Returns the new circuit.
29pub fn deep_copy(circuit: &Circuit, mut name_mapper: impl NameMapper) -> Result<Circuit, String> {
30    let output_names = circuit
31        .output_names()
32        .iter()
33        .map(|n| name_mapper.map_output_name(n))
34        .collect_vec();
35    let mut deep_copy = DeepCopy::new(name_mapper);
36    let output_gates = circuit
37        .outputs()
38        .iter()
39        .map(|o| deep_copy.copy(o))
40        .collect_vec();
41    let input_names = circuit
42        .input_names()
43        .map(|n| deep_copy.copy_of_input(n).to_string_as_tree())
44        .unique()
45        .collect_vec();
46
47    Circuit::from_named_outputs(output_gates.into_iter().zip(output_names))
48        .with_input_order(input_names)
49}
50
51pub fn deep_copy_of_gate(gate: &Gate, name_mapper: impl NameMapper) -> Gate {
52    DeepCopy::new(name_mapper).copy(gate)
53}
54
55struct DeepCopy<'a, N> {
56    name_mapper: N,
57    input_name_substitutions: HashMap<&'a str, Gate>,
58    gate_substitutions: HashMap<usize, Gate>,
59}
60
61impl<'a, N: NameMapper> DeepCopy<'a, N> {
62    fn new(name_mapper: N) -> Self {
63        Self {
64            name_mapper,
65            input_name_substitutions: Default::default(),
66            gate_substitutions: HashMap::new(),
67        }
68    }
69
70    fn copy(&mut self, gate: &'a Gate) -> Gate {
71        for n in gate.post_visit_iter() {
72            let substitution = match n.operation() {
73                Operation::Variable(name) => self.copy_of_input(name.as_str()),
74                Operation::Constant(value) => Gate::from(*value),
75                Operation::Negation(inner) => !self.sub(inner),
76                Operation::Conjunction(left, right) => self.sub(left) & self.sub(right),
77                Operation::Disjunction(left, right) => self.sub(left) | self.sub(right),
78                Operation::Xor(left, right) => self.sub(left) ^ self.sub(right),
79            };
80            self.gate_substitutions.insert(n.id(), substitution);
81        }
82        self.sub(gate)
83    }
84
85    fn copy_of_input(&mut self, name: &'a str) -> Gate {
86        self.input_name_substitutions
87            .entry(name)
88            .or_insert_with(|| Gate::from(self.name_mapper.map_input_name(name)))
89            .clone()
90    }
91
92    fn sub(&self, node: &'a Gate) -> Gate {
93        self.gate_substitutions.get(&node.id()).unwrap().clone()
94    }
95}
96
97#[cfg(test)]
98mod test {
99    use super::*;
100
101    #[derive(Default)]
102    struct CountedNames {
103        counter: usize,
104    }
105
106    impl NameMapper for CountedNames {
107        fn map_input_name(&mut self, _: &str) -> String {
108            self.counter += 1;
109            format!("copy_{}", self.counter)
110        }
111
112        fn map_output_name(&mut self, _: &str) -> String {
113            self.counter += 1;
114            format!("copy_{}", self.counter)
115        }
116    }
117
118    #[test]
119    fn simple() {
120        let copied_circuit = deep_copy_of_gate(&Gate::from("v1"), CountedNames::default());
121        assert_eq!(copied_circuit.to_string_as_tree(), "copy_1");
122    }
123
124    #[test]
125    fn intermediate1() {
126        let gate = (Gate::from("v1") & Gate::from("v2")) | !Gate::from("v1");
127        let copied_circuit = deep_copy_of_gate(&gate, CountedNames::default());
128        assert_eq!(
129            copied_circuit.to_string_as_tree(),
130            "((copy_1 & copy_2) | !copy_1)"
131        );
132    }
133
134    #[test]
135    fn intermediate2() {
136        let gate = (Gate::from("v3") ^ Gate::from("v3")) & Gate::from(true) | Gate::from(false);
137        let copied_circuit = deep_copy_of_gate(&gate, CountedNames::default());
138        assert_eq!(
139            copied_circuit.to_string_as_tree(),
140            "(((copy_1 ^ copy_1) & true) | false)"
141        );
142    }
143
144    impl NameMapper for HashMap<&str, String> {
145        fn map_input_name(&mut self, n: &str) -> String {
146            self[n].clone()
147        }
148        fn map_output_name(&mut self, n: &str) -> String {
149            self[n].clone()
150        }
151    }
152
153    #[test]
154    fn with_input_repetitions() {
155        let substitutions = HashMap::from([("v1", "x".to_string()), ("v2", "x".to_string())]);
156        let circuit = Gate::from("v1") & Gate::from("v2");
157        let copied_circuit = deep_copy_of_gate(&circuit, substitutions);
158        assert_eq!(copied_circuit.to_string_as_tree(), "(x & x)");
159    }
160
161    #[test]
162    fn circuit_copy() {
163        let out_b = Gate::from("v1") | Gate::from("v2");
164        let out_a = Gate::from("v1") ^ Gate::from("v3");
165        let circuit = Circuit::from_named_outputs([(out_a, "a"), (out_b, "b")])
166            .with_input_order(["v2", "v3", "v1"])
167            .unwrap();
168        let substitutions = HashMap::from([
169            ("v1", "r1".to_string()),
170            ("v2", "r2".to_string()),
171            ("v3", "r3".to_string()),
172            ("a", "x".to_string()),
173            ("b", "y".to_string()),
174        ]);
175        let copy = deep_copy(&circuit, substitutions).unwrap();
176        assert_eq!(copy.input_names().collect_vec(), vec!["r2", "r3", "r1"]);
177        assert_eq!(copy.output_names(), vec!["x".to_string(), "y".to_string()]);
178    }
179
180    #[test]
181    fn input_order_with_repetitions() {
182        let out_b = Gate::from("v1") | Gate::from("v2");
183        let out_a = Gate::from("v1") ^ Gate::from("v3");
184        let circuit = Circuit::from_named_outputs([(out_a, "a"), (out_b, "b")])
185            .with_input_order(["v2", "v3", "v1"])
186            .unwrap();
187        let substitutions = HashMap::from([
188            ("v1", "r2".to_string()),
189            ("v2", "r2".to_string()),
190            ("v3", "r1".to_string()),
191            ("a", "x".to_string()),
192            ("b", "y".to_string()),
193        ]);
194        let copy = deep_copy(&circuit, substitutions).unwrap();
195        assert_eq!(copy.input_names().collect_vec(), vec!["r2", "r1"]);
196        assert_eq!(copy.output_names(), vec!["x".to_string(), "y".to_string()]);
197    }
198}