ast_grep_core/
meta_var.rs

1use crate::match_tree::does_node_match_exactly;
2use crate::matcher::Matcher;
3use crate::source::Content;
4use crate::{Doc, Node};
5use std::borrow::Cow;
6use std::collections::HashMap;
7
8use crate::replacer::formatted_slice;
9
10pub type MetaVariableID = String;
11
12pub type Underlying<D> = Vec<<<D as Doc>::Source as Content>::Underlying>;
13
14/// a dictionary that stores metavariable instantiation
15/// const a = 123 matched with const a = $A will produce env: $A => 123
16#[derive(Clone)]
17pub struct MetaVarEnv<'tree, D: Doc> {
18  single_matched: HashMap<MetaVariableID, Node<'tree, D>>,
19  multi_matched: HashMap<MetaVariableID, Vec<Node<'tree, D>>>,
20  transformed_var: HashMap<MetaVariableID, Underlying<D>>,
21}
22
23impl<'t, D: Doc> MetaVarEnv<'t, D> {
24  pub fn new() -> Self {
25    Self {
26      single_matched: HashMap::new(),
27      multi_matched: HashMap::new(),
28      transformed_var: HashMap::new(),
29    }
30  }
31
32  pub fn insert(&mut self, id: &str, ret: Node<'t, D>) -> Option<&mut Self> {
33    if self.match_variable(id, &ret) {
34      self.single_matched.insert(id.to_string(), ret);
35      Some(self)
36    } else {
37      None
38    }
39  }
40
41  pub fn insert_multi(&mut self, id: &str, ret: Vec<Node<'t, D>>) -> Option<&mut Self> {
42    if self.match_multi_var(id, &ret) {
43      self.multi_matched.insert(id.to_string(), ret);
44      Some(self)
45    } else {
46      None
47    }
48  }
49
50  pub fn get_match(&self, var: &str) -> Option<&'_ Node<'t, D>> {
51    self.single_matched.get(var)
52  }
53
54  pub fn get_multiple_matches(&self, var: &str) -> Vec<Node<'t, D>> {
55    self.multi_matched.get(var).cloned().unwrap_or_default()
56  }
57
58  pub fn add_label(&mut self, label: &str, node: Node<'t, D>) {
59    self
60      .multi_matched
61      .entry(label.into())
62      .or_default()
63      .push(node);
64  }
65
66  pub fn get_labels(&self, label: &str) -> Option<&Vec<Node<'t, D>>> {
67    self.multi_matched.get(label)
68  }
69
70  pub fn get_matched_variables(&self) -> impl Iterator<Item = MetaVariable> + use<'_, 't, D> {
71    let single = self
72      .single_matched
73      .keys()
74      .cloned()
75      .map(|n| MetaVariable::Capture(n, false));
76    let transformed = self
77      .transformed_var
78      .keys()
79      .cloned()
80      .map(|n| MetaVariable::Capture(n, false));
81    let multi = self
82      .multi_matched
83      .keys()
84      .cloned()
85      .map(MetaVariable::MultiCapture);
86    single.chain(multi).chain(transformed)
87  }
88
89  fn match_variable(&self, id: &str, candidate: &Node<'t, D>) -> bool {
90    if let Some(m) = self.single_matched.get(id) {
91      return does_node_match_exactly(m, candidate);
92    }
93    true
94  }
95  fn match_multi_var(&self, id: &str, cands: &[Node<'t, D>]) -> bool {
96    let Some(nodes) = self.multi_matched.get(id) else {
97      return true;
98    };
99    let mut named_nodes = nodes.iter().filter(|n| n.is_named());
100    let mut named_cands = cands.iter().filter(|n| n.is_named());
101    loop {
102      if let Some(node) = named_nodes.next() {
103        let Some(cand) = named_cands.next() else {
104          // cand is done but node is not
105          break false;
106        };
107        if !does_node_match_exactly(node, cand) {
108          break false;
109        }
110      } else if named_cands.next().is_some() {
111        // node is done but cand is not
112        break false;
113      } else {
114        // both None, matches
115        break true;
116      }
117    }
118  }
119
120  pub fn match_constraints<M: Matcher>(
121    &mut self,
122    var_matchers: &HashMap<MetaVariableID, M>,
123  ) -> bool {
124    let mut env = Cow::Borrowed(self);
125    for (var_id, candidate) in &self.single_matched {
126      if let Some(m) = var_matchers.get(var_id) {
127        if m.match_node_with_env(candidate.clone(), &mut env).is_none() {
128          return false;
129        }
130      }
131    }
132    if let Cow::Owned(env) = env {
133      *self = env;
134    }
135    true
136  }
137
138  pub fn insert_transformation(&mut self, var: &MetaVariable, name: &str, slice: Underlying<D>) {
139    let node = match var {
140      MetaVariable::Capture(v, _) => self.single_matched.get(v),
141      MetaVariable::MultiCapture(vs) => self.multi_matched.get(vs).and_then(|vs| vs.first()),
142      _ => None,
143    };
144    let deindented = if let Some(v) = node {
145      formatted_slice(&slice, v.get_doc().get_source(), v.range().start).to_vec()
146    } else {
147      slice
148    };
149    self.transformed_var.insert(name.to_string(), deindented);
150  }
151
152  pub fn get_transformed(&self, var: &str) -> Option<&Underlying<D>> {
153    self.transformed_var.get(var)
154  }
155  pub fn get_var_bytes<'s>(
156    &'s self,
157    var: &MetaVariable,
158  ) -> Option<&'s [<D::Source as Content>::Underlying]> {
159    get_var_bytes_impl(self, var)
160  }
161}
162
163impl<D: Doc> MetaVarEnv<'_, D> {
164  /// internal for readopt NodeMatch in pinned.rs
165  /// readopt node and env when sending them to other threads
166  pub(crate) fn visit_nodes<F>(&mut self, mut f: F)
167  where
168    F: FnMut(&mut Node<'_, D>),
169  {
170    for n in self.single_matched.values_mut() {
171      f(n)
172    }
173    for ns in self.multi_matched.values_mut() {
174      for n in ns {
175        f(n)
176      }
177    }
178  }
179}
180
181fn get_var_bytes_impl<'e, 't, C, D>(
182  env: &'e MetaVarEnv<'t, D>,
183  var: &MetaVariable,
184) -> Option<&'e [C::Underlying]>
185where
186  D: Doc<Source = C> + 't,
187  C: Content + 't,
188{
189  match var {
190    MetaVariable::Capture(n, _) => {
191      if let Some(node) = env.get_match(n) {
192        let bytes = node.get_doc().get_source().get_range(node.range());
193        Some(bytes)
194      } else if let Some(bytes) = env.get_transformed(n) {
195        Some(bytes)
196      } else {
197        None
198      }
199    }
200    MetaVariable::MultiCapture(n) => {
201      let nodes = env.get_multiple_matches(n);
202      if nodes.is_empty() {
203        None
204      } else {
205        // NOTE: start_byte is not always index range of source's slice.
206        // e.g. start_byte is still byte_offset in utf_16 (napi). start_byte
207        // so we need to call source's get_range method
208        let start = nodes[0].range().start;
209        let end = nodes[nodes.len() - 1].range().end;
210        Some(nodes[0].get_doc().get_source().get_range(start..end))
211      }
212    }
213    _ => None,
214  }
215}
216
217impl<D: Doc> Default for MetaVarEnv<'_, D> {
218  fn default() -> Self {
219    Self::new()
220  }
221}
222
223#[derive(Clone, Debug, PartialEq, Eq)]
224pub enum MetaVariable {
225  /// $A for captured meta var
226  Capture(MetaVariableID, bool),
227  /// $_ for non-captured meta var
228  Dropped(bool),
229  /// $$$ for non-captured multi var
230  Multiple,
231  /// $$$A for captured ellipsis
232  MultiCapture(MetaVariableID),
233}
234
235pub(crate) fn extract_meta_var(src: &str, meta_char: char) -> Option<MetaVariable> {
236  use MetaVariable::*;
237  let ellipsis: String = std::iter::repeat(meta_char).take(3).collect();
238  if src == ellipsis {
239    return Some(Multiple);
240  }
241  if let Some(trimmed) = src.strip_prefix(&ellipsis) {
242    if !trimmed.chars().all(is_valid_meta_var_char) {
243      return None;
244    }
245    if trimmed.starts_with('_') {
246      return Some(Multiple);
247    } else {
248      return Some(MultiCapture(trimmed.to_owned()));
249    }
250  }
251  if !src.starts_with(meta_char) {
252    return None;
253  }
254  let trimmed = &src[meta_char.len_utf8()..];
255  let (trimmed, named) = if let Some(t) = trimmed.strip_prefix(meta_char) {
256    (t, false)
257  } else {
258    (trimmed, true)
259  };
260  if !trimmed.starts_with(is_valid_first_char) || // empty or started with number
261    !trimmed.chars().all(is_valid_meta_var_char)
262  // not in form of $A or $_
263  {
264    return None;
265  }
266  if trimmed.starts_with('_') {
267    Some(Dropped(named))
268  } else {
269    Some(Capture(trimmed.to_owned(), named))
270  }
271}
272
273#[inline]
274fn is_valid_first_char(c: char) -> bool {
275  matches!(c, 'A'..='Z' | '_')
276}
277
278#[inline]
279pub(crate) fn is_valid_meta_var_char(c: char) -> bool {
280  is_valid_first_char(c) || c.is_ascii_digit()
281}
282
283impl<'tree, D: Doc> From<MetaVarEnv<'tree, D>> for HashMap<String, String> {
284  fn from(env: MetaVarEnv<'tree, D>) -> Self {
285    let mut ret = HashMap::new();
286    for (id, node) in env.single_matched {
287      ret.insert(id, node.text().into());
288    }
289    for (id, bytes) in env.transformed_var {
290      ret.insert(id, <D::Source as Content>::encode_bytes(&bytes).to_string());
291    }
292    for (id, nodes) in env.multi_matched {
293      let s: Vec<_> = nodes.iter().map(|n| n.text()).collect();
294      let s = s.join(", ");
295      ret.insert(id, format!("[{s}]"));
296    }
297    ret
298  }
299}
300
301#[cfg(test)]
302mod test {
303  use super::*;
304  use crate::language::Tsx;
305  use crate::tree_sitter::LanguageExt;
306  use crate::Pattern;
307
308  fn extract_var(s: &str) -> Option<MetaVariable> {
309    extract_meta_var(s, '$')
310  }
311  #[test]
312  fn test_match_var() {
313    use MetaVariable::*;
314    assert_eq!(extract_var("$$$"), Some(Multiple));
315    assert_eq!(extract_var("$ABC"), Some(Capture("ABC".into(), true)));
316    assert_eq!(extract_var("$$ABC"), Some(Capture("ABC".into(), false)));
317    assert_eq!(extract_var("$MATCH1"), Some(Capture("MATCH1".into(), true)));
318    assert_eq!(extract_var("$$$ABC"), Some(MultiCapture("ABC".into())));
319    assert_eq!(extract_var("$_"), Some(Dropped(true)));
320    assert_eq!(extract_var("$_123"), Some(Dropped(true)));
321    assert_eq!(extract_var("$$_"), Some(Dropped(false)));
322  }
323
324  #[test]
325  fn test_not_meta_var() {
326    assert_eq!(extract_var("$123"), None);
327    assert_eq!(extract_var("$"), None);
328    assert_eq!(extract_var("$$"), None);
329    assert_eq!(extract_var("abc"), None);
330    assert_eq!(extract_var("$abc"), None);
331  }
332
333  fn match_constraints(pattern: &str, node: &str) -> bool {
334    let mut matchers = HashMap::new();
335    matchers.insert("A".to_string(), Pattern::new(pattern, Tsx));
336    let mut env = MetaVarEnv::new();
337    let root = Tsx.ast_grep(node);
338    let node = root.root().child(0).unwrap().child(0).unwrap();
339    env.insert("A", node);
340    env.match_constraints(&matchers)
341  }
342
343  #[test]
344  fn test_non_ascii_meta_var() {
345    let extract = |s| extract_meta_var(s, 'µ');
346    use MetaVariable::*;
347    assert_eq!(extract("µµµ"), Some(Multiple));
348    assert_eq!(extract("µABC"), Some(Capture("ABC".into(), true)));
349    assert_eq!(extract("µµABC"), Some(Capture("ABC".into(), false)));
350    assert_eq!(extract("µµµABC"), Some(MultiCapture("ABC".into())));
351    assert_eq!(extract("µ_"), Some(Dropped(true)));
352    assert_eq!(extract("abc"), None);
353    assert_eq!(extract("µabc"), None);
354  }
355
356  #[test]
357  fn test_match_constraints() {
358    assert!(match_constraints("a + b", "a + b"));
359  }
360
361  #[test]
362  fn test_match_not_constraints() {
363    assert!(!match_constraints("a - b", "a + b"));
364  }
365
366  #[test]
367  fn test_multi_var_match() {
368    let grep = Tsx.ast_grep("if (true) { a += 1; b += 1 } else { a += 1; b += 1 }");
369    let node = grep.root();
370    let found = node.find("if (true) { $$$A } else { $$$A }");
371    assert!(found.is_some());
372    let grep = Tsx.ast_grep("if (true) { a += 1 } else { b += 1 }");
373    let node = grep.root();
374    let not_found = node.find("if (true) { $$$A } else { $$$A }");
375    assert!(not_found.is_none());
376  }
377
378  #[test]
379  fn test_multi_var_match_with_trailing() {
380    let grep = Tsx.ast_grep("if (true) { a += 1; } else { a += 1; b += 1 }");
381    let node = grep.root();
382    let not_found = node.find("if (true) { $$$A } else { $$$A }");
383    assert!(not_found.is_none());
384    let grep = Tsx.ast_grep("if (true) { a += 1; b += 1; } else { a += 1 }");
385    let node = grep.root();
386    let not_found = node.find("if (true) { $$$A } else { $$$A }");
387    assert!(not_found.is_none());
388  }
389}