alopex_sql/parser/
recursion.rs

1use crate::error::ParserError;
2
3pub const DEFAULT_RECURSION_LIMIT: usize = 50;
4
5/// Prevents runaway recursion during parsing.
6#[derive(Debug, Clone)]
7pub struct RecursionCounter {
8    max_depth: usize,
9    remaining_depth: std::cell::Cell<usize>,
10}
11
12impl RecursionCounter {
13    pub fn new(max_depth: usize) -> Self {
14        Self {
15            max_depth,
16            remaining_depth: std::cell::Cell::new(max_depth),
17        }
18    }
19
20    pub fn try_decrease(&self) -> Result<DepthGuard, ParserError> {
21        let remaining = self.remaining_depth.get();
22        if remaining == 0 {
23            Err(ParserError::RecursionLimitExceeded {
24                depth: self.max_depth + 1, // actual depth attempted
25            })
26        } else {
27            self.remaining_depth.set(remaining - 1);
28            Ok(DepthGuard {
29                counter: self as *const RecursionCounter,
30            })
31        }
32    }
33
34    pub fn current_depth(&self) -> usize {
35        self.max_depth - self.remaining_depth.get()
36    }
37}
38
39#[derive(Debug)]
40pub struct DepthGuard {
41    counter: *const RecursionCounter,
42}
43
44impl Drop for DepthGuard {
45    fn drop(&mut self) {
46        // Safety: counter points to self.recursion inside Parser; guard does not outlive Parser.
47        let counter = unsafe { &*self.counter };
48        let old_value = counter.remaining_depth.get();
49        counter.remaining_depth.set(old_value + 1);
50    }
51}
52
53#[cfg(test)]
54mod tests {
55    use super::*;
56
57    #[test]
58    fn reports_overflow_depth_as_limit_plus_one() {
59        let counter = RecursionCounter::new(2);
60        let _g1 = counter.try_decrease().unwrap();
61        let _g2 = counter.try_decrease().unwrap();
62        let err = counter.try_decrease().unwrap_err();
63        match err {
64            ParserError::RecursionLimitExceeded { depth } => assert_eq!(depth, 3),
65            other => panic!("expected recursion limit error, got {:?}", other),
66        }
67    }
68}