Skip to main content

leo_passes/common/
block_to_function_rewriter.rs

1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! Transforms a captured `Block` into a standalone async `Function` plus a
18//! corresponding call expression.
19//!
20//! This pass analyzes symbol accesses inside the block, determines which
21//! variables must become parameters, and synthesizes the necessary `Input`s
22//! and call-site arguments. Tuple and tuple-field accesses are normalized so
23//! that each accessed element becomes a unique parameter, with full-tuple
24//! reconstruction when needed.
25//!
26//! The original block is then reconstructed with all symbol references
27//! replaced by these synthesized parameters. The result is a function
28//! that encapsulates the block's logic and a call expression that invokes it.
29//!
30//! # Example
31//! ```leo
32//! // Original block
33//! let a: i32 = 1;
34//! let b: i32 = 2;
35//! let c: (i32, i32) = (3, 4);
36//! {
37//!     let x = a + b;
38//!     let y = c.0 + c.1;
39//!     ..
40//! }
41//!
42//! // Rewritten as a function + call expression (assume `variant` is `AsyncFunction` here)
43//! async function generated_async(a: i32, b: i32, "c.0": i32, "c.1": i32) {
44//!     let x = a + b;
45//!     let y = "c.0" + "c.1";
46//!     ..
47//! }
48//!
49//! // Call
50//! generated_async(a, b, c.0, c.1);
51//! ```
52
53use crate::{CompilerState, Replacer, SymbolAccessCollector};
54
55use leo_ast::{
56    AstReconstructor,
57    AstVisitor,
58    Block,
59    CallExpression,
60    Expression,
61    Function,
62    Identifier,
63    Input,
64    Location,
65    Node,
66    Path,
67    TupleAccess,
68    TupleExpression,
69    TupleType,
70    Type,
71    Variant,
72};
73use leo_span::{Span, Symbol};
74
75use indexmap::IndexMap;
76
77pub struct BlockToFunctionRewriter<'a> {
78    state: &'a mut CompilerState,
79    current_program: Symbol,
80}
81
82impl<'a> BlockToFunctionRewriter<'a> {
83    pub fn new(state: &'a mut CompilerState, current_program: Symbol) -> Self {
84        Self { state, current_program }
85    }
86}
87
88impl BlockToFunctionRewriter<'_> {
89    pub fn rewrite_block(
90        &mut self,
91        input: &Block,
92        function_name: Symbol,
93        function_variant: Variant,
94    ) -> (Function, Expression) {
95        // Collect all symbol accesses in the block.
96        let mut access_collector = SymbolAccessCollector::new(self.state);
97        access_collector.visit_block(input);
98
99        // Stores mapping from accessed symbol (and optional index) to the expression used in replacement.
100        let mut replacements: IndexMap<(Symbol, Option<usize>), Expression> = IndexMap::new();
101
102        // Helper to create a fresh `Identifier`.
103        let make_identifier = |slf: &mut Self, symbol: Symbol| Identifier {
104            name: symbol,
105            span: Span::default(),
106            id: slf.state.node_builder.next_id(),
107        };
108
109        // Generates a set of `Input`s and corresponding call-site `Expression`s for a given symbol access.
110        //
111        // This function handles both:
112        // - Direct variable accesses (e.g., `foo`)
113        // - Tuple element accesses (e.g., `foo.0`)
114        //
115        // For tuple accesses:
116        // - If a single element (e.g. `foo.0`) is accessed, it generates a synthetic input like `"foo.0"`.
117        // - If the whole tuple (e.g. `foo`) is accessed, it ensures all elements are covered by:
118        //     - Reusing existing inputs from `replacements` if already generated via prior field access.
119        //     - Creating new inputs and arguments for any missing elements.
120        // - The entire tuple is reconstructed in `replacements` using the individual elements as a `TupleExpression`.
121        //
122        // This function also ensures deduplication by consulting the `replacements` map:
123        // - If a given `(symbol, index)` has already been processed, no duplicate input or argument is generated.
124        // - This prevents repeated parameters for accesses like both `foo` and `foo.0`.
125        //
126        // # Parameters
127        // - `symbol`: The symbol being accessed.
128        // - `var_type`: The type of the symbol (may be a tuple or base type).
129        // - `index_opt`: `Some(index)` for a tuple field (e.g., `.0`), or `None` for full-variable access.
130        //
131        // # Returns
132        // A `Vec<(Input, Expression)>`, where:
133        // - `Input` is a parameter for the generated function.
134        // - `Expression` is the call-site argument expression used to invoke that parameter.
135        let mut make_inputs_and_arguments =
136            |slf: &mut Self, symbol: Symbol, var_type: &Type, index_opt: Option<usize>| -> Vec<(Input, Expression)> {
137                if replacements.contains_key(&(symbol, index_opt)) {
138                    return vec![]; // No new input needed; argument already exists
139                }
140
141                match index_opt {
142                    Some(index) => {
143                        let Type::Tuple(TupleType { elements }) = var_type else {
144                            // The type checker has already emitted an error for this invalid access;
145                            // return no inputs so compilation can continue to report all diagnostics.
146                            return vec![];
147                        };
148
149                        // The type checker has already emitted an error for this out-of-bounds access;
150                        // return no inputs so compilation can continue to report all diagnostics.
151                        if index >= elements.len() {
152                            return vec![];
153                        }
154
155                        let synthetic_name = format!("\"{symbol}.{index}\"");
156                        let synthetic_symbol = Symbol::intern(&synthetic_name);
157                        let identifier = make_identifier(slf, synthetic_symbol);
158
159                        let input = Input {
160                            identifier,
161                            mode: leo_ast::Mode::None,
162                            type_: elements[index].clone(),
163                            span: Span::default(),
164                            id: slf.state.node_builder.next_id(),
165                        };
166
167                        replacements.insert((symbol, Some(index)), Path::from(identifier).to_local().into());
168
169                        vec![(
170                            input,
171                            TupleAccess {
172                                tuple: Path::from(make_identifier(slf, symbol)).to_local().into(),
173                                index: index.into(),
174                                span: Span::default(),
175                                id: slf.state.node_builder.next_id(),
176                            }
177                            .into(),
178                        )]
179                    }
180
181                    None => match var_type {
182                        Type::Tuple(TupleType { elements }) => {
183                            let mut inputs_and_arguments = Vec::with_capacity(elements.len());
184                            let mut tuple_elements = Vec::with_capacity(elements.len());
185
186                            for (i, element_type) in elements.iter().enumerate() {
187                                let key = (symbol, Some(i));
188
189                                // Skip if this field is already handled
190                                if let Some(existing_expr) = replacements.get(&key) {
191                                    tuple_elements.push(existing_expr.clone());
192                                    continue;
193                                }
194
195                                // Otherwise, synthesize identifier and input
196                                let synthetic_name = format!("\"{symbol}.{i}\"");
197                                let synthetic_symbol = Symbol::intern(&synthetic_name);
198                                let identifier = make_identifier(slf, synthetic_symbol);
199
200                                let input = Input {
201                                    identifier,
202                                    mode: leo_ast::Mode::None,
203                                    type_: element_type.clone(),
204                                    span: Span::default(),
205                                    id: slf.state.node_builder.next_id(),
206                                };
207
208                                let expr: Expression = Path::from(identifier).to_local().into();
209
210                                replacements.insert(key, expr.clone());
211                                tuple_elements.push(expr.clone());
212                                inputs_and_arguments.push((
213                                    input,
214                                    TupleAccess {
215                                        tuple: Path::from(make_identifier(slf, symbol)).to_local().into(),
216                                        index: i.into(),
217                                        span: Span::default(),
218                                        id: slf.state.node_builder.next_id(),
219                                    }
220                                    .into(),
221                                ));
222                            }
223
224                            // Now insert the full tuple (even if all fields were already there).
225                            replacements.insert(
226                                (symbol, None),
227                                Expression::Tuple(TupleExpression {
228                                    elements: tuple_elements,
229                                    span: Span::default(),
230                                    id: slf.state.node_builder.next_id(),
231                                }),
232                            );
233
234                            inputs_and_arguments
235                        }
236
237                        _ => {
238                            let identifier = make_identifier(slf, symbol);
239                            let input = Input {
240                                identifier,
241                                mode: leo_ast::Mode::None,
242                                type_: var_type.clone(),
243                                span: Span::default(),
244                                id: slf.state.node_builder.next_id(),
245                            };
246
247                            replacements.insert((symbol, None), Path::from(identifier).to_local().into());
248
249                            let argument = Path::from(make_identifier(slf, symbol)).to_local().into();
250                            vec![(input, argument)]
251                        }
252                    },
253                }
254            };
255
256        // Resolve symbol accesses into inputs and call arguments.
257        let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector
258            .symbol_accesses
259            .iter()
260            .filter_map(|(path, index)| {
261                // Skip globals and variables that are local to this block or to one of its children.
262
263                // Skip globals.
264                if path.is_global() {
265                    return None;
266                }
267
268                // Skip variables that are local to this block or to one of its children.
269                let local_var_name = path.expect_local_symbol(); // Not global, so must be local 
270                if self.state.symbol_table.is_local_to_or_in_child_scope(input.id(), local_var_name) {
271                    return None;
272                }
273
274                // All other variables become parameters to the function being built.
275                let var = self.state.symbol_table.lookup_local(local_var_name)?;
276                Some(make_inputs_and_arguments(self, local_var_name, &var.type_.expect("must be known by now"), *index))
277            })
278            .flatten()
279            .unzip();
280
281        // Replacement logic used to patch the block.
282        let replace_expr = |expr: &Expression| -> Expression {
283            match expr {
284                Expression::Path(path) => {
285                    replacements.get(&(path.identifier().name, None)).cloned().unwrap_or_else(|| expr.clone())
286                }
287
288                Expression::TupleAccess(ta) => {
289                    if let Expression::Path(path) = &ta.tuple {
290                        replacements
291                            .get(&(path.identifier().name, Some(ta.index.value())))
292                            .cloned()
293                            .unwrap_or_else(|| expr.clone())
294                    } else {
295                        expr.clone()
296                    }
297                }
298
299                _ => expr.clone(),
300            }
301        };
302
303        // Reconstruct the block with replaced references.
304        let mut replacer = Replacer::new(replace_expr, true /* refresh IDs */, self.state);
305        let new_block = replacer.reconstruct_block(input.clone()).0;
306
307        // Define the new function.
308        let function = Function {
309            annotations: vec![],
310            variant: function_variant,
311            identifier: make_identifier(self, function_name),
312            const_parameters: vec![],
313            input: inputs,
314            output: vec![],          // No returns supported yet.
315            output_type: Type::Unit, // No returns supported yet.
316            block: new_block,
317            span: input.span,
318            id: self.state.node_builder.next_id(),
319        };
320
321        // Create the call expression to invoke the function.
322        let call_to_function = CallExpression {
323            function: Path::from(make_identifier(self, function_name))
324                .to_global(Location::new(self.current_program, vec![function_name])),
325            const_arguments: vec![],
326            arguments,
327            span: input.span,
328            id: self.state.node_builder.next_id(),
329        };
330
331        (function, call_to_function.into())
332    }
333}