1use std::collections::HashMap;
2
3use itertools::Itertools;
4
5use crate::{Circuit, Gate, Operation};
6
7pub trait NameMapper {
10 fn map_input_name(&mut self, input_name: &str) -> String;
13 fn map_output_name(&mut self, output_name: &str) -> String;
19}
20
21pub fn deep_copy(circuit: &Circuit, mut name_mapper: impl NameMapper) -> Result<Circuit, String> {
30 let output_names = circuit
31 .output_names()
32 .iter()
33 .map(|n| name_mapper.map_output_name(n))
34 .collect_vec();
35 let mut deep_copy = DeepCopy::new(name_mapper);
36 let output_gates = circuit
37 .outputs()
38 .iter()
39 .map(|o| deep_copy.copy(o))
40 .collect_vec();
41 let input_names = circuit
42 .input_names()
43 .map(|n| deep_copy.copy_of_input(n).to_string_as_tree())
44 .unique()
45 .collect_vec();
46
47 Circuit::from_named_outputs(output_gates.into_iter().zip(output_names))
48 .with_input_order(input_names)
49}
50
51pub fn deep_copy_of_gate(gate: &Gate, name_mapper: impl NameMapper) -> Gate {
52 DeepCopy::new(name_mapper).copy(gate)
53}
54
55struct DeepCopy<'a, N> {
56 name_mapper: N,
57 input_name_substitutions: HashMap<&'a str, Gate>,
58 gate_substitutions: HashMap<usize, Gate>,
59}
60
61impl<'a, N: NameMapper> DeepCopy<'a, N> {
62 fn new(name_mapper: N) -> Self {
63 Self {
64 name_mapper,
65 input_name_substitutions: Default::default(),
66 gate_substitutions: HashMap::new(),
67 }
68 }
69
70 fn copy(&mut self, gate: &'a Gate) -> Gate {
71 for n in gate.post_visit_iter() {
72 let substitution = match n.operation() {
73 Operation::Variable(name) => self.copy_of_input(name.as_str()),
74 Operation::Constant(value) => Gate::from(*value),
75 Operation::Negation(inner) => !self.sub(inner),
76 Operation::Conjunction(left, right) => self.sub(left) & self.sub(right),
77 Operation::Disjunction(left, right) => self.sub(left) | self.sub(right),
78 Operation::Xor(left, right) => self.sub(left) ^ self.sub(right),
79 };
80 self.gate_substitutions.insert(n.id(), substitution);
81 }
82 self.sub(gate)
83 }
84
85 fn copy_of_input(&mut self, name: &'a str) -> Gate {
86 self.input_name_substitutions
87 .entry(name)
88 .or_insert_with(|| Gate::from(self.name_mapper.map_input_name(name)))
89 .clone()
90 }
91
92 fn sub(&self, node: &'a Gate) -> Gate {
93 self.gate_substitutions.get(&node.id()).unwrap().clone()
94 }
95}
96
97#[cfg(test)]
98mod test {
99 use super::*;
100
101 #[derive(Default)]
102 struct CountedNames {
103 counter: usize,
104 }
105
106 impl NameMapper for CountedNames {
107 fn map_input_name(&mut self, _: &str) -> String {
108 self.counter += 1;
109 format!("copy_{}", self.counter)
110 }
111
112 fn map_output_name(&mut self, _: &str) -> String {
113 self.counter += 1;
114 format!("copy_{}", self.counter)
115 }
116 }
117
118 #[test]
119 fn simple() {
120 let copied_circuit = deep_copy_of_gate(&Gate::from("v1"), CountedNames::default());
121 assert_eq!(copied_circuit.to_string_as_tree(), "copy_1");
122 }
123
124 #[test]
125 fn intermediate1() {
126 let gate = (Gate::from("v1") & Gate::from("v2")) | !Gate::from("v1");
127 let copied_circuit = deep_copy_of_gate(&gate, CountedNames::default());
128 assert_eq!(
129 copied_circuit.to_string_as_tree(),
130 "((copy_1 & copy_2) | !copy_1)"
131 );
132 }
133
134 #[test]
135 fn intermediate2() {
136 let gate = (Gate::from("v3") ^ Gate::from("v3")) & Gate::from(true) | Gate::from(false);
137 let copied_circuit = deep_copy_of_gate(&gate, CountedNames::default());
138 assert_eq!(
139 copied_circuit.to_string_as_tree(),
140 "(((copy_1 ^ copy_1) & true) | false)"
141 );
142 }
143
144 impl NameMapper for HashMap<&str, String> {
145 fn map_input_name(&mut self, n: &str) -> String {
146 self[n].clone()
147 }
148 fn map_output_name(&mut self, n: &str) -> String {
149 self[n].clone()
150 }
151 }
152
153 #[test]
154 fn with_input_repetitions() {
155 let substitutions = HashMap::from([("v1", "x".to_string()), ("v2", "x".to_string())]);
156 let circuit = Gate::from("v1") & Gate::from("v2");
157 let copied_circuit = deep_copy_of_gate(&circuit, substitutions);
158 assert_eq!(copied_circuit.to_string_as_tree(), "(x & x)");
159 }
160
161 #[test]
162 fn circuit_copy() {
163 let out_b = Gate::from("v1") | Gate::from("v2");
164 let out_a = Gate::from("v1") ^ Gate::from("v3");
165 let circuit = Circuit::from_named_outputs([(out_a, "a"), (out_b, "b")])
166 .with_input_order(["v2", "v3", "v1"])
167 .unwrap();
168 let substitutions = HashMap::from([
169 ("v1", "r1".to_string()),
170 ("v2", "r2".to_string()),
171 ("v3", "r3".to_string()),
172 ("a", "x".to_string()),
173 ("b", "y".to_string()),
174 ]);
175 let copy = deep_copy(&circuit, substitutions).unwrap();
176 assert_eq!(copy.input_names().collect_vec(), vec!["r2", "r3", "r1"]);
177 assert_eq!(copy.output_names(), vec!["x".to_string(), "y".to_string()]);
178 }
179
180 #[test]
181 fn input_order_with_repetitions() {
182 let out_b = Gate::from("v1") | Gate::from("v2");
183 let out_a = Gate::from("v1") ^ Gate::from("v3");
184 let circuit = Circuit::from_named_outputs([(out_a, "a"), (out_b, "b")])
185 .with_input_order(["v2", "v3", "v1"])
186 .unwrap();
187 let substitutions = HashMap::from([
188 ("v1", "r2".to_string()),
189 ("v2", "r2".to_string()),
190 ("v3", "r1".to_string()),
191 ("a", "x".to_string()),
192 ("b", "y".to_string()),
193 ]);
194 let copy = deep_copy(&circuit, substitutions).unwrap();
195 assert_eq!(copy.input_names().collect_vec(), vec!["r2", "r1"]);
196 assert_eq!(copy.output_names(), vec!["x".to_string(), "y".to_string()]);
197 }
198}