circomspect_program_analysis/
constraint_analysis.rs

1use log::{debug, trace};
2use std::collections::{HashMap, HashSet};
3
4use program_structure::cfg::Cfg;
5use program_structure::intermediate_representation::variable_meta::VariableMeta;
6use program_structure::intermediate_representation::AssignOp;
7use program_structure::ir::variable_meta::VariableUse;
8use program_structure::ir::{Statement, VariableName};
9
10/// This analysis computes the transitive closure of the constraint relation.
11/// (Note that the resulting relation will be symmetric, but not reflexive in
12/// general.)
13#[derive(Clone, Default)]
14pub struct ConstraintAnalysis {
15    constraint_map: HashMap<VariableName, HashSet<VariableName>>,
16    declarations: HashMap<VariableName, VariableUse>,
17    definitions: HashMap<VariableName, VariableUse>,
18}
19
20impl ConstraintAnalysis {
21    fn new() -> ConstraintAnalysis {
22        ConstraintAnalysis::default()
23    }
24
25    /// Add the variable use corresponding to the definition of the variable.
26    fn add_definition(&mut self, var: &VariableUse) {
27        // TODO: Since we don't version components and signals, we may end up
28        // overwriting component initializations here. For example, in the
29        // following case the component initialization will be clobbered.
30        //
31        //   component c[2];
32        //   ...
33        //   c[0].in[0] <== 0;
34        //   c[1].in[1] <== 1;
35        //
36        // The constraint map should probably track VariableAccesses rather
37        // than VariableNames.
38        self.definitions.insert(var.name().clone(), var.clone());
39    }
40
41    /// Get the variable use corresponding to the definition of the variable.
42    pub fn get_definition(&self, var: &VariableName) -> Option<VariableUse> {
43        self.definitions.get(var).cloned()
44    }
45
46    pub fn definitions(&self) -> impl Iterator<Item = &VariableUse> {
47        self.definitions.values()
48    }
49
50    /// Add the variable use corresponding to the declaration of the variable.
51    fn add_declaration(&mut self, var: &VariableUse) {
52        self.declarations.insert(var.name().clone(), var.clone());
53    }
54
55    /// Get the variable use corresponding to the declaration of the variable.
56    pub fn get_declaration(&self, var: &VariableName) -> Option<VariableUse> {
57        self.declarations.get(var).cloned()
58    }
59
60    pub fn declarations(&self) -> impl Iterator<Item = &VariableUse> {
61        self.declarations.values()
62    }
63
64    /// Add a constraint from source to sink.
65    fn add_constraint_step(&mut self, source: &VariableName, sink: &VariableName) {
66        let sinks = self.constraint_map.entry(source.clone()).or_default();
67        sinks.insert(sink.clone());
68    }
69
70    /// Returns variables constrained in a single step by `source`.
71    pub fn single_step_constraint(&self, source: &VariableName) -> HashSet<VariableName> {
72        self.constraint_map.get(source).cloned().unwrap_or_default()
73    }
74
75    /// Returns variables constrained in one or more steps by `source`.
76    pub fn multi_step_constraint(&self, source: &VariableName) -> HashSet<VariableName> {
77        let mut result = HashSet::new();
78        let mut update = self.single_step_constraint(source);
79        while !update.is_subset(&result) {
80            result.extend(update.iter().cloned());
81            update = update.iter().flat_map(|source| self.single_step_constraint(source)).collect();
82        }
83        result
84    }
85
86    /// Returns true if the source constrains any of the sinks.
87    pub fn constrains_any(&self, source: &VariableName, sinks: &HashSet<VariableName>) -> bool {
88        self.multi_step_constraint(source).iter().any(|sink| sinks.contains(sink))
89    }
90
91    /// Returns the set of variables occurring in a constraint together with at
92    /// least one other variable.
93    pub fn constrained_variables(&self) -> HashSet<VariableName> {
94        self.constraint_map.keys().cloned().collect::<HashSet<_>>()
95    }
96}
97
98pub fn run_constraint_analysis(cfg: &Cfg) -> ConstraintAnalysis {
99    debug!("running constraint analysis pass");
100    let mut result = ConstraintAnalysis::new();
101
102    use AssignOp::*;
103    use Statement::*;
104    for basic_block in cfg.iter() {
105        for stmt in basic_block.iter() {
106            trace!("visiting statement `{stmt:?}`");
107            // Add definitions to the result.
108            for var in stmt.variables_written() {
109                result.add_definition(var);
110            }
111            match stmt {
112                Declaration { meta, names, .. } => {
113                    // Add declarations to the result.
114                    for sink in names {
115                        result.add_declaration(&VariableUse::new(meta, sink, &Vec::new()));
116                    }
117                }
118                ConstraintEquality { .. } | Substitution { op: AssignConstraintSignal, .. } => {
119                    for source in stmt.variables_used() {
120                        for sink in stmt.variables_used() {
121                            if source.name() != sink.name() {
122                                trace!(
123                                    "adding constraint step with source `{:?}` and sink `{:?}`",
124                                    source.name(),
125                                    sink.name()
126                                );
127                                result.add_constraint_step(source.name(), sink.name());
128                            }
129                        }
130                    }
131                }
132                _ => {}
133            }
134        }
135    }
136    result
137}
138
139#[cfg(test)]
140mod tests {
141    use parser::parse_definition;
142    use program_structure::cfg::IntoCfg;
143    use program_structure::constants::Curve;
144    use program_structure::report::ReportCollection;
145
146    use super::*;
147
148    #[test]
149    fn test_single_step_constraint() {
150        let src = r#"
151            template T(n) {
152                signal input in;
153                signal output out;
154                signal tmp;
155
156                tmp <== 2 * in;
157                out <== in * in;
158
159            }
160        "#;
161        let sources = [
162            VariableName::from_string("in"),
163            VariableName::from_string("out"),
164            VariableName::from_string("tmp"),
165        ];
166        let sinks = [2, 1, 1];
167        validate_constraints(src, &sources, &sinks);
168
169        let src = r#"
170            template T(n) {
171                signal input in;
172                signal output out;
173                signal tmp;
174
175                tmp === 2 * in;
176                out <== in * in;
177
178            }
179        "#;
180        let sources = [
181            VariableName::from_string("in"),
182            VariableName::from_string("out"),
183            VariableName::from_string("tmp"),
184        ];
185        let sinks = [2, 1, 1];
186        validate_constraints(src, &sources, &sinks);
187    }
188
189    fn validate_constraints(src: &str, sources: &[VariableName], sinks: &[usize]) {
190        // Build CFG.
191        let mut reports = ReportCollection::new();
192        let cfg = parse_definition(src)
193            .unwrap()
194            .into_cfg(&Curve::default(), &mut reports)
195            .unwrap()
196            .into_ssa()
197            .unwrap();
198        assert!(reports.is_empty());
199
200        // Run constraint analysis.
201        let constraint_analysis = run_constraint_analysis(&cfg);
202        for (source, sinks) in sources.iter().zip(sinks) {
203            assert_eq!(constraint_analysis.single_step_constraint(source).len(), *sinks)
204        }
205    }
206}