use super::{ParseError, ParseResult};
use std::collections::HashSet;
pub const MAX_RECURSION_DEPTH: usize = 1000;
pub const PARSING_TIMEOUT_SECS: u64 = 120;
#[derive(Debug)]
pub struct StackSafeContext {
pub depth: usize,
pub max_depth: usize,
pub active_stack: Vec<(u32, u16)>,
pub completed_refs: HashSet<(u32, u16)>,
#[cfg(not(target_arch = "wasm32"))]
pub start_time: std::time::Instant,
#[cfg(not(target_arch = "wasm32"))]
pub timeout: std::time::Duration,
}
impl Default for StackSafeContext {
fn default() -> Self {
Self::new()
}
}
impl StackSafeContext {
pub fn new() -> Self {
Self {
depth: 0,
max_depth: MAX_RECURSION_DEPTH,
active_stack: Vec::new(),
completed_refs: HashSet::new(),
#[cfg(not(target_arch = "wasm32"))]
start_time: std::time::Instant::now(),
#[cfg(not(target_arch = "wasm32"))]
timeout: std::time::Duration::from_secs(PARSING_TIMEOUT_SECS),
}
}
#[allow(unused_variables)]
pub fn with_limits(max_depth: usize, timeout_secs: u64) -> Self {
Self {
depth: 0,
max_depth,
active_stack: Vec::new(),
completed_refs: HashSet::new(),
#[cfg(not(target_arch = "wasm32"))]
start_time: std::time::Instant::now(),
#[cfg(not(target_arch = "wasm32"))]
timeout: std::time::Duration::from_secs(timeout_secs),
}
}
pub fn enter(&mut self) -> ParseResult<()> {
if self.depth + 1 > self.max_depth {
return Err(ParseError::SyntaxError {
position: 0,
message: format!(
"Maximum recursion depth exceeded: {} (limit: {})",
self.depth + 1,
self.max_depth
),
});
}
self.depth += 1;
self.check_timeout()?;
Ok(())
}
pub fn exit(&mut self) {
if self.depth > 0 {
self.depth -= 1;
}
}
pub fn push_ref(&mut self, obj_num: u32, gen_num: u16) -> ParseResult<()> {
let ref_key = (obj_num, gen_num);
if self.active_stack.contains(&ref_key) {
return Err(ParseError::SyntaxError {
position: 0,
message: format!("Circular reference detected: {obj_num} {gen_num} R"),
});
}
self.active_stack.push(ref_key);
Ok(())
}
pub fn pop_ref(&mut self) {
if let Some(ref_key) = self.active_stack.pop() {
self.completed_refs.insert(ref_key);
}
}
pub fn check_timeout(&self) -> ParseResult<()> {
#[cfg(not(target_arch = "wasm32"))]
{
if self.start_time.elapsed() > self.timeout {
return Err(ParseError::SyntaxError {
position: 0,
message: format!("Parsing timeout exceeded: {}s", self.timeout.as_secs()),
});
}
}
Ok(())
}
pub fn child(&self) -> Self {
Self {
depth: self.depth,
max_depth: self.max_depth,
active_stack: self.active_stack.clone(),
completed_refs: self.completed_refs.clone(),
#[cfg(not(target_arch = "wasm32"))]
start_time: self.start_time,
#[cfg(not(target_arch = "wasm32"))]
timeout: self.timeout,
}
}
}
pub struct RecursionGuard<'a> {
context: &'a mut StackSafeContext,
}
impl<'a> RecursionGuard<'a> {
pub fn new(context: &'a mut StackSafeContext) -> ParseResult<Self> {
context.enter()?;
Ok(Self { context })
}
}
impl<'a> Drop for RecursionGuard<'a> {
fn drop(&mut self) {
self.context.exit();
}
}
pub struct ReferenceStackGuard<'a> {
context: &'a mut StackSafeContext,
}
impl<'a> ReferenceStackGuard<'a> {
pub fn new(context: &'a mut StackSafeContext, obj_num: u32, gen_num: u16) -> ParseResult<Self> {
context.push_ref(obj_num, gen_num)?;
Ok(Self { context })
}
}
impl<'a> Drop for ReferenceStackGuard<'a> {
fn drop(&mut self) {
self.context.pop_ref();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stack_safe_context_new() {
let context = StackSafeContext::new();
assert_eq!(context.depth, 0);
assert_eq!(context.max_depth, MAX_RECURSION_DEPTH);
assert!(context.active_stack.is_empty());
assert!(context.completed_refs.is_empty());
}
#[test]
fn test_stack_safe_context_default() {
let context = StackSafeContext::default();
assert_eq!(context.depth, 0);
assert_eq!(context.max_depth, MAX_RECURSION_DEPTH);
}
#[test]
fn test_stack_safe_context_with_limits() {
let context = StackSafeContext::with_limits(50, 30);
assert_eq!(context.depth, 0);
assert_eq!(context.max_depth, 50);
#[cfg(not(target_arch = "wasm32"))]
assert_eq!(context.timeout.as_secs(), 30);
}
#[test]
fn test_recursion_limits() {
let mut context = StackSafeContext::with_limits(3, 60);
assert!(context.enter().is_ok());
assert_eq!(context.depth, 1);
assert!(context.enter().is_ok());
assert_eq!(context.depth, 2);
assert!(context.enter().is_ok());
assert_eq!(context.depth, 3);
assert!(context.enter().is_err());
context.exit();
assert_eq!(context.depth, 2);
}
#[test]
fn test_enter_increments_depth() {
let mut context = StackSafeContext::new();
assert_eq!(context.depth, 0);
context.enter().unwrap();
assert_eq!(context.depth, 1);
context.enter().unwrap();
assert_eq!(context.depth, 2);
}
#[test]
fn test_exit_decrements_depth() {
let mut context = StackSafeContext::new();
context.enter().unwrap();
context.enter().unwrap();
assert_eq!(context.depth, 2);
context.exit();
assert_eq!(context.depth, 1);
context.exit();
assert_eq!(context.depth, 0);
}
#[test]
fn test_exit_at_zero_does_not_underflow() {
let mut context = StackSafeContext::new();
assert_eq!(context.depth, 0);
context.exit(); assert_eq!(context.depth, 0);
context.exit(); assert_eq!(context.depth, 0);
}
#[test]
fn test_cycle_detection() {
let mut context = StackSafeContext::new();
assert!(context.push_ref(1, 0).is_ok());
assert!(context.push_ref(1, 0).is_err());
assert!(context.push_ref(2, 0).is_ok());
context.pop_ref(); context.pop_ref();
assert!(context.push_ref(1, 0).is_ok());
}
#[test]
fn test_push_ref_adds_to_active_stack() {
let mut context = StackSafeContext::new();
assert!(context.active_stack.is_empty());
context.push_ref(10, 5).unwrap();
assert_eq!(context.active_stack.len(), 1);
assert!(context.active_stack.contains(&(10, 5)));
}
#[test]
fn test_pop_ref_marks_as_completed() {
let mut context = StackSafeContext::new();
context.push_ref(7, 3).unwrap();
assert!(context.completed_refs.is_empty());
context.pop_ref();
assert!(context.active_stack.is_empty());
assert!(context.completed_refs.contains(&(7, 3)));
}
#[test]
fn test_pop_ref_on_empty_stack() {
let mut context = StackSafeContext::new();
assert!(context.active_stack.is_empty());
context.pop_ref();
assert!(context.active_stack.is_empty());
}
#[test]
fn test_multiple_refs_stack_order() {
let mut context = StackSafeContext::new();
context.push_ref(1, 0).unwrap();
context.push_ref(2, 0).unwrap();
context.push_ref(3, 0).unwrap();
assert_eq!(context.active_stack.len(), 3);
context.pop_ref(); assert!(context.completed_refs.contains(&(3, 0)));
assert!(!context.completed_refs.contains(&(2, 0)));
assert!(!context.completed_refs.contains(&(1, 0)));
context.pop_ref(); assert!(context.completed_refs.contains(&(2, 0)));
context.pop_ref(); assert!(context.completed_refs.contains(&(1, 0)));
}
#[test]
fn test_check_timeout_within_limit() {
let context = StackSafeContext::with_limits(100, 60);
assert!(context.check_timeout().is_ok());
}
#[test]
fn test_child_context() {
let mut context = StackSafeContext::with_limits(50, 30);
context.enter().unwrap();
context.enter().unwrap();
context.push_ref(5, 0).unwrap();
context.pop_ref();
let child = context.child();
assert_eq!(child.depth, context.depth);
assert_eq!(child.max_depth, context.max_depth);
assert!(child.completed_refs.contains(&(5, 0)));
}
#[test]
fn test_child_context_is_independent() {
let mut context = StackSafeContext::new();
context.enter().unwrap();
let child = context.child();
assert_eq!(child.depth, 1);
context.exit();
assert_eq!(context.depth, 0);
}
#[test]
fn test_different_generation_numbers() {
let mut context = StackSafeContext::new();
context.push_ref(1, 0).unwrap();
context.push_ref(1, 1).unwrap(); context.push_ref(1, 2).unwrap();
assert_eq!(context.active_stack.len(), 3);
}
#[test]
fn test_recursion_guard() {
let mut context = StackSafeContext::new();
assert_eq!(context.depth, 0);
{
let _guard = RecursionGuard::new(&mut context).unwrap();
}
assert_eq!(context.depth, 0);
}
#[test]
fn test_recursion_guard_nesting() {
let mut context = StackSafeContext::with_limits(10, 60);
{
let _guard1 = RecursionGuard::new(&mut context).unwrap();
}
assert_eq!(context.depth, 0);
context.enter().unwrap();
context.enter().unwrap();
assert_eq!(context.depth, 2);
}
#[test]
fn test_recursion_guard_fails_at_limit() {
let mut context = StackSafeContext::with_limits(1, 60);
context.enter().unwrap();
let result = RecursionGuard::new(&mut context);
assert!(result.is_err());
}
#[test]
fn test_reference_stack_guard() {
let mut context = StackSafeContext::new();
{
let _guard = ReferenceStackGuard::new(&mut context, 1, 0).unwrap();
}
assert_eq!(context.active_stack.len(), 0);
assert!(context.completed_refs.contains(&(1, 0)));
assert!(context.push_ref(1, 0).is_ok());
}
#[test]
fn test_reference_stack_guard_circular_detection() {
let mut context = StackSafeContext::new();
context.push_ref(5, 0).unwrap();
let result = ReferenceStackGuard::new(&mut context, 5, 0);
assert!(result.is_err());
}
#[test]
fn test_constants() {
assert_eq!(MAX_RECURSION_DEPTH, 1000);
assert_eq!(PARSING_TIMEOUT_SECS, 120);
}
#[test]
fn test_stack_safe_context_debug() {
let context = StackSafeContext::new();
let debug_str = format!("{:?}", context);
assert!(debug_str.contains("StackSafeContext"));
assert!(debug_str.contains("depth: 0"));
}
#[test]
fn test_deep_recursion_simulation() {
let mut context = StackSafeContext::with_limits(100, 60);
for i in 0..100 {
assert!(context.enter().is_ok(), "Failed at depth {}", i);
}
assert_eq!(context.depth, 100);
assert!(context.enter().is_err());
for _ in 0..100 {
context.exit();
}
assert_eq!(context.depth, 0);
}
#[test]
fn test_complex_reference_scenario() {
let mut context = StackSafeContext::new();
context.push_ref(1, 0).unwrap(); context.push_ref(2, 0).unwrap(); context.push_ref(3, 0).unwrap();
assert!(context.push_ref(3, 0).is_err());
context.push_ref(4, 0).unwrap();
context.pop_ref();
context.pop_ref();
context.pop_ref();
context.pop_ref();
assert!(context.completed_refs.contains(&(1, 0)));
assert!(context.completed_refs.contains(&(2, 0)));
assert!(context.completed_refs.contains(&(3, 0)));
assert!(context.completed_refs.contains(&(4, 0)));
context.push_ref(3, 0).unwrap();
}
}