circomspect_program_structure/program_library/
template_data.rs

1use super::ast;
2use super::ast::{FillMeta, Statement};
3use super::file_definition::FileID;
4use crate::file_definition::FileLocation;
5use std::collections::{HashMap, HashSet, BTreeMap};
6
7pub type TagInfo = HashSet<String>;
8pub type TemplateInfo = HashMap<String, TemplateData>;
9type SignalInfo = BTreeMap<String, (usize, TagInfo)>;
10type SignalDeclarationOrder = Vec<(String, usize)>;
11
12#[derive(Clone)]
13pub struct TemplateData {
14    file_id: FileID,
15    name: String,
16    body: Statement,
17    num_of_params: usize,
18    name_of_params: Vec<String>,
19    param_location: FileLocation,
20    input_signals: SignalInfo,
21    output_signals: SignalInfo,
22    is_parallel: bool,
23    is_custom_gate: bool,
24    // Only used to know the order in which signals are declared.
25    input_declarations: SignalDeclarationOrder,
26    output_declarations: SignalDeclarationOrder,
27}
28
29impl TemplateData {
30    #[allow(clippy::too_many_arguments)]
31    pub fn new(
32        name: String,
33        file_id: FileID,
34        mut body: Statement,
35        num_of_params: usize,
36        name_of_params: Vec<String>,
37        param_location: FileLocation,
38        elem_id: &mut usize,
39        is_parallel: bool,
40        is_custom_gate: bool,
41    ) -> TemplateData {
42        body.fill(file_id, elem_id);
43        let mut input_signals = SignalInfo::new();
44        let mut output_signals = SignalInfo::new();
45        let mut input_declarations = SignalDeclarationOrder::new();
46        let mut output_declarations = SignalDeclarationOrder::new();
47        fill_inputs_and_outputs(
48            &body,
49            &mut input_signals,
50            &mut output_signals,
51            &mut input_declarations,
52            &mut output_declarations,
53        );
54        TemplateData {
55            name,
56            file_id,
57            body,
58            num_of_params,
59            name_of_params,
60            param_location,
61            input_signals,
62            output_signals,
63            is_parallel,
64            is_custom_gate,
65            input_declarations,
66            output_declarations,
67        }
68    }
69
70    pub fn get_file_id(&self) -> FileID {
71        self.file_id
72    }
73
74    pub fn get_body(&self) -> &Statement {
75        &self.body
76    }
77
78    pub fn get_body_as_vec(&self) -> &Vec<Statement> {
79        match &self.body {
80            Statement::Block { stmts, .. } => stmts,
81            _ => panic!("Function body should be a block"),
82        }
83    }
84
85    pub fn get_mut_body(&mut self) -> &mut Statement {
86        &mut self.body
87    }
88
89    pub fn get_mut_body_as_vec(&mut self) -> &mut Vec<Statement> {
90        match &mut self.body {
91            Statement::Block { stmts, .. } => stmts,
92            _ => panic!("Function body should be a block"),
93        }
94    }
95
96    pub fn get_num_of_params(&self) -> usize {
97        self.num_of_params
98    }
99
100    pub fn get_param_location(&self) -> FileLocation {
101        self.param_location.clone()
102    }
103
104    pub fn get_name_of_params(&self) -> &Vec<String> {
105        &self.name_of_params
106    }
107
108    pub fn get_input_info(&self, name: &str) -> Option<&(usize, TagInfo)> {
109        self.input_signals.get(name)
110    }
111
112    pub fn get_output_info(&self, name: &str) -> Option<&(usize, TagInfo)> {
113        self.output_signals.get(name)
114    }
115    pub fn get_inputs(&self) -> &SignalInfo {
116        &self.input_signals
117    }
118    pub fn get_outputs(&self) -> &SignalInfo {
119        &self.output_signals
120    }
121    pub fn get_declaration_inputs(&self) -> &SignalDeclarationOrder {
122        &self.input_declarations
123    }
124    pub fn get_declaration_outputs(&self) -> &SignalDeclarationOrder {
125        &self.output_declarations
126    }
127    pub fn get_name(&self) -> &str {
128        &self.name
129    }
130    pub fn is_parallel(&self) -> bool {
131        self.is_parallel
132    }
133    pub fn is_custom_gate(&self) -> bool {
134        self.is_custom_gate
135    }
136}
137
138fn fill_inputs_and_outputs(
139    template_statement: &Statement,
140    input_signals: &mut SignalInfo,
141    output_signals: &mut SignalInfo,
142    input_declarations: &mut SignalDeclarationOrder,
143    output_declarations: &mut SignalDeclarationOrder,
144) {
145    match template_statement {
146        Statement::IfThenElse { if_case, else_case, .. } => {
147            fill_inputs_and_outputs(
148                if_case,
149                input_signals,
150                output_signals,
151                input_declarations,
152                output_declarations,
153            );
154            if let Option::Some(else_value) = else_case {
155                fill_inputs_and_outputs(
156                    else_value,
157                    input_signals,
158                    output_signals,
159                    input_declarations,
160                    output_declarations,
161                );
162            }
163        }
164        Statement::Block { stmts, .. } => {
165            for stmt in stmts.iter() {
166                fill_inputs_and_outputs(
167                    stmt,
168                    input_signals,
169                    output_signals,
170                    input_declarations,
171                    output_declarations,
172                );
173            }
174        }
175        Statement::While { stmt, .. } => {
176            fill_inputs_and_outputs(
177                stmt,
178                input_signals,
179                output_signals,
180                input_declarations,
181                output_declarations,
182            );
183        }
184        Statement::InitializationBlock { initializations, .. } => {
185            for initialization in initializations.iter() {
186                fill_inputs_and_outputs(
187                    initialization,
188                    input_signals,
189                    output_signals,
190                    input_declarations,
191                    output_declarations,
192                );
193            }
194        }
195        Statement::Declaration {
196            xtype: ast::VariableType::Signal(stype, tag_list),
197            name,
198            dimensions,
199            ..
200        } => {
201            let signal_name = name.clone();
202            let dimensions = dimensions.len();
203            let mut tag_info = HashSet::new();
204            for tag in tag_list {
205                tag_info.insert(tag.clone());
206            }
207
208            match stype {
209                ast::SignalType::Input => {
210                    input_signals.insert(signal_name.clone(), (dimensions, tag_info));
211                    input_declarations.push((signal_name, dimensions));
212                }
213                ast::SignalType::Output => {
214                    output_signals.insert(signal_name.clone(), (dimensions, tag_info));
215                    output_declarations.push((signal_name, dimensions));
216                }
217                _ => {} //no need to deal with intermediate signals
218            }
219        }
220        _ => {}
221    }
222}