use crate::{
ArrayType,
CompositeType,
FutureType,
Identifier,
IntegerType,
Location,
MappingType,
OptionalType,
Path,
ProgramId,
TupleType,
VectorType,
};
use itertools::Itertools;
use serde::{Deserialize, Serialize};
use snarkvm::prelude::{
LiteralType,
Network,
PlaintextType,
PlaintextType::{Array, ExternalStruct, Literal, Struct},
};
use std::fmt;
#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
pub enum Type {
Address,
Array(ArrayType),
Boolean,
Composite(CompositeType),
Field,
Future(FutureType),
Group,
Identifier,
DynRecord,
Ident(Identifier),
Integer(IntegerType),
Mapping(MappingType),
Optional(OptionalType),
Scalar,
Signature,
String,
Tuple(TupleType),
Vector(VectorType),
Numeric,
Unit,
#[default]
Err,
}
impl Type {
pub fn eq_user(&self, other: &Type) -> bool {
match (self, other) {
(Type::Err, _)
| (_, Type::Err)
| (Type::Address, Type::Address)
| (Type::Boolean, Type::Boolean)
| (Type::Field, Type::Field)
| (Type::Group, Type::Group)
| (Type::Scalar, Type::Scalar)
| (Type::Signature, Type::Signature)
| (Type::String, Type::String)
| (Type::Identifier, Type::Identifier)
| (Type::DynRecord, Type::DynRecord)
| (Type::Unit, Type::Unit) => true,
(Type::Array(left), Type::Array(right)) => {
(match (left.length.as_u32(), right.length.as_u32()) {
(Some(l1), Some(l2)) => l1 == l2,
_ => {
true
}
}) && left.element_type().eq_user(right.element_type())
}
(Type::Ident(left), Type::Ident(right)) => left.name == right.name,
(Type::Integer(left), Type::Integer(right)) => left == right,
(Type::Mapping(left), Type::Mapping(right)) => {
left.key.eq_user(&right.key) && left.value.eq_user(&right.value)
}
(Type::Optional(left), Type::Optional(right)) => left.inner.eq_user(&right.inner),
(Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
.elements()
.iter()
.zip_eq(right.elements().iter())
.all(|(left_type, right_type)| left_type.eq_user(right_type)),
(Type::Vector(left), Type::Vector(right)) => left.element_type.eq_user(&right.element_type),
(Type::Composite(left), Type::Composite(right)) => {
if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
return true;
}
match (&left.path.try_global_location(), &right.path.try_global_location()) {
(Some(l), Some(r)) => l == r,
_ => false,
}
}
(Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
(Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
.inputs()
.iter()
.zip_eq(right.inputs().iter())
.all(|(left_type, right_type)| left_type.eq_user(right_type)),
_ => false,
}
}
pub fn eq_flat_relaxed(&self, other: &Self) -> bool {
match (self, other) {
(Type::Address, Type::Address)
| (Type::Boolean, Type::Boolean)
| (Type::Field, Type::Field)
| (Type::Group, Type::Group)
| (Type::Scalar, Type::Scalar)
| (Type::Signature, Type::Signature)
| (Type::String, Type::String)
| (Type::Identifier, Type::Identifier)
| (Type::DynRecord, Type::DynRecord)
| (Type::Unit, Type::Unit) => true,
(Type::Array(left), Type::Array(right)) => {
(match (left.length.as_u32(), right.length.as_u32()) {
(Some(l1), Some(l2)) => l1 == l2,
_ => {
true
}
}) && left.element_type().eq_flat_relaxed(right.element_type())
}
(Type::Ident(left), Type::Ident(right)) => left.matches(right),
(Type::Integer(left), Type::Integer(right)) => left.eq(right),
(Type::Mapping(left), Type::Mapping(right)) => {
left.key.eq_flat_relaxed(&right.key) && left.value.eq_flat_relaxed(&right.value)
}
(Type::Optional(left), Type::Optional(right)) => left.inner.eq_flat_relaxed(&right.inner),
(Type::Tuple(left), Type::Tuple(right)) if left.length() == right.length() => left
.elements()
.iter()
.zip_eq(right.elements().iter())
.all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
(Type::Vector(left), Type::Vector(right)) => left.element_type.eq_flat_relaxed(&right.element_type),
(Type::Composite(left), Type::Composite(right)) => {
if !left.const_arguments.is_empty() || !right.const_arguments.is_empty() {
return true;
}
match (&left.path.try_global_location(), &right.path.try_global_location()) {
(Some(l), Some(r)) => l.path == r.path,
_ => false,
}
}
(Type::Future(left), Type::Future(right)) if !left.is_explicit || !right.is_explicit => true,
(Type::Future(left), Type::Future(right)) if left.inputs.len() == right.inputs.len() => left
.inputs()
.iter()
.zip_eq(right.inputs().iter())
.all(|(left_type, right_type)| left_type.eq_flat_relaxed(right_type)),
_ => false,
}
}
pub fn from_snarkvm<N: Network>(t: &PlaintextType<N>, program_id: ProgramId) -> Self {
match t {
Literal(lit) => (*lit).into(),
Struct(s) => Type::Composite(CompositeType {
path: {
let ident = Identifier::from(s);
Path::from(ident).to_global(Location::new(program_id.as_symbol(), vec![ident.name]))
},
const_arguments: Vec::new(),
}),
ExternalStruct(l) => Type::Composite(CompositeType {
path: {
let external_program = ProgramId::from(l.program_id());
let name = Identifier::from(l.resource());
Path::from(name)
.with_user_program(external_program)
.to_global(Location::new(external_program.as_symbol(), vec![name.name]))
},
const_arguments: Vec::new(),
}),
Array(array) => Type::Array(ArrayType::from_snarkvm(array, program_id)),
}
}
pub fn to_snarkvm<N: Network>(&self) -> anyhow::Result<PlaintextType<N>> {
match self {
Type::Address => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Address)),
Type::Boolean => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Boolean)),
Type::Field => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Field)),
Type::Group => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Group)),
Type::Integer(int_type) => match int_type {
IntegerType::U8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U8)),
IntegerType::U16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U16)),
IntegerType::U32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U32)),
IntegerType::U64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U64)),
IntegerType::U128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::U128)),
IntegerType::I8 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I8)),
IntegerType::I16 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I16)),
IntegerType::I32 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I32)),
IntegerType::I64 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I64)),
IntegerType::I128 => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::I128)),
},
Type::Scalar => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Scalar)),
Type::Signature => Ok(PlaintextType::Literal(snarkvm::prelude::LiteralType::Signature)),
Type::Array(array_type) => Ok(PlaintextType::<N>::Array(array_type.to_snarkvm()?)),
_ => anyhow::bail!("Converting from type {self} to snarkVM type is not supported"),
}
}
pub fn size_in_bits<N: Network, F0, F1>(
&self,
is_raw: bool,
get_structs: F0,
get_external_structs: F1,
) -> anyhow::Result<usize>
where
F0: Fn(&snarkvm::prelude::Identifier<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
F1: Fn(&snarkvm::prelude::Locator<N>) -> anyhow::Result<snarkvm::prelude::StructType<N>>,
{
match is_raw {
false => self.to_snarkvm::<N>()?.size_in_bits(&get_structs, &get_external_structs),
true => self.to_snarkvm::<N>()?.size_in_bits_raw(&get_structs, &get_external_structs),
}
}
pub fn can_coerce_to(&self, expected: &Type) -> bool {
use Type::*;
match (self, expected) {
(Optional(actual_opt), Optional(expected_opt)) => actual_opt.inner.can_coerce_to(&expected_opt.inner),
(a, Optional(opt)) => a.can_coerce_to(&opt.inner),
(Array(a_arr), Array(e_arr)) => {
let lengths_equal = match (a_arr.length.as_u32(), e_arr.length.as_u32()) {
(Some(l1), Some(l2)) => l1 == l2,
_ => true,
};
lengths_equal && a_arr.element_type().can_coerce_to(e_arr.element_type())
}
_ => self.eq_user(expected),
}
}
pub fn is_optional(&self) -> bool {
matches!(self, Self::Optional(_))
}
pub fn is_vector(&self) -> bool {
matches!(self, Self::Vector(_))
}
pub fn is_mapping(&self) -> bool {
matches!(self, Self::Mapping(_))
}
pub fn to_optional(&self) -> Type {
Type::Optional(OptionalType { inner: Box::new(self.clone()) })
}
pub fn is_empty(&self) -> bool {
match self {
Type::Unit => true,
Type::Array(array_type) => {
if let Some(length) = array_type.length.as_u32() {
length == 0
} else {
false
}
}
_ => false,
}
}
}
impl From<LiteralType> for Type {
fn from(value: LiteralType) -> Self {
match value {
LiteralType::Identifier => Type::Identifier,
LiteralType::Address => Type::Address,
LiteralType::Boolean => Type::Boolean,
LiteralType::Field => Type::Field,
LiteralType::Group => Type::Group,
LiteralType::U8 => Type::Integer(IntegerType::U8),
LiteralType::U16 => Type::Integer(IntegerType::U16),
LiteralType::U32 => Type::Integer(IntegerType::U32),
LiteralType::U64 => Type::Integer(IntegerType::U64),
LiteralType::U128 => Type::Integer(IntegerType::U128),
LiteralType::I8 => Type::Integer(IntegerType::I8),
LiteralType::I16 => Type::Integer(IntegerType::I16),
LiteralType::I32 => Type::Integer(IntegerType::I32),
LiteralType::I64 => Type::Integer(IntegerType::I64),
LiteralType::I128 => Type::Integer(IntegerType::I128),
LiteralType::Scalar => Type::Scalar,
LiteralType::Signature => Type::Signature,
LiteralType::String => Type::String,
}
}
}
impl fmt::Display for Type {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match *self {
Type::Address => write!(f, "address"),
Type::Identifier => write!(f, "identifier"),
Type::DynRecord => write!(f, "dyn record"),
Type::Array(ref array_type) => write!(f, "{array_type}"),
Type::Boolean => write!(f, "bool"),
Type::Field => write!(f, "field"),
Type::Future(ref future_type) => write!(f, "{future_type}"),
Type::Group => write!(f, "group"),
Type::Ident(ref variable) => write!(f, "{variable}"),
Type::Integer(ref integer_type) => write!(f, "{integer_type}"),
Type::Mapping(ref mapping_type) => write!(f, "{mapping_type}"),
Type::Optional(ref optional_type) => write!(f, "{optional_type}"),
Type::Scalar => write!(f, "scalar"),
Type::Signature => write!(f, "signature"),
Type::String => write!(f, "string"),
Type::Composite(ref composite_type) => write!(f, "{composite_type}"),
Type::Tuple(ref tuple) => write!(f, "{tuple}"),
Type::Vector(ref vector_type) => write!(f, "{vector_type}"),
Type::Numeric => write!(f, "numeric"),
Type::Unit => write!(f, "()"),
Type::Err => write!(f, "error"),
}
}
}