1pub mod ast;
28mod generated;
29pub mod identifier;
30mod parsing;
31mod ptr;
32pub mod quote;
33pub mod syntax_error;
34mod syntax_node;
35mod token_text;
36mod unescape;
37mod validation;
38
39#[cfg(test)]
40mod test;
41
42use std::{marker::PhantomData, sync::Arc};
43
44pub use squawk_parser::SyntaxKind;
45
46use ast::AstNode;
47pub use ptr::{AstPtr, SyntaxNodePtr};
48use rowan::GreenNode;
49use syntax_error::SyntaxError;
50pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
51pub use token_text::TokenText;
52
53#[derive(Debug, PartialEq, Eq)]
59pub struct Parse<T> {
60 green: GreenNode,
61 errors: Option<Arc<[SyntaxError]>>,
62 _ty: PhantomData<fn() -> T>,
63}
64
65impl<T> Clone for Parse<T> {
66 fn clone(&self) -> Parse<T> {
67 Parse {
68 green: self.green.clone(),
69 errors: self.errors.clone(),
70 _ty: PhantomData,
71 }
72 }
73}
74
75impl<T> Parse<T> {
76 fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
77 Parse {
78 green,
79 errors: if errors.is_empty() {
80 None
81 } else {
82 Some(errors.into())
83 },
84 _ty: PhantomData,
85 }
86 }
87
88 pub fn syntax_node(&self) -> SyntaxNode {
89 SyntaxNode::new_root(self.green.clone())
90 }
91
92 pub fn errors(&self) -> Vec<SyntaxError> {
93 let mut errors = if let Some(e) = self.errors.as_deref() {
94 e.to_vec()
95 } else {
96 vec![]
97 };
98 validation::validate(&self.syntax_node(), &mut errors);
99 errors.sort_by_key(|error| error.range().start());
100 errors
101 }
102}
103
104impl<T: AstNode> Parse<T> {
105 pub fn to_syntax(self) -> Parse<SyntaxNode> {
107 Parse {
108 green: self.green,
109 errors: self.errors,
110 _ty: PhantomData,
111 }
112 }
113
114 pub fn tree(&self) -> T {
121 T::cast(self.syntax_node()).unwrap()
122 }
123
124 pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
126 match self.errors() {
127 errors if !errors.is_empty() => Err(errors),
128 _ => Ok(self.tree()),
129 }
130 }
131}
132
133impl Parse<SyntaxNode> {
134 pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
135 if N::cast(self.syntax_node()).is_some() {
136 Some(Parse {
137 green: self.green,
138 errors: self.errors,
139 _ty: PhantomData,
140 })
141 } else {
142 None
143 }
144 }
145}
146
147pub use crate::ast::SourceFile;
149
150impl SourceFile {
151 pub fn parse(text: &str) -> Parse<SourceFile> {
152 let (green, errors) = parsing::parse_text(text);
153 let root = SyntaxNode::new_root(green.clone());
154
155 assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
156 Parse::new(green, errors)
157 }
158}
159
160#[macro_export]
175macro_rules! match_ast {
176 (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
177
178 (match ($node:expr) {
179 $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
180 _ => $catch_all:expr $(,)?
181 }) => {{
182 $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
183 { $catch_all }
184 }};
185}
186
187#[test]
190fn api_walkthrough() {
191 use ast::SourceFile;
192 use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
193 use std::fmt::Write;
194
195 let source_code = "
196 create function foo(p int8)
197 returns int
198 as 'select 1 + 1'
199 language sql;
200 ";
201 let parse = SourceFile::parse(source_code);
206 assert!(parse.errors().is_empty());
207
208 let file: SourceFile = parse.tree();
211
212 let mut func = None;
215 for stmt in file.stmts() {
216 match stmt {
217 ast::Stmt::CreateFunction(f) => func = Some(f),
218 _ => unreachable!(),
219 }
220 }
221 let func: ast::CreateFunction = func.unwrap();
222
223 let path: Option<ast::Path> = func.path();
229 let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
230 assert_eq!(name.text(), "foo");
231
232 let ret_type: Option<ast::RetType> = func.ret_type();
234 let r_ty = &ret_type.unwrap().ty().unwrap();
235 let type_: &ast::PathType = match &r_ty {
236 ast::Type::PathType(r) => r,
237 _ => unreachable!(),
238 };
239 let type_path: ast::Path = type_.path().unwrap();
240 assert_eq!(type_path.syntax().to_string(), "int");
241
242 let param_list: ast::ParamList = func.param_list().unwrap();
244 let param: ast::Param = param_list.params().next().unwrap();
245
246 let param_name: ast::Name = param.name().unwrap();
247 assert_eq!(param_name.syntax().to_string(), "p");
248
249 let param_ty: ast::Type = param.ty().unwrap();
250 assert_eq!(param_ty.syntax().to_string(), "int8");
251
252 let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
253
254 let func_option = func_option_list.options().next().unwrap();
259 let option: &ast::AsFuncOption = match &func_option {
260 ast::FuncOption::AsFuncOption(o) => o,
261 _ => unreachable!(),
262 };
263 let definition: ast::Literal = option.definition().unwrap();
264 assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
265
266 let func_option_syntax = func_option.syntax();
269
270 assert!(func_option_syntax == option.syntax());
272
273 let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
275 Some(e) => e,
276 None => unreachable!(),
277 };
278
279 assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
281
282 assert_eq!(
284 func_option_syntax.text_range(),
285 TextRange::new(65.into(), 82.into())
286 );
287
288 let text: SyntaxText = func_option_syntax.text();
291 assert_eq!(text.to_string(), "as 'select 1 + 1'");
292
293 assert_eq!(
295 func_option_syntax.parent().as_ref(),
296 Some(func_option_list.syntax())
297 );
298 assert_eq!(
299 param_list
300 .syntax()
301 .first_child_or_token()
302 .map(|it| it.kind()),
303 Some(SyntaxKind::L_PAREN)
304 );
305 assert_eq!(
306 func_option_syntax
307 .next_sibling_or_token()
308 .map(|it| it.kind()),
309 Some(SyntaxKind::WHITESPACE)
310 );
311
312 let f = func_option_syntax
314 .ancestors()
315 .find_map(ast::CreateFunction::cast);
316 assert_eq!(f, Some(func));
317 assert!(
318 param
319 .syntax()
320 .siblings_with_tokens(Direction::Next)
321 .any(|it| it.kind() == SyntaxKind::R_PAREN)
322 );
323 assert_eq!(
324 func_option_syntax.descendants_with_tokens().count(),
325 5, );
329
330 let mut buf = String::new();
332 let mut indent = 0;
333 for event in func_option_syntax.preorder_with_tokens() {
334 match event {
335 WalkEvent::Enter(node) => {
336 let text = match &node {
337 NodeOrToken::Node(it) => it.text().to_string(),
338 NodeOrToken::Token(it) => it.text().to_owned(),
339 };
340 buf.write_fmt(format_args!(
341 "{:indent$}{:?} {:?}\n",
342 " ",
343 text,
344 node.kind(),
345 indent = indent
346 ))
347 .unwrap();
348 indent += 2;
349 }
350 WalkEvent::Leave(_) => indent -= 2,
351 }
352 }
353 assert_eq!(indent, 0);
354 assert_eq!(
355 buf.trim(),
356 r#"
357"as 'select 1 + 1'" AS_FUNC_OPTION
358 "as" AS_KW
359 " " WHITESPACE
360 "'select 1 + 1'" LITERAL
361 "'select 1 + 1'" STRING
362 "#
363 .trim()
364 );
365
366 let exprs_cast: Vec<String> = file
373 .syntax()
374 .descendants()
375 .filter_map(ast::FuncOption::cast)
376 .map(|expr| expr.syntax().text().to_string())
377 .collect();
378
379 let mut exprs_visit = Vec::new();
381 for node in file.syntax().descendants() {
382 match_ast! {
383 match node {
384 ast::FuncOption(it) => {
385 let res = it.syntax().text().to_string();
386 exprs_visit.push(res);
387 },
388 _ => (),
389 }
390 }
391 }
392 assert_eq!(exprs_cast, exprs_visit);
393}
394
395#[test]
396fn create_table() {
397 use insta::assert_debug_snapshot;
398
399 let source_code = "
400 create table users (
401 id int8 primary key,
402 name varchar(255) not null,
403 email text,
404 created_at timestamp default now()
405 );
406
407 create table posts (
408 id serial primary key,
409 title varchar(500),
410 content text,
411 user_id int8 references users(id)
412 );
413 ";
414
415 let parse = SourceFile::parse(source_code);
416 assert!(parse.errors().is_empty());
417 let file: SourceFile = parse.tree();
418
419 let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
420
421 for stmt in file.stmts() {
422 if let ast::Stmt::CreateTable(create_table) = stmt {
423 let table_name = create_table.path().unwrap().syntax().to_string();
424 let mut columns = vec![];
425 for arg in create_table.table_arg_list().unwrap().args() {
426 match arg {
427 ast::TableArg::Column(column) => {
428 let column_name = column.name().unwrap();
429 let column_type = column.ty().unwrap();
430 columns.push((
431 column_name.syntax().to_string(),
432 column_type.syntax().to_string(),
433 ));
434 }
435 ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
436 }
437 }
438 tables.push((table_name, columns));
439 }
440 }
441
442 assert_debug_snapshot!(tables, @r#"
443 [
444 (
445 "users",
446 [
447 (
448 "id",
449 "int8",
450 ),
451 (
452 "name",
453 "varchar(255)",
454 ),
455 (
456 "email",
457 "text",
458 ),
459 (
460 "created_at",
461 "timestamp",
462 ),
463 ],
464 ),
465 (
466 "posts",
467 [
468 (
469 "id",
470 "serial",
471 ),
472 (
473 "title",
474 "varchar(500)",
475 ),
476 (
477 "content",
478 "text",
479 ),
480 (
481 "user_id",
482 "int8",
483 ),
484 ],
485 ),
486 ]
487 "#)
488}
489
490#[test]
491fn bin_expr() {
492 use insta::assert_debug_snapshot;
493
494 let source_code = "select 1 is not null;";
495 let parse = SourceFile::parse(source_code);
496 assert!(parse.errors().is_empty());
497 let file: SourceFile = parse.tree();
498
499 let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
500 unreachable!()
501 };
502
503 let target_list = select.select_clause().unwrap().target_list().unwrap();
504 let target = target_list.targets().next().unwrap();
505 let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
506 unreachable!()
507 };
508
509 let lhs = bin_expr.lhs();
510 let op = bin_expr.op();
511 let rhs = bin_expr.rhs();
512
513 assert_debug_snapshot!(lhs, @r#"
514 Some(
515 Literal(
516 Literal {
517 syntax: LITERAL@7..8
518 INT_NUMBER@7..8 "1"
519 ,
520 },
521 ),
522 )
523 "#);
524 assert_debug_snapshot!(op, @r#"
525 Some(
526 IsNot(
527 IsNot {
528 syntax: IS_NOT@9..15
529 IS_KW@9..11 "is"
530 WHITESPACE@11..12 " "
531 NOT_KW@12..15 "not"
532 ,
533 },
534 ),
535 )
536 "#);
537 assert_debug_snapshot!(rhs, @r#"
538 Some(
539 Literal(
540 Literal {
541 syntax: LITERAL@16..20
542 NULL_KW@16..20 "null"
543 ,
544 },
545 ),
546 )
547 "#);
548}