Skip to main content

leo_ast/types/
type_.rs

1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::{
18    ArrayType,
19    CompositeType,
20    FutureType,
21    Identifier,
22    IntegerType,
23    Location,
24    MappingType,
25    OptionalType,
26    Path,
27    TupleType,
28    VectorType,
29};
30
31use itertools::Itertools;
32use leo_span::Symbol;
33use serde::{Deserialize, Serialize};
34use snarkvm::prelude::{
35    LiteralType,
36    Network,
37    PlaintextType,
38    PlaintextType::{Array, ExternalStruct, Literal, Struct},
39};
40use std::fmt;
41
42/// Explicit type used for defining a variable or expression type
43#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub enum Type {
45    /// The `address` type.
46    Address,
47    /// The array type.
48    Array(ArrayType),
49    /// The `bool` type.
50    Boolean,
51    /// The composite type.
52    Composite(CompositeType),
53    /// The `field` type.
54    Field,
55    /// The `future` type.
56    Future(FutureType),
57    /// The `group` type.
58    Group,
59    /// A reference to a built in type.
60    Identifier(Identifier),
61    /// An integer type.
62    Integer(IntegerType),
63    /// A mapping type.
64    Mapping(MappingType),
65    /// A nullable type.
66    Optional(OptionalType),
67    /// The `scalar` type.
68    Scalar,
69    /// The `signature` type.
70    Signature,
71    /// The `string` type.
72    String,
73    /// A static tuple of at least one type.
74    Tuple(TupleType),
75    /// The vector type.
76    Vector(VectorType),
77    /// Numeric type which should be resolved to `Field`, `Group`, `Integer(_)`, or `Scalar`.
78    Numeric,
79    /// The `unit` type.
80    Unit,
81    /// Placeholder for a type that could not be resolved or was not well-formed.
82    /// Will eventually lead to a compile error.
83    #[default]
84    Err,
85}
86
87impl Type {
88    /// Are the types considered equal as far as the Leo user is concerned?
89    ///
90    /// In particular, any comparison involving an `Err` is `true`, and Futures which aren't explicit compare equal to
91    /// other Futures.
92    ///
93    /// An array with an undetermined length (e.g., one that depends on a `const`) is considered equal to other arrays
94    /// if their element types match. This allows const propagation to potentially resolve the length before type
95    /// checking is performed again.
96    ///
97    /// Composite types are considered equal if their names and resolved program names match. If either side still has
98    /// const generic arguments, they are treated as equal unconditionally since monomorphization and other passes of
99    /// type-checking will handle mismatches later.
100    pub fn eq_user(&self, other: &Type) -> bool {
101        match (self, other) {
102            (Type::Err, _)
103            | (_, Type::Err)
104            | (Type::Address, Type::Address)
105            | (Type::Boolean, Type::Boolean)
106            | (Type::Field, Type::Field)
107            | (Type::Group, Type::Group)
108            | (Type::Scalar, Type::Scalar)
109            | (Type::Signature, Type::Signature)
110            | (Type::String, Type::String)
111            | (Type::Unit, Type::Unit) => true,
112            (Type::Array(left), Type::Array(right)) => {
113                (match (left.length.as_u32(), right.length.as_u32()) {
114                    (Some(l1), Some(l2)) => l1 == l2,
115                    _ => {
116                        // An array with an undetermined length (e.g., one that depends on a `const`) is considered
117                        // equal to other arrays because their lengths _may_ eventually be proven equal.
118                        true
119                    }
120                }) && left.element_type().eq_user(right.element_type())
121            }
122            (Type::Identifier(left), Type::Identifier(right)) => left.name == right.name,
123            (Type::Integer(left), Type::Integer(right)) => left == right,
124            (Type::Mapping(left), Type::Mapping(right)) => {
125                left.key.eq_user(&right.key) && left.value.eq_user(&right.value)
126            }
127            (Type::Optional(left), Type::Optional(right)) => left.inner.eq_user(&right.inner),
128            (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
129                .elements()
130                .iter()
131                .zip_eq(right.elements().iter())
132                .all(|(left_type, right_type)| left_type.eq_user(right_type)),
133            (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_user(&right.element_type),
134            (Type::Composite(left), Type::Composite(right)) => {
135                // If either composite still has const generic arguments, treat them as equal.
136                if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
137                    return true;
138                }
139
140                // Two composite types are the same if their global locations match.
141                match (&left.path.try_global_location(), &right.path.try_global_location()) {
142                    (Some(l), Some(r)) => l == r,
143                    _ => false,
144                }
145            }
146
147            (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
148            (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
149                .inputs()
150                .iter()
151                .zip_eq(right.inputs().iter())
152                .all(|(left_type, right_type)| left_type.eq_user(right_type)),
153            _ => false,
154        }
155    }
156
157    /// Returns `true` if the self `Type` is equal to the other `Type` in all aspects besides composite program of origin.
158    ///
159    /// In the case of futures, it also makes sure that if both are not explicit, they are equal.
160    ///
161    /// Flattens array syntax: `[[u8; 1]; 2] == [u8; (2, 1)] == true`
162    ///
163    /// Composite types are considered equal if their names match. If either side still has const generic arguments,
164    /// they are treated as equal unconditionally since monomorphization and other passes of type-checking will handle
165    /// mismatches later.
166    pub fn eq_flat_relaxed(&self, other: &Self) -> bool {
167        match (self, other) {
168            (Type::Address, Type::Address)
169            | (Type::Boolean, Type::Boolean)
170            | (Type::Field, Type::Field)
171            | (Type::Group, Type::Group)
172            | (Type::Scalar, Type::Scalar)
173            | (Type::Signature, Type::Signature)
174            | (Type::String, Type::String)
175            | (Type::Unit, Type::Unit) => true,
176            (Type::Array(left), Type::Array(right)) => {
177                // Two arrays are equal if their element types are the same and if their lengths
178                // are the same, assuming the lengths can be extracted as `u32`.
179                (match (left.length.as_u32(), right.length.as_u32()) {
180                    (Some(l1), Some(l2)) => l1 == l2,
181                    _ => {
182                        // An array with an undetermined length (e.g., one that depends on a `const`) is considered
183                        // equal to other arrays because their lengths _may_ eventually be proven equal.
184                        true
185                    }
186                }) && left.element_type().eq_flat_relaxed(right.element_type())
187            }
188            (Type::Identifier(left), Type::Identifier(right)) => left.matches(right),
189            (Type::Integer(left), Type::Integer(right)) => left.eq(right),
190            (Type::Mapping(left), Type::Mapping(right)) => {
191                left.key.eq_flat_relaxed(&right.key) && left.value.eq_flat_relaxed(&right.value)
192            }
193            (Type::Optional(left), Type::Optional(right)) => left.inner.eq_flat_relaxed(&right.inner),
194            (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
195                .elements()
196                .iter()
197                .zip_eq(right.elements().iter())
198                .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
199            (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_flat_relaxed(&right.element_type),
200            (Type::Composite(left), Type::Composite(right)) => {
201                // If either composite still has const generic arguments, treat them as equal.
202                // Type checking will run again after monomorphization.
203                if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
204                    return true;
205                }
206
207                // Two composite types are the same if their global paths match.
208                // If the absolute paths are not available, then we really can't compare the two
209                // types and we just return `false` to be conservative.
210                match (&left.path.try_global_location(), &right.path.try_global_location()) {
211                    (Some(l), Some(r)) => l.path == r.path,
212                    _ => false,
213                }
214            }
215            // Don't type check when type hasn't been explicitly defined.
216            (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
217            (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
218                .inputs()
219                .iter()
220                .zip_eq(right.inputs().iter())
221                .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
222            _ => false,
223        }
224    }
225
226    pub fn from_snarkvm<N: Network>(t: &PlaintextType<N>, program: Symbol) -> Self {
227        match t {
228            Literal(lit) => (*lit).into(),
229            Struct(s) => Type::Composite(CompositeType {
230                path: {
231                    let ident = Identifier::from(s);
232                    Path::from(ident).to_global(Location::new(program, vec![ident.name]))
233                },
234                const_arguments: Vec::new(),
235            }),
236            ExternalStruct(l) => Type::Composite(CompositeType {
237                path: {
238                    let external_program = Identifier::from(l.program_id().name());
239                    let name = Identifier::from(l.resource());
240                    Path::from(name)
241                        .with_user_program(external_program)
242                        .to_global(Location::new(external_program.name, vec![name.name]))
243                },
244                const_arguments: Vec::new(),
245            }),
246            Array(array) => Type::Array(ArrayType::from_snarkvm(array, program)),
247        }
248    }
249
250    // Attempts to convert `self` to a snarkVM `PlaintextType`.
251    pub fn to_snarkvm<N: Network>(&self) -> anyhow::Result<PlaintextType<N>> {
252        match self {
253            Type::Address => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Address)),
254            Type::Boolean => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Boolean)),
255            Type::Field => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Field)),
256            Type::Group => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Group)),
257            Type::Integer(int_type) => match int_type {
258                IntegerType::U8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U8)),
259                IntegerType::U16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U16)),
260                IntegerType::U32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U32)),
261                IntegerType::U64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U64)),
262                IntegerType::U128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U128)),
263                IntegerType::I8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I8)),
264                IntegerType::I16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I16)),
265                IntegerType::I32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I32)),
266                IntegerType::I64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I64)),
267                IntegerType::I128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I128)),
268            },
269            Type::Scalar => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Scalar)),
270            Type::Signature => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Signature)),
271            Type::Array(array_type) => Ok(PlaintextType::<N>::Array(array_type.to_snarkvm()?)),
272            _ => anyhow::bail!("Converting from type {self} to snarkVM type is not supported"),
273        }
274    }
275
276    // A helper function to get the size in bits of the input type.
277    pub fn size_in_bits<N: Network, F0, F1>(
278        &self,
279        is_raw: bool,
280        get_structs: F0,
281        get_external_structs: F1,
282    ) -> anyhow::Result<usize>
283    where
284        F0: Fn(&snarkvm::prelude::Identifier<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
285        F1: Fn(&snarkvm::prelude::Locator<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
286    {
287        match is_raw {
288            false => self.to_snarkvm::<N>()?.size_in_bits(&get_structs, &get_external_structs),
289            true => self.to_snarkvm::<N>()?.size_in_bits_raw(&get_structs, &get_external_structs),
290        }
291    }
292
293    /// Determines whether `self` can be coerced to the `expected` type.
294    ///
295    /// This method checks if the current type can be implicitly coerced to the expected type
296    /// according to specific rules:
297    /// - `Optional<T>` can be coerced to `Optional<T>`.
298    /// - `T` can be coerced to `Optional<T>`.
299    /// - Arrays `[T; N]` can be coerced to `[Optional<T>; N]` if lengths match or are unknown,
300    ///   and element types are coercible.
301    /// - Falls back to an equality check for other types.
302    ///
303    /// # Arguments
304    /// * `expected` - The type to which `self` is being coerced.
305    ///
306    /// # Returns
307    /// `true` if coercion is allowed; `false` otherwise.
308    pub fn can_coerce_to(&self, expected: &Type) -> bool {
309        use Type::*;
310
311        match (self, expected) {
312            // Allow Optional<T> → Optional<T>
313            (Optional(actual_opt), Optional(expected_opt)) => actual_opt.inner.can_coerce_to(&expected_opt.inner),
314
315            // Allow T → Optional<T>
316            (a, Optional(opt)) => a.can_coerce_to(&opt.inner),
317
318            // Allow [T; N] → [Optional<T>; N]
319            (Array(a_arr), Array(e_arr)) => {
320                let lengths_equal = match (a_arr.length.as_u32(), e_arr.length.as_u32()) {
321                    (Some(l1), Some(l2)) => l1 == l2,
322                    _ => true,
323                };
324
325                lengths_equal && a_arr.element_type().can_coerce_to(e_arr.element_type())
326            }
327
328            // Fallback: check for exact match
329            _ => self.eq_user(expected),
330        }
331    }
332
333    pub fn is_optional(&self) -> bool {
334        matches!(self, Self::Optional(_))
335    }
336
337    pub fn is_vector(&self) -> bool {
338        matches!(self, Self::Vector(_))
339    }
340
341    pub fn is_mapping(&self) -> bool {
342        matches!(self, Self::Mapping(_))
343    }
344
345    pub fn to_optional(&self) -> Type {
346        Type::Optional(OptionalType { inner: Box::new(self.clone()) })
347    }
348
349    pub fn is_empty(&self) -> bool {
350        match self {
351            Type::Unit => true,
352            Type::Array(array_type) => {
353                if let Some(length) = array_type.length.as_u32() {
354                    length == 0
355                } else {
356                    false
357                }
358            }
359            _ => false,
360        }
361    }
362}
363
364impl From<LiteralType> for Type {
365    fn from(value: LiteralType) -> Self {
366        match value {
367            LiteralType::Address => Type::Address,
368            LiteralType::Boolean => Type::Boolean,
369            LiteralType::Field => Type::Field,
370            LiteralType::Group => Type::Group,
371            LiteralType::U8 => Type::Integer(IntegerType::U8),
372            LiteralType::U16 => Type::Integer(IntegerType::U16),
373            LiteralType::U32 => Type::Integer(IntegerType::U32),
374            LiteralType::U64 => Type::Integer(IntegerType::U64),
375            LiteralType::U128 => Type::Integer(IntegerType::U128),
376            LiteralType::I8 => Type::Integer(IntegerType::I8),
377            LiteralType::I16 => Type::Integer(IntegerType::I16),
378            LiteralType::I32 => Type::Integer(IntegerType::I32),
379            LiteralType::I64 => Type::Integer(IntegerType::I64),
380            LiteralType::I128 => Type::Integer(IntegerType::I128),
381            LiteralType::Scalar => Type::Scalar,
382            LiteralType::Signature => Type::Signature,
383            LiteralType::String => Type::String,
384        }
385    }
386}
387
388impl fmt::Display for Type {
389    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
390        match *self {
391            Type::Address => write!(f, "address"),
392            Type::Array(ref array_type) => write!(f, "{array_type}"),
393            Type::Boolean => write!(f, "bool"),
394            Type::Field => write!(f, "field"),
395            Type::Future(ref future_type) => write!(f, "{future_type}"),
396            Type::Group => write!(f, "group"),
397            Type::Identifier(ref variable) => write!(f, "{variable}"),
398            Type::Integer(ref integer_type) => write!(f, "{integer_type}"),
399            Type::Mapping(ref mapping_type) => write!(f, "{mapping_type}"),
400            Type::Optional(ref optional_type) => write!(f, "{optional_type}"),
401            Type::Scalar => write!(f, "scalar"),
402            Type::Signature => write!(f, "signature"),
403            Type::String => write!(f, "string"),
404            Type::Composite(ref composite_type) => write!(f, "{composite_type}"),
405            Type::Tuple(ref tuple) => write!(f, "{tuple}"),
406            Type::Vector(ref vector_type) => write!(f, "{vector_type}"),
407            Type::Numeric => write!(f, "numeric"),
408            Type::Unit => write!(f, "()"),
409            Type::Err => write!(f, "error"),
410        }
411    }
412}