Skip to main content

tsz_solver/
recursion.rs

1//! Unified recursion guard for cycle detection, depth limiting,
2//! and iteration bounding in recursive type computations.
3//!
4//! # Design
5//!
6//! `RecursionGuard` replaces the scattered `in_progress` / `visiting` / `depth` /
7//! `total_checks` fields that were manually reimplemented across `SubtypeChecker`,
8//! `TypeEvaluator`, `PropertyAccessEvaluator`, and others.
9//!
10//! It combines three safety mechanisms:
11//! 1. **Cycle detection** via a visiting set (`FxHashSet<K>`)
12//! 2. **Depth limiting** to prevent stack overflow
13//! 3. **Iteration bounding** to prevent infinite loops
14//!
15//! # Profiles
16//!
17//! [`RecursionProfile`] provides named presets that eliminate magic numbers and
18//! make the intent of each guard clear at the call site:
19//!
20//! ```ignore
21//! // Before (what does 50, 100_000 mean?)
22//! let guard = RecursionGuard::new(50, 100_000);
23//!
24//! // After (intent is clear, limits are centralized)
25//! let guard = RecursionGuard::with_profile(RecursionProfile::TypeEvaluation);
26//! ```
27//!
28//! # Safety
29//!
30//! - **Debug leak detection**: In debug builds, dropping a guard with active entries
31//!   triggers a panic, catching forgotten `leave()` calls.
32//! - **Debug double-leave detection**: In debug builds, leaving a key that isn't in
33//!   the visiting set triggers a panic.
34//! - **Overflow protection**: Iteration counting uses saturating arithmetic.
35//! - **Encapsulated exceeded state**: The `exceeded` flag is private; use
36//!   [`is_exceeded()`](RecursionGuard::is_exceeded) and
37//!   [`mark_exceeded()`](RecursionGuard::mark_exceeded).
38
39use rustc_hash::FxHashSet;
40use std::hash::Hash;
41
42// ---------------------------------------------------------------------------
43// RecursionProfile
44// ---------------------------------------------------------------------------
45
46/// Named recursion limit presets.
47///
48/// Each profile encodes a `(max_depth, max_iterations)` pair that is
49/// appropriate for a particular kind of recursive computation. Using profiles
50/// instead of raw numbers:
51/// - Documents *why* a guard exists at every call site
52/// - Centralises limit values so they can be tuned in one place
53/// - Prevents copy-paste of magic numbers
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum RecursionProfile {
56    /// Subtype checking: deep structural comparison of recursive types.
57    ///
58    /// Used by `SubtypeChecker` and `SubtypeTracer`.
59    /// Needs the deepest depth limit because structural comparison of
60    /// recursive types can legitimately nest deeply before a cycle is found.
61    ///
62    /// depth = 100, iterations = 100,000
63    SubtypeCheck,
64
65    /// Type evaluation: conditional types, mapped types, indexed access.
66    ///
67    /// Used by `TypeEvaluator` (both `TypeId` guard and `DefId` guard).
68    ///
69    /// depth = 50, iterations = 100,000
70    TypeEvaluation,
71
72    /// Generic type application / instantiation.
73    ///
74    /// Used by `TypeApplicationEvaluator`.
75    /// Matches TypeScript's instantiation depth limit for TS2589.
76    ///
77    /// depth = 50, iterations = 100,000
78    TypeApplication,
79
80    /// Property access resolution on complex types.
81    ///
82    /// Used by `PropertyAccessEvaluator`.
83    ///
84    /// depth = 50, iterations = 100,000
85    PropertyAccess,
86
87    /// Variance computation.
88    ///
89    /// Used by `VarianceVisitor`.
90    ///
91    /// depth = 50, iterations = 100,000
92    Variance,
93
94    /// Shape extraction for compatibility checking.
95    ///
96    /// Used by `ShapeExtractor`.
97    ///
98    /// depth = 50, iterations = 100,000
99    ShapeExtraction,
100
101    /// Shallow type traversal: contains-type checks, type collection.
102    ///
103    /// Used by `RecursiveTypeCollector`, `ContainsTypeChecker`.
104    /// Intentionally shallow — these just walk the top-level structure.
105    ///
106    /// depth = 20, iterations = 100,000
107    ShallowTraversal,
108
109    /// Const assertion processing.
110    ///
111    /// Used by `ConstAssertionVisitor`.
112    ///
113    /// depth = 50, iterations = 100,000
114    ConstAssertion,
115
116    // ----- Checker profiles -----
117    /// Expression type checking depth.
118    ///
119    /// Used by `ExpressionChecker`.
120    /// Generous limit for deeply nested expressions.
121    ///
122    /// depth = 500
123    ExpressionCheck,
124
125    /// Type node resolution depth.
126    ///
127    /// Used by `TypeNodeChecker`.
128    /// Generous limit for deeply nested type annotations.
129    ///
130    /// depth = 500
131    TypeNodeCheck,
132
133    /// Function call resolution depth.
134    ///
135    /// Used by `get_type_of_call_expression`.
136    /// Relatively low to catch infinite recursion in overload resolution.
137    ///
138    /// depth = 20
139    CallResolution,
140
141    /// General checker recursion depth.
142    ///
143    /// Used by `enter_recursion`/`leave_recursion` on checker functions.
144    ///
145    /// depth = 50
146    CheckerRecursion,
147
148    /// Custom limits for one-off or test scenarios.
149    Custom { max_depth: u32, max_iterations: u32 },
150}
151
152impl RecursionProfile {
153    /// Maximum recursion depth for this profile.
154    pub const fn max_depth(self) -> u32 {
155        match self {
156            Self::SubtypeCheck => 100,
157            Self::TypeEvaluation
158            | Self::TypeApplication
159            | Self::PropertyAccess
160            | Self::Variance
161            | Self::ShapeExtraction
162            | Self::ConstAssertion
163            | Self::CheckerRecursion => 50,
164            Self::ShallowTraversal | Self::CallResolution => 20,
165            Self::ExpressionCheck | Self::TypeNodeCheck => 500,
166            Self::Custom { max_depth, .. } => max_depth,
167        }
168    }
169
170    /// Maximum iteration count for this profile.
171    pub const fn max_iterations(self) -> u32 {
172        match self {
173            Self::SubtypeCheck
174            | Self::TypeEvaluation
175            | Self::TypeApplication
176            | Self::PropertyAccess
177            | Self::Variance
178            | Self::ShapeExtraction
179            | Self::ConstAssertion
180            | Self::ExpressionCheck
181            | Self::TypeNodeCheck
182            | Self::CallResolution
183            | Self::ShallowTraversal
184            | Self::CheckerRecursion => 100_000,
185            Self::Custom { max_iterations, .. } => max_iterations,
186        }
187    }
188}
189
190// ---------------------------------------------------------------------------
191// RecursionResult
192// ---------------------------------------------------------------------------
193
194/// Result of attempting to enter a recursive computation.
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum RecursionResult {
197    /// Proceed with the computation.
198    Entered,
199    /// This key is already being visited — cycle detected.
200    Cycle,
201    /// Maximum recursion depth exceeded.
202    DepthExceeded,
203    /// Maximum iteration count exceeded.
204    IterationExceeded,
205}
206
207impl RecursionResult {
208    /// Returns `true` if entry was successful.
209    #[inline]
210    pub const fn is_entered(self) -> bool {
211        matches!(self, Self::Entered)
212    }
213
214    /// Returns `true` if a cycle was detected.
215    #[inline]
216    pub const fn is_cycle(self) -> bool {
217        matches!(self, Self::Cycle)
218    }
219
220    /// Returns `true` if any limit was exceeded (depth or iterations).
221    #[inline]
222    pub const fn is_exceeded(self) -> bool {
223        matches!(self, Self::DepthExceeded | Self::IterationExceeded)
224    }
225
226    /// Returns `true` if entry was denied for any reason (cycle or exceeded).
227    #[inline]
228    pub const fn is_denied(self) -> bool {
229        !self.is_entered()
230    }
231}
232
233// ---------------------------------------------------------------------------
234// RecursionGuard
235// ---------------------------------------------------------------------------
236
237/// Tracks recursion state for cycle detection, depth limiting,
238/// and iteration bounding.
239///
240/// # Usage
241///
242/// ```ignore
243/// use crate::recursion::{RecursionGuard, RecursionProfile, RecursionResult};
244///
245/// let mut guard = RecursionGuard::with_profile(RecursionProfile::TypeEvaluation);
246///
247/// match guard.enter(key) {
248///     RecursionResult::Entered => {
249///         let result = do_work();
250///         guard.leave(key);
251///         result
252///     }
253///     RecursionResult::Cycle => handle_cycle(),
254///     RecursionResult::DepthExceeded
255///     | RecursionResult::IterationExceeded => handle_exceeded(),
256/// }
257/// ```
258///
259/// # Debug-mode safety
260///
261/// In debug builds (`#[cfg(debug_assertions)]`):
262/// - Dropping a guard with entries still in the visiting set panics.
263/// - Calling `leave(key)` with a key not in the visiting set panics.
264pub struct RecursionGuard<K: Hash + Eq + Copy> {
265    visiting: FxHashSet<K>,
266    depth: u32,
267    iterations: u32,
268    max_depth: u32,
269    max_iterations: u32,
270    max_visiting: u32,
271    exceeded: bool,
272}
273
274impl<K: Hash + Eq + Copy> RecursionGuard<K> {
275    /// Create a guard with explicit limits.
276    ///
277    /// Prefer [`with_profile`](Self::with_profile) for standard use cases.
278    pub fn new(max_depth: u32, max_iterations: u32) -> Self {
279        Self {
280            visiting: FxHashSet::default(),
281            depth: 0,
282            iterations: 0,
283            max_depth,
284            max_iterations,
285            max_visiting: 10_000,
286            exceeded: false,
287        }
288    }
289
290    /// Create a guard from a named [`RecursionProfile`].
291    pub fn with_profile(profile: RecursionProfile) -> Self {
292        Self::new(profile.max_depth(), profile.max_iterations())
293    }
294
295    /// Builder: set a custom max visiting-set size.
296    pub const fn with_max_visiting(mut self, max_visiting: u32) -> Self {
297        self.max_visiting = max_visiting;
298        self
299    }
300
301    // -----------------------------------------------------------------------
302    // Core enter / leave API
303    // -----------------------------------------------------------------------
304
305    /// Try to enter a recursive computation for `key`.
306    ///
307    /// Returns [`RecursionResult::Entered`] if the computation may proceed.
308    /// On success the caller **must** call [`leave`](Self::leave) with the
309    /// same key when done.
310    ///
311    /// The other variants indicate why entry was denied:
312    /// - [`Cycle`](RecursionResult::Cycle): `key` is already being visited.
313    /// - [`DepthExceeded`](RecursionResult::DepthExceeded): nesting is too deep.
314    /// - [`IterationExceeded`](RecursionResult::IterationExceeded): total work budget exhausted.
315    pub fn enter(&mut self, key: K) -> RecursionResult {
316        // Saturating add prevents overflow with very high max_iterations.
317        self.iterations = self.iterations.saturating_add(1);
318
319        if self.iterations > self.max_iterations {
320            self.exceeded = true;
321            return RecursionResult::IterationExceeded;
322        }
323        if self.depth >= self.max_depth {
324            self.exceeded = true;
325            return RecursionResult::DepthExceeded;
326        }
327        if self.visiting.contains(&key) {
328            return RecursionResult::Cycle;
329        }
330        if self.visiting.len() as u32 >= self.max_visiting {
331            self.exceeded = true;
332            return RecursionResult::DepthExceeded;
333        }
334
335        self.visiting.insert(key);
336        self.depth += 1;
337        RecursionResult::Entered
338    }
339
340    /// Leave a recursive computation for `key`.
341    ///
342    /// **Must** be called exactly once after every successful [`enter`](Self::enter).
343    ///
344    /// # Debug panics
345    ///
346    /// In debug builds, panics if `key` is not in the visiting set (double-leave
347    /// or leave without matching enter).
348    pub fn leave(&mut self, key: K) {
349        let was_present = self.visiting.remove(&key);
350
351        debug_assert!(
352            was_present,
353            "RecursionGuard::leave() called with a key that is not in the visiting set. \
354             This indicates a double-leave or a leave without a matching enter()."
355        );
356
357        self.depth = self.depth.saturating_sub(1);
358    }
359
360    // -----------------------------------------------------------------------
361    // Closure-based RAII helper
362    // -----------------------------------------------------------------------
363
364    /// Execute `f` inside a guarded scope.
365    ///
366    /// Calls `enter(key)`, runs `f` if entered, then calls `leave(key)`.
367    /// Returns `Ok(value)` on success or `Err(reason)` if entry was denied.
368    ///
369    /// This is the safest API when the guard is standalone (not a field of a
370    /// struct that `f` also needs to mutate).
371    ///
372    /// # Panic safety
373    ///
374    /// If `f` panics, `leave()` is **not** called — the entry leaks. This is
375    /// safe because the guard's `Drop` impl (debug builds) checks
376    /// `std::thread::panicking()` and suppresses the leak-detection panic
377    /// during unwinding.
378    pub fn scope<T>(&mut self, key: K, f: impl FnOnce() -> T) -> Result<T, RecursionResult> {
379        match self.enter(key) {
380            RecursionResult::Entered => {
381                let result = f();
382                self.leave(key);
383                Ok(result)
384            }
385            denied => Err(denied),
386        }
387    }
388
389    // -----------------------------------------------------------------------
390    // Query API
391    // -----------------------------------------------------------------------
392
393    /// Check if `key` is currently being visited (without entering).
394    #[inline]
395    pub fn is_visiting(&self, key: &K) -> bool {
396        self.visiting.contains(key)
397    }
398
399    /// Check if any currently-visiting key satisfies the predicate.
400    ///
401    /// Used for symbol-level cycle detection: the same interface may appear
402    /// with different `DefIds` in different checker contexts, so we need to
403    /// check all visiting entries for symbol-level matches.
404    pub fn is_visiting_any(&self, predicate: impl Fn(&K) -> bool) -> bool {
405        self.visiting.iter().any(predicate)
406    }
407
408    /// Current recursion depth (number of active entries on the stack).
409    #[inline]
410    pub const fn depth(&self) -> u32 {
411        self.depth
412    }
413
414    /// Total enter attempts so far (successful or not).
415    #[inline]
416    pub const fn iterations(&self) -> u32 {
417        self.iterations
418    }
419
420    /// Number of keys currently in the visiting set.
421    #[inline]
422    pub fn visiting_count(&self) -> usize {
423        self.visiting.len()
424    }
425
426    /// Returns `true` if the guard has any active entries.
427    #[inline]
428    pub const fn is_active(&self) -> bool {
429        self.depth > 0
430    }
431
432    /// The configured maximum depth.
433    #[inline]
434    pub const fn max_depth(&self) -> u32 {
435        self.max_depth
436    }
437
438    /// The configured maximum iterations.
439    #[inline]
440    pub const fn max_iterations(&self) -> u32 {
441        self.max_iterations
442    }
443
444    // -----------------------------------------------------------------------
445    // Exceeded-state management
446    // -----------------------------------------------------------------------
447
448    /// Returns `true` if any limit was previously exceeded.
449    ///
450    /// Once set, this flag stays `true` until [`reset()`](Self::reset) is called.
451    /// This is sticky: even if depth later decreases below the limit, the flag
452    /// remains set. This is intentional — callers use it to bail out early on
453    /// subsequent calls (e.g. TS2589 "excessively deep" diagnostics).
454    #[inline]
455    pub const fn is_exceeded(&self) -> bool {
456        self.exceeded
457    }
458
459    /// Manually mark the guard as exceeded.
460    ///
461    /// Useful when an external condition (e.g. distribution size limit) means
462    /// further recursion should be blocked.
463    #[inline]
464    pub const fn mark_exceeded(&mut self) {
465        self.exceeded = true;
466    }
467
468    // -----------------------------------------------------------------------
469    // Reset
470    // -----------------------------------------------------------------------
471
472    /// Reset all state while preserving configured limits.
473    ///
474    /// After reset the guard behaves as if freshly constructed.
475    pub fn reset(&mut self) {
476        self.visiting.clear();
477        self.depth = 0;
478        self.iterations = 0;
479        self.exceeded = false;
480    }
481}
482
483// ---------------------------------------------------------------------------
484// Debug-mode leak detection
485// ---------------------------------------------------------------------------
486
487#[cfg(debug_assertions)]
488impl<K: Hash + Eq + Copy> Drop for RecursionGuard<K> {
489    fn drop(&mut self) {
490        if !std::thread::panicking() && !self.visiting.is_empty() {
491            panic!(
492                "RecursionGuard dropped with {} active entries still in the visiting set. \
493                 This indicates leaked enter() calls without matching leave() calls.",
494                self.visiting.len(),
495            );
496        }
497    }
498}
499
500// ---------------------------------------------------------------------------
501// DepthCounter — depth-only guard (no cycle detection)
502// ---------------------------------------------------------------------------
503
504/// A lightweight depth counter for stack overflow protection.
505///
506/// Unlike [`RecursionGuard`], `DepthCounter` does not track which keys are
507/// being visited — it only limits nesting depth. Use this when:
508/// - The same node/key may be legitimately revisited (e.g., expression
509///   re-checking with different contextual types)
510/// - You only need stack overflow protection, not cycle detection
511///
512/// # Safety
513///
514/// Shares the same debug-mode safety features as `RecursionGuard`:
515/// - **Debug leak detection**: Dropping with depth > 0 panics.
516/// - **Debug underflow detection**: Calling `leave()` at depth 0 panics.
517///
518/// # Usage
519///
520/// ```ignore
521/// let mut counter = DepthCounter::with_profile(RecursionProfile::ExpressionCheck);
522///
523/// if !counter.enter() {
524///     return TypeId::ERROR; // depth exceeded
525/// }
526/// let result = do_work();
527/// counter.leave();
528/// result
529/// ```
530#[derive(Debug)]
531pub struct DepthCounter {
532    depth: u32,
533    max_depth: u32,
534    exceeded: bool,
535    /// The depth at construction time. Used to distinguish inherited depth
536    /// from depth added by this counter's own `enter()` calls.
537    /// Debug leak detection only fires if `depth > base_depth`.
538    base_depth: u32,
539}
540
541impl DepthCounter {
542    /// Create a counter with an explicit max depth.
543    ///
544    /// Prefer [`with_profile`](Self::with_profile) for standard use cases.
545    pub const fn new(max_depth: u32) -> Self {
546        Self {
547            depth: 0,
548            max_depth,
549            exceeded: false,
550            base_depth: 0,
551        }
552    }
553
554    /// Create a counter from a named [`RecursionProfile`].
555    ///
556    /// Only the profile's `max_depth` is used (iterations are not relevant
557    /// for a depth-only counter).
558    pub const fn with_profile(profile: RecursionProfile) -> Self {
559        Self::new(profile.max_depth())
560    }
561
562    /// Create a counter with an initial depth already set.
563    ///
564    /// Used when inheriting depth from a parent context to maintain
565    /// the overall depth limit across context boundaries. The inherited
566    /// depth is treated as the "base" — debug leak detection only fires
567    /// if depth exceeds this base at drop time.
568    pub const fn with_initial_depth(max_depth: u32, initial_depth: u32) -> Self {
569        Self {
570            depth: initial_depth,
571            max_depth,
572            exceeded: false,
573            base_depth: initial_depth,
574        }
575    }
576
577    /// Try to enter a deeper level.
578    ///
579    /// Returns `true` if the depth limit has not been reached and entry
580    /// is allowed. The caller **must** call [`leave`](Self::leave) when done.
581    ///
582    /// Returns `false` if the depth limit has been reached. The `exceeded`
583    /// flag is set and the depth is **not** incremented — do **not** call
584    /// `leave()` in this case.
585    #[inline]
586    pub const fn enter(&mut self) -> bool {
587        if self.depth >= self.max_depth {
588            self.exceeded = true;
589            return false;
590        }
591        self.depth += 1;
592        true
593    }
594
595    /// Leave the current depth level.
596    ///
597    /// **Must** be called exactly once after every successful [`enter`](Self::enter).
598    ///
599    /// # Debug panics
600    ///
601    /// In debug builds, panics if depth is already 0 (leave without enter).
602    #[inline]
603    pub fn leave(&mut self) {
604        debug_assert!(
605            self.depth > 0,
606            "DepthCounter::leave() called at depth 0. \
607             This indicates a leave without a matching enter()."
608        );
609        self.depth = self.depth.saturating_sub(1);
610    }
611
612    /// Current depth.
613    #[inline]
614    pub const fn depth(&self) -> u32 {
615        self.depth
616    }
617
618    /// The configured maximum depth.
619    #[inline]
620    pub const fn max_depth(&self) -> u32 {
621        self.max_depth
622    }
623
624    /// Returns `true` if the depth limit was previously exceeded.
625    ///
626    /// Sticky — stays `true` until [`reset`](Self::reset).
627    #[inline]
628    pub const fn is_exceeded(&self) -> bool {
629        self.exceeded
630    }
631
632    /// Manually mark as exceeded.
633    #[inline]
634    pub const fn mark_exceeded(&mut self) {
635        self.exceeded = true;
636    }
637
638    /// Reset to initial state, preserving the max depth and base depth.
639    pub const fn reset(&mut self) {
640        self.depth = self.base_depth;
641        self.exceeded = false;
642    }
643}
644
645#[cfg(debug_assertions)]
646impl Drop for DepthCounter {
647    fn drop(&mut self) {
648        if !std::thread::panicking() && self.depth > self.base_depth {
649            panic!(
650                "DepthCounter dropped with depth {} > base_depth {}. \
651                 This indicates leaked enter() calls without matching leave() calls.",
652                self.depth, self.base_depth,
653            );
654        }
655    }
656}
657
658// ---------------------------------------------------------------------------
659// Tests
660// ---------------------------------------------------------------------------
661
662#[cfg(test)]
663#[path = "../tests/recursion_tests.rs"]
664mod tests;