leo_ast/stub/
function_stub.rs

1// Copyright (C) 2019-2025 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::{
18    Annotation,
19    CompositeType,
20    Function,
21    FutureType,
22    Identifier,
23    Input,
24    Location,
25    Mode,
26    Node,
27    NodeID,
28    Output,
29    ProgramId,
30    TupleType,
31    Type,
32    Variant,
33};
34use leo_span::{Span, Symbol, sym};
35
36use itertools::Itertools;
37use serde::{Deserialize, Serialize};
38use snarkvm::{
39    console::program::{
40        FinalizeType::{Future as FutureFinalizeType, Plaintext as PlaintextFinalizeType},
41        RegisterType::{ExternalRecord, Future, Plaintext, Record},
42    },
43    prelude::{Network, ValueType},
44    synthesizer::program::{ClosureCore, CommandTrait, FunctionCore, InstructionTrait},
45};
46use std::fmt;
47
48/// A function stub definition.
49#[derive(Clone, Serialize, Deserialize)]
50pub struct FunctionStub {
51    /// Annotations on the function.
52    pub annotations: Vec<Annotation>,
53    /// Is this function a transition, inlined, or a regular function?.
54    pub variant: Variant,
55    /// The function identifier, e.g., `foo` in `function foo(...) { ... }`.
56    pub identifier: Identifier,
57    /// The function's input parameters.
58    pub input: Vec<Input>,
59    /// The function's output declarations.
60    pub output: Vec<Output>,
61    /// The function's output type.
62    pub output_type: Type,
63    /// The entire span of the function definition.
64    pub span: Span,
65    /// The ID of the node.
66    pub id: NodeID,
67}
68
69impl PartialEq for FunctionStub {
70    fn eq(&self, other: &Self) -> bool {
71        self.identifier == other.identifier
72    }
73}
74
75impl Eq for FunctionStub {}
76
77impl FunctionStub {
78    /// Initialize a new function.
79    #[allow(clippy::too_many_arguments)]
80    pub fn new(
81        annotations: Vec<Annotation>,
82        _is_async: bool,
83        variant: Variant,
84        identifier: Identifier,
85        input: Vec<Input>,
86        output: Vec<Output>,
87        span: Span,
88        id: NodeID,
89    ) -> Self {
90        let output_type = match output.len() {
91            0 => Type::Unit,
92            1 => output[0].type_.clone(),
93            _ => Type::Tuple(TupleType::new(output.iter().map(|o| o.type_.clone()).collect())),
94        };
95
96        FunctionStub { annotations, variant, identifier, input, output, output_type, span, id }
97    }
98
99    /// Returns function name.
100    pub fn name(&self) -> Symbol {
101        self.identifier.name
102    }
103
104    /// Returns `true` if the function name is `main`.
105    pub fn is_main(&self) -> bool {
106        self.name() == sym::main
107    }
108
109    /// Private formatting method used for optimizing [fmt::Debug] and [fmt::Display] implementations.
110    fn format(&self, f: &mut fmt::Formatter) -> fmt::Result {
111        match self.variant {
112            Variant::Inline => write!(f, "inline ")?,
113            Variant::Script => write!(f, "script ")?,
114            Variant::Function | Variant::AsyncFunction => write!(f, "function ")?,
115            Variant::Transition | Variant::AsyncTransition => write!(f, "transition ")?,
116        }
117        write!(f, "{}", self.identifier)?;
118
119        let parameters = self.input.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(",");
120        let returns = match self.output.len() {
121            0 => "()".to_string(),
122            1 => self.output[0].to_string(),
123            _ => self.output.iter().map(|x| x.to_string()).collect::<Vec<_>>().join(","),
124        };
125        write!(f, "({parameters}) -> {returns}")?;
126
127        Ok(())
128    }
129
130    /// Converts from snarkvm function type to leo FunctionStub, while also carrying the parent program name.
131    pub fn from_function_core<N: Network, Instruction: InstructionTrait<N>, Command: CommandTrait<N>>(
132        function: &FunctionCore<N, Instruction, Command>,
133        program: Symbol,
134    ) -> Self {
135        let outputs = function
136            .outputs()
137            .iter()
138            .map(|output| match output.value_type() {
139                ValueType::Constant(val) => vec![Output {
140                    mode: Mode::Constant,
141                    type_: Type::from_snarkvm(val, None),
142                    span: Default::default(),
143                    id: Default::default(),
144                }],
145                ValueType::Public(val) => vec![Output {
146                    mode: Mode::Public,
147                    type_: Type::from_snarkvm(val, None),
148                    span: Default::default(),
149                    id: Default::default(),
150                }],
151                ValueType::Private(val) => vec![Output {
152                    mode: Mode::Private,
153                    type_: Type::from_snarkvm(val, None),
154                    span: Default::default(),
155                    id: Default::default(),
156                }],
157                ValueType::Record(id) => vec![Output {
158                    mode: Mode::None,
159                    type_: Type::Composite(CompositeType {
160                        id: Identifier::from(id),
161                        const_arguments: Vec::new(),
162                        program: Some(program),
163                    }),
164                    span: Default::default(),
165                    id: Default::default(),
166                }],
167                ValueType::ExternalRecord(loc) => {
168                    vec![Output {
169                        mode: Mode::None,
170                        span: Default::default(),
171                        id: Default::default(),
172                        type_: Type::Composite(CompositeType {
173                            id: Identifier::from(loc.resource()),
174                            const_arguments: Vec::new(),
175                            program: Some(ProgramId::from(loc.program_id()).name.name),
176                        }),
177                    }]
178                }
179                ValueType::Future(_) => vec![Output {
180                    mode: Mode::None,
181                    span: Default::default(),
182                    id: Default::default(),
183                    type_: Type::Future(FutureType::new(
184                        Vec::new(),
185                        Some(Location::new(program, Identifier::from(function.name()).name)),
186                        false,
187                    )),
188                }],
189            })
190            .collect_vec()
191            .concat();
192        let output_vec = outputs.iter().map(|output| output.type_.clone()).collect_vec();
193        let output_type = match output_vec.len() {
194            0 => Type::Unit,
195            1 => output_vec[0].clone(),
196            _ => Type::Tuple(TupleType::new(output_vec)),
197        };
198
199        Self {
200            annotations: Vec::new(),
201            variant: match function.finalize_logic().is_some() {
202                true => Variant::AsyncTransition,
203                false => Variant::Transition,
204            },
205            identifier: Identifier::from(function.name()),
206            input: function
207                .inputs()
208                .iter()
209                .enumerate()
210                .map(|(index, input)| {
211                    let arg_name = Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default());
212                    match input.value_type() {
213                        ValueType::Constant(val) => Input {
214                            identifier: arg_name,
215                            mode: Mode::Constant,
216                            type_: Type::from_snarkvm(val, None),
217                            span: Default::default(),
218                            id: Default::default(),
219                        },
220                        ValueType::Public(val) => Input {
221                            identifier: arg_name,
222                            mode: Mode::Public,
223                            type_: Type::from_snarkvm(val, None),
224                            span: Default::default(),
225                            id: Default::default(),
226                        },
227                        ValueType::Private(val) => Input {
228                            identifier: arg_name,
229                            mode: Mode::Private,
230                            type_: Type::from_snarkvm(val, None),
231                            span: Default::default(),
232                            id: Default::default(),
233                        },
234                        ValueType::Record(id) => Input {
235                            identifier: arg_name,
236                            mode: Mode::None,
237                            type_: Type::Composite(CompositeType {
238                                id: Identifier::from(id),
239                                const_arguments: Vec::new(),
240                                program: Some(program),
241                            }),
242                            span: Default::default(),
243                            id: Default::default(),
244                        },
245                        ValueType::ExternalRecord(loc) => Input {
246                            identifier: arg_name,
247                            mode: Mode::None,
248                            span: Default::default(),
249                            id: Default::default(),
250                            type_: Type::Composite(CompositeType {
251                                id: Identifier::from(loc.resource()),
252                                const_arguments: Vec::new(),
253                                program: Some(ProgramId::from(loc.program_id()).name.name),
254                            }),
255                        },
256                        ValueType::Future(_) => panic!("Functions do not contain futures as inputs"),
257                    }
258                })
259                .collect_vec(),
260            output: outputs,
261            output_type,
262            span: Default::default(),
263            id: Default::default(),
264        }
265    }
266
267    pub fn from_finalize<N: Network, Instruction: InstructionTrait<N>, Command: CommandTrait<N>>(
268        function: &FunctionCore<N, Instruction, Command>,
269        key_name: Symbol,
270        program: Symbol,
271    ) -> Self {
272        Self {
273            annotations: Vec::new(),
274            variant: Variant::AsyncFunction,
275            identifier: Identifier::new(key_name, Default::default()),
276            input: function
277                .finalize_logic()
278                .unwrap()
279                .inputs()
280                .iter()
281                .enumerate()
282                .map(|(index, input)| Input {
283                    identifier: Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default()),
284                    mode: Mode::None,
285                    type_: match input.finalize_type() {
286                        PlaintextFinalizeType(val) => Type::from_snarkvm(val, Some(program)),
287                        FutureFinalizeType(val) => Type::Future(FutureType::new(
288                            Vec::new(),
289                            Some(Location::new(
290                                Identifier::from(val.program_id().name()).name,
291                                Symbol::intern(&format!("finalize/{}", val.resource())),
292                            )),
293                            false,
294                        )),
295                    },
296                    span: Default::default(),
297                    id: Default::default(),
298                })
299                .collect_vec(),
300            output: Vec::new(),
301            output_type: Type::Unit,
302            span: Default::default(),
303            id: 0,
304        }
305    }
306
307    pub fn from_closure<N: Network, Instruction: InstructionTrait<N>>(
308        closure: &ClosureCore<N, Instruction>,
309        program: Symbol,
310    ) -> Self {
311        let outputs = closure
312            .outputs()
313            .iter()
314            .map(|output| match output.register_type() {
315                Plaintext(val) => Output {
316                    mode: Mode::None,
317                    type_: Type::from_snarkvm(val, Some(program)),
318                    span: Default::default(),
319                    id: Default::default(),
320                },
321                Record(_) => panic!("Closures do not return records"),
322                ExternalRecord(_) => panic!("Closures do not return external records"),
323                Future(_) => panic!("Closures do not return futures"),
324            })
325            .collect_vec();
326        let output_vec = outputs.iter().map(|output| output.type_.clone()).collect_vec();
327        let output_type = match output_vec.len() {
328            0 => Type::Unit,
329            1 => output_vec[0].clone(),
330            _ => Type::Tuple(TupleType::new(output_vec)),
331        };
332        Self {
333            annotations: Vec::new(),
334            variant: Variant::Function,
335            identifier: Identifier::from(closure.name()),
336            input: closure
337                .inputs()
338                .iter()
339                .enumerate()
340                .map(|(index, input)| {
341                    let arg_name = Identifier::new(Symbol::intern(&format!("arg{}", index + 1)), Default::default());
342                    match input.register_type() {
343                        Plaintext(val) => Input {
344                            identifier: arg_name,
345                            mode: Mode::None,
346                            type_: Type::from_snarkvm(val, None),
347                            span: Default::default(),
348                            id: Default::default(),
349                        },
350                        Record(_) => panic!("Closures do not contain records as inputs"),
351                        ExternalRecord(_) => panic!("Closures do not contain external records as inputs"),
352                        Future(_) => panic!("Closures do not contain futures as inputs"),
353                    }
354                })
355                .collect_vec(),
356            output: outputs,
357            output_type,
358            span: Default::default(),
359            id: Default::default(),
360        }
361    }
362}
363
364impl From<Function> for FunctionStub {
365    fn from(function: Function) -> Self {
366        Self {
367            annotations: function.annotations,
368            variant: function.variant,
369            identifier: function.identifier,
370            input: function.input,
371            output: function.output,
372            output_type: function.output_type,
373            span: function.span,
374            id: function.id,
375        }
376    }
377}
378
379impl fmt::Debug for FunctionStub {
380    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
381        self.format(f)
382    }
383}
384
385impl fmt::Display for FunctionStub {
386    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
387        self.format(f)
388    }
389}
390
391crate::simple_node_impl!(FunctionStub);