1pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43 get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
44 node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
45 remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
46 set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67 ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72 contains_aggregate,
73 contains_subquery,
74 contains_window_function,
75 find_ancestor,
76 find_parent,
77 get_columns,
78 get_tables,
79 is_add,
80 is_aggregate,
81 is_alias,
82 is_alter_table,
83 is_and,
84 is_arithmetic,
85 is_avg,
86 is_between,
87 is_boolean,
88 is_case,
89 is_cast,
90 is_coalesce,
91 is_column,
92 is_comparison,
93 is_concat,
94 is_count,
95 is_create_index,
96 is_create_table,
97 is_create_view,
98 is_cte,
99 is_ddl,
100 is_delete,
101 is_div,
102 is_drop_index,
103 is_drop_table,
104 is_drop_view,
105 is_eq,
106 is_except,
107 is_exists,
108 is_from,
109 is_function,
110 is_group_by,
111 is_gt,
112 is_gte,
113 is_having,
114 is_identifier,
115 is_ilike,
116 is_in,
117 is_insert,
119 is_intersect,
120 is_is_null,
121 is_join,
122 is_like,
123 is_limit,
124 is_literal,
125 is_logical,
126 is_lt,
127 is_lte,
128 is_max_func,
129 is_min_func,
130 is_mod,
131 is_mul,
132 is_neq,
133 is_not,
134 is_null_if,
135 is_null_literal,
136 is_offset,
137 is_or,
138 is_order_by,
139 is_ordered,
140 is_paren,
141 is_query,
143 is_safe_cast,
144 is_select,
145 is_set_operation,
146 is_star,
147 is_sub,
148 is_subquery,
149 is_sum,
150 is_table,
151 is_try_cast,
152 is_union,
153 is_update,
154 is_where,
155 is_window_function,
156 is_with,
157 transform,
158 transform_map,
159 BfsIter,
160 DfsIter,
161 ExpressionWalk,
162 ParentInfo,
163 TreeContext,
164};
165pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
166pub use validation::{
167 validate_with_schema, SchemaColumn, SchemaColumnReference, SchemaForeignKey, SchemaTable,
168 SchemaTableReference, SchemaValidationOptions, ValidationSchema,
169};
170
171const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
173const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
174const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
175
176fn default_format_max_input_bytes() -> Option<usize> {
177 Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
178}
179
180fn default_format_max_tokens() -> Option<usize> {
181 Some(DEFAULT_FORMAT_MAX_TOKENS)
182}
183
184fn default_format_max_ast_nodes() -> Option<usize> {
185 Some(DEFAULT_FORMAT_MAX_AST_NODES)
186}
187
188fn default_format_max_set_op_chain() -> Option<usize> {
189 Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
190}
191
192#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct FormatGuardOptions {
199 #[serde(default = "default_format_max_input_bytes")]
202 pub max_input_bytes: Option<usize>,
203 #[serde(default = "default_format_max_tokens")]
206 pub max_tokens: Option<usize>,
207 #[serde(default = "default_format_max_ast_nodes")]
210 pub max_ast_nodes: Option<usize>,
211 #[serde(default = "default_format_max_set_op_chain")]
216 pub max_set_op_chain: Option<usize>,
217}
218
219impl Default for FormatGuardOptions {
220 fn default() -> Self {
221 Self {
222 max_input_bytes: default_format_max_input_bytes(),
223 max_tokens: default_format_max_tokens(),
224 max_ast_nodes: default_format_max_ast_nodes(),
225 max_set_op_chain: default_format_max_set_op_chain(),
226 }
227 }
228}
229
230fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
231 Error::generate(format!(
232 "{code}: value {actual} exceeds configured limit {limit}"
233 ))
234}
235
236fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
237 if let Some(max) = options.max_input_bytes {
238 let input_bytes = sql.len();
239 if input_bytes > max {
240 return Err(format_guard_error(
241 "E_GUARD_INPUT_TOO_LARGE",
242 input_bytes,
243 max,
244 ));
245 }
246 }
247 Ok(())
248}
249
250fn parse_with_token_guard(
251 sql: &str,
252 dialect: &Dialect,
253 options: &FormatGuardOptions,
254) -> Result<Vec<Expression>> {
255 let tokens = dialect.tokenize(sql)?;
256 if let Some(max) = options.max_tokens {
257 let token_count = tokens.len();
258 if token_count > max {
259 return Err(format_guard_error(
260 "E_GUARD_TOKEN_BUDGET_EXCEEDED",
261 token_count,
262 max,
263 ));
264 }
265 }
266 enforce_set_op_chain_guard(&tokens, options)?;
267
268 let config = crate::parser::ParserConfig {
269 dialect: Some(dialect.dialect_type()),
270 ..Default::default()
271 };
272 let mut parser = Parser::with_source(tokens, config, sql.to_string());
273 parser.parse()
274}
275
276fn is_trivia_token(token_type: TokenType) -> bool {
277 matches!(
278 token_type,
279 TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
280 )
281}
282
283fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
284 tokens
285 .iter()
286 .skip(start)
287 .find(|token| !is_trivia_token(token.token_type))
288}
289
290fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
291 let token = &tokens[idx];
292 match token.token_type {
293 TokenType::Union | TokenType::Intersect => true,
294 TokenType::Except => {
295 if token.text.eq_ignore_ascii_case("minus")
298 && matches!(
299 next_significant_token(tokens, idx + 1).map(|t| t.token_type),
300 Some(TokenType::LParen)
301 )
302 {
303 return false;
304 }
305 true
306 }
307 _ => false,
308 }
309}
310
311fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
312 let Some(max) = options.max_set_op_chain else {
313 return Ok(());
314 };
315
316 let mut set_op_count = 0usize;
317 for (idx, token) in tokens.iter().enumerate() {
318 if token.token_type == TokenType::Semicolon {
319 set_op_count = 0;
320 continue;
321 }
322
323 if is_set_operation_token(tokens, idx) {
324 set_op_count += 1;
325 if set_op_count > max {
326 return Err(format_guard_error(
327 "E_GUARD_SET_OP_CHAIN_EXCEEDED",
328 set_op_count,
329 max,
330 ));
331 }
332 }
333 }
334
335 Ok(())
336}
337
338fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
339 if let Some(max) = options.max_ast_nodes {
340 let ast_nodes: usize = expressions.iter().map(node_count).sum();
341 if ast_nodes > max {
342 return Err(format_guard_error(
343 "E_GUARD_AST_BUDGET_EXCEEDED",
344 ast_nodes,
345 max,
346 ));
347 }
348 }
349 Ok(())
350}
351
352fn format_with_dialect(
353 sql: &str,
354 dialect: &Dialect,
355 options: &FormatGuardOptions,
356) -> Result<Vec<String>> {
357 enforce_input_guard(sql, options)?;
358 let expressions = parse_with_token_guard(sql, dialect, options)?;
359 enforce_ast_guard(&expressions, options)?;
360
361 expressions
362 .iter()
363 .map(|expr| dialect.generate_pretty(expr))
364 .collect()
365}
366
367pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
388 let read_dialect = Dialect::get(read);
389 let write_dialect = Dialect::get(write);
390 let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
391
392 let expressions = read_dialect.parse(sql)?;
393
394 expressions
395 .into_iter()
396 .map(|expr| {
397 if generic_identity {
398 write_dialect.generate_with_source(&expr, read)
399 } else {
400 let transformed = write_dialect.transform(expr)?;
401 write_dialect.generate_with_source(&transformed, read)
402 }
403 })
404 .collect()
405}
406
407pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
416 let d = Dialect::get(dialect);
417 d.parse(sql)
418}
419
420pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
429 let mut expressions = parse(sql, dialect)?;
430
431 if expressions.len() != 1 {
432 return Err(Error::parse(
433 format!("Expected 1 statement, found {}", expressions.len()),
434 0,
435 0,
436 0,
437 0,
438 ));
439 }
440
441 Ok(expressions.remove(0))
442}
443
444pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
453 let d = Dialect::get(dialect);
454 d.generate(expression)
455}
456
457pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
461 format_with_options(sql, dialect, &FormatGuardOptions::default())
462}
463
464pub fn format_with_options(
466 sql: &str,
467 dialect: DialectType,
468 options: &FormatGuardOptions,
469) -> Result<Vec<String>> {
470 let d = Dialect::get(dialect);
471 format_with_dialect(sql, &d, options)
472}
473
474pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
483 validate_with_options(sql, dialect, &ValidationOptions::default())
484}
485
486#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
488#[serde(rename_all = "camelCase")]
489pub struct ValidationOptions {
490 #[serde(default)]
493 pub strict_syntax: bool,
494}
495
496pub fn validate_with_options(
498 sql: &str,
499 dialect: DialectType,
500 options: &ValidationOptions,
501) -> ValidationResult {
502 let d = Dialect::get(dialect);
503 match d.parse(sql) {
504 Ok(expressions) => {
505 for expr in &expressions {
509 if !expr.is_statement() {
510 let msg = format!("Invalid expression / Unexpected token");
511 return ValidationResult::with_errors(vec![ValidationError::error(
512 msg, "E004",
513 )]);
514 }
515 }
516 if options.strict_syntax {
517 if let Some(error) = strict_syntax_error(sql, &d) {
518 return ValidationResult::with_errors(vec![error]);
519 }
520 }
521 ValidationResult::success()
522 }
523 Err(e) => {
524 let error = match &e {
525 Error::Syntax {
526 message,
527 line,
528 column,
529 start,
530 end,
531 } => ValidationError::error(message.clone(), "E001")
532 .with_location(*line, *column)
533 .with_span(Some(*start), Some(*end)),
534 Error::Tokenize {
535 message,
536 line,
537 column,
538 start,
539 end,
540 } => ValidationError::error(message.clone(), "E002")
541 .with_location(*line, *column)
542 .with_span(Some(*start), Some(*end)),
543 Error::Parse {
544 message,
545 line,
546 column,
547 start,
548 end,
549 } => ValidationError::error(message.clone(), "E003")
550 .with_location(*line, *column)
551 .with_span(Some(*start), Some(*end)),
552 _ => ValidationError::error(e.to_string(), "E000"),
553 };
554 ValidationResult::with_errors(vec![error])
555 }
556 }
557}
558
559fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
560 let tokens = dialect.tokenize(sql).ok()?;
561
562 for (idx, token) in tokens.iter().enumerate() {
563 if token.token_type != TokenType::Comma {
564 continue;
565 }
566
567 let next = tokens.get(idx + 1);
568 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
569 Some(TokenType::From) => (true, "FROM"),
570 Some(TokenType::Where) => (true, "WHERE"),
571 Some(TokenType::GroupBy) => (true, "GROUP BY"),
572 Some(TokenType::Having) => (true, "HAVING"),
573 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
574 Some(TokenType::Limit) => (true, "LIMIT"),
575 Some(TokenType::Offset) => (true, "OFFSET"),
576 Some(TokenType::Union) => (true, "UNION"),
577 Some(TokenType::Intersect) => (true, "INTERSECT"),
578 Some(TokenType::Except) => (true, "EXCEPT"),
579 Some(TokenType::Qualify) => (true, "QUALIFY"),
580 Some(TokenType::Window) => (true, "WINDOW"),
581 Some(TokenType::Semicolon) | None => (true, "end of statement"),
582 _ => (false, ""),
583 };
584
585 if is_boundary {
586 let message = format!(
587 "Trailing comma before {} is not allowed in strict syntax mode",
588 boundary_name
589 );
590 return Some(
591 ValidationError::error(message, "E005")
592 .with_location(token.span.line, token.span.column),
593 );
594 }
595 }
596
597 None
598}
599
600pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
613 let read_dialect = Dialect::get_by_name(read)
614 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
615 let write_dialect = Dialect::get_by_name(write)
616 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
617 let generic_identity = read_dialect.dialect_type() == DialectType::Generic
618 && write_dialect.dialect_type() == DialectType::Generic;
619
620 let expressions = read_dialect.parse(sql)?;
621
622 expressions
623 .into_iter()
624 .map(|expr| {
625 if generic_identity {
626 write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
627 } else {
628 let transformed = write_dialect.transform(expr)?;
629 write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
630 }
631 })
632 .collect()
633}
634
635pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
639 let d = Dialect::get_by_name(dialect)
640 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
641 d.parse(sql)
642}
643
644pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
648 let d = Dialect::get_by_name(dialect)
649 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
650 d.generate(expression)
651}
652
653pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
657 format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
658}
659
660pub fn format_with_options_by_name(
662 sql: &str,
663 dialect: &str,
664 options: &FormatGuardOptions,
665) -> Result<Vec<String>> {
666 let d = Dialect::get_by_name(dialect)
667 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
668 format_with_dialect(sql, &d, options)
669}
670
671#[cfg(test)]
672mod validation_tests {
673 use super::*;
674
675 #[test]
676 fn validate_is_permissive_by_default_for_trailing_commas() {
677 let result = validate("SELECT name, FROM employees", DialectType::Generic);
678 assert!(result.valid, "Result: {:?}", result.errors);
679 }
680
681 #[test]
682 fn validate_with_options_rejects_trailing_comma_before_from() {
683 let options = ValidationOptions {
684 strict_syntax: true,
685 };
686 let result = validate_with_options(
687 "SELECT name, FROM employees",
688 DialectType::Generic,
689 &options,
690 );
691 assert!(!result.valid, "Result should be invalid");
692 assert!(
693 result.errors.iter().any(|e| e.code == "E005"),
694 "Expected E005, got: {:?}",
695 result.errors
696 );
697 }
698
699 #[test]
700 fn validate_with_options_rejects_trailing_comma_before_where() {
701 let options = ValidationOptions {
702 strict_syntax: true,
703 };
704 let result = validate_with_options(
705 "SELECT name FROM employees, WHERE salary > 10",
706 DialectType::Generic,
707 &options,
708 );
709 assert!(!result.valid, "Result should be invalid");
710 assert!(
711 result.errors.iter().any(|e| e.code == "E005"),
712 "Expected E005, got: {:?}",
713 result.errors
714 );
715 }
716}
717
718#[cfg(test)]
719mod format_tests {
720 use super::*;
721
722 #[test]
723 fn format_basic_query() {
724 let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
725 assert_eq!(result.len(), 1);
726 assert!(result[0].contains('\n'));
727 }
728
729 #[test]
730 fn format_guard_rejects_large_input() {
731 let options = FormatGuardOptions {
732 max_input_bytes: Some(7),
733 max_tokens: None,
734 max_ast_nodes: None,
735 max_set_op_chain: None,
736 };
737 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
738 .expect_err("expected guard error");
739 assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
740 }
741
742 #[test]
743 fn format_guard_rejects_token_budget() {
744 let options = FormatGuardOptions {
745 max_input_bytes: None,
746 max_tokens: Some(1),
747 max_ast_nodes: None,
748 max_set_op_chain: None,
749 };
750 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
751 .expect_err("expected guard error");
752 assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
753 }
754
755 #[test]
756 fn format_guard_rejects_ast_budget() {
757 let options = FormatGuardOptions {
758 max_input_bytes: None,
759 max_tokens: None,
760 max_ast_nodes: Some(1),
761 max_set_op_chain: None,
762 };
763 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
764 .expect_err("expected guard error");
765 assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
766 }
767
768 #[test]
769 fn format_guard_rejects_set_op_chain_budget() {
770 let options = FormatGuardOptions {
771 max_input_bytes: None,
772 max_tokens: None,
773 max_ast_nodes: None,
774 max_set_op_chain: Some(1),
775 };
776 let err = format_with_options(
777 "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
778 DialectType::Generic,
779 &options,
780 )
781 .expect_err("expected guard error");
782 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
783 }
784
785 #[test]
786 fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
787 let options = FormatGuardOptions {
788 max_input_bytes: None,
789 max_tokens: None,
790 max_ast_nodes: None,
791 max_set_op_chain: Some(0),
792 };
793 let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
794 assert!(result.is_ok(), "Result: {:?}", result);
795 }
796
797 #[test]
798 fn format_default_guard_rejects_deep_union_chain_before_parse() {
799 let base = "SELECT col0, col1 FROM t";
800 let mut sql = base.to_string();
801 for _ in 0..1100 {
802 sql.push_str(" UNION ALL ");
803 sql.push_str(base);
804 }
805
806 let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
807 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
808 }
809}