Skip to main content

squawk_ide/
find_references.rs

1use crate::binder::Binder;
2use crate::builtins::{builtins_binder, parse_builtins};
3use crate::db::{File, bind, parse};
4use crate::goto_definition::{FileId, Location};
5use crate::offsets::token_from_offset;
6use crate::resolve;
7use rowan::TextSize;
8use salsa::Database as Db;
9use smallvec::{SmallVec, smallvec};
10use squawk_syntax::{
11    SyntaxNodePtr,
12    ast::{self, AstNode},
13    match_ast,
14};
15
16#[salsa::tracked]
17pub fn find_references(db: &dyn Db, file: File, offset: TextSize) -> Vec<Location> {
18    let parse = parse(db, file);
19    let source_file = parse.tree();
20
21    let current_binder = bind(db, file);
22
23    let builtins_tree = parse_builtins(db).tree();
24    let builtins_binder = builtins_binder(db);
25
26    let Some((target_file, target_defs)) = find_target_defs(
27        &source_file,
28        offset,
29        &current_binder,
30        &builtins_tree,
31        &builtins_binder,
32    ) else {
33        return vec![];
34    };
35
36    let (binder, root) = match target_file {
37        FileId::Current => (&current_binder, source_file.syntax()),
38        FileId::Builtins => (&builtins_binder, builtins_tree.syntax()),
39    };
40
41    let mut refs: Vec<Location> = vec![];
42
43    if target_file == FileId::Builtins {
44        for ptr in &target_defs {
45            refs.push(Location {
46                file: FileId::Builtins,
47                range: ptr.to_node(builtins_tree.syntax()).text_range(),
48            });
49        }
50    }
51
52    for node in source_file.syntax().descendants() {
53        match_ast! {
54            match node {
55                ast::NameRef(name_ref) => {
56                    // Check if the ref matches one of the defs
57                    if let Some(found_defs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)
58                        && found_defs.iter().any(|def| target_defs.contains(def))
59                    {
60                        refs.push(Location {
61                            file: FileId::Current,
62                            range: name_ref.syntax().text_range(),
63                        });
64                    }
65                },
66                ast::Name(name) => {
67                    // Find refs also includes the defs so we have to check.
68                    let found = SyntaxNodePtr::new(name.syntax());
69                    if target_defs.contains(&found) {
70                        refs.push(Location {
71                            file: FileId::Current,
72                            range: name.syntax().text_range(),
73                        });
74                    }
75                },
76                _ => (),
77            }
78        }
79    }
80
81    refs.sort_by_key(|loc| (loc.file, loc.range.start()));
82    refs
83}
84
85fn find_target_defs(
86    file: &ast::SourceFile,
87    offset: TextSize,
88    current_binder: &Binder,
89    builtins_tree: &ast::SourceFile,
90    builtins_binder: &Binder,
91) -> Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>)> {
92    let token = token_from_offset(file, offset)?;
93    let parent = token.parent()?;
94
95    if let Some(name) = ast::Name::cast(parent.clone()) {
96        return Some((
97            FileId::Current,
98            smallvec![SyntaxNodePtr::new(name.syntax())],
99        ));
100    }
101
102    if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
103        for file_id in [FileId::Current, FileId::Builtins] {
104            let (binder, root) = match file_id {
105                FileId::Current => (current_binder, file.syntax()),
106                FileId::Builtins => (builtins_binder, builtins_tree.syntax()),
107            };
108            if let Some(ptrs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref) {
109                return Some((file_id, ptrs));
110            }
111        }
112    }
113
114    None
115}
116
117#[cfg(test)]
118mod test {
119    use crate::builtins::BUILTINS_SQL;
120    use crate::db::{Database, File};
121    use crate::find_references::find_references;
122    use crate::goto_definition::FileId;
123    use crate::test_utils::fixture;
124    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
125    use insta::assert_snapshot;
126    use rowan::TextRange;
127
128    #[track_caller]
129    fn find_refs(sql: &str) -> String {
130        let (mut offset, sql) = fixture(sql);
131        offset = offset.checked_sub(1.into()).unwrap_or_default();
132        let db = Database::default();
133        let file = File::new(&db, sql.clone().into());
134        assert_eq!(crate::db::parse(&db, file).errors(), vec![]);
135
136        let references = find_references(&db, file, offset);
137
138        let offset_usize: usize = offset.into();
139
140        let mut current_refs = vec![];
141        let mut builtin_refs = vec![];
142        for (i, location) in references.iter().enumerate() {
143            let label_index = i + 1;
144            match location.file {
145                FileId::Current => current_refs.push((label_index, location.range)),
146                FileId::Builtins => builtin_refs.push((label_index, location.range)),
147            }
148        }
149
150        let has_builtins = !builtin_refs.is_empty();
151
152        let mut snippet = Snippet::source(&sql).fold(true);
153        if has_builtins {
154            snippet = snippet.path("current.sql");
155        }
156        snippet = snippet.annotation(
157            AnnotationKind::Context
158                .span(offset_usize..offset_usize + 1)
159                .label("0. query"),
160        );
161        snippet = annotate_refs(snippet, current_refs);
162
163        let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
164
165        if has_builtins {
166            let builtins_snippet = Snippet::source(BUILTINS_SQL).path("builtin.sql").fold(true);
167            let builtins_snippet = annotate_refs(builtins_snippet, builtin_refs);
168            groups.push(
169                Level::INFO
170                    .primary_title("references")
171                    .element(builtins_snippet),
172            );
173        }
174
175        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
176        renderer
177            .render(&groups)
178            .to_string()
179            .replace("info: references", "")
180    }
181
182    fn annotate_refs<'a>(
183        mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
184        refs: Vec<(usize, TextRange)>,
185    ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
186        for (label_index, range) in refs {
187            snippet = snippet.annotation(
188                AnnotationKind::Context
189                    .span(range.into())
190                    .label(format!("{}. reference", label_index)),
191            );
192        }
193        snippet
194    }
195
196    #[test]
197    fn simple_table_reference() {
198        assert_snapshot!(find_refs("
199create table t();
200drop table t$0;
201"), @r"
202          ╭▸ 
203        2 │ create table t();
204          │              ─ 1. reference
205        3 │ drop table t;
206          │            ┬
207          │            │
208          │            0. query
209          ╰╴           2. reference
210        ");
211    }
212
213    #[test]
214    fn multiple_references() {
215        assert_snapshot!(find_refs("
216create table users();
217drop table users$0;
218table users;
219"), @r"
220          ╭▸ 
221        2 │ create table users();
222          │              ───── 1. reference
223        3 │ drop table users;
224          │            ┬───┬
225          │            │   │
226          │            │   0. query
227          │            2. reference
228        4 │ table users;
229          ╰╴      ───── 3. reference
230        ");
231    }
232
233    #[test]
234    fn join_using_column() {
235        assert_snapshot!(find_refs("
236create table t(id int);
237create table u(id int);
238select * from t join u using (id$0);
239"), @r"
240          ╭▸ 
241        2 │ create table t(id int);
242          │                ── 1. reference
243        3 │ create table u(id int);
244          │                ── 2. reference
245        4 │ select * from t join u using (id);
246          │                               ┬┬
247          │                               ││
248          │                               │0. query
249          ╰╴                              3. reference
250        ");
251    }
252
253    #[test]
254    fn find_from_definition() {
255        assert_snapshot!(find_refs("
256create table t$0();
257drop table t;
258"), @r"
259          ╭▸ 
260        2 │ create table t();
261          │              ┬
262          │              │
263          │              0. query
264          │              1. reference
265        3 │ drop table t;
266          ╰╴           ─ 2. reference
267        ");
268    }
269
270    #[test]
271    fn with_schema_qualified() {
272        assert_snapshot!(find_refs("
273create table public.users();
274drop table public.users$0;
275table users;
276"), @r"
277          ╭▸ 
278        2 │ create table public.users();
279          │                     ───── 1. reference
280        3 │ drop table public.users;
281          │                   ┬───┬
282          │                   │   │
283          │                   │   0. query
284          │                   2. reference
285        4 │ table users;
286          ╰╴      ───── 3. reference
287        ");
288    }
289
290    #[test]
291    fn temp_table_do_not_shadows_public() {
292        assert_snapshot!(find_refs("
293create table t();
294create temp table t$0();
295drop table t;
296"), @r"
297          ╭▸ 
298        3 │ create temp table t();
299          │                   ┬
300          │                   │
301          │                   0. query
302          ╰╴                  1. reference
303        ");
304    }
305
306    #[test]
307    fn different_schema_no_match() {
308        assert_snapshot!(find_refs("
309create table foo.t();
310create table bar.t$0();
311"), @r"
312          ╭▸ 
313        3 │ create table bar.t();
314          │                  ┬
315          │                  │
316          │                  0. query
317          ╰╴                 1. reference
318        ");
319    }
320
321    #[test]
322    fn with_search_path() {
323        assert_snapshot!(find_refs("
324set search_path to myschema;
325create table myschema.users$0();
326drop table users;
327"), @r"
328          ╭▸ 
329        3 │ create table myschema.users();
330          │                       ┬───┬
331          │                       │   │
332          │                       │   0. query
333          │                       1. reference
334        4 │ drop table users;
335          ╰╴           ───── 2. reference
336        ");
337    }
338
339    #[test]
340    fn temp_table_with_pg_temp_schema() {
341        assert_snapshot!(find_refs("
342create temp table t();
343drop table pg_temp.t$0;
344"), @r"
345          ╭▸ 
346        2 │ create temp table t();
347          │                   ─ 1. reference
348        3 │ drop table pg_temp.t;
349          │                    ┬
350          │                    │
351          │                    0. query
352          ╰╴                   2. reference
353        ");
354    }
355
356    #[test]
357    fn case_insensitive() {
358        assert_snapshot!(find_refs("
359create table Users();
360drop table USERS$0;
361table users;
362"), @r"
363          ╭▸ 
364        2 │ create table Users();
365          │              ───── 1. reference
366        3 │ drop table USERS;
367          │            ┬───┬
368          │            │   │
369          │            │   0. query
370          │            2. reference
371        4 │ table users;
372          ╰╴      ───── 3. reference
373        ");
374    }
375    #[test]
376    fn case_insensitive_part_2() {
377        // we should see refs for `drop table` and `table`
378        assert_snapshot!(find_refs(r#"
379create table actors();
380create table "Actors"();
381drop table ACTORS$0;
382table actors;
383"#), @r#"
384          ╭▸ 
385        2 │ create table actors();
386          │              ────── 1. reference
387        3 │ create table "Actors"();
388        4 │ drop table ACTORS;
389          │            ┬────┬
390          │            │    │
391          │            │    0. query
392          │            2. reference
393        5 │ table actors;
394          ╰╴      ────── 3. reference
395        "#);
396    }
397
398    #[test]
399    fn case_insensitive_with_schema() {
400        assert_snapshot!(find_refs("
401create table Public.Users();
402drop table PUBLIC.USERS$0;
403table public.users;
404"), @r"
405          ╭▸ 
406        2 │ create table Public.Users();
407          │                     ───── 1. reference
408        3 │ drop table PUBLIC.USERS;
409          │                   ┬───┬
410          │                   │   │
411          │                   │   0. query
412          │                   2. reference
413        4 │ table public.users;
414          ╰╴             ───── 3. reference
415        ");
416    }
417
418    #[test]
419    fn no_partial_match() {
420        assert_snapshot!(find_refs("
421create table t$0();
422create table temp_t();
423"), @r"
424          ╭▸ 
425        2 │ create table t();
426          │              ┬
427          │              │
428          │              0. query
429          ╰╴             1. reference
430        ");
431    }
432
433    #[test]
434    fn identifier_boundaries() {
435        assert_snapshot!(find_refs("
436create table foo$0();
437drop table foo;
438drop table foo1;
439drop table barfoo;
440drop table foo_bar;
441"), @r"
442          ╭▸ 
443        2 │ create table foo();
444          │              ┬─┬
445          │              │ │
446          │              │ 0. query
447          │              1. reference
448        3 │ drop table foo;
449          ╰╴           ─── 2. reference
450        ");
451    }
452
453    #[test]
454    fn builtin_function_references() {
455        assert_snapshot!(find_refs("
456select now$0();
457select now();
458"), @"
459              ╭▸ current.sql:2:8
460461            2 │ select now();
462              │        ┬─┬
463              │        │ │
464              │        │ 0. query
465              │        1. reference
466            3 │ select now();
467              │        ─── 2. reference
468              ╰╴
469
470              ╭▸ builtin.sql:11089:28
471472        11089 │ create function pg_catalog.now() returns timestamp with time zone
473              ╰╴                           ─── 3. reference
474        ");
475    }
476}