1use super::{
2 patterns::Pattern,
3 variable::{Variable, VariableScope},
4 variable_content::VariableContent,
5};
6use crate::{
7 binding::Binding,
8 constants::MATCH_VAR,
9 context::QueryContext,
10 effects::Effect,
11 file_owners::FileOwner,
12 intervals::{earliest_deadline_sort, get_top_level_intervals_in_range, Interval},
13 pattern::resolved_pattern::ResolvedPattern,
14};
15use grit_util::{
16 error::{GritPatternError, GritResult},
17 AnalysisLogs, CodeRange, Range, VariableMatch,
18};
19use rand::SeedableRng;
20use std::ops::Range as StdRange;
21use std::{collections::HashMap, path::Path};
22
23#[derive(Debug, Clone)]
24pub struct EffectRange<'a, Q: QueryContext> {
25 range: StdRange<u32>,
26 pub effect: Effect<'a, Q>,
27}
28
29impl<Q: QueryContext> Interval for EffectRange<'_, Q> {
30 fn interval(&self) -> (u32, u32) {
31 (self.range.start, self.range.end)
32 }
33}
34
35#[derive(Clone, Debug)]
36pub struct FileRegistry<'a, Q: QueryContext> {
37 version_count: Vec<u16>,
39 file_paths: Vec<&'a Path>,
41 owners: Vec<Vec<&'a FileOwner<Q::Tree<'a>>>>,
43}
44
45impl<'a, Q: QueryContext> FileRegistry<'a, Q> {
46 pub fn get_file_owner(&self, pointer: FilePtr) -> &'a FileOwner<Q::Tree<'a>> {
47 #[cfg(debug_assertions)]
48 {
49 if pointer.file as usize >= self.owners.len() {
50 panic!(
51 "File index out of bounds: file={}, owners.len()={}",
52 pointer.file,
53 self.owners.len()
54 );
55 }
56 let file_owners = &self.owners[pointer.file as usize];
57 if pointer.version as usize >= file_owners.len() {
58 let name = self.get_file_name(pointer);
59 panic!(
60 "File ({}) does not have version ({}) available. Only {} versions available. Make sure load_file is called before accessing file owners.",
61 name.to_string_lossy(),
62 pointer.version,
63 file_owners.len()
64 );
65 }
66 }
67
68 self.owners[pointer.file as usize][pointer.version as usize]
69 }
70
71 pub fn get_file_name(&self, pointer: FilePtr) -> &'a Path {
72 let file_index = pointer.file as usize;
73 let version_index = pointer.version as usize;
74 if let Some(owners) = self.owners.get(file_index) {
75 if let Some(owner) = owners.get(version_index) {
76 return &owner.name;
77 }
78 }
79 self.file_paths
80 .get(file_index)
81 .expect("File path should exist for given file index.")
82 }
83
84 pub fn get_absolute_path(&self, pointer: FilePtr) -> GritResult<&'a Path> {
85 let file_index = pointer.file as usize;
86 let version_index = pointer.version as usize;
87 if let Some(owners) = self.owners.get(file_index) {
88 if let Some(owner) = owners.get(version_index) {
89 return Ok(&owner.absolute_path);
90 }
91 }
92 Err(GritPatternError::new(
93 "Absolute file path accessed before file was loaded.",
94 ))
95 }
96
97 pub fn new_from_paths(file_paths: Vec<&'a Path>) -> Self {
100 Self {
101 version_count: file_paths.iter().map(|_| 0).collect(),
102 owners: file_paths.iter().map(|_| Vec::new()).collect(),
103 file_paths,
104 }
105 }
106
107 pub fn is_loaded(&self, pointer: &FilePtr) -> bool {
109 self.version_count
110 .get(pointer.file as usize)
111 .map_or(false, |&v| v > 0)
112 }
113
114 pub fn load_file(&mut self, pointer: &FilePtr, file: &'a FileOwner<Q::Tree<'a>>) {
116 self.push_revision(pointer, file)
117 }
118
119 pub fn latest_revision(&self, pointer: &FilePtr) -> FilePtr {
122 match self.version_count.get(pointer.file as usize) {
123 Some(&version_count) => {
124 if version_count == 0 {
125 *pointer
126 } else {
127 FilePtr {
128 file: pointer.file,
129 version: version_count - 1,
130 }
131 }
132 }
133 None => *pointer,
134 }
135 }
136
137 pub fn files(&self) -> &Vec<Vec<&'a FileOwner<Q::Tree<'a>>>> {
138 &self.owners
139 }
140
141 pub fn push_revision(&mut self, pointer: &FilePtr, file: &'a FileOwner<Q::Tree<'a>>) {
142 self.version_count[pointer.file as usize] += 1;
143 self.owners[pointer.file as usize].push(file)
144 }
145
146 pub fn push_new_file(&mut self, file: &'a FileOwner<Q::Tree<'a>>) -> FilePtr {
147 self.version_count.push(1);
148 self.file_paths.push(&file.name);
149 self.owners.push(vec![file]);
150 FilePtr {
151 file: (self.owners.len() - 1) as u16,
152 version: 0,
153 }
154 }
155}
156
157#[derive(Clone, Debug)]
159pub struct State<'a, Q: QueryContext> {
160 pub bindings: VarRegistry<'a, Q>,
161 pub effects: Vec<Effect<'a, Q>>,
162 pub files: FileRegistry<'a, Q>,
163 rng: rand::rngs::StdRng,
164 current_scope: usize,
165 pattern_scopes: HashMap<String, usize>,
167}
168
169fn get_top_level_effect_ranges<'a, Q: QueryContext>(
170 effects: &[Effect<'a, Q>],
171 memo: &HashMap<CodeRange, Option<String>>,
172 range: &CodeRange,
173 language: &Q::Language<'a>,
174 logs: &mut AnalysisLogs,
175) -> GritResult<Vec<EffectRange<'a, Q>>> {
176 let mut effects: Vec<EffectRange<Q>> = effects
177 .iter()
178 .filter(|effect| {
179 let binding = &effect.binding;
180 if let Some(src) = binding.source() {
181 if let Some(binding_range) = binding.code_range(language) {
182 range.applies_to(src) && !matches!(memo.get(&binding_range), Some(None))
183 } else {
184 let _ = binding.log_empty_field_rewrite_error(language, logs);
185 false
186 }
187 } else {
188 false
189 }
190 })
191 .map(|effect| {
192 let binding = &effect.binding;
193 let byte_range = binding
194 .range(language)
195 .ok_or_else(|| GritPatternError::new("binding has no range"))?;
196 let end_byte = byte_range.end as u32;
197 let start_byte = byte_range.start as u32;
198 Ok(EffectRange {
199 range: start_byte..end_byte,
200 effect: effect.clone(),
201 })
202 })
203 .collect::<GritResult<Vec<EffectRange<Q>>>>()?;
204 if !earliest_deadline_sort(&mut effects) {
205 return Err(GritPatternError::new("effects have overlapping ranges"));
206 }
207 Ok(get_top_level_intervals_in_range(
208 effects,
209 range.start,
210 range.end,
211 ))
212}
213
214pub fn get_top_level_effects<'a, Q: QueryContext>(
215 effects: &[Effect<'a, Q>],
216 memo: &HashMap<CodeRange, Option<String>>,
217 range: &CodeRange,
218 language: &Q::Language<'a>,
219 logs: &mut AnalysisLogs,
220) -> GritResult<Vec<Effect<'a, Q>>> {
221 let top_level = get_top_level_effect_ranges(effects, memo, range, language, logs)?;
222 let top_level: Vec<Effect<'a, Q>> = top_level
223 .into_iter()
224 .map(|e| {
225 assert!(e.range.start >= range.start);
226 assert!(e.range.end <= range.end);
227 e.effect
228 })
229 .collect();
230 Ok(top_level)
231}
232
233#[derive(Debug, Clone, Copy, Eq, Hash, PartialEq)]
234pub struct FilePtr {
235 pub file: u16,
236 pub version: u16,
237}
238
239impl FilePtr {
240 pub fn new(file: u16, version: u16) -> Self {
241 Self { file, version }
242 }
243}
244
245pub struct ScopeTracker {
246 previous_scope: usize,
247}
248
249impl<'a, Q: QueryContext> State<'a, Q> {
250 pub fn new(bindings: VarRegistry<'a, Q>, registry: FileRegistry<'a, Q>) -> Self {
251 Self {
252 rng: rand::rngs::StdRng::seed_from_u64(32),
253 current_scope: 0,
254 bindings,
255 effects: vec![],
256 files: registry,
257 pattern_scopes: HashMap::new(),
258 }
259 }
260
261 pub fn get_files<'b>(&'b self) -> &'b FileRegistry<Q>
262 where
263 'b: 'a,
264 {
265 &self.files
266 }
267
268 pub fn get_rng(&mut self) -> &mut rand::rngs::StdRng {
270 &mut self.rng
271 }
272
273 pub(crate) fn enter_scope(
281 &mut self,
282 scope: usize,
283 args: &'a [Option<Pattern<Q>>],
284 ) -> ScopeTracker {
285 let old_scope = self.bindings[scope].last().unwrap();
286 let new_scope: Vec<Box<VariableContent<Q>>> = old_scope
287 .iter()
288 .enumerate()
289 .map(|(index, content)| {
290 let mut content = content.clone();
291 let pattern = args.get(index).and_then(Option::as_ref);
292 if let Some(Pattern::Variable(v)) = pattern {
293 content.mirrors.push(v)
294 };
295 Box::new(VariableContent {
296 pattern,
297 value: None,
298 value_history: Vec::new(),
299 ..*content
300 })
301 })
302 .collect();
303 self.bindings[scope].push(new_scope);
304
305 let old_scope_index = self.current_scope;
306 self.current_scope = scope;
307
308 ScopeTracker {
309 previous_scope: old_scope_index,
310 }
311 }
312
313 pub(crate) fn exit_scope(&mut self, tracker: ScopeTracker) {
314 self.current_scope = tracker.previous_scope;
315 }
316
317 pub fn register_pattern_definition(&mut self, name: &str) -> usize {
318 if let Some(scope) = self.pattern_scopes.get(name) {
319 *scope
320 } else {
321 let registered_scope = self.bindings.len();
323 self.bindings.push(vec![vec![]]);
324 self.pattern_scopes
325 .insert(name.to_string(), registered_scope);
326 registered_scope
327 }
328 }
329
330 pub fn get_name(&self, var: &Variable) -> &str {
339 &self.bindings[var.try_scope().unwrap() as usize]
340 .last()
341 .unwrap()[var.try_index().unwrap() as usize]
342 .name
343 }
344
345 pub fn find_var(&self, name: &str) -> Option<Variable> {
350 if let Some(scope) = self.find_var_scope(name) {
351 return Some(Variable::new(scope.scope as usize, scope.index as usize));
352 }
353 None
354 }
355
356 fn find_var_scope(&self, name: &str) -> Option<VariableScope> {
358 for (scope_index, scope) in self.bindings.iter().enumerate().rev() {
359 for (index, content) in scope.last().unwrap().iter().enumerate() {
360 if content.name == name {
361 return Some(VariableScope::new(scope_index, index));
362 }
363 }
364 }
365 None
366 }
367
368 pub(crate) fn register_var(&mut self, name: &str) -> VariableScope {
369 if let Some(existing) = self.find_var_scope(name) {
370 return existing;
371 };
372
373 let scope = self.current_scope;
374 let the_scope = self.bindings[self.current_scope].last_mut().unwrap();
375 let index = the_scope.len();
376
377 the_scope.push(Box::new(VariableContent::new(name.to_string())));
378 VariableScope::new(scope, index)
379 }
380
381 pub fn find_var_in_scope(&mut self, name: &str) -> Option<Variable> {
383 for (index, content) in self.bindings[self.current_scope]
384 .last()
385 .unwrap()
386 .iter()
387 .enumerate()
388 {
389 if content.name == name {
390 return Some(Variable::new(self.current_scope, index));
391 }
392 }
393 None
394 }
395
396 pub fn trace_var(&self, var: &Variable) -> Variable {
399 if let Ok(scope) = var.try_scope() {
400 if let Ok(index) = var.try_index() {
401 if let Some(Pattern::Variable(v)) =
402 &self.bindings[scope as usize].last().unwrap()[index as usize].pattern
403 {
404 return self.trace_var(v);
405 }
406 }
407 }
408 var.clone()
409 }
410
411 pub fn trace_var_mut(&mut self, var: &Variable) -> Variable {
412 if let Ok(scope) = var.get_scope(self) {
413 if let Ok(index) = var.get_index(self) {
414 if let Some(Pattern::Variable(v)) =
415 &self.bindings[scope as usize].last().unwrap()[index as usize].pattern
416 {
417 return self.trace_var_mut(v);
418 }
419 }
420 }
421 var.clone()
422 }
423
424 pub fn bindings_history_to_ranges(
425 &self,
426 language: &Q::Language<'a>,
427 current_name: Option<&str>,
428 ) -> (Vec<VariableMatch>, Vec<Range>, bool) {
429 let mut matches = vec![];
430 let mut top_level_matches = vec![];
431 let mut suppressed = false;
432 for (i, scope) in self.bindings.iter().enumerate() {
433 for (j, content) in scope.last().unwrap().iter().enumerate() {
434 let name = content.name.clone();
435 let mut var_ranges = vec![];
436 let mut bindings_count = 0;
437 let mut suppressed_count = 0;
438 for value in content.value_history.iter() {
439 if let Some(bindings) = value.get_bindings() {
440 for binding in bindings {
441 bindings_count += 1;
442 if binding.is_suppressed(language, current_name) {
443 suppressed_count += 1;
444 continue;
445 }
446 if let Some(match_position) = binding.position(language) {
447 if name == MATCH_VAR {
449 top_level_matches.push(match_position);
451 }
452 var_ranges.push(match_position);
453 }
454 }
455 }
456 }
457 if suppressed_count > 0 && suppressed_count == bindings_count {
458 suppressed = true;
459 continue;
460 }
461 let scoped_name = format!("{}_{}_{}", i, j, name);
462 let var_match = VariableMatch::new(name, scoped_name, var_ranges);
463 matches.push(var_match);
464 }
465 }
466 suppressed = suppressed && top_level_matches.is_empty();
467 (matches, top_level_matches, suppressed)
468 }
469}
470
471pub type VarRegistry<'a, P> = Vec<Vec<Vec<Box<VariableContent<'a, P>>>>>;