use self::union_bound::{PointerLike, UbElement, WithGhostToken};
use crate::dag::{DagLike, NoSharing};
use crate::Tmr;
use std::fmt;
use std::sync::Arc;
pub mod arrow;
mod context;
mod final_data;
mod incomplete;
mod precomputed;
mod union_bound;
mod variable;
pub use context::{BoundRef, Context};
pub use final_data::{CompleteBound, Final};
pub use incomplete::Incomplete;
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum Error {
Bind {
existing_bound: Arc<Incomplete>,
new_bound: Arc<Incomplete>,
hint: &'static str,
},
CompleteTypeMismatch {
type1: Arc<Final>,
type2: Arc<Final>,
hint: &'static str,
},
OccursCheck { infinite_bound: Arc<Incomplete> },
InferenceContextMismatch,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
Error::Bind {
ref existing_bound,
ref new_bound,
hint,
} => {
write!(
f,
"failed to apply bound `{}` to existing bound `{}`: {}",
new_bound, existing_bound, hint,
)
}
Error::CompleteTypeMismatch {
ref type1,
ref type2,
hint,
} => {
write!(
f,
"attempted to unify unequal types `{}` and `{}`: {}",
type1, type2, hint,
)
}
Error::OccursCheck { infinite_bound } => {
write!(f, "infinitely-sized type {}", infinite_bound,)
}
Error::InferenceContextMismatch => {
f.write_str("attempted to combine two nodes with different type inference contexts")
}
}
}
}
impl std::error::Error for Error {}
#[derive(Clone)]
enum Bound<'brand> {
Free(String),
Complete(Arc<Final>),
Sum(TypeInner<'brand>, TypeInner<'brand>),
Product(TypeInner<'brand>, TypeInner<'brand>),
}
impl Bound<'_> {
pub fn shallow_clone(&self) -> Self {
self.clone()
}
}
#[derive(Clone)]
pub struct Type<'brand> {
ctx: Context<'brand>,
inner: TypeInner<'brand>,
}
#[derive(Clone)]
struct TypeInner<'brand> {
bound: UbElement<'brand, BoundRef<'brand>>,
}
impl TypeInner<'_> {
fn shallow_clone(&self) -> Self {
self.clone()
}
}
impl<'brand> Type<'brand> {
pub fn free(ctx: &Context<'brand>, name: String) -> Self {
Self::wrap_bound(ctx, ctx.alloc_free(name))
}
pub fn unit(ctx: &Context<'brand>) -> Self {
Self::wrap_bound(ctx, ctx.alloc_unit())
}
pub fn two_two_n(ctx: &Context<'brand>, n: usize) -> Self {
Self::complete(ctx, precomputed::nth_power_of_2(n))
}
pub fn sum(ctx: &Context<'brand>, left: Self, right: Self) -> Self {
Self::wrap_bound(ctx, ctx.alloc_sum(left, right))
}
pub fn product(ctx: &Context<'brand>, left: Self, right: Self) -> Self {
Self::wrap_bound(ctx, ctx.alloc_product(left, right))
}
pub fn complete(ctx: &Context<'brand>, final_data: Arc<Final>) -> Self {
Self::wrap_bound(ctx, ctx.alloc_complete(final_data))
}
fn wrap_bound(ctx: &Context<'brand>, bound: BoundRef<'brand>) -> Self {
Type {
ctx: ctx.shallow_clone(),
inner: TypeInner {
bound: UbElement::new(bound),
},
}
}
pub fn shallow_clone(&self) -> Self {
self.clone()
}
pub fn tmr(&self) -> Option<Tmr> {
self.final_data().map(|data| data.tmr())
}
pub fn final_data(&self) -> Option<Arc<Final>> {
let root = self.ctx.get_root_ref(&self.inner.bound);
let bound = self.ctx.get(&root);
if let Bound::Complete(ref data) = bound {
Some(Arc::clone(data))
} else {
None
}
}
pub fn is_final(&self) -> bool {
self.final_data().is_some()
}
pub fn to_incomplete(&self) -> Arc<Incomplete> {
let root = self.ctx.get_root_ref(&self.inner.bound);
Incomplete::from_bound_ref(&self.ctx, root)
}
pub fn finalize(&self) -> Result<Arc<Final>, Error> {
let root = self.ctx.get_root_ref(&self.inner.bound);
let bound = self.ctx.get(&root);
if let Bound::Complete(ref data) = bound {
return Ok(Arc::clone(data));
}
if let Some(infinite_bound) = Incomplete::occurs_check(&self.ctx, root.shallow_clone()) {
return Err(Error::OccursCheck { infinite_bound });
}
let mut finalized = vec![];
for data in (&self.ctx, root).post_order_iter::<NoSharing>() {
let bound_get = data.node.0.get(&data.node.1);
let final_data = match bound_get {
Bound::Free(_) => Final::unit(),
Bound::Complete(ref arc) => Arc::clone(arc),
Bound::Sum(..) => Final::sum(
Arc::clone(&finalized[data.left_index.unwrap()]),
Arc::clone(&finalized[data.right_index.unwrap()]),
),
Bound::Product(..) => Final::product(
Arc::clone(&finalized[data.left_index.unwrap()]),
Arc::clone(&finalized[data.right_index.unwrap()]),
),
};
if !matches!(bound_get, Bound::Complete(..)) {
self.ctx
.reassign_non_complete(data.node.1, Bound::Complete(Arc::clone(&final_data)));
}
finalized.push(final_data);
}
Ok(finalized.pop().unwrap())
}
pub fn powers_of_two(ctx: &Context<'brand>, n: usize) -> Vec<Self> {
let mut ret = Vec::with_capacity(n);
let unit = Type::unit(ctx);
let mut two = Type::sum(ctx, unit.shallow_clone(), unit);
for _ in 0..n {
ret.push(two.shallow_clone());
two = Type::product(ctx, two.shallow_clone(), two);
}
ret
}
}
const MAX_DISPLAY_DEPTH: usize = 64;
const MAX_DISPLAY_LENGTH: usize = 10000;
impl fmt::Debug for Type<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let root = self.ctx.get_root_ref(&self.inner.bound);
for data in (&self.ctx, root).verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH)) {
if data.index > MAX_DISPLAY_LENGTH {
write!(f, "... [truncated type after {} nodes]", MAX_DISPLAY_LENGTH)?;
return Ok(());
}
if data.depth == MAX_DISPLAY_DEPTH {
if data.n_children_yielded == 0 {
f.write_str("...")?;
}
continue;
}
let bound = data.node.0.get(&data.node.1);
match (bound, data.n_children_yielded) {
(Bound::Free(ref s), _) => f.write_str(s)?,
(Bound::Complete(ref comp), _) => fmt::Debug::fmt(comp, f)?,
(Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
if data.index > 0 {
f.write_str("(")?;
}
}
(Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
if data.index > 0 {
f.write_str(")")?
}
}
(Bound::Sum(..), _) => f.write_str(" + ")?,
(Bound::Product(..), _) => f.write_str(" × ")?,
}
}
Ok(())
}
}
impl fmt::Display for Type<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let root = self.ctx.get_root_ref(&self.inner.bound);
for data in (&self.ctx, root).verbose_pre_order_iter::<NoSharing>(Some(MAX_DISPLAY_DEPTH)) {
if data.index > MAX_DISPLAY_LENGTH {
write!(f, "... [truncated type after {} nodes]", MAX_DISPLAY_LENGTH)?;
return Ok(());
}
if data.depth == MAX_DISPLAY_DEPTH {
if data.n_children_yielded == 0 {
f.write_str("...")?;
}
continue;
}
let bound = data.node.0.get(&data.node.1);
match (bound, data.n_children_yielded) {
(Bound::Free(ref s), _) => f.write_str(s)?,
(Bound::Complete(ref comp), _) => fmt::Display::fmt(comp, f)?,
(Bound::Sum(..), 0) | (Bound::Product(..), 0) => {
if data.index > 0 {
f.write_str("(")?;
}
}
(Bound::Sum(..), 2) | (Bound::Product(..), 2) => {
if data.index > 0 {
f.write_str(")")?
}
}
(Bound::Sum(..), _) => f.write_str(" + ")?,
(Bound::Product(..), _) => f.write_str(" × ")?,
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::jet::Core;
use crate::node::{ConstructNode, CoreConstructible};
#[test]
fn inference_failure() {
Context::with_context(|ctx| {
let unit = Arc::<ConstructNode<Core>>::unit(&ctx);
Arc::<ConstructNode<Core>>::comp(&unit, &unit).unwrap();
let take_unit = Arc::<ConstructNode<Core>>::take(&unit);
Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
Arc::<ConstructNode<Core>>::pair(&unit, &take_unit).unwrap_err();
});
}
#[test]
fn memory_leak() {
Context::with_context(|ctx| {
let iden = Arc::<ConstructNode<Core>>::iden(&ctx);
let drop = Arc::<ConstructNode<Core>>::drop_(&iden);
let case = Arc::<ConstructNode<Core>>::case(&iden, &drop).unwrap();
let _ = format!("{:?}", case.arrow().source);
});
}
}