boolean_circuit/
circuit.rs1use std::collections::HashSet;
2
3use itertools::Itertools;
4
5use crate::{
6    gate::{GraphIterator, PostVisitIterator},
7    Gate, Operation,
8};
9
10#[derive(Default)]
14pub struct Circuit {
15    outputs: Vec<Gate>,
17    input_names: Option<Vec<String>>,
22    output_names: Vec<String>,
24}
25
26impl From<Gate> for Circuit {
27    fn from(gate: Gate) -> Self {
28        Circuit {
29            outputs: vec![gate],
30            input_names: None,
31            output_names: vec![String::new()],
32        }
33    }
34}
35
36impl Circuit {
37    pub fn from_unnamed_outputs(outputs: impl IntoIterator<Item = Gate>) -> Self {
41        Self::from_named_outputs(outputs.into_iter().map(|n| (n, String::new())))
42    }
43
44    pub fn with_input_order(
46        self,
47        input_names: impl IntoIterator<Item = impl ToString>,
48    ) -> Result<Self, String> {
49        let input_names = input_names.into_iter().map(|n| n.to_string()).collect_vec();
50        let inputs_in_circuit = self.input_names_from_traversal().collect::<HashSet<_>>();
51        let input_names_set = input_names
52            .iter()
53            .map(|n| n.as_str())
54            .collect::<HashSet<_>>();
55        if input_names_set.len() != input_names.len() {
56            return Err("Duplicate input names in list.".to_string());
57        }
58        if inputs_in_circuit != input_names_set {
59            return Err(format!(
60                "Input names do not match circuit inputs:\n{}\n  !=\n{}",
61                input_names.iter().sorted().format(", "),
62                inputs_in_circuit.iter().sorted().format(", ")
63            ));
64        }
65
66        Ok(Circuit {
67            input_names: Some(input_names),
68            ..self
69        })
70    }
71
72    pub fn from_named_outputs(items: impl IntoIterator<Item = (Gate, impl ToString)>) -> Self {
80        let mut seen_names: HashSet<_> = Default::default();
81        let mut circuit = Self::default();
82        for (gate, name) in items {
83            let name = name.to_string();
84            if !name.is_empty() && !seen_names.insert(name.clone()) {
85                panic!("Duplicate output name {name}");
86            }
87            circuit.outputs.push(gate);
88            circuit.output_names.push(name);
89        }
90        circuit
91    }
92
93    pub fn input_names(&self) -> impl Iterator<Item = &str> + '_ {
96        if let Some(names) = &self.input_names {
97            Box::new(names.iter().map(|n| n.as_str())) as Box<dyn Iterator<Item = &str>>
98        } else {
99            Box::new(self.input_names_from_traversal())
100        }
101    }
102
103    pub fn outputs(&self) -> &[Gate] {
105        &self.outputs
106    }
107
108    pub fn output_names(&self) -> &[String] {
110        &self.output_names
111    }
112
113    pub fn named_outputs(&self) -> impl Iterator<Item = (&Gate, &String)> {
116        self.outputs().iter().zip_eq(self.output_names())
117    }
118
119    pub fn iter(&self) -> impl Iterator<Item = &Gate> {
122        GraphIterator::new(&self.outputs)
123    }
124
125    pub fn post_visit_iter(&self) -> impl Iterator<Item = &Gate> {
129        PostVisitIterator::new(self.outputs.iter())
130    }
131
132    fn input_names_from_traversal(&self) -> impl Iterator<Item = &str> {
135        self.iter()
136            .filter_map(|gate| match gate.operation() {
137                Operation::Variable(name) => Some(name.as_str()),
138                _ => None,
139            })
140            .unique()
141    }
142}
143
144#[cfg(test)]
145mod test {
146    use super::*;
147
148    #[test]
149    fn input_order() {
150        let c = Circuit::from(Gate::from("a") & Gate::from("b"));
151        assert_eq!(c.input_names().collect::<Vec<_>>(), vec!["a", "b"]);
152        let c = c.with_input_order(["b", "a"]).unwrap();
153        assert_eq!(
154            c.input_names().collect::<Vec<_>>(),
155            vec!["b".to_string(), "a".to_string()]
156        );
157    }
158}