1pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43 get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
44 node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
45 remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
46 set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67 ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72 contains_aggregate,
73 contains_subquery,
74 contains_window_function,
75 find_ancestor,
76 find_parent,
77 get_columns,
78 get_tables,
79 is_add,
80 is_aggregate,
81 is_alias,
82 is_alter_table,
83 is_and,
84 is_arithmetic,
85 is_avg,
86 is_between,
87 is_boolean,
88 is_case,
89 is_cast,
90 is_coalesce,
91 is_column,
92 is_comparison,
93 is_concat,
94 is_count,
95 is_create_index,
96 is_create_table,
97 is_create_view,
98 is_cte,
99 is_ddl,
100 is_delete,
101 is_div,
102 is_drop_index,
103 is_drop_table,
104 is_drop_view,
105 is_eq,
106 is_except,
107 is_exists,
108 is_from,
109 is_function,
110 is_group_by,
111 is_gt,
112 is_gte,
113 is_having,
114 is_identifier,
115 is_ilike,
116 is_in,
117 is_insert,
119 is_intersect,
120 is_is_null,
121 is_join,
122 is_like,
123 is_limit,
124 is_literal,
125 is_logical,
126 is_lt,
127 is_lte,
128 is_max_func,
129 is_min_func,
130 is_mod,
131 is_mul,
132 is_neq,
133 is_not,
134 is_null_if,
135 is_null_literal,
136 is_offset,
137 is_or,
138 is_order_by,
139 is_ordered,
140 is_paren,
141 is_query,
143 is_safe_cast,
144 is_select,
145 is_set_operation,
146 is_star,
147 is_sub,
148 is_subquery,
149 is_sum,
150 is_table,
151 is_try_cast,
152 is_union,
153 is_update,
154 is_where,
155 is_window_function,
156 is_with,
157 transform,
158 transform_map,
159 BfsIter,
160 DfsIter,
161 ExpressionWalk,
162 ParentInfo,
163 TreeContext,
164};
165pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
166pub use validation::{
167 validate_with_schema, SchemaColumn, SchemaColumnReference, SchemaForeignKey, SchemaTable,
168 SchemaTableReference, SchemaValidationOptions, ValidationSchema,
169};
170
171const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
173const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
174const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
175
176fn default_format_max_input_bytes() -> Option<usize> {
177 Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
178}
179
180fn default_format_max_tokens() -> Option<usize> {
181 Some(DEFAULT_FORMAT_MAX_TOKENS)
182}
183
184fn default_format_max_ast_nodes() -> Option<usize> {
185 Some(DEFAULT_FORMAT_MAX_AST_NODES)
186}
187
188fn default_format_max_set_op_chain() -> Option<usize> {
189 Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
190}
191
192#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct FormatGuardOptions {
199 #[serde(default = "default_format_max_input_bytes")]
202 pub max_input_bytes: Option<usize>,
203 #[serde(default = "default_format_max_tokens")]
206 pub max_tokens: Option<usize>,
207 #[serde(default = "default_format_max_ast_nodes")]
210 pub max_ast_nodes: Option<usize>,
211 #[serde(default = "default_format_max_set_op_chain")]
216 pub max_set_op_chain: Option<usize>,
217}
218
219impl Default for FormatGuardOptions {
220 fn default() -> Self {
221 Self {
222 max_input_bytes: default_format_max_input_bytes(),
223 max_tokens: default_format_max_tokens(),
224 max_ast_nodes: default_format_max_ast_nodes(),
225 max_set_op_chain: default_format_max_set_op_chain(),
226 }
227 }
228}
229
230fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
231 Error::generate(format!(
232 "{code}: value {actual} exceeds configured limit {limit}"
233 ))
234}
235
236fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
237 if let Some(max) = options.max_input_bytes {
238 let input_bytes = sql.len();
239 if input_bytes > max {
240 return Err(format_guard_error(
241 "E_GUARD_INPUT_TOO_LARGE",
242 input_bytes,
243 max,
244 ));
245 }
246 }
247 Ok(())
248}
249
250fn parse_with_token_guard(
251 sql: &str,
252 dialect: &Dialect,
253 options: &FormatGuardOptions,
254) -> Result<Vec<Expression>> {
255 let tokens = dialect.tokenize(sql)?;
256 if let Some(max) = options.max_tokens {
257 let token_count = tokens.len();
258 if token_count > max {
259 return Err(format_guard_error(
260 "E_GUARD_TOKEN_BUDGET_EXCEEDED",
261 token_count,
262 max,
263 ));
264 }
265 }
266 enforce_set_op_chain_guard(&tokens, options)?;
267
268 let config = crate::parser::ParserConfig {
269 dialect: Some(dialect.dialect_type()),
270 ..Default::default()
271 };
272 let mut parser = Parser::with_source(tokens, config, sql.to_string());
273 parser.parse()
274}
275
276fn is_trivia_token(token_type: TokenType) -> bool {
277 matches!(
278 token_type,
279 TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
280 )
281}
282
283fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
284 tokens
285 .iter()
286 .skip(start)
287 .find(|token| !is_trivia_token(token.token_type))
288}
289
290fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
291 let token = &tokens[idx];
292 match token.token_type {
293 TokenType::Union | TokenType::Intersect => true,
294 TokenType::Except => {
295 if token.text.eq_ignore_ascii_case("minus")
298 && matches!(
299 next_significant_token(tokens, idx + 1).map(|t| t.token_type),
300 Some(TokenType::LParen)
301 )
302 {
303 return false;
304 }
305 true
306 }
307 _ => false,
308 }
309}
310
311fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
312 let Some(max) = options.max_set_op_chain else {
313 return Ok(());
314 };
315
316 let mut set_op_count = 0usize;
317 for (idx, token) in tokens.iter().enumerate() {
318 if token.token_type == TokenType::Semicolon {
319 set_op_count = 0;
320 continue;
321 }
322
323 if is_set_operation_token(tokens, idx) {
324 set_op_count += 1;
325 if set_op_count > max {
326 return Err(format_guard_error(
327 "E_GUARD_SET_OP_CHAIN_EXCEEDED",
328 set_op_count,
329 max,
330 ));
331 }
332 }
333 }
334
335 Ok(())
336}
337
338fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
339 if let Some(max) = options.max_ast_nodes {
340 let ast_nodes: usize = expressions.iter().map(node_count).sum();
341 if ast_nodes > max {
342 return Err(format_guard_error(
343 "E_GUARD_AST_BUDGET_EXCEEDED",
344 ast_nodes,
345 max,
346 ));
347 }
348 }
349 Ok(())
350}
351
352fn format_with_dialect(
353 sql: &str,
354 dialect: &Dialect,
355 options: &FormatGuardOptions,
356) -> Result<Vec<String>> {
357 enforce_input_guard(sql, options)?;
358 let expressions = parse_with_token_guard(sql, dialect, options)?;
359 enforce_ast_guard(&expressions, options)?;
360
361 expressions
362 .iter()
363 .map(|expr| dialect.generate_pretty(expr))
364 .collect()
365}
366
367pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
388 let read_dialect = Dialect::get(read);
389 let write_dialect = Dialect::get(write);
390 let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
391
392 let expressions = read_dialect.parse(sql)?;
393
394 expressions
395 .into_iter()
396 .map(|expr| {
397 if generic_identity {
398 write_dialect.generate_with_source(&expr, read)
399 } else {
400 let transformed = write_dialect.transform(expr)?;
401 write_dialect.generate_with_source(&transformed, read)
402 }
403 })
404 .collect()
405}
406
407pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
416 let d = Dialect::get(dialect);
417 d.parse(sql)
418}
419
420pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
429 let mut expressions = parse(sql, dialect)?;
430
431 if expressions.len() != 1 {
432 return Err(Error::parse(
433 format!("Expected 1 statement, found {}", expressions.len()),
434 0,
435 0,
436 ));
437 }
438
439 Ok(expressions.remove(0))
440}
441
442pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
451 let d = Dialect::get(dialect);
452 d.generate(expression)
453}
454
455pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
459 format_with_options(sql, dialect, &FormatGuardOptions::default())
460}
461
462pub fn format_with_options(
464 sql: &str,
465 dialect: DialectType,
466 options: &FormatGuardOptions,
467) -> Result<Vec<String>> {
468 let d = Dialect::get(dialect);
469 format_with_dialect(sql, &d, options)
470}
471
472pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
481 validate_with_options(sql, dialect, &ValidationOptions::default())
482}
483
484#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
486#[serde(rename_all = "camelCase")]
487pub struct ValidationOptions {
488 #[serde(default)]
491 pub strict_syntax: bool,
492}
493
494pub fn validate_with_options(
496 sql: &str,
497 dialect: DialectType,
498 options: &ValidationOptions,
499) -> ValidationResult {
500 let d = Dialect::get(dialect);
501 match d.parse(sql) {
502 Ok(expressions) => {
503 for expr in &expressions {
507 if !expr.is_statement() {
508 let msg = format!("Invalid expression / Unexpected token");
509 return ValidationResult::with_errors(vec![ValidationError::error(
510 msg, "E004",
511 )]);
512 }
513 }
514 if options.strict_syntax {
515 if let Some(error) = strict_syntax_error(sql, &d) {
516 return ValidationResult::with_errors(vec![error]);
517 }
518 }
519 ValidationResult::success()
520 }
521 Err(e) => {
522 let error = match &e {
523 Error::Syntax {
524 message,
525 line,
526 column,
527 } => ValidationError::error(message.clone(), "E001").with_location(*line, *column),
528 Error::Tokenize {
529 message,
530 line,
531 column,
532 } => ValidationError::error(message.clone(), "E002").with_location(*line, *column),
533 Error::Parse {
534 message,
535 line,
536 column,
537 } => ValidationError::error(message.clone(), "E003").with_location(*line, *column),
538 _ => ValidationError::error(e.to_string(), "E000"),
539 };
540 ValidationResult::with_errors(vec![error])
541 }
542 }
543}
544
545fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
546 let tokens = dialect.tokenize(sql).ok()?;
547
548 for (idx, token) in tokens.iter().enumerate() {
549 if token.token_type != TokenType::Comma {
550 continue;
551 }
552
553 let next = tokens.get(idx + 1);
554 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
555 Some(TokenType::From) => (true, "FROM"),
556 Some(TokenType::Where) => (true, "WHERE"),
557 Some(TokenType::GroupBy) => (true, "GROUP BY"),
558 Some(TokenType::Having) => (true, "HAVING"),
559 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
560 Some(TokenType::Limit) => (true, "LIMIT"),
561 Some(TokenType::Offset) => (true, "OFFSET"),
562 Some(TokenType::Union) => (true, "UNION"),
563 Some(TokenType::Intersect) => (true, "INTERSECT"),
564 Some(TokenType::Except) => (true, "EXCEPT"),
565 Some(TokenType::Qualify) => (true, "QUALIFY"),
566 Some(TokenType::Window) => (true, "WINDOW"),
567 Some(TokenType::Semicolon) | None => (true, "end of statement"),
568 _ => (false, ""),
569 };
570
571 if is_boundary {
572 let message = format!(
573 "Trailing comma before {} is not allowed in strict syntax mode",
574 boundary_name
575 );
576 return Some(
577 ValidationError::error(message, "E005")
578 .with_location(token.span.line, token.span.column),
579 );
580 }
581 }
582
583 None
584}
585
586pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
599 let read_dialect = Dialect::get_by_name(read)
600 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0))?;
601 let write_dialect = Dialect::get_by_name(write)
602 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0))?;
603 let generic_identity = read_dialect.dialect_type() == DialectType::Generic
604 && write_dialect.dialect_type() == DialectType::Generic;
605
606 let expressions = read_dialect.parse(sql)?;
607
608 expressions
609 .into_iter()
610 .map(|expr| {
611 if generic_identity {
612 write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
613 } else {
614 let transformed = write_dialect.transform(expr)?;
615 write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
616 }
617 })
618 .collect()
619}
620
621pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
625 let d = Dialect::get_by_name(dialect)
626 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
627 d.parse(sql)
628}
629
630pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
634 let d = Dialect::get_by_name(dialect)
635 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
636 d.generate(expression)
637}
638
639pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
643 format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
644}
645
646pub fn format_with_options_by_name(
648 sql: &str,
649 dialect: &str,
650 options: &FormatGuardOptions,
651) -> Result<Vec<String>> {
652 let d = Dialect::get_by_name(dialect)
653 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
654 format_with_dialect(sql, &d, options)
655}
656
657#[cfg(test)]
658mod validation_tests {
659 use super::*;
660
661 #[test]
662 fn validate_is_permissive_by_default_for_trailing_commas() {
663 let result = validate("SELECT name, FROM employees", DialectType::Generic);
664 assert!(result.valid, "Result: {:?}", result.errors);
665 }
666
667 #[test]
668 fn validate_with_options_rejects_trailing_comma_before_from() {
669 let options = ValidationOptions {
670 strict_syntax: true,
671 };
672 let result = validate_with_options(
673 "SELECT name, FROM employees",
674 DialectType::Generic,
675 &options,
676 );
677 assert!(!result.valid, "Result should be invalid");
678 assert!(
679 result.errors.iter().any(|e| e.code == "E005"),
680 "Expected E005, got: {:?}",
681 result.errors
682 );
683 }
684
685 #[test]
686 fn validate_with_options_rejects_trailing_comma_before_where() {
687 let options = ValidationOptions {
688 strict_syntax: true,
689 };
690 let result = validate_with_options(
691 "SELECT name FROM employees, WHERE salary > 10",
692 DialectType::Generic,
693 &options,
694 );
695 assert!(!result.valid, "Result should be invalid");
696 assert!(
697 result.errors.iter().any(|e| e.code == "E005"),
698 "Expected E005, got: {:?}",
699 result.errors
700 );
701 }
702}
703
704#[cfg(test)]
705mod format_tests {
706 use super::*;
707
708 #[test]
709 fn format_basic_query() {
710 let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
711 assert_eq!(result.len(), 1);
712 assert!(result[0].contains('\n'));
713 }
714
715 #[test]
716 fn format_guard_rejects_large_input() {
717 let options = FormatGuardOptions {
718 max_input_bytes: Some(7),
719 max_tokens: None,
720 max_ast_nodes: None,
721 max_set_op_chain: None,
722 };
723 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
724 .expect_err("expected guard error");
725 assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
726 }
727
728 #[test]
729 fn format_guard_rejects_token_budget() {
730 let options = FormatGuardOptions {
731 max_input_bytes: None,
732 max_tokens: Some(1),
733 max_ast_nodes: None,
734 max_set_op_chain: None,
735 };
736 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
737 .expect_err("expected guard error");
738 assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
739 }
740
741 #[test]
742 fn format_guard_rejects_ast_budget() {
743 let options = FormatGuardOptions {
744 max_input_bytes: None,
745 max_tokens: None,
746 max_ast_nodes: Some(1),
747 max_set_op_chain: None,
748 };
749 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
750 .expect_err("expected guard error");
751 assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
752 }
753
754 #[test]
755 fn format_guard_rejects_set_op_chain_budget() {
756 let options = FormatGuardOptions {
757 max_input_bytes: None,
758 max_tokens: None,
759 max_ast_nodes: None,
760 max_set_op_chain: Some(1),
761 };
762 let err = format_with_options(
763 "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
764 DialectType::Generic,
765 &options,
766 )
767 .expect_err("expected guard error");
768 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
769 }
770
771 #[test]
772 fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
773 let options = FormatGuardOptions {
774 max_input_bytes: None,
775 max_tokens: None,
776 max_ast_nodes: None,
777 max_set_op_chain: Some(0),
778 };
779 let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
780 assert!(result.is_ok(), "Result: {:?}", result);
781 }
782
783 #[test]
784 fn format_default_guard_rejects_deep_union_chain_before_parse() {
785 let base = "SELECT col0, col1 FROM t";
786 let mut sql = base.to_string();
787 for _ in 0..1100 {
788 sql.push_str(" UNION ALL ");
789 sql.push_str(base);
790 }
791
792 let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
793 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
794 }
795}