Skip to main content

leo_ast/constructor/
mod.rs

1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::{Annotation, Block, Indent, IntegerType, Location, NetworkName, Node, NodeID, Type};
18use leo_span::{Span, sym};
19
20use anyhow::{anyhow, bail};
21use serde::{Deserialize, Serialize};
22use snarkvm::prelude::{Address, Literal, Locator, Network};
23use std::{fmt, str::FromStr};
24
25/// A constructor definition.
26#[derive(Clone, Default, Eq, PartialEq, Serialize, Deserialize)]
27pub struct Constructor {
28    /// Annotations on the constructor.
29    pub annotations: Vec<Annotation>,
30    /// The body of the constructor.
31    pub block: Block,
32    /// The entire span of the constructor definition.
33    pub span: Span,
34    /// The ID of the node.
35    pub id: NodeID,
36}
37
38/// The upgrade variant.
39#[derive(Clone, Debug, Eq, PartialEq)]
40pub enum UpgradeVariant {
41    Admin { address: String },
42    Custom,
43    Checksum { mapping: Location, key: String, key_type: Type },
44    NoUpgrade,
45}
46
47impl Constructor {
48    pub fn get_upgrade_variant_with_network(&self, network: NetworkName) -> anyhow::Result<UpgradeVariant> {
49        match network {
50            NetworkName::MainnetV0 => self.get_upgrade_variant::<snarkvm::prelude::MainnetV0>(),
51            NetworkName::TestnetV0 => self.get_upgrade_variant::<snarkvm::prelude::TestnetV0>(),
52            NetworkName::CanaryV0 => self.get_upgrade_variant::<snarkvm::prelude::CanaryV0>(),
53        }
54    }
55
56    /// Checks that the constructor's annotations are valid and returns the upgrade variant.
57    pub fn get_upgrade_variant<N: Network>(&self) -> anyhow::Result<UpgradeVariant> {
58        // Check that there is exactly one annotation.
59        if self.annotations.len() != 1 {
60            bail!(
61                "A constructor must have exactly one of the following annotations: `@admin`, `@checksum`, `@custom`, or `@noupgrade`."
62            );
63        }
64        // Get the annotation.
65        let annotation = &self.annotations[0];
66        match annotation.identifier.name {
67            sym::admin => {
68                // Parse the address string from the annotation.
69                let Some(address_string) = annotation.map.get(&sym::address) else {
70                    bail!("An `@admin` annotation must have an 'address' key.")
71                };
72                // Parse the address.
73                let address = Address::<N>::from_str(address_string)
74                    .map_err(|e| anyhow!("Invalid address in `@admin` annotation: `{e}`."))?;
75                Ok(UpgradeVariant::Admin { address: address.to_string() })
76            }
77            sym::checksum => {
78                // Parse the mapping string from the annotation.
79                let Some(mapping_string) = annotation.map.get(&sym::mapping) else {
80                    bail!("A `@checksum` annotation must have a 'mapping' key.")
81                };
82                // Parse the mapping string as a locator. Accept both `prog.aleo::name` (Leo
83                // syntax) and `prog.aleo/name` (Aleo protocol syntax) by normalizing `::` to `/`.
84                let normalized = mapping_string.replace(".aleo::", ".aleo/");
85                let mapping = Locator::<N>::from_str(&normalized)
86                    .map_err(|e| anyhow!("Invalid mapping in `@checksum` annotation: `{e}`."))?;
87
88                // Parse the key string from the annotation.
89                let Some(key_string) = annotation.map.get(&sym::key) else {
90                    bail!("A `@checksum` annotation must have a 'key' key.")
91                };
92                // Parse the key as a plaintext value.
93                let key = Literal::<N>::from_str(key_string)
94                    .map_err(|e| anyhow!("Invalid key in `@checksum` annotation: `{e}`."))?;
95                // Get the literal type.
96                let key_type = get_type_from_snarkvm_literal(&key);
97                Ok(UpgradeVariant::Checksum { mapping: mapping.into(), key: key.to_string(), key_type })
98            }
99            sym::custom => Ok(UpgradeVariant::Custom),
100            sym::noupgrade => Ok(UpgradeVariant::NoUpgrade),
101            _ => bail!(
102                "Invalid annotation on constructor: `{}`. Expected one of `@admin`, `@checksum`, `@custom`, or `@noupgrade`.",
103                annotation.identifier.name
104            ),
105        }
106    }
107}
108
109impl fmt::Display for Constructor {
110    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
111        for annotation in &self.annotations {
112            writeln!(f, "{annotation}")?;
113        }
114
115        writeln!(f, "async constructor() {{")?;
116        for stmt in self.block.statements.iter() {
117            writeln!(f, "{}{}", Indent(stmt), stmt.semicolon())?;
118        }
119        write!(f, "}}")
120    }
121}
122
123impl fmt::Debug for Constructor {
124    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
125        write!(f, "{self}")
126    }
127}
128
129crate::simple_node_impl!(Constructor);
130
131// A helper function to get the type from a snarkVM literal.
132fn get_type_from_snarkvm_literal<N: Network>(literal: &Literal<N>) -> Type {
133    match literal {
134        Literal::Field(_) => Type::Field,
135        Literal::Group(_) => Type::Group,
136        Literal::Address(_) => Type::Address,
137        Literal::Scalar(_) => Type::Scalar,
138        Literal::Boolean(_) => Type::Boolean,
139        Literal::String(_) => Type::String,
140        Literal::I8(_) => Type::Integer(IntegerType::I8),
141        Literal::I16(_) => Type::Integer(IntegerType::I16),
142        Literal::I32(_) => Type::Integer(IntegerType::I32),
143        Literal::I64(_) => Type::Integer(IntegerType::I64),
144        Literal::I128(_) => Type::Integer(IntegerType::I128),
145        Literal::U8(_) => Type::Integer(IntegerType::U8),
146        Literal::U16(_) => Type::Integer(IntegerType::U16),
147        Literal::U32(_) => Type::Integer(IntegerType::U32),
148        Literal::U64(_) => Type::Integer(IntegerType::U64),
149        Literal::U128(_) => Type::Integer(IntegerType::U128),
150        Literal::Signature(_) => Type::Signature,
151        Literal::Identifier(_) => Type::Identifier,
152    }
153}