1pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43 get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
44 get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
45 remove_select_columns, remove_where, rename_columns, rename_tables, replace_by_type,
46 replace_nodes, set_distinct, set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67 ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72 contains_aggregate,
73 contains_subquery,
74 contains_window_function,
75 find_ancestor,
76 find_parent,
77 get_all_tables,
78 get_columns,
79 get_merge_source,
80 get_merge_target,
81 get_tables,
82 is_add,
83 is_aggregate,
84 is_alias,
85 is_alter_table,
86 is_and,
87 is_arithmetic,
88 is_avg,
89 is_between,
90 is_boolean,
91 is_case,
92 is_cast,
93 is_coalesce,
94 is_column,
95 is_comparison,
96 is_concat,
97 is_count,
98 is_create_index,
99 is_create_table,
100 is_create_view,
101 is_cte,
102 is_ddl,
103 is_delete,
104 is_div,
105 is_drop_index,
106 is_drop_table,
107 is_drop_view,
108 is_eq,
109 is_except,
110 is_exists,
111 is_from,
112 is_function,
113 is_group_by,
114 is_gt,
115 is_gte,
116 is_having,
117 is_identifier,
118 is_ilike,
119 is_in,
120 is_insert,
122 is_intersect,
123 is_is_null,
124 is_join,
125 is_like,
126 is_limit,
127 is_literal,
128 is_logical,
129 is_lt,
130 is_lte,
131 is_max_func,
132 is_merge,
133 is_min_func,
134 is_mod,
135 is_mul,
136 is_neq,
137 is_not,
138 is_null_if,
139 is_null_literal,
140 is_offset,
141 is_or,
142 is_order_by,
143 is_ordered,
144 is_paren,
145 is_query,
147 is_safe_cast,
148 is_select,
149 is_set_operation,
150 is_star,
151 is_sub,
152 is_subquery,
153 is_sum,
154 is_table,
155 is_try_cast,
156 is_union,
157 is_update,
158 is_where,
159 is_window_function,
160 is_with,
161 transform,
162 transform_map,
163 BfsIter,
164 DfsIter,
165 ExpressionWalk,
166 ParentInfo,
167 TreeContext,
168};
169pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
170pub use validation::{
171 mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
172 SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
173 SchemaValidationOptions, ValidationSchema,
174};
175
176const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
178const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
179const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
180
181fn default_format_max_input_bytes() -> Option<usize> {
182 Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
183}
184
185fn default_format_max_tokens() -> Option<usize> {
186 Some(DEFAULT_FORMAT_MAX_TOKENS)
187}
188
189fn default_format_max_ast_nodes() -> Option<usize> {
190 Some(DEFAULT_FORMAT_MAX_AST_NODES)
191}
192
193fn default_format_max_set_op_chain() -> Option<usize> {
194 Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
195}
196
197#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
202#[serde(rename_all = "camelCase")]
203pub struct FormatGuardOptions {
204 #[serde(default = "default_format_max_input_bytes")]
207 pub max_input_bytes: Option<usize>,
208 #[serde(default = "default_format_max_tokens")]
211 pub max_tokens: Option<usize>,
212 #[serde(default = "default_format_max_ast_nodes")]
215 pub max_ast_nodes: Option<usize>,
216 #[serde(default = "default_format_max_set_op_chain")]
221 pub max_set_op_chain: Option<usize>,
222}
223
224impl Default for FormatGuardOptions {
225 fn default() -> Self {
226 Self {
227 max_input_bytes: default_format_max_input_bytes(),
228 max_tokens: default_format_max_tokens(),
229 max_ast_nodes: default_format_max_ast_nodes(),
230 max_set_op_chain: default_format_max_set_op_chain(),
231 }
232 }
233}
234
235fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
236 Error::generate(format!(
237 "{code}: value {actual} exceeds configured limit {limit}"
238 ))
239}
240
241fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
242 if let Some(max) = options.max_input_bytes {
243 let input_bytes = sql.len();
244 if input_bytes > max {
245 return Err(format_guard_error(
246 "E_GUARD_INPUT_TOO_LARGE",
247 input_bytes,
248 max,
249 ));
250 }
251 }
252 Ok(())
253}
254
255fn parse_with_token_guard(
256 sql: &str,
257 dialect: &Dialect,
258 options: &FormatGuardOptions,
259) -> Result<Vec<Expression>> {
260 let tokens = dialect.tokenize(sql)?;
261 if let Some(max) = options.max_tokens {
262 let token_count = tokens.len();
263 if token_count > max {
264 return Err(format_guard_error(
265 "E_GUARD_TOKEN_BUDGET_EXCEEDED",
266 token_count,
267 max,
268 ));
269 }
270 }
271 enforce_set_op_chain_guard(&tokens, options)?;
272
273 let config = crate::parser::ParserConfig {
274 dialect: Some(dialect.dialect_type()),
275 ..Default::default()
276 };
277 let mut parser = Parser::with_source(tokens, config, sql.to_string());
278 parser.parse()
279}
280
281fn is_trivia_token(token_type: TokenType) -> bool {
282 matches!(
283 token_type,
284 TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
285 )
286}
287
288fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
289 tokens
290 .iter()
291 .skip(start)
292 .find(|token| !is_trivia_token(token.token_type))
293}
294
295fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
296 let token = &tokens[idx];
297 match token.token_type {
298 TokenType::Union | TokenType::Intersect => true,
299 TokenType::Except => {
300 if token.text.eq_ignore_ascii_case("minus")
303 && matches!(
304 next_significant_token(tokens, idx + 1).map(|t| t.token_type),
305 Some(TokenType::LParen)
306 )
307 {
308 return false;
309 }
310 true
311 }
312 _ => false,
313 }
314}
315
316fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
317 let Some(max) = options.max_set_op_chain else {
318 return Ok(());
319 };
320
321 let mut set_op_count = 0usize;
322 for (idx, token) in tokens.iter().enumerate() {
323 if token.token_type == TokenType::Semicolon {
324 set_op_count = 0;
325 continue;
326 }
327
328 if is_set_operation_token(tokens, idx) {
329 set_op_count += 1;
330 if set_op_count > max {
331 return Err(format_guard_error(
332 "E_GUARD_SET_OP_CHAIN_EXCEEDED",
333 set_op_count,
334 max,
335 ));
336 }
337 }
338 }
339
340 Ok(())
341}
342
343fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
344 if let Some(max) = options.max_ast_nodes {
345 let ast_nodes: usize = expressions.iter().map(node_count).sum();
346 if ast_nodes > max {
347 return Err(format_guard_error(
348 "E_GUARD_AST_BUDGET_EXCEEDED",
349 ast_nodes,
350 max,
351 ));
352 }
353 }
354 Ok(())
355}
356
357fn format_with_dialect(
358 sql: &str,
359 dialect: &Dialect,
360 options: &FormatGuardOptions,
361) -> Result<Vec<String>> {
362 enforce_input_guard(sql, options)?;
363 let expressions = parse_with_token_guard(sql, dialect, options)?;
364 enforce_ast_guard(&expressions, options)?;
365
366 expressions
367 .iter()
368 .map(|expr| dialect.generate_pretty(expr))
369 .collect()
370}
371
372pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
393 let read_dialect = Dialect::get(read);
394 let write_dialect = Dialect::get(write);
395 let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
396
397 let expressions = read_dialect.parse(sql)?;
398
399 expressions
400 .into_iter()
401 .map(|expr| {
402 if generic_identity {
403 write_dialect.generate_with_source(&expr, read)
404 } else {
405 let transformed = write_dialect.transform(expr)?;
406 write_dialect.generate_with_source(&transformed, read)
407 }
408 })
409 .collect()
410}
411
412pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
421 let d = Dialect::get(dialect);
422 d.parse(sql)
423}
424
425pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
434 let mut expressions = parse(sql, dialect)?;
435
436 if expressions.len() != 1 {
437 return Err(Error::parse(
438 format!("Expected 1 statement, found {}", expressions.len()),
439 0,
440 0,
441 0,
442 0,
443 ));
444 }
445
446 Ok(expressions.remove(0))
447}
448
449pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
458 let d = Dialect::get(dialect);
459 d.generate(expression)
460}
461
462pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
466 format_with_options(sql, dialect, &FormatGuardOptions::default())
467}
468
469pub fn format_with_options(
471 sql: &str,
472 dialect: DialectType,
473 options: &FormatGuardOptions,
474) -> Result<Vec<String>> {
475 let d = Dialect::get(dialect);
476 format_with_dialect(sql, &d, options)
477}
478
479pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
488 validate_with_options(sql, dialect, &ValidationOptions::default())
489}
490
491#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
493#[serde(rename_all = "camelCase")]
494pub struct ValidationOptions {
495 #[serde(default)]
498 pub strict_syntax: bool,
499}
500
501pub fn validate_with_options(
503 sql: &str,
504 dialect: DialectType,
505 options: &ValidationOptions,
506) -> ValidationResult {
507 let d = Dialect::get(dialect);
508 match d.parse(sql) {
509 Ok(expressions) => {
510 for expr in &expressions {
514 if !expr.is_statement() {
515 let msg = format!("Invalid expression / Unexpected token");
516 return ValidationResult::with_errors(vec![ValidationError::error(
517 msg, "E004",
518 )]);
519 }
520 }
521 if options.strict_syntax {
522 if let Some(error) = strict_syntax_error(sql, &d) {
523 return ValidationResult::with_errors(vec![error]);
524 }
525 }
526 ValidationResult::success()
527 }
528 Err(e) => {
529 let error = match &e {
530 Error::Syntax {
531 message,
532 line,
533 column,
534 start,
535 end,
536 } => ValidationError::error(message.clone(), "E001")
537 .with_location(*line, *column)
538 .with_span(Some(*start), Some(*end)),
539 Error::Tokenize {
540 message,
541 line,
542 column,
543 start,
544 end,
545 } => ValidationError::error(message.clone(), "E002")
546 .with_location(*line, *column)
547 .with_span(Some(*start), Some(*end)),
548 Error::Parse {
549 message,
550 line,
551 column,
552 start,
553 end,
554 } => ValidationError::error(message.clone(), "E003")
555 .with_location(*line, *column)
556 .with_span(Some(*start), Some(*end)),
557 _ => ValidationError::error(e.to_string(), "E000"),
558 };
559 ValidationResult::with_errors(vec![error])
560 }
561 }
562}
563
564fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
565 let tokens = dialect.tokenize(sql).ok()?;
566
567 for (idx, token) in tokens.iter().enumerate() {
568 if token.token_type != TokenType::Comma {
569 continue;
570 }
571
572 let next = tokens.get(idx + 1);
573 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
574 Some(TokenType::From) => (true, "FROM"),
575 Some(TokenType::Where) => (true, "WHERE"),
576 Some(TokenType::GroupBy) => (true, "GROUP BY"),
577 Some(TokenType::Having) => (true, "HAVING"),
578 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
579 Some(TokenType::Limit) => (true, "LIMIT"),
580 Some(TokenType::Offset) => (true, "OFFSET"),
581 Some(TokenType::Union) => (true, "UNION"),
582 Some(TokenType::Intersect) => (true, "INTERSECT"),
583 Some(TokenType::Except) => (true, "EXCEPT"),
584 Some(TokenType::Qualify) => (true, "QUALIFY"),
585 Some(TokenType::Window) => (true, "WINDOW"),
586 Some(TokenType::Semicolon) | None => (true, "end of statement"),
587 _ => (false, ""),
588 };
589
590 if is_boundary {
591 let message = format!(
592 "Trailing comma before {} is not allowed in strict syntax mode",
593 boundary_name
594 );
595 return Some(
596 ValidationError::error(message, "E005")
597 .with_location(token.span.line, token.span.column),
598 );
599 }
600 }
601
602 None
603}
604
605pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
618 let read_dialect = Dialect::get_by_name(read)
619 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
620 let write_dialect = Dialect::get_by_name(write)
621 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
622 let generic_identity = read_dialect.dialect_type() == DialectType::Generic
623 && write_dialect.dialect_type() == DialectType::Generic;
624
625 let expressions = read_dialect.parse(sql)?;
626
627 expressions
628 .into_iter()
629 .map(|expr| {
630 if generic_identity {
631 write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
632 } else {
633 let transformed = write_dialect.transform(expr)?;
634 write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
635 }
636 })
637 .collect()
638}
639
640pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
644 let d = Dialect::get_by_name(dialect)
645 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
646 d.parse(sql)
647}
648
649pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
653 let d = Dialect::get_by_name(dialect)
654 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
655 d.generate(expression)
656}
657
658pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
662 format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
663}
664
665pub fn format_with_options_by_name(
667 sql: &str,
668 dialect: &str,
669 options: &FormatGuardOptions,
670) -> Result<Vec<String>> {
671 let d = Dialect::get_by_name(dialect)
672 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
673 format_with_dialect(sql, &d, options)
674}
675
676#[cfg(test)]
677mod validation_tests {
678 use super::*;
679
680 #[test]
681 fn validate_is_permissive_by_default_for_trailing_commas() {
682 let result = validate("SELECT name, FROM employees", DialectType::Generic);
683 assert!(result.valid, "Result: {:?}", result.errors);
684 }
685
686 #[test]
687 fn validate_with_options_rejects_trailing_comma_before_from() {
688 let options = ValidationOptions {
689 strict_syntax: true,
690 };
691 let result = validate_with_options(
692 "SELECT name, FROM employees",
693 DialectType::Generic,
694 &options,
695 );
696 assert!(!result.valid, "Result should be invalid");
697 assert!(
698 result.errors.iter().any(|e| e.code == "E005"),
699 "Expected E005, got: {:?}",
700 result.errors
701 );
702 }
703
704 #[test]
705 fn validate_with_options_rejects_trailing_comma_before_where() {
706 let options = ValidationOptions {
707 strict_syntax: true,
708 };
709 let result = validate_with_options(
710 "SELECT name FROM employees, WHERE salary > 10",
711 DialectType::Generic,
712 &options,
713 );
714 assert!(!result.valid, "Result should be invalid");
715 assert!(
716 result.errors.iter().any(|e| e.code == "E005"),
717 "Expected E005, got: {:?}",
718 result.errors
719 );
720 }
721}
722
723#[cfg(test)]
724mod format_tests {
725 use super::*;
726
727 #[test]
728 fn format_basic_query() {
729 let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
730 assert_eq!(result.len(), 1);
731 assert!(result[0].contains('\n'));
732 }
733
734 #[test]
735 fn format_guard_rejects_large_input() {
736 let options = FormatGuardOptions {
737 max_input_bytes: Some(7),
738 max_tokens: None,
739 max_ast_nodes: None,
740 max_set_op_chain: None,
741 };
742 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
743 .expect_err("expected guard error");
744 assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
745 }
746
747 #[test]
748 fn format_guard_rejects_token_budget() {
749 let options = FormatGuardOptions {
750 max_input_bytes: None,
751 max_tokens: Some(1),
752 max_ast_nodes: None,
753 max_set_op_chain: None,
754 };
755 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
756 .expect_err("expected guard error");
757 assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
758 }
759
760 #[test]
761 fn format_guard_rejects_ast_budget() {
762 let options = FormatGuardOptions {
763 max_input_bytes: None,
764 max_tokens: None,
765 max_ast_nodes: Some(1),
766 max_set_op_chain: None,
767 };
768 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
769 .expect_err("expected guard error");
770 assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
771 }
772
773 #[test]
774 fn format_guard_rejects_set_op_chain_budget() {
775 let options = FormatGuardOptions {
776 max_input_bytes: None,
777 max_tokens: None,
778 max_ast_nodes: None,
779 max_set_op_chain: Some(1),
780 };
781 let err = format_with_options(
782 "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
783 DialectType::Generic,
784 &options,
785 )
786 .expect_err("expected guard error");
787 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
788 }
789
790 #[test]
791 fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
792 let options = FormatGuardOptions {
793 max_input_bytes: None,
794 max_tokens: None,
795 max_ast_nodes: None,
796 max_set_op_chain: Some(0),
797 };
798 let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
799 assert!(result.is_ok(), "Result: {:?}", result);
800 }
801
802 #[test]
803 fn issue57_invalid_ternary_returns_error() {
804 let sql = "SELECT x > 0 ? 1 : 0 FROM t";
807
808 let parse_result = parse(sql, DialectType::PostgreSQL);
809 assert!(
810 parse_result.is_err(),
811 "Expected parse error for invalid ternary SQL, got: {:?}",
812 parse_result
813 );
814
815 let format_result = format(sql, DialectType::PostgreSQL);
816 assert!(
817 format_result.is_err(),
818 "Expected format error for invalid ternary SQL, got: {:?}",
819 format_result
820 );
821
822 let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
823 assert!(
824 transpile_result.is_err(),
825 "Expected transpile error for invalid ternary SQL, got: {:?}",
826 transpile_result
827 );
828 }
829
830 #[test]
831 fn format_default_guard_rejects_deep_union_chain_before_parse() {
832 let base = "SELECT col0, col1 FROM t";
833 let mut sql = base.to_string();
834 for _ in 0..1100 {
835 sql.push_str(" UNION ALL ");
836 sql.push_str(base);
837 }
838
839 let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
840 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
841 }
842}