Skip to main content

leo_ast/stub/
function_stub.rs

1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::{
18    Annotation,
19    CompositeType,
20    Function,
21    FutureType,
22    Identifier,
23    Input,
24    Location,
25    Mode,
26    Node,
27    NodeID,
28    Output,
29    Path,
30    ProgramId,
31    TupleType,
32    Type,
33    Variant,
34};
35use leo_span::{Span, Symbol, sym};
36
37use itertools::Itertools;
38use serde::{Deserialize, Serialize};
39use snarkvm::{
40    console::program::{
41        FinalizeType::{Future as FutureFinalizeType, Plaintext as PlaintextFinalizeType},
42        RegisterType::{ExternalRecord, Future, Plaintext, Record},
43    },
44    prelude::{Network, ValueType},
45    synthesizer::program::{ClosureCore, FunctionCore},
46};
47use std::fmt;
48
49/// A function stub definition.
50#[derive(Clone, Serialize, Deserialize)]
51pub struct FunctionStub {
52    /// Annotations on the function.
53    pub annotations: Vec<Annotation>,
54    /// Is this function a transition, inlined, or a regular function?.
55    pub variant: Variant,
56    /// The function identifier, e.g., `foo` in `function foo(...) { ... }`.
57    pub identifier: Identifier,
58    /// The function's input parameters.
59    pub input: Vec<Input>,
60    /// The function's output declarations.
61    pub output: Vec<Output>,
62    /// The function's output type.
63    pub output_type: Type,
64    /// The entire span of the function definition.
65    pub span: Span,
66    /// The ID of the node.
67    pub id: NodeID,
68}
69
70impl PartialEq for FunctionStub {
71    fn eq(&self, other: &Self) -> bool {
72        self.identifier == other.identifier
73    }
74}
75
76impl Eq for FunctionStub {}
77
78impl FunctionStub {
79    /// Initialize a new function.
80    #[allow(clippy::too_many_arguments)]
81    pub fn new(
82        annotations: Vec<Annotation>,
83        _is_async: bool,
84        variant: Variant,
85        identifier: Identifier,
86        input: Vec<Input>,
87        output: Vec<Output>,
88        span: Span,
89        id: NodeID,
90    ) -> Self {
91        let output_type = match output.len() {
92            0 => Type::Unit,
93            1 => output[0].type_.clone(),
94            _ => Type::Tuple(TupleType::new(output.iter().map(|o| o.type_.clone()).collect())),
95        };
96
97        FunctionStub { annotations, variant, identifier, input, output, output_type, span, id }
98    }
99
100    /// Returns function name.
101    pub fn name(&self) -> Symbol {
102        self.identifier.name
103    }
104
105    /// Returns `true` if the function name is `main`.
106    pub fn is_main(&self) -> bool {
107        self.name() == sym::main
108    }
109
110    /// Private formatting method used for optimizing [fmt::Debug] and [fmt::Display] implementations.
111    fn format(&self, f: &mut fmt::Formatter) -> fmt::Result {
112        match self.variant {
113            Variant::Inline => write!(f, "inline ")?,
114            Variant::Script => write!(f, "script ")?,
115            Variant::Function | Variant::AsyncFunction => write!(f, "function ")?,
116            Variant::Transition | Variant::AsyncTransition => write!(f, "transition ")?,
117        }
118        write!(f, "{}", self.identifier)?;
119
120        let parameters = self.input.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(",");
121        let returns = match self.output.len() {
122            0 => "()".to_string(),
123            1 => self.output[0].to_string(),
124            _ => self.output.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","),
125        };
126        write!(f, "({parameters}) -> {returns}")?;
127
128        Ok(())
129    }
130
131    /// Converts from snarkvm function type to leo FunctionStub, while also carrying the parent program name.
132    pub fn from_function_core<N: Network>(function: &FunctionCore<N>, program: Symbol) -> Self {
133        let outputs = function
134            .outputs()
135            .iter()
136            .map(|output| match output.value_type() {
137                ValueType::Constant(val) => vec![Output {
138                    mode: Mode::Constant,
139                    type_: Type::from_snarkvm(val, program),
140                    span: Default::default(),
141                    id: Default::default(),
142                }],
143                ValueType::Public(val) => vec![Output {
144                    mode: Mode::Public,
145                    type_: Type::from_snarkvm(val, program),
146                    span: Default::default(),
147                    id: Default::default(),
148                }],
149                ValueType::Private(val) => vec![Output {
150                    mode: Mode::Private,
151                    type_: Type::from_snarkvm(val, program),
152                    span: Default::default(),
153                    id: Default::default(),
154                }],
155                ValueType::Record(id) => vec![Output {
156                    mode: Mode::None,
157                    type_: Type::Composite(CompositeType {
158                        path: {
159                            let ident = Identifier::from(id);
160                            Path::from(ident)
161                                .to_global(Location::new(program, vec![ident.name]))
162                                .with_user_program(Identifier::new(program, Default::default()))
163                        },
164                        const_arguments: Vec::new(),
165                    }),
166                    span: Default::default(),
167                    id: Default::default(),
168                }],
169                ValueType::ExternalRecord(loc) => {
170                    let external_program = ProgramId::from(loc.program_id()).name.name;
171                    vec![Output {
172                        mode: Mode::None,
173                        span: Default::default(),
174                        id: Default::default(),
175                        type_: Type::Composite(CompositeType {
176                            path: {
177                                let ident = Identifier::from(loc.resource());
178                                Path::from(ident)
179                                    .to_global(Location::new(external_program, vec![ident.name]))
180                                    .with_user_program(Identifier::new(external_program, Default::default()))
181                            },
182                            const_arguments: Vec::new(),
183                        }),
184                    }]
185                }
186                ValueType::Future(_) => vec![Output {
187                    mode: Mode::None,
188                    span: Default::default(),
189                    id: Default::default(),
190                    type_: Type::Future(FutureType::new(
191                        Vec::new(),
192                        Some(Location::new(program, vec![Symbol::intern(&function.name().to_string())])),
193                        false,
194                    )),
195                }],
196            })
197            .collect_vec()
198            .concat();
199        let output_vec = outputs.iter().map(|output| output.type_.clone()).collect_vec();
200        let output_type = match output_vec.len() {
201            0 => Type::Unit,
202            1 => output_vec[0].clone(),
203            _ => Type::Tuple(TupleType::new(output_vec)),
204        };
205
206        Self {
207            annotations: Vec::new(),
208            variant: match function.finalize_logic().is_some() {
209                true => Variant::AsyncTransition,
210                false => Variant::Transition,
211            },
212            identifier: Identifier::from(function.name()),
213            input: function
214                .inputs()
215                .iter()
216                .enumerate()
217                .map(|(index, input)| {
218                    let arg_name = Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default());
219                    match input.value_type() {
220                        ValueType::Constant(val) => Input {
221                            identifier: arg_name,
222                            mode: Mode::Constant,
223                            type_: Type::from_snarkvm(val, program),
224                            span: Default::default(),
225                            id: Default::default(),
226                        },
227                        ValueType::Public(val) => Input {
228                            identifier: arg_name,
229                            mode: Mode::Public,
230                            type_: Type::from_snarkvm(val, program),
231                            span: Default::default(),
232                            id: Default::default(),
233                        },
234                        ValueType::Private(val) => Input {
235                            identifier: arg_name,
236                            mode: Mode::Private,
237                            type_: Type::from_snarkvm(val, program),
238                            span: Default::default(),
239                            id: Default::default(),
240                        },
241                        ValueType::Record(id) => Input {
242                            identifier: arg_name,
243                            mode: Mode::None,
244                            type_: Type::Composite(CompositeType {
245                                path: {
246                                    let ident = Identifier::from(id);
247                                    Path::from(ident)
248                                        .to_global(Location::new(program, vec![ident.name]))
249                                        .with_user_program(Identifier::new(program, Default::default()))
250                                },
251                                const_arguments: Vec::new(),
252                            }),
253                            span: Default::default(),
254                            id: Default::default(),
255                        },
256                        ValueType::ExternalRecord(loc) => {
257                            let external_program = ProgramId::from(loc.program_id()).name.name;
258                            Input {
259                                identifier: arg_name,
260                                mode: Mode::None,
261                                span: Default::default(),
262                                id: Default::default(),
263                                type_: Type::Composite(CompositeType {
264                                    path: {
265                                        let ident = Identifier::from(loc.resource());
266                                        Path::from(ident)
267                                            .to_global(Location::new(external_program, vec![ident.name]))
268                                            .with_user_program(Identifier::new(external_program, Default::default()))
269                                    },
270                                    const_arguments: Vec::new(),
271                                }),
272                            }
273                        }
274                        ValueType::Future(_) => panic!("Functions do not contain futures as inputs"),
275                    }
276                })
277                .collect_vec(),
278            output: outputs,
279            output_type,
280            span: Default::default(),
281            id: Default::default(),
282        }
283    }
284
285    pub fn from_finalize<N: Network>(function: &FunctionCore<N>, key_name: Symbol, program: Symbol) -> Self {
286        Self {
287            annotations: Vec::new(),
288            variant: Variant::AsyncFunction,
289            identifier: Identifier::new(key_name, Default::default()),
290            input: function
291                .finalize_logic()
292                .unwrap()
293                .inputs()
294                .iter()
295                .enumerate()
296                .map(|(index, input)| Input {
297                    identifier: Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default()),
298                    mode: Mode::None,
299                    type_: match input.finalize_type() {
300                        PlaintextFinalizeType(val) => Type::from_snarkvm(val, program),
301                        FutureFinalizeType(val) => Type::Future(FutureType::new(
302                            Vec::new(),
303                            Some(Location::new(Identifier::from(val.program_id().name()).name, vec![Symbol::intern(
304                                &format!("finalize/{}", val.resource()),
305                            )])),
306                            false,
307                        )),
308                    },
309                    span: Default::default(),
310                    id: Default::default(),
311                })
312                .collect_vec(),
313            output: Vec::new(),
314            output_type: Type::Unit,
315            span: Default::default(),
316            id: 0,
317        }
318    }
319
320    pub fn from_closure<N: Network>(closure: &ClosureCore<N>, program: Symbol) -> Self {
321        let outputs = closure
322            .outputs()
323            .iter()
324            .map(|output| match output.register_type() {
325                Plaintext(val) => Output {
326                    mode: Mode::None,
327                    type_: Type::from_snarkvm(val, program),
328                    span: Default::default(),
329                    id: Default::default(),
330                },
331                Record(_) => panic!("Closures do not return records"),
332                ExternalRecord(_) => panic!("Closures do not return external records"),
333                Future(_) => panic!("Closures do not return futures"),
334            })
335            .collect_vec();
336        let output_vec = outputs.iter().map(|output| output.type_.clone()).collect_vec();
337        let output_type = match output_vec.len() {
338            0 => Type::Unit,
339            1 => output_vec[0].clone(),
340            _ => Type::Tuple(TupleType::new(output_vec)),
341        };
342        Self {
343            annotations: Vec::new(),
344            variant: Variant::Function,
345            identifier: Identifier::from(closure.name()),
346            input: closure
347                .inputs()
348                .iter()
349                .enumerate()
350                .map(|(index, input)| {
351                    let arg_name = Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default());
352                    match input.register_type() {
353                        Plaintext(val) => Input {
354                            identifier: arg_name,
355                            mode: Mode::None,
356                            type_: Type::from_snarkvm(val, program),
357                            span: Default::default(),
358                            id: Default::default(),
359                        },
360                        Record(_) => panic!("Closures do not contain records as inputs"),
361                        ExternalRecord(_) => panic!("Closures do not contain external records as inputs"),
362                        Future(_) => panic!("Closures do not contain futures as inputs"),
363                    }
364                })
365                .collect_vec(),
366            output: outputs,
367            output_type,
368            span: Default::default(),
369            id: Default::default(),
370        }
371    }
372}
373
374impl From<Function> for FunctionStub {
375    fn from(function: Function) -> Self {
376        Self {
377            annotations: function.annotations,
378            variant: function.variant,
379            identifier: function.identifier,
380            input: function.input,
381            output: function.output,
382            output_type: function.output_type,
383            span: function.span,
384            id: function.id,
385        }
386    }
387}
388
389impl fmt::Debug for FunctionStub {
390    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
391        self.format(f)
392    }
393}
394
395impl fmt::Display for FunctionStub {
396    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
397        self.format(f)
398    }
399}
400
401crate::simple_node_impl!(FunctionStub);