alopex_sql/parser/
recursion.rs1use crate::error::ParserError;
2
3pub const DEFAULT_RECURSION_LIMIT: usize = 50;
4
5#[derive(Debug, Clone)]
7pub struct RecursionCounter {
8 max_depth: usize,
9 remaining_depth: std::cell::Cell<usize>,
10}
11
12impl RecursionCounter {
13 pub fn new(max_depth: usize) -> Self {
14 Self {
15 max_depth,
16 remaining_depth: std::cell::Cell::new(max_depth),
17 }
18 }
19
20 pub fn try_decrease(&self) -> Result<DepthGuard, ParserError> {
21 let remaining = self.remaining_depth.get();
22 if remaining == 0 {
23 Err(ParserError::RecursionLimitExceeded {
24 depth: self.max_depth + 1, })
26 } else {
27 self.remaining_depth.set(remaining - 1);
28 Ok(DepthGuard {
29 counter: self as *const RecursionCounter,
30 })
31 }
32 }
33
34 pub fn current_depth(&self) -> usize {
35 self.max_depth - self.remaining_depth.get()
36 }
37}
38
39#[derive(Debug)]
40pub struct DepthGuard {
41 counter: *const RecursionCounter,
42}
43
44impl Drop for DepthGuard {
45 fn drop(&mut self) {
46 let counter = unsafe { &*self.counter };
48 let old_value = counter.remaining_depth.get();
49 counter.remaining_depth.set(old_value + 1);
50 }
51}
52
53#[cfg(test)]
54mod tests {
55 use super::*;
56
57 #[test]
58 fn reports_overflow_depth_as_limit_plus_one() {
59 let counter = RecursionCounter::new(2);
60 let _g1 = counter.try_decrease().unwrap();
61 let _g2 = counter.try_decrease().unwrap();
62 let err = counter.try_decrease().unwrap_err();
63 match err {
64 ParserError::RecursionLimitExceeded { depth } => assert_eq!(depth, 3),
65 other => panic!("expected recursion limit error, got {:?}", other),
66 }
67 }
68}