grit_pattern_matcher/pattern/
accessor.rs

1use super::{
2    container::{Container, PatternOrResolved, PatternOrResolvedMut},
3    map::GritMap,
4    patterns::{Matcher, Pattern, PatternName},
5    resolved_pattern::ResolvedPattern,
6    state::State,
7    variable::Variable,
8};
9use crate::{
10    binding::Binding,
11    context::{ExecContext, QueryContext},
12};
13use grit_util::{
14    error::{GritPatternError, GritResult},
15    AnalysisLogs,
16};
17use std::borrow::Cow;
18
19#[derive(Debug, Clone)]
20pub enum AccessorMap<Q: QueryContext> {
21    Container(Container<Q>),
22    Map(GritMap<Q>),
23}
24
25#[derive(Debug, Clone)]
26pub struct Accessor<Q: QueryContext> {
27    pub map: AccessorMap<Q>,
28    pub key: AccessorKey,
29}
30
31#[derive(Debug, Clone)]
32pub enum AccessorKey {
33    String(String),
34    Variable(Variable),
35}
36
37impl<Q: QueryContext> Accessor<Q> {
38    pub fn new(map: AccessorMap<Q>, key: AccessorKey) -> Self {
39        Self { map, key }
40    }
41
42    fn get_key<'a>(
43        &'a self,
44        state: &State<'a, Q>,
45        lang: &Q::Language<'a>,
46    ) -> GritResult<Cow<'a, str>> {
47        match &self.key {
48            AccessorKey::String(s) => Ok(Cow::Borrowed(s)),
49            AccessorKey::Variable(v) => v.text(state, lang),
50        }
51    }
52
53    pub fn get<'a, 'b>(
54        &'a self,
55        state: &'b State<'a, Q>,
56        lang: &Q::Language<'a>,
57    ) -> GritResult<Option<PatternOrResolved<'a, 'b, Q>>> {
58        let key = self.get_key(state, lang)?;
59        match &self.map {
60            AccessorMap::Container(c) => match c.get_pattern_or_resolved(state, lang)? {
61                None => Ok(None),
62                Some(PatternOrResolved::Pattern(Pattern::Map(m))) => {
63                    Ok(m.get(&key).map(PatternOrResolved::Pattern))
64                }
65                Some(PatternOrResolved::Resolved(resolved)) => match resolved.get_map() {
66                    Some(m) => Ok(m.get(key.as_ref()).map(PatternOrResolved::Resolved)),
67                    None => Err(GritPatternError::new(
68                        "left side of an accessor must be a map",
69                    )),
70                },
71                Some(_) => Err(GritPatternError::new(
72                    "left side of an accessor must be a map",
73                )),
74            },
75            AccessorMap::Map(m) => Ok(m.get(&key).map(PatternOrResolved::Pattern)),
76        }
77    }
78
79    pub fn get_mut<'a, 'b>(
80        &'a self,
81        state: &'b mut State<'a, Q>,
82        lang: &Q::Language<'a>,
83    ) -> GritResult<Option<PatternOrResolvedMut<'a, 'b, Q>>> {
84        let key = self.get_key(state, lang)?;
85        match &self.map {
86            AccessorMap::Container(c) => match c.get_pattern_or_resolved_mut(state, lang)? {
87                None => Ok(None),
88                Some(PatternOrResolvedMut::Pattern(Pattern::Map(m))) => {
89                    Ok(m.get(&key).map(PatternOrResolvedMut::Pattern))
90                }
91                Some(PatternOrResolvedMut::Resolved(resolved)) => match resolved.get_map_mut() {
92                    Some(m) => Ok(m.get_mut(key.as_ref()).map(PatternOrResolvedMut::Resolved)),
93                    None => Err(GritPatternError::new(
94                        "left side of an accessor must be a map",
95                    )),
96                },
97                Some(_) => Err(GritPatternError::new(
98                    "left side of an accessor must be a map",
99                )),
100            },
101            AccessorMap::Map(m) => Ok(m.get(&key).map(PatternOrResolvedMut::Pattern)),
102        }
103    }
104
105    pub fn set_resolved<'a>(
106        &'a self,
107        state: &mut State<'a, Q>,
108        lang: &Q::Language<'a>,
109        value: Q::ResolvedPattern<'a>,
110    ) -> GritResult<bool> {
111        match &self.map {
112            AccessorMap::Container(c) => {
113                let key = self.get_key(state, lang)?;
114                match c.get_pattern_or_resolved_mut(state, lang)? {
115                    None => Ok(false),
116                    Some(PatternOrResolvedMut::Resolved(resolved)) => {
117                        if let Some(m) = resolved.get_map_mut() {
118                            m.insert(key.to_string(), value);
119                            Ok(true)
120                        } else {
121                            Err(GritPatternError::new(
122                                "accessor can only mutate a resolved map",
123                            ))
124                        }
125                    }
126                    Some(_) => Err(GritPatternError::new(
127                        "accessor can only mutate a resolved map",
128                    )),
129                }
130            }
131            AccessorMap::Map(_) => Err(GritPatternError::new("cannot mutate a map literal")),
132        }
133    }
134}
135
136impl<Q: QueryContext> PatternName for Accessor<Q> {
137    fn name(&self) -> &'static str {
138        "ACCESSOR"
139    }
140}
141
142impl<Q: QueryContext> Matcher<Q> for Accessor<Q> {
143    fn execute<'a>(
144        &'a self,
145        binding: &Q::ResolvedPattern<'a>,
146        state: &mut State<'a, Q>,
147        context: &'a Q::ExecContext<'a>,
148        logs: &mut AnalysisLogs,
149    ) -> GritResult<bool> {
150        match self.get(state, context.language())? {
151            Some(PatternOrResolved::Resolved(r)) => {
152                execute_resolved_with_binding(r, binding, state, context.language())
153            }
154            Some(PatternOrResolved::ResolvedBinding(r)) => {
155                execute_resolved_with_binding(&r, binding, state, context.language())
156            }
157            Some(PatternOrResolved::Pattern(p)) => p.execute(binding, state, context, logs),
158            None => Ok(binding.matches_false_or_undefined()),
159        }
160    }
161}
162
163pub fn execute_resolved_with_binding<'a, Q: QueryContext>(
164    r: &Q::ResolvedPattern<'a>,
165    binding: &Q::ResolvedPattern<'a>,
166    state: &State<'a, Q>,
167    language: &Q::Language<'a>,
168) -> GritResult<bool> {
169    if let (Some(r), Some(b)) = (r.get_last_binding(), binding.get_last_binding()) {
170        Ok(r.is_equivalent_to(b, language))
171    } else {
172        Ok(r.text(&state.files, language)? == binding.text(&state.files, language)?)
173    }
174}