Skip to main content

pdf_ast/performance/
limits.rs

1use std::sync::{Arc, Mutex};
2use std::time::{Duration, Instant};
3
4#[derive(Debug, Clone)]
5pub struct PerformanceLimits {
6    pub max_nodes: usize,
7    pub max_edges: usize,
8    pub max_memory_mb: usize,
9    pub max_parse_time: Duration,
10    pub max_query_time: Duration,
11    pub max_depth: usize,
12    pub max_file_size_mb: usize,
13    pub max_object_size_mb: usize,
14    pub max_stream_decode_ratio: usize,
15    pub max_concurrent_parsers: usize,
16    pub enable_timeout_checks: bool,
17    pub enable_memory_checks: bool,
18    pub enable_recursion_checks: bool,
19}
20
21impl Default for PerformanceLimits {
22    fn default() -> Self {
23        Self {
24            max_nodes: 1_000_000,
25            max_edges: 5_000_000,
26            max_memory_mb: 1024,
27            max_parse_time: Duration::from_secs(300), // 5 minutes
28            max_query_time: Duration::from_secs(30),
29            max_depth: 1000,
30            max_file_size_mb: 100,
31            max_object_size_mb: 50,
32            max_stream_decode_ratio: 100,
33            max_concurrent_parsers: 4,
34            enable_timeout_checks: true,
35            enable_memory_checks: true,
36            enable_recursion_checks: true,
37        }
38    }
39}
40
41impl PerformanceLimits {
42    pub fn conservative() -> Self {
43        Self {
44            max_nodes: 100_000,
45            max_edges: 500_000,
46            max_memory_mb: 256,
47            max_parse_time: Duration::from_secs(60),
48            max_query_time: Duration::from_secs(10),
49            max_depth: 100,
50            max_file_size_mb: 10,
51            max_object_size_mb: 5,
52            max_stream_decode_ratio: 50,
53            max_concurrent_parsers: 2,
54            ..Default::default()
55        }
56    }
57
58    pub fn permissive() -> Self {
59        Self {
60            max_nodes: 10_000_000,
61            max_edges: 50_000_000,
62            max_memory_mb: 4096,
63            max_parse_time: Duration::from_secs(1800), // 30 minutes
64            max_query_time: Duration::from_secs(120),
65            max_depth: 10000,
66            max_file_size_mb: 1000,
67            max_object_size_mb: 500,
68            max_stream_decode_ratio: 200,
69            max_concurrent_parsers: 8,
70            ..Default::default()
71        }
72    }
73}
74
75#[derive(Debug)]
76pub struct PerformanceGuard {
77    limits: PerformanceLimits,
78    start_time: Instant,
79    node_count: usize,
80    edge_count: usize,
81    current_depth: usize,
82    max_depth_reached: usize,
83    memory_usage: Arc<Mutex<usize>>,
84    operation_name: String,
85}
86
87#[derive(Debug, Clone)]
88pub enum PerformanceViolation {
89    TooManyNodes(usize, usize),
90    TooManyEdges(usize, usize),
91    ExcessiveMemory(usize, usize),
92    Timeout(Duration, Duration),
93    ExcessiveDepth(usize, usize),
94    FileTooLarge(usize, usize),
95    ObjectTooLarge(usize, usize),
96    TooManyConcurrentParsers(usize, usize),
97}
98
99impl std::fmt::Display for PerformanceViolation {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        match self {
102            PerformanceViolation::TooManyNodes(current, max) => {
103                write!(f, "Too many nodes: {} > {}", current, max)
104            }
105            PerformanceViolation::TooManyEdges(current, max) => {
106                write!(f, "Too many edges: {} > {}", current, max)
107            }
108            PerformanceViolation::ExcessiveMemory(current, max) => {
109                write!(f, "Excessive memory usage: {}MB > {}MB", current, max)
110            }
111            PerformanceViolation::Timeout(current, max) => {
112                write!(f, "Operation timeout: {:?} > {:?}", current, max)
113            }
114            PerformanceViolation::ExcessiveDepth(current, max) => {
115                write!(f, "Excessive recursion depth: {} > {}", current, max)
116            }
117            PerformanceViolation::FileTooLarge(current, max) => {
118                write!(f, "File too large: {}MB > {}MB", current, max)
119            }
120            PerformanceViolation::ObjectTooLarge(current, max) => {
121                write!(f, "Object too large: {}MB > {}MB", current, max)
122            }
123            PerformanceViolation::TooManyConcurrentParsers(current, max) => {
124                write!(f, "Too many concurrent parsers: {} > {}", current, max)
125            }
126        }
127    }
128}
129
130impl std::error::Error for PerformanceViolation {}
131
132impl PerformanceGuard {
133    pub fn new(limits: PerformanceLimits, operation_name: &str) -> Self {
134        Self {
135            limits,
136            start_time: Instant::now(),
137            node_count: 0,
138            edge_count: 0,
139            current_depth: 0,
140            max_depth_reached: 0,
141            memory_usage: Arc::new(Mutex::new(0)),
142            operation_name: operation_name.to_string(),
143        }
144    }
145
146    pub fn check_file_size(&self, size_bytes: usize) -> Result<(), PerformanceViolation> {
147        let size_mb = size_bytes / (1024 * 1024);
148        if size_mb > self.limits.max_file_size_mb {
149            return Err(PerformanceViolation::FileTooLarge(
150                size_mb,
151                self.limits.max_file_size_mb,
152            ));
153        }
154        Ok(())
155    }
156
157    pub fn check_object_size(&self, size_bytes: usize) -> Result<(), PerformanceViolation> {
158        let size_mb = size_bytes / (1024 * 1024);
159        if size_mb > self.limits.max_object_size_mb {
160            return Err(PerformanceViolation::ObjectTooLarge(
161                size_mb,
162                self.limits.max_object_size_mb,
163            ));
164        }
165        Ok(())
166    }
167
168    pub fn check_nodes(&self, count: usize) -> Result<(), PerformanceViolation> {
169        if count > self.limits.max_nodes {
170            return Err(PerformanceViolation::TooManyNodes(
171                count,
172                self.limits.max_nodes,
173            ));
174        }
175        Ok(())
176    }
177
178    pub fn check_edges(&self, count: usize) -> Result<(), PerformanceViolation> {
179        if count > self.limits.max_edges {
180            return Err(PerformanceViolation::TooManyEdges(
181                count,
182                self.limits.max_edges,
183            ));
184        }
185        Ok(())
186    }
187
188    pub fn check_timeout(&self, max_duration: Duration) -> Result<(), PerformanceViolation> {
189        if !self.limits.enable_timeout_checks {
190            return Ok(());
191        }
192
193        let elapsed = self.start_time.elapsed();
194        if elapsed > max_duration {
195            return Err(PerformanceViolation::Timeout(elapsed, max_duration));
196        }
197        Ok(())
198    }
199
200    pub fn check_parse_timeout(&self) -> Result<(), PerformanceViolation> {
201        self.check_timeout(self.limits.max_parse_time)
202    }
203
204    pub fn check_query_timeout(&self) -> Result<(), PerformanceViolation> {
205        self.check_timeout(self.limits.max_query_time)
206    }
207
208    pub fn enter_recursion(&mut self) -> Result<RecursionGuard<'_>, PerformanceViolation> {
209        if !self.limits.enable_recursion_checks {
210            return Ok(RecursionGuard::new(self, 0));
211        }
212
213        self.current_depth += 1;
214        if self.current_depth > self.max_depth_reached {
215            self.max_depth_reached = self.current_depth;
216        }
217
218        if self.current_depth > self.limits.max_depth {
219            return Err(PerformanceViolation::ExcessiveDepth(
220                self.current_depth,
221                self.limits.max_depth,
222            ));
223        }
224
225        Ok(RecursionGuard::new(self, self.current_depth))
226    }
227
228    pub fn track_memory_allocation(&self, bytes: usize) -> Result<(), PerformanceViolation> {
229        if !self.limits.enable_memory_checks {
230            return Ok(());
231        }
232
233        if let Ok(mut usage) = self.memory_usage.lock() {
234            *usage += bytes;
235            let usage_mb = *usage / (1024 * 1024);
236            if usage_mb > self.limits.max_memory_mb {
237                return Err(PerformanceViolation::ExcessiveMemory(
238                    usage_mb,
239                    self.limits.max_memory_mb,
240                ));
241            }
242        }
243        Ok(())
244    }
245
246    pub fn track_memory_deallocation(&self, bytes: usize) {
247        if let Ok(mut usage) = self.memory_usage.lock() {
248            *usage = usage.saturating_sub(bytes);
249        }
250    }
251
252    pub fn increment_nodes(&mut self) -> Result<(), PerformanceViolation> {
253        self.node_count += 1;
254        self.check_nodes(self.node_count)
255    }
256
257    pub fn increment_edges(&mut self) -> Result<(), PerformanceViolation> {
258        self.edge_count += 1;
259        self.check_edges(self.edge_count)
260    }
261
262    pub fn get_stats(&self) -> PerformanceStats {
263        PerformanceStats {
264            operation_name: self.operation_name.clone(),
265            elapsed_time: self.start_time.elapsed(),
266            node_count: self.node_count,
267            edge_count: self.edge_count,
268            max_depth_reached: self.max_depth_reached,
269            memory_usage_mb: self.memory_usage.lock().map(|guard| *guard).unwrap_or(0)
270                / (1024 * 1024),
271        }
272    }
273
274    fn exit_recursion(&mut self) {
275        if self.current_depth > 0 {
276            self.current_depth -= 1;
277        }
278    }
279}
280
281pub struct RecursionGuard<'a> {
282    guard: &'a mut PerformanceGuard,
283    depth: usize,
284}
285
286impl<'a> RecursionGuard<'a> {
287    fn new(guard: &'a mut PerformanceGuard, depth: usize) -> Self {
288        Self { guard, depth }
289    }
290
291    pub fn depth(&self) -> usize {
292        self.depth
293    }
294}
295
296impl<'a> Drop for RecursionGuard<'a> {
297    fn drop(&mut self) {
298        self.guard.exit_recursion();
299    }
300}
301
302#[derive(Debug, Clone)]
303pub struct PerformanceStats {
304    pub operation_name: String,
305    pub elapsed_time: Duration,
306    pub node_count: usize,
307    pub edge_count: usize,
308    pub max_depth_reached: usize,
309    pub memory_usage_mb: usize,
310}
311
312impl PerformanceStats {
313    pub fn nodes_per_second(&self) -> f64 {
314        if self.elapsed_time.as_secs() == 0 {
315            return self.node_count as f64;
316        }
317        self.node_count as f64 / self.elapsed_time.as_secs_f64()
318    }
319
320    pub fn edges_per_second(&self) -> f64 {
321        if self.elapsed_time.as_secs() == 0 {
322            return self.edge_count as f64;
323        }
324        self.edge_count as f64 / self.elapsed_time.as_secs_f64()
325    }
326}
327
328pub struct ConcurrencyGuard {
329    limits: PerformanceLimits,
330    active_parsers: Arc<Mutex<usize>>,
331}
332
333impl ConcurrencyGuard {
334    pub fn new(limits: PerformanceLimits) -> Self {
335        Self {
336            limits,
337            active_parsers: Arc::new(Mutex::new(0)),
338        }
339    }
340
341    pub fn acquire_parser_slot(&self) -> Result<ParserSlot, PerformanceViolation> {
342        let mut active = self
343            .active_parsers
344            .lock()
345            .map_err(|_| PerformanceViolation::TooManyConcurrentParsers(0, 0))?;
346
347        if *active >= self.limits.max_concurrent_parsers {
348            return Err(PerformanceViolation::TooManyConcurrentParsers(
349                *active,
350                self.limits.max_concurrent_parsers,
351            ));
352        }
353
354        *active += 1;
355        Ok(ParserSlot {
356            active_parsers: self.active_parsers.clone(),
357        })
358    }
359}
360
361pub struct ParserSlot {
362    active_parsers: Arc<Mutex<usize>>,
363}
364
365impl Drop for ParserSlot {
366    fn drop(&mut self) {
367        if let Ok(mut active) = self.active_parsers.lock() {
368            *active = active.saturating_sub(1);
369        }
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376    use std::thread;
377
378    #[test]
379    fn test_performance_limits() {
380        let limits = PerformanceLimits::conservative();
381        let mut guard = PerformanceGuard::new(limits, "test");
382
383        // Test node limit
384        for _ in 0..100_000 {
385            guard.increment_nodes().unwrap();
386        }
387
388        // Should fail on the next increment
389        assert!(guard.increment_nodes().is_err());
390    }
391
392    #[test]
393    fn test_recursion_guard() {
394        let limits = PerformanceLimits::conservative();
395        let mut guard = PerformanceGuard::new(limits, "test");
396
397        // Test that we can enter recursion initially
398        let rguard1 = guard
399            .enter_recursion()
400            .expect("Should be able to enter recursion");
401
402        // Test depth tracking
403        assert_eq!(rguard1.depth(), 1);
404
405        // Drop the guard to allow further recursion
406        drop(rguard1);
407
408        // Should be able to enter again after dropping
409        let _rguard2 = guard
410            .enter_recursion()
411            .expect("Should be able to enter recursion again");
412    }
413
414    #[test]
415    fn test_concurrency_guard() {
416        let limits = PerformanceLimits::conservative();
417        let guard = ConcurrencyGuard::new(limits);
418
419        // Acquire max slots
420        let mut slots = Vec::new();
421        for _ in 0..2 {
422            slots.push(guard.acquire_parser_slot().unwrap());
423        }
424
425        // Should fail to acquire another slot
426        assert!(guard.acquire_parser_slot().is_err());
427    }
428
429    #[test]
430    fn test_timeout_check() {
431        let limits = PerformanceLimits {
432            max_parse_time: Duration::from_millis(10),
433            ..Default::default()
434        };
435        let guard = PerformanceGuard::new(limits, "test");
436
437        thread::sleep(Duration::from_millis(20));
438        assert!(guard.check_parse_timeout().is_err());
439    }
440}