1use crate::error::{DbxError, DbxResult};
2use sqlparser::ast::Statement;
3use sqlparser::dialect::GenericDialect;
4use sqlparser::parser::Parser;
5
6pub struct SqlParser {
8 dialect: GenericDialect,
9}
10
11impl SqlParser {
12 pub fn new() -> Self {
14 Self {
15 dialect: GenericDialect {},
16 }
17 }
18
19 pub fn parse(&self, sql: &str) -> DbxResult<Vec<Statement>> {
21 Parser::parse_sql(&self.dialect, sql).map_err(|e| DbxError::SqlParse {
22 message: e.to_string(),
23 sql: sql.to_string(),
24 })
25 }
26}
27
28impl Default for SqlParser {
29 fn default() -> Self {
30 Self::new()
31 }
32}
33
34#[cfg(test)]
35mod tests {
36 use super::*;
37 use sqlparser::ast::{SelectItem, SetExpr};
38
39 #[test]
40 fn test_parse_simple_select() {
41 let parser = SqlParser::new();
42 let statements = parser.parse("SELECT * FROM users").unwrap();
43 assert_eq!(statements.len(), 1);
44
45 match &statements[0] {
46 Statement::Query(query) => {
47 if let SetExpr::Select(select) = query.body.as_ref() {
48 assert_eq!(select.projection.len(), 1);
49 assert!(matches!(select.projection[0], SelectItem::Wildcard(_)));
50 }
51 }
52 _ => panic!("Expected Query"),
53 }
54 }
55
56 #[test]
57 fn test_parse_select_with_where() {
58 let parser = SqlParser::new();
59 let statements = parser
60 .parse("SELECT id, name FROM users WHERE id = 1")
61 .unwrap();
62 assert_eq!(statements.len(), 1);
63
64 match &statements[0] {
65 Statement::Query(query) => {
66 if let SetExpr::Select(select) = query.body.as_ref() {
67 assert_eq!(select.projection.len(), 2);
68 assert!(select.selection.is_some());
69 }
70 }
71 _ => panic!("Expected Query"),
72 }
73 }
74
75 #[test]
76 fn test_parse_insert() {
77 let parser = SqlParser::new();
78 let statements = parser
79 .parse("INSERT INTO users (id, name) VALUES (1, 'Alice')")
80 .unwrap();
81 assert_eq!(statements.len(), 1);
82 assert!(matches!(statements[0], Statement::Insert { .. }));
83 }
84
85 #[test]
86 fn test_parse_update() {
87 let parser = SqlParser::new();
88 let statements = parser
89 .parse("UPDATE users SET name = 'Bob' WHERE id = 1")
90 .unwrap();
91 assert_eq!(statements.len(), 1);
92 assert!(matches!(statements[0], Statement::Update { .. }));
93 }
94
95 #[test]
96 fn test_parse_delete() {
97 let parser = SqlParser::new();
98 let statements = parser.parse("DELETE FROM users WHERE id = 1").unwrap();
99 assert_eq!(statements.len(), 1);
100 assert!(matches!(statements[0], Statement::Delete(_)));
101 }
102
103 #[test]
104 fn test_parse_create_table() {
105 let parser = SqlParser::new();
106 let statements = parser
107 .parse("CREATE TABLE users (id INT PRIMARY KEY, name TEXT)")
108 .unwrap();
109 assert_eq!(statements.len(), 1);
110 assert!(matches!(statements[0], Statement::CreateTable(_)));
111 }
112
113 #[test]
114 fn test_parse_drop_table() {
115 let parser = SqlParser::new();
116 let statements = parser.parse("DROP TABLE users").unwrap();
117 assert_eq!(statements.len(), 1);
118 assert!(matches!(statements[0], Statement::Drop { .. }));
119 }
120
121 #[test]
122 fn test_parse_select_with_join() {
123 let parser = SqlParser::new();
124 let statements = parser
125 .parse("SELECT u.id, o.total FROM users u INNER JOIN orders o ON u.id = o.user_id")
126 .unwrap();
127 assert_eq!(statements.len(), 1);
128
129 match &statements[0] {
130 Statement::Query(query) => {
131 if let SetExpr::Select(select) = query.body.as_ref() {
132 assert_eq!(select.from.len(), 1);
133 assert!(!select.from[0].joins.is_empty());
134 }
135 }
136 _ => panic!("Expected Query"),
137 }
138 }
139
140 #[test]
141 fn test_parse_select_with_group_by() {
142 let parser = SqlParser::new();
143 let statements = parser
144 .parse("SELECT category, COUNT(*) FROM products GROUP BY category")
145 .unwrap();
146 assert_eq!(statements.len(), 1);
147
148 match &statements[0] {
149 Statement::Query(query) => {
150 if let SetExpr::Select(select) = query.body.as_ref() {
151 match &select.group_by {
152 sqlparser::ast::GroupByExpr::Expressions(exprs, _) => {
153 assert!(!exprs.is_empty());
154 }
155 sqlparser::ast::GroupByExpr::All(_) => {}
156 }
157 }
158 }
159 _ => panic!("Expected Query"),
160 }
161 }
162
163 #[test]
164 fn test_parse_select_with_order_by() {
165 let parser = SqlParser::new();
166 let statements = parser
167 .parse("SELECT * FROM users ORDER BY name DESC")
168 .unwrap();
169 assert_eq!(statements.len(), 1);
170
171 match &statements[0] {
172 Statement::Query(query) => {
173 assert!(query.order_by.is_some());
174 if let Some(order_by) = &query.order_by {
175 assert!(!order_by.exprs.is_empty());
176 }
177 }
178 _ => panic!("Expected Query"),
179 }
180 }
181
182 #[test]
183 fn test_parse_select_with_limit() {
184 let parser = SqlParser::new();
185 let statements = parser.parse("SELECT * FROM users LIMIT 10").unwrap();
186 assert_eq!(statements.len(), 1);
187
188 match &statements[0] {
189 Statement::Query(query) => {
190 assert!(query.limit.is_some());
191 }
192 _ => panic!("Expected Query"),
193 }
194 }
195
196 #[test]
197 fn test_parse_multiple_statements() {
198 let parser = SqlParser::new();
199 let statements = parser
200 .parse("SELECT * FROM users; SELECT * FROM orders;")
201 .unwrap();
202 assert_eq!(statements.len(), 2);
203 }
204
205 #[test]
206 fn test_parse_invalid_sql() {
207 let parser = SqlParser::new();
208 let result = parser.parse("SELECT * FROM");
209 assert!(result.is_err());
210 }
211}