1use crate::{
18 ArrayType,
19 CompositeType,
20 FutureType,
21 Identifier,
22 IntegerType,
23 Location,
24 MappingType,
25 OptionalType,
26 Path,
27 TupleType,
28 VectorType,
29};
30
31use itertools::Itertools;
32use leo_span::Symbol;
33use serde::{Deserialize, Serialize};
34use snarkvm::prelude::{
35 LiteralType,
36 Network,
37 PlaintextType,
38 PlaintextType::{Array, ExternalStruct, Literal, Struct},
39};
40use std::fmt;
41
42#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub enum Type {
45 Address,
47 Array(ArrayType),
49 Boolean,
51 Composite(CompositeType),
53 Field,
55 Future(FutureType),
57 Group,
59 Identifier(Identifier),
61 Integer(IntegerType),
63 Mapping(MappingType),
65 Optional(OptionalType),
67 Scalar,
69 Signature,
71 String,
73 Tuple(TupleType),
75 Vector(VectorType),
77 Numeric,
79 Unit,
81 #[default]
84 Err,
85}
86
87impl Type {
88 pub fn eq_user(&self, other: &Type) -> bool {
101 match (self, other) {
102 (Type::Err, _)
103 | (_, Type::Err)
104 | (Type::Address, Type::Address)
105 | (Type::Boolean, Type::Boolean)
106 | (Type::Field, Type::Field)
107 | (Type::Group, Type::Group)
108 | (Type::Scalar, Type::Scalar)
109 | (Type::Signature, Type::Signature)
110 | (Type::String, Type::String)
111 | (Type::Unit, Type::Unit) => true,
112 (Type::Array(left), Type::Array(right)) => {
113 (match (left.length.as_u32(), right.length.as_u32()) {
114 (Some(l1), Some(l2)) => l1 == l2,
115 _ => {
116 true
119 }
120 }) && left.element_type().eq_user(right.element_type())
121 }
122 (Type::Identifier(left), Type::Identifier(right)) => left.name == right.name,
123 (Type::Integer(left), Type::Integer(right)) => left == right,
124 (Type::Mapping(left), Type::Mapping(right)) => {
125 left.key.eq_user(&right.key) && left.value.eq_user(&right.value)
126 }
127 (Type::Optional(left), Type::Optional(right)) => left.inner.eq_user(&right.inner),
128 (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
129 .elements()
130 .iter()
131 .zip_eq(right.elements().iter())
132 .all(|(left_type, right_type)| left_type.eq_user(right_type)),
133 (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_user(&right.element_type),
134 (Type::Composite(left), Type::Composite(right)) => {
135 if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
137 return true;
138 }
139
140 match (&left.path.try_global_location(), &right.path.try_global_location()) {
142 (Some(l), Some(r)) => l == r,
143 _ => false,
144 }
145 }
146
147 (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
148 (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
149 .inputs()
150 .iter()
151 .zip_eq(right.inputs().iter())
152 .all(|(left_type, right_type)| left_type.eq_user(right_type)),
153 _ => false,
154 }
155 }
156
157 pub fn eq_flat_relaxed(&self, other: &Self) -> bool {
167 match (self, other) {
168 (Type::Address, Type::Address)
169 | (Type::Boolean, Type::Boolean)
170 | (Type::Field, Type::Field)
171 | (Type::Group, Type::Group)
172 | (Type::Scalar, Type::Scalar)
173 | (Type::Signature, Type::Signature)
174 | (Type::String, Type::String)
175 | (Type::Unit, Type::Unit) => true,
176 (Type::Array(left), Type::Array(right)) => {
177 (match (left.length.as_u32(), right.length.as_u32()) {
180 (Some(l1), Some(l2)) => l1 == l2,
181 _ => {
182 true
185 }
186 }) && left.element_type().eq_flat_relaxed(right.element_type())
187 }
188 (Type::Identifier(left), Type::Identifier(right)) => left.matches(right),
189 (Type::Integer(left), Type::Integer(right)) => left.eq(right),
190 (Type::Mapping(left), Type::Mapping(right)) => {
191 left.key.eq_flat_relaxed(&right.key) && left.value.eq_flat_relaxed(&right.value)
192 }
193 (Type::Optional(left), Type::Optional(right)) => left.inner.eq_flat_relaxed(&right.inner),
194 (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
195 .elements()
196 .iter()
197 .zip_eq(right.elements().iter())
198 .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
199 (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_flat_relaxed(&right.element_type),
200 (Type::Composite(left), Type::Composite(right)) => {
201 if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
204 return true;
205 }
206
207 match (&left.path.try_global_location(), &right.path.try_global_location()) {
211 (Some(l), Some(r)) => l.path == r.path,
212 _ => false,
213 }
214 }
215 (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
217 (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
218 .inputs()
219 .iter()
220 .zip_eq(right.inputs().iter())
221 .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
222 _ => false,
223 }
224 }
225
226 pub fn from_snarkvm<N: Network>(t: &PlaintextType<N>, program: Symbol) -> Self {
227 match t {
228 Literal(lit) => (*lit).into(),
229 Struct(s) => Type::Composite(CompositeType {
230 path: {
231 let ident = Identifier::from(s);
232 Path::from(ident).to_global(Location::new(program, vec![ident.name]))
233 },
234 const_arguments: Vec::new(),
235 }),
236 ExternalStruct(l) => Type::Composite(CompositeType {
237 path: {
238 let external_program = Identifier::from(l.program_id().name());
239 let name = Identifier::from(l.resource());
240 Path::from(name)
241 .with_user_program(external_program)
242 .to_global(Location::new(external_program.name, vec![name.name]))
243 },
244 const_arguments: Vec::new(),
245 }),
246 Array(array) => Type::Array(ArrayType::from_snarkvm(array, program)),
247 }
248 }
249
250 pub fn to_snarkvm<N: Network>(&self) -> anyhow::Result<PlaintextType<N>> {
252 match self {
253 Type::Address => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Address)),
254 Type::Boolean => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Boolean)),
255 Type::Field => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Field)),
256 Type::Group => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Group)),
257 Type::Integer(int_type) => match int_type {
258 IntegerType::U8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U8)),
259 IntegerType::U16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U16)),
260 IntegerType::U32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U32)),
261 IntegerType::U64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U64)),
262 IntegerType::U128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U128)),
263 IntegerType::I8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I8)),
264 IntegerType::I16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I16)),
265 IntegerType::I32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I32)),
266 IntegerType::I64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I64)),
267 IntegerType::I128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I128)),
268 },
269 Type::Scalar => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Scalar)),
270 Type::Signature => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Signature)),
271 Type::Array(array_type) => Ok(PlaintextType::<N>::Array(array_type.to_snarkvm()?)),
272 _ => anyhow::bail!("Converting from type {self} to snarkVM type is not supported"),
273 }
274 }
275
276 pub fn size_in_bits<N: Network, F0, F1>(
278 &self,
279 is_raw: bool,
280 get_structs: F0,
281 get_external_structs: F1,
282 ) -> anyhow::Result<usize>
283 where
284 F0: Fn(&snarkvm::prelude::Identifier<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
285 F1: Fn(&snarkvm::prelude::Locator<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
286 {
287 match is_raw {
288 false => self.to_snarkvm::<N>()?.size_in_bits(&get_structs, &get_external_structs),
289 true => self.to_snarkvm::<N>()?.size_in_bits_raw(&get_structs, &get_external_structs),
290 }
291 }
292
293 pub fn can_coerce_to(&self, expected: &Type) -> bool {
309 use Type::*;
310
311 match (self, expected) {
312 (Optional(actual_opt), Optional(expected_opt)) => actual_opt.inner.can_coerce_to(&expected_opt.inner),
314
315 (a, Optional(opt)) => a.can_coerce_to(&opt.inner),
317
318 (Array(a_arr), Array(e_arr)) => {
320 let lengths_equal = match (a_arr.length.as_u32(), e_arr.length.as_u32()) {
321 (Some(l1), Some(l2)) => l1 == l2,
322 _ => true,
323 };
324
325 lengths_equal && a_arr.element_type().can_coerce_to(e_arr.element_type())
326 }
327
328 _ => self.eq_user(expected),
330 }
331 }
332
333 pub fn is_optional(&self) -> bool {
334 matches!(self, Self::Optional(_))
335 }
336
337 pub fn is_vector(&self) -> bool {
338 matches!(self, Self::Vector(_))
339 }
340
341 pub fn is_mapping(&self) -> bool {
342 matches!(self, Self::Mapping(_))
343 }
344
345 pub fn to_optional(&self) -> Type {
346 Type::Optional(OptionalType { inner: Box::new(self.clone()) })
347 }
348
349 pub fn is_empty(&self) -> bool {
350 match self {
351 Type::Unit => true,
352 Type::Array(array_type) => {
353 if let Some(length) = array_type.length.as_u32() {
354 length == 0
355 } else {
356 false
357 }
358 }
359 _ => false,
360 }
361 }
362}
363
364impl From<LiteralType> for Type {
365 fn from(value: LiteralType) -> Self {
366 match value {
367 LiteralType::Address => Type::Address,
368 LiteralType::Boolean => Type::Boolean,
369 LiteralType::Field => Type::Field,
370 LiteralType::Group => Type::Group,
371 LiteralType::U8 => Type::Integer(IntegerType::U8),
372 LiteralType::U16 => Type::Integer(IntegerType::U16),
373 LiteralType::U32 => Type::Integer(IntegerType::U32),
374 LiteralType::U64 => Type::Integer(IntegerType::U64),
375 LiteralType::U128 => Type::Integer(IntegerType::U128),
376 LiteralType::I8 => Type::Integer(IntegerType::I8),
377 LiteralType::I16 => Type::Integer(IntegerType::I16),
378 LiteralType::I32 => Type::Integer(IntegerType::I32),
379 LiteralType::I64 => Type::Integer(IntegerType::I64),
380 LiteralType::I128 => Type::Integer(IntegerType::I128),
381 LiteralType::Scalar => Type::Scalar,
382 LiteralType::Signature => Type::Signature,
383 LiteralType::String => Type::String,
384 }
385 }
386}
387
388impl fmt::Display for Type {
389 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
390 match *self {
391 Type::Address => write!(f, "address"),
392 Type::Array(ref array_type) => write!(f, "{array_type}"),
393 Type::Boolean => write!(f, "bool"),
394 Type::Field => write!(f, "field"),
395 Type::Future(ref future_type) => write!(f, "{future_type}"),
396 Type::Group => write!(f, "group"),
397 Type::Identifier(ref variable) => write!(f, "{variable}"),
398 Type::Integer(ref integer_type) => write!(f, "{integer_type}"),
399 Type::Mapping(ref mapping_type) => write!(f, "{mapping_type}"),
400 Type::Optional(ref optional_type) => write!(f, "{optional_type}"),
401 Type::Scalar => write!(f, "scalar"),
402 Type::Signature => write!(f, "signature"),
403 Type::String => write!(f, "string"),
404 Type::Composite(ref composite_type) => write!(f, "{composite_type}"),
405 Type::Tuple(ref tuple) => write!(f, "{tuple}"),
406 Type::Vector(ref vector_type) => write!(f, "{vector_type}"),
407 Type::Numeric => write!(f, "numeric"),
408 Type::Unit => write!(f, "()"),
409 Type::Err => write!(f, "error"),
410 }
411 }
412}