pub mod ast;
pub mod identifier;
mod parsing;
mod ptr;
pub mod syntax_error;
mod syntax_node;
mod token_text;
mod validation;
#[cfg(test)]
mod test;
use std::{marker::PhantomData, sync::Arc};
pub use squawk_parser::SyntaxKind;
use ast::AstNode;
pub use ptr::{AstPtr, SyntaxNodePtr};
use rowan::GreenNode;
use syntax_error::SyntaxError;
pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
pub use token_text::TokenText;
#[derive(Debug, PartialEq, Eq)]
pub struct Parse<T> {
green: GreenNode,
errors: Option<Arc<[SyntaxError]>>,
_ty: PhantomData<fn() -> T>,
}
impl<T> Clone for Parse<T> {
fn clone(&self) -> Parse<T> {
Parse {
green: self.green.clone(),
errors: self.errors.clone(),
_ty: PhantomData,
}
}
}
impl<T> Parse<T> {
fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
Parse {
green,
errors: if errors.is_empty() {
None
} else {
Some(errors.into())
},
_ty: PhantomData,
}
}
pub fn syntax_node(&self) -> SyntaxNode {
SyntaxNode::new_root(self.green.clone())
}
pub fn errors(&self) -> Vec<SyntaxError> {
let mut errors = if let Some(e) = self.errors.as_deref() {
e.to_vec()
} else {
vec![]
};
validation::validate(&self.syntax_node(), &mut errors);
errors
}
}
impl<T: AstNode> Parse<T> {
pub fn to_syntax(self) -> Parse<SyntaxNode> {
Parse {
green: self.green,
errors: self.errors,
_ty: PhantomData,
}
}
pub fn tree(&self) -> T {
T::cast(self.syntax_node()).unwrap()
}
pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
match self.errors() {
errors if !errors.is_empty() => Err(errors),
_ => Ok(self.tree()),
}
}
}
impl Parse<SyntaxNode> {
pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
if N::cast(self.syntax_node()).is_some() {
Some(Parse {
green: self.green,
errors: self.errors,
_ty: PhantomData,
})
} else {
None
}
}
}
pub use crate::ast::SourceFile;
impl SourceFile {
pub fn parse(text: &str) -> Parse<SourceFile> {
let (green, errors) = parsing::parse_text(text);
let root = SyntaxNode::new_root(green.clone());
assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
Parse::new(green, errors)
}
}
#[macro_export]
macro_rules! match_ast {
(match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
(match ($node:expr) {
$( $( $path:ident )::+ ($it:pat) => $res:expr, )*
_ => $catch_all:expr $(,)?
}) => {{
$( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
{ $catch_all }
}};
}
#[test]
fn api_walkthrough() {
use ast::SourceFile;
use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
use std::fmt::Write;
let source_code = "
create function foo(p int8)
returns int
as 'select 1 + 1'
language sql;
";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let mut func = None;
for stmt in file.stmts() {
match stmt {
ast::Stmt::CreateFunction(f) => func = Some(f),
_ => unreachable!(),
}
}
let func: ast::CreateFunction = func.unwrap();
let path: Option<ast::Path> = func.path();
let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
assert_eq!(name.text(), "foo");
let ret_type: Option<ast::RetType> = func.ret_type();
let r_ty = &ret_type.unwrap().ty().unwrap();
let type_: &ast::PathType = match &r_ty {
ast::Type::PathType(r) => r,
_ => unreachable!(),
};
let type_path: ast::Path = type_.path().unwrap();
assert_eq!(type_path.syntax().to_string(), "int");
let param_list: ast::ParamList = func.param_list().unwrap();
let param: ast::Param = param_list.params().next().unwrap();
let param_name: ast::Name = param.name().unwrap();
assert_eq!(param_name.syntax().to_string(), "p");
let param_ty: ast::Type = param.ty().unwrap();
assert_eq!(param_ty.syntax().to_string(), "int8");
let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
let func_option = func_option_list.options().next().unwrap();
let option: &ast::AsFuncOption = match &func_option {
ast::FuncOption::AsFuncOption(o) => o,
_ => unreachable!(),
};
let definition: ast::Literal = option.definition().unwrap();
assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
let func_option_syntax = func_option.syntax();
assert!(func_option_syntax == option.syntax());
let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
Some(e) => e,
None => unreachable!(),
};
assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
assert_eq!(
func_option_syntax.text_range(),
TextRange::new(65.into(), 82.into())
);
let text: SyntaxText = func_option_syntax.text();
assert_eq!(text.to_string(), "as 'select 1 + 1'");
assert_eq!(
func_option_syntax.parent().as_ref(),
Some(func_option_list.syntax())
);
assert_eq!(
param_list
.syntax()
.first_child_or_token()
.map(|it| it.kind()),
Some(SyntaxKind::L_PAREN)
);
assert_eq!(
func_option_syntax
.next_sibling_or_token()
.map(|it| it.kind()),
Some(SyntaxKind::WHITESPACE)
);
let f = func_option_syntax
.ancestors()
.find_map(ast::CreateFunction::cast);
assert_eq!(f, Some(func));
assert!(
param
.syntax()
.siblings_with_tokens(Direction::Next)
.any(|it| it.kind() == SyntaxKind::R_PAREN)
);
assert_eq!(
func_option_syntax.descendants_with_tokens().count(),
5, );
let mut buf = String::new();
let mut indent = 0;
for event in func_option_syntax.preorder_with_tokens() {
match event {
WalkEvent::Enter(node) => {
let text = match &node {
NodeOrToken::Node(it) => it.text().to_string(),
NodeOrToken::Token(it) => it.text().to_owned(),
};
buf.write_fmt(format_args!(
"{:indent$}{:?} {:?}\n",
" ",
text,
node.kind(),
indent = indent
))
.unwrap();
indent += 2;
}
WalkEvent::Leave(_) => indent -= 2,
}
}
assert_eq!(indent, 0);
assert_eq!(
buf.trim(),
r#"
"as 'select 1 + 1'" AS_FUNC_OPTION
"as" AS_KW
" " WHITESPACE
"'select 1 + 1'" LITERAL
"'select 1 + 1'" STRING
"#
.trim()
);
let exprs_cast: Vec<String> = file
.syntax()
.descendants()
.filter_map(ast::FuncOption::cast)
.map(|expr| expr.syntax().text().to_string())
.collect();
let mut exprs_visit = Vec::new();
for node in file.syntax().descendants() {
match_ast! {
match node {
ast::FuncOption(it) => {
let res = it.syntax().text().to_string();
exprs_visit.push(res);
},
_ => (),
}
}
}
assert_eq!(exprs_cast, exprs_visit);
}
#[test]
fn create_table() {
use insta::assert_debug_snapshot;
let source_code = "
create table users (
id int8 primary key,
name varchar(255) not null,
email text,
created_at timestamp default now()
);
create table posts (
id serial primary key,
title varchar(500),
content text,
user_id int8 references users(id)
);
";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
for stmt in file.stmts() {
if let ast::Stmt::CreateTable(create_table) = stmt {
let table_name = create_table.path().unwrap().syntax().to_string();
let mut columns = vec![];
for arg in create_table.table_arg_list().unwrap().args() {
match arg {
ast::TableArg::Column(column) => {
let column_name = column.name().unwrap();
let column_type = column.ty().unwrap();
columns.push((
column_name.syntax().to_string(),
column_type.syntax().to_string(),
));
}
ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
}
}
tables.push((table_name, columns));
}
}
assert_debug_snapshot!(tables, @r#"
[
(
"users",
[
(
"id",
"int8",
),
(
"name",
"varchar(255)",
),
(
"email",
"text",
),
(
"created_at",
"timestamp",
),
],
),
(
"posts",
[
(
"id",
"serial",
),
(
"title",
"varchar(500)",
),
(
"content",
"text",
),
(
"user_id",
"int8",
),
],
),
]
"#)
}
#[test]
fn bin_expr() {
use insta::assert_debug_snapshot;
let source_code = "select 1 is not null;";
let parse = SourceFile::parse(source_code);
assert!(parse.errors().is_empty());
let file: SourceFile = parse.tree();
let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
unreachable!()
};
let target_list = select.select_clause().unwrap().target_list().unwrap();
let target = target_list.targets().next().unwrap();
let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
unreachable!()
};
let lhs = bin_expr.lhs();
let op = bin_expr.op();
let rhs = bin_expr.rhs();
assert_debug_snapshot!(lhs, @r#"
Some(
Literal(
Literal {
syntax: LITERAL@7..8
INT_NUMBER@7..8 "1"
,
},
),
)
"#);
assert_debug_snapshot!(op, @r#"
Some(
IsNot(
IsNot {
syntax: IS_NOT@9..15
IS_KW@9..11 "is"
WHITESPACE@11..12 " "
NOT_KW@12..15 "not"
,
},
),
)
"#);
assert_debug_snapshot!(rhs, @r#"
Some(
Literal(
Literal {
syntax: LITERAL@16..20
NULL_KW@16..20 "null"
,
},
),
)
"#);
}