1pub mod ast;
28mod generated;
29pub mod identifier;
30mod parsing;
31mod ptr;
32pub mod quote;
33pub mod syntax_error;
34mod syntax_node;
35mod token_text;
36mod validation;
37
38#[cfg(test)]
39mod test;
40
41use std::{marker::PhantomData, sync::Arc};
42
43pub use squawk_parser::SyntaxKind;
44
45use ast::AstNode;
46pub use ptr::{AstPtr, SyntaxNodePtr};
47use rowan::GreenNode;
48use syntax_error::SyntaxError;
49pub use syntax_node::{SyntaxElement, SyntaxNode, SyntaxToken};
50pub use token_text::TokenText;
51
52#[derive(Debug, PartialEq, Eq)]
58pub struct Parse<T> {
59 green: GreenNode,
60 errors: Option<Arc<[SyntaxError]>>,
61 _ty: PhantomData<fn() -> T>,
62}
63
64impl<T> Clone for Parse<T> {
65 fn clone(&self) -> Parse<T> {
66 Parse {
67 green: self.green.clone(),
68 errors: self.errors.clone(),
69 _ty: PhantomData,
70 }
71 }
72}
73
74impl<T> Parse<T> {
75 fn new(green: GreenNode, errors: Vec<SyntaxError>) -> Parse<T> {
76 Parse {
77 green,
78 errors: if errors.is_empty() {
79 None
80 } else {
81 Some(errors.into())
82 },
83 _ty: PhantomData,
84 }
85 }
86
87 pub fn syntax_node(&self) -> SyntaxNode {
88 SyntaxNode::new_root(self.green.clone())
89 }
90
91 pub fn errors(&self) -> Vec<SyntaxError> {
92 let mut errors = if let Some(e) = self.errors.as_deref() {
93 e.to_vec()
94 } else {
95 vec![]
96 };
97 validation::validate(&self.syntax_node(), &mut errors);
98 errors
99 }
100}
101
102impl<T: AstNode> Parse<T> {
103 pub fn to_syntax(self) -> Parse<SyntaxNode> {
105 Parse {
106 green: self.green,
107 errors: self.errors,
108 _ty: PhantomData,
109 }
110 }
111
112 pub fn tree(&self) -> T {
119 T::cast(self.syntax_node()).unwrap()
120 }
121
122 pub fn ok(self) -> Result<T, Vec<SyntaxError>> {
124 match self.errors() {
125 errors if !errors.is_empty() => Err(errors),
126 _ => Ok(self.tree()),
127 }
128 }
129}
130
131impl Parse<SyntaxNode> {
132 pub fn cast<N: AstNode>(self) -> Option<Parse<N>> {
133 if N::cast(self.syntax_node()).is_some() {
134 Some(Parse {
135 green: self.green,
136 errors: self.errors,
137 _ty: PhantomData,
138 })
139 } else {
140 None
141 }
142 }
143}
144
145pub use crate::ast::SourceFile;
147
148impl SourceFile {
149 pub fn parse(text: &str) -> Parse<SourceFile> {
150 let (green, errors) = parsing::parse_text(text);
151 let root = SyntaxNode::new_root(green.clone());
152
153 assert_eq!(root.kind(), SyntaxKind::SOURCE_FILE);
154 Parse::new(green, errors)
155 }
156}
157
158#[macro_export]
173macro_rules! match_ast {
174 (match $node:ident { $($tt:tt)* }) => { $crate::match_ast!(match ($node) { $($tt)* }) };
175
176 (match ($node:expr) {
177 $( $( $path:ident )::+ ($it:pat) => $res:expr, )*
178 _ => $catch_all:expr $(,)?
179 }) => {{
180 $( if let Some($it) = $($path::)+cast($node.clone()) { $res } else )*
181 { $catch_all }
182 }};
183}
184
185#[test]
188fn api_walkthrough() {
189 use ast::SourceFile;
190 use rowan::{Direction, NodeOrToken, SyntaxText, TextRange, WalkEvent};
191 use std::fmt::Write;
192
193 let source_code = "
194 create function foo(p int8)
195 returns int
196 as 'select 1 + 1'
197 language sql;
198 ";
199 let parse = SourceFile::parse(source_code);
204 assert!(parse.errors().is_empty());
205
206 let file: SourceFile = parse.tree();
209
210 let mut func = None;
213 for stmt in file.stmts() {
214 match stmt {
215 ast::Stmt::CreateFunction(f) => func = Some(f),
216 _ => unreachable!(),
217 }
218 }
219 let func: ast::CreateFunction = func.unwrap();
220
221 let path: Option<ast::Path> = func.path();
227 let name: ast::Name = path.unwrap().segment().unwrap().name().unwrap();
228 assert_eq!(name.text(), "foo");
229
230 let ret_type: Option<ast::RetType> = func.ret_type();
232 let r_ty = &ret_type.unwrap().ty().unwrap();
233 let type_: &ast::PathType = match &r_ty {
234 ast::Type::PathType(r) => r,
235 _ => unreachable!(),
236 };
237 let type_path: ast::Path = type_.path().unwrap();
238 assert_eq!(type_path.syntax().to_string(), "int");
239
240 let param_list: ast::ParamList = func.param_list().unwrap();
242 let param: ast::Param = param_list.params().next().unwrap();
243
244 let param_name: ast::Name = param.name().unwrap();
245 assert_eq!(param_name.syntax().to_string(), "p");
246
247 let param_ty: ast::Type = param.ty().unwrap();
248 assert_eq!(param_ty.syntax().to_string(), "int8");
249
250 let func_option_list: ast::FuncOptionList = func.option_list().unwrap();
251
252 let func_option = func_option_list.options().next().unwrap();
257 let option: &ast::AsFuncOption = match &func_option {
258 ast::FuncOption::AsFuncOption(o) => o,
259 _ => unreachable!(),
260 };
261 let definition: ast::Literal = option.definition().unwrap();
262 assert_eq!(definition.syntax().to_string(), "'select 1 + 1'");
263
264 let func_option_syntax = func_option.syntax();
267
268 assert!(func_option_syntax == option.syntax());
270
271 let _expr: ast::FuncOption = match ast::FuncOption::cast(func_option_syntax.clone()) {
273 Some(e) => e,
274 None => unreachable!(),
275 };
276
277 assert_eq!(func_option_syntax.kind(), SyntaxKind::AS_FUNC_OPTION);
279
280 assert_eq!(
282 func_option_syntax.text_range(),
283 TextRange::new(65.into(), 82.into())
284 );
285
286 let text: SyntaxText = func_option_syntax.text();
289 assert_eq!(text.to_string(), "as 'select 1 + 1'");
290
291 assert_eq!(
293 func_option_syntax.parent().as_ref(),
294 Some(func_option_list.syntax())
295 );
296 assert_eq!(
297 param_list
298 .syntax()
299 .first_child_or_token()
300 .map(|it| it.kind()),
301 Some(SyntaxKind::L_PAREN)
302 );
303 assert_eq!(
304 func_option_syntax
305 .next_sibling_or_token()
306 .map(|it| it.kind()),
307 Some(SyntaxKind::WHITESPACE)
308 );
309
310 let f = func_option_syntax
312 .ancestors()
313 .find_map(ast::CreateFunction::cast);
314 assert_eq!(f, Some(func));
315 assert!(
316 param
317 .syntax()
318 .siblings_with_tokens(Direction::Next)
319 .any(|it| it.kind() == SyntaxKind::R_PAREN)
320 );
321 assert_eq!(
322 func_option_syntax.descendants_with_tokens().count(),
323 5, );
327
328 let mut buf = String::new();
330 let mut indent = 0;
331 for event in func_option_syntax.preorder_with_tokens() {
332 match event {
333 WalkEvent::Enter(node) => {
334 let text = match &node {
335 NodeOrToken::Node(it) => it.text().to_string(),
336 NodeOrToken::Token(it) => it.text().to_owned(),
337 };
338 buf.write_fmt(format_args!(
339 "{:indent$}{:?} {:?}\n",
340 " ",
341 text,
342 node.kind(),
343 indent = indent
344 ))
345 .unwrap();
346 indent += 2;
347 }
348 WalkEvent::Leave(_) => indent -= 2,
349 }
350 }
351 assert_eq!(indent, 0);
352 assert_eq!(
353 buf.trim(),
354 r#"
355"as 'select 1 + 1'" AS_FUNC_OPTION
356 "as" AS_KW
357 " " WHITESPACE
358 "'select 1 + 1'" LITERAL
359 "'select 1 + 1'" STRING
360 "#
361 .trim()
362 );
363
364 let exprs_cast: Vec<String> = file
371 .syntax()
372 .descendants()
373 .filter_map(ast::FuncOption::cast)
374 .map(|expr| expr.syntax().text().to_string())
375 .collect();
376
377 let mut exprs_visit = Vec::new();
379 for node in file.syntax().descendants() {
380 match_ast! {
381 match node {
382 ast::FuncOption(it) => {
383 let res = it.syntax().text().to_string();
384 exprs_visit.push(res);
385 },
386 _ => (),
387 }
388 }
389 }
390 assert_eq!(exprs_cast, exprs_visit);
391}
392
393#[test]
394fn create_table() {
395 use insta::assert_debug_snapshot;
396
397 let source_code = "
398 create table users (
399 id int8 primary key,
400 name varchar(255) not null,
401 email text,
402 created_at timestamp default now()
403 );
404
405 create table posts (
406 id serial primary key,
407 title varchar(500),
408 content text,
409 user_id int8 references users(id)
410 );
411 ";
412
413 let parse = SourceFile::parse(source_code);
414 assert!(parse.errors().is_empty());
415 let file: SourceFile = parse.tree();
416
417 let mut tables: Vec<(String, Vec<(String, String)>)> = vec![];
418
419 for stmt in file.stmts() {
420 if let ast::Stmt::CreateTable(create_table) = stmt {
421 let table_name = create_table.path().unwrap().syntax().to_string();
422 let mut columns = vec![];
423 for arg in create_table.table_arg_list().unwrap().args() {
424 match arg {
425 ast::TableArg::Column(column) => {
426 let column_name = column.name().unwrap();
427 let column_type = column.ty().unwrap();
428 columns.push((
429 column_name.syntax().to_string(),
430 column_type.syntax().to_string(),
431 ));
432 }
433 ast::TableArg::TableConstraint(_) | ast::TableArg::LikeClause(_) => (),
434 }
435 }
436 tables.push((table_name, columns));
437 }
438 }
439
440 assert_debug_snapshot!(tables, @r#"
441 [
442 (
443 "users",
444 [
445 (
446 "id",
447 "int8",
448 ),
449 (
450 "name",
451 "varchar(255)",
452 ),
453 (
454 "email",
455 "text",
456 ),
457 (
458 "created_at",
459 "timestamp",
460 ),
461 ],
462 ),
463 (
464 "posts",
465 [
466 (
467 "id",
468 "serial",
469 ),
470 (
471 "title",
472 "varchar(500)",
473 ),
474 (
475 "content",
476 "text",
477 ),
478 (
479 "user_id",
480 "int8",
481 ),
482 ],
483 ),
484 ]
485 "#)
486}
487
488#[test]
489fn bin_expr() {
490 use insta::assert_debug_snapshot;
491
492 let source_code = "select 1 is not null;";
493 let parse = SourceFile::parse(source_code);
494 assert!(parse.errors().is_empty());
495 let file: SourceFile = parse.tree();
496
497 let ast::Stmt::Select(select) = file.stmts().next().unwrap() else {
498 unreachable!()
499 };
500
501 let target_list = select.select_clause().unwrap().target_list().unwrap();
502 let target = target_list.targets().next().unwrap();
503 let ast::Expr::BinExpr(bin_expr) = target.expr().unwrap() else {
504 unreachable!()
505 };
506
507 let lhs = bin_expr.lhs();
508 let op = bin_expr.op();
509 let rhs = bin_expr.rhs();
510
511 assert_debug_snapshot!(lhs, @r#"
512 Some(
513 Literal(
514 Literal {
515 syntax: LITERAL@7..8
516 INT_NUMBER@7..8 "1"
517 ,
518 },
519 ),
520 )
521 "#);
522 assert_debug_snapshot!(op, @r#"
523 Some(
524 IsNot(
525 IsNot {
526 syntax: IS_NOT@9..15
527 IS_KW@9..11 "is"
528 WHITESPACE@11..12 " "
529 NOT_KW@12..15 "not"
530 ,
531 },
532 ),
533 )
534 "#);
535 assert_debug_snapshot!(rhs, @r#"
536 Some(
537 Literal(
538 Literal {
539 syntax: LITERAL@16..20
540 NULL_KW@16..20 "null"
541 ,
542 },
543 ),
544 )
545 "#);
546}