1use crate::binder::{self, Binder};
2use crate::builtins::BUILTINS_SQL;
3use crate::goto_definition::{FileId, Location};
4use crate::offsets::token_from_offset;
5use crate::resolve;
6use rowan::TextSize;
7use smallvec::{SmallVec, smallvec};
8use squawk_syntax::{
9 SyntaxNodePtr,
10 ast::{self, AstNode},
11 match_ast,
12};
13
14pub fn find_references(file: &ast::SourceFile, offset: TextSize) -> Vec<Location> {
15 let current_binder = binder::bind(file);
17
18 let builtins_tree = ast::SourceFile::parse(BUILTINS_SQL).tree();
20 let builtins_binder = binder::bind(&builtins_tree);
22
23 let Some((target_file, target_defs)) = find_target_defs(
24 file,
25 offset,
26 ¤t_binder,
27 &builtins_tree,
28 &builtins_binder,
29 ) else {
30 return vec![];
31 };
32
33 let (binder, root) = match target_file {
34 FileId::Current => (¤t_binder, file.syntax()),
35 FileId::Builtins => (&builtins_binder, builtins_tree.syntax()),
36 };
37
38 let mut refs: Vec<Location> = vec![];
39
40 if target_file == FileId::Builtins {
41 for ptr in &target_defs {
42 refs.push(Location {
43 file: FileId::Builtins,
44 range: ptr.to_node(builtins_tree.syntax()).text_range(),
45 });
46 }
47 }
48
49 for node in file.syntax().descendants() {
50 match_ast! {
51 match node {
52 ast::NameRef(name_ref) => {
53 if let Some(found_defs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref)
55 && found_defs.iter().any(|def| target_defs.contains(def))
56 {
57 refs.push(Location {
58 file: FileId::Current,
59 range: name_ref.syntax().text_range(),
60 });
61 }
62 },
63 ast::Name(name) => {
64 let found = SyntaxNodePtr::new(name.syntax());
66 if target_defs.contains(&found) {
67 refs.push(Location {
68 file: FileId::Current,
69 range: name.syntax().text_range(),
70 });
71 }
72 },
73 _ => (),
74 }
75 }
76 }
77
78 refs.sort_by_key(|loc| (loc.file, loc.range.start()));
79 refs
80}
81
82fn find_target_defs(
83 file: &ast::SourceFile,
84 offset: TextSize,
85 current_binder: &Binder,
86 builtins_tree: &ast::SourceFile,
87 builtins_binder: &Binder,
88) -> Option<(FileId, SmallVec<[SyntaxNodePtr; 1]>)> {
89 let token = token_from_offset(file, offset)?;
90 let parent = token.parent()?;
91
92 if let Some(name) = ast::Name::cast(parent.clone()) {
93 return Some((
94 FileId::Current,
95 smallvec![SyntaxNodePtr::new(name.syntax())],
96 ));
97 }
98
99 if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
100 for file_id in [FileId::Current, FileId::Builtins] {
101 let (binder, root) = match file_id {
102 FileId::Current => (current_binder, file.syntax()),
103 FileId::Builtins => (builtins_binder, builtins_tree.syntax()),
104 };
105 if let Some(ptrs) = resolve::resolve_name_ref_ptrs(binder, root, &name_ref) {
106 return Some((file_id, ptrs));
107 }
108 }
109 }
110
111 None
112}
113
114#[cfg(test)]
115mod test {
116 use crate::builtins::BUILTINS_SQL;
117 use crate::find_references::find_references;
118 use crate::goto_definition::FileId;
119 use crate::test_utils::fixture;
120 use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
121 use insta::assert_snapshot;
122 use rowan::TextRange;
123 use squawk_syntax::ast;
124
125 #[track_caller]
126 fn find_refs(sql: &str) -> String {
127 let (mut offset, sql) = fixture(sql);
128 offset = offset.checked_sub(1.into()).unwrap_or_default();
129 let parse = ast::SourceFile::parse(&sql);
130 assert_eq!(parse.errors(), vec![]);
131 let file: ast::SourceFile = parse.tree();
132
133 let references = find_references(&file, offset);
134
135 let offset_usize: usize = offset.into();
136
137 let mut current_refs = vec![];
138 let mut builtin_refs = vec![];
139 for (i, location) in references.iter().enumerate() {
140 let label_index = i + 1;
141 match location.file {
142 FileId::Current => current_refs.push((label_index, location.range)),
143 FileId::Builtins => builtin_refs.push((label_index, location.range)),
144 }
145 }
146
147 let has_builtins = !builtin_refs.is_empty();
148
149 let mut snippet = Snippet::source(&sql).fold(true);
150 if has_builtins {
151 snippet = snippet.path("current.sql");
152 }
153 snippet = snippet.annotation(
154 AnnotationKind::Context
155 .span(offset_usize..offset_usize + 1)
156 .label("0. query"),
157 );
158 snippet = annotate_refs(snippet, current_refs);
159
160 let mut groups = vec![Level::INFO.primary_title("references").element(snippet)];
161
162 if has_builtins {
163 let builtins_snippet = Snippet::source(BUILTINS_SQL).path("builtin.sql").fold(true);
164 let builtins_snippet = annotate_refs(builtins_snippet, builtin_refs);
165 groups.push(
166 Level::INFO
167 .primary_title("references")
168 .element(builtins_snippet),
169 );
170 }
171
172 let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
173 renderer
174 .render(&groups)
175 .to_string()
176 .replace("info: references", "")
177 }
178
179 fn annotate_refs<'a>(
180 mut snippet: Snippet<'a, annotate_snippets::Annotation<'a>>,
181 refs: Vec<(usize, TextRange)>,
182 ) -> Snippet<'a, annotate_snippets::Annotation<'a>> {
183 for (label_index, range) in refs {
184 snippet = snippet.annotation(
185 AnnotationKind::Context
186 .span(range.into())
187 .label(format!("{}. reference", label_index)),
188 );
189 }
190 snippet
191 }
192
193 #[test]
194 fn simple_table_reference() {
195 assert_snapshot!(find_refs("
196create table t();
197drop table t$0;
198"), @r"
199 ╭▸
200 2 │ create table t();
201 │ ─ 1. reference
202 3 │ drop table t;
203 │ ┬
204 │ │
205 │ 0. query
206 ╰╴ 2. reference
207 ");
208 }
209
210 #[test]
211 fn multiple_references() {
212 assert_snapshot!(find_refs("
213create table users();
214drop table users$0;
215table users;
216"), @r"
217 ╭▸
218 2 │ create table users();
219 │ ───── 1. reference
220 3 │ drop table users;
221 │ ┬───┬
222 │ │ │
223 │ │ 0. query
224 │ 2. reference
225 4 │ table users;
226 ╰╴ ───── 3. reference
227 ");
228 }
229
230 #[test]
231 fn join_using_column() {
232 assert_snapshot!(find_refs("
233create table t(id int);
234create table u(id int);
235select * from t join u using (id$0);
236"), @r"
237 ╭▸
238 2 │ create table t(id int);
239 │ ── 1. reference
240 3 │ create table u(id int);
241 │ ── 2. reference
242 4 │ select * from t join u using (id);
243 │ ┬┬
244 │ ││
245 │ │0. query
246 ╰╴ 3. reference
247 ");
248 }
249
250 #[test]
251 fn find_from_definition() {
252 assert_snapshot!(find_refs("
253create table t$0();
254drop table t;
255"), @r"
256 ╭▸
257 2 │ create table t();
258 │ ┬
259 │ │
260 │ 0. query
261 │ 1. reference
262 3 │ drop table t;
263 ╰╴ ─ 2. reference
264 ");
265 }
266
267 #[test]
268 fn with_schema_qualified() {
269 assert_snapshot!(find_refs("
270create table public.users();
271drop table public.users$0;
272table users;
273"), @r"
274 ╭▸
275 2 │ create table public.users();
276 │ ───── 1. reference
277 3 │ drop table public.users;
278 │ ┬───┬
279 │ │ │
280 │ │ 0. query
281 │ 2. reference
282 4 │ table users;
283 ╰╴ ───── 3. reference
284 ");
285 }
286
287 #[test]
288 fn temp_table_do_not_shadows_public() {
289 assert_snapshot!(find_refs("
290create table t();
291create temp table t$0();
292drop table t;
293"), @r"
294 ╭▸
295 3 │ create temp table t();
296 │ ┬
297 │ │
298 │ 0. query
299 ╰╴ 1. reference
300 ");
301 }
302
303 #[test]
304 fn different_schema_no_match() {
305 assert_snapshot!(find_refs("
306create table foo.t();
307create table bar.t$0();
308"), @r"
309 ╭▸
310 3 │ create table bar.t();
311 │ ┬
312 │ │
313 │ 0. query
314 ╰╴ 1. reference
315 ");
316 }
317
318 #[test]
319 fn with_search_path() {
320 assert_snapshot!(find_refs("
321set search_path to myschema;
322create table myschema.users$0();
323drop table users;
324"), @r"
325 ╭▸
326 3 │ create table myschema.users();
327 │ ┬───┬
328 │ │ │
329 │ │ 0. query
330 │ 1. reference
331 4 │ drop table users;
332 ╰╴ ───── 2. reference
333 ");
334 }
335
336 #[test]
337 fn temp_table_with_pg_temp_schema() {
338 assert_snapshot!(find_refs("
339create temp table t();
340drop table pg_temp.t$0;
341"), @r"
342 ╭▸
343 2 │ create temp table t();
344 │ ─ 1. reference
345 3 │ drop table pg_temp.t;
346 │ ┬
347 │ │
348 │ 0. query
349 ╰╴ 2. reference
350 ");
351 }
352
353 #[test]
354 fn case_insensitive() {
355 assert_snapshot!(find_refs("
356create table Users();
357drop table USERS$0;
358table users;
359"), @r"
360 ╭▸
361 2 │ create table Users();
362 │ ───── 1. reference
363 3 │ drop table USERS;
364 │ ┬───┬
365 │ │ │
366 │ │ 0. query
367 │ 2. reference
368 4 │ table users;
369 ╰╴ ───── 3. reference
370 ");
371 }
372 #[test]
373 fn case_insensitive_part_2() {
374 assert_snapshot!(find_refs(r#"
376create table actors();
377create table "Actors"();
378drop table ACTORS$0;
379table actors;
380"#), @r#"
381 ╭▸
382 2 │ create table actors();
383 │ ────── 1. reference
384 3 │ create table "Actors"();
385 4 │ drop table ACTORS;
386 │ ┬────┬
387 │ │ │
388 │ │ 0. query
389 │ 2. reference
390 5 │ table actors;
391 ╰╴ ────── 3. reference
392 "#);
393 }
394
395 #[test]
396 fn case_insensitive_with_schema() {
397 assert_snapshot!(find_refs("
398create table Public.Users();
399drop table PUBLIC.USERS$0;
400table public.users;
401"), @r"
402 ╭▸
403 2 │ create table Public.Users();
404 │ ───── 1. reference
405 3 │ drop table PUBLIC.USERS;
406 │ ┬───┬
407 │ │ │
408 │ │ 0. query
409 │ 2. reference
410 4 │ table public.users;
411 ╰╴ ───── 3. reference
412 ");
413 }
414
415 #[test]
416 fn no_partial_match() {
417 assert_snapshot!(find_refs("
418create table t$0();
419create table temp_t();
420"), @r"
421 ╭▸
422 2 │ create table t();
423 │ ┬
424 │ │
425 │ 0. query
426 ╰╴ 1. reference
427 ");
428 }
429
430 #[test]
431 fn identifier_boundaries() {
432 assert_snapshot!(find_refs("
433create table foo$0();
434drop table foo;
435drop table foo1;
436drop table barfoo;
437drop table foo_bar;
438"), @r"
439 ╭▸
440 2 │ create table foo();
441 │ ┬─┬
442 │ │ │
443 │ │ 0. query
444 │ 1. reference
445 3 │ drop table foo;
446 ╰╴ ─── 2. reference
447 ");
448 }
449
450 #[test]
451 fn builtin_function_references() {
452 assert_snapshot!(find_refs("
453select now$0();
454select now();
455"), @"
456 ╭▸ current.sql:2:8
457 │
458 2 │ select now();
459 │ ┬─┬
460 │ │ │
461 │ │ 0. query
462 │ 1. reference
463 3 │ select now();
464 │ ─── 2. reference
465 ╰╴
466
467 ╭▸ builtin.sql:11089:28
468 │
469 11089 │ create function pg_catalog.now() returns timestamp with time zone
470 ╰╴ ─── 3. reference
471 ");
472 }
473}