grit_pattern_matcher/pattern/
limit.rs1use super::{
2 patterns::{Matcher, Pattern, PatternName},
3 state::State,
4};
5use crate::{context::ExecContext, context::QueryContext};
6use grit_util::{error::GritResult, AnalysisLogs};
7use std::sync::{
8 atomic::{AtomicUsize, Ordering},
9 Arc,
10};
11
12#[derive(Debug, Clone)]
13pub struct Limit<Q: QueryContext> {
14 pub pattern: Pattern<Q>,
15 pub limit: usize,
16 pub invocation_count: Arc<AtomicUsize>,
17}
18
19impl<Q: QueryContext> Limit<Q> {
20 pub fn new(pattern: Pattern<Q>, limit: usize) -> Self {
21 Self {
22 pattern,
23 limit,
24 invocation_count: Arc::new(AtomicUsize::new(0)),
25 }
26 }
27}
28
29impl<Q: QueryContext> PatternName for Limit<Q> {
30 fn name(&self) -> &'static str {
31 "LIMIT"
32 }
33}
34
35impl<Q: QueryContext> Matcher<Q> for Limit<Q> {
36 fn execute<'a>(
37 &'a self,
38 binding: &Q::ResolvedPattern<'a>,
39 state: &mut State<'a, Q>,
40 context: &'a Q::ExecContext<'a>,
41 logs: &mut AnalysisLogs,
42 ) -> GritResult<bool> {
43 if context.ignore_limit_pattern() {
44 let res = self.pattern.execute(binding, state, context, logs)?;
45 return Ok(res);
46 }
47 if self.invocation_count.load(Ordering::Relaxed) >= self.limit {
48 return Ok(false);
49 }
50 let res = self.pattern.execute(binding, state, context, logs)?;
51 if !res {
52 return Ok(false);
53 }
54 loop {
55 let current_count = self.invocation_count.load(Ordering::SeqCst);
56 if current_count >= self.limit {
57 return Ok(false);
58 }
59 let attempt_increment = self.invocation_count.compare_exchange(
60 current_count,
61 current_count + 1,
62 Ordering::SeqCst,
63 Ordering::Relaxed,
64 );
65 if attempt_increment.is_ok() {
66 break;
67 }
68 }
69 Ok(true)
70 }
71}