use std::fmt;
use std::marker::PhantomData;
use std::sync::{Arc, Mutex, MutexGuard};
use ghost_cell::GhostToken;
use crate::dag::{Dag, DagLike};
use super::{
Bound, CompleteBound, Error, Final, Incomplete, Type, TypeInner, UbElement, WithGhostToken,
};
type InvariantLifetime<'brand> = PhantomData<fn(&'brand ()) -> &'brand ()>;
#[derive(Clone)]
pub struct Context<'brand> {
inner: Arc<Mutex<WithGhostToken<'brand, ContextInner<'brand>>>>,
}
struct ContextInner<'brand> {
slab: Vec<Bound<'brand>>,
}
impl fmt::Debug for Context<'_> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
let id = Arc::as_ptr(&self.inner) as usize;
write!(f, "inference_ctx_{:08x}", id)
}
}
impl PartialEq for Context<'_> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl Eq for Context<'_> {}
impl<'brand> Context<'brand> {
pub fn with_context<R, F>(fun: F) -> R
where
F: for<'new_brand> FnOnce(Context<'new_brand>) -> R,
{
GhostToken::new(|token| {
let ctx = Context::new(token);
fun(ctx)
})
}
pub fn new(token: GhostToken<'brand>) -> Self {
Context {
inner: Arc::new(Mutex::new(WithGhostToken {
token,
inner: ContextInner { slab: vec![] },
})),
}
}
fn alloc_bound(&self, bound: Bound<'brand>) -> BoundRef<'brand> {
let mut lock = self.lock();
lock.alloc_bound(bound)
}
pub fn alloc_free(&self, name: String) -> BoundRef<'brand> {
self.alloc_bound(Bound::Free(name))
}
pub fn alloc_unit(&self) -> BoundRef<'brand> {
self.alloc_bound(Bound::Complete(Final::unit()))
}
pub fn alloc_complete(&self, data: Arc<Final>) -> BoundRef<'brand> {
self.alloc_bound(Bound::Complete(data))
}
pub fn alloc_sum(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
let mut lock = self.lock();
if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
lock.alloc_bound(Bound::Complete(Final::sum(data1, data2)))
} else {
lock.alloc_bound(Bound::Sum(left.inner, right.inner))
}
}
pub fn alloc_product(&self, left: Type<'brand>, right: Type<'brand>) -> BoundRef<'brand> {
let mut lock = self.lock();
if let Some((data1, data2)) = lock.complete_pair_data(&left.inner, &right.inner) {
lock.alloc_bound(Bound::Complete(Final::product(data1, data2)))
} else {
lock.alloc_bound(Bound::Product(left.inner, right.inner))
}
}
pub fn shallow_clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
pub fn check_eq(&self, other: &Self) -> Result<(), super::Error> {
if self == other {
Ok(())
} else {
Err(super::Error::InferenceContextMismatch)
}
}
pub(super) fn get(&self, bound: &BoundRef<'brand>) -> Bound<'brand> {
let lock = self.lock();
lock.inner.slab[bound.index].shallow_clone()
}
pub(super) fn get_root_ref(
&self,
bound: &UbElement<'brand, BoundRef<'brand>>,
) -> BoundRef<'brand> {
let mut lock = self.lock();
bound.root(&mut lock.token)
}
pub(super) fn reassign_non_complete(&self, bound: BoundRef<'brand>, new: Bound<'brand>) {
let mut lock = self.lock();
lock.reassign_non_complete(bound, new);
}
pub fn bind_product(
&self,
existing: &Type<'brand>,
prod_l: &Type<'brand>,
prod_r: &Type<'brand>,
hint: &'static str,
) -> Result<(), Error> {
let mut lock = self.lock();
let existing_root = existing.inner.bound.root(&mut lock.token);
let new_bound = Bound::Product(prod_l.inner.shallow_clone(), prod_r.inner.shallow_clone());
lock.bind(existing_root, new_bound).map_err(|e| {
let new_bound = lock.alloc_bound(e.new);
drop(lock);
Error::Bind {
existing_bound: Incomplete::from_bound_ref(self, e.existing),
new_bound: Incomplete::from_bound_ref(self, new_bound),
hint,
}
})
}
pub fn unify(
&self,
ty1: &Type<'brand>,
ty2: &Type<'brand>,
hint: &'static str,
) -> Result<(), Error> {
let mut lock = self.lock();
lock.unify(&ty1.inner, &ty2.inner).map_err(|e| {
let new_bound = lock.alloc_bound(e.new);
drop(lock);
Error::Bind {
existing_bound: Incomplete::from_bound_ref(self, e.existing),
new_bound: Incomplete::from_bound_ref(self, new_bound),
hint,
}
})
}
fn lock(&self) -> MutexGuard<'_, WithGhostToken<'brand, ContextInner<'brand>>> {
self.inner.lock().unwrap()
}
}
#[derive(Debug, Clone)]
pub struct BoundRef<'brand> {
phantom: InvariantLifetime<'brand>,
index: usize,
}
impl<'brand> BoundRef<'brand> {
pub fn occurs_check_id(&self) -> OccursCheckId<'brand> {
OccursCheckId {
phantom: InvariantLifetime::default(),
index: self.index,
}
}
}
impl super::PointerLike for BoundRef<'_> {
fn ptr_eq(&self, other: &Self) -> bool {
self.index == other.index
}
fn shallow_clone(&self) -> Self {
BoundRef {
phantom: InvariantLifetime::default(),
index: self.index,
}
}
}
impl<'brand> DagLike for (&'_ Context<'brand>, BoundRef<'brand>) {
type Node = BoundRef<'brand>;
fn data(&self) -> &BoundRef<'brand> {
&self.1
}
fn as_dag_node(&self) -> Dag<Self> {
match self.0.get(&self.1) {
Bound::Free(..) | Bound::Complete(..) => Dag::Nullary,
Bound::Sum(ref ty1, ref ty2) | Bound::Product(ref ty1, ref ty2) => {
let root1 = self.0.get_root_ref(&ty1.bound);
let root2 = self.0.get_root_ref(&ty2.bound);
Dag::Binary((self.0, root1), (self.0, root2))
}
}
}
}
#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)]
pub struct OccursCheckId<'brand> {
phantom: InvariantLifetime<'brand>,
index: usize,
}
struct BindError<'brand> {
existing: BoundRef<'brand>,
new: Bound<'brand>,
}
impl<'brand> ContextInner<'brand> {
fn alloc_bound(&mut self, bound: Bound<'brand>) -> BoundRef<'brand> {
self.slab.push(bound);
let index = self.slab.len() - 1;
BoundRef {
phantom: InvariantLifetime::default(),
index,
}
}
fn reassign_non_complete(&mut self, bound: BoundRef<'brand>, new: Bound<'brand>) {
assert!(
!matches!(self.slab[bound.index], Bound::Complete(..)),
"tried to modify finalized type",
);
self.slab[bound.index] = new;
}
}
impl<'brand> WithGhostToken<'brand, ContextInner<'brand>> {
fn complete_pair_data(
&mut self,
inn1: &TypeInner<'brand>,
inn2: &TypeInner<'brand>,
) -> Option<(Arc<Final>, Arc<Final>)> {
let idx1 = inn1.bound.root(&mut self.token).index;
let idx2 = inn2.bound.root(&mut self.token).index;
let bound1 = &self.slab[idx1];
let bound2 = &self.slab[idx2];
if let (Bound::Complete(ref data1), Bound::Complete(ref data2)) = (bound1, bound2) {
Some((Arc::clone(data1), Arc::clone(data2)))
} else {
None
}
}
fn unify(
&mut self,
existing: &TypeInner<'brand>,
other: &TypeInner<'brand>,
) -> Result<(), BindError<'brand>> {
existing
.bound
.unify(self, &other.bound, |self_, x_bound, y_bound| {
self_.bind(x_bound, self_.slab[y_bound.index].shallow_clone())
})
}
fn bind(
&mut self,
existing: BoundRef<'brand>,
new: Bound<'brand>,
) -> Result<(), BindError<'brand>> {
let existing_bound = self.slab[existing.index].shallow_clone();
let bind_error = || BindError {
existing: existing.clone(),
new: new.shallow_clone(),
};
match (&existing_bound, &new) {
(_, Bound::Free(_)) => Ok(()),
(Bound::Free(_), _) => {
self.reassign_non_complete(existing, new);
Ok(())
}
(Bound::Complete(ref existing_final), Bound::Complete(ref new_final)) => {
if existing_final == new_final {
Ok(())
} else {
Err(bind_error())
}
}
(Bound::Complete(complete), incomplete) | (incomplete, Bound::Complete(complete)) => {
match (complete.bound(), incomplete) {
(CompleteBound::Unit, _) => Err(bind_error()),
(
CompleteBound::Product(ref comp1, ref comp2),
Bound::Product(ref ty1, ref ty2),
)
| (CompleteBound::Sum(ref comp1, ref comp2), Bound::Sum(ref ty1, ref ty2)) => {
let bound1 = ty1.bound.root(&mut self.token);
let bound2 = ty2.bound.root(&mut self.token);
self.bind(bound1, Bound::Complete(Arc::clone(comp1)))?;
self.bind(bound2, Bound::Complete(Arc::clone(comp2)))
}
_ => Err(bind_error()),
}
}
(Bound::Sum(ref x1, ref x2), Bound::Sum(ref y1, ref y2))
| (Bound::Product(ref x1, ref x2), Bound::Product(ref y1, ref y2)) => {
self.unify(x1, y1)?;
self.unify(x2, y2)?;
if let Some((data1, data2)) = self.complete_pair_data(y1, y2) {
self.reassign_non_complete(
existing,
Bound::Complete(if let Bound::Sum(..) = existing_bound {
Final::sum(data1, data2)
} else {
Final::product(data1, data2)
}),
);
}
Ok(())
}
(_, _) => Err(bind_error()),
}
}
}