Skip to main content

squawk_ide/
find_references.rs

1use crate::binder::{self, Binder};
2use crate::builtins::BUILTINS_SQL;
3use crate::goto_definition::{FileId, Location};
4use crate::offsets::token_from_offset;
5use crate::resolve;
6use rowan::TextSize;
7use smallvec::{SmallVec, smallvec};
8use squawk_syntax::{
9    SyntaxNodePtr,
10    ast::{self, AstNode},
11    match_ast,
12};
13
14pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec<Location> {
15    let current_binder = binder::bind(file);
16
17    let builtins_tree = ast::SourceFile::parse(BUILTINS_SQL).tree();
18    let builtins_binder = binder::bind(&builtins_tree);
19
20    let Some((target_file, target_defs)) = find_target_defs(
21        file,
22        offset,
23        &current_binder,
24        &builtins_tree,
25        &builtins_binder,
26    ) else {
27        return vec![];
28    };
29
30    let (binder, root) = match target_file {
31        FileId::Current => (&current_binder, file.syntax()),
32        FileId::Builtins => (&builtins_binder, builtins_tree.syntax()),
33    };
34
35    let mut refs: Vec<Location> = vec![];
36
37    if target_file == FileId::Builtins {
38        for ptr in &target_defs {
39            refs.push(Location {
40                file: FileId::Builtins,
41                range: ptr.to_node(builtins_tree.syntax()).text_range(),
42            });
43        }
44    }
45
46    for node in file.syntax().descendants() {
47        match_ast! {
48            match node {
49                ast::NameRef(name_ref) => {
50                    // Check if the ref matches one of the defs
51                    if let Some(found_defs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)
52                        && found_defs.iter().any(|def| target_defs.contains(def))
53                    {
54                        refs.push(Location {
55                            file: FileId::Current,
56                            range: name_ref.syntax().text_range(),
57                        });
58                    }
59                },
60                ast::Name(name) => {
61                    // Find refs also includes the defs so we have to check.
62                    let found = SyntaxNodePtr::new(name.syntax());
63                    if target_defs.contains(&found) {
64                        refs.push(Location {
65                            file: FileId::Current,
66                            range: name.syntax().text_range(),
67                        });
68                    }
69                },
70                _ => (),
71            }
72        }
73    }
74
75    refs.sort_by_key(|loc| (loc.file, loc.range.start()));
76    refs
77}
78
79fn find_target_defs(
80    file: &ast::SourceFile,
81    offset: TextSize,
82    current_binder: &Binder,
83    builtins_tree: &ast::SourceFile,
84    builtins_binder: &Binder,
85) -> Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>)> {
86    let token = token_from_offset(file, offset)?;
87    let parent = token.parent()?;
88
89    if let Some(name) = ast::Name::cast(parent.clone()) {
90        return Some((
91            FileId::Current,
92            smallvec![SyntaxNodePtr::new(name.syntax())],
93        ));
94    }
95
96    if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
97        for file_id in [FileId::Current, FileId::Builtins] {
98            let (binder, root) = match file_id {
99                FileId::Current => (current_binder, file.syntax()),
100                FileId::Builtins => (builtins_binder, builtins_tree.syntax()),
101            };
102            if let Some(ptrs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref) {
103                return Some((file_id, ptrs));
104            }
105        }
106    }
107
108    None
109}
110
111#[cfg(test)]
112mod test {
113    use crate::builtins::BUILTINS_SQL;
114    use crate::find_references::find_references;
115    use crate::goto_definition::FileId;
116    use crate::test_utils::fixture;
117    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
118    use insta::assert_snapshot;
119    use rowan::TextRange;
120    use squawk_syntax::ast;
121
122    #[track_caller]
123    fn find_refs(sql: &str) -> String {
124        let (mut offset, sql) = fixture(sql);
125        offset = offset.checked_sub(1.into()).unwrap_or_default();
126        let parse = ast::SourceFile::parse(&sql);
127        assert_eq!(parse.errors(), vec![]);
128        let file: ast::SourceFile = parse.tree();
129
130        let references = find_references(&file, offset);
131
132        let offset_usize: usize = offset.into();
133
134        let mut current_refs = vec![];
135        let mut builtin_refs = vec![];
136        for (i, location) in references.iter().enumerate() {
137            let label_index = i + 1;
138            match location.file {
139                FileId::Current => current_refs.push((label_index, location.range)),
140                FileId::Builtins => builtin_refs.push((label_index, location.range)),
141            }
142        }
143
144        let has_builtins = !builtin_refs.is_empty();
145
146        let mut snippet = Snippet::source(&sql).fold(true);
147        if has_builtins {
148            snippet = snippet.path("current.sql");
149        }
150        snippet = snippet.annotation(
151            AnnotationKind::Context
152                .span(offset_usize..offset_usize + 1)
153                .label("0. query"),
154        );
155        snippet = annotate_refs(snippet, current_refs);
156
157        let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
158
159        if has_builtins {
160            let builtins_snippet = Snippet::source(BUILTINS_SQL).path("builtin.sql").fold(true);
161            let builtins_snippet = annotate_refs(builtins_snippet, builtin_refs);
162            groups.push(
163                Level::INFO
164                    .primary_title("references")
165                    .element(builtins_snippet),
166            );
167        }
168
169        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
170        renderer
171            .render(&groups)
172            .to_string()
173            .replace("info: references", "")
174    }
175
176    fn annotate_refs<'a>(
177        mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
178        refs: Vec<(usize, TextRange)>,
179    ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
180        for (label_index, range) in refs {
181            snippet = snippet.annotation(
182                AnnotationKind::Context
183                    .span(range.into())
184                    .label(format!("{}. reference", label_index)),
185            );
186        }
187        snippet
188    }
189
190    #[test]
191    fn simple_table_reference() {
192        assert_snapshot!(find_refs("
193create table t();
194drop table t$0;
195"), @r"
196          ╭▸ 
197        2 │ create table t();
198          │              ─ 1. reference
199        3 │ drop table t;
200          │            ┬
201          │            │
202          │            0. query
203          ╰╴           2. reference
204        ");
205    }
206
207    #[test]
208    fn multiple_references() {
209        assert_snapshot!(find_refs("
210create table users();
211drop table users$0;
212table users;
213"), @r"
214          ╭▸ 
215        2 │ create table users();
216          │              ───── 1. reference
217        3 │ drop table users;
218          │            ┬───┬
219          │            │   │
220          │            │   0. query
221          │            2. reference
222        4 │ table users;
223          ╰╴      ───── 3. reference
224        ");
225    }
226
227    #[test]
228    fn join_using_column() {
229        assert_snapshot!(find_refs("
230create table t(id int);
231create table u(id int);
232select * from t join u using (id$0);
233"), @r"
234          ╭▸ 
235        2 │ create table t(id int);
236          │                ── 1. reference
237        3 │ create table u(id int);
238          │                ── 2. reference
239        4 │ select * from t join u using (id);
240          │                               ┬┬
241          │                               ││
242          │                               │0. query
243          ╰╴                              3. reference
244        ");
245    }
246
247    #[test]
248    fn find_from_definition() {
249        assert_snapshot!(find_refs("
250create table t$0();
251drop table t;
252"), @r"
253          ╭▸ 
254        2 │ create table t();
255          │              ┬
256          │              │
257          │              0. query
258          │              1. reference
259        3 │ drop table t;
260          ╰╴           ─ 2. reference
261        ");
262    }
263
264    #[test]
265    fn with_schema_qualified() {
266        assert_snapshot!(find_refs("
267create table public.users();
268drop table public.users$0;
269table users;
270"), @r"
271          ╭▸ 
272        2 │ create table public.users();
273          │                     ───── 1. reference
274        3 │ drop table public.users;
275          │                   ┬───┬
276          │                   │   │
277          │                   │   0. query
278          │                   2. reference
279        4 │ table users;
280          ╰╴      ───── 3. reference
281        ");
282    }
283
284    #[test]
285    fn temp_table_do_not_shadows_public() {
286        assert_snapshot!(find_refs("
287create table t();
288create temp table t$0();
289drop table t;
290"), @r"
291          ╭▸ 
292        3 │ create temp table t();
293          │                   ┬
294          │                   │
295          │                   0. query
296          ╰╴                  1. reference
297        ");
298    }
299
300    #[test]
301    fn different_schema_no_match() {
302        assert_snapshot!(find_refs("
303create table foo.t();
304create table bar.t$0();
305"), @r"
306          ╭▸ 
307        3 │ create table bar.t();
308          │                  ┬
309          │                  │
310          │                  0. query
311          ╰╴                 1. reference
312        ");
313    }
314
315    #[test]
316    fn with_search_path() {
317        assert_snapshot!(find_refs("
318set search_path to myschema;
319create table myschema.users$0();
320drop table users;
321"), @r"
322          ╭▸ 
323        3 │ create table myschema.users();
324          │                       ┬───┬
325          │                       │   │
326          │                       │   0. query
327          │                       1. reference
328        4 │ drop table users;
329          ╰╴           ───── 2. reference
330        ");
331    }
332
333    #[test]
334    fn temp_table_with_pg_temp_schema() {
335        assert_snapshot!(find_refs("
336create temp table t();
337drop table pg_temp.t$0;
338"), @r"
339          ╭▸ 
340        2 │ create temp table t();
341          │                   ─ 1. reference
342        3 │ drop table pg_temp.t;
343          │                    ┬
344          │                    │
345          │                    0. query
346          ╰╴                   2. reference
347        ");
348    }
349
350    #[test]
351    fn case_insensitive() {
352        assert_snapshot!(find_refs("
353create table Users();
354drop table USERS$0;
355table users;
356"), @r"
357          ╭▸ 
358        2 │ create table Users();
359          │              ───── 1. reference
360        3 │ drop table USERS;
361          │            ┬───┬
362          │            │   │
363          │            │   0. query
364          │            2. reference
365        4 │ table users;
366          ╰╴      ───── 3. reference
367        ");
368    }
369    #[test]
370    fn case_insensitive_part_2() {
371        // we should see refs for `drop table` and `table`
372        assert_snapshot!(find_refs(r#"
373create table actors();
374create table "Actors"();
375drop table ACTORS$0;
376table actors;
377"#), @r#"
378          ╭▸ 
379        2 │ create table actors();
380          │              ────── 1. reference
381        3 │ create table "Actors"();
382        4 │ drop table ACTORS;
383          │            ┬────┬
384          │            │    │
385          │            │    0. query
386          │            2. reference
387        5 │ table actors;
388          ╰╴      ────── 3. reference
389        "#);
390    }
391
392    #[test]
393    fn case_insensitive_with_schema() {
394        assert_snapshot!(find_refs("
395create table Public.Users();
396drop table PUBLIC.USERS$0;
397table public.users;
398"), @r"
399          ╭▸ 
400        2 │ create table Public.Users();
401          │                     ───── 1. reference
402        3 │ drop table PUBLIC.USERS;
403          │                   ┬───┬
404          │                   │   │
405          │                   │   0. query
406          │                   2. reference
407        4 │ table public.users;
408          ╰╴             ───── 3. reference
409        ");
410    }
411
412    #[test]
413    fn no_partial_match() {
414        assert_snapshot!(find_refs("
415create table t$0();
416create table temp_t();
417"), @r"
418          ╭▸ 
419        2 │ create table t();
420          │              ┬
421          │              │
422          │              0. query
423          ╰╴             1. reference
424        ");
425    }
426
427    #[test]
428    fn identifier_boundaries() {
429        assert_snapshot!(find_refs("
430create table foo$0();
431drop table foo;
432drop table foo1;
433drop table barfoo;
434drop table foo_bar;
435"), @r"
436          ╭▸ 
437        2 │ create table foo();
438          │              ┬─┬
439          │              │ │
440          │              │ 0. query
441          │              1. reference
442        3 │ drop table foo;
443          ╰╴           ─── 2. reference
444        ");
445    }
446
447    #[test]
448    fn builtin_function_references() {
449        assert_snapshot!(find_refs("
450select now$0();
451select now();
452"), @r"
453              ╭▸ current.sql:2:8
454455            2 │ select now();
456              │        ┬─┬
457              │        │ │
458              │        │ 0. query
459              │        1. reference
460            3 │ select now();
461              │        ─── 2. reference
462              ╰╴
463
464              ╭▸ builtin.sql:10798:28
465466        10798 │ create function pg_catalog.now() returns timestamp with time zone
467              ╰╴                           ─── 3. reference
468        ");
469    }
470}