use {
crate::{
multiset::Multiset,
reflection::{
AlgebraicTypeFormer, Erased, PrecomputedTypeFormer, TermsOfVariousTypes, Type, info,
type_of,
},
search,
size::{Size, Sizes},
},
core::{fmt, mem, num::NonZero, ops::Deref, ptr},
std::collections::BTreeSet,
wyrand::WyRand,
};
#[non_exhaustive]
#[derive(Clone, Copy, Hash)]
pub struct CtorFn<T> {
pub call: for<'terms> fn(&'terms mut TermsOfVariousTypes) -> T,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Hash)]
pub struct IndexedCtorFn<T> {
pub arbitrary_fields: for<'prng> fn(&'prng mut WyRand, Sizes) -> TermsOfVariousTypes,
pub call: CtorFn<T>,
pub index: NonZero<usize>,
pub n_big: usize,
}
#[non_exhaustive]
#[repr(transparent)]
#[derive(Clone, Copy, Hash)]
pub struct ElimFn<T> {
pub call: fn(T) -> Decomposition,
}
#[derive(Clone, Debug)]
#[expect(clippy::exhaustive_structs, reason = "constructed in macros")]
pub struct Algebraic<T> {
pub elimination_rule: ElimFn<T>,
pub introduction_rules: Vec<IntroductionRule<T>>,
}
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Literal<T> {
pub generate: for<'prng> fn(&'prng mut WyRand) -> T,
pub shrink: fn(T) -> Box<dyn Iterator<Item = T>>,
}
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum TypeFormer<T> {
Algebraic(Algebraic<T>),
Literal(Literal<T>),
}
#[derive(Debug)]
#[expect(clippy::exhaustive_structs, reason = "constructed in macros")]
pub struct Decomposition {
pub ctor_idx: NonZero<usize>,
pub fields: TermsOfVariousTypes,
}
#[derive(Clone, Debug)]
#[expect(clippy::exhaustive_structs, reason = "constructed in macros")]
pub struct IntroductionRule<T> {
pub arbitrary_fields: for<'prng> fn(&'prng mut WyRand, Sizes) -> TermsOfVariousTypes,
pub call: CtorFn<T>,
pub immediate_dependencies: Multiset<Type>,
}
pub trait Construct: 'static + Clone + fmt::Debug + Eq {
fn register_all_immediate_dependencies(visited: &mut BTreeSet<Type>);
fn type_former() -> TypeFormer<Self>;
fn visit_deep<V: Construct>(&self) -> impl Iterator<Item = V>;
}
impl<T> CtorFn<T> {
#[inline]
#[must_use]
pub const fn erase(self) -> CtorFn<Erased> {
unsafe { mem::transmute::<CtorFn<T>, CtorFn<Erased>>(self) }
}
#[inline]
pub const fn new(call: for<'terms> fn(&'terms mut TermsOfVariousTypes) -> T) -> Self {
Self { call }
}
}
impl CtorFn<Erased> {
#[inline]
#[must_use]
pub const unsafe fn unerase<T>(self) -> for<'terms> fn(&'terms mut TermsOfVariousTypes) -> T {
unsafe { mem::transmute::<CtorFn<Erased>, CtorFn<T>>(self) }.call
}
}
impl<T> fmt::Debug for CtorFn<T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("(|terms| ...)")
}
}
impl<T> Deref for CtorFn<T> {
type Target = for<'terms> fn(&'terms mut TermsOfVariousTypes) -> T;
#[inline]
fn deref(&self) -> &Self::Target {
&self.call
}
}
impl<T> Deref for IndexedCtorFn<T> {
type Target = CtorFn<T>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.call
}
}
impl<T> ElimFn<T> {
#[inline]
#[must_use]
pub const fn erase(self) -> ElimFn<Erased> {
unsafe { mem::transmute::<ElimFn<T>, ElimFn<Erased>>(self) }
}
#[inline]
pub const fn new(call: fn(T) -> Decomposition) -> Self {
Self { call }
}
}
impl<T> fmt::Debug for ElimFn<T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("(|ctor| ...)")
}
}
impl ElimFn<Erased> {
#[inline]
#[must_use]
pub const unsafe fn unerase<T>(self) -> fn(T) -> Decomposition {
unsafe { mem::transmute::<ElimFn<Erased>, ElimFn<T>>(self) }.call
}
}
impl<T> Deref for ElimFn<T> {
type Target = fn(T) -> Decomposition;
#[inline]
fn deref(&self) -> &Self::Target {
&self.call
}
}
#[inline]
pub fn arbitrary<T: Construct>(prng: &mut WyRand, size: Size) -> Option<T> {
let info = info::<T>();
match info.type_former {
PrecomputedTypeFormer::Algebraic(ref adt) => {
let potential_loops = adt.potential_loops();
let (ctor, minus_one) = if size.should_recurse(prng)
&& let Some(n) = NonZero::new(potential_loops.len())
{
#[expect(
clippy::as_conversions,
clippy::cast_possible_truncation,
reason = "fine: definitely not > `u64::MAX` constructors"
)]
let i = prng.rand() as usize % n;
(unsafe { potential_loops.get_unchecked(i) }, true)
} else {
let potential_leaves = adt.potential_leaves();
let n = NonZero::new(potential_leaves.len())?;
#[expect(
clippy::as_conversions,
clippy::cast_possible_truncation,
reason = "fine: definitely not > `u64::MAX` constructors"
)]
let i = prng.rand() as usize % n;
(unsafe { potential_leaves.get_unchecked(i) }, false)
};
let sizes = size.partition_into(ctor.n_big, prng, minus_one);
let mut fields = (ctor.arbitrary_fields)(prng, sizes);
let result = unsafe { ctor.unerase::<T>() }(&mut fields);
debug_assert!(
fields.is_empty(),
"internal `pbt` error: leftover terms after applying a constructor: {fields:#?}",
);
Some(result)
}
PrecomputedTypeFormer::Literal { generate, .. } => {
let generate = unsafe {
mem::transmute::<fn(&mut WyRand) -> Erased, fn(&mut WyRand) -> T>(generate)
};
Some(generate(prng))
}
}
}
#[inline]
pub fn check_eta_expansion<T: Construct>() {
let info = info::<T>();
let PrecomputedTypeFormer::Algebraic(AlgebraicTypeFormer {
ref all_constructors,
eliminator,
..
}) = info.type_former
else {
return;
};
let eliminator = unsafe { mem::transmute::<ElimFn<Erased>, ElimFn<T>>(eliminator) };
let () = search::assert_eq(32, |orig: &T| {
let Decomposition {
ctor_idx,
mut fields,
} = eliminator(orig.clone());
#[expect(clippy::multiple_unsafe_ops_per_block, reason = "logically grouped")]
let (ctor, _) = *unsafe { all_constructors.get_unchecked(ctor_idx.get().unchecked_sub(1)) };
let f = unsafe { ctor.unerase::<T>() };
let constructed = f(&mut fields);
assert!(
fields.is_empty(),
"internal `pbt` error: leftover terms after applying a constructor: {fields:#?}",
);
(constructed, orig.clone())
});
}
#[inline]
pub fn visit_self<V: Construct, S: Construct>(s: &S) -> impl Iterator<Item = V> {
visit_self_opt::<V, S>(s).cloned().into_iter()
}
#[inline]
pub fn visit_self_opt<V: Construct, S: Construct>(s: &S) -> Option<&V> {
(type_of::<V>() == type_of::<S>()).then(|| {
let s: *const S = ptr::from_ref(s);
let s: *const V = s.cast();
unsafe { &*s }
})
}
#[inline]
pub fn visit_self_owned<V: Construct, S: Construct>(s: S) -> Option<V> {
(type_of::<V>() == type_of::<S>()).then(|| {
let ptr: *const S = ptr::from_ref(&s);
let ptr: *const V = ptr.cast();
let v: V = unsafe { ptr::read(ptr) };
#[expect(clippy::mem_forget, reason = "intentional")]
let () = mem::forget(s);
v
})
}