grit_pattern_matcher/pattern/
state.rs

1use super::{
2    patterns::Pattern,
3    variable::{Variable, VariableScope},
4    variable_content::VariableContent,
5};
6use crate::{
7    binding::Binding,
8    constants::MATCH_VAR,
9    context::QueryContext,
10    effects::Effect,
11    file_owners::FileOwner,
12    intervals::{earliest_deadline_sort, get_top_level_intervals_in_range, Interval},
13    pattern::resolved_pattern::ResolvedPattern,
14};
15use grit_util::{
16    error::{GritPatternError, GritResult},
17    AnalysisLogs, CodeRange, Range, VariableMatch,
18};
19use rand::SeedableRng;
20use std::ops::Range as StdRange;
21use std::{collections::HashMap, path::Path};
22
23#[derive(Debug, Clone)]
24pub struct EffectRange<'a, Q: QueryContext> {
25    range: StdRange<u32>,
26    pub effect: Effect<'a, Q>,
27}
28
29impl<Q: QueryContext> Interval for EffectRange<'_, Q> {
30    fn interval(&self) -> (u32, u32) {
31        (self.range.start, self.range.end)
32    }
33}
34
35#[derive(Clone, Debug)]
36pub struct FileRegistry<'a, Q: QueryContext> {
37    /// The number of versions for each file
38    version_count: Vec<u16>,
39    /// Original file paths, for lazy loading
40    file_paths: Vec<&'a Path>,
41    /// The actual FileOwner, which has the full file available
42    owners: Vec<Vec<&'a FileOwner<Q::Tree<'a>>>>,
43}
44
45impl<'a, Q: QueryContext> FileRegistry<'a, Q> {
46    pub fn get_file_owner(&self, pointer: FilePtr) -> &'a FileOwner<Q::Tree<'a>> {
47        #[cfg(debug_assertions)]
48        {
49            if pointer.file as usize >= self.owners.len() {
50                panic!(
51                    "File index out of bounds: file={}, owners.len()={}",
52                    pointer.file,
53                    self.owners.len()
54                );
55            }
56            let file_owners = &self.owners[pointer.file as usize];
57            if pointer.version as usize >= file_owners.len() {
58                let name = self.get_file_name(pointer);
59                panic!(
60                    "File ({}) does not have version ({}) available. Only {} versions available. Make sure load_file is called before accessing file owners.",
61                    name.to_string_lossy(),
62                    pointer.version,
63                    file_owners.len()
64                );
65            }
66        }
67
68        self.owners[pointer.file as usize][pointer.version as usize]
69    }
70
71    pub fn get_file_name(&self, pointer: FilePtr) -> &'a Path {
72        let file_index = pointer.file as usize;
73        let version_index = pointer.version as usize;
74        if let Some(owners) = self.owners.get(file_index) {
75            if let Some(owner) = owners.get(version_index) {
76                return &owner.name;
77            }
78        }
79        self.file_paths
80            .get(file_index)
81            .expect("File path should exist for given file index.")
82    }
83
84    pub fn get_absolute_path(&self, pointer: FilePtr) -> GritResult<&'a Path> {
85        let file_index = pointer.file as usize;
86        let version_index = pointer.version as usize;
87        if let Some(owners) = self.owners.get(file_index) {
88            if let Some(owner) = owners.get(version_index) {
89                return Ok(&owner.absolute_path);
90            }
91        }
92        Err(GritPatternError::new(
93            "Absolute file path accessed before file was loaded.",
94        ))
95    }
96
97    /// If only the paths are available, create a FileRegistry with empty owners
98    /// This is a logic error if you do not later insert the appropriate owners before get_file_owner is called
99    pub fn new_from_paths(file_paths: Vec<&'a Path>) -> Self {
100        Self {
101            version_count: file_paths.iter().map(|_| 0).collect(),
102            owners: file_paths.iter().map(|_| Vec::new()).collect(),
103            file_paths,
104        }
105    }
106
107    /// Confirms a file is already fully loaded
108    pub fn is_loaded(&self, pointer: &FilePtr) -> bool {
109        self.version_count
110            .get(pointer.file as usize)
111            .map_or(false, |&v| v > 0)
112    }
113
114    /// Load a file in
115    pub fn load_file(&mut self, pointer: &FilePtr, file: &'a FileOwner<Q::Tree<'a>>) {
116        self.push_revision(pointer, file)
117    }
118
119    /// Returns the latest revision of a given filepointer
120    /// If none exists, returns the file pointer itself
121    pub fn latest_revision(&self, pointer: &FilePtr) -> FilePtr {
122        match self.version_count.get(pointer.file as usize) {
123            Some(&version_count) => {
124                if version_count == 0 {
125                    *pointer
126                } else {
127                    FilePtr {
128                        file: pointer.file,
129                        version: version_count - 1,
130                    }
131                }
132            }
133            None => *pointer,
134        }
135    }
136
137    pub fn files(&self) -> &Vec<Vec<&'a FileOwner<Q::Tree<'a>>>> {
138        &self.owners
139    }
140
141    pub fn push_revision(&mut self, pointer: &FilePtr, file: &'a FileOwner<Q::Tree<'a>>) {
142        self.version_count[pointer.file as usize] += 1;
143        self.owners[pointer.file as usize].push(file)
144    }
145
146    pub fn push_new_file(&mut self, file: &'a FileOwner<Q::Tree<'a>>) -> FilePtr {
147        self.version_count.push(1);
148        self.file_paths.push(&file.name);
149        self.owners.push(vec![file]);
150        FilePtr {
151            file: (self.owners.len() - 1) as u16,
152            version: 0,
153        }
154    }
155}
156
157// todo: we don't want to clone pattern definitions when cloning State
158#[derive(Clone, Debug)]
159pub struct State<'a, Q: QueryContext> {
160    pub bindings: VarRegistry<'a, Q>,
161    pub effects: Vec<Effect<'a, Q>>,
162    pub files: FileRegistry<'a, Q>,
163    rng: rand::rngs::StdRng,
164    current_scope: usize,
165    // Track dynamic pattern scope names
166    pattern_scopes: HashMap<String, usize>,
167}
168
169fn get_top_level_effect_ranges<'a, Q: QueryContext>(
170    effects: &[Effect<'a, Q>],
171    memo: &HashMap<CodeRange, Option<String>>,
172    range: &CodeRange,
173    language: &Q::Language<'a>,
174    logs: &mut AnalysisLogs,
175) -> GritResult<Vec<EffectRange<'a, Q>>> {
176    let mut effects: Vec<EffectRange<Q>> = effects
177        .iter()
178        .filter(|effect| {
179            let binding = &effect.binding;
180            if let Some(src) = binding.source() {
181                if let Some(binding_range) = binding.code_range(language) {
182                    range.applies_to(src) && !matches!(memo.get(&binding_range), Some(None))
183                } else {
184                    let _ = binding.log_empty_field_rewrite_error(language, logs);
185                    false
186                }
187            } else {
188                false
189            }
190        })
191        .map(|effect| {
192            let binding = &effect.binding;
193            let byte_range = binding
194                .range(language)
195                .ok_or_else(|| GritPatternError::new("binding has no range"))?;
196            let end_byte = byte_range.end as u32;
197            let start_byte = byte_range.start as u32;
198            Ok(EffectRange {
199                range: start_byte..end_byte,
200                effect: effect.clone(),
201            })
202        })
203        .collect::<GritResult<Vec<EffectRange<Q>>>>()?;
204    if !earliest_deadline_sort(&mut effects) {
205        return Err(GritPatternError::new("effects have overlapping ranges"));
206    }
207    Ok(get_top_level_intervals_in_range(
208        effects,
209        range.start,
210        range.end,
211    ))
212}
213
214pub fn get_top_level_effects<'a, Q: QueryContext>(
215    effects: &[Effect<'a, Q>],
216    memo: &HashMap<CodeRange, Option<String>>,
217    range: &CodeRange,
218    language: &Q::Language<'a>,
219    logs: &mut AnalysisLogs,
220) -> GritResult<Vec<Effect<'a, Q>>> {
221    let top_level = get_top_level_effect_ranges(effects, memo, range, language, logs)?;
222    let top_level: Vec<Effect<'a, Q>> = top_level
223        .into_iter()
224        .map(|e| {
225            assert!(e.range.start >= range.start);
226            assert!(e.range.end <= range.end);
227            e.effect
228        })
229        .collect();
230    Ok(top_level)
231}
232
233#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)]
234pub struct FilePtr {
235    pub file: u16,
236    pub version: u16,
237}
238
239impl FilePtr {
240    pub fn new(file: u16, version: u16) -> Self {
241        Self { file, version }
242    }
243}
244
245pub struct ScopeTracker {
246    previous_scope: usize,
247}
248
249impl<'a, Q: QueryContext> State<'a, Q> {
250    pub fn new(bindings: VarRegistry<'a, Q>, registry: FileRegistry<'a, Q>) -> Self {
251        Self {
252            rng: rand::rngs::StdRng::seed_from_u64(32),
253            current_scope: 0,
254            bindings,
255            effects: vec![],
256            files: registry,
257            pattern_scopes: HashMap::new(),
258        }
259    }
260
261    pub fn get_files<'b>(&'b self) -> &'b FileRegistry<Q>
262    where
263        'b: 'a,
264    {
265        &self.files
266    }
267
268    // Grit uses a fixed seed RNG for reproducibility
269    pub fn get_rng(&mut self) -> &mut rand::rngs::StdRng {
270        &mut self.rng
271    }
272
273    /// Enter a scope by copying the current scope and adding the new variables
274    /// When you are done with a scope, you *must* call exit_scope
275    ///
276    /// # Parameters
277    ///
278    /// * `scope` - The scope to enter
279    /// * `args` - The arguments to the scope
280    pub(crate) fn enter_scope(
281        &mut self,
282        scope: usize,
283        args: &'a [Option<Pattern<Q>>],
284    ) -> ScopeTracker {
285        let old_scope = self.bindings[scope].last().unwrap();
286        let new_scope: Vec<Box<VariableContent<Q>>> = old_scope
287            .iter()
288            .enumerate()
289            .map(|(index, content)| {
290                let mut content = content.clone();
291                let pattern = args.get(index).and_then(Option::as_ref);
292                if let Some(Pattern::Variable(v)) = pattern {
293                    content.mirrors.push(v)
294                };
295                Box::new(VariableContent {
296                    pattern,
297                    value: None,
298                    value_history: Vec::new(),
299                    ..*content
300                })
301            })
302            .collect();
303        self.bindings[scope].push(new_scope);
304
305        let old_scope_index = self.current_scope;
306        self.current_scope = scope;
307
308        ScopeTracker {
309            previous_scope: old_scope_index,
310        }
311    }
312
313    pub(crate) fn exit_scope(&mut self, tracker: ScopeTracker) {
314        self.current_scope = tracker.previous_scope;
315    }
316
317    pub fn register_pattern_definition(&mut self, name: &str) -> usize {
318        if let Some(scope) = self.pattern_scopes.get(name) {
319            *scope
320        } else {
321            // The dynamic pattern definition is always in a *new* scope
322            let registered_scope = self.bindings.len();
323            self.bindings.push(vec![vec![]]);
324            self.pattern_scopes
325                .insert(name.to_string(), registered_scope);
326            registered_scope
327        }
328    }
329
330    // unfortunately these accessor functions are not as useful as they
331    // could be due to the inability of rust to split borrows across functions
332    // within a function you could mutably borrow bindings, and immutably borrow
333    // src simultaneously, but you can't do that across functions.
334    // see:
335    // https://stackoverflow.com/questions/61699010/rust-not-allowing-mutable-borrow-when-splitting-properly
336    // https://doc.rust-lang.org/nomicon/borrow-splitting.html
337    // todo split State in a sensible way.
338    pub fn get_name(&self, var: &Variable) -> &str {
339        &self.bindings[var.try_scope().unwrap() as usize]
340            .last()
341            .unwrap()[var.try_index().unwrap() as usize]
342            .name
343    }
344
345    /// Attempt to find a variable by name in any scope
346    /// This is inefficient and should only be used when we haven't pre-allocated a Variable reference
347    ///
348    /// If you have a Variable reference, use `trace_var` instead to find the latest binding
349    pub fn find_var(&self, name: &str) -> Option<Variable> {
350        if let Some(scope) = self.find_var_scope(name) {
351            return Some(Variable::new(scope.scope as usize, scope.index as usize));
352        }
353        None
354    }
355
356    /// Find a variable's registered scope, by name in any scope
357    fn find_var_scope(&self, name: &str) -> Option<VariableScope> {
358        for (scope_index, scope) in self.bindings.iter().enumerate().rev() {
359            for (index, content) in scope.last().unwrap().iter().enumerate() {
360                if content.name == name {
361                    return Some(VariableScope::new(scope_index, index));
362                }
363            }
364        }
365        None
366    }
367
368    pub(crate) fn register_var(&mut self, name: &str) -> VariableScope {
369        if let Some(existing) = self.find_var_scope(name) {
370            return existing;
371        };
372
373        let scope = self.current_scope;
374        let the_scope = self.bindings[self.current_scope].last_mut().unwrap();
375        let index = the_scope.len();
376
377        the_scope.push(Box::new(VariableContent::new(name.to_string())));
378        VariableScope::new(scope, index)
379    }
380
381    /// Attempt to find a variable by name in the current scope
382    pub fn find_var_in_scope(&mut self, name: &str) -> Option<Variable> {
383        for (index, content) in self.bindings[self.current_scope]
384            .last()
385            .unwrap()
386            .iter()
387            .enumerate()
388        {
389            if content.name == name {
390                return Some(Variable::new(self.current_scope, index));
391            }
392        }
393        None
394    }
395
396    /// Trace a variable to the root binding
397    /// Where possible, prefer trace_var_mut
398    pub fn trace_var(&self, var: &Variable) -> Variable {
399        if let Ok(scope) = var.try_scope() {
400            if let Ok(index) = var.try_index() {
401                if let Some(Pattern::Variable(v)) =
402                    &self.bindings[scope as usize].last().unwrap()[index as usize].pattern
403                {
404                    return self.trace_var(v);
405                }
406            }
407        }
408        var.clone()
409    }
410
411    pub fn trace_var_mut(&mut self, var: &Variable) -> Variable {
412        if let Ok(scope) = var.get_scope(self) {
413            if let Ok(index) = var.get_index(self) {
414                if let Some(Pattern::Variable(v)) =
415                    &self.bindings[scope as usize].last().unwrap()[index as usize].pattern
416                {
417                    return self.trace_var_mut(v);
418                }
419            }
420        }
421        var.clone()
422    }
423
424    pub fn bindings_history_to_ranges(
425        &self,
426        language: &Q::Language<'a>,
427        current_name: Option<&str>,
428    ) -> (Vec<VariableMatch>, Vec<Range>, bool) {
429        let mut matches = vec![];
430        let mut top_level_matches = vec![];
431        let mut suppressed = false;
432        for (i, scope) in self.bindings.iter().enumerate() {
433            for (j, content) in scope.last().unwrap().iter().enumerate() {
434                let name = content.name.clone();
435                let mut var_ranges = vec![];
436                let mut bindings_count = 0;
437                let mut suppressed_count = 0;
438                for value in content.value_history.iter() {
439                    if let Some(bindings) = value.get_bindings() {
440                        for binding in bindings {
441                            bindings_count += 1;
442                            if binding.is_suppressed(language, current_name) {
443                                suppressed_count += 1;
444                                continue;
445                            }
446                            if let Some(match_position) = binding.position(language) {
447                                // TODO, this check only needs to be done at the global scope right?
448                                if name == MATCH_VAR {
449                                    // apply_match = true;
450                                    top_level_matches.push(match_position);
451                                }
452                                var_ranges.push(match_position);
453                            }
454                        }
455                    }
456                }
457                if suppressed_count > 0 && suppressed_count == bindings_count {
458                    suppressed = true;
459                    continue;
460                }
461                let scoped_name = format!("{}_{}_{}", i, j, name);
462                let var_match = VariableMatch::new(name, scoped_name, var_ranges);
463                matches.push(var_match);
464            }
465        }
466        suppressed = suppressed && top_level_matches.is_empty();
467        (matches, top_level_matches, suppressed)
468    }
469}
470
471pub type VarRegistry<'a, P> = Vec<Vec<Vec<Box<VariableContent<'a, P>>>>>;