leo_passes/common/block_to_function_rewriter.rs
1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17//! Transforms a captured `Block` into a standalone async `Function` plus a
18//! corresponding call expression.
19//!
20//! This pass analyzes symbol accesses inside the block, determines which
21//! variables must become parameters, and synthesizes the necessary `Input`s
22//! and call-site arguments. Tuple and tuple-field accesses are normalized so
23//! that each accessed element becomes a unique parameter, with full-tuple
24//! reconstruction when needed.
25//!
26//! The original block is then reconstructed with all symbol references
27//! replaced by these synthesized parameters. The result is a function
28//! that encapsulates the block's logic and a call expression that invokes it.
29//!
30//! # Example
31//! ```leo
32//! // Original block
33//! let a: i32 = 1;
34//! let b: i32 = 2;
35//! let c: (i32, i32) = (3, 4);
36//! {
37//! let x = a + b;
38//! let y = c.0 + c.1;
39//! ..
40//! }
41//!
42//! // Rewritten as a function + call expression (assume `variant` is `AsyncFunction` here)
43//! async function generated_async(a: i32, b: i32, "c.0": i32, "c.1": i32) {
44//! let x = a + b;
45//! let y = "c.0" + "c.1";
46//! ..
47//! }
48//!
49//! // Call
50//! generated_async(a, b, c.0, c.1);
51//! ```
52
53use crate::{CompilerState, Replacer, SymbolAccessCollector};
54
55use leo_ast::{
56 AstReconstructor,
57 AstVisitor,
58 Block,
59 CallExpression,
60 Expression,
61 Function,
62 Identifier,
63 Input,
64 Location,
65 Node,
66 Path,
67 TupleAccess,
68 TupleExpression,
69 TupleType,
70 Type,
71 Variant,
72};
73use leo_span::{Span, Symbol};
74
75use indexmap::IndexMap;
76
77pub struct BlockToFunctionRewriter<'a> {
78 state: &'a mut CompilerState,
79 current_program: Symbol,
80}
81
82impl<'a> BlockToFunctionRewriter<'a> {
83 pub fn new(state: &'a mut CompilerState, current_program: Symbol) -> Self {
84 Self { state, current_program }
85 }
86}
87
88impl BlockToFunctionRewriter<'_> {
89 pub fn rewrite_block(
90 &mut self,
91 input: &Block,
92 function_name: Symbol,
93 function_variant: Variant,
94 ) -> (Function, Expression) {
95 // Collect all symbol accesses in the block.
96 let mut access_collector = SymbolAccessCollector::new(self.state);
97 access_collector.visit_block(input);
98
99 // Stores mapping from accessed symbol (and optional index) to the expression used in replacement.
100 let mut replacements: IndexMap<(Symbol, Option<usize>), Expression> = IndexMap::new();
101
102 // Helper to create a fresh `Identifier`.
103 let make_identifier = |slf: &mut Self, symbol: Symbol| Identifier {
104 name: symbol,
105 span: Span::default(),
106 id: slf.state.node_builder.next_id(),
107 };
108
109 // Generates a set of `Input`s and corresponding call-site `Expression`s for a given symbol access.
110 //
111 // This function handles both:
112 // - Direct variable accesses (e.g., `foo`)
113 // - Tuple element accesses (e.g., `foo.0`)
114 //
115 // For tuple accesses:
116 // - If a single element (e.g. `foo.0`) is accessed, it generates a synthetic input like `"foo.0"`.
117 // - If the whole tuple (e.g. `foo`) is accessed, it ensures all elements are covered by:
118 // - Reusing existing inputs from `replacements` if already generated via prior field access.
119 // - Creating new inputs and arguments for any missing elements.
120 // - The entire tuple is reconstructed in `replacements` using the individual elements as a `TupleExpression`.
121 //
122 // This function also ensures deduplication by consulting the `replacements` map:
123 // - If a given `(symbol, index)` has already been processed, no duplicate input or argument is generated.
124 // - This prevents repeated parameters for accesses like both `foo` and `foo.0`.
125 //
126 // # Parameters
127 // - `symbol`: The symbol being accessed.
128 // - `var_type`: The type of the symbol (may be a tuple or base type).
129 // - `index_opt`: `Some(index)` for a tuple field (e.g., `.0`), or `None` for full-variable access.
130 //
131 // # Returns
132 // A `Vec<(Input, Expression)>`, where:
133 // - `Input` is a parameter for the generated function.
134 // - `Expression` is the call-site argument expression used to invoke that parameter.
135 let mut make_inputs_and_arguments =
136 |slf: &mut Self, symbol: Symbol, var_type: &Type, index_opt: Option<usize>| -> Vec<(Input, Expression)> {
137 if replacements.contains_key(&(symbol, index_opt)) {
138 return vec![]; // No new input needed; argument already exists
139 }
140
141 match index_opt {
142 Some(index) => {
143 let Type::Tuple(TupleType { elements }) = var_type else {
144 // The type checker has already emitted an error for this invalid access;
145 // return no inputs so compilation can continue to report all diagnostics.
146 return vec![];
147 };
148
149 // The type checker has already emitted an error for this out-of-bounds access;
150 // return no inputs so compilation can continue to report all diagnostics.
151 if index >= elements.len() {
152 return vec![];
153 }
154
155 let synthetic_name = format!("\"{symbol}.{index}\"");
156 let synthetic_symbol = Symbol::intern(&synthetic_name);
157 let identifier = make_identifier(slf, synthetic_symbol);
158
159 let input = Input {
160 identifier,
161 mode: leo_ast::Mode::None,
162 type_: elements[index].clone(),
163 span: Span::default(),
164 id: slf.state.node_builder.next_id(),
165 };
166
167 replacements.insert((symbol, Some(index)), Path::from(identifier).to_local().into());
168
169 vec![(
170 input,
171 TupleAccess {
172 tuple: Path::from(make_identifier(slf, symbol)).to_local().into(),
173 index: index.into(),
174 span: Span::default(),
175 id: slf.state.node_builder.next_id(),
176 }
177 .into(),
178 )]
179 }
180
181 None => match var_type {
182 Type::Tuple(TupleType { elements }) => {
183 let mut inputs_and_arguments = Vec::with_capacity(elements.len());
184 let mut tuple_elements = Vec::with_capacity(elements.len());
185
186 for (i, element_type) in elements.iter().enumerate() {
187 let key = (symbol, Some(i));
188
189 // Skip if this field is already handled
190 if let Some(existing_expr) = replacements.get(&key) {
191 tuple_elements.push(existing_expr.clone());
192 continue;
193 }
194
195 // Otherwise, synthesize identifier and input
196 let synthetic_name = format!("\"{symbol}.{i}\"");
197 let synthetic_symbol = Symbol::intern(&synthetic_name);
198 let identifier = make_identifier(slf, synthetic_symbol);
199
200 let input = Input {
201 identifier,
202 mode: leo_ast::Mode::None,
203 type_: element_type.clone(),
204 span: Span::default(),
205 id: slf.state.node_builder.next_id(),
206 };
207
208 let expr: Expression = Path::from(identifier).to_local().into();
209
210 replacements.insert(key, expr.clone());
211 tuple_elements.push(expr.clone());
212 inputs_and_arguments.push((
213 input,
214 TupleAccess {
215 tuple: Path::from(make_identifier(slf, symbol)).to_local().into(),
216 index: i.into(),
217 span: Span::default(),
218 id: slf.state.node_builder.next_id(),
219 }
220 .into(),
221 ));
222 }
223
224 // Now insert the full tuple (even if all fields were already there).
225 replacements.insert(
226 (symbol, None),
227 Expression::Tuple(TupleExpression {
228 elements: tuple_elements,
229 span: Span::default(),
230 id: slf.state.node_builder.next_id(),
231 }),
232 );
233
234 inputs_and_arguments
235 }
236
237 _ => {
238 let identifier = make_identifier(slf, symbol);
239 let input = Input {
240 identifier,
241 mode: leo_ast::Mode::None,
242 type_: var_type.clone(),
243 span: Span::default(),
244 id: slf.state.node_builder.next_id(),
245 };
246
247 replacements.insert((symbol, None), Path::from(identifier).to_local().into());
248
249 let argument = Path::from(make_identifier(slf, symbol)).to_local().into();
250 vec![(input, argument)]
251 }
252 },
253 }
254 };
255
256 // Resolve symbol accesses into inputs and call arguments.
257 let (inputs, arguments): (Vec<_>, Vec<_>) = access_collector
258 .symbol_accesses
259 .iter()
260 .filter_map(|(path, index)| {
261 // Skip globals and variables that are local to this block or to one of its children.
262
263 // Skip globals.
264 if path.is_global() {
265 return None;
266 }
267
268 // Skip variables that are local to this block or to one of its children.
269 let local_var_name = path.expect_local_symbol(); // Not global, so must be local
270 if self.state.symbol_table.is_local_to_or_in_child_scope(input.id(), local_var_name) {
271 return None;
272 }
273
274 // All other variables become parameters to the function being built.
275 let var = self.state.symbol_table.lookup_local(local_var_name)?;
276 Some(make_inputs_and_arguments(self, local_var_name, &var.type_.expect("must be known by now"), *index))
277 })
278 .flatten()
279 .unzip();
280
281 // Replacement logic used to patch the block.
282 let replace_expr = |expr: &Expression| -> Expression {
283 match expr {
284 Expression::Path(path) => {
285 replacements.get(&(path.identifier().name, None)).cloned().unwrap_or_else(|| expr.clone())
286 }
287
288 Expression::TupleAccess(ta) => {
289 if let Expression::Path(path) = &ta.tuple {
290 replacements
291 .get(&(path.identifier().name, Some(ta.index.value())))
292 .cloned()
293 .unwrap_or_else(|| expr.clone())
294 } else {
295 expr.clone()
296 }
297 }
298
299 _ => expr.clone(),
300 }
301 };
302
303 // Reconstruct the block with replaced references.
304 let mut replacer = Replacer::new(replace_expr, true /* refresh IDs */, self.state);
305 let new_block = replacer.reconstruct_block(input.clone()).0;
306
307 // Define the new function.
308 let function = Function {
309 annotations: vec![],
310 variant: function_variant,
311 identifier: make_identifier(self, function_name),
312 const_parameters: vec![],
313 input: inputs,
314 output: vec![], // No returns supported yet.
315 output_type: Type::Unit, // No returns supported yet.
316 block: new_block,
317 span: input.span,
318 id: self.state.node_builder.next_id(),
319 };
320
321 // Create the call expression to invoke the function.
322 let call_to_function = CallExpression {
323 function: Path::from(make_identifier(self, function_name))
324 .to_global(Location::new(self.current_program, vec![function_name])),
325 const_arguments: vec![],
326 arguments,
327 span: input.span,
328 id: self.state.node_builder.next_id(),
329 };
330
331 (function, call_to_function.into())
332 }
333}