1use async_trait::async_trait;
2use crate::utils::{InternalQuery, QueryOperation, DataSource, Column, Predicate, PredicateOperator, PredicateValue, OrderBy, OrderColumn, OrderDirection};
3use crate::utils::error::{QueryParsingError, NirvResult};
4use sqlparser::ast::{Statement, Query, SelectItem, Expr, BinaryOperator, Value as SqlValue, OrderByExpr, FunctionArg, FunctionArgExpr};
5use sqlparser::dialect::{PostgreSqlDialect, MySqlDialect, SQLiteDialect, GenericDialect};
6use sqlparser::parser::Parser;
7use regex::Regex;
8
9#[async_trait]
11pub trait QueryParser: Send + Sync {
12 async fn parse_sql(&self, sql: &str) -> NirvResult<InternalQuery>;
14
15 async fn validate_syntax(&self, sql: &str) -> NirvResult<bool>;
17
18 async fn extract_sources(&self, sql: &str) -> NirvResult<Vec<String>>;
20}
21
22pub struct DefaultQueryParser {
24 postgres_dialect: PostgreSqlDialect,
25 mysql_dialect: MySqlDialect,
26 sqlite_dialect: SQLiteDialect,
27 generic_dialect: GenericDialect,
28 source_regex: Regex,
29}
30
31impl DefaultQueryParser {
32 pub fn new() -> NirvResult<Self> {
34 let source_regex = Regex::new(r#"source\s*\(\s*['"]([^'"]+)['"]\s*\)"#)
35 .map_err(|e| QueryParsingError::InvalidSyntax(format!("Failed to compile source regex: {}", e)))?;
36
37 Ok(Self {
38 postgres_dialect: PostgreSqlDialect {},
39 mysql_dialect: MySqlDialect {},
40 sqlite_dialect: SQLiteDialect {},
41 generic_dialect: GenericDialect {},
42 source_regex,
43 })
44 }
45
46 pub fn parse(&self, sql: &str) -> NirvResult<InternalQuery> {
48 let statement = self.try_parse_with_dialects(sql)?;
50
51 match statement {
52 Statement::Query(query) => self.convert_query(*query),
53 _ => Err(QueryParsingError::UnsupportedFeature("Only SELECT queries are currently supported".to_string()).into()),
54 }
55 }
56
57 fn try_parse_with_dialects(&self, sql: &str) -> NirvResult<Statement> {
59 if let Ok(statements) = Parser::parse_sql(&self.postgres_dialect, sql) {
61 if let Some(statement) = statements.into_iter().next() {
62 return Ok(statement);
63 }
64 }
65
66 if let Ok(statements) = Parser::parse_sql(&self.mysql_dialect, sql) {
68 if let Some(statement) = statements.into_iter().next() {
69 return Ok(statement);
70 }
71 }
72
73 if let Ok(statements) = Parser::parse_sql(&self.sqlite_dialect, sql) {
75 if let Some(statement) = statements.into_iter().next() {
76 return Ok(statement);
77 }
78 }
79
80 if let Ok(statements) = Parser::parse_sql(&self.generic_dialect, sql) {
82 if let Some(statement) = statements.into_iter().next() {
83 return Ok(statement);
84 }
85 }
86
87 Err(QueryParsingError::InvalidSyntax("Failed to parse SQL with any supported dialect".to_string()).into())
88 }
89
90 fn convert_query(&self, query: Query) -> NirvResult<InternalQuery> {
92 let mut internal_query = InternalQuery::new(QueryOperation::Select);
93
94 if let sqlparser::ast::SetExpr::Select(body) = query.body.as_ref() {
95 internal_query.projections = self.extract_projections(&body.projection)?;
97
98 internal_query.sources = self.extract_sources(&body.from)?;
100
101 if let Some(selection) = &body.selection {
103 internal_query.predicates = self.extract_predicates(selection)?;
104 }
105
106 if !query.order_by.is_empty() {
108 internal_query.ordering = Some(self.extract_order_by(&query.order_by)?);
109 }
110
111 if let Some(limit) = &query.limit {
113 internal_query.limit = Some(self.extract_limit(limit)?);
114 }
115 } else {
116 return Err(QueryParsingError::UnsupportedFeature("Only SELECT queries are supported".to_string()).into());
117 }
118
119 Ok(internal_query)
120 }
121
122 fn extract_projections(&self, projection: &[SelectItem]) -> NirvResult<Vec<Column>> {
124 let mut columns = Vec::new();
125
126 for item in projection {
127 match item {
128 SelectItem::UnnamedExpr(expr) => {
129 let column = self.extract_column_from_expr(expr, None)?;
130 columns.push(column);
131 }
132 SelectItem::ExprWithAlias { expr, alias } => {
133 let column = self.extract_column_from_expr(expr, Some(alias.value.clone()))?;
134 columns.push(column);
135 }
136 SelectItem::Wildcard(_) => {
137 columns.push(Column {
138 name: "*".to_string(),
139 alias: None,
140 source: None,
141 });
142 }
143 SelectItem::QualifiedWildcard(object_name, _) => {
144 columns.push(Column {
145 name: "*".to_string(),
146 alias: None,
147 source: Some(object_name.to_string()),
148 });
149 }
150 }
151 }
152
153 Ok(columns)
154 }
155
156 fn extract_column_from_expr(&self, expr: &Expr, alias: Option<String>) -> NirvResult<Column> {
158 match expr {
159 Expr::Identifier(ident) => {
160 Ok(Column {
161 name: ident.value.clone(),
162 alias,
163 source: None,
164 })
165 }
166 Expr::CompoundIdentifier(idents) => {
167 if idents.len() == 2 {
168 Ok(Column {
169 name: idents[1].value.clone(),
170 alias,
171 source: Some(idents[0].value.clone()),
172 })
173 } else {
174 Ok(Column {
175 name: idents.last().unwrap().value.clone(),
176 alias,
177 source: None,
178 })
179 }
180 }
181 Expr::Function(func) => {
182 if func.name.to_string().to_lowercase() == "source" {
184 return Err(QueryParsingError::InvalidSourceFormat("source() function should be used in FROM clause, not SELECT".to_string()).into());
185 }
186
187 Ok(Column {
188 name: func.name.to_string(),
189 alias,
190 source: None,
191 })
192 }
193 _ => {
194 Ok(Column {
195 name: "expr".to_string(),
196 alias,
197 source: None,
198 })
199 }
200 }
201 }
202
203 fn extract_sources(&self, from: &[sqlparser::ast::TableWithJoins]) -> NirvResult<Vec<DataSource>> {
205 let mut sources = Vec::new();
206
207 for table_with_joins in from {
208 let source = self.extract_source_from_table(&table_with_joins.relation)?;
209 sources.push(source);
210 }
211
212 if sources.is_empty() {
213 return Err(QueryParsingError::MissingSource.into());
214 }
215
216 Ok(sources)
217 }
218
219 fn extract_source_from_table(&self, table: &sqlparser::ast::TableFactor) -> NirvResult<DataSource> {
221 match table {
222 sqlparser::ast::TableFactor::Table { name, alias, args, .. } => {
223 let table_name = name.to_string();
224
225 if table_name.to_lowercase() == "source" && args.is_some() {
227 let source_spec = self.extract_source_from_function_args(args.as_ref().unwrap())?;
228 Ok(DataSource {
229 object_type: source_spec.0,
230 identifier: source_spec.1,
231 alias: alias.as_ref().map(|a| a.name.value.clone()),
232 })
233 } else {
234 Ok(DataSource {
236 object_type: "table".to_string(),
237 identifier: table_name,
238 alias: alias.as_ref().map(|a| a.name.value.clone()),
239 })
240 }
241 }
242 sqlparser::ast::TableFactor::Derived { alias, .. } => {
243 Ok(DataSource {
245 object_type: "subquery".to_string(),
246 identifier: "derived".to_string(),
247 alias: alias.as_ref().map(|a| a.name.value.clone()),
248 })
249 }
250 sqlparser::ast::TableFactor::Function { name, args, alias, .. } => {
251 if name.to_string().to_lowercase() == "source" {
253 let source_spec = self.extract_source_from_function_args(args)?;
254 Ok(DataSource {
255 object_type: source_spec.0,
256 identifier: source_spec.1,
257 alias: alias.as_ref().map(|a| a.name.value.clone()),
258 })
259 } else {
260 Err(QueryParsingError::UnsupportedFeature(format!("Function {} not supported in FROM clause", name)).into())
261 }
262 }
263 _ => Err(QueryParsingError::UnsupportedFeature("Unsupported table reference type".to_string()).into()),
264 }
265 }
266
267 fn extract_source_function(&self, table_name: &str) -> NirvResult<Option<(String, String)>> {
269 if let Some(captures) = self.source_regex.captures(table_name) {
270 if let Some(source_spec) = captures.get(1) {
271 let spec = source_spec.as_str();
272 if let Some(dot_pos) = spec.find('.') {
273 let object_type = spec[..dot_pos].to_string();
274 let identifier = spec[dot_pos + 1..].to_string();
275 Ok(Some((object_type, identifier)))
276 } else {
277 Ok(Some(("table".to_string(), spec.to_string())))
279 }
280 } else {
281 Err(QueryParsingError::InvalidSourceFormat("Empty source specification".to_string()).into())
282 }
283 } else {
284 Ok(None)
285 }
286 }
287
288 fn extract_source_from_function_args(&self, args: &[FunctionArg]) -> NirvResult<(String, String)> {
290 if args.len() != 1 {
291 return Err(QueryParsingError::InvalidSourceFormat("source() function requires exactly one argument".to_string()).into());
292 }
293
294 if let FunctionArg::Unnamed(FunctionArgExpr::Expr(Expr::Value(SqlValue::SingleQuotedString(spec)))) = &args[0] {
295 if let Some(dot_pos) = spec.find('.') {
296 let object_type = spec[..dot_pos].to_string();
297 let identifier = spec[dot_pos + 1..].to_string();
298 Ok((object_type, identifier))
299 } else {
300 Ok(("table".to_string(), spec.to_string()))
302 }
303 } else {
304 Err(QueryParsingError::InvalidSourceFormat("source() function argument must be a string literal".to_string()).into())
305 }
306 }
307
308 fn extract_predicates(&self, expr: &Expr) -> NirvResult<Vec<Predicate>> {
310 let mut predicates = Vec::new();
311 self.extract_predicates_recursive(expr, &mut predicates)?;
312 Ok(predicates)
313 }
314
315 fn extract_predicates_recursive(&self, expr: &Expr, predicates: &mut Vec<Predicate>) -> NirvResult<()> {
317 match expr {
318 Expr::BinaryOp { left, op, right } => {
319 match op {
320 BinaryOperator::And => {
321 self.extract_predicates_recursive(left, predicates)?;
323 self.extract_predicates_recursive(right, predicates)?;
324 }
325 BinaryOperator::Or => {
326 self.extract_predicates_recursive(left, predicates)?;
328 self.extract_predicates_recursive(right, predicates)?;
329 }
330 _ => {
331 let predicate = self.create_predicate_from_binary_op(left, op, right)?;
333 predicates.push(predicate);
334 }
335 }
336 }
337 Expr::IsNull(expr) => {
338 let column = self.extract_column_name_from_expr(expr)?;
339 predicates.push(Predicate {
340 column,
341 operator: PredicateOperator::IsNull,
342 value: PredicateValue::Null,
343 });
344 }
345 Expr::IsNotNull(expr) => {
346 let column = self.extract_column_name_from_expr(expr)?;
347 predicates.push(Predicate {
348 column,
349 operator: PredicateOperator::IsNotNull,
350 value: PredicateValue::Null,
351 });
352 }
353 _ => {
354 }
356 }
357 Ok(())
358 }
359
360 fn create_predicate_from_binary_op(&self, left: &Expr, op: &BinaryOperator, right: &Expr) -> NirvResult<Predicate> {
362 let column = self.extract_column_name_from_expr(left)?;
363 let operator = self.convert_binary_operator(op)?;
364 let value = self.extract_predicate_value_from_expr(right)?;
365
366 Ok(Predicate {
367 column,
368 operator,
369 value,
370 })
371 }
372
373 fn extract_column_name_from_expr(&self, expr: &Expr) -> NirvResult<String> {
375 match expr {
376 Expr::Identifier(ident) => Ok(ident.value.clone()),
377 Expr::CompoundIdentifier(idents) => {
378 if idents.len() >= 2 {
379 Ok(format!("{}.{}", idents[0].value, idents[1].value))
380 } else {
381 Ok(idents[0].value.clone())
382 }
383 }
384 _ => Err(QueryParsingError::InvalidSyntax("Expected column identifier in predicate".to_string()).into()),
385 }
386 }
387
388 fn convert_binary_operator(&self, op: &BinaryOperator) -> NirvResult<PredicateOperator> {
390 match op {
391 BinaryOperator::Eq => Ok(PredicateOperator::Equal),
392 BinaryOperator::NotEq => Ok(PredicateOperator::NotEqual),
393 BinaryOperator::Gt => Ok(PredicateOperator::GreaterThan),
394 BinaryOperator::GtEq => Ok(PredicateOperator::GreaterThanOrEqual),
395 BinaryOperator::Lt => Ok(PredicateOperator::LessThan),
396 BinaryOperator::LtEq => Ok(PredicateOperator::LessThanOrEqual),
397 _ => Err(QueryParsingError::UnsupportedFeature(format!("Operator {:?} not supported", op)).into()),
399 }
400 }
401
402 fn extract_predicate_value_from_expr(&self, expr: &Expr) -> NirvResult<PredicateValue> {
404 match expr {
405 Expr::Value(sql_value) => self.convert_sql_value(sql_value),
406 Expr::Identifier(ident) => Ok(PredicateValue::String(ident.value.clone())),
407 _ => Err(QueryParsingError::UnsupportedFeature("Complex expressions in predicates not yet supported".to_string()).into()),
408 }
409 }
410
411 fn convert_sql_value(&self, value: &SqlValue) -> NirvResult<PredicateValue> {
413 match value {
414 SqlValue::SingleQuotedString(s) | SqlValue::DoubleQuotedString(s) => {
415 Ok(PredicateValue::String(s.clone()))
416 }
417 SqlValue::Number(n, _) => {
418 if let Ok(int_val) = n.parse::<i64>() {
419 Ok(PredicateValue::Integer(int_val))
420 } else if let Ok(float_val) = n.parse::<f64>() {
421 Ok(PredicateValue::Number(float_val))
422 } else {
423 Err(QueryParsingError::InvalidSyntax(format!("Invalid number format: {}", n)).into())
424 }
425 }
426 SqlValue::Boolean(b) => Ok(PredicateValue::Boolean(*b)),
427 SqlValue::Null => Ok(PredicateValue::Null),
428 _ => Err(QueryParsingError::UnsupportedFeature(format!("Value type {:?} not supported", value)).into()),
429 }
430 }
431
432 fn extract_order_by(&self, order_by: &[OrderByExpr]) -> NirvResult<OrderBy> {
434 let mut columns = Vec::new();
435
436 for order_expr in order_by {
437 let column_name = self.extract_column_name_from_expr(&order_expr.expr)?;
438 let direction = if order_expr.asc.unwrap_or(true) {
439 OrderDirection::Ascending
440 } else {
441 OrderDirection::Descending
442 };
443
444 columns.push(OrderColumn {
445 column: column_name,
446 direction,
447 });
448 }
449
450 Ok(OrderBy { columns })
451 }
452
453 fn extract_limit(&self, limit_expr: &Expr) -> NirvResult<u64> {
455 match limit_expr {
456 Expr::Value(SqlValue::Number(n, _)) => {
457 n.parse::<u64>()
458 .map_err(|_| QueryParsingError::InvalidSyntax(format!("Invalid LIMIT value: {}", n)).into())
459 }
460 _ => Err(QueryParsingError::InvalidSyntax("LIMIT must be a number".to_string()).into()),
461 }
462 }
463}
464
465impl Default for DefaultQueryParser {
466 fn default() -> Self {
467 Self::new().expect("Failed to create default QueryParser")
468 }
469}
470
471#[async_trait]
472impl QueryParser for DefaultQueryParser {
473 async fn parse_sql(&self, sql: &str) -> NirvResult<InternalQuery> {
474 self.parse(sql)
475 }
476
477 async fn validate_syntax(&self, sql: &str) -> NirvResult<bool> {
478 match self.try_parse_with_dialects(sql) {
479 Ok(_) => Ok(true),
480 Err(_) => Ok(false),
481 }
482 }
483
484 async fn extract_sources(&self, sql: &str) -> NirvResult<Vec<String>> {
485 let query = self.parse(sql)?;
486 Ok(query.sources.into_iter()
487 .map(|source| format!("{}.{}", source.object_type, source.identifier))
488 .collect())
489 }
490}
491#[cfg
492(test)]
493mod tests {
494 use super::*;
495 use crate::utils::{QueryOperation, DataSource, Column, Predicate, PredicateOperator, PredicateValue, OrderDirection};
496
497 fn create_parser() -> DefaultQueryParser {
498 DefaultQueryParser::new().expect("Failed to create parser")
499 }
500
501 #[test]
502 fn test_parser_creation() {
503 let parser = DefaultQueryParser::new();
504 assert!(parser.is_ok());
505 }
506
507 #[test]
508 fn test_simple_select_all() {
509 let parser = create_parser();
510 let sql = "SELECT * FROM source('postgres.users')";
511 let result = parser.parse(sql);
512
513 assert!(result.is_ok());
514 let query = result.unwrap();
515
516 assert_eq!(query.operation, QueryOperation::Select);
517 assert_eq!(query.projections.len(), 1);
518 assert_eq!(query.projections[0].name, "*");
519 assert_eq!(query.sources.len(), 1);
520 assert_eq!(query.sources[0].object_type, "postgres");
521 assert_eq!(query.sources[0].identifier, "users");
522 }
523
524 #[test]
525 fn test_select_with_columns() {
526 let parser = create_parser();
527 let sql = "SELECT id, name, email FROM source('postgres.users')";
528 let result = parser.parse(sql);
529
530 assert!(result.is_ok());
531 let query = result.unwrap();
532
533 assert_eq!(query.projections.len(), 3);
534 assert_eq!(query.projections[0].name, "id");
535 assert_eq!(query.projections[1].name, "name");
536 assert_eq!(query.projections[2].name, "email");
537 }
538
539 #[test]
540 fn test_select_with_aliases() {
541 let parser = create_parser();
542 let sql = "SELECT id as user_id, name as full_name FROM source('postgres.users') as u";
543 let result = parser.parse(sql);
544
545 assert!(result.is_ok());
546 let query = result.unwrap();
547
548 assert_eq!(query.projections.len(), 2);
549 assert_eq!(query.projections[0].name, "id");
550 assert_eq!(query.projections[0].alias, Some("user_id".to_string()));
551 assert_eq!(query.projections[1].name, "name");
552 assert_eq!(query.projections[1].alias, Some("full_name".to_string()));
553 assert_eq!(query.sources[0].alias, Some("u".to_string()));
554 }
555
556 #[test]
557 fn test_source_function_parsing() {
558 let parser = create_parser();
559
560 let test_cases = vec![
562 ("SELECT * FROM source('postgres.users')", "postgres", "users"),
563 ("SELECT * FROM source('file.data.csv')", "file", "data.csv"),
564 ("SELECT * FROM source('api.endpoint')", "api", "endpoint"),
565 ("SELECT * FROM source('users')", "table", "users"), ];
567
568 for (sql, expected_type, expected_id) in test_cases {
569 let result = parser.parse(sql);
570 assert!(result.is_ok(), "Failed to parse: {}", sql);
571
572 let query = result.unwrap();
573 assert_eq!(query.sources.len(), 1);
574 assert_eq!(query.sources[0].object_type, expected_type);
575 assert_eq!(query.sources[0].identifier, expected_id);
576 }
577 }
578
579 #[test]
580 fn test_where_clause_parsing() {
581 let parser = create_parser();
582 let sql = "SELECT * FROM source('postgres.users') WHERE age > 18 AND name = 'John'";
583 let result = parser.parse(sql);
584
585 assert!(result.is_ok());
586 let query = result.unwrap();
587
588 assert_eq!(query.predicates.len(), 2);
589
590 assert_eq!(query.predicates[0].column, "age");
592 assert_eq!(query.predicates[0].operator, PredicateOperator::GreaterThan);
593 assert_eq!(query.predicates[0].value, PredicateValue::Integer(18));
594
595 assert_eq!(query.predicates[1].column, "name");
597 assert_eq!(query.predicates[1].operator, PredicateOperator::Equal);
598 assert_eq!(query.predicates[1].value, PredicateValue::String("John".to_string()));
599 }
600
601 #[test]
602 fn test_various_operators() {
603 let parser = create_parser();
604
605 let test_cases = vec![
606 ("SELECT * FROM source('test') WHERE id = 1", PredicateOperator::Equal),
607 ("SELECT * FROM source('test') WHERE id != 1", PredicateOperator::NotEqual),
608 ("SELECT * FROM source('test') WHERE id > 1", PredicateOperator::GreaterThan),
609 ("SELECT * FROM source('test') WHERE id >= 1", PredicateOperator::GreaterThanOrEqual),
610 ("SELECT * FROM source('test') WHERE id < 1", PredicateOperator::LessThan),
611 ("SELECT * FROM source('test') WHERE id <= 1", PredicateOperator::LessThanOrEqual),
612 ];
614
615 for (sql, expected_op) in test_cases {
616 let result = parser.parse(sql);
617 assert!(result.is_ok(), "Failed to parse: {}", sql);
618
619 let query = result.unwrap();
620 assert_eq!(query.predicates.len(), 1);
621 assert_eq!(query.predicates[0].operator, expected_op);
622 }
623 }
624
625 #[test]
626 fn test_null_predicates() {
627 let parser = create_parser();
628
629 let sql1 = "SELECT * FROM source('test') WHERE name IS NULL";
630 let result1 = parser.parse(sql1);
631 assert!(result1.is_ok());
632 let query1 = result1.unwrap();
633 assert_eq!(query1.predicates.len(), 1);
634 assert_eq!(query1.predicates[0].operator, PredicateOperator::IsNull);
635
636 let sql2 = "SELECT * FROM source('test') WHERE name IS NOT NULL";
637 let result2 = parser.parse(sql2);
638 assert!(result2.is_ok());
639 let query2 = result2.unwrap();
640 assert_eq!(query2.predicates.len(), 1);
641 assert_eq!(query2.predicates[0].operator, PredicateOperator::IsNotNull);
642 }
643
644 #[test]
645 fn test_order_by_clause() {
646 let parser = create_parser();
647 let sql = "SELECT * FROM source('postgres.users') ORDER BY name ASC, age DESC";
648 let result = parser.parse(sql);
649
650 assert!(result.is_ok());
651 let query = result.unwrap();
652
653 assert!(query.ordering.is_some());
654 let ordering = query.ordering.unwrap();
655 assert_eq!(ordering.columns.len(), 2);
656
657 assert_eq!(ordering.columns[0].column, "name");
658 assert_eq!(ordering.columns[0].direction, OrderDirection::Ascending);
659
660 assert_eq!(ordering.columns[1].column, "age");
661 assert_eq!(ordering.columns[1].direction, OrderDirection::Descending);
662 }
663
664 #[test]
665 fn test_limit_clause() {
666 let parser = create_parser();
667 let sql = "SELECT * FROM source('postgres.users') LIMIT 10";
668 let result = parser.parse(sql);
669
670 assert!(result.is_ok());
671 let query = result.unwrap();
672
673 assert!(query.limit.is_some());
674 assert_eq!(query.limit.unwrap(), 10);
675 }
676
677 #[test]
678 fn test_complex_query() {
679 let parser = create_parser();
680 let sql = "SELECT id, name as full_name FROM source('postgres.users') as u WHERE age >= 21 AND status = 'active' ORDER BY name ASC LIMIT 50";
681 let result = parser.parse(sql);
682
683 assert!(result.is_ok());
684 let query = result.unwrap();
685
686 assert_eq!(query.projections.len(), 2);
688 assert_eq!(query.projections[0].name, "id");
689 assert_eq!(query.projections[1].alias, Some("full_name".to_string()));
690
691 assert_eq!(query.sources.len(), 1);
693 assert_eq!(query.sources[0].object_type, "postgres");
694 assert_eq!(query.sources[0].identifier, "users");
695 assert_eq!(query.sources[0].alias, Some("u".to_string()));
696
697 assert_eq!(query.predicates.len(), 2);
699
700 assert!(query.ordering.is_some());
702
703 assert_eq!(query.limit, Some(50));
705 }
706
707 #[test]
708 fn test_postgresql_dialect() {
709 let parser = create_parser();
710 let sql = "SELECT id, name FROM source('postgres.users') WHERE created_at > '2023-01-01'";
711 let result = parser.parse(sql);
712
713 assert!(result.is_ok());
715 }
716
717 #[test]
718 fn test_mysql_dialect() {
719 let parser = create_parser();
720 let sql = "SELECT id, name FROM source('mysql.users') WHERE created_at > '2023-01-01'";
721 let result = parser.parse(sql);
722
723 assert!(result.is_ok());
725 }
726
727 #[test]
728 fn test_sqlite_dialect() {
729 let parser = create_parser();
730 let sql = "SELECT id, name FROM source('sqlite.users') WHERE created_at > '2023-01-01'";
731 let result = parser.parse(sql);
732
733 assert!(result.is_ok());
735 }
736
737 #[test]
738 fn test_invalid_sql_syntax() {
739 let parser = create_parser();
740 let sql = "INVALID SQL SYNTAX";
741 let result = parser.parse(sql);
742
743 assert!(result.is_err());
744 match result.unwrap_err() {
745 crate::utils::error::NirvError::QueryParsing(QueryParsingError::InvalidSyntax(_)) => {},
746 _ => panic!("Expected InvalidSyntax error"),
747 }
748 }
749
750 #[test]
751 fn test_missing_source_function() {
752 let parser = create_parser();
753 let sql = "SELECT * FROM users"; let result = parser.parse(sql);
755
756 assert!(result.is_ok()); let query = result.unwrap();
758 assert_eq!(query.sources[0].object_type, "table");
759 assert_eq!(query.sources[0].identifier, "users");
760 }
761
762 #[test]
763 fn test_invalid_source_format() {
764 let parser = create_parser();
765 let sql = "SELECT * FROM source()"; let result = parser.parse(sql);
767
768 assert!(result.is_err());
769 assert!(matches!(result.unwrap_err(),
771 crate::utils::error::NirvError::QueryParsing(QueryParsingError::InvalidSourceFormat(_)) |
772 crate::utils::error::NirvError::QueryParsing(QueryParsingError::InvalidSyntax(_))
773 ));
774 }
775
776 #[test]
777 fn test_unsupported_query_type() {
778 let parser = create_parser();
779 let sql = "INSERT INTO users (name) VALUES ('John')";
780 let result = parser.parse(sql);
781
782 assert!(result.is_err());
783 match result.unwrap_err() {
784 crate::utils::error::NirvError::QueryParsing(QueryParsingError::UnsupportedFeature(_)) => {},
785 _ => panic!("Expected UnsupportedFeature error"),
786 }
787 }
788
789 #[test]
790 fn test_source_function_in_select_clause() {
791 let parser = create_parser();
792 let sql = "SELECT source('test') FROM users";
793 let result = parser.parse(sql);
794
795 assert!(result.is_err());
796 match result.unwrap_err() {
797 crate::utils::error::NirvError::QueryParsing(QueryParsingError::InvalidSourceFormat(_)) => {},
798 _ => panic!("Expected InvalidSourceFormat error"),
799 }
800 }
801
802 #[test]
803 fn test_compound_identifiers() {
804 let parser = create_parser();
805 let sql = "SELECT u.id, u.name FROM source('postgres.users') as u WHERE u.age > 18";
806 let result = parser.parse(sql);
807
808 assert!(result.is_ok());
809 let query = result.unwrap();
810
811 assert_eq!(query.projections.len(), 2);
813 assert_eq!(query.projections[0].name, "id");
814 assert_eq!(query.projections[0].source, Some("u".to_string()));
815 assert_eq!(query.projections[1].name, "name");
816 assert_eq!(query.projections[1].source, Some("u".to_string()));
817
818 assert_eq!(query.predicates.len(), 1);
820 assert_eq!(query.predicates[0].column, "u.age");
821 }
822
823 #[test]
824 fn test_various_value_types() {
825 let parser = create_parser();
826
827 let test_cases = vec![
828 ("SELECT * FROM source('test') WHERE str_col = 'text'", PredicateValue::String("text".to_string())),
829 ("SELECT * FROM source('test') WHERE int_col = 42", PredicateValue::Integer(42)),
830 ("SELECT * FROM source('test') WHERE float_col = 3.14", PredicateValue::Number(3.14)),
831 ("SELECT * FROM source('test') WHERE bool_col = true", PredicateValue::Boolean(true)),
832 ("SELECT * FROM source('test') WHERE null_col = NULL", PredicateValue::Null),
833 ];
834
835 for (sql, expected_value) in test_cases {
836 let result = parser.parse(sql);
837 assert!(result.is_ok(), "Failed to parse: {}", sql);
838
839 let query = result.unwrap();
840 assert_eq!(query.predicates.len(), 1);
841 assert_eq!(query.predicates[0].value, expected_value);
842 }
843 }
844
845 #[test]
846 fn test_double_quoted_strings() {
847 let parser = create_parser();
848 let sql = r#"SELECT * FROM source('postgres.users') WHERE name = "John""#;
849 let result = parser.parse(sql);
850
851 assert!(result.is_ok());
852 let query = result.unwrap();
853 assert_eq!(query.sources[0].object_type, "postgres");
854 assert_eq!(query.sources[0].identifier, "users");
855 assert_eq!(query.predicates[0].value, PredicateValue::String("John".to_string()));
856 }
857
858 #[test]
859 fn test_qualified_wildcard() {
860 let parser = create_parser();
861 let sql = "SELECT u.* FROM source('postgres.users') as u";
862 let result = parser.parse(sql);
863
864 assert!(result.is_ok());
865 let query = result.unwrap();
866 assert_eq!(query.projections.len(), 1);
867 assert_eq!(query.projections[0].name, "*");
868 assert_eq!(query.projections[0].source, Some("u".to_string()));
869 }
870}