1use crate::{
18 ArrayType,
19 CompositeType,
20 FutureType,
21 Identifier,
22 IntegerType,
23 Location,
24 MappingType,
25 OptionalType,
26 Path,
27 ProgramId,
28 TupleType,
29 VectorType,
30};
31
32use itertools::Itertools;
33use serde::{Deserialize, Serialize};
34use snarkvm::prelude::{
35 LiteralType,
36 Network,
37 PlaintextType,
38 PlaintextType::{Array, ExternalStruct, Literal, Struct},
39};
40use std::fmt;
41
42#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
44pub enum Type {
45 Address,
47 Array(ArrayType),
49 Boolean,
51 Composite(CompositeType),
53 Field,
55 Future(FutureType),
57 Group,
59 Identifier,
61 DynRecord,
63 Ident(Identifier),
65 Integer(IntegerType),
67 Mapping(MappingType),
69 Optional(OptionalType),
71 Scalar,
73 Signature,
75 String,
77 Tuple(TupleType),
79 Vector(VectorType),
81 Numeric,
83 Unit,
85 #[default]
88 Err,
89}
90
91impl Type {
92 pub fn eq_user(&self, other: &Type) -> bool {
105 match (self, other) {
106 (Type::Err, _)
107 | (_, Type::Err)
108 | (Type::Address, Type::Address)
109 | (Type::Boolean, Type::Boolean)
110 | (Type::Field, Type::Field)
111 | (Type::Group, Type::Group)
112 | (Type::Scalar, Type::Scalar)
113 | (Type::Signature, Type::Signature)
114 | (Type::String, Type::String)
115 | (Type::Identifier, Type::Identifier)
116 | (Type::DynRecord, Type::DynRecord)
117 | (Type::Unit, Type::Unit) => true,
118 (Type::Array(left), Type::Array(right)) => {
119 (match (left.length.as_u32(), right.length.as_u32()) {
120 (Some(l1), Some(l2)) => l1 == l2,
121 _ => {
122 true
125 }
126 }) && left.element_type().eq_user(right.element_type())
127 }
128 (Type::Ident(left), Type::Ident(right)) => left.name == right.name,
129 (Type::Integer(left), Type::Integer(right)) => left == right,
130 (Type::Mapping(left), Type::Mapping(right)) => {
131 left.key.eq_user(&right.key) && left.value.eq_user(&right.value)
132 }
133 (Type::Optional(left), Type::Optional(right)) => left.inner.eq_user(&right.inner),
134 (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
135 .elements()
136 .iter()
137 .zip_eq(right.elements().iter())
138 .all(|(left_type, right_type)| left_type.eq_user(right_type)),
139 (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_user(&right.element_type),
140 (Type::Composite(left), Type::Composite(right)) => {
141 if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
144 return true;
145 }
146
147 match (&left.path.try_global_location(), &right.path.try_global_location()) {
149 (Some(l), Some(r)) => l == r,
150 _ => false,
151 }
152 }
153
154 (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
155 (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
156 .inputs()
157 .iter()
158 .zip_eq(right.inputs().iter())
159 .all(|(left_type, right_type)| left_type.eq_user(right_type)),
160 _ => false,
161 }
162 }
163
164 pub fn eq_flat_relaxed(&self, other: &Self) -> bool {
174 match (self, other) {
175 (Type::Address, Type::Address)
176 | (Type::Boolean, Type::Boolean)
177 | (Type::Field, Type::Field)
178 | (Type::Group, Type::Group)
179 | (Type::Scalar, Type::Scalar)
180 | (Type::Signature, Type::Signature)
181 | (Type::String, Type::String)
182 | (Type::Identifier, Type::Identifier)
183 | (Type::DynRecord, Type::DynRecord)
184 | (Type::Unit, Type::Unit) => true,
185 (Type::Array(left), Type::Array(right)) => {
186 (match (left.length.as_u32(), right.length.as_u32()) {
189 (Some(l1), Some(l2)) => l1 == l2,
190 _ => {
191 true
194 }
195 }) && left.element_type().eq_flat_relaxed(right.element_type())
196 }
197 (Type::Ident(left), Type::Ident(right)) => left.matches(right),
198 (Type::Integer(left), Type::Integer(right)) => left.eq(right),
199 (Type::Mapping(left), Type::Mapping(right)) => {
200 left.key.eq_flat_relaxed(&right.key) && left.value.eq_flat_relaxed(&right.value)
201 }
202 (Type::Optional(left), Type::Optional(right)) => left.inner.eq_flat_relaxed(&right.inner),
203 (Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
204 .elements()
205 .iter()
206 .zip_eq(right.elements().iter())
207 .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
208 (Type::Vector(left), Type::Vector(right)) => left.element_type.eq_flat_relaxed(&right.element_type),
209 (Type::Composite(left), Type::Composite(right)) => {
210 if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
213 return true;
214 }
215
216 match (&left.path.try_global_location(), &right.path.try_global_location()) {
220 (Some(l), Some(r)) => l.path == r.path,
221 _ => false,
222 }
223 }
224 (Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
226 (Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
227 .inputs()
228 .iter()
229 .zip_eq(right.inputs().iter())
230 .all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
231 _ => false,
232 }
233 }
234
235 pub fn from_snarkvm<N: Network>(t: &PlaintextType<N>, program_id: ProgramId) -> Self {
236 match t {
237 Literal(lit) => (*lit).into(),
238 Struct(s) => Type::Composite(CompositeType {
239 path: {
240 let ident = Identifier::from(s);
241 Path::from(ident).to_global(Location::new(program_id.as_symbol(), vec![ident.name]))
242 },
243 const_arguments: Vec::new(),
244 }),
245 ExternalStruct(l) => Type::Composite(CompositeType {
246 path: {
247 let external_program = ProgramId::from(l.program_id());
248 let name = Identifier::from(l.resource());
249 Path::from(name)
250 .with_user_program(external_program)
251 .to_global(Location::new(external_program.as_symbol(), vec![name.name]))
252 },
253 const_arguments: Vec::new(),
254 }),
255 Array(array) => Type::Array(ArrayType::from_snarkvm(array, program_id)),
256 }
257 }
258
259 pub fn to_snarkvm<N: Network>(&self) -> anyhow::Result<PlaintextType<N>> {
261 match self {
262 Type::Address => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Address)),
263 Type::Boolean => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Boolean)),
264 Type::Field => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Field)),
265 Type::Group => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Group)),
266 Type::Integer(int_type) => match int_type {
267 IntegerType::U8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U8)),
268 IntegerType::U16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U16)),
269 IntegerType::U32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U32)),
270 IntegerType::U64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U64)),
271 IntegerType::U128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U128)),
272 IntegerType::I8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I8)),
273 IntegerType::I16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I16)),
274 IntegerType::I32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I32)),
275 IntegerType::I64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I64)),
276 IntegerType::I128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I128)),
277 },
278 Type::Scalar => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Scalar)),
279 Type::Signature => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Signature)),
280 Type::Array(array_type) => Ok(PlaintextType::<N>::Array(array_type.to_snarkvm()?)),
281 _ => anyhow::bail!("Converting from type {self} to snarkVM type is not supported"),
282 }
283 }
284
285 pub fn size_in_bits<N: Network, F0, F1>(
287 &self,
288 is_raw: bool,
289 get_structs: F0,
290 get_external_structs: F1,
291 ) -> anyhow::Result<usize>
292 where
293 F0: Fn(&snarkvm::prelude::Identifier<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
294 F1: Fn(&snarkvm::prelude::Locator<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
295 {
296 match is_raw {
297 false => self.to_snarkvm::<N>()?.size_in_bits(&get_structs, &get_external_structs),
298 true => self.to_snarkvm::<N>()?.size_in_bits_raw(&get_structs, &get_external_structs),
299 }
300 }
301
302 pub fn can_coerce_to(&self, expected: &Type) -> bool {
318 use Type::*;
319
320 match (self, expected) {
321 (Optional(actual_opt), Optional(expected_opt)) => actual_opt.inner.can_coerce_to(&expected_opt.inner),
323
324 (a, Optional(opt)) => a.can_coerce_to(&opt.inner),
326
327 (Array(a_arr), Array(e_arr)) => {
329 let lengths_equal = match (a_arr.length.as_u32(), e_arr.length.as_u32()) {
330 (Some(l1), Some(l2)) => l1 == l2,
331 _ => true,
332 };
333
334 lengths_equal && a_arr.element_type().can_coerce_to(e_arr.element_type())
335 }
336
337 _ => self.eq_user(expected),
339 }
340 }
341
342 pub fn is_optional(&self) -> bool {
343 matches!(self, Self::Optional(_))
344 }
345
346 pub fn is_vector(&self) -> bool {
347 matches!(self, Self::Vector(_))
348 }
349
350 pub fn is_mapping(&self) -> bool {
351 matches!(self, Self::Mapping(_))
352 }
353
354 pub fn to_optional(&self) -> Type {
355 Type::Optional(OptionalType { inner: Box::new(self.clone()) })
356 }
357
358 pub fn is_empty(&self) -> bool {
359 match self {
360 Type::Unit => true,
361 Type::Array(array_type) => {
362 if let Some(length) = array_type.length.as_u32() {
363 length == 0
364 } else {
365 false
366 }
367 }
368 _ => false,
369 }
370 }
371}
372
373impl From<LiteralType> for Type {
374 fn from(value: LiteralType) -> Self {
375 match value {
376 LiteralType::Identifier => Type::Identifier,
377 LiteralType::Address => Type::Address,
378 LiteralType::Boolean => Type::Boolean,
379 LiteralType::Field => Type::Field,
380 LiteralType::Group => Type::Group,
381 LiteralType::U8 => Type::Integer(IntegerType::U8),
382 LiteralType::U16 => Type::Integer(IntegerType::U16),
383 LiteralType::U32 => Type::Integer(IntegerType::U32),
384 LiteralType::U64 => Type::Integer(IntegerType::U64),
385 LiteralType::U128 => Type::Integer(IntegerType::U128),
386 LiteralType::I8 => Type::Integer(IntegerType::I8),
387 LiteralType::I16 => Type::Integer(IntegerType::I16),
388 LiteralType::I32 => Type::Integer(IntegerType::I32),
389 LiteralType::I64 => Type::Integer(IntegerType::I64),
390 LiteralType::I128 => Type::Integer(IntegerType::I128),
391 LiteralType::Scalar => Type::Scalar,
392 LiteralType::Signature => Type::Signature,
393 LiteralType::String => Type::String,
394 }
395 }
396}
397
398impl fmt::Display for Type {
399 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
400 match *self {
401 Type::Address => write!(f, "address"),
402 Type::Identifier => write!(f, "identifier"),
403 Type::DynRecord => write!(f, "dyn record"),
404 Type::Array(ref array_type) => write!(f, "{array_type}"),
405 Type::Boolean => write!(f, "bool"),
406 Type::Field => write!(f, "field"),
407 Type::Future(ref future_type) => write!(f, "{future_type}"),
408 Type::Group => write!(f, "group"),
409 Type::Ident(ref variable) => write!(f, "{variable}"),
410 Type::Integer(ref integer_type) => write!(f, "{integer_type}"),
411 Type::Mapping(ref mapping_type) => write!(f, "{mapping_type}"),
412 Type::Optional(ref optional_type) => write!(f, "{optional_type}"),
413 Type::Scalar => write!(f, "scalar"),
414 Type::Signature => write!(f, "signature"),
415 Type::String => write!(f, "string"),
416 Type::Composite(ref composite_type) => write!(f, "{composite_type}"),
417 Type::Tuple(ref tuple) => write!(f, "{tuple}"),
418 Type::Vector(ref vector_type) => write!(f, "{vector_type}"),
419 Type::Numeric => write!(f, "numeric"),
420 Type::Unit => write!(f, "()"),
421 Type::Err => write!(f, "error"),
422 }
423 }
424}