use std::convert::Infallible;
use std::mem;
use std::sync::Arc;
use rustc_index::{Idx, IndexVec};
use thin_vec::ThinVec;
use tracing::{debug, instrument};
use crate::inherent::*;
use crate::visit::{TypeVisitable, TypeVisitableExt as _};
use crate::{self as ty, BoundVarIndexKind, Interner, TypeFlags};
pub trait TypeFoldable<I: Interner>: TypeVisitable<I> + Clone {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error>;
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self;
}
pub trait TypeSuperFoldable<I: Interner>: TypeFoldable<I> {
fn try_super_fold_with<F: FallibleTypeFolder<I>>(
self,
folder: &mut F,
) -> Result<Self, F::Error>;
fn super_fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self;
}
pub trait TypeFolder<I: Interner>: Sized {
fn cx(&self) -> I;
fn fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T>
where
T: TypeFoldable<I>,
{
t.super_fold_with(self)
}
fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
t.super_fold_with(self)
}
fn fold_region(&mut self, r: I::Region) -> I::Region {
r
}
fn fold_const(&mut self, c: I::Const) -> I::Const {
c.super_fold_with(self)
}
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
p.super_fold_with(self)
}
fn fold_clauses(&mut self, c: I::Clauses) -> I::Clauses {
c.super_fold_with(self)
}
}
pub trait FallibleTypeFolder<I: Interner>: Sized {
type Error;
fn cx(&self) -> I;
fn try_fold_binder<T>(&mut self, t: ty::Binder<I, T>) -> Result<ty::Binder<I, T>, Self::Error>
where
T: TypeFoldable<I>,
{
t.try_super_fold_with(self)
}
fn try_fold_ty(&mut self, t: I::Ty) -> Result<I::Ty, Self::Error> {
t.try_super_fold_with(self)
}
fn try_fold_region(&mut self, r: I::Region) -> Result<I::Region, Self::Error> {
Ok(r)
}
fn try_fold_const(&mut self, c: I::Const) -> Result<I::Const, Self::Error> {
c.try_super_fold_with(self)
}
fn try_fold_predicate(&mut self, p: I::Predicate) -> Result<I::Predicate, Self::Error> {
p.try_super_fold_with(self)
}
fn try_fold_clauses(&mut self, c: I::Clauses) -> Result<I::Clauses, Self::Error> {
c.try_super_fold_with(self)
}
}
impl<I: Interner, T: TypeFoldable<I>, U: TypeFoldable<I>> TypeFoldable<I> for (T, U) {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<(T, U), F::Error> {
Ok((self.0.try_fold_with(folder)?, self.1.try_fold_with(folder)?))
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
(self.0.fold_with(folder), self.1.fold_with(folder))
}
}
impl<I: Interner, A: TypeFoldable<I>, B: TypeFoldable<I>, C: TypeFoldable<I>> TypeFoldable<I>
for (A, B, C)
{
fn try_fold_with<F: FallibleTypeFolder<I>>(
self,
folder: &mut F,
) -> Result<(A, B, C), F::Error> {
Ok((
self.0.try_fold_with(folder)?,
self.1.try_fold_with(folder)?,
self.2.try_fold_with(folder)?,
))
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
(self.0.fold_with(folder), self.1.fold_with(folder), self.2.fold_with(folder))
}
}
impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Option<T> {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
Ok(match self {
Some(v) => Some(v.try_fold_with(folder)?),
None => None,
})
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
Some(self?.fold_with(folder))
}
}
impl<I: Interner, T: TypeFoldable<I>, E: TypeFoldable<I>> TypeFoldable<I> for Result<T, E> {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
Ok(match self {
Ok(v) => Ok(v.try_fold_with(folder)?),
Err(e) => Err(e.try_fold_with(folder)?),
})
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
match self {
Ok(v) => Ok(v.fold_with(folder)),
Err(e) => Err(e.fold_with(folder)),
}
}
}
fn fold_arc<T: Clone, E>(
mut arc: Arc<T>,
fold: impl FnOnce(T) -> Result<T, E>,
) -> Result<Arc<T>, E> {
unsafe {
Arc::make_mut(&mut arc);
let ptr = Arc::into_raw(arc).cast::<mem::ManuallyDrop<T>>();
let mut unique = Arc::from_raw(ptr);
let slot = Arc::get_mut(&mut unique).unwrap_unchecked();
let owned = mem::ManuallyDrop::take(slot);
let folded = fold(owned)?;
*slot = mem::ManuallyDrop::new(folded);
Ok(Arc::from_raw(Arc::into_raw(unique).cast()))
}
}
impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Arc<T> {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
fold_arc(self, |t| t.try_fold_with(folder))
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
match fold_arc::<T, Infallible>(self, |t| Ok(t.fold_with(folder))) {
Ok(t) => t,
}
}
}
impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Box<T> {
fn try_fold_with<F: FallibleTypeFolder<I>>(mut self, folder: &mut F) -> Result<Self, F::Error> {
*self = (*self).try_fold_with(folder)?;
Ok(self)
}
fn fold_with<F: TypeFolder<I>>(mut self, folder: &mut F) -> Self {
*self = (*self).fold_with(folder);
self
}
}
impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Vec<T> {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
self.into_iter().map(|t| t.try_fold_with(folder)).collect()
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
self.into_iter().map(|t| t.fold_with(folder)).collect()
}
}
impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for ThinVec<T> {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
self.into_iter().map(|t| t.try_fold_with(folder)).collect()
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
self.into_iter().map(|t| t.fold_with(folder)).collect()
}
}
impl<I: Interner, T: TypeFoldable<I>> TypeFoldable<I> for Box<[T]> {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
Vec::from(self).try_fold_with(folder).map(Vec::into_boxed_slice)
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
Vec::into_boxed_slice(Vec::from(self).fold_with(folder))
}
}
impl<I: Interner, T: TypeFoldable<I>, Ix: Idx> TypeFoldable<I> for IndexVec<Ix, T> {
fn try_fold_with<F: FallibleTypeFolder<I>>(self, folder: &mut F) -> Result<Self, F::Error> {
self.raw.try_fold_with(folder).map(IndexVec::from_raw)
}
fn fold_with<F: TypeFolder<I>>(self, folder: &mut F) -> Self {
IndexVec::from_raw(self.raw.fold_with(folder))
}
}
struct Shifter<I: Interner> {
cx: I,
current_index: ty::DebruijnIndex,
amount: u32,
}
impl<I: Interner> Shifter<I> {
fn new(cx: I, amount: u32) -> Self {
Shifter { cx, current_index: ty::INNERMOST, amount }
}
}
impl<I: Interner> TypeFolder<I> for Shifter<I> {
fn cx(&self) -> I {
self.cx
}
fn fold_binder<T: TypeFoldable<I>>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T> {
self.current_index.shift_in(1);
let t = t.super_fold_with(self);
self.current_index.shift_out(1);
t
}
fn fold_region(&mut self, r: I::Region) -> I::Region {
match r.kind() {
ty::ReBound(ty::BoundVarIndexKind::Bound(debruijn), br)
if debruijn >= self.current_index =>
{
let debruijn = debruijn.shifted_in(self.amount);
Region::new_bound(self.cx, debruijn, br)
}
_ => r,
}
}
fn fold_ty(&mut self, ty: I::Ty) -> I::Ty {
match ty.kind() {
ty::Bound(BoundVarIndexKind::Bound(debruijn), bound_ty)
if debruijn >= self.current_index =>
{
let debruijn = debruijn.shifted_in(self.amount);
Ty::new_bound(self.cx, debruijn, bound_ty)
}
_ if ty.has_vars_bound_at_or_above(self.current_index) => ty.super_fold_with(self),
_ => ty,
}
}
fn fold_const(&mut self, ct: I::Const) -> I::Const {
match ct.kind() {
ty::ConstKind::Bound(ty::BoundVarIndexKind::Bound(debruijn), bound_ct)
if debruijn >= self.current_index =>
{
let debruijn = debruijn.shifted_in(self.amount);
Const::new_bound(self.cx, debruijn, bound_ct)
}
_ => ct.super_fold_with(self),
}
}
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.has_vars_bound_at_or_above(self.current_index) { p.super_fold_with(self) } else { p }
}
}
pub fn shift_region<I: Interner>(cx: I, region: I::Region, amount: u32) -> I::Region {
match region.kind() {
ty::ReBound(ty::BoundVarIndexKind::Bound(debruijn), br) if amount > 0 => {
Region::new_bound(cx, debruijn.shifted_in(amount), br)
}
_ => region,
}
}
#[instrument(level = "trace", skip(cx), ret)]
pub fn shift_vars<I: Interner, T>(cx: I, value: T, amount: u32) -> T
where
T: TypeFoldable<I>,
{
if amount == 0 || !value.has_escaping_bound_vars() {
value
} else {
value.fold_with(&mut Shifter::new(cx, amount))
}
}
pub fn fold_regions<I: Interner, T>(
cx: I,
value: T,
f: impl FnMut(I::Region, ty::DebruijnIndex) -> I::Region,
) -> T
where
T: TypeFoldable<I>,
{
value.fold_with(&mut RegionFolder::new(cx, f))
}
pub struct RegionFolder<I, F> {
cx: I,
current_index: ty::DebruijnIndex,
fold_region_fn: F,
}
impl<I, F> RegionFolder<I, F> {
#[inline]
pub fn new(cx: I, fold_region_fn: F) -> RegionFolder<I, F> {
RegionFolder { cx, current_index: ty::INNERMOST, fold_region_fn }
}
}
impl<I, F> TypeFolder<I> for RegionFolder<I, F>
where
I: Interner,
F: FnMut(I::Region, ty::DebruijnIndex) -> I::Region,
{
fn cx(&self) -> I {
self.cx
}
fn fold_binder<T: TypeFoldable<I>>(&mut self, t: ty::Binder<I, T>) -> ty::Binder<I, T> {
self.current_index.shift_in(1);
let t = t.super_fold_with(self);
self.current_index.shift_out(1);
t
}
#[instrument(skip(self), level = "debug", ret)]
fn fold_region(&mut self, r: I::Region) -> I::Region {
match r.kind() {
ty::ReBound(ty::BoundVarIndexKind::Bound(debruijn), _)
if debruijn < self.current_index =>
{
debug!(?self.current_index, "skipped bound region");
r
}
ty::ReBound(ty::BoundVarIndexKind::Canonical, _) => {
debug!(?self.current_index, "skipped bound region");
r
}
_ => {
debug!(?self.current_index, "folding free region");
(self.fold_region_fn)(r, self.current_index)
}
}
}
fn fold_ty(&mut self, t: I::Ty) -> I::Ty {
if t.has_type_flags(
TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
) {
t.super_fold_with(self)
} else {
t
}
}
fn fold_const(&mut self, ct: I::Const) -> I::Const {
if ct.has_type_flags(
TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
) {
ct.super_fold_with(self)
} else {
ct
}
}
fn fold_predicate(&mut self, p: I::Predicate) -> I::Predicate {
if p.has_type_flags(
TypeFlags::HAS_FREE_REGIONS | TypeFlags::HAS_RE_BOUND | TypeFlags::HAS_RE_ERASED,
) {
p.super_fold_with(self)
} else {
p
}
}
}