squawk_ide/
goto_definition.rs

1use crate::binder;
2use crate::offsets::token_from_offset;
3use crate::resolve;
4use rowan::{TextRange, TextSize};
5use squawk_syntax::{
6    SyntaxKind,
7    ast::{self, AstNode},
8};
9
10pub fn goto_definition(file: ast::SourceFile, offset: TextSize) -> Option<TextRange> {
11    let token = token_from_offset(&file, offset)?;
12    let parent = token.parent()?;
13
14    // goto def on case exprs
15    if (token.kind() == SyntaxKind::WHEN_KW && parent.kind() == SyntaxKind::WHEN_CLAUSE)
16        || (token.kind() == SyntaxKind::ELSE_KW && parent.kind() == SyntaxKind::ELSE_CLAUSE)
17        || (token.kind() == SyntaxKind::END_KW && parent.kind() == SyntaxKind::CASE_EXPR)
18    {
19        for parent in token.parent_ancestors() {
20            if let Some(case_expr) = ast::CaseExpr::cast(parent)
21                && let Some(case_token) = case_expr.case_token()
22            {
23                return Some(case_token.text_range());
24            }
25        }
26    }
27
28    // goto def on COMMIT -> BEGIN/START TRANSACTION
29    if ast::Commit::can_cast(parent.kind()) {
30        if let Some(begin_range) = find_preceding_begin(&file, token.text_range().start()) {
31            return Some(begin_range);
32        }
33    }
34
35    // goto def on ROLLBACK -> BEGIN/START TRANSACTION
36    if ast::Rollback::can_cast(parent.kind()) {
37        if let Some(begin_range) = find_preceding_begin(&file, token.text_range().start()) {
38            return Some(begin_range);
39        }
40    }
41
42    // goto def on BEGIN/START TRANSACTION -> COMMIT or ROLLBACK
43    if ast::Begin::can_cast(parent.kind()) {
44        if let Some(end_range) = find_following_commit_or_rollback(&file, token.text_range().end())
45        {
46            return Some(end_range);
47        }
48    }
49
50    if let Some(name_ref) = ast::NameRef::cast(parent.clone()) {
51        let binder_output = binder::bind(&file);
52        if let Some(ptr) = resolve::resolve_name_ref(&binder_output, &name_ref) {
53            let node = ptr.to_node(file.syntax());
54            return Some(node.text_range());
55        }
56    }
57
58    return None;
59}
60
61fn find_preceding_begin(file: &ast::SourceFile, before: TextSize) -> Option<TextRange> {
62    let mut last_begin: Option<TextRange> = None;
63    for stmt in file.stmts() {
64        if let ast::Stmt::Begin(begin) = stmt {
65            let range = begin.syntax().text_range();
66            if range.end() <= before {
67                last_begin = Some(range);
68            }
69        }
70    }
71    last_begin
72}
73
74fn find_following_commit_or_rollback(file: &ast::SourceFile, after: TextSize) -> Option<TextRange> {
75    for stmt in file.stmts() {
76        let range = match &stmt {
77            ast::Stmt::Commit(commit) => commit.syntax().text_range(),
78            ast::Stmt::Rollback(rollback) => rollback.syntax().text_range(),
79            _ => continue,
80        };
81        if range.start() >= after {
82            return Some(range);
83        }
84    }
85    None
86}
87
88#[cfg(test)]
89mod test {
90    use crate::goto_definition::goto_definition;
91    use crate::test_utils::fixture;
92    use annotate_snippets::{AnnotationKind, Level, Renderer, Snippet, renderer::DecorStyle};
93    use insta::assert_snapshot;
94    use log::info;
95    use squawk_syntax::ast;
96
97    #[track_caller]
98    fn goto(sql: &str) -> String {
99        goto_(sql).expect("should always find a definition")
100    }
101
102    #[track_caller]
103    fn goto_(sql: &str) -> Option<String> {
104        info!("starting");
105        let (mut offset, sql) = fixture(sql);
106        // For go to def we want the previous character since we usually put the
107        // marker after the item we're trying to go to def on.
108        offset = offset.checked_sub(1.into()).unwrap_or_default();
109        let parse = ast::SourceFile::parse(&sql);
110        assert_eq!(parse.errors(), vec![]);
111        let file: ast::SourceFile = parse.tree();
112        if let Some(result) = goto_definition(file, offset) {
113            let offset: usize = offset.into();
114            let group = Level::INFO.primary_title("definition").element(
115                Snippet::source(&sql)
116                    .fold(true)
117                    .annotation(
118                        AnnotationKind::Context
119                            .span(result.into())
120                            .label("2. destination"),
121                    )
122                    .annotation(
123                        AnnotationKind::Context
124                            .span(offset..offset + 1)
125                            .label("1. source"),
126                    ),
127            );
128            let renderer = Renderer::plain().decor_style(DecorStyle::Unicode);
129            return Some(
130                renderer
131                    .render(&[group])
132                    .to_string()
133                    // hacky cleanup to make the text shorter
134                    .replace("info: definition", ""),
135            );
136        }
137        None
138    }
139
140    fn goto_not_found(sql: &str) {
141        assert!(goto_(sql).is_none(), "Should not find a definition");
142    }
143
144    #[test]
145    fn goto_case_when() {
146        assert_snapshot!(goto("
147select case when$0 x > 1 then 1 else 2 end;
148"), @r"
149          ╭▸ 
150        2 │ select case when x > 1 then 1 else 2 end;
151          │        ┬───    ─ 1. source
152          │        │
153          ╰╴       2. destination
154        ");
155    }
156
157    #[test]
158    fn goto_case_else() {
159        assert_snapshot!(goto("
160select case when x > 1 then 1 else$0 2 end;
161"), @r"
162          ╭▸ 
163        2 │ select case when x > 1 then 1 else 2 end;
164          ╰╴       ──── 2. destination       ─ 1. source
165        ");
166    }
167
168    #[test]
169    fn goto_case_end() {
170        assert_snapshot!(goto("
171select case when x > 1 then 1 else 2 end$0;
172"), @r"
173          ╭▸ 
174        2 │ select case when x > 1 then 1 else 2 end;
175          ╰╴       ──── 2. destination             ─ 1. source
176        ");
177    }
178
179    #[test]
180    fn goto_case_end_trailing_semi() {
181        assert_snapshot!(goto("
182select case when x > 1 then 1 else 2 end;$0
183"), @r"
184          ╭▸ 
185        2 │ select case when x > 1 then 1 else 2 end;
186          ╰╴       ──── 2. destination              ─ 1. source
187        ");
188    }
189
190    #[test]
191    fn goto_case_then_not_found() {
192        goto_not_found(
193            "
194select case when x > 1 then$0 1 else 2 end;
195",
196        )
197    }
198
199    #[test]
200    fn rollback_to_begin() {
201        assert_snapshot!(goto(
202            "
203begin;
204select 1;
205rollback$0;
206",
207        ), @r"
208          ╭▸ 
209        2 │ begin;
210          │ ───── 2. destination
211        3 │ select 1;
212        4 │ rollback;
213          ╰╴       ─ 1. source
214        ");
215    }
216
217    #[test]
218    fn goto_drop_table() {
219        assert_snapshot!(goto("
220create table t();
221drop table t$0;
222"), @r"
223          ╭▸ 
224        2 │ create table t();
225          │              ─ 2. destination
226        3 │ drop table t;
227          ╰╴           ─ 1. source
228        ");
229    }
230
231    #[test]
232    fn goto_drop_table_with_schema() {
233        assert_snapshot!(goto("
234create table public.t();
235drop table t$0;
236"), @r"
237          ╭▸ 
238        2 │ create table public.t();
239          │                     ─ 2. destination
240        3 │ drop table t;
241          ╰╴           ─ 1. source
242        ");
243
244        assert_snapshot!(goto("
245create table foo.t();
246drop table foo.t$0;
247"), @r"
248          ╭▸ 
249        2 │ create table foo.t();
250          │                  ─ 2. destination
251        3 │ drop table foo.t;
252          ╰╴               ─ 1. source
253        ");
254
255        goto_not_found(
256            "
257-- defaults to public schema
258create table t();
259drop table foo.t$0;
260",
261        );
262    }
263
264    #[test]
265    fn goto_drop_temp_table() {
266        assert_snapshot!(goto("
267create temp table t();
268drop table t$0;
269"), @r"
270          ╭▸ 
271        2 │ create temp table t();
272          │                   ─ 2. destination
273        3 │ drop table t;
274          ╰╴           ─ 1. source
275        ");
276    }
277
278    #[test]
279    fn goto_drop_temporary_table() {
280        assert_snapshot!(goto("
281create temporary table t();
282drop table t$0;
283"), @r"
284          ╭▸ 
285        2 │ create temporary table t();
286          │                        ─ 2. destination
287        3 │ drop table t;
288          ╰╴           ─ 1. source
289        ");
290    }
291
292    #[test]
293    fn goto_drop_temp_table_with_pg_temp_schema() {
294        assert_snapshot!(goto("
295create temp table t();
296drop table pg_temp.t$0;
297"), @r"
298          ╭▸ 
299        2 │ create temp table t();
300          │                   ─ 2. destination
301        3 │ drop table pg_temp.t;
302          ╰╴                   ─ 1. source
303        ");
304    }
305
306    #[test]
307    fn goto_drop_temp_table_shadows_public() {
308        // temp tables shadow public tables when no schema is specified
309        assert_snapshot!(goto("
310create table t();
311create temp table t();
312drop table t$0;
313"), @r"
314          ╭▸ 
315        3 │ create temp table t();
316          │                   ─ 2. destination
317        4 │ drop table t;
318          ╰╴           ─ 1. source
319        ");
320    }
321
322    #[test]
323    fn goto_drop_public_table_when_temp_exists() {
324        // can still access public table explicitly
325        assert_snapshot!(goto("
326create table t();
327create temp table t();
328drop table public.t$0;
329"), @r"
330          ╭▸ 
331        2 │ create table t();
332          │              ─ 2. destination
333        3 │ create temp table t();
334        4 │ drop table public.t;
335          ╰╴                  ─ 1. source
336        ");
337    }
338
339    #[test]
340    fn goto_drop_table_defined_after() {
341        assert_snapshot!(goto("
342drop table t$0;
343create table t();
344"), @r"
345          ╭▸ 
346        2 │ drop table t;
347          │            ─ 1. source
348        3 │ create table t();
349          ╰╴             ─ 2. destination
350        ");
351    }
352
353    #[test]
354    fn begin_to_rollback() {
355        assert_snapshot!(goto(
356            "
357begin$0;
358select 1;
359rollback;
360commit;
361",
362        ), @r"
363          ╭▸ 
364        2 │ begin;
365          │     ─ 1. source
366        3 │ select 1;
367        4 │ rollback;
368          ╰╴──────── 2. destination
369        ");
370    }
371
372    #[test]
373    fn commit_to_begin() {
374        assert_snapshot!(goto(
375            "
376begin;
377select 1;
378commit$0;
379",
380        ), @r"
381          ╭▸ 
382        2 │ begin;
383          │ ───── 2. destination
384        3 │ select 1;
385        4 │ commit;
386          ╰╴     ─ 1. source
387        ");
388    }
389
390    #[test]
391    fn begin_to_commit() {
392        assert_snapshot!(goto(
393            "
394begin$0;
395select 1;
396commit;
397",
398        ), @r"
399          ╭▸ 
400        2 │ begin;
401          │     ─ 1. source
402        3 │ select 1;
403        4 │ commit;
404          ╰╴────── 2. destination
405        ");
406    }
407
408    #[test]
409    fn commit_to_start_transaction() {
410        assert_snapshot!(goto(
411            "
412start transaction;
413select 1;
414commit$0;
415",
416        ), @r"
417          ╭▸ 
418        2 │ start transaction;
419          │ ───────────────── 2. destination
420        3 │ select 1;
421        4 │ commit;
422          ╰╴     ─ 1. source
423        ");
424    }
425
426    #[test]
427    fn start_transaction_to_commit() {
428        assert_snapshot!(goto(
429            "
430start$0 transaction;
431select 1;
432commit;
433",
434        ), @r"
435          ╭▸ 
436        2 │ start transaction;
437          │     ─ 1. source
438        3 │ select 1;
439        4 │ commit;
440          ╰╴────── 2. destination
441        ");
442    }
443}