1pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43 get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
44 get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
45 remove_select_columns, remove_where, rename_columns, rename_tables, replace_by_type,
46 replace_nodes, set_distinct, set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67 ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72 contains_aggregate,
73 contains_subquery,
74 contains_window_function,
75 find_ancestor,
76 find_parent,
77 get_columns,
78 get_tables,
79 is_add,
80 is_aggregate,
81 is_alias,
82 is_alter_table,
83 is_and,
84 is_arithmetic,
85 is_avg,
86 is_between,
87 is_boolean,
88 is_case,
89 is_cast,
90 is_coalesce,
91 is_column,
92 is_comparison,
93 is_concat,
94 is_count,
95 is_create_index,
96 is_create_table,
97 is_create_view,
98 is_cte,
99 is_ddl,
100 is_delete,
101 is_div,
102 is_drop_index,
103 is_drop_table,
104 is_drop_view,
105 is_eq,
106 is_except,
107 is_exists,
108 is_from,
109 is_function,
110 is_group_by,
111 is_gt,
112 is_gte,
113 is_having,
114 is_identifier,
115 is_ilike,
116 is_in,
117 is_insert,
119 is_intersect,
120 is_is_null,
121 is_join,
122 is_like,
123 is_limit,
124 is_literal,
125 is_logical,
126 is_lt,
127 is_lte,
128 is_max_func,
129 is_min_func,
130 is_mod,
131 is_mul,
132 is_neq,
133 is_not,
134 is_null_if,
135 is_null_literal,
136 is_offset,
137 is_or,
138 is_order_by,
139 is_ordered,
140 is_paren,
141 is_query,
143 is_safe_cast,
144 is_select,
145 is_set_operation,
146 is_star,
147 is_sub,
148 is_subquery,
149 is_sum,
150 is_table,
151 is_try_cast,
152 is_union,
153 is_update,
154 is_where,
155 is_window_function,
156 is_with,
157 transform,
158 transform_map,
159 BfsIter,
160 DfsIter,
161 ExpressionWalk,
162 ParentInfo,
163 TreeContext,
164};
165pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
166pub use validation::{
167 mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
168 SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
169 SchemaValidationOptions, ValidationSchema,
170};
171
172const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
174const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
175const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
176
177fn default_format_max_input_bytes() -> Option<usize> {
178 Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
179}
180
181fn default_format_max_tokens() -> Option<usize> {
182 Some(DEFAULT_FORMAT_MAX_TOKENS)
183}
184
185fn default_format_max_ast_nodes() -> Option<usize> {
186 Some(DEFAULT_FORMAT_MAX_AST_NODES)
187}
188
189fn default_format_max_set_op_chain() -> Option<usize> {
190 Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
191}
192
193#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
198#[serde(rename_all = "camelCase")]
199pub struct FormatGuardOptions {
200 #[serde(default = "default_format_max_input_bytes")]
203 pub max_input_bytes: Option<usize>,
204 #[serde(default = "default_format_max_tokens")]
207 pub max_tokens: Option<usize>,
208 #[serde(default = "default_format_max_ast_nodes")]
211 pub max_ast_nodes: Option<usize>,
212 #[serde(default = "default_format_max_set_op_chain")]
217 pub max_set_op_chain: Option<usize>,
218}
219
220impl Default for FormatGuardOptions {
221 fn default() -> Self {
222 Self {
223 max_input_bytes: default_format_max_input_bytes(),
224 max_tokens: default_format_max_tokens(),
225 max_ast_nodes: default_format_max_ast_nodes(),
226 max_set_op_chain: default_format_max_set_op_chain(),
227 }
228 }
229}
230
231fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
232 Error::generate(format!(
233 "{code}: value {actual} exceeds configured limit {limit}"
234 ))
235}
236
237fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
238 if let Some(max) = options.max_input_bytes {
239 let input_bytes = sql.len();
240 if input_bytes > max {
241 return Err(format_guard_error(
242 "E_GUARD_INPUT_TOO_LARGE",
243 input_bytes,
244 max,
245 ));
246 }
247 }
248 Ok(())
249}
250
251fn parse_with_token_guard(
252 sql: &str,
253 dialect: &Dialect,
254 options: &FormatGuardOptions,
255) -> Result<Vec<Expression>> {
256 let tokens = dialect.tokenize(sql)?;
257 if let Some(max) = options.max_tokens {
258 let token_count = tokens.len();
259 if token_count > max {
260 return Err(format_guard_error(
261 "E_GUARD_TOKEN_BUDGET_EXCEEDED",
262 token_count,
263 max,
264 ));
265 }
266 }
267 enforce_set_op_chain_guard(&tokens, options)?;
268
269 let config = crate::parser::ParserConfig {
270 dialect: Some(dialect.dialect_type()),
271 ..Default::default()
272 };
273 let mut parser = Parser::with_source(tokens, config, sql.to_string());
274 parser.parse()
275}
276
277fn is_trivia_token(token_type: TokenType) -> bool {
278 matches!(
279 token_type,
280 TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
281 )
282}
283
284fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
285 tokens
286 .iter()
287 .skip(start)
288 .find(|token| !is_trivia_token(token.token_type))
289}
290
291fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
292 let token = &tokens[idx];
293 match token.token_type {
294 TokenType::Union | TokenType::Intersect => true,
295 TokenType::Except => {
296 if token.text.eq_ignore_ascii_case("minus")
299 && matches!(
300 next_significant_token(tokens, idx + 1).map(|t| t.token_type),
301 Some(TokenType::LParen)
302 )
303 {
304 return false;
305 }
306 true
307 }
308 _ => false,
309 }
310}
311
312fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
313 let Some(max) = options.max_set_op_chain else {
314 return Ok(());
315 };
316
317 let mut set_op_count = 0usize;
318 for (idx, token) in tokens.iter().enumerate() {
319 if token.token_type == TokenType::Semicolon {
320 set_op_count = 0;
321 continue;
322 }
323
324 if is_set_operation_token(tokens, idx) {
325 set_op_count += 1;
326 if set_op_count > max {
327 return Err(format_guard_error(
328 "E_GUARD_SET_OP_CHAIN_EXCEEDED",
329 set_op_count,
330 max,
331 ));
332 }
333 }
334 }
335
336 Ok(())
337}
338
339fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
340 if let Some(max) = options.max_ast_nodes {
341 let ast_nodes: usize = expressions.iter().map(node_count).sum();
342 if ast_nodes > max {
343 return Err(format_guard_error(
344 "E_GUARD_AST_BUDGET_EXCEEDED",
345 ast_nodes,
346 max,
347 ));
348 }
349 }
350 Ok(())
351}
352
353fn format_with_dialect(
354 sql: &str,
355 dialect: &Dialect,
356 options: &FormatGuardOptions,
357) -> Result<Vec<String>> {
358 enforce_input_guard(sql, options)?;
359 let expressions = parse_with_token_guard(sql, dialect, options)?;
360 enforce_ast_guard(&expressions, options)?;
361
362 expressions
363 .iter()
364 .map(|expr| dialect.generate_pretty(expr))
365 .collect()
366}
367
368pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
389 let read_dialect = Dialect::get(read);
390 let write_dialect = Dialect::get(write);
391 let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
392
393 let expressions = read_dialect.parse(sql)?;
394
395 expressions
396 .into_iter()
397 .map(|expr| {
398 if generic_identity {
399 write_dialect.generate_with_source(&expr, read)
400 } else {
401 let transformed = write_dialect.transform(expr)?;
402 write_dialect.generate_with_source(&transformed, read)
403 }
404 })
405 .collect()
406}
407
408pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
417 let d = Dialect::get(dialect);
418 d.parse(sql)
419}
420
421pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
430 let mut expressions = parse(sql, dialect)?;
431
432 if expressions.len() != 1 {
433 return Err(Error::parse(
434 format!("Expected 1 statement, found {}", expressions.len()),
435 0,
436 0,
437 0,
438 0,
439 ));
440 }
441
442 Ok(expressions.remove(0))
443}
444
445pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
454 let d = Dialect::get(dialect);
455 d.generate(expression)
456}
457
458pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
462 format_with_options(sql, dialect, &FormatGuardOptions::default())
463}
464
465pub fn format_with_options(
467 sql: &str,
468 dialect: DialectType,
469 options: &FormatGuardOptions,
470) -> Result<Vec<String>> {
471 let d = Dialect::get(dialect);
472 format_with_dialect(sql, &d, options)
473}
474
475pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
484 validate_with_options(sql, dialect, &ValidationOptions::default())
485}
486
487#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
489#[serde(rename_all = "camelCase")]
490pub struct ValidationOptions {
491 #[serde(default)]
494 pub strict_syntax: bool,
495}
496
497pub fn validate_with_options(
499 sql: &str,
500 dialect: DialectType,
501 options: &ValidationOptions,
502) -> ValidationResult {
503 let d = Dialect::get(dialect);
504 match d.parse(sql) {
505 Ok(expressions) => {
506 for expr in &expressions {
510 if !expr.is_statement() {
511 let msg = format!("Invalid expression / Unexpected token");
512 return ValidationResult::with_errors(vec![ValidationError::error(
513 msg, "E004",
514 )]);
515 }
516 }
517 if options.strict_syntax {
518 if let Some(error) = strict_syntax_error(sql, &d) {
519 return ValidationResult::with_errors(vec![error]);
520 }
521 }
522 ValidationResult::success()
523 }
524 Err(e) => {
525 let error = match &e {
526 Error::Syntax {
527 message,
528 line,
529 column,
530 start,
531 end,
532 } => ValidationError::error(message.clone(), "E001")
533 .with_location(*line, *column)
534 .with_span(Some(*start), Some(*end)),
535 Error::Tokenize {
536 message,
537 line,
538 column,
539 start,
540 end,
541 } => ValidationError::error(message.clone(), "E002")
542 .with_location(*line, *column)
543 .with_span(Some(*start), Some(*end)),
544 Error::Parse {
545 message,
546 line,
547 column,
548 start,
549 end,
550 } => ValidationError::error(message.clone(), "E003")
551 .with_location(*line, *column)
552 .with_span(Some(*start), Some(*end)),
553 _ => ValidationError::error(e.to_string(), "E000"),
554 };
555 ValidationResult::with_errors(vec![error])
556 }
557 }
558}
559
560fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
561 let tokens = dialect.tokenize(sql).ok()?;
562
563 for (idx, token) in tokens.iter().enumerate() {
564 if token.token_type != TokenType::Comma {
565 continue;
566 }
567
568 let next = tokens.get(idx + 1);
569 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
570 Some(TokenType::From) => (true, "FROM"),
571 Some(TokenType::Where) => (true, "WHERE"),
572 Some(TokenType::GroupBy) => (true, "GROUP BY"),
573 Some(TokenType::Having) => (true, "HAVING"),
574 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
575 Some(TokenType::Limit) => (true, "LIMIT"),
576 Some(TokenType::Offset) => (true, "OFFSET"),
577 Some(TokenType::Union) => (true, "UNION"),
578 Some(TokenType::Intersect) => (true, "INTERSECT"),
579 Some(TokenType::Except) => (true, "EXCEPT"),
580 Some(TokenType::Qualify) => (true, "QUALIFY"),
581 Some(TokenType::Window) => (true, "WINDOW"),
582 Some(TokenType::Semicolon) | None => (true, "end of statement"),
583 _ => (false, ""),
584 };
585
586 if is_boundary {
587 let message = format!(
588 "Trailing comma before {} is not allowed in strict syntax mode",
589 boundary_name
590 );
591 return Some(
592 ValidationError::error(message, "E005")
593 .with_location(token.span.line, token.span.column),
594 );
595 }
596 }
597
598 None
599}
600
601pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
614 let read_dialect = Dialect::get_by_name(read)
615 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
616 let write_dialect = Dialect::get_by_name(write)
617 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
618 let generic_identity = read_dialect.dialect_type() == DialectType::Generic
619 && write_dialect.dialect_type() == DialectType::Generic;
620
621 let expressions = read_dialect.parse(sql)?;
622
623 expressions
624 .into_iter()
625 .map(|expr| {
626 if generic_identity {
627 write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
628 } else {
629 let transformed = write_dialect.transform(expr)?;
630 write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
631 }
632 })
633 .collect()
634}
635
636pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
640 let d = Dialect::get_by_name(dialect)
641 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
642 d.parse(sql)
643}
644
645pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
649 let d = Dialect::get_by_name(dialect)
650 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
651 d.generate(expression)
652}
653
654pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
658 format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
659}
660
661pub fn format_with_options_by_name(
663 sql: &str,
664 dialect: &str,
665 options: &FormatGuardOptions,
666) -> Result<Vec<String>> {
667 let d = Dialect::get_by_name(dialect)
668 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
669 format_with_dialect(sql, &d, options)
670}
671
672#[cfg(test)]
673mod validation_tests {
674 use super::*;
675
676 #[test]
677 fn validate_is_permissive_by_default_for_trailing_commas() {
678 let result = validate("SELECT name, FROM employees", DialectType::Generic);
679 assert!(result.valid, "Result: {:?}", result.errors);
680 }
681
682 #[test]
683 fn validate_with_options_rejects_trailing_comma_before_from() {
684 let options = ValidationOptions {
685 strict_syntax: true,
686 };
687 let result = validate_with_options(
688 "SELECT name, FROM employees",
689 DialectType::Generic,
690 &options,
691 );
692 assert!(!result.valid, "Result should be invalid");
693 assert!(
694 result.errors.iter().any(|e| e.code == "E005"),
695 "Expected E005, got: {:?}",
696 result.errors
697 );
698 }
699
700 #[test]
701 fn validate_with_options_rejects_trailing_comma_before_where() {
702 let options = ValidationOptions {
703 strict_syntax: true,
704 };
705 let result = validate_with_options(
706 "SELECT name FROM employees, WHERE salary > 10",
707 DialectType::Generic,
708 &options,
709 );
710 assert!(!result.valid, "Result should be invalid");
711 assert!(
712 result.errors.iter().any(|e| e.code == "E005"),
713 "Expected E005, got: {:?}",
714 result.errors
715 );
716 }
717}
718
719#[cfg(test)]
720mod format_tests {
721 use super::*;
722
723 #[test]
724 fn format_basic_query() {
725 let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
726 assert_eq!(result.len(), 1);
727 assert!(result[0].contains('\n'));
728 }
729
730 #[test]
731 fn format_guard_rejects_large_input() {
732 let options = FormatGuardOptions {
733 max_input_bytes: Some(7),
734 max_tokens: None,
735 max_ast_nodes: None,
736 max_set_op_chain: None,
737 };
738 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
739 .expect_err("expected guard error");
740 assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
741 }
742
743 #[test]
744 fn format_guard_rejects_token_budget() {
745 let options = FormatGuardOptions {
746 max_input_bytes: None,
747 max_tokens: Some(1),
748 max_ast_nodes: None,
749 max_set_op_chain: None,
750 };
751 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
752 .expect_err("expected guard error");
753 assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
754 }
755
756 #[test]
757 fn format_guard_rejects_ast_budget() {
758 let options = FormatGuardOptions {
759 max_input_bytes: None,
760 max_tokens: None,
761 max_ast_nodes: Some(1),
762 max_set_op_chain: None,
763 };
764 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
765 .expect_err("expected guard error");
766 assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
767 }
768
769 #[test]
770 fn format_guard_rejects_set_op_chain_budget() {
771 let options = FormatGuardOptions {
772 max_input_bytes: None,
773 max_tokens: None,
774 max_ast_nodes: None,
775 max_set_op_chain: Some(1),
776 };
777 let err = format_with_options(
778 "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
779 DialectType::Generic,
780 &options,
781 )
782 .expect_err("expected guard error");
783 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
784 }
785
786 #[test]
787 fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
788 let options = FormatGuardOptions {
789 max_input_bytes: None,
790 max_tokens: None,
791 max_ast_nodes: None,
792 max_set_op_chain: Some(0),
793 };
794 let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
795 assert!(result.is_ok(), "Result: {:?}", result);
796 }
797
798 #[test]
799 fn issue57_invalid_ternary_returns_error() {
800 let sql = "SELECT x > 0 ? 1 : 0 FROM t";
803
804 let parse_result = parse(sql, DialectType::PostgreSQL);
805 assert!(
806 parse_result.is_err(),
807 "Expected parse error for invalid ternary SQL, got: {:?}",
808 parse_result
809 );
810
811 let format_result = format(sql, DialectType::PostgreSQL);
812 assert!(
813 format_result.is_err(),
814 "Expected format error for invalid ternary SQL, got: {:?}",
815 format_result
816 );
817
818 let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
819 assert!(
820 transpile_result.is_err(),
821 "Expected transpile error for invalid ternary SQL, got: {:?}",
822 transpile_result
823 );
824 }
825
826 #[test]
827 fn format_default_guard_rejects_deep_union_chain_before_parse() {
828 let base = "SELECT col0, col1 FROM t";
829 let mut sql = base.to_string();
830 for _ in 0..1100 {
831 sql.push_str(" UNION ALL ");
832 sql.push_str(base);
833 }
834
835 let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
836 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
837 }
838}