mod check;
pub mod custom;
mod poly_func;
mod row_var;
mod serialize;
mod signature;
pub mod type_param;
pub mod type_row;
pub(crate) use row_var::MaybeRV;
pub use row_var::{NoRV, RowVariable};
pub use crate::ops::constant::{ConstTypeError, CustomCheckFailure};
use crate::types::type_param::check_type_arg;
use crate::utils::display_list_with_separator;
pub use check::SumTypeError;
pub use custom::CustomType;
pub(crate) use poly_func::PolyFuncTypeBase;
pub use poly_func::{PolyFuncType, PolyFuncTypeRV};
pub(crate) use signature::FuncTypeBase;
pub use signature::{FuncValueType, Signature};
use smol_str::SmolStr;
pub use type_param::TypeArg;
pub use type_row::{TypeRow, TypeRowRV};
use itertools::FoldWhile::{Continue, Done};
use itertools::{repeat_n, Itertools};
#[cfg(test)]
use proptest_derive::Arbitrary;
use serde::{Deserialize, Serialize};
use crate::extension::{ExtensionRegistry, SignatureError};
use crate::ops::AliasDecl;
use self::type_param::TypeParam;
use self::type_row::TypeRowBase;
pub type TypeName = SmolStr;
pub type TypeNameRef = str;
#[derive(Clone, PartialEq, Eq, Debug, serde::Serialize, serde::Deserialize)]
#[non_exhaustive]
pub enum EdgeKind {
ControlFlow,
Value(Type),
Const(Type),
Function(PolyFuncType),
StateOrder,
}
impl EdgeKind {
pub fn is_linear(&self) -> bool {
matches!(self, EdgeKind::Value(t) if !t.copyable())
}
pub fn is_static(&self) -> bool {
matches!(self, EdgeKind::Const(_) | EdgeKind::Function(_))
}
}
#[derive(
Copy, Default, Clone, PartialEq, Eq, Hash, Debug, derive_more::Display, Serialize, Deserialize,
)]
#[cfg_attr(test, derive(Arbitrary))]
pub enum TypeBound {
#[serde(rename = "C", alias = "E")] Copyable,
#[serde(rename = "A")]
#[default]
Any,
}
impl TypeBound {
pub fn union(self, other: Self) -> Self {
if self.contains(other) {
self
} else {
debug_assert!(other.contains(self));
other
}
}
pub const fn contains(&self, other: TypeBound) -> bool {
use TypeBound::*;
matches!((self, other), (Any, _) | (_, Copyable))
}
}
pub(crate) fn least_upper_bound(mut tags: impl Iterator<Item = TypeBound>) -> TypeBound {
tags.fold_while(TypeBound::Copyable, |acc, new| {
if acc == TypeBound::Any || new == TypeBound::Any {
Done(TypeBound::Any)
} else {
Continue(acc.union(new))
}
})
.into_inner()
}
#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)]
#[serde(tag = "s")]
#[non_exhaustive]
pub enum SumType {
#[allow(missing_docs)]
Unit { size: u8 },
#[allow(missing_docs)]
General { rows: Vec<TypeRowRV> },
}
impl std::fmt::Display for SumType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if self.num_variants() == 0 {
return write!(f, "⊥");
}
match self {
SumType::Unit { size } => {
display_list_with_separator(repeat_n("[]", *size as usize), f, "+")
}
SumType::General { rows } => display_list_with_separator(rows.iter(), f, "+"),
}
}
}
impl SumType {
pub fn new<V>(variants: impl IntoIterator<Item = V>) -> Self
where
V: Into<TypeRowRV>,
{
let rows = variants.into_iter().map(Into::into).collect_vec();
let len: usize = rows.len();
if len <= (u8::MAX as usize) && rows.iter().all(TypeRowRV::is_empty) {
Self::new_unary(len as u8)
} else {
Self::General { rows }
}
}
pub const fn new_unary(size: u8) -> Self {
Self::Unit { size }
}
pub fn new_tuple(types: impl Into<TypeRow>) -> Self {
Self::new([types.into()])
}
pub fn get_variant(&self, tag: usize) -> Option<&TypeRowRV> {
match self {
SumType::Unit { size } if tag < (*size as usize) => Some(TypeRV::EMPTY_TYPEROW_REF),
SumType::General { rows } => rows.get(tag),
_ => None,
}
}
pub fn num_variants(&self) -> usize {
match self {
SumType::Unit { size } => *size as usize,
SumType::General { rows } => rows.len(),
}
}
pub fn as_tuple(&self) -> Option<&TypeRowRV> {
match self {
SumType::Unit { size } if *size == 1 => Some(TypeRV::EMPTY_TYPEROW_REF),
SumType::General { rows } if rows.len() == 1 => Some(&rows[0]),
_ => None,
}
}
}
impl<RV: MaybeRV> From<SumType> for TypeBase<RV> {
fn from(sum: SumType) -> Self {
match sum {
SumType::Unit { size } => TypeBase::new_unit_sum(size),
SumType::General { rows } => TypeBase::new_sum(rows),
}
}
}
#[derive(Clone, Debug, Eq, Hash, derive_more::Display)]
pub enum TypeEnum<RV: MaybeRV> {
#[allow(missing_docs)]
Extension(CustomType),
#[allow(missing_docs)]
#[display("Alias({})", _0.name())]
Alias(AliasDecl),
#[allow(missing_docs)]
#[display("Function({_0})")]
Function(Box<FuncValueType>),
#[allow(missing_docs)]
#[display("Variable({_0})")]
Variable(usize, TypeBound),
#[display("RowVar({_0})")]
RowVar(RV),
#[allow(missing_docs)]
Sum(SumType),
}
impl<RV: MaybeRV> TypeEnum<RV> {
fn least_upper_bound(&self) -> TypeBound {
match self {
TypeEnum::Extension(c) => c.bound(),
TypeEnum::Alias(a) => a.bound,
TypeEnum::Function(_) => TypeBound::Copyable,
TypeEnum::Variable(_, b) => *b,
TypeEnum::RowVar(b) => b.bound(),
TypeEnum::Sum(SumType::Unit { size: _ }) => TypeBound::Copyable,
TypeEnum::Sum(SumType::General { rows }) => least_upper_bound(
rows.iter()
.flat_map(TypeRowRV::iter)
.map(TypeRV::least_upper_bound),
),
}
}
}
#[derive(Clone, Debug, Eq, Hash, derive_more::Display, serde::Serialize, serde::Deserialize)]
#[display("{_0}")]
#[serde(
into = "serialize::SerSimpleType",
try_from = "serialize::SerSimpleType"
)]
pub struct TypeBase<RV: MaybeRV>(TypeEnum<RV>, TypeBound);
pub type Type = TypeBase<NoRV>;
pub type TypeRV = TypeBase<RowVariable>;
impl<RV1: MaybeRV, RV2: MaybeRV> PartialEq<TypeEnum<RV1>> for TypeEnum<RV2> {
fn eq(&self, other: &TypeEnum<RV1>) -> bool {
match (self, other) {
(TypeEnum::Extension(e1), TypeEnum::Extension(e2)) => e1 == e2,
(TypeEnum::Alias(a1), TypeEnum::Alias(a2)) => a1 == a2,
(TypeEnum::Function(f1), TypeEnum::Function(f2)) => f1 == f2,
(TypeEnum::Variable(i1, b1), TypeEnum::Variable(i2, b2)) => i1 == i2 && b1 == b2,
(TypeEnum::RowVar(v1), TypeEnum::RowVar(v2)) => v1.as_rv() == v2.as_rv(),
(TypeEnum::Sum(s1), TypeEnum::Sum(s2)) => s1 == s2,
_ => false,
}
}
}
impl<RV1: MaybeRV, RV2: MaybeRV> PartialEq<TypeBase<RV1>> for TypeBase<RV2> {
fn eq(&self, other: &TypeBase<RV1>) -> bool {
self.0 == other.0 && self.1 == other.1
}
}
impl<RV: MaybeRV> TypeBase<RV> {
pub const EMPTY_TYPEROW: TypeRowBase<RV> = TypeRowBase::<RV>::new();
pub const UNIT: Self = Self(
TypeEnum::Sum(SumType::Unit { size: 1 }),
TypeBound::Copyable,
);
const EMPTY_TYPEROW_REF: &'static TypeRowBase<RV> = &Self::EMPTY_TYPEROW;
pub fn new_function(fun_ty: impl Into<FuncValueType>) -> Self {
Self::new(TypeEnum::Function(Box::new(fun_ty.into())))
}
#[inline(always)]
pub fn new_tuple(types: impl Into<TypeRowRV>) -> Self {
let row = types.into();
match row.len() {
0 => Self::UNIT,
_ => Self::new_sum([row]),
}
}
#[inline(always)]
pub fn new_sum<R>(variants: impl IntoIterator<Item = R>) -> Self
where
R: Into<TypeRowRV>,
{
Self::new(TypeEnum::Sum(SumType::new(variants)))
}
pub const fn new_extension(opaque: CustomType) -> Self {
let bound = opaque.bound();
TypeBase(TypeEnum::Extension(opaque), bound)
}
pub fn new_alias(alias: AliasDecl) -> Self {
Self::new(TypeEnum::Alias(alias))
}
fn new(type_e: TypeEnum<RV>) -> Self {
let bound = type_e.least_upper_bound();
Self(type_e, bound)
}
pub const fn new_unit_sum(size: u8) -> Self {
Self(TypeEnum::Sum(SumType::new_unary(size)), TypeBound::Copyable)
}
pub const fn new_var_use(idx: usize, bound: TypeBound) -> Self {
Self(TypeEnum::Variable(idx, bound), bound)
}
#[inline(always)]
pub const fn least_upper_bound(&self) -> TypeBound {
self.1
}
#[inline(always)]
pub const fn as_type_enum(&self) -> &TypeEnum<RV> {
&self.0
}
pub const fn copyable(&self) -> bool {
TypeBound::Copyable.contains(self.least_upper_bound())
}
pub(crate) fn validate(
&self,
extension_registry: &ExtensionRegistry,
var_decls: &[TypeParam],
) -> Result<(), SignatureError> {
match &self.0 {
TypeEnum::Sum(SumType::General { rows }) => rows
.iter()
.try_for_each(|row| row.validate(extension_registry, var_decls)),
TypeEnum::Sum(SumType::Unit { .. }) => Ok(()), TypeEnum::Alias(_) => Ok(()),
TypeEnum::Extension(custy) => custy.validate(extension_registry, var_decls),
TypeEnum::Function(ft) => ft.validate(extension_registry, var_decls),
TypeEnum::Variable(idx, bound) => check_typevar_decl(var_decls, *idx, &(*bound).into()),
TypeEnum::RowVar(rv) => rv.validate(var_decls),
}
}
fn substitute(&self, t: &Substitution) -> Vec<Self> {
match &self.0 {
TypeEnum::RowVar(rv) => rv.substitute(t),
TypeEnum::Alias(_) | TypeEnum::Sum(SumType::Unit { .. }) => vec![self.clone()],
TypeEnum::Variable(idx, bound) => {
let TypeArg::Type { ty } = t.apply_var(*idx, &((*bound).into())) else {
panic!("Variable was not a type - try validate() first")
};
vec![ty.into_()]
}
TypeEnum::Extension(cty) => vec![TypeBase::new_extension(cty.substitute(t))],
TypeEnum::Function(bf) => vec![TypeBase::new_function(bf.substitute(t))],
TypeEnum::Sum(SumType::General { rows }) => {
vec![TypeBase::new_sum(rows.iter().map(|r| r.substitute(t)))]
}
}
}
}
impl Type {
fn substitute1(&self, s: &Substitution) -> Self {
let v = self.substitute(s);
let [r] = v.try_into().unwrap(); r
}
}
impl TypeRV {
pub fn is_row_var(&self) -> bool {
matches!(self.0, TypeEnum::RowVar(_))
}
pub const fn new_row_var_use(idx: usize, bound: TypeBound) -> Self {
Self(TypeEnum::RowVar(RowVariable(idx, bound)), bound)
}
}
impl<RV: MaybeRV> TypeBase<RV> {
pub fn try_into_type(self) -> Result<Type, RowVariable> {
Ok(TypeBase(
match self.0 {
TypeEnum::Extension(e) => TypeEnum::Extension(e),
TypeEnum::Alias(a) => TypeEnum::Alias(a),
TypeEnum::Function(f) => TypeEnum::Function(f),
TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound),
TypeEnum::RowVar(rv) => Err(rv.as_rv().clone())?,
TypeEnum::Sum(s) => TypeEnum::Sum(s),
},
self.1,
))
}
}
impl TryFrom<TypeRV> for Type {
type Error = RowVariable;
fn try_from(value: TypeRV) -> Result<Self, RowVariable> {
value.try_into_type()
}
}
impl<RV1: MaybeRV> TypeBase<RV1> {
fn into_<RV2: MaybeRV>(self) -> TypeBase<RV2>
where
RV1: Into<RV2>,
{
TypeBase(
match self.0 {
TypeEnum::Extension(e) => TypeEnum::Extension(e),
TypeEnum::Alias(a) => TypeEnum::Alias(a),
TypeEnum::Function(f) => TypeEnum::Function(f),
TypeEnum::Variable(idx, bound) => TypeEnum::Variable(idx, bound),
TypeEnum::RowVar(rv) => TypeEnum::RowVar(rv.into()),
TypeEnum::Sum(s) => TypeEnum::Sum(s),
},
self.1,
)
}
}
impl From<Type> for TypeRV {
fn from(value: Type) -> Self {
value.into_()
}
}
pub(crate) struct Substitution<'a>(&'a [TypeArg], &'a ExtensionRegistry);
impl<'a> Substitution<'a> {
pub(crate) fn apply_var(&self, idx: usize, decl: &TypeParam) -> TypeArg {
let arg = self
.0
.get(idx)
.expect("Undeclared type variable - call validate() ?");
debug_assert_eq!(check_type_arg(arg, decl), Ok(()));
arg.clone()
}
fn apply_rowvar(&self, idx: usize, bound: TypeBound) -> Vec<TypeRV> {
let arg = self
.0
.get(idx)
.expect("Undeclared type variable - call validate() ?");
debug_assert!(check_type_arg(arg, &TypeParam::new_list(bound)).is_ok());
match arg {
TypeArg::Sequence { elems } => elems
.iter()
.map(|ta| {
match ta {
TypeArg::Type { ty } => return ty.clone().into(),
TypeArg::Variable { v } => {
if let Some(b) = v.bound_if_row_var() {
return TypeRV::new_row_var_use(v.index(), b);
}
}
_ => (),
}
panic!("Not a list of types - call validate() ?")
})
.collect(),
TypeArg::Type { ty } if matches!(ty.0, TypeEnum::RowVar(_)) => {
vec![ty.clone().into()]
}
_ => panic!("Not a type or list of types - call validate() ?"),
}
}
fn extension_registry(&self) -> &ExtensionRegistry {
self.1
}
}
pub(crate) fn check_typevar_decl(
decls: &[TypeParam],
idx: usize,
cached_decl: &TypeParam,
) -> Result<(), SignatureError> {
match decls.get(idx) {
None => Err(SignatureError::FreeTypeVar {
idx,
num_decls: decls.len(),
}),
Some(actual) => {
if actual == cached_decl {
Ok(())
} else {
Err(SignatureError::TypeVarDoesNotMatchDeclaration {
cached: cached_decl.clone(),
actual: actual.clone(),
})
}
}
}
}
#[cfg(test)]
pub(crate) mod test {
use super::*;
use crate::extension::prelude::USIZE_T;
use crate::type_row;
#[test]
fn construct() {
let t: Type = Type::new_tuple(vec![
USIZE_T,
Type::new_function(Signature::new_endo(vec![])),
Type::new_extension(CustomType::new(
"my_custom",
[],
"my_extension".try_into().unwrap(),
TypeBound::Copyable,
)),
Type::new_alias(AliasDecl::new("my_alias", TypeBound::Copyable)),
]);
assert_eq!(
&t.to_string(),
"[usize, Function([[]][]), my_custom, Alias(my_alias)]"
);
}
#[rstest::rstest]
fn sum_construct() {
let pred1 = Type::new_sum([type_row![], type_row![]]);
let pred2 = TypeRV::new_unit_sum(2);
assert_eq!(pred1, pred2);
let pred_direct = SumType::Unit { size: 2 };
assert_eq!(pred1, Type::from(pred_direct));
}
mod proptest {
use crate::proptest::RecursionDepth;
use super::{AliasDecl, MaybeRV, TypeBase, TypeBound, TypeEnum};
use crate::types::{CustomType, FuncValueType, SumType, TypeRowRV};
use ::proptest::prelude::*;
impl Arbitrary for super::SumType {
type Parameters = RecursionDepth;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy {
use proptest::collection::vec;
if depth.leaf() {
any::<u8>().prop_map(Self::new_unary).boxed()
} else {
vec(any_with::<TypeRowRV>(depth), 0..3)
.prop_map(SumType::new)
.boxed()
}
}
}
impl<RV: MaybeRV> Arbitrary for TypeBase<RV> {
type Parameters = RecursionDepth;
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(depth: Self::Parameters) -> Self::Strategy {
let depth = depth.descend();
prop_oneof![
1 => any::<AliasDecl>().prop_map(TypeBase::new_alias),
1 => any_with::<CustomType>(depth.into()).prop_map(TypeBase::new_extension),
1 => any_with::<FuncValueType>(depth).prop_map(TypeBase::new_function),
1 => any_with::<SumType>(depth).prop_map(TypeBase::from),
1 => (any::<usize>(), any::<TypeBound>()).prop_map(|(i,b)| TypeBase::new_var_use(i,b)),
RV::weight() => RV::arb().prop_map(|rv| TypeBase::new(TypeEnum::RowVar(rv)))
]
.boxed()
}
}
}
}