Skip to main content

squawk_syntax/
lib.rs

1// via https://github.com/rust-lang/rust-analyzer/blob/d8887c0758bbd2d5f752d5bd405d4491e90e7ed6/crates/syntax/src/lib.rs
2//
3// Permission is hereby granted, free of charge, to any
4// person obtaining a copy of this software and associated
5// documentation files (the "Software"), to deal in the
6// Software without restriction, including without
7// limitation the rights to use, copy, modify, merge,
8// publish, distribute, sublicense, and/or sell copies of
9// the Software, and to permit persons to whom the Software
10// is furnished to do so, subject to the following
11// conditions:
12//
13// The above copyright notice and this permission notice
14// shall be included in all copies or substantial portions
15// of the Software.
16//
17// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF
18// ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED
19// TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
20// PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT
21// SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
22// CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
23// OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR
24// IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
25// DEALINGS IN THE SOFTWARE.
26
27pub mod ast;
28pub mod identifier;
29mod parsing;
30mod ptr;
31pub mod syntax_error;
32mod syntax_node;
33mod token_text;
34mod validation;
35
36#[cfg(test)]
37mod test;
38
39use std::{marker::PhantomData, sync::Arc};
40
41pub use squawk_parser::SyntaxKind;
42
43use ast::AstNode;
44pub use ptr::{AstPtr, SyntaxNodePtr};
45use rowan::GreenNode;
46use syntax_error::SyntaxError;
47pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
48pub use token_text::TokenText;
49
50/// `Parse` is the result of the parsing: a syntax tree and a collection of
51/// errors.
52///
53/// Note that we always produce a syntax tree, even for completely invalid
54/// files.
55#[derive(Debug, PartialEq, Eq)]
56pub struct Parse<T> {
57    green: GreenNode,
58    errors: Option<Arc<[SyntaxError]>>,
59    _ty: PhantomData<fn() -> T>,
60}
61
62impl<T> Clone for Parse<T> {
63    fn clone(&self) -> Parse<T> {
64        Parse {
65            green: self.green.clone(),
66            errors: self.errors.clone(),
67            _ty: PhantomData,
68        }
69    }
70}
71
72impl<T> Parse<T> {
73    fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
74        Parse {
75            green,
76            errors: if errors.is_empty() {
77                None
78            } else {
79                Some(errors.into())
80            },
81            _ty: PhantomData,
82        }
83    }
84
85    pub fn syntax_node(&self) -> SyntaxNode {
86        SyntaxNode::new_root(self.green.clone())
87    }
88
89    pub fn errors(&self) -> Vec<SyntaxError> {
90        let mut errors = if let Some(e) = self.errors.as_deref() {
91            e.to_vec()
92        } else {
93            vec![]
94        };
95        validation::validate(&self.syntax_node(), &mut errors);
96        errors
97    }
98}
99
100impl<T: AstNode> Parse<T> {
101    /// Converts this parse result into a parse result for an untyped syntax tree.
102    pub fn to_syntax(self) -> Parse<SyntaxNode> {
103        Parse {
104            green: self.green,
105            errors: self.errors,
106            _ty: PhantomData,
107        }
108    }
109
110    /// Gets the parsed syntax tree as a typed ast node.
111    ///
112    /// # Panics
113    ///
114    /// Panics if the root node cannot be casted into the typed ast node
115    /// (e.g. if it's an `ERROR` node).
116    pub fn tree(&self) -> T {
117        T::cast(self.syntax_node()).unwrap()
118    }
119
120    /// Converts from `Parse<T>` to [`Result<T, Vec<SyntaxError>>`].
121    pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
122        match self.errors() {
123            errors if !errors.is_empty() => Err(errors),
124            _ => Ok(self.tree()),
125        }
126    }
127}
128
129impl Parse<SyntaxNode> {
130    pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
131        if N::cast(self.syntax_node()).is_some() {
132            Some(Parse {
133                green: self.green,
134                errors: self.errors,
135                _ty: PhantomData,
136            })
137        } else {
138            None
139        }
140    }
141}
142
143/// `SourceFile` represents a parse tree for a single SQL file.
144pub use crate::ast::SourceFile;
145
146impl SourceFile {
147    pub fn parse(text: &str) -> Parse<SourceFile> {
148        let (green, errors) = parsing::parse_text(text);
149        let root = SyntaxNode::new_root(green.clone());
150
151        assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
152        Parse::new(green, errors)
153    }
154}
155
156/// Matches a `SyntaxNode` against an `ast` type.
157///
158/// # Example:
159///
160/// ```ignore
161/// match_ast! {
162///     match node {
163///         ast::CallExpr(it) => { ... },
164///         ast::MethodCallExpr(it) => { ... },
165///         ast::MacroCall(it) => { ... },
166///         _ => None,
167///     }
168/// }
169/// ```
170#[macro_export]
171macro_rules! match_ast {
172    (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
173
174    (match ($node:expr) {
175        $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
176        _ => $catch_all:expr $(,)?
177    }) => {{
178        $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
179        { $catch_all }
180    }};
181}
182
183/// This test does not assert anything and instead just shows off the crate's
184/// API.
185#[test]
186fn api_walkthrough() {
187    use ast::SourceFile;
188    use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
189    use std::fmt::Write;
190
191    let source_code = "
192        create function foo(p int8)
193        returns int
194        as 'select 1 + 1'
195        language sql;
196    ";
197    // `SourceFile` is the main entry point.
198    //
199    // The `parse` method returns a `Parse` -- a pair of syntax tree and a list
200    // of errors. That is, syntax tree is constructed even in presence of errors.
201    let parse = SourceFile::parse(source_code);
202    assert!(parse.errors().is_empty());
203
204    // The `tree` method returns an owned syntax node of type `SourceFile`.
205    // Owned nodes are cheap: inside, they are `Rc` handles to the underling data.
206    let file: SourceFile = parse.tree();
207
208    // `SourceFile` is the root of the syntax tree. We can iterate file's items.
209    // Let's fetch the `foo` function.
210    let mut func = None;
211    for stmt in file.stmts() {
212        match stmt {
213            ast::Stmt::CreateFunction(f) => func = Some(f),
214            _ => unreachable!(),
215        }
216    }
217    let func: ast::CreateFunction = func.unwrap();
218
219    // Each AST node has a bunch of getters for children. All getters return
220    // `Option`s though, to account for incomplete code. Some getters are common
221    // for several kinds of node. In this case, a trait like `ast::NameOwner`
222    // usually exists. By convention, all ast types should be used with `ast::`
223    // qualifier.
224    let path: Option<ast::Path> = func.path();
225    let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
226    assert_eq!(name.text(), "foo");
227
228    // return
229    let ret_type: Option<ast::RetType> = func.ret_type();
230    let r_ty = &ret_type.unwrap().ty().unwrap();
231    let type_: &ast::PathType = match &r_ty {
232        ast::Type::PathType(r) => r,
233        _ => unreachable!(),
234    };
235    let type_path: ast::Path = type_.path().unwrap();
236    assert_eq!(type_path.syntax().to_string(), "int");
237
238    // params
239    let param_list: ast::ParamList = func.param_list().unwrap();
240    let param: ast::Param = param_list.params().next().unwrap();
241
242    let param_name: ast::Name = param.name().unwrap();
243    assert_eq!(param_name.syntax().to_string(), "p");
244
245    let param_ty: ast::Type = param.ty().unwrap();
246    assert_eq!(param_ty.syntax().to_string(), "int8");
247
248    let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
249
250    // Enums are used to group related ast nodes together, and can be used for
251    // matching. However, because there are no public fields, it's possible to
252    // match only the top level enum: that is the price we pay for increased API
253    // flexibility
254    let func_option = func_option_list.options().next().unwrap();
255    let option: &ast::AsFuncOption = match &func_option {
256        ast::FuncOption::AsFuncOption(o) => o,
257        _ => unreachable!(),
258    };
259    let definition: ast::Literal = option.definition().unwrap();
260    assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
261
262    // Besides the "typed" AST API, there's an untyped CST one as well.
263    // To switch from AST to CST, call `.syntax()` method:
264    let func_option_syntax = func_option.syntax();
265
266    // Note how `func_option_syntax` and `option` are in fact the same node underneath:
267    assert!(func_option_syntax == option.syntax());
268
269    // To go from CST to AST, `AstNode::cast` function is used:
270    let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
271        Some(e) => e,
272        None => unreachable!(),
273    };
274
275    // The two properties each syntax node has is a `SyntaxKind`:
276    assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
277
278    // And text range:
279    assert_eq!(
280        func_option_syntax.text_range(),
281        TextRange::new(65.into(), 82.into())
282    );
283
284    // You can get node's text as a `SyntaxText` object, which will traverse the
285    // tree collecting token's text:
286    let text: SyntaxText = func_option_syntax.text();
287    assert_eq!(text.to_string(), "as 'select 1 + 1'");
288
289    // There's a bunch of traversal methods on `SyntaxNode`:
290    assert_eq!(
291        func_option_syntax.parent().as_ref(),
292        Some(func_option_list.syntax())
293    );
294    assert_eq!(
295        param_list
296            .syntax()
297            .first_child_or_token()
298            .map(|it| it.kind()),
299        Some(SyntaxKind::L_PAREN)
300    );
301    assert_eq!(
302        func_option_syntax
303            .next_sibling_or_token()
304            .map(|it| it.kind()),
305        Some(SyntaxKind::WHITESPACE)
306    );
307
308    // As well as some iterator helpers:
309    let f = func_option_syntax
310        .ancestors()
311        .find_map(ast::CreateFunction::cast);
312    assert_eq!(f, Some(func));
313    assert!(
314        param
315            .syntax()
316            .siblings_with_tokens(Direction::Next)
317            .any(|it| it.kind() == SyntaxKind::R_PAREN)
318    );
319    assert_eq!(
320        func_option_syntax.descendants_with_tokens().count(),
321        5, // 5 tokens `1`, ` `, `+`, ` `, `1`
322           // 2 child literal expressions: `1`, `1`
323           // 1 the node itself: `1 + 1`
324    );
325
326    // There's also a `preorder` method with a more fine-grained iteration control:
327    let mut buf = String::new();
328    let mut indent = 0;
329    for event in func_option_syntax.preorder_with_tokens() {
330        match event {
331            WalkEvent::Enter(node) => {
332                let text = match &node {
333                    NodeOrToken::Node(it) => it.text().to_string(),
334                    NodeOrToken::Token(it) => it.text().to_owned(),
335                };
336                buf.write_fmt(format_args!(
337                    "{:indent$}{:?} {:?}\n",
338                    " ",
339                    text,
340                    node.kind(),
341                    indent = indent
342                ))
343                .unwrap();
344                indent += 2;
345            }
346            WalkEvent::Leave(_) => indent -= 2,
347        }
348    }
349    assert_eq!(indent, 0);
350    assert_eq!(
351        buf.trim(),
352        r#"
353"as 'select 1 + 1'" AS_FUNC_OPTION
354  "as" AS_KW
355  " " WHITESPACE
356  "'select 1 + 1'" LITERAL
357    "'select 1 + 1'" STRING
358    "#
359        .trim()
360    );
361
362    // To recursively process the tree, there are three approaches:
363    // 1. explicitly call getter methods on AST nodes.
364    // 2. use descendants and `AstNode::cast`.
365    // 3. use descendants and `match_ast!`.
366    //
367    // Here's how the first one looks like:
368    let exprs_cast: Vec<String> = file
369        .syntax()
370        .descendants()
371        .filter_map(ast::FuncOption::cast)
372        .map(|expr| expr.syntax().text().to_string())
373        .collect();
374
375    // An alternative is to use a macro.
376    let mut exprs_visit = Vec::new();
377    for node in file.syntax().descendants() {
378        match_ast! {
379            match node {
380                ast::FuncOption(it) => {
381                    let res = it.syntax().text().to_string();
382                    exprs_visit.push(res);
383                },
384                _ => (),
385            }
386        }
387    }
388    assert_eq!(exprs_cast, exprs_visit);
389}
390
391#[test]
392fn create_table() {
393    use insta::assert_debug_snapshot;
394
395    let source_code = "
396        create table users (
397            id int8 primary key,
398            name varchar(255) not null,
399            email text,
400            created_at timestamp default now()
401        );
402        
403        create table posts (
404            id serial primary key,
405            title varchar(500),
406            content text,
407            user_id int8 references users(id)
408        );
409    ";
410
411    let parse = SourceFile::parse(source_code);
412    assert!(parse.errors().is_empty());
413    let file: SourceFile = parse.tree();
414
415    let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
416
417    for stmt in file.stmts() {
418        if let ast::Stmt::CreateTable(create_table) = stmt {
419            let table_name = create_table.path().unwrap().syntax().to_string();
420            let mut columns = vec![];
421            for arg in create_table.table_arg_list().unwrap().args() {
422                match arg {
423                    ast::TableArg::Column(column) => {
424                        let column_name = column.name().unwrap();
425                        let column_type = column.ty().unwrap();
426                        columns.push((
427                            column_name.syntax().to_string(),
428                            column_type.syntax().to_string(),
429                        ));
430                    }
431                    ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
432                }
433            }
434            tables.push((table_name, columns));
435        }
436    }
437
438    assert_debug_snapshot!(tables, @r#"
439    [
440        (
441            "users",
442            [
443                (
444                    "id",
445                    "int8",
446                ),
447                (
448                    "name",
449                    "varchar(255)",
450                ),
451                (
452                    "email",
453                    "text",
454                ),
455                (
456                    "created_at",
457                    "timestamp",
458                ),
459            ],
460        ),
461        (
462            "posts",
463            [
464                (
465                    "id",
466                    "serial",
467                ),
468                (
469                    "title",
470                    "varchar(500)",
471                ),
472                (
473                    "content",
474                    "text",
475                ),
476                (
477                    "user_id",
478                    "int8",
479                ),
480            ],
481        ),
482    ]
483    "#)
484}
485
486#[test]
487fn bin_expr() {
488    use insta::assert_debug_snapshot;
489
490    let source_code = "select 1 is not null;";
491    let parse = SourceFile::parse(source_code);
492    assert!(parse.errors().is_empty());
493    let file: SourceFile = parse.tree();
494
495    let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
496        unreachable!()
497    };
498
499    let target_list = select.select_clause().unwrap().target_list().unwrap();
500    let target = target_list.targets().next().unwrap();
501    let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
502        unreachable!()
503    };
504
505    let lhs = bin_expr.lhs();
506    let op = bin_expr.op();
507    let rhs = bin_expr.rhs();
508
509    assert_debug_snapshot!(lhs, @r#"
510    Some(
511        Literal(
512            Literal {
513                syntax: LITERAL@7..8
514                  INT_NUMBER@7..8 "1"
515                ,
516            },
517        ),
518    )
519    "#);
520    assert_debug_snapshot!(op, @r#"
521    Some(
522        IsNot(
523            IsNot {
524                syntax: IS_NOT@9..15
525                  IS_KW@9..11 "is"
526                  WHITESPACE@11..12 " "
527                  NOT_KW@12..15 "not"
528                ,
529            },
530        ),
531    )
532    "#);
533    assert_debug_snapshot!(rhs, @r#"
534    Some(
535        Literal(
536            Literal {
537                syntax: LITERAL@16..20
538                  NULL_KW@16..20 "null"
539                ,
540            },
541        ),
542    )
543    "#);
544}