leo_ast/passes/
reconstructor.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! This module contains a Reconstructor trait for the AST.
18//! It implements default methods for each node to be made
19//! given the information of the old node.
20
21use crate::*;
22
23/// A Reconstructor trait for expressions in the AST.
24pub trait ExpressionReconstructor {
25    type AdditionalOutput: Default;
26
27    fn reconstruct_expression(&mut self, input: Expression) -> (Expression, Self::AdditionalOutput) {
28        match input {
29            Expression::AssociatedConstant(constant) => self.reconstruct_associated_constant(constant),
30            Expression::AssociatedFunction(function) => self.reconstruct_associated_function(function),
31            Expression::Array(array) => self.reconstruct_array(array),
32            Expression::ArrayAccess(access) => self.reconstruct_array_access(*access),
33            Expression::Binary(binary) => self.reconstruct_binary(*binary),
34            Expression::Call(call) => self.reconstruct_call(*call),
35            Expression::Cast(cast) => self.reconstruct_cast(*cast),
36            Expression::Struct(struct_) => self.reconstruct_struct_init(struct_),
37            Expression::Err(err) => self.reconstruct_err(err),
38            Expression::Identifier(identifier) => self.reconstruct_identifier(identifier),
39            Expression::Literal(value) => self.reconstruct_literal(value),
40            Expression::Locator(locator) => self.reconstruct_locator(locator),
41            Expression::MemberAccess(access) => self.reconstruct_member_access(*access),
42            Expression::Ternary(ternary) => self.reconstruct_ternary(*ternary),
43            Expression::Tuple(tuple) => self.reconstruct_tuple(tuple),
44            Expression::TupleAccess(access) => self.reconstruct_tuple_access(*access),
45            Expression::Unary(unary) => self.reconstruct_unary(*unary),
46            Expression::Unit(unit) => self.reconstruct_unit(unit),
47        }
48    }
49
50    fn reconstruct_array_access(&mut self, input: ArrayAccess) -> (Expression, Self::AdditionalOutput) {
51        (
52            ArrayAccess {
53                array: self.reconstruct_expression(input.array).0,
54                index: self.reconstruct_expression(input.index).0,
55                ..input
56            }
57            .into(),
58            Default::default(),
59        )
60    }
61
62    fn reconstruct_associated_constant(
63        &mut self,
64        input: AssociatedConstantExpression,
65    ) -> (Expression, Self::AdditionalOutput) {
66        (input.into(), Default::default())
67    }
68
69    fn reconstruct_associated_function(
70        &mut self,
71        input: AssociatedFunctionExpression,
72    ) -> (Expression, Self::AdditionalOutput) {
73        (
74            AssociatedFunctionExpression {
75                arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
76                ..input
77            }
78            .into(),
79            Default::default(),
80        )
81    }
82
83    fn reconstruct_member_access(&mut self, input: MemberAccess) -> (Expression, Self::AdditionalOutput) {
84        (MemberAccess { inner: self.reconstruct_expression(input.inner).0, ..input }.into(), Default::default())
85    }
86
87    fn reconstruct_tuple_access(&mut self, input: TupleAccess) -> (Expression, Self::AdditionalOutput) {
88        (TupleAccess { tuple: self.reconstruct_expression(input.tuple).0, ..input }.into(), Default::default())
89    }
90
91    fn reconstruct_array(&mut self, input: ArrayExpression) -> (Expression, Self::AdditionalOutput) {
92        (
93            ArrayExpression {
94                elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
95                ..input
96            }
97            .into(),
98            Default::default(),
99        )
100    }
101
102    fn reconstruct_binary(&mut self, input: BinaryExpression) -> (Expression, Self::AdditionalOutput) {
103        (
104            BinaryExpression {
105                left: self.reconstruct_expression(input.left).0,
106                right: self.reconstruct_expression(input.right).0,
107                ..input
108            }
109            .into(),
110            Default::default(),
111        )
112    }
113
114    fn reconstruct_call(&mut self, input: CallExpression) -> (Expression, Self::AdditionalOutput) {
115        (
116            CallExpression {
117                arguments: input.arguments.into_iter().map(|arg| self.reconstruct_expression(arg).0).collect(),
118                ..input
119            }
120            .into(),
121            Default::default(),
122        )
123    }
124
125    fn reconstruct_cast(&mut self, input: CastExpression) -> (Expression, Self::AdditionalOutput) {
126        (
127            CastExpression { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
128            Default::default(),
129        )
130    }
131
132    fn reconstruct_struct_init(&mut self, input: StructExpression) -> (Expression, Self::AdditionalOutput) {
133        (
134            StructExpression {
135                members: input
136                    .members
137                    .into_iter()
138                    .map(|member| StructVariableInitializer {
139                        identifier: member.identifier,
140                        expression: member.expression.map(|expr| self.reconstruct_expression(expr).0),
141                        span: member.span,
142                        id: member.id,
143                    })
144                    .collect(),
145                ..input
146            }
147            .into(),
148            Default::default(),
149        )
150    }
151
152    fn reconstruct_err(&mut self, _input: ErrExpression) -> (Expression, Self::AdditionalOutput) {
153        panic!("`ErrExpression`s should not be in the AST at this phase of compilation.")
154    }
155
156    fn reconstruct_identifier(&mut self, input: Identifier) -> (Expression, Self::AdditionalOutput) {
157        (input.into(), Default::default())
158    }
159
160    fn reconstruct_literal(&mut self, input: Literal) -> (Expression, Self::AdditionalOutput) {
161        (input.into(), Default::default())
162    }
163
164    fn reconstruct_locator(&mut self, input: LocatorExpression) -> (Expression, Self::AdditionalOutput) {
165        (input.into(), Default::default())
166    }
167
168    fn reconstruct_ternary(&mut self, input: TernaryExpression) -> (Expression, Self::AdditionalOutput) {
169        (
170            TernaryExpression {
171                condition: self.reconstruct_expression(input.condition).0,
172                if_true: self.reconstruct_expression(input.if_true).0,
173                if_false: self.reconstruct_expression(input.if_false).0,
174                span: input.span,
175                id: input.id,
176            }
177            .into(),
178            Default::default(),
179        )
180    }
181
182    fn reconstruct_tuple(&mut self, input: TupleExpression) -> (Expression, Self::AdditionalOutput) {
183        (
184            TupleExpression {
185                elements: input.elements.into_iter().map(|element| self.reconstruct_expression(element).0).collect(),
186                ..input
187            }
188            .into(),
189            Default::default(),
190        )
191    }
192
193    fn reconstruct_unary(&mut self, input: UnaryExpression) -> (Expression, Self::AdditionalOutput) {
194        (
195            UnaryExpression { receiver: self.reconstruct_expression(input.receiver).0, ..input }.into(),
196            Default::default(),
197        )
198    }
199
200    fn reconstruct_unit(&mut self, input: UnitExpression) -> (Expression, Self::AdditionalOutput) {
201        (input.into(), Default::default())
202    }
203}
204
205/// A Reconstructor trait for statements in the AST.
206pub trait StatementReconstructor: ExpressionReconstructor {
207    fn reconstruct_statement(&mut self, input: Statement) -> (Statement, Self::AdditionalOutput) {
208        match input {
209            Statement::Assert(assert) => self.reconstruct_assert(assert),
210            Statement::Assign(stmt) => self.reconstruct_assign(*stmt),
211            Statement::Block(stmt) => {
212                let (stmt, output) = self.reconstruct_block(stmt);
213                (stmt.into(), output)
214            }
215            Statement::Conditional(stmt) => self.reconstruct_conditional(stmt),
216            Statement::Const(stmt) => self.reconstruct_const(stmt),
217            Statement::Definition(stmt) => self.reconstruct_definition(stmt),
218            Statement::Expression(stmt) => self.reconstruct_expression_statement(stmt),
219            Statement::Iteration(stmt) => self.reconstruct_iteration(*stmt),
220            Statement::Return(stmt) => self.reconstruct_return(stmt),
221        }
222    }
223
224    fn reconstruct_assert(&mut self, input: AssertStatement) -> (Statement, Self::AdditionalOutput) {
225        (
226            AssertStatement {
227                variant: match input.variant {
228                    AssertVariant::Assert(expr) => AssertVariant::Assert(self.reconstruct_expression(expr).0),
229                    AssertVariant::AssertEq(left, right) => AssertVariant::AssertEq(
230                        self.reconstruct_expression(left).0,
231                        self.reconstruct_expression(right).0,
232                    ),
233                    AssertVariant::AssertNeq(left, right) => AssertVariant::AssertNeq(
234                        self.reconstruct_expression(left).0,
235                        self.reconstruct_expression(right).0,
236                    ),
237                },
238                ..input
239            }
240            .into(),
241            Default::default(),
242        )
243    }
244
245    fn reconstruct_assign(&mut self, input: AssignStatement) -> (Statement, Self::AdditionalOutput) {
246        (AssignStatement { value: self.reconstruct_expression(input.value).0, ..input }.into(), Default::default())
247    }
248
249    fn reconstruct_block(&mut self, input: Block) -> (Block, Self::AdditionalOutput) {
250        (
251            Block {
252                statements: input.statements.into_iter().map(|s| self.reconstruct_statement(s).0).collect(),
253                span: input.span,
254                id: input.id,
255            },
256            Default::default(),
257        )
258    }
259
260    fn reconstruct_conditional(&mut self, input: ConditionalStatement) -> (Statement, Self::AdditionalOutput) {
261        (
262            ConditionalStatement {
263                condition: self.reconstruct_expression(input.condition).0,
264                then: self.reconstruct_block(input.then).0,
265                otherwise: input.otherwise.map(|n| Box::new(self.reconstruct_statement(*n).0)),
266                ..input
267            }
268            .into(),
269            Default::default(),
270        )
271    }
272
273    fn reconstruct_const(&mut self, input: ConstDeclaration) -> (Statement, Self::AdditionalOutput) {
274        (ConstDeclaration { value: self.reconstruct_expression(input.value).0, ..input }.into(), Default::default())
275    }
276
277    fn reconstruct_definition(&mut self, input: DefinitionStatement) -> (Statement, Self::AdditionalOutput) {
278        (DefinitionStatement { value: self.reconstruct_expression(input.value).0, ..input }.into(), Default::default())
279    }
280
281    fn reconstruct_expression_statement(&mut self, input: ExpressionStatement) -> (Statement, Self::AdditionalOutput) {
282        (
283            ExpressionStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
284            Default::default(),
285        )
286    }
287
288    fn reconstruct_iteration(&mut self, input: IterationStatement) -> (Statement, Self::AdditionalOutput) {
289        (
290            IterationStatement {
291                start: self.reconstruct_expression(input.start).0,
292                stop: self.reconstruct_expression(input.stop).0,
293                block: self.reconstruct_block(input.block).0,
294                ..input
295            }
296            .into(),
297            Default::default(),
298        )
299    }
300
301    fn reconstruct_return(&mut self, input: ReturnStatement) -> (Statement, Self::AdditionalOutput) {
302        (
303            ReturnStatement { expression: self.reconstruct_expression(input.expression).0, ..input }.into(),
304            Default::default(),
305        )
306    }
307}
308
309/// A Reconstructor trait for the program represented by the AST.
310pub trait ProgramReconstructor: StatementReconstructor {
311    fn reconstruct_program(&mut self, input: Program) -> Program {
312        Program {
313            imports: input
314                .imports
315                .into_iter()
316                .map(|(id, import)| (id, (self.reconstruct_import(import.0), import.1)))
317                .collect(),
318            stubs: input.stubs.into_iter().map(|(id, stub)| (id, self.reconstruct_stub(stub))).collect(),
319            program_scopes: input
320                .program_scopes
321                .into_iter()
322                .map(|(id, scope)| (id, self.reconstruct_program_scope(scope)))
323                .collect(),
324        }
325    }
326
327    fn reconstruct_stub(&mut self, input: Stub) -> Stub {
328        Stub {
329            imports: input.imports,
330            stub_id: input.stub_id,
331            consts: input.consts,
332            structs: input.structs,
333            mappings: input.mappings,
334            span: input.span,
335            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function_stub(f))).collect(),
336        }
337    }
338
339    fn reconstruct_program_scope(&mut self, input: ProgramScope) -> ProgramScope {
340        ProgramScope {
341            program_id: input.program_id,
342            structs: input.structs.into_iter().map(|(i, c)| (i, self.reconstruct_struct(c))).collect(),
343            mappings: input.mappings.into_iter().map(|(id, mapping)| (id, self.reconstruct_mapping(mapping))).collect(),
344            functions: input.functions.into_iter().map(|(i, f)| (i, self.reconstruct_function(f))).collect(),
345            consts: input
346                .consts
347                .into_iter()
348                .map(|(i, c)| match self.reconstruct_const(c) {
349                    (Statement::Const(declaration), _) => (i, declaration),
350                    _ => panic!("`reconstruct_const` can only return `Statement::Const`"),
351                })
352                .collect(),
353            span: input.span,
354        }
355    }
356
357    fn reconstruct_function(&mut self, input: Function) -> Function {
358        Function {
359            annotations: input.annotations,
360            variant: input.variant,
361            identifier: input.identifier,
362            input: input.input,
363            output: input.output,
364            output_type: input.output_type,
365            block: self.reconstruct_block(input.block).0,
366            span: input.span,
367            id: input.id,
368        }
369    }
370
371    fn reconstruct_function_stub(&mut self, input: FunctionStub) -> FunctionStub {
372        input
373    }
374
375    fn reconstruct_struct(&mut self, input: Composite) -> Composite {
376        input
377    }
378
379    fn reconstruct_import(&mut self, input: Program) -> Program {
380        self.reconstruct_program(input)
381    }
382
383    fn reconstruct_mapping(&mut self, input: Mapping) -> Mapping {
384        input
385    }
386}