Skip to main content

squawk_syntax/
lib.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/lib.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27pub mod ast;
28mod generated;
29pub mod identifier;
30mod parsing;
31mod ptr;
32pub mod quote;
33pub mod syntax_error;
34mod syntax_node;
35mod token_text;
36mod validation;
37
38#[cfg(test)]
39mod test;
40
41use std::{marker::PhantomData, sync::Arc};
42
43pub use squawk_parser::SyntaxKind;
44
45use ast::AstNode;
46pub use ptr::{AstPtr, SyntaxNodePtr};
47use rowan::GreenNode;
48use syntax_error::SyntaxError;
49pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
50pub use token_text::TokenText;
51
52/// `Parse` is the result of the parsing: a syntax tree and a collection of
53/// errors.
54///
55/// Note that we always produce a syntax tree, even for completely invalid
56/// files.
57#[derive(Debug, PartialEq, Eq)]
58pub struct Parse<T> {
59    green: GreenNode,
60    errors: Option<Arc<[SyntaxError]>>,
61    _ty: PhantomData<fn() -> T>,
62}
63
64impl<T> Clone for Parse<T> {
65    fn clone(&self) -> Parse<T> {
66        Parse {
67            green: self.green.clone(),
68            errors: self.errors.clone(),
69            _ty: PhantomData,
70        }
71    }
72}
73
74impl<T> Parse<T> {
75    fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
76        Parse {
77            green,
78            errors: if errors.is_empty() {
79                None
80            } else {
81                Some(errors.into())
82            },
83            _ty: PhantomData,
84        }
85    }
86
87    pub fn syntax_node(&self) -> SyntaxNode {
88        SyntaxNode::new_root(self.green.clone())
89    }
90
91    pub fn errors(&self) -> Vec<SyntaxError> {
92        let mut errors = if let Some(e) = self.errors.as_deref() {
93            e.to_vec()
94        } else {
95            vec![]
96        };
97        validation::validate(&self.syntax_node(), &mut errors);
98        errors
99    }
100}
101
102impl<T: AstNode> Parse<T> {
103    /// Converts this parse result into a parse result for an untyped syntax tree.
104    pub fn to_syntax(self) -> Parse<SyntaxNode> {
105        Parse {
106            green: self.green,
107            errors: self.errors,
108            _ty: PhantomData,
109        }
110    }
111
112    /// Gets the parsed syntax tree as a typed ast node.
113    ///
114    /// # Panics
115    ///
116    /// Panics if the root node cannot be casted into the typed ast node
117    /// (e.g. if it's an `ERROR` node).
118    pub fn tree(&self) -> T {
119        T::cast(self.syntax_node()).unwrap()
120    }
121
122    /// Converts from `Parse<T>` to [`Result<T, Vec<SyntaxError>>`].
123    pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
124        match self.errors() {
125            errors if !errors.is_empty() => Err(errors),
126            _ => Ok(self.tree()),
127        }
128    }
129}
130
131impl Parse<SyntaxNode> {
132    pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
133        if N::cast(self.syntax_node()).is_some() {
134            Some(Parse {
135                green: self.green,
136                errors: self.errors,
137                _ty: PhantomData,
138            })
139        } else {
140            None
141        }
142    }
143}
144
145/// `SourceFile` represents a parse tree for a single SQL file.
146pub use crate::ast::SourceFile;
147
148impl SourceFile {
149    pub fn parse(text: &str) -> Parse<SourceFile> {
150        let (green, errors) = parsing::parse_text(text);
151        let root = SyntaxNode::new_root(green.clone());
152
153        assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
154        Parse::new(green, errors)
155    }
156}
157
158/// Matches a `SyntaxNode` against an `ast` type.
159///
160/// # Example:
161///
162/// ```ignore
163/// match_ast! {
164///     match node {
165///         ast::CallExpr(it) => { ... },
166///         ast::MethodCallExpr(it) => { ... },
167///         ast::MacroCall(it) => { ... },
168///         _ => None,
169///     }
170/// }
171/// ```
172#[macro_export]
173macro_rules! match_ast {
174    (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
175
176    (match ($node:expr) {
177        $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
178        _ => $catch_all:expr $(,)?
179    }) => {{
180        $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
181        { $catch_all }
182    }};
183}
184
185/// This test does not assert anything and instead just shows off the crate's
186/// API.
187#[test]
188fn api_walkthrough() {
189    use ast::SourceFile;
190    use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
191    use std::fmt::Write;
192
193    let source_code = "
194        create function foo(p int8)
195        returns int
196        as 'select 1 + 1'
197        language sql;
198    ";
199    // `SourceFile` is the main entry point.
200    //
201    // The `parse` method returns a `Parse` -- a pair of syntax tree and a list
202    // of errors. That is, syntax tree is constructed even in presence of errors.
203    let parse = SourceFile::parse(source_code);
204    assert!(parse.errors().is_empty());
205
206    // The `tree` method returns an owned syntax node of type `SourceFile`.
207    // Owned nodes are cheap: inside, they are `Rc` handles to the underling data.
208    let file: SourceFile = parse.tree();
209
210    // `SourceFile` is the root of the syntax tree. We can iterate file's items.
211    // Let's fetch the `foo` function.
212    let mut func = None;
213    for stmt in file.stmts() {
214        match stmt {
215            ast::Stmt::CreateFunction(f) => func = Some(f),
216            _ => unreachable!(),
217        }
218    }
219    let func: ast::CreateFunction = func.unwrap();
220
221    // Each AST node has a bunch of getters for children. All getters return
222    // `Option`s though, to account for incomplete code. Some getters are common
223    // for several kinds of node. In this case, a trait like `ast::NameOwner`
224    // usually exists. By convention, all ast types should be used with `ast::`
225    // qualifier.
226    let path: Option<ast::Path> = func.path();
227    let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
228    assert_eq!(name.text(), "foo");
229
230    // return
231    let ret_type: Option<ast::RetType> = func.ret_type();
232    let r_ty = &ret_type.unwrap().ty().unwrap();
233    let type_: &ast::PathType = match &r_ty {
234        ast::Type::PathType(r) => r,
235        _ => unreachable!(),
236    };
237    let type_path: ast::Path = type_.path().unwrap();
238    assert_eq!(type_path.syntax().to_string(), "int");
239
240    // params
241    let param_list: ast::ParamList = func.param_list().unwrap();
242    let param: ast::Param = param_list.params().next().unwrap();
243
244    let param_name: ast::Name = param.name().unwrap();
245    assert_eq!(param_name.syntax().to_string(), "p");
246
247    let param_ty: ast::Type = param.ty().unwrap();
248    assert_eq!(param_ty.syntax().to_string(), "int8");
249
250    let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
251
252    // Enums are used to group related ast nodes together, and can be used for
253    // matching. However, because there are no public fields, it's possible to
254    // match only the top level enum: that is the price we pay for increased API
255    // flexibility
256    let func_option = func_option_list.options().next().unwrap();
257    let option: &ast::AsFuncOption = match &func_option {
258        ast::FuncOption::AsFuncOption(o) => o,
259        _ => unreachable!(),
260    };
261    let definition: ast::Literal = option.definition().unwrap();
262    assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
263
264    // Besides the "typed" AST API, there's an untyped CST one as well.
265    // To switch from AST to CST, call `.syntax()` method:
266    let func_option_syntax = func_option.syntax();
267
268    // Note how `func_option_syntax` and `option` are in fact the same node underneath:
269    assert!(func_option_syntax == option.syntax());
270
271    // To go from CST to AST, `AstNode::cast` function is used:
272    let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
273        Some(e) => e,
274        None => unreachable!(),
275    };
276
277    // The two properties each syntax node has is a `SyntaxKind`:
278    assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
279
280    // And text range:
281    assert_eq!(
282        func_option_syntax.text_range(),
283        TextRange::new(65.into(), 82.into())
284    );
285
286    // You can get node's text as a `SyntaxText` object, which will traverse the
287    // tree collecting token's text:
288    let text: SyntaxText = func_option_syntax.text();
289    assert_eq!(text.to_string(), "as 'select 1 + 1'");
290
291    // There's a bunch of traversal methods on `SyntaxNode`:
292    assert_eq!(
293        func_option_syntax.parent().as_ref(),
294        Some(func_option_list.syntax())
295    );
296    assert_eq!(
297        param_list
298            .syntax()
299            .first_child_or_token()
300            .map(|it| it.kind()),
301        Some(SyntaxKind::L_PAREN)
302    );
303    assert_eq!(
304        func_option_syntax
305            .next_sibling_or_token()
306            .map(|it| it.kind()),
307        Some(SyntaxKind::WHITESPACE)
308    );
309
310    // As well as some iterator helpers:
311    let f = func_option_syntax
312        .ancestors()
313        .find_map(ast::CreateFunction::cast);
314    assert_eq!(f, Some(func));
315    assert!(
316        param
317            .syntax()
318            .siblings_with_tokens(Direction::Next)
319            .any(|it| it.kind() == SyntaxKind::R_PAREN)
320    );
321    assert_eq!(
322        func_option_syntax.descendants_with_tokens().count(),
323        5, // 5 tokens `1`, ` `, `+`, ` `, `1`
324           // 2 child literal expressions: `1`, `1`
325           // 1 the node itself: `1 + 1`
326    );
327
328    // There's also a `preorder` method with a more fine-grained iteration control:
329    let mut buf = String::new();
330    let mut indent = 0;
331    for event in func_option_syntax.preorder_with_tokens() {
332        match event {
333            WalkEvent::Enter(node) => {
334                let text = match &node {
335                    NodeOrToken::Node(it) => it.text().to_string(),
336                    NodeOrToken::Token(it) => it.text().to_owned(),
337                };
338                buf.write_fmt(format_args!(
339                    "{:indent$}{:?} {:?}\n",
340                    " ",
341                    text,
342                    node.kind(),
343                    indent = indent
344                ))
345                .unwrap();
346                indent += 2;
347            }
348            WalkEvent::Leave(_) => indent -= 2,
349        }
350    }
351    assert_eq!(indent, 0);
352    assert_eq!(
353        buf.trim(),
354        r#"
355"as 'select 1 + 1'" AS_FUNC_OPTION
356  "as" AS_KW
357  " " WHITESPACE
358  "'select 1 + 1'" LITERAL
359    "'select 1 + 1'" STRING
360    "#
361        .trim()
362    );
363
364    // To recursively process the tree, there are three approaches:
365    // 1. explicitly call getter methods on AST nodes.
366    // 2. use descendants and `AstNode::cast`.
367    // 3. use descendants and `match_ast!`.
368    //
369    // Here's how the first one looks like:
370    let exprs_cast: Vec<String> = file
371        .syntax()
372        .descendants()
373        .filter_map(ast::FuncOption::cast)
374        .map(|expr| expr.syntax().text().to_string())
375        .collect();
376
377    // An alternative is to use a macro.
378    let mut exprs_visit = Vec::new();
379    for node in file.syntax().descendants() {
380        match_ast! {
381            match node {
382                ast::FuncOption(it) => {
383                    let res = it.syntax().text().to_string();
384                    exprs_visit.push(res);
385                },
386                _ => (),
387            }
388        }
389    }
390    assert_eq!(exprs_cast, exprs_visit);
391}
392
393#[test]
394fn create_table() {
395    use insta::assert_debug_snapshot;
396
397    let source_code = "
398        create table users (
399            id int8 primary key,
400            name varchar(255) not null,
401            email text,
402            created_at timestamp default now()
403        );
404        
405        create table posts (
406            id serial primary key,
407            title varchar(500),
408            content text,
409            user_id int8 references users(id)
410        );
411    ";
412
413    let parse = SourceFile::parse(source_code);
414    assert!(parse.errors().is_empty());
415    let file: SourceFile = parse.tree();
416
417    let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
418
419    for stmt in file.stmts() {
420        if let ast::Stmt::CreateTable(create_table) = stmt {
421            let table_name = create_table.path().unwrap().syntax().to_string();
422            let mut columns = vec![];
423            for arg in create_table.table_arg_list().unwrap().args() {
424                match arg {
425                    ast::TableArg::Column(column) => {
426                        let column_name = column.name().unwrap();
427                        let column_type = column.ty().unwrap();
428                        columns.push((
429                            column_name.syntax().to_string(),
430                            column_type.syntax().to_string(),
431                        ));
432                    }
433                    ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
434                }
435            }
436            tables.push((table_name, columns));
437        }
438    }
439
440    assert_debug_snapshot!(tables, @r#"
441    [
442        (
443            "users",
444            [
445                (
446                    "id",
447                    "int8",
448                ),
449                (
450                    "name",
451                    "varchar(255)",
452                ),
453                (
454                    "email",
455                    "text",
456                ),
457                (
458                    "created_at",
459                    "timestamp",
460                ),
461            ],
462        ),
463        (
464            "posts",
465            [
466                (
467                    "id",
468                    "serial",
469                ),
470                (
471                    "title",
472                    "varchar(500)",
473                ),
474                (
475                    "content",
476                    "text",
477                ),
478                (
479                    "user_id",
480                    "int8",
481                ),
482            ],
483        ),
484    ]
485    "#)
486}
487
488#[test]
489fn bin_expr() {
490    use insta::assert_debug_snapshot;
491
492    let source_code = "select 1 is not null;";
493    let parse = SourceFile::parse(source_code);
494    assert!(parse.errors().is_empty());
495    let file: SourceFile = parse.tree();
496
497    let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
498        unreachable!()
499    };
500
501    let target_list = select.select_clause().unwrap().target_list().unwrap();
502    let target = target_list.targets().next().unwrap();
503    let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
504        unreachable!()
505    };
506
507    let lhs = bin_expr.lhs();
508    let op = bin_expr.op();
509    let rhs = bin_expr.rhs();
510
511    assert_debug_snapshot!(lhs, @r#"
512    Some(
513        Literal(
514            Literal {
515                syntax: LITERAL@7..8
516                  INT_NUMBER@7..8 "1"
517                ,
518            },
519        ),
520    )
521    "#);
522    assert_debug_snapshot!(op, @r#"
523    Some(
524        IsNot(
525            IsNot {
526                syntax: IS_NOT@9..15
527                  IS_KW@9..11 "is"
528                  WHITESPACE@11..12 " "
529                  NOT_KW@12..15 "not"
530                ,
531            },
532        ),
533    )
534    "#);
535    assert_debug_snapshot!(rhs, @r#"
536    Some(
537        Literal(
538            Literal {
539                syntax: LITERAL@16..20
540                  NULL_KW@16..20 "null"
541                ,
542            },
543        ),
544    )
545    "#);
546}