1pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43 get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
44 get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
45 remove_select_columns, remove_where, rename_columns, rename_tables, replace_by_type,
46 replace_nodes, set_distinct, set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67 ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72 contains_aggregate,
73 contains_subquery,
74 contains_window_function,
75 find_ancestor,
76 find_parent,
77 get_columns,
78 get_merge_source,
79 get_merge_target,
80 get_tables,
81 is_add,
82 is_aggregate,
83 is_alias,
84 is_alter_table,
85 is_and,
86 is_arithmetic,
87 is_avg,
88 is_between,
89 is_boolean,
90 is_case,
91 is_cast,
92 is_coalesce,
93 is_column,
94 is_comparison,
95 is_concat,
96 is_count,
97 is_create_index,
98 is_create_table,
99 is_create_view,
100 is_cte,
101 is_ddl,
102 is_delete,
103 is_div,
104 is_drop_index,
105 is_drop_table,
106 is_drop_view,
107 is_eq,
108 is_except,
109 is_exists,
110 is_from,
111 is_function,
112 is_group_by,
113 is_gt,
114 is_gte,
115 is_having,
116 is_identifier,
117 is_ilike,
118 is_in,
119 is_insert,
121 is_intersect,
122 is_is_null,
123 is_join,
124 is_like,
125 is_limit,
126 is_literal,
127 is_logical,
128 is_merge,
129 is_lt,
130 is_lte,
131 is_max_func,
132 is_min_func,
133 is_mod,
134 is_mul,
135 is_neq,
136 is_not,
137 is_null_if,
138 is_null_literal,
139 is_offset,
140 is_or,
141 is_order_by,
142 is_ordered,
143 is_paren,
144 is_query,
146 is_safe_cast,
147 is_select,
148 is_set_operation,
149 is_star,
150 is_sub,
151 is_subquery,
152 is_sum,
153 is_table,
154 is_try_cast,
155 is_union,
156 is_update,
157 is_where,
158 is_window_function,
159 is_with,
160 transform,
161 transform_map,
162 BfsIter,
163 DfsIter,
164 ExpressionWalk,
165 ParentInfo,
166 TreeContext,
167};
168pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
169pub use validation::{
170 mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
171 SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
172 SchemaValidationOptions, ValidationSchema,
173};
174
175const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
177const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
178const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
179
180fn default_format_max_input_bytes() -> Option<usize> {
181 Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
182}
183
184fn default_format_max_tokens() -> Option<usize> {
185 Some(DEFAULT_FORMAT_MAX_TOKENS)
186}
187
188fn default_format_max_ast_nodes() -> Option<usize> {
189 Some(DEFAULT_FORMAT_MAX_AST_NODES)
190}
191
192fn default_format_max_set_op_chain() -> Option<usize> {
193 Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
194}
195
196#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
201#[serde(rename_all = "camelCase")]
202pub struct FormatGuardOptions {
203 #[serde(default = "default_format_max_input_bytes")]
206 pub max_input_bytes: Option<usize>,
207 #[serde(default = "default_format_max_tokens")]
210 pub max_tokens: Option<usize>,
211 #[serde(default = "default_format_max_ast_nodes")]
214 pub max_ast_nodes: Option<usize>,
215 #[serde(default = "default_format_max_set_op_chain")]
220 pub max_set_op_chain: Option<usize>,
221}
222
223impl Default for FormatGuardOptions {
224 fn default() -> Self {
225 Self {
226 max_input_bytes: default_format_max_input_bytes(),
227 max_tokens: default_format_max_tokens(),
228 max_ast_nodes: default_format_max_ast_nodes(),
229 max_set_op_chain: default_format_max_set_op_chain(),
230 }
231 }
232}
233
234fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
235 Error::generate(format!(
236 "{code}: value {actual} exceeds configured limit {limit}"
237 ))
238}
239
240fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
241 if let Some(max) = options.max_input_bytes {
242 let input_bytes = sql.len();
243 if input_bytes > max {
244 return Err(format_guard_error(
245 "E_GUARD_INPUT_TOO_LARGE",
246 input_bytes,
247 max,
248 ));
249 }
250 }
251 Ok(())
252}
253
254fn parse_with_token_guard(
255 sql: &str,
256 dialect: &Dialect,
257 options: &FormatGuardOptions,
258) -> Result<Vec<Expression>> {
259 let tokens = dialect.tokenize(sql)?;
260 if let Some(max) = options.max_tokens {
261 let token_count = tokens.len();
262 if token_count > max {
263 return Err(format_guard_error(
264 "E_GUARD_TOKEN_BUDGET_EXCEEDED",
265 token_count,
266 max,
267 ));
268 }
269 }
270 enforce_set_op_chain_guard(&tokens, options)?;
271
272 let config = crate::parser::ParserConfig {
273 dialect: Some(dialect.dialect_type()),
274 ..Default::default()
275 };
276 let mut parser = Parser::with_source(tokens, config, sql.to_string());
277 parser.parse()
278}
279
280fn is_trivia_token(token_type: TokenType) -> bool {
281 matches!(
282 token_type,
283 TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
284 )
285}
286
287fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
288 tokens
289 .iter()
290 .skip(start)
291 .find(|token| !is_trivia_token(token.token_type))
292}
293
294fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
295 let token = &tokens[idx];
296 match token.token_type {
297 TokenType::Union | TokenType::Intersect => true,
298 TokenType::Except => {
299 if token.text.eq_ignore_ascii_case("minus")
302 && matches!(
303 next_significant_token(tokens, idx + 1).map(|t| t.token_type),
304 Some(TokenType::LParen)
305 )
306 {
307 return false;
308 }
309 true
310 }
311 _ => false,
312 }
313}
314
315fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
316 let Some(max) = options.max_set_op_chain else {
317 return Ok(());
318 };
319
320 let mut set_op_count = 0usize;
321 for (idx, token) in tokens.iter().enumerate() {
322 if token.token_type == TokenType::Semicolon {
323 set_op_count = 0;
324 continue;
325 }
326
327 if is_set_operation_token(tokens, idx) {
328 set_op_count += 1;
329 if set_op_count > max {
330 return Err(format_guard_error(
331 "E_GUARD_SET_OP_CHAIN_EXCEEDED",
332 set_op_count,
333 max,
334 ));
335 }
336 }
337 }
338
339 Ok(())
340}
341
342fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
343 if let Some(max) = options.max_ast_nodes {
344 let ast_nodes: usize = expressions.iter().map(node_count).sum();
345 if ast_nodes > max {
346 return Err(format_guard_error(
347 "E_GUARD_AST_BUDGET_EXCEEDED",
348 ast_nodes,
349 max,
350 ));
351 }
352 }
353 Ok(())
354}
355
356fn format_with_dialect(
357 sql: &str,
358 dialect: &Dialect,
359 options: &FormatGuardOptions,
360) -> Result<Vec<String>> {
361 enforce_input_guard(sql, options)?;
362 let expressions = parse_with_token_guard(sql, dialect, options)?;
363 enforce_ast_guard(&expressions, options)?;
364
365 expressions
366 .iter()
367 .map(|expr| dialect.generate_pretty(expr))
368 .collect()
369}
370
371pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
392 let read_dialect = Dialect::get(read);
393 let write_dialect = Dialect::get(write);
394 let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
395
396 let expressions = read_dialect.parse(sql)?;
397
398 expressions
399 .into_iter()
400 .map(|expr| {
401 if generic_identity {
402 write_dialect.generate_with_source(&expr, read)
403 } else {
404 let transformed = write_dialect.transform(expr)?;
405 write_dialect.generate_with_source(&transformed, read)
406 }
407 })
408 .collect()
409}
410
411pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
420 let d = Dialect::get(dialect);
421 d.parse(sql)
422}
423
424pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
433 let mut expressions = parse(sql, dialect)?;
434
435 if expressions.len() != 1 {
436 return Err(Error::parse(
437 format!("Expected 1 statement, found {}", expressions.len()),
438 0,
439 0,
440 0,
441 0,
442 ));
443 }
444
445 Ok(expressions.remove(0))
446}
447
448pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
457 let d = Dialect::get(dialect);
458 d.generate(expression)
459}
460
461pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
465 format_with_options(sql, dialect, &FormatGuardOptions::default())
466}
467
468pub fn format_with_options(
470 sql: &str,
471 dialect: DialectType,
472 options: &FormatGuardOptions,
473) -> Result<Vec<String>> {
474 let d = Dialect::get(dialect);
475 format_with_dialect(sql, &d, options)
476}
477
478pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
487 validate_with_options(sql, dialect, &ValidationOptions::default())
488}
489
490#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
492#[serde(rename_all = "camelCase")]
493pub struct ValidationOptions {
494 #[serde(default)]
497 pub strict_syntax: bool,
498}
499
500pub fn validate_with_options(
502 sql: &str,
503 dialect: DialectType,
504 options: &ValidationOptions,
505) -> ValidationResult {
506 let d = Dialect::get(dialect);
507 match d.parse(sql) {
508 Ok(expressions) => {
509 for expr in &expressions {
513 if !expr.is_statement() {
514 let msg = format!("Invalid expression / Unexpected token");
515 return ValidationResult::with_errors(vec![ValidationError::error(
516 msg, "E004",
517 )]);
518 }
519 }
520 if options.strict_syntax {
521 if let Some(error) = strict_syntax_error(sql, &d) {
522 return ValidationResult::with_errors(vec![error]);
523 }
524 }
525 ValidationResult::success()
526 }
527 Err(e) => {
528 let error = match &e {
529 Error::Syntax {
530 message,
531 line,
532 column,
533 start,
534 end,
535 } => ValidationError::error(message.clone(), "E001")
536 .with_location(*line, *column)
537 .with_span(Some(*start), Some(*end)),
538 Error::Tokenize {
539 message,
540 line,
541 column,
542 start,
543 end,
544 } => ValidationError::error(message.clone(), "E002")
545 .with_location(*line, *column)
546 .with_span(Some(*start), Some(*end)),
547 Error::Parse {
548 message,
549 line,
550 column,
551 start,
552 end,
553 } => ValidationError::error(message.clone(), "E003")
554 .with_location(*line, *column)
555 .with_span(Some(*start), Some(*end)),
556 _ => ValidationError::error(e.to_string(), "E000"),
557 };
558 ValidationResult::with_errors(vec![error])
559 }
560 }
561}
562
563fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
564 let tokens = dialect.tokenize(sql).ok()?;
565
566 for (idx, token) in tokens.iter().enumerate() {
567 if token.token_type != TokenType::Comma {
568 continue;
569 }
570
571 let next = tokens.get(idx + 1);
572 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
573 Some(TokenType::From) => (true, "FROM"),
574 Some(TokenType::Where) => (true, "WHERE"),
575 Some(TokenType::GroupBy) => (true, "GROUP BY"),
576 Some(TokenType::Having) => (true, "HAVING"),
577 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
578 Some(TokenType::Limit) => (true, "LIMIT"),
579 Some(TokenType::Offset) => (true, "OFFSET"),
580 Some(TokenType::Union) => (true, "UNION"),
581 Some(TokenType::Intersect) => (true, "INTERSECT"),
582 Some(TokenType::Except) => (true, "EXCEPT"),
583 Some(TokenType::Qualify) => (true, "QUALIFY"),
584 Some(TokenType::Window) => (true, "WINDOW"),
585 Some(TokenType::Semicolon) | None => (true, "end of statement"),
586 _ => (false, ""),
587 };
588
589 if is_boundary {
590 let message = format!(
591 "Trailing comma before {} is not allowed in strict syntax mode",
592 boundary_name
593 );
594 return Some(
595 ValidationError::error(message, "E005")
596 .with_location(token.span.line, token.span.column),
597 );
598 }
599 }
600
601 None
602}
603
604pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
617 let read_dialect = Dialect::get_by_name(read)
618 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
619 let write_dialect = Dialect::get_by_name(write)
620 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
621 let generic_identity = read_dialect.dialect_type() == DialectType::Generic
622 && write_dialect.dialect_type() == DialectType::Generic;
623
624 let expressions = read_dialect.parse(sql)?;
625
626 expressions
627 .into_iter()
628 .map(|expr| {
629 if generic_identity {
630 write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
631 } else {
632 let transformed = write_dialect.transform(expr)?;
633 write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
634 }
635 })
636 .collect()
637}
638
639pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
643 let d = Dialect::get_by_name(dialect)
644 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
645 d.parse(sql)
646}
647
648pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
652 let d = Dialect::get_by_name(dialect)
653 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
654 d.generate(expression)
655}
656
657pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
661 format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
662}
663
664pub fn format_with_options_by_name(
666 sql: &str,
667 dialect: &str,
668 options: &FormatGuardOptions,
669) -> Result<Vec<String>> {
670 let d = Dialect::get_by_name(dialect)
671 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
672 format_with_dialect(sql, &d, options)
673}
674
675#[cfg(test)]
676mod validation_tests {
677 use super::*;
678
679 #[test]
680 fn validate_is_permissive_by_default_for_trailing_commas() {
681 let result = validate("SELECT name, FROM employees", DialectType::Generic);
682 assert!(result.valid, "Result: {:?}", result.errors);
683 }
684
685 #[test]
686 fn validate_with_options_rejects_trailing_comma_before_from() {
687 let options = ValidationOptions {
688 strict_syntax: true,
689 };
690 let result = validate_with_options(
691 "SELECT name, FROM employees",
692 DialectType::Generic,
693 &options,
694 );
695 assert!(!result.valid, "Result should be invalid");
696 assert!(
697 result.errors.iter().any(|e| e.code == "E005"),
698 "Expected E005, got: {:?}",
699 result.errors
700 );
701 }
702
703 #[test]
704 fn validate_with_options_rejects_trailing_comma_before_where() {
705 let options = ValidationOptions {
706 strict_syntax: true,
707 };
708 let result = validate_with_options(
709 "SELECT name FROM employees, WHERE salary > 10",
710 DialectType::Generic,
711 &options,
712 );
713 assert!(!result.valid, "Result should be invalid");
714 assert!(
715 result.errors.iter().any(|e| e.code == "E005"),
716 "Expected E005, got: {:?}",
717 result.errors
718 );
719 }
720}
721
722#[cfg(test)]
723mod format_tests {
724 use super::*;
725
726 #[test]
727 fn format_basic_query() {
728 let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
729 assert_eq!(result.len(), 1);
730 assert!(result[0].contains('\n'));
731 }
732
733 #[test]
734 fn format_guard_rejects_large_input() {
735 let options = FormatGuardOptions {
736 max_input_bytes: Some(7),
737 max_tokens: None,
738 max_ast_nodes: None,
739 max_set_op_chain: None,
740 };
741 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
742 .expect_err("expected guard error");
743 assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
744 }
745
746 #[test]
747 fn format_guard_rejects_token_budget() {
748 let options = FormatGuardOptions {
749 max_input_bytes: None,
750 max_tokens: Some(1),
751 max_ast_nodes: None,
752 max_set_op_chain: None,
753 };
754 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
755 .expect_err("expected guard error");
756 assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
757 }
758
759 #[test]
760 fn format_guard_rejects_ast_budget() {
761 let options = FormatGuardOptions {
762 max_input_bytes: None,
763 max_tokens: None,
764 max_ast_nodes: Some(1),
765 max_set_op_chain: None,
766 };
767 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
768 .expect_err("expected guard error");
769 assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
770 }
771
772 #[test]
773 fn format_guard_rejects_set_op_chain_budget() {
774 let options = FormatGuardOptions {
775 max_input_bytes: None,
776 max_tokens: None,
777 max_ast_nodes: None,
778 max_set_op_chain: Some(1),
779 };
780 let err = format_with_options(
781 "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
782 DialectType::Generic,
783 &options,
784 )
785 .expect_err("expected guard error");
786 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
787 }
788
789 #[test]
790 fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
791 let options = FormatGuardOptions {
792 max_input_bytes: None,
793 max_tokens: None,
794 max_ast_nodes: None,
795 max_set_op_chain: Some(0),
796 };
797 let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
798 assert!(result.is_ok(), "Result: {:?}", result);
799 }
800
801 #[test]
802 fn issue57_invalid_ternary_returns_error() {
803 let sql = "SELECT x > 0 ? 1 : 0 FROM t";
806
807 let parse_result = parse(sql, DialectType::PostgreSQL);
808 assert!(
809 parse_result.is_err(),
810 "Expected parse error for invalid ternary SQL, got: {:?}",
811 parse_result
812 );
813
814 let format_result = format(sql, DialectType::PostgreSQL);
815 assert!(
816 format_result.is_err(),
817 "Expected format error for invalid ternary SQL, got: {:?}",
818 format_result
819 );
820
821 let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
822 assert!(
823 transpile_result.is_err(),
824 "Expected transpile error for invalid ternary SQL, got: {:?}",
825 transpile_result
826 );
827 }
828
829 #[test]
830 fn format_default_guard_rejects_deep_union_chain_before_parse() {
831 let base = "SELECT col0, col1 FROM t";
832 let mut sql = base.to_string();
833 for _ in 0..1100 {
834 sql.push_str(" UNION ALL ");
835 sql.push_str(base);
836 }
837
838 let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
839 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
840 }
841}