circomspect_program_analysis/
constraint_analysis.rs1use log::{debug, trace};
2use std::collections::{HashMap, HashSet};
3
4use program_structure::cfg::Cfg;
5use program_structure::intermediate_representation::variable_meta::VariableMeta;
6use program_structure::intermediate_representation::AssignOp;
7use program_structure::ir::variable_meta::VariableUse;
8use program_structure::ir::{Statement, VariableName};
9
10#[derive(Clone, Default)]
14pub struct ConstraintAnalysis {
15 constraint_map: HashMap<VariableName, HashSet<VariableName>>,
16 declarations: HashMap<VariableName, VariableUse>,
17 definitions: HashMap<VariableName, VariableUse>,
18}
19
20impl ConstraintAnalysis {
21 fn new() -> ConstraintAnalysis {
22 ConstraintAnalysis::default()
23 }
24
25 fn add_definition(&mut self, var: &VariableUse) {
27 self.definitions.insert(var.name().clone(), var.clone());
39 }
40
41 pub fn get_definition(&self, var: &VariableName) -> Option<VariableUse> {
43 self.definitions.get(var).cloned()
44 }
45
46 pub fn definitions(&self) -> impl Iterator<Item = &VariableUse> {
47 self.definitions.values()
48 }
49
50 fn add_declaration(&mut self, var: &VariableUse) {
52 self.declarations.insert(var.name().clone(), var.clone());
53 }
54
55 pub fn get_declaration(&self, var: &VariableName) -> Option<VariableUse> {
57 self.declarations.get(var).cloned()
58 }
59
60 pub fn declarations(&self) -> impl Iterator<Item = &VariableUse> {
61 self.declarations.values()
62 }
63
64 fn add_constraint_step(&mut self, source: &VariableName, sink: &VariableName) {
66 let sinks = self.constraint_map.entry(source.clone()).or_default();
67 sinks.insert(sink.clone());
68 }
69
70 pub fn single_step_constraint(&self, source: &VariableName) -> HashSet<VariableName> {
72 self.constraint_map.get(source).cloned().unwrap_or_default()
73 }
74
75 pub fn multi_step_constraint(&self, source: &VariableName) -> HashSet<VariableName> {
77 let mut result = HashSet::new();
78 let mut update = self.single_step_constraint(source);
79 while !update.is_subset(&result) {
80 result.extend(update.iter().cloned());
81 update = update.iter().flat_map(|source| self.single_step_constraint(source)).collect();
82 }
83 result
84 }
85
86 pub fn constrains_any(&self, source: &VariableName, sinks: &HashSet<VariableName>) -> bool {
88 self.multi_step_constraint(source).iter().any(|sink| sinks.contains(sink))
89 }
90
91 pub fn constrained_variables(&self) -> HashSet<VariableName> {
94 self.constraint_map.keys().cloned().collect::<HashSet<_>>()
95 }
96}
97
98pub fn run_constraint_analysis(cfg: &Cfg) -> ConstraintAnalysis {
99 debug!("running constraint analysis pass");
100 let mut result = ConstraintAnalysis::new();
101
102 use AssignOp::*;
103 use Statement::*;
104 for basic_block in cfg.iter() {
105 for stmt in basic_block.iter() {
106 trace!("visiting statement `{stmt:?}`");
107 for var in stmt.variables_written() {
109 result.add_definition(var);
110 }
111 match stmt {
112 Declaration { meta, names, .. } => {
113 for sink in names {
115 result.add_declaration(&VariableUse::new(meta, sink, &Vec::new()));
116 }
117 }
118 ConstraintEquality { .. } | Substitution { op: AssignConstraintSignal, .. } => {
119 for source in stmt.variables_used() {
120 for sink in stmt.variables_used() {
121 if source.name() != sink.name() {
122 trace!(
123 "adding constraint step with source `{:?}` and sink `{:?}`",
124 source.name(),
125 sink.name()
126 );
127 result.add_constraint_step(source.name(), sink.name());
128 }
129 }
130 }
131 }
132 _ => {}
133 }
134 }
135 }
136 result
137}
138
139#[cfg(test)]
140mod tests {
141 use parser::parse_definition;
142 use program_structure::cfg::IntoCfg;
143 use program_structure::constants::Curve;
144 use program_structure::report::ReportCollection;
145
146 use super::*;
147
148 #[test]
149 fn test_single_step_constraint() {
150 let src = r#"
151 template T(n) {
152 signal input in;
153 signal output out;
154 signal tmp;
155
156 tmp <== 2 * in;
157 out <== in * in;
158
159 }
160 "#;
161 let sources = [
162 VariableName::from_string("in"),
163 VariableName::from_string("out"),
164 VariableName::from_string("tmp"),
165 ];
166 let sinks = [2, 1, 1];
167 validate_constraints(src, &sources, &sinks);
168
169 let src = r#"
170 template T(n) {
171 signal input in;
172 signal output out;
173 signal tmp;
174
175 tmp === 2 * in;
176 out <== in * in;
177
178 }
179 "#;
180 let sources = [
181 VariableName::from_string("in"),
182 VariableName::from_string("out"),
183 VariableName::from_string("tmp"),
184 ];
185 let sinks = [2, 1, 1];
186 validate_constraints(src, &sources, &sinks);
187 }
188
189 fn validate_constraints(src: &str, sources: &[VariableName], sinks: &[usize]) {
190 let mut reports = ReportCollection::new();
192 let cfg = parse_definition(src)
193 .unwrap()
194 .into_cfg(&Curve::default(), &mut reports)
195 .unwrap()
196 .into_ssa()
197 .unwrap();
198 assert!(reports.is_empty());
199
200 let constraint_analysis = run_constraint_analysis(&cfg);
202 for (source, sinks) in sources.iter().zip(sinks) {
203 assert_eq!(constraint_analysis.single_step_constraint(source).len(), *sinks)
204 }
205 }
206}