boolean_circuit/
disjoint_union.rs1use std::collections::HashSet;
2
3use itertools::Itertools;
4
5use crate::{
6 deep_copy::{deep_copy, NameMapper},
7 Circuit,
8};
9
10pub fn disjoint_union<'a>(circuits: impl Iterator<Item = &'a Circuit>) -> Circuit {
14 let mut dispenser = NameDispenser::default();
15
16 let circuits = circuits
17 .map(|circuit| deep_copy(circuit, &mut dispenser).unwrap())
18 .collect_vec();
19
20 Circuit::from_named_outputs(
21 circuits
22 .iter()
23 .flat_map(|circuit| circuit.named_outputs())
24 .map(|(g, n)| (g.clone(), n)),
25 )
26 .with_input_order(circuits.iter().flat_map(|circuit| circuit.input_names()))
27 .unwrap()
28}
29
30#[derive(Default)]
33struct NameDispenser {
34 counter: u64,
35 seen_names: HashSet<String>,
36}
37
38impl NameMapper for &mut NameDispenser {
39 fn map_input_name(&mut self, input_name: &str) -> String {
40 self.map_name(input_name)
41 }
42
43 fn map_output_name(&mut self, output_name: &str) -> String {
44 if output_name.is_empty() {
45 String::default()
46 } else {
47 self.map_name(output_name)
48 }
49 }
50}
51
52impl NameDispenser {
53 fn map_name(&mut self, name: &str) -> String {
57 let mut name = name.to_string();
58 loop {
59 if self.seen_names.insert(name.clone()) {
60 return name;
61 }
62 self.counter += 1;
63 name = format!("v_{}", self.counter);
64 }
65 }
66}
67
68#[cfg(test)]
69mod test {
70 use itertools::Itertools;
71
72 use crate::Gate;
73
74 use super::*;
75
76 #[test]
77 fn name_dispenser() {
78 let names = ["a", "b", "v_3", "b", "b", "v_1", "c", "v_5", "a"];
79 let mut dispenser = &mut NameDispenser::default();
80
81 assert_eq!(
82 &names.map(|n| dispenser.map_input_name(n)).join(" "),
83 "a b v_3 v_1 v_2 v_4 c v_5 v_6"
84 );
85 }
86
87 #[test]
88 fn disjoint_union_test() {
89 let or = Gate::from("v_1") | Gate::from("v_2");
90 let xor = Gate::from("v_1") ^ Gate::from("v_3");
91 let circuit = Circuit::from_named_outputs([(xor, "xor"), (or, "or")])
92 .with_input_order(["v_2", "v_3", "v_1"])
93 .unwrap();
94 let circuit = disjoint_union([&circuit, &circuit].into_iter());
95 assert_eq!(
96 circuit.output_names().iter().join(", "),
97 "xor, or, v_4, v_5"
98 );
99 assert_eq!(
100 circuit.input_names().join(", "),
101 "v_2, v_3, v_1, v_8, v_7, v_6"
102 );
103 assert_eq!(
104 circuit
105 .outputs()
106 .iter()
107 .map(|g| g.to_string_as_tree())
108 .join("\n"),
109 "(v_1 ^ v_3)\n(v_1 | v_2)\n(v_6 ^ v_7)\n(v_6 | v_8)"
110 );
111 }
112}