Skip to main content

squawk_ide/
find_references.rs

1use crate::binder::{self, Binder};
2use crate::builtins::BUILTINS_SQL;
3use crate::goto_definition::{FileId, Location};
4use crate::offsets::token_from_offset;
5use crate::resolve;
6use rowan::TextSize;
7use smallvec::{SmallVec, smallvec};
8use squawk_syntax::{
9    SyntaxNodePtr,
10    ast::{self, AstNode},
11    match_ast,
12};
13
14pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec<Location> {
15    // TODO: we should salsa this
16    let current_binder = binder::bind(file);
17
18    // TODO: we should salsa this
19    let builtins_tree = ast::SourceFile::parse(BUILTINS_SQL).tree();
20    // TODO: we should salsa this
21    let builtins_binder = binder::bind(&builtins_tree);
22
23    let Some((target_file, target_defs)) = find_target_defs(
24        file,
25        offset,
26        &current_binder,
27        &builtins_tree,
28        &builtins_binder,
29    ) else {
30        return vec![];
31    };
32
33    let (binder, root) = match target_file {
34        FileId::Current => (&current_binder, file.syntax()),
35        FileId::Builtins => (&builtins_binder, builtins_tree.syntax()),
36    };
37
38    let mut refs: Vec<Location> = vec![];
39
40    if target_file == FileId::Builtins {
41        for ptr in &target_defs {
42            refs.push(Location {
43                file: FileId::Builtins,
44                range: ptr.to_node(builtins_tree.syntax()).text_range(),
45            });
46        }
47    }
48
49    for node in file.syntax().descendants() {
50        match_ast! {
51            match node {
52                ast::NameRef(name_ref) => {
53                    // Check if the ref matches one of the defs
54                    if let Some(found_defs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)
55                        && found_defs.iter().any(|def| target_defs.contains(def))
56                    {
57                        refs.push(Location {
58                            file: FileId::Current,
59                            range: name_ref.syntax().text_range(),
60                        });
61                    }
62                },
63                ast::Name(name) => {
64                    // Find refs also includes the defs so we have to check.
65                    let found = SyntaxNodePtr::new(name.syntax());
66                    if target_defs.contains(&found) {
67                        refs.push(Location {
68                            file: FileId::Current,
69                            range: name.syntax().text_range(),
70                        });
71                    }
72                },
73                _ => (),
74            }
75        }
76    }
77
78    refs.sort_by_key(|loc| (loc.file, loc.range.start()));
79    refs
80}
81
82fn find_target_defs(
83    file: &ast::SourceFile,
84    offset: TextSize,
85    current_binder: &Binder,
86    builtins_tree: &ast::SourceFile,
87    builtins_binder: &Binder,
88) -> Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>)> {
89    let token = token_from_offset(file, offset)?;
90    let parent = token.parent()?;
91
92    if let Some(name) = ast::Name::cast(parent.clone()) {
93        return Some((
94            FileId::Current,
95            smallvec![SyntaxNodePtr::new(name.syntax())],
96        ));
97    }
98
99    if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
100        for file_id in [FileId::Current, FileId::Builtins] {
101            let (binder, root) = match file_id {
102                FileId::Current => (current_binder, file.syntax()),
103                FileId::Builtins => (builtins_binder, builtins_tree.syntax()),
104            };
105            if let Some(ptrs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref) {
106                return Some((file_id, ptrs));
107            }
108        }
109    }
110
111    None
112}
113
114#[cfg(test)]
115mod test {
116    use crate::builtins::BUILTINS_SQL;
117    use crate::find_references::find_references;
118    use crate::goto_definition::FileId;
119    use crate::test_utils::fixture;
120    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
121    use insta::assert_snapshot;
122    use rowan::TextRange;
123    use squawk_syntax::ast;
124
125    #[track_caller]
126    fn find_refs(sql: &str) -> String {
127        let (mut offset, sql) = fixture(sql);
128        offset = offset.checked_sub(1.into()).unwrap_or_default();
129        let parse = ast::SourceFile::parse(&sql);
130        assert_eq!(parse.errors(), vec![]);
131        let file: ast::SourceFile = parse.tree();
132
133        let references = find_references(&file, offset);
134
135        let offset_usize: usize = offset.into();
136
137        let mut current_refs = vec![];
138        let mut builtin_refs = vec![];
139        for (i, location) in references.iter().enumerate() {
140            let label_index = i + 1;
141            match location.file {
142                FileId::Current => current_refs.push((label_index, location.range)),
143                FileId::Builtins => builtin_refs.push((label_index, location.range)),
144            }
145        }
146
147        let has_builtins = !builtin_refs.is_empty();
148
149        let mut snippet = Snippet::source(&sql).fold(true);
150        if has_builtins {
151            snippet = snippet.path("current.sql");
152        }
153        snippet = snippet.annotation(
154            AnnotationKind::Context
155                .span(offset_usize..offset_usize + 1)
156                .label("0. query"),
157        );
158        snippet = annotate_refs(snippet, current_refs);
159
160        let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
161
162        if has_builtins {
163            let builtins_snippet = Snippet::source(BUILTINS_SQL).path("builtin.sql").fold(true);
164            let builtins_snippet = annotate_refs(builtins_snippet, builtin_refs);
165            groups.push(
166                Level::INFO
167                    .primary_title("references")
168                    .element(builtins_snippet),
169            );
170        }
171
172        let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
173        renderer
174            .render(&groups)
175            .to_string()
176            .replace("info: references", "")
177    }
178
179    fn annotate_refs<'a>(
180        mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
181        refs: Vec<(usize, TextRange)>,
182    ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
183        for (label_index, range) in refs {
184            snippet = snippet.annotation(
185                AnnotationKind::Context
186                    .span(range.into())
187                    .label(format!("{}. reference", label_index)),
188            );
189        }
190        snippet
191    }
192
193    #[test]
194    fn simple_table_reference() {
195        assert_snapshot!(find_refs("
196create table t();
197drop table t$0;
198"), @r"
199          ╭▸ 
200        2 │ create table t();
201          │              ─ 1. reference
202        3 │ drop table t;
203          │            ┬
204          │            │
205          │            0. query
206          ╰╴           2. reference
207        ");
208    }
209
210    #[test]
211    fn multiple_references() {
212        assert_snapshot!(find_refs("
213create table users();
214drop table users$0;
215table users;
216"), @r"
217          ╭▸ 
218        2 │ create table users();
219          │              ───── 1. reference
220        3 │ drop table users;
221          │            ┬───┬
222          │            │   │
223          │            │   0. query
224          │            2. reference
225        4 │ table users;
226          ╰╴      ───── 3. reference
227        ");
228    }
229
230    #[test]
231    fn join_using_column() {
232        assert_snapshot!(find_refs("
233create table t(id int);
234create table u(id int);
235select * from t join u using (id$0);
236"), @r"
237          ╭▸ 
238        2 │ create table t(id int);
239          │                ── 1. reference
240        3 │ create table u(id int);
241          │                ── 2. reference
242        4 │ select * from t join u using (id);
243          │                               ┬┬
244          │                               ││
245          │                               │0. query
246          ╰╴                              3. reference
247        ");
248    }
249
250    #[test]
251    fn find_from_definition() {
252        assert_snapshot!(find_refs("
253create table t$0();
254drop table t;
255"), @r"
256          ╭▸ 
257        2 │ create table t();
258          │              ┬
259          │              │
260          │              0. query
261          │              1. reference
262        3 │ drop table t;
263          ╰╴           ─ 2. reference
264        ");
265    }
266
267    #[test]
268    fn with_schema_qualified() {
269        assert_snapshot!(find_refs("
270create table public.users();
271drop table public.users$0;
272table users;
273"), @r"
274          ╭▸ 
275        2 │ create table public.users();
276          │                     ───── 1. reference
277        3 │ drop table public.users;
278          │                   ┬───┬
279          │                   │   │
280          │                   │   0. query
281          │                   2. reference
282        4 │ table users;
283          ╰╴      ───── 3. reference
284        ");
285    }
286
287    #[test]
288    fn temp_table_do_not_shadows_public() {
289        assert_snapshot!(find_refs("
290create table t();
291create temp table t$0();
292drop table t;
293"), @r"
294          ╭▸ 
295        3 │ create temp table t();
296          │                   ┬
297          │                   │
298          │                   0. query
299          ╰╴                  1. reference
300        ");
301    }
302
303    #[test]
304    fn different_schema_no_match() {
305        assert_snapshot!(find_refs("
306create table foo.t();
307create table bar.t$0();
308"), @r"
309          ╭▸ 
310        3 │ create table bar.t();
311          │                  ┬
312          │                  │
313          │                  0. query
314          ╰╴                 1. reference
315        ");
316    }
317
318    #[test]
319    fn with_search_path() {
320        assert_snapshot!(find_refs("
321set search_path to myschema;
322create table myschema.users$0();
323drop table users;
324"), @r"
325          ╭▸ 
326        3 │ create table myschema.users();
327          │                       ┬───┬
328          │                       │   │
329          │                       │   0. query
330          │                       1. reference
331        4 │ drop table users;
332          ╰╴           ───── 2. reference
333        ");
334    }
335
336    #[test]
337    fn temp_table_with_pg_temp_schema() {
338        assert_snapshot!(find_refs("
339create temp table t();
340drop table pg_temp.t$0;
341"), @r"
342          ╭▸ 
343        2 │ create temp table t();
344          │                   ─ 1. reference
345        3 │ drop table pg_temp.t;
346          │                    ┬
347          │                    │
348          │                    0. query
349          ╰╴                   2. reference
350        ");
351    }
352
353    #[test]
354    fn case_insensitive() {
355        assert_snapshot!(find_refs("
356create table Users();
357drop table USERS$0;
358table users;
359"), @r"
360          ╭▸ 
361        2 │ create table Users();
362          │              ───── 1. reference
363        3 │ drop table USERS;
364          │            ┬───┬
365          │            │   │
366          │            │   0. query
367          │            2. reference
368        4 │ table users;
369          ╰╴      ───── 3. reference
370        ");
371    }
372    #[test]
373    fn case_insensitive_part_2() {
374        // we should see refs for `drop table` and `table`
375        assert_snapshot!(find_refs(r#"
376create table actors();
377create table "Actors"();
378drop table ACTORS$0;
379table actors;
380"#), @r#"
381          ╭▸ 
382        2 │ create table actors();
383          │              ────── 1. reference
384        3 │ create table "Actors"();
385        4 │ drop table ACTORS;
386          │            ┬────┬
387          │            │    │
388          │            │    0. query
389          │            2. reference
390        5 │ table actors;
391          ╰╴      ────── 3. reference
392        "#);
393    }
394
395    #[test]
396    fn case_insensitive_with_schema() {
397        assert_snapshot!(find_refs("
398create table Public.Users();
399drop table PUBLIC.USERS$0;
400table public.users;
401"), @r"
402          ╭▸ 
403        2 │ create table Public.Users();
404          │                     ───── 1. reference
405        3 │ drop table PUBLIC.USERS;
406          │                   ┬───┬
407          │                   │   │
408          │                   │   0. query
409          │                   2. reference
410        4 │ table public.users;
411          ╰╴             ───── 3. reference
412        ");
413    }
414
415    #[test]
416    fn no_partial_match() {
417        assert_snapshot!(find_refs("
418create table t$0();
419create table temp_t();
420"), @r"
421          ╭▸ 
422        2 │ create table t();
423          │              ┬
424          │              │
425          │              0. query
426          ╰╴             1. reference
427        ");
428    }
429
430    #[test]
431    fn identifier_boundaries() {
432        assert_snapshot!(find_refs("
433create table foo$0();
434drop table foo;
435drop table foo1;
436drop table barfoo;
437drop table foo_bar;
438"), @r"
439          ╭▸ 
440        2 │ create table foo();
441          │              ┬─┬
442          │              │ │
443          │              │ 0. query
444          │              1. reference
445        3 │ drop table foo;
446          ╰╴           ─── 2. reference
447        ");
448    }
449
450    #[test]
451    fn builtin_function_references() {
452        assert_snapshot!(find_refs("
453select now$0();
454select now();
455"), @"
456              ╭▸ current.sql:2:8
457458            2 │ select now();
459              │        ┬─┬
460              │        │ │
461              │        │ 0. query
462              │        1. reference
463            3 │ select now();
464              │        ─── 2. reference
465              ╰╴
466
467              ╭▸ builtin.sql:11089:28
468469        11089 │ create function pg_catalog.now() returns timestamp with time zone
470              ╰╴                           ─── 3. reference
471        ");
472    }
473}