grit_pattern_matcher/pattern/
limit.rs

1use super::{
2    patterns::{Matcher, Pattern, PatternName},
3    state::State,
4};
5use crate::{context::ExecContext, context::QueryContext};
6use grit_util::{error::GritResult, AnalysisLogs};
7use std::sync::{
8    atomic::{AtomicUsize, Ordering},
9    Arc,
10};
11
12#[derive(Debug, Clone)]
13pub struct Limit<Q: QueryContext> {
14    pub pattern: Pattern<Q>,
15    pub limit: usize,
16    pub invocation_count: Arc<AtomicUsize>,
17}
18
19impl<Q: QueryContext> Limit<Q> {
20    pub fn new(pattern: Pattern<Q>, limit: usize) -> Self {
21        Self {
22            pattern,
23            limit,
24            invocation_count: Arc::new(AtomicUsize::new(0)),
25        }
26    }
27}
28
29impl<Q: QueryContext> PatternName for Limit<Q> {
30    fn name(&self) -> &'static str {
31        "LIMIT"
32    }
33}
34
35impl<Q: QueryContext> Matcher<Q> for Limit<Q> {
36    fn execute<'a>(
37        &'a self,
38        binding: &Q::ResolvedPattern<'a>,
39        state: &mut State<'a, Q>,
40        context: &'a Q::ExecContext<'a>,
41        logs: &mut AnalysisLogs,
42    ) -> GritResult<bool> {
43        if context.ignore_limit_pattern() {
44            let res = self.pattern.execute(binding, state, context, logs)?;
45            return Ok(res);
46        }
47        if self.invocation_count.load(Ordering::Relaxed) >= self.limit {
48            return Ok(false);
49        }
50        let res = self.pattern.execute(binding, state, context, logs)?;
51        if !res {
52            return Ok(false);
53        }
54        loop {
55            let current_count = self.invocation_count.load(Ordering::SeqCst);
56            if current_count >= self.limit {
57                return Ok(false);
58            }
59            let attempt_increment = self.invocation_count.compare_exchange(
60                current_count,
61                current_count + 1,
62                Ordering::SeqCst,
63                Ordering::Relaxed,
64            );
65            if attempt_increment.is_ok() {
66                break;
67            }
68        }
69        Ok(true)
70    }
71}