grit_pattern_matcher/pattern/
within.rs

1use super::{
2    patterns::{Matcher, Pattern, PatternName},
3    resolved_pattern::ResolvedPattern,
4    State,
5};
6use crate::{binding::Binding, context::QueryContext};
7use core::fmt::Debug;
8use grit_util::{error::GritResult, AnalysisLogs, AstNode};
9
10#[derive(Debug, Clone)]
11pub struct Within<Q: QueryContext> {
12    pub pattern: Pattern<Q>,
13    until: Option<Pattern<Q>>,
14}
15
16impl<Q: QueryContext> Within<Q> {
17    pub fn new(pattern: Pattern<Q>, until: Option<Pattern<Q>>) -> Self {
18        Self { pattern, until }
19    }
20}
21
22impl<Q: QueryContext> PatternName for Within<Q> {
23    fn name(&self) -> &'static str {
24        "WITHIN"
25    }
26}
27
28impl<Q: QueryContext> Matcher<Q> for Within<Q> {
29    fn execute<'a>(
30        &'a self,
31        binding: &Q::ResolvedPattern<'a>,
32        init_state: &mut State<'a, Q>,
33        context: &'a Q::ExecContext<'a>,
34        logs: &mut AnalysisLogs,
35    ) -> GritResult<bool> {
36        let mut did_match = false;
37        let mut cur_state = init_state.clone();
38
39        let state = cur_state.clone();
40        if self
41            .pattern
42            .execute(binding, &mut cur_state, context, logs)?
43        {
44            did_match = true;
45        } else {
46            cur_state = state;
47        }
48
49        let Some(node) = binding.get_last_binding().and_then(Binding::parent_node) else {
50            return Ok(did_match);
51        };
52        for n in node.ancestors() {
53            let state = cur_state.clone();
54            let resolved = ResolvedPattern::from_node_binding(n);
55            if self
56                .pattern
57                .execute(&resolved, &mut cur_state, context, logs)?
58            {
59                did_match = true;
60                // We still traverse upwards, so side effects can be applied to all ancestors
61            } else {
62                cur_state = state;
63
64                if let Some(until) = &self.until {
65                    if until.execute(&resolved, &mut cur_state, context, logs)? {
66                        break;
67                    }
68                }
69            }
70        }
71        if did_match {
72            *init_state = cur_state;
73            Ok(true)
74        } else {
75            Ok(false)
76        }
77    }
78}