squawk_ide/
document_symbols.rs

1use rowan::TextRange;
2use squawk_syntax::ast::{self, AstNode};
3
4use crate::binder::{self, extract_string_literal};
5use crate::resolve::{resolve_function_info, resolve_table_info, resolve_type_info};
6
7#[derive(Debug)]
8pub enum DocumentSymbolKind {
9    Table,
10    Function,
11    Type,
12    Column,
13    Variant,
14}
15
16#[derive(Debug)]
17pub struct DocumentSymbol {
18    pub name: String,
19    pub detail: Option<String>,
20    pub kind: DocumentSymbolKind,
21    /// Range used for determining when cursor is inside the symbol for showing
22    /// in the UI
23    pub full_range: TextRange,
24    /// Range selected when symbol is selected
25    pub focus_range: TextRange,
26    pub children: Vec<DocumentSymbol>,
27}
28
29pub fn document_symbols(file: &ast::SourceFile) -> Vec<DocumentSymbol> {
30    let binder = binder::bind(file);
31    let mut symbols = vec![];
32
33    for stmt in file.stmts() {
34        match stmt {
35            ast::Stmt::CreateTable(create_table) => {
36                if let Some(symbol) = create_table_symbol(&binder, create_table) {
37                    symbols.push(symbol);
38                }
39            }
40            ast::Stmt::CreateFunction(create_function) => {
41                if let Some(symbol) = create_function_symbol(&binder, create_function) {
42                    symbols.push(symbol);
43                }
44            }
45            ast::Stmt::CreateType(create_type) => {
46                if let Some(symbol) = create_type_symbol(&binder, create_type) {
47                    symbols.push(symbol);
48                }
49            }
50            _ => {}
51        }
52    }
53
54    symbols
55}
56
57fn create_table_symbol(
58    binder: &binder::Binder,
59    create_table: ast::CreateTable,
60) -> Option<DocumentSymbol> {
61    let path = create_table.path()?;
62    let segment = path.segment()?;
63    let name_node = segment.name()?;
64
65    let (schema, table_name) = resolve_table_info(binder, &path)?;
66    let name = format!("{}.{}", schema.0, table_name);
67
68    let full_range = create_table.syntax().text_range();
69    let focus_range = name_node.syntax().text_range();
70
71    let mut children = vec![];
72    if let Some(table_arg_list) = create_table.table_arg_list() {
73        for arg in table_arg_list.args() {
74            if let ast::TableArg::Column(column) = arg
75                && let Some(column_symbol) = create_column_symbol(column)
76            {
77                children.push(column_symbol);
78            }
79        }
80    }
81
82    Some(DocumentSymbol {
83        name,
84        detail: None,
85        kind: DocumentSymbolKind::Table,
86        full_range,
87        focus_range,
88        children,
89    })
90}
91
92fn create_function_symbol(
93    binder: &binder::Binder,
94    create_function: ast::CreateFunction,
95) -> Option<DocumentSymbol> {
96    let path = create_function.path()?;
97    let segment = path.segment()?;
98    let name_node = segment.name()?;
99
100    let (schema, function_name) = resolve_function_info(binder, &path)?;
101    let name = format!("{}.{}", schema.0, function_name);
102
103    let full_range = create_function.syntax().text_range();
104    let focus_range = name_node.syntax().text_range();
105
106    Some(DocumentSymbol {
107        name,
108        detail: None,
109        kind: DocumentSymbolKind::Function,
110        full_range,
111        focus_range,
112        children: vec![],
113    })
114}
115
116fn create_type_symbol(
117    binder: &binder::Binder,
118    create_type: ast::CreateType,
119) -> Option<DocumentSymbol> {
120    let path = create_type.path()?;
121    let segment = path.segment()?;
122    let name_node = segment.name()?;
123
124    let (schema, type_name) = resolve_type_info(binder, &path)?;
125    let name = format!("{}.{}", schema.0, type_name);
126
127    let full_range = create_type.syntax().text_range();
128    let focus_range = name_node.syntax().text_range();
129
130    let mut children = vec![];
131    if let Some(variant_list) = create_type.variant_list() {
132        for variant in variant_list.variants() {
133            if let Some(variant_symbol) = create_variant_symbol(variant) {
134                children.push(variant_symbol);
135            }
136        }
137    } else if let Some(column_list) = create_type.column_list() {
138        for column in column_list.columns() {
139            if let Some(column_symbol) = create_column_symbol(column) {
140                children.push(column_symbol);
141            }
142        }
143    }
144
145    Some(DocumentSymbol {
146        name,
147        detail: None,
148        kind: DocumentSymbolKind::Type,
149        full_range,
150        focus_range,
151        children,
152    })
153}
154
155fn create_column_symbol(column: ast::Column) -> Option<DocumentSymbol> {
156    let name_node = column.name()?;
157    let name = name_node.syntax().text().to_string();
158
159    let detail = column.ty().map(|t| t.syntax().text().to_string());
160
161    let full_range = column.syntax().text_range();
162    let focus_range = name_node.syntax().text_range();
163
164    Some(DocumentSymbol {
165        name,
166        detail,
167        kind: DocumentSymbolKind::Column,
168        full_range,
169        focus_range,
170        children: vec![],
171    })
172}
173
174fn create_variant_symbol(variant: ast::Variant) -> Option<DocumentSymbol> {
175    let literal = variant.literal()?;
176    let name = extract_string_literal(&literal)?;
177
178    let full_range = variant.syntax().text_range();
179    let focus_range = literal.syntax().text_range();
180
181    Some(DocumentSymbol {
182        name,
183        detail: None,
184        kind: DocumentSymbolKind::Variant,
185        full_range,
186        focus_range,
187        children: vec![],
188    })
189}
190
191#[cfg(test)]
192mod tests {
193    use super::*;
194    use annotate_snippets::{
195        AnnotationKind, Group, Level, Renderer, Snippet, renderer::DecorStyle,
196    };
197    use insta::assert_snapshot;
198
199    fn symbols_not_found(sql: &str) {
200        let parse = ast::SourceFile::parse(sql);
201        let file = parse.tree();
202        let symbols = document_symbols(&file);
203        if !symbols.is_empty() {
204            panic!("Symbols found. If this is expected, use `symbols` instead.")
205        }
206    }
207
208    fn symbols(sql: &str) -> String {
209        let parse = ast::SourceFile::parse(sql);
210        let file = parse.tree();
211        let symbols = document_symbols(&file);
212        if symbols.is_empty() {
213            panic!("No symbols found. If this is expected, use `symbols_not_found` instead.")
214        }
215
216        let mut output = vec![];
217        for symbol in symbols {
218            let group = symbol_to_group(&symbol, sql);
219            output.push(group);
220        }
221        Renderer::plain()
222            .decor_style(DecorStyle::Unicode)
223            .render(&output)
224            .to_string()
225    }
226
227    fn symbol_to_group<'a>(symbol: &DocumentSymbol, sql: &'a str) -> Group<'a> {
228        let kind = match symbol.kind {
229            DocumentSymbolKind::Table => "table",
230            DocumentSymbolKind::Function => "function",
231            DocumentSymbolKind::Type => "type",
232            DocumentSymbolKind::Column => "column",
233            DocumentSymbolKind::Variant => "variant",
234        };
235
236        let title = if let Some(detail) = &symbol.detail {
237            format!("{}: {} {}", kind, symbol.name, detail)
238        } else {
239            format!("{}: {}", kind, symbol.name)
240        };
241
242        let snippet = Snippet::source(sql)
243            .fold(true)
244            .annotation(
245                AnnotationKind::Primary
246                    .span(symbol.focus_range.into())
247                    .label("focus range"),
248            )
249            .annotation(
250                AnnotationKind::Context
251                    .span(symbol.full_range.into())
252                    .label("full range"),
253            );
254
255        let mut group = Level::INFO.primary_title(title.clone()).element(snippet);
256
257        if !symbol.children.is_empty() {
258            let child_labels: Vec<String> = symbol
259                .children
260                .iter()
261                .map(|child| {
262                    let kind = match child.kind {
263                        DocumentSymbolKind::Column => "column",
264                        DocumentSymbolKind::Variant => "variant",
265                        _ => unreachable!("only columns and variants can be children"),
266                    };
267                    if let Some(detail) = &child.detail {
268                        format!("{}: {} {}", kind, child.name, detail)
269                    } else {
270                        format!("{}: {}", kind, child.name)
271                    }
272                })
273                .collect();
274
275            let mut children_snippet = Snippet::source(sql).fold(true);
276
277            for (i, child) in symbol.children.iter().enumerate() {
278                children_snippet = children_snippet
279                    .annotation(
280                        AnnotationKind::Context
281                            .span(child.full_range.into())
282                            .label(format!("full range for `{}`", child_labels[i].clone())),
283                    )
284                    .annotation(
285                        AnnotationKind::Primary
286                            .span(child.focus_range.into())
287                            .label("focus range"),
288                    );
289            }
290
291            group = group.element(children_snippet);
292        }
293
294        group
295    }
296
297    #[test]
298    fn create_table() {
299        assert_snapshot!(symbols("
300create table users (
301  id int,
302  email citext
303);"), @r"
304        info: table: public.users
305          ╭▸ 
306        2 │   create table users (
307          │   │            ━━━━━ focus range
308          │ ┌─┘
309          │ │
310        3 │ │   id int,
311        4 │ │   email citext
312        5 │ │ );
313          │ └─┘ full range
314315316        3 │     id int,
317          │     ┯━────
318          │     │
319          │     full range for `column: id int`
320          │     focus range
321        4 │     email citext
322          │     ┯━━━━───────
323          │     │
324          │     full range for `column: email citext`
325          ╰╴    focus range
326        ");
327    }
328
329    #[test]
330    fn create_function() {
331        assert_snapshot!(
332            symbols("create function hello() returns void as $$ select 1; $$ language sql;"),
333            @r"
334        info: function: public.hello
335          ╭▸ 
336        1 │ create function hello() returns void as $$ select 1; $$ language sql;
337          │ ┬───────────────┯━━━━───────────────────────────────────────────────
338          │ │               │
339          │ │               focus range
340          ╰╴full range
341        "
342        );
343    }
344
345    #[test]
346    fn multiple_symbols() {
347        assert_snapshot!(symbols("
348create table users (id int);
349create table posts (id int);
350create function get_user(user_id int) returns void as $$ select 1; $$ language sql;
351"), @r"
352        info: table: public.users
353          ╭▸ 
354        2 │ create table users (id int);
355          │ ┬────────────┯━━━━─────────
356          │ │            │
357          │ │            focus range
358          │ full range
359360361        2 │ create table users (id int);
362          │                     ┯━────
363          │                     │
364          │                     full range for `column: id int`
365          │                     focus range
366          ╰╴
367        info: table: public.posts
368          ╭▸ 
369        3 │ create table posts (id int);
370          │ ┬────────────┯━━━━─────────
371          │ │            │
372          │ │            focus range
373          │ full range
374375376        3 │ create table posts (id int);
377          │                     ┯━────
378          │                     │
379          │                     full range for `column: id int`
380          ╰╴                    focus range
381        info: function: public.get_user
382          ╭▸ 
383        4 │ create function get_user(user_id int) returns void as $$ select 1; $$ language sql;
384          │ ┬───────────────┯━━━━━━━──────────────────────────────────────────────────────────
385          │ │               │
386          │ │               focus range
387          ╰╴full range
388        ");
389    }
390
391    #[test]
392    fn qualified_names() {
393        assert_snapshot!(symbols("
394create table public.users (id int);
395create function my_schema.hello() returns void as $$ select 1; $$ language sql;
396"), @r"
397        info: table: public.users
398          ╭▸ 
399        2 │ create table public.users (id int);
400          │ ┬───────────────────┯━━━━─────────
401          │ │                   │
402          │ │                   focus range
403          │ full range
404405406        2 │ create table public.users (id int);
407          │                            ┯━────
408          │                            │
409          │                            full range for `column: id int`
410          │                            focus range
411          ╰╴
412        info: function: my_schema.hello
413          ╭▸ 
414        3 │ create function my_schema.hello() returns void as $$ select 1; $$ language sql;
415          │ ┬─────────────────────────┯━━━━───────────────────────────────────────────────
416          │ │                         │
417          │ │                         focus range
418          ╰╴full range
419        ");
420    }
421
422    #[test]
423    fn create_type() {
424        assert_snapshot!(
425            symbols("create type status as enum ('active', 'inactive');"),
426            @r"
427        info: type: public.status
428          ╭▸ 
429        1 │ create type status as enum ('active', 'inactive');
430          │ ┬───────────┯━━━━━───────────────────────────────
431          │ │           │
432          │ │           focus range
433          │ full range
434435436        1 │ create type status as enum ('active', 'inactive');
437          │                             ┯━━━━━━━  ┯━━━━━━━━━
438          │                             │         │
439          │                             │         full range for `variant: inactive`
440          │                             │         focus range
441          │                             full range for `variant: active`
442          ╰╴                            focus range
443        "
444        );
445    }
446
447    #[test]
448    fn create_type_composite() {
449        assert_snapshot!(
450            symbols("create type person as (name text, age int);"),
451            @r"
452        info: type: public.person
453          ╭▸ 
454        1 │ create type person as (name text, age int);
455          │ ┬───────────┯━━━━━────────────────────────
456          │ │           │
457          │ │           focus range
458          │ full range
459460461        1 │ create type person as (name text, age int);
462          │                        ┯━━━─────  ┯━━────
463          │                        │          │
464          │                        │          full range for `column: age int`
465          │                        │          focus range
466          │                        full range for `column: name text`
467          ╰╴                       focus range
468        "
469        );
470    }
471
472    #[test]
473    fn create_type_composite_multiple_columns() {
474        assert_snapshot!(
475            symbols("create type address as (street text, city text, zip varchar(10));"),
476            @r"
477        info: type: public.address
478          ╭▸ 
479        1 │ create type address as (street text, city text, zip varchar(10));
480          │ ┬───────────┯━━━━━━─────────────────────────────────────────────
481          │ │           │
482          │ │           focus range
483          │ full range
484485486        1 │ create type address as (street text, city text, zip varchar(10));
487          │                         ┯━━━━━─────  ┯━━━─────  ┯━━────────────
488          │                         │            │          │
489          │                         │            │          full range for `column: zip varchar(10)`
490          │                         │            │          focus range
491          │                         │            full range for `column: city text`
492          │                         │            focus range
493          │                         full range for `column: street text`
494          ╰╴                        focus range
495        "
496        );
497    }
498
499    #[test]
500    fn create_type_with_schema() {
501        assert_snapshot!(
502            symbols("create type myschema.status as enum ('active', 'inactive');"),
503            @r"
504        info: type: myschema.status
505          ╭▸ 
506        1 │ create type myschema.status as enum ('active', 'inactive');
507          │ ┬────────────────────┯━━━━━───────────────────────────────
508          │ │                    │
509          │ │                    focus range
510          │ full range
511512513        1 │ create type myschema.status as enum ('active', 'inactive');
514          │                                      ┯━━━━━━━  ┯━━━━━━━━━
515          │                                      │         │
516          │                                      │         full range for `variant: inactive`
517          │                                      │         focus range
518          │                                      full range for `variant: active`
519          ╰╴                                     focus range
520        "
521        );
522    }
523
524    #[test]
525    fn create_type_enum_multiple_variants() {
526        assert_snapshot!(
527            symbols("create type priority as enum ('low', 'medium', 'high', 'urgent');"),
528            @r"
529        info: type: public.priority
530          ╭▸ 
531        1 │ create type priority as enum ('low', 'medium', 'high', 'urgent');
532          │ ┬───────────┯━━━━━━━────────────────────────────────────────────
533          │ │           │
534          │ │           focus range
535          │ full range
536537538        1 │ create type priority as enum ('low', 'medium', 'high', 'urgent');
539          │                               ┯━━━━  ┯━━━━━━━  ┯━━━━━  ┯━━━━━━━
540          │                               │      │         │       │
541          │                               │      │         │       full range for `variant: urgent`
542          │                               │      │         │       focus range
543          │                               │      │         full range for `variant: high`
544          │                               │      │         focus range
545          │                               │      full range for `variant: medium`
546          │                               │      focus range
547          │                               full range for `variant: low`
548          ╰╴                              focus range
549        "
550        );
551    }
552
553    #[test]
554    fn empty_file() {
555        symbols_not_found("")
556    }
557
558    #[test]
559    fn non_create_statements() {
560        symbols_not_found("select * from users;")
561    }
562}