Skip to main content

leo_ast/types/
type_.rs

1// Copyright (C) 2019-2026 Provable Inc.
2// This file is part of the Leo library.
3
4// The Leo library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The Leo library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the Leo library. If not, see <https://www.gnu.org/licenses/>.
16
17use crate::{
18    ArrayType,
19    CompositeType,
20    FutureType,
21    Identifier,
22    IntegerType,
23    Location,
24    MappingType,
25    OptionalType,
26    Path,
27    ProgramId,
28    TupleType,
29    VectorType,
30};
31
32use itertools::Itertools;
33use serde::{Deserialize, Serialize};
34use snarkvm::prelude::{
35    LiteralType,
36    Network,
37    PlaintextType,
38    PlaintextType::{Array, ExternalStruct, Literal, Struct},
39};
40use std::fmt;
41
42/// Explicit type used for defining a variable or expression type
43#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub enum Type {
45    /// The `address` type.
46    Address,
47    /// The array type.
48    Array(ArrayType),
49    /// The `bool` type.
50    Boolean,
51    /// The composite type.
52    Composite(CompositeType),
53    /// The `field` type.
54    Field,
55    /// The `future` type.
56    Future(FutureType),
57    /// The `group` type.
58    Group,
59    /// The `identifier` type.
60    Identifier,
61    /// The `dyn record` type.
62    DynRecord,
63    /// A reference to a built in type.
64    Ident(Identifier),
65    /// An integer type.
66    Integer(IntegerType),
67    /// A mapping type.
68    Mapping(MappingType),
69    /// A nullable type.
70    Optional(OptionalType),
71    /// The `scalar` type.
72    Scalar,
73    /// The `signature` type.
74    Signature,
75    /// The `string` type.
76    String,
77    /// A static tuple of at least one type.
78    Tuple(TupleType),
79    /// The vector type.
80    Vector(VectorType),
81    /// Numeric type which should be resolved to `Field`, `Group`, `Integer(_)`, or `Scalar`.
82    Numeric,
83    /// The `unit` type.
84    Unit,
85    /// Placeholder for a type that could not be resolved or was not well-formed.
86    /// Will eventually lead to a compile error.
87    #[default]
88    Err,
89}
90
91impl Type {
92    /// Are the types considered equal as far as the Leo user is concerned?
93    ///
94    /// In particular, any comparison involving an `Err` is `true`, and Futures which aren't explicit compare equal to
95    /// other Futures.
96    ///
97    /// An array with an undetermined length (e.g., one that depends on a `const`) is considered equal to other arrays
98    /// if their element types match. This allows const propagation to potentially resolve the length before type
99    /// checking is performed again.
100    ///
101    /// Composite types are considered equal if their names and resolved program names match. If either side still has
102    /// const generic arguments, they are treated as equal unconditionally since monomorphization and other passes of
103    /// type-checking will handle mismatches later.
104    pub fn eq_user(&self, other: &Type) -> bool {
105        match (self, other) {
106            (Type::Err, _)
107            | (_, Type::Err)
108            | (Type::Address, Type::Address)
109            | (Type::Boolean, Type::Boolean)
110            | (Type::Field, Type::Field)
111            | (Type::Group, Type::Group)
112            | (Type::Scalar, Type::Scalar)
113            | (Type::Signature, Type::Signature)
114            | (Type::String, Type::String)
115            | (Type::Identifier, Type::Identifier)
116            | (Type::DynRecord, Type::DynRecord)
117            | (Type::Unit, Type::Unit) => true,
118            (Type::Array(left), Type::Array(right)) => {
119                (match (left.length.as_u32(), right.length.as_u32()) {
120                    (Some(l1), Some(l2)) => l1 == l2,
121                    _ => {
122                        // An array with an undetermined length (e.g., one that depends on a `const`) is considered
123                        // equal to other arrays because their lengths _may_ eventually be proven equal.
124                        true
125                    }
126                }) && left.element_type().eq_user(right.element_type())
127            }
128            (Type::Ident(left), Type::Ident(right)) => left.name == right.name,
129            (Type::Integer(left), Type::Integer(right)) => left == right,
130            (Type::Mapping(left), Type::Mapping(right)) => {
131                left.key.eq_user(&right.key) && left.value.eq_user(&right.value)
132            }
133            (Type::Optional(left), Type::Optional(right)) => left.inner.eq_user(&right.inner),
134            (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
135                .elements()
136                .iter()
137                .zip_eq(right.elements().iter())
138                .all(|(left_type, right_type)| left_type.eq_user(right_type)),
139            (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_user(&right.element_type),
140            (Type::Composite(left), Type::Composite(right)) => {
141                // If either composite still has const generic arguments, treat them as equal;
142                // monomorphization and a subsequent type-checking pass will handle mismatches.
143                if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
144                    return true;
145                }
146
147                // Two composite types are the same if their global locations match.
148                match (&left.path.try_global_location(), &right.path.try_global_location()) {
149                    (Some(l), Some(r)) => l == r,
150                    _ => false,
151                }
152            }
153
154            (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
155            (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
156                .inputs()
157                .iter()
158                .zip_eq(right.inputs().iter())
159                .all(|(left_type, right_type)| left_type.eq_user(right_type)),
160            _ => false,
161        }
162    }
163
164    /// Returns `true` if the self `Type` is equal to the other `Type` in all aspects besides composite program of origin.
165    ///
166    /// In the case of futures, it also makes sure that if both are not explicit, they are equal.
167    ///
168    /// Flattens array syntax: `[[u8; 1]; 2] == [u8; (2, 1)] == true`
169    ///
170    /// Composite types are considered equal if their names match. If either side still has const generic arguments,
171    /// they are treated as equal unconditionally since monomorphization and other passes of type-checking will handle
172    /// mismatches later.
173    pub fn eq_flat_relaxed(&self, other: &Self) -> bool {
174        match (self, other) {
175            (Type::Address, Type::Address)
176            | (Type::Boolean, Type::Boolean)
177            | (Type::Field, Type::Field)
178            | (Type::Group, Type::Group)
179            | (Type::Scalar, Type::Scalar)
180            | (Type::Signature, Type::Signature)
181            | (Type::String, Type::String)
182            | (Type::Identifier, Type::Identifier)
183            | (Type::DynRecord, Type::DynRecord)
184            | (Type::Unit, Type::Unit) => true,
185            (Type::Array(left), Type::Array(right)) => {
186                // Two arrays are equal if their element types are the same and if their lengths
187                // are the same, assuming the lengths can be extracted as `u32`.
188                (match (left.length.as_u32(), right.length.as_u32()) {
189                    (Some(l1), Some(l2)) => l1 == l2,
190                    _ => {
191                        // An array with an undetermined length (e.g., one that depends on a `const`) is considered
192                        // equal to other arrays because their lengths _may_ eventually be proven equal.
193                        true
194                    }
195                }) && left.element_type().eq_flat_relaxed(right.element_type())
196            }
197            (Type::Ident(left), Type::Ident(right)) => left.matches(right),
198            (Type::Integer(left), Type::Integer(right)) => left.eq(right),
199            (Type::Mapping(left), Type::Mapping(right)) => {
200                left.key.eq_flat_relaxed(&right.key) && left.value.eq_flat_relaxed(&right.value)
201            }
202            (Type::Optional(left), Type::Optional(right)) => left.inner.eq_flat_relaxed(&right.inner),
203            (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
204                .elements()
205                .iter()
206                .zip_eq(right.elements().iter())
207                .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
208            (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_flat_relaxed(&right.element_type),
209            (Type::Composite(left), Type::Composite(right)) => {
210                // If either composite still has const generic arguments, treat them as equal.
211                // Type checking will run again after monomorphization.
212                if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
213                    return true;
214                }
215
216                // Two composite types are the same if their global paths match.
217                // If the absolute paths are not available, then we really can't compare the two
218                // types and we just return `false` to be conservative.
219                match (&left.path.try_global_location(), &right.path.try_global_location()) {
220                    (Some(l), Some(r)) => l.path == r.path,
221                    _ => false,
222                }
223            }
224            // Don't type check when type hasn't been explicitly defined.
225            (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
226            (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
227                .inputs()
228                .iter()
229                .zip_eq(right.inputs().iter())
230                .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
231            _ => false,
232        }
233    }
234
235    pub fn from_snarkvm<N: Network>(t: &PlaintextType<N>, program_id: ProgramId) -> Self {
236        match t {
237            Literal(lit) => (*lit).into(),
238            Struct(s) => Type::Composite(CompositeType {
239                path: {
240                    let ident = Identifier::from(s);
241                    Path::from(ident).to_global(Location::new(program_id.as_symbol(), vec![ident.name]))
242                },
243                const_arguments: Vec::new(),
244            }),
245            ExternalStruct(l) => Type::Composite(CompositeType {
246                path: {
247                    let external_program = ProgramId::from(l.program_id());
248                    let name = Identifier::from(l.resource());
249                    Path::from(name)
250                        .with_user_program(external_program)
251                        .to_global(Location::new(external_program.as_symbol(), vec![name.name]))
252                },
253                const_arguments: Vec::new(),
254            }),
255            Array(array) => Type::Array(ArrayType::from_snarkvm(array, program_id)),
256        }
257    }
258
259    // Attempts to convert `self` to a snarkVM `PlaintextType`.
260    pub fn to_snarkvm<N: Network>(&self) -> anyhow::Result<PlaintextType<N>> {
261        match self {
262            Type::Address => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Address)),
263            Type::Boolean => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Boolean)),
264            Type::Field => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Field)),
265            Type::Group => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Group)),
266            Type::Integer(int_type) => match int_type {
267                IntegerType::U8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U8)),
268                IntegerType::U16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U16)),
269                IntegerType::U32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U32)),
270                IntegerType::U64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U64)),
271                IntegerType::U128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U128)),
272                IntegerType::I8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I8)),
273                IntegerType::I16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I16)),
274                IntegerType::I32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I32)),
275                IntegerType::I64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I64)),
276                IntegerType::I128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I128)),
277            },
278            Type::Scalar => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Scalar)),
279            Type::Signature => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Signature)),
280            Type::Array(array_type) => Ok(PlaintextType::<N>::Array(array_type.to_snarkvm()?)),
281            _ => anyhow::bail!("Converting from type {self} to snarkVM type is not supported"),
282        }
283    }
284
285    // A helper function to get the size in bits of the input type.
286    pub fn size_in_bits<N: Network, F0, F1>(
287        &self,
288        is_raw: bool,
289        get_structs: F0,
290        get_external_structs: F1,
291    ) -> anyhow::Result<usize>
292    where
293        F0: Fn(&snarkvm::prelude::Identifier<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
294        F1: Fn(&snarkvm::prelude::Locator<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
295    {
296        match is_raw {
297            false => self.to_snarkvm::<N>()?.size_in_bits(&get_structs, &get_external_structs),
298            true => self.to_snarkvm::<N>()?.size_in_bits_raw(&get_structs, &get_external_structs),
299        }
300    }
301
302    /// Determines whether `self` can be coerced to the `expected` type.
303    ///
304    /// This method checks if the current type can be implicitly coerced to the expected type
305    /// according to specific rules:
306    /// - `Optional<T>` can be coerced to `Optional<T>`.
307    /// - `T` can be coerced to `Optional<T>`.
308    /// - Arrays `[T; N]` can be coerced to `[Optional<T>; N]` if lengths match or are unknown,
309    ///   and element types are coercible.
310    /// - Falls back to an equality check for other types.
311    ///
312    /// # Arguments
313    /// * `expected` - The type to which `self` is being coerced.
314    ///
315    /// # Returns
316    /// `true` if coercion is allowed; `false` otherwise.
317    pub fn can_coerce_to(&self, expected: &Type) -> bool {
318        use Type::*;
319
320        match (self, expected) {
321            // Allow Optional<T> → Optional<T>
322            (Optional(actual_opt), Optional(expected_opt)) => actual_opt.inner.can_coerce_to(&expected_opt.inner),
323
324            // Allow T → Optional<T>
325            (a, Optional(opt)) => a.can_coerce_to(&opt.inner),
326
327            // Allow [T; N] → [Optional<T>; N]
328            (Array(a_arr), Array(e_arr)) => {
329                let lengths_equal = match (a_arr.length.as_u32(), e_arr.length.as_u32()) {
330                    (Some(l1), Some(l2)) => l1 == l2,
331                    _ => true,
332                };
333
334                lengths_equal && a_arr.element_type().can_coerce_to(e_arr.element_type())
335            }
336
337            // Fallback: check for exact match
338            _ => self.eq_user(expected),
339        }
340    }
341
342    pub fn is_optional(&self) -> bool {
343        matches!(self, Self::Optional(_))
344    }
345
346    pub fn is_vector(&self) -> bool {
347        matches!(self, Self::Vector(_))
348    }
349
350    pub fn is_mapping(&self) -> bool {
351        matches!(self, Self::Mapping(_))
352    }
353
354    pub fn to_optional(&self) -> Type {
355        Type::Optional(OptionalType { inner: Box::new(self.clone()) })
356    }
357
358    pub fn is_empty(&self) -> bool {
359        match self {
360            Type::Unit => true,
361            Type::Array(array_type) => {
362                if let Some(length) = array_type.length.as_u32() {
363                    length == 0
364                } else {
365                    false
366                }
367            }
368            _ => false,
369        }
370    }
371}
372
373impl From<LiteralType> for Type {
374    fn from(value: LiteralType) -> Self {
375        match value {
376            LiteralType::Identifier => Type::Identifier,
377            LiteralType::Address => Type::Address,
378            LiteralType::Boolean => Type::Boolean,
379            LiteralType::Field => Type::Field,
380            LiteralType::Group => Type::Group,
381            LiteralType::U8 => Type::Integer(IntegerType::U8),
382            LiteralType::U16 => Type::Integer(IntegerType::U16),
383            LiteralType::U32 => Type::Integer(IntegerType::U32),
384            LiteralType::U64 => Type::Integer(IntegerType::U64),
385            LiteralType::U128 => Type::Integer(IntegerType::U128),
386            LiteralType::I8 => Type::Integer(IntegerType::I8),
387            LiteralType::I16 => Type::Integer(IntegerType::I16),
388            LiteralType::I32 => Type::Integer(IntegerType::I32),
389            LiteralType::I64 => Type::Integer(IntegerType::I64),
390            LiteralType::I128 => Type::Integer(IntegerType::I128),
391            LiteralType::Scalar => Type::Scalar,
392            LiteralType::Signature => Type::Signature,
393            LiteralType::String => Type::String,
394        }
395    }
396}
397
398impl fmt::Display for Type {
399    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
400        match *self {
401            Type::Address => write!(f, "address"),
402            Type::Identifier => write!(f, "identifier"),
403            Type::DynRecord => write!(f, "dyn record"),
404            Type::Array(ref array_type) => write!(f, "{array_type}"),
405            Type::Boolean => write!(f, "bool"),
406            Type::Field => write!(f, "field"),
407            Type::Future(ref future_type) => write!(f, "{future_type}"),
408            Type::Group => write!(f, "group"),
409            Type::Ident(ref variable) => write!(f, "{variable}"),
410            Type::Integer(ref integer_type) => write!(f, "{integer_type}"),
411            Type::Mapping(ref mapping_type) => write!(f, "{mapping_type}"),
412            Type::Optional(ref optional_type) => write!(f, "{optional_type}"),
413            Type::Scalar => write!(f, "scalar"),
414            Type::Signature => write!(f, "signature"),
415            Type::String => write!(f, "string"),
416            Type::Composite(ref composite_type) => write!(f, "{composite_type}"),
417            Type::Tuple(ref tuple) => write!(f, "{tuple}"),
418            Type::Vector(ref vector_type) => write!(f, "{vector_type}"),
419            Type::Numeric => write!(f, "numeric"),
420            Type::Unit => write!(f, "()"),
421            Type::Err => write!(f, "error"),
422        }
423    }
424}