use rustc_hash::FxHashSet;
use std::hash::Hash;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecursionProfile {
SubtypeCheck,
TypeEvaluation,
TypeApplication,
PropertyAccess,
Variance,
ShapeExtraction,
ShallowTraversal,
ConstAssertion,
ExpressionCheck,
TypeNodeCheck,
CallResolution,
CheckerRecursion,
Custom { max_depth: u32, max_iterations: u32 },
}
impl RecursionProfile {
pub const fn max_depth(self) -> u32 {
match self {
Self::SubtypeCheck => 100,
Self::TypeEvaluation
| Self::TypeApplication
| Self::PropertyAccess
| Self::Variance
| Self::ShapeExtraction
| Self::ConstAssertion
| Self::CheckerRecursion => 50,
Self::ShallowTraversal | Self::CallResolution => 20,
Self::ExpressionCheck | Self::TypeNodeCheck => 500,
Self::Custom { max_depth, .. } => max_depth,
}
}
pub const fn max_iterations(self) -> u32 {
match self {
Self::SubtypeCheck
| Self::TypeEvaluation
| Self::TypeApplication
| Self::PropertyAccess
| Self::Variance
| Self::ShapeExtraction
| Self::ConstAssertion
| Self::ExpressionCheck
| Self::TypeNodeCheck
| Self::CallResolution
| Self::ShallowTraversal
| Self::CheckerRecursion => 100_000,
Self::Custom { max_iterations, .. } => max_iterations,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecursionResult {
Entered,
Cycle,
DepthExceeded,
IterationExceeded,
}
impl RecursionResult {
#[inline]
pub const fn is_entered(self) -> bool {
matches!(self, Self::Entered)
}
#[inline]
pub const fn is_cycle(self) -> bool {
matches!(self, Self::Cycle)
}
#[inline]
pub const fn is_exceeded(self) -> bool {
matches!(self, Self::DepthExceeded | Self::IterationExceeded)
}
#[inline]
pub const fn is_denied(self) -> bool {
!self.is_entered()
}
}
pub struct RecursionGuard<K: Hash + Eq + Copy> {
visiting: FxHashSet<K>,
depth: u32,
iterations: u32,
max_depth: u32,
max_iterations: u32,
max_visiting: u32,
exceeded: bool,
}
impl<K: Hash + Eq + Copy> RecursionGuard<K> {
pub fn new(max_depth: u32, max_iterations: u32) -> Self {
Self {
visiting: FxHashSet::default(),
depth: 0,
iterations: 0,
max_depth,
max_iterations,
max_visiting: 10_000,
exceeded: false,
}
}
pub fn with_profile(profile: RecursionProfile) -> Self {
Self::new(profile.max_depth(), profile.max_iterations())
}
pub const fn with_max_visiting(mut self, max_visiting: u32) -> Self {
self.max_visiting = max_visiting;
self
}
pub fn enter(&mut self, key: K) -> RecursionResult {
self.iterations = self.iterations.saturating_add(1);
if self.iterations > self.max_iterations {
self.exceeded = true;
return RecursionResult::IterationExceeded;
}
if self.depth >= self.max_depth {
self.exceeded = true;
return RecursionResult::DepthExceeded;
}
if self.visiting.contains(&key) {
return RecursionResult::Cycle;
}
if self.visiting.len() as u32 >= self.max_visiting {
self.exceeded = true;
return RecursionResult::DepthExceeded;
}
self.visiting.insert(key);
self.depth += 1;
RecursionResult::Entered
}
pub fn leave(&mut self, key: K) {
let was_present = self.visiting.remove(&key);
debug_assert!(
was_present,
"RecursionGuard::leave() called with a key that is not in the visiting set. \
This indicates a double-leave or a leave without a matching enter()."
);
self.depth = self.depth.saturating_sub(1);
}
pub fn scope<T>(&mut self, key: K, f: impl FnOnce() -> T) -> Result<T, RecursionResult> {
match self.enter(key) {
RecursionResult::Entered => {
let result = f();
self.leave(key);
Ok(result)
}
denied => Err(denied),
}
}
#[inline]
pub fn is_visiting(&self, key: &K) -> bool {
self.visiting.contains(key)
}
pub fn is_visiting_any(&self, predicate: impl Fn(&K) -> bool) -> bool {
self.visiting.iter().any(predicate)
}
#[inline]
pub const fn depth(&self) -> u32 {
self.depth
}
#[inline]
pub const fn iterations(&self) -> u32 {
self.iterations
}
#[inline]
pub fn visiting_count(&self) -> usize {
self.visiting.len()
}
#[inline]
pub const fn is_active(&self) -> bool {
self.depth > 0
}
#[inline]
pub const fn max_depth(&self) -> u32 {
self.max_depth
}
#[inline]
pub const fn max_iterations(&self) -> u32 {
self.max_iterations
}
#[inline]
pub const fn is_exceeded(&self) -> bool {
self.exceeded
}
#[inline]
pub const fn mark_exceeded(&mut self) {
self.exceeded = true;
}
pub fn reset(&mut self) {
self.visiting.clear();
self.depth = 0;
self.iterations = 0;
self.exceeded = false;
}
}
#[cfg(debug_assertions)]
impl<K: Hash + Eq + Copy> Drop for RecursionGuard<K> {
fn drop(&mut self) {
if !std::thread::panicking() && !self.visiting.is_empty() {
panic!(
"RecursionGuard dropped with {} active entries still in the visiting set. \
This indicates leaked enter() calls without matching leave() calls.",
self.visiting.len(),
);
}
}
}
#[derive(Debug)]
pub struct DepthCounter {
depth: u32,
max_depth: u32,
exceeded: bool,
base_depth: u32,
}
impl DepthCounter {
pub const fn new(max_depth: u32) -> Self {
Self {
depth: 0,
max_depth,
exceeded: false,
base_depth: 0,
}
}
pub const fn with_profile(profile: RecursionProfile) -> Self {
Self::new(profile.max_depth())
}
pub const fn with_initial_depth(max_depth: u32, initial_depth: u32) -> Self {
Self {
depth: initial_depth,
max_depth,
exceeded: false,
base_depth: initial_depth,
}
}
#[inline]
pub const fn enter(&mut self) -> bool {
if self.depth >= self.max_depth {
self.exceeded = true;
return false;
}
self.depth += 1;
true
}
#[inline]
pub fn leave(&mut self) {
debug_assert!(
self.depth > 0,
"DepthCounter::leave() called at depth 0. \
This indicates a leave without a matching enter()."
);
self.depth = self.depth.saturating_sub(1);
}
#[inline]
pub const fn depth(&self) -> u32 {
self.depth
}
#[inline]
pub const fn max_depth(&self) -> u32 {
self.max_depth
}
#[inline]
pub const fn is_exceeded(&self) -> bool {
self.exceeded
}
#[inline]
pub const fn mark_exceeded(&mut self) {
self.exceeded = true;
}
pub const fn reset(&mut self) {
self.depth = self.base_depth;
self.exceeded = false;
}
}
#[cfg(debug_assertions)]
impl Drop for DepthCounter {
fn drop(&mut self) {
if !std::thread::panicking() && self.depth > self.base_depth {
panic!(
"DepthCounter dropped with depth {} > base_depth {}. \
This indicates leaked enter() calls without matching leave() calls.",
self.depth, self.base_depth,
);
}
}
}
#[cfg(test)]
#[path = "../tests/recursion_tests.rs"]
mod tests;