use {
crate::{
cache,
multiset::Multiset,
reflection::{
AlgebraicTypeFormer, Erased, ErasedTermBuckets, PrecomputedTypeFormer, Type, info,
type_of,
},
scc::StronglyConnectedComponents,
search,
size::{Size, Sizes},
},
core::{fmt, mem, num::NonZero, ops::Deref, ptr},
std::collections::BTreeSet,
wyrand::WyRand,
};
#[non_exhaustive]
#[derive(Clone, Copy, Hash)]
pub struct CtorFn<T> {
pub call: for<'terms> fn(&'terms mut ErasedTermBuckets) -> Option<T>,
}
#[non_exhaustive]
#[derive(Clone, Copy, Debug, Hash)]
pub struct IndexedCtorFn<T> {
pub arbitrary_fields:
for<'prng> fn(&'prng mut WyRand, Sizes) -> Result<ErasedTermBuckets, MaybeUninstantiable>,
pub call: CtorFn<T>,
pub index: NonZero<usize>,
pub n_big: usize,
}
#[non_exhaustive]
#[repr(transparent)]
#[derive(Clone, Copy, Hash)]
pub struct ElimFn<T> {
pub call: fn(T) -> Decomposition,
}
#[derive(Clone, Debug)]
#[expect(clippy::exhaustive_structs, reason = "constructed in macros")]
pub struct Algebraic<T> {
pub elimination_rule: ElimFn<T>,
pub introduction_rules: Vec<IntroductionRule<T>>,
}
#[non_exhaustive]
#[derive(Clone, Debug)]
pub struct Literal<T> {
pub deserialize: fn(&str) -> Option<T>,
pub generate: for<'prng> fn(&'prng mut WyRand) -> T,
pub serialize: fn(&T) -> String,
pub shrink: fn(T) -> Box<dyn Iterator<Item = T>>,
}
#[non_exhaustive]
#[derive(Clone, Debug)]
pub enum TypeFormer<T> {
Algebraic(Algebraic<T>),
Literal(Literal<T>),
}
#[derive(Clone, Debug)]
#[expect(clippy::exhaustive_enums, reason = "used internally")]
pub enum MaybeUninstantiable {
Retry,
Uninstantiable,
}
#[derive(Debug)]
#[expect(clippy::exhaustive_structs, reason = "constructed in macros")]
pub struct Decomposition {
pub ctor_idx: NonZero<usize>,
pub fields: ErasedTermBuckets,
}
#[derive(Clone, Debug)]
#[expect(clippy::exhaustive_structs, reason = "constructed in macros")]
pub struct IntroductionRule<T> {
pub arbitrary_fields:
for<'prng> fn(&'prng mut WyRand, Sizes) -> Result<ErasedTermBuckets, MaybeUninstantiable>,
pub call: CtorFn<T>,
pub immediate_dependencies: Multiset<Type>,
}
pub trait Construct: 'static + Clone + fmt::Debug + Eq {
fn register_all_immediate_dependencies(
visited: &mut BTreeSet<Type>,
sccs: &mut StronglyConnectedComponents,
);
fn type_former() -> TypeFormer<Self>;
fn visit_deep<V: Construct>(&self) -> impl Iterator<Item = V>;
}
impl<T> CtorFn<T> {
#[inline]
#[must_use]
pub const fn erase(self) -> CtorFn<Erased> {
unsafe { mem::transmute::<CtorFn<T>, CtorFn<Erased>>(self) }
}
#[inline]
pub const fn new(call: for<'terms> fn(&'terms mut ErasedTermBuckets) -> Option<T>) -> Self {
Self { call }
}
}
impl CtorFn<Erased> {
#[inline]
#[must_use]
pub const unsafe fn unerase<T>(
self,
) -> for<'terms> fn(&'terms mut ErasedTermBuckets) -> Option<T> {
unsafe { mem::transmute::<CtorFn<Erased>, CtorFn<T>>(self) }.call
}
}
impl<T> fmt::Debug for CtorFn<T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("(|terms| ...)")
}
}
impl<T> Deref for CtorFn<T> {
type Target = for<'terms> fn(&'terms mut ErasedTermBuckets) -> Option<T>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.call
}
}
impl<T> Deref for IndexedCtorFn<T> {
type Target = CtorFn<T>;
#[inline]
fn deref(&self) -> &Self::Target {
&self.call
}
}
impl<T> ElimFn<T> {
#[inline]
#[must_use]
pub const fn erase(self) -> ElimFn<Erased> {
unsafe { mem::transmute::<ElimFn<T>, ElimFn<Erased>>(self) }
}
#[inline]
pub const fn new(call: fn(T) -> Decomposition) -> Self {
Self { call }
}
}
impl<T> fmt::Debug for ElimFn<T> {
#[inline]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("(|ctor| ...)")
}
}
impl ElimFn<Erased> {
#[inline]
#[must_use]
pub const unsafe fn unerase<T>(self) -> fn(T) -> Decomposition {
unsafe { mem::transmute::<ElimFn<Erased>, ElimFn<T>>(self) }.call
}
}
impl<T> Deref for ElimFn<T> {
type Target = fn(T) -> Decomposition;
#[inline]
fn deref(&self) -> &Self::Target {
&self.call
}
}
#[inline]
pub fn arbitrary<T: Construct>(prng: &mut WyRand, mut size: Size) -> Option<T> {
loop {
match try_arbitrary::<T>(prng, size._copy()) {
Ok(t) => return Some(t),
Err(MaybeUninstantiable::Retry) => size._increment(),
Err(MaybeUninstantiable::Uninstantiable) => return None,
}
}
}
#[inline]
pub fn push_arbitrary_field<T: Construct>(
fields: &mut ErasedTermBuckets,
sizes: &mut Sizes,
prng: &mut WyRand,
) -> Result<(), MaybeUninstantiable> {
match sizes.try_arbitrary::<T>(prng) {
Ok(t) => {
fields.push(t);
Ok(())
}
Err(error) => {
sizes._discard_remaining();
Err(error)
}
}
}
#[inline]
#[expect(
clippy::needless_pass_by_value,
reason = "`Size` is intentionally consumed as the total budget for one generation attempt"
)]
pub fn try_arbitrary<T: Construct>(
prng: &mut WyRand,
size: Size,
) -> Result<T, MaybeUninstantiable> {
let info = info::<T>();
match info.type_former {
PrecomputedTypeFormer::Algebraic(ref adt) => {
let potential_loops = adt.potential_loops();
let mut canary = 0_u8;
loop {
let (ctor, minus_one) = if size.should_recurse(prng)
&& let Some(n) = NonZero::new(potential_loops.len())
{
#[expect(
clippy::as_conversions,
clippy::cast_possible_truncation,
reason = "fine: definitely not > `u64::MAX` constructors"
)]
let i = prng.rand() as usize % n;
(unsafe { potential_loops.get_unchecked(i) }, true)
} else {
let potential_leaves = adt.potential_leaves();
let Some(n) = NonZero::new(potential_leaves.len()) else {
return Err(MaybeUninstantiable::Uninstantiable);
};
#[expect(
clippy::as_conversions,
clippy::cast_possible_truncation,
reason = "fine: definitely not > `u64::MAX` constructors"
)]
let i = prng.rand() as usize % n;
(unsafe { potential_leaves.get_unchecked(i) }, false)
};
let sizes = size._partition_into(ctor.n_big, prng, minus_one);
let mut fields = (ctor.arbitrary_fields)(prng, sizes)?;
let result = unsafe { ctor.unerase::<T>() }(&mut fields);
debug_assert!(
fields.is_empty(),
"internal `pbt` error: leftover terms after applying a constructor: {fields:#?}",
);
if let Some(result) = result {
return Ok(result);
}
let Some(next_canary) = canary.checked_add(1) else {
return Err(MaybeUninstantiable::Retry);
};
canary = next_canary;
}
}
PrecomputedTypeFormer::Literal { generate, .. } => {
let generate = unsafe {
mem::transmute::<fn(&mut WyRand) -> Erased, fn(&mut WyRand) -> T>(generate)
};
Ok(generate(prng))
}
}
}
#[inline]
pub fn check_eta_expansion<T: Construct>() {
let info = info::<T>();
let PrecomputedTypeFormer::Algebraic(AlgebraicTypeFormer {
ref all_constructors,
eliminator,
..
}) = info.type_former
else {
return;
};
let eliminator = unsafe { mem::transmute::<ElimFn<Erased>, ElimFn<T>>(eliminator) };
let () = search::assert_eq(32, |orig: &T| {
let Decomposition {
ctor_idx,
mut fields,
} = eliminator(orig.clone());
#[expect(clippy::multiple_unsafe_ops_per_block, reason = "logically grouped")]
let (ctor, _) = *unsafe { all_constructors.get_unchecked(ctor_idx.get().unchecked_sub(1)) };
let f = unsafe { ctor.unerase::<T>() };
let constructed = f(&mut fields);
assert!(
fields.is_empty(),
"internal `pbt` error: leftover terms after applying a constructor: {fields:#?}",
);
(constructed, Some(orig.clone()))
});
}
#[inline]
pub fn visit_self<V: Construct, S: Construct>(s: &S) -> impl Iterator<Item = V> {
visit_self_opt::<V, S>(s).cloned().into_iter()
}
#[inline]
pub fn visit_self_opt<V: Construct, S: Construct>(s: &S) -> Option<&V> {
(type_of::<V>() == type_of::<S>()).then(|| {
let s: *const S = ptr::from_ref(s);
let s: *const V = s.cast();
unsafe { &*s }
})
}
#[inline]
pub fn visit_self_owned<V: Construct, S: Construct>(s: S) -> Option<V> {
(type_of::<V>() == type_of::<S>()).then(|| {
let ptr: *const S = ptr::from_ref(&s);
let ptr: *const V = ptr.cast();
let v: V = unsafe { ptr::read(ptr) };
#[expect(clippy::mem_forget, reason = "intentional")]
let () = mem::forget(s);
v
})
}
#[inline]
pub(crate) fn deserialize_cached_term_into_buckets<T: Construct>(
term: &cache::CachedTerm,
terms: &mut ErasedTermBuckets,
) -> bool {
let Some(value) = cache::deserialize_term::<T>(term) else {
return false;
};
terms.push(value);
true
}