Skip to main content

rigsql_parser/
parser.rs

1use rigsql_core::Segment;
2use rigsql_lexer::{Lexer, LexerConfig, LexerError};
3use thiserror::Error;
4
5use crate::context::ParseContext;
6use crate::grammar::Grammar;
7
8#[derive(Debug, Error)]
9pub enum ParseError {
10    #[error("Lexer error: {0}")]
11    Lexer(#[from] LexerError),
12}
13
14/// High-level SQL parser: source text → CST.
15pub struct Parser {
16    lexer_config: LexerConfig,
17}
18
19impl Parser {
20    pub fn new(lexer_config: LexerConfig) -> Self {
21        Self { lexer_config }
22    }
23
24    /// Parse SQL source into a CST rooted at a File segment.
25    pub fn parse(&self, source: &str) -> Result<Segment, ParseError> {
26        let mut lexer = Lexer::new(source, self.lexer_config.clone());
27        let tokens = lexer.tokenize()?;
28        let mut ctx = ParseContext::new(&tokens, source);
29        let file = Grammar::parse_file(&mut ctx);
30        Ok(file)
31    }
32}
33
34impl Default for Parser {
35    fn default() -> Self {
36        Self::new(LexerConfig::ansi())
37    }
38}
39
40#[cfg(test)]
41mod tests {
42    use super::*;
43    use rigsql_core::SegmentType;
44
45    fn parse(sql: &str) -> Segment {
46        Parser::default().parse(sql).unwrap()
47    }
48
49    fn assert_type(seg: &Segment, expected: SegmentType) {
50        assert_eq!(
51            seg.segment_type(),
52            expected,
53            "Expected {:?} but got {:?} for raw: {:?}",
54            expected,
55            seg.segment_type(),
56            seg.raw()
57        );
58    }
59
60    fn find_type(seg: &Segment, ty: SegmentType) -> Option<&Segment> {
61        let mut result = None;
62        seg.walk(&mut |s| {
63            if result.is_none() && s.segment_type() == ty {
64                result = Some(s as *const Segment);
65            }
66        });
67        result.map(|p| unsafe { &*p })
68    }
69
70    #[test]
71    fn test_simple_select() {
72        let cst = parse("SELECT 1");
73        assert_type(&cst, SegmentType::File);
74        let stmt = &cst.children()[0];
75        assert_type(stmt, SegmentType::Statement);
76        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
77    }
78
79    #[test]
80    fn test_select_from_where() {
81        let cst = parse("SELECT name FROM users WHERE id = 1");
82        assert!(find_type(&cst, SegmentType::SelectClause).is_some());
83        assert!(find_type(&cst, SegmentType::FromClause).is_some());
84        assert!(find_type(&cst, SegmentType::WhereClause).is_some());
85    }
86
87    #[test]
88    fn test_join() {
89        let cst = parse("SELECT a.id FROM a INNER JOIN b ON a.id = b.id");
90        assert!(find_type(&cst, SegmentType::JoinClause).is_some());
91        assert!(find_type(&cst, SegmentType::OnClause).is_some());
92    }
93
94    #[test]
95    fn test_group_by_having_order_by() {
96        let cst = parse(
97            "SELECT dept, COUNT(*) FROM emp GROUP BY dept HAVING COUNT(*) > 5 ORDER BY dept ASC",
98        );
99        assert!(find_type(&cst, SegmentType::GroupByClause).is_some());
100        assert!(find_type(&cst, SegmentType::HavingClause).is_some());
101        assert!(find_type(&cst, SegmentType::OrderByClause).is_some());
102    }
103
104    #[test]
105    fn test_insert_values() {
106        let cst = parse("INSERT INTO users (name, email) VALUES ('Alice', 'a@b.com')");
107        assert!(find_type(&cst, SegmentType::InsertStatement).is_some());
108        assert!(find_type(&cst, SegmentType::ValuesClause).is_some());
109    }
110
111    #[test]
112    fn test_update_set_where() {
113        let cst = parse("UPDATE users SET name = 'Bob' WHERE id = 1");
114        assert!(find_type(&cst, SegmentType::UpdateStatement).is_some());
115        assert!(find_type(&cst, SegmentType::SetClause).is_some());
116        assert!(find_type(&cst, SegmentType::WhereClause).is_some());
117    }
118
119    #[test]
120    fn test_delete() {
121        let cst = parse("DELETE FROM users WHERE id = 1");
122        assert!(find_type(&cst, SegmentType::DeleteStatement).is_some());
123    }
124
125    #[test]
126    fn test_create_table() {
127        let cst = parse("CREATE TABLE users (id INT, name VARCHAR(100))");
128        assert!(find_type(&cst, SegmentType::CreateTableStatement).is_some());
129    }
130
131    #[test]
132    fn test_with_cte() {
133        let cst =
134            parse("WITH active AS (SELECT * FROM users WHERE active = TRUE) SELECT * FROM active");
135        assert!(find_type(&cst, SegmentType::WithClause).is_some());
136        assert!(find_type(&cst, SegmentType::CteDefinition).is_some());
137    }
138
139    #[test]
140    fn test_case_expression() {
141        let cst = parse("SELECT CASE WHEN x > 0 THEN 'pos' ELSE 'neg' END FROM t");
142        assert!(find_type(&cst, SegmentType::CaseExpression).is_some());
143        assert!(find_type(&cst, SegmentType::WhenClause).is_some());
144        assert!(find_type(&cst, SegmentType::ElseClause).is_some());
145    }
146
147    #[test]
148    fn test_subquery() {
149        let cst = parse("SELECT * FROM (SELECT 1) AS sub");
150        assert!(find_type(&cst, SegmentType::Subquery).is_some());
151    }
152
153    #[test]
154    fn test_function_call() {
155        let cst = parse("SELECT COUNT(*) FROM users");
156        assert!(find_type(&cst, SegmentType::FunctionCall).is_some());
157    }
158
159    #[test]
160    fn test_roundtrip() {
161        let sql = "SELECT a, b FROM t WHERE x = 1 ORDER BY a;";
162        let cst = parse(sql);
163        assert_eq!(
164            cst.raw(),
165            sql,
166            "CST roundtrip must preserve source text exactly"
167        );
168    }
169
170    #[test]
171    fn test_multiple_statements() {
172        let sql = "SELECT 1; SELECT 2;";
173        let cst = parse(sql);
174        let stmts: Vec<_> = cst
175            .children()
176            .iter()
177            .filter(|s| s.segment_type() == SegmentType::Statement)
178            .collect();
179        assert_eq!(stmts.len(), 2);
180    }
181
182    #[test]
183    fn test_roundtrip_complex() {
184        let sql = "WITH cte AS (\n  SELECT id, name\n  FROM users\n  WHERE active = TRUE\n)\nSELECT cte.id, cte.name\nFROM cte\nINNER JOIN orders ON cte.id = orders.user_id\nWHERE orders.total > 100\nORDER BY cte.name ASC\nLIMIT 10;";
185        let cst = parse(sql);
186        assert_eq!(cst.raw(), sql);
187    }
188}