1use rigsql_core::Segment;
2use rigsql_lexer::{Lexer, LexerConfig, LexerError};
3use thiserror::Error;
4
5use crate::context::ParseContext;
6use crate::grammar::Grammar;
7
8#[derive(Debug, Error)]
9pub enum ParseError {
10 #[error("Lexer error: {0}")]
11 Lexer(#[from] LexerError),
12}
13
14pub struct Parser {
16 lexer_config: LexerConfig,
17}
18
19impl Parser {
20 pub fn new(lexer_config: LexerConfig) -> Self {
21 Self { lexer_config }
22 }
23
24 pub fn parse(&self, source: &str) -> Result<Segment, ParseError> {
26 let mut lexer = Lexer::new(source, self.lexer_config.clone());
27 let tokens = lexer.tokenize()?;
28 let mut ctx = ParseContext::new(&tokens, source);
29 let file = Grammar::parse_file(&mut ctx);
30 Ok(file)
31 }
32}
33
34impl Default for Parser {
35 fn default() -> Self {
36 Self::new(LexerConfig::ansi())
37 }
38}
39
40#[cfg(test)]
41mod tests {
42 use super::*;
43 use rigsql_core::SegmentType;
44
45 fn parse(sql: &str) -> Segment {
46 Parser::default().parse(sql).unwrap()
47 }
48
49 fn assert_type(seg: &Segment, expected: SegmentType) {
50 assert_eq!(
51 seg.segment_type(),
52 expected,
53 "Expected {:?} but got {:?} for raw: {:?}",
54 expected,
55 seg.segment_type(),
56 seg.raw()
57 );
58 }
59
60 fn find_type(seg: &Segment, ty: SegmentType) -> Option<&Segment> {
61 let mut result = None;
62 seg.walk(&mut |s| {
63 if result.is_none() && s.segment_type() == ty {
64 result = Some(s as *const Segment);
65 }
66 });
67 result.map(|p| unsafe { &*p })
68 }
69
70 #[test]
71 fn test_simple_select() {
72 let cst = parse("SELECT 1");
73 assert_type(&cst, SegmentType::File);
74 let stmt = &cst.children()[0];
75 assert_type(stmt, SegmentType::Statement);
76 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
77 }
78
79 #[test]
80 fn test_select_from_where() {
81 let cst = parse("SELECT name FROM users WHERE id = 1");
82 assert!(find_type(&cst, SegmentType::SelectClause).is_some());
83 assert!(find_type(&cst, SegmentType::FromClause).is_some());
84 assert!(find_type(&cst, SegmentType::WhereClause).is_some());
85 }
86
87 #[test]
88 fn test_join() {
89 let cst = parse("SELECT a.id FROM a INNER JOIN b ON a.id = b.id");
90 assert!(find_type(&cst, SegmentType::JoinClause).is_some());
91 assert!(find_type(&cst, SegmentType::OnClause).is_some());
92 }
93
94 #[test]
95 fn test_group_by_having_order_by() {
96 let cst = parse(
97 "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 5 ORDER BY dept ASC",
98 );
99 assert!(find_type(&cst, SegmentType::GroupByClause).is_some());
100 assert!(find_type(&cst, SegmentType::HavingClause).is_some());
101 assert!(find_type(&cst, SegmentType::OrderByClause).is_some());
102 }
103
104 #[test]
105 fn test_insert_values() {
106 let cst = parse("INSERT INTO users (name, email) VALUES ('Alice', 'a@b.com')");
107 assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
108 assert!(find_type(&cst, SegmentType::ValuesClause).is_some());
109 }
110
111 #[test]
112 fn test_update_set_where() {
113 let cst = parse("UPDATE users SET name = 'Bob' WHERE id = 1");
114 assert!(find_type(&cst, SegmentType::UpdateStatement).is_some());
115 assert!(find_type(&cst, SegmentType::SetClause).is_some());
116 assert!(find_type(&cst, SegmentType::WhereClause).is_some());
117 }
118
119 #[test]
120 fn test_delete() {
121 let cst = parse("DELETE FROM users WHERE id = 1");
122 assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
123 }
124
125 #[test]
126 fn test_create_table() {
127 let cst = parse("CREATE TABLE users (id INT, name VARCHAR(100))");
128 assert!(find_type(&cst, SegmentType::CreateTableStatement).is_some());
129 }
130
131 #[test]
132 fn test_with_cte() {
133 let cst =
134 parse("WITH active AS (SELECT * FROM users WHERE active = TRUE) SELECT * FROM active");
135 assert!(find_type(&cst, SegmentType::WithClause).is_some());
136 assert!(find_type(&cst, SegmentType::CteDefinition).is_some());
137 }
138
139 #[test]
140 fn test_case_expression() {
141 let cst = parse("SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t");
142 assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
143 assert!(find_type(&cst, SegmentType::WhenClause).is_some());
144 assert!(find_type(&cst, SegmentType::ElseClause).is_some());
145 }
146
147 #[test]
148 fn test_subquery() {
149 let cst = parse("SELECT * FROM (SELECT 1) AS sub");
150 assert!(find_type(&cst, SegmentType::Subquery).is_some());
151 }
152
153 #[test]
154 fn test_function_call() {
155 let cst = parse("SELECT COUNT(*) FROM users");
156 assert!(find_type(&cst, SegmentType::FunctionCall).is_some());
157 }
158
159 #[test]
160 fn test_roundtrip() {
161 let sql = "SELECT a, b FROM t WHERE x = 1 ORDER BY a;";
162 let cst = parse(sql);
163 assert_eq!(
164 cst.raw(),
165 sql,
166 "CST roundtrip must preserve source text exactly"
167 );
168 }
169
170 #[test]
171 fn test_multiple_statements() {
172 let sql = "SELECT 1; SELECT 2;";
173 let cst = parse(sql);
174 let stmts: Vec<_> = cst
175 .children()
176 .iter()
177 .filter(|s| s.segment_type() == SegmentType::Statement)
178 .collect();
179 assert_eq!(stmts.len(), 2);
180 }
181
182 #[test]
183 fn test_roundtrip_complex() {
184 let sql = "WITH cte AS (\n SELECT id, name\n FROM users\n WHERE active = TRUE\n)\nSELECT cte.id, cte.name\nFROM cte\nINNER JOIN orders ON cte.id = orders.user_id\nWHERE orders.total > 100\nORDER BY cte.name ASC\nLIMIT 10;";
185 let cst = parse(sql);
186 assert_eq!(cst.raw(), sql);
187 }
188}