1pub mod ast_json;
16#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
17pub mod ast_transforms;
18#[cfg(feature = "builder")]
19pub mod builder;
20pub mod dialects;
21#[cfg(feature = "diff")]
22pub mod diff;
23pub mod error;
24pub mod expressions;
25#[cfg(feature = "semantic")]
26pub mod function_catalog;
27mod function_registry;
28#[cfg(feature = "generate")]
29pub mod generator;
30#[cfg(feature = "semantic")]
31pub mod helper;
32#[cfg(feature = "semantic")]
33pub mod lineage;
34#[cfg(feature = "openlineage")]
35pub mod openlineage;
36#[cfg(feature = "semantic")]
37pub mod optimizer;
38pub mod parser;
39#[cfg(feature = "planner")]
40pub mod planner;
41#[cfg(all(feature = "semantic", feature = "generate"))]
42pub mod query_analysis;
43#[cfg(feature = "semantic")]
44pub mod resolver;
45#[cfg(feature = "semantic")]
46pub mod schema;
47#[cfg(feature = "semantic")]
48pub mod scope;
49#[cfg(feature = "time")]
50pub mod time;
51pub mod tokens;
52#[cfg(feature = "transpile")]
53pub mod transforms;
54#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
55pub mod traversal;
56#[cfg(any(feature = "semantic", feature = "time"))]
57pub mod trie;
58#[cfg(feature = "semantic")]
59pub mod validation;
60
61#[cfg(any(feature = "generate", feature = "semantic"))]
62use serde::{Deserialize, Serialize};
63
64#[cfg(feature = "ast-tools")]
65pub use ast_transforms::{
66 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
67 get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
68 get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
69 remove_select_columns, remove_where, rename_columns, rename_tables, rename_tables_with_options,
70 replace_by_type, replace_nodes, set_distinct, set_limit, set_limit_expr, set_offset,
71 set_offset_expr, set_order_by, RenameTablesOptions,
72};
73pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
74#[cfg(feature = "transpile")]
75pub use dialects::{TranspileOptions, TranspileTarget};
76pub use error::{Error, Result};
77#[cfg(feature = "semantic")]
78pub use error::{ValidationError, ValidationResult, ValidationSeverity};
79pub use expressions::{DataType, Expression};
80#[cfg(feature = "semantic")]
81pub use function_catalog::{
82 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
83};
84#[cfg(feature = "generate")]
85pub use generator::{Generator, UnsupportedLevel};
86#[cfg(feature = "semantic")]
87pub use helper::{
88 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
89 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
90};
91#[cfg(feature = "semantic")]
92pub use optimizer::{
93 annotate_types, qualify_tables, QualifyTablesOptions, TypeAnnotator, TypeCoercionClass,
94};
95pub use parser::Parser;
96#[cfg(all(feature = "semantic", feature = "generate"))]
97pub use query_analysis::{
98 analyze_query, AnalyzeQueryOptions, ColumnReferenceFact, CteFact, ProjectionFact,
99 ProjectionNullability, QueryAnalysis, QueryShape, ReferenceConfidence, RelationFact,
100 SetOperationBranchFact, SetOperationFact, StarProjectionFact, TransformFunctionFact,
101 TransformKind,
102};
103#[cfg(feature = "semantic")]
104pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
105#[cfg(feature = "semantic")]
106pub use schema::{
107 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
108};
109#[cfg(feature = "semantic")]
110pub use scope::{
111 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
112 ScopeType, SourceInfo,
113};
114#[cfg(feature = "time")]
115pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
116pub use tokens::{Token, TokenType, Tokenizer};
117#[cfg(feature = "ast-tools")]
118pub use traversal::{
119 contains_aggregate,
120 contains_subquery,
121 contains_window_function,
122 find_ancestor,
123 find_parent,
124 get_all_tables,
125 get_columns,
126 get_merge_source,
127 get_merge_target,
128 get_tables,
129 is_add,
130 is_aggregate,
131 is_alias,
132 is_alter_table,
133 is_and,
134 is_arithmetic,
135 is_avg,
136 is_between,
137 is_boolean,
138 is_case,
139 is_cast,
140 is_coalesce,
141 is_column,
142 is_comparison,
143 is_concat,
144 is_count,
145 is_create_index,
146 is_create_table,
147 is_create_view,
148 is_cte,
149 is_ddl,
150 is_delete,
151 is_div,
152 is_drop_index,
153 is_drop_table,
154 is_drop_view,
155 is_eq,
156 is_except,
157 is_exists,
158 is_from,
159 is_function,
160 is_group_by,
161 is_gt,
162 is_gte,
163 is_having,
164 is_identifier,
165 is_ilike,
166 is_in,
167 is_insert,
169 is_intersect,
170 is_is_null,
171 is_join,
172 is_like,
173 is_limit,
174 is_literal,
175 is_logical,
176 is_lt,
177 is_lte,
178 is_max_func,
179 is_merge,
180 is_min_func,
181 is_mod,
182 is_mul,
183 is_neq,
184 is_not,
185 is_null_if,
186 is_null_literal,
187 is_offset,
188 is_or,
189 is_order_by,
190 is_ordered,
191 is_paren,
192 is_query,
194 is_safe_cast,
195 is_select,
196 is_set_operation,
197 is_star,
198 is_sub,
199 is_subquery,
200 is_sum,
201 is_table,
202 is_try_cast,
203 is_union,
204 is_update,
205 is_where,
206 is_window_function,
207 is_with,
208 transform,
209 transform_map,
210 BfsIter,
211 DfsIter,
212 ExpressionWalk,
213 ParentInfo,
214 TreeContext,
215};
216#[cfg(any(feature = "semantic", feature = "time"))]
217pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
218#[cfg(feature = "semantic")]
219pub use validation::{
220 mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
221 SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
222 SchemaValidationOptions, ValidationSchema,
223};
224
225#[cfg(feature = "generate")]
226const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; #[cfg(feature = "generate")]
228const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
229#[cfg(feature = "generate")]
230const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
231#[cfg(feature = "generate")]
232const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
233
234#[cfg(feature = "generate")]
235fn default_format_max_input_bytes() -> Option<usize> {
236 Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
237}
238
239#[cfg(feature = "generate")]
240fn default_format_max_tokens() -> Option<usize> {
241 Some(DEFAULT_FORMAT_MAX_TOKENS)
242}
243
244#[cfg(feature = "generate")]
245fn default_format_max_ast_nodes() -> Option<usize> {
246 Some(DEFAULT_FORMAT_MAX_AST_NODES)
247}
248
249#[cfg(feature = "generate")]
250fn default_format_max_set_op_chain() -> Option<usize> {
251 Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
252}
253
254#[cfg(feature = "generate")]
259#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct FormatGuardOptions {
262 #[serde(default = "default_format_max_input_bytes")]
265 pub max_input_bytes: Option<usize>,
266 #[serde(default = "default_format_max_tokens")]
269 pub max_tokens: Option<usize>,
270 #[serde(default = "default_format_max_ast_nodes")]
273 pub max_ast_nodes: Option<usize>,
274 #[serde(default = "default_format_max_set_op_chain")]
279 pub max_set_op_chain: Option<usize>,
280}
281
282#[cfg(feature = "generate")]
283impl Default for FormatGuardOptions {
284 fn default() -> Self {
285 Self {
286 max_input_bytes: default_format_max_input_bytes(),
287 max_tokens: default_format_max_tokens(),
288 max_ast_nodes: default_format_max_ast_nodes(),
289 max_set_op_chain: default_format_max_set_op_chain(),
290 }
291 }
292}
293
294#[cfg(feature = "generate")]
295fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
296 Error::generate(format!(
297 "{code}: value {actual} exceeds configured limit {limit}"
298 ))
299}
300
301#[cfg(feature = "generate")]
302fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
303 if let Some(max) = options.max_input_bytes {
304 let input_bytes = sql.len();
305 if input_bytes > max {
306 return Err(format_guard_error(
307 "E_GUARD_INPUT_TOO_LARGE",
308 input_bytes,
309 max,
310 ));
311 }
312 }
313 Ok(())
314}
315
316#[cfg(feature = "generate")]
317fn parse_with_token_guard(
318 sql: &str,
319 dialect: &Dialect,
320 options: &FormatGuardOptions,
321) -> Result<Vec<Expression>> {
322 let tokens = dialect.tokenize(sql)?;
323 if let Some(max) = options.max_tokens {
324 let token_count = tokens.len();
325 if token_count > max {
326 return Err(format_guard_error(
327 "E_GUARD_TOKEN_BUDGET_EXCEEDED",
328 token_count,
329 max,
330 ));
331 }
332 }
333 enforce_set_op_chain_guard(&tokens, options)?;
334
335 let config = crate::parser::ParserConfig {
336 dialect: Some(dialect.dialect_type()),
337 ..Default::default()
338 };
339 let mut parser = Parser::with_source(tokens, config, sql.to_string());
340 parser.parse()
341}
342
343#[cfg(feature = "generate")]
344fn is_trivia_token(token_type: TokenType) -> bool {
345 matches!(
346 token_type,
347 TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
348 )
349}
350
351#[cfg(feature = "generate")]
352fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
353 tokens
354 .iter()
355 .skip(start)
356 .find(|token| !is_trivia_token(token.token_type))
357}
358
359#[cfg(feature = "generate")]
360fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
361 let token = &tokens[idx];
362 match token.token_type {
363 TokenType::Union | TokenType::Intersect => true,
364 TokenType::Except => {
365 if token.text.eq_ignore_ascii_case("minus")
368 && matches!(
369 next_significant_token(tokens, idx + 1).map(|t| t.token_type),
370 Some(TokenType::LParen)
371 )
372 {
373 return false;
374 }
375 true
376 }
377 _ => false,
378 }
379}
380
381#[cfg(feature = "generate")]
382fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
383 let Some(max) = options.max_set_op_chain else {
384 return Ok(());
385 };
386
387 let mut set_op_count = 0usize;
388 for (idx, token) in tokens.iter().enumerate() {
389 if token.token_type == TokenType::Semicolon {
390 set_op_count = 0;
391 continue;
392 }
393
394 if is_set_operation_token(tokens, idx) {
395 set_op_count += 1;
396 if set_op_count > max {
397 return Err(format_guard_error(
398 "E_GUARD_SET_OP_CHAIN_EXCEEDED",
399 set_op_count,
400 max,
401 ));
402 }
403 }
404 }
405
406 Ok(())
407}
408
409#[cfg(feature = "generate")]
410fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
411 if let Some(max) = options.max_ast_nodes {
412 let ast_nodes: usize = expressions
413 .iter()
414 .map(crate::ast_transforms::node_count)
415 .sum();
416 if ast_nodes > max {
417 return Err(format_guard_error(
418 "E_GUARD_AST_BUDGET_EXCEEDED",
419 ast_nodes,
420 max,
421 ));
422 }
423 }
424 Ok(())
425}
426
427#[cfg(feature = "generate")]
428fn format_with_dialect(
429 sql: &str,
430 dialect: &Dialect,
431 options: &FormatGuardOptions,
432) -> Result<Vec<String>> {
433 enforce_input_guard(sql, options)?;
434 let expressions = parse_with_token_guard(sql, dialect, options)?;
435 enforce_ast_guard(&expressions, options)?;
436
437 expressions
438 .iter()
439 .map(|expr| dialect.generate_pretty(expr))
440 .collect()
441}
442
443#[cfg(feature = "transpile")]
464pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
465 Dialect::get(read).transpile(sql, write)
470}
471
472pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
481 let d = Dialect::get(dialect);
482 d.parse(sql)
483}
484
485pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
494 let mut expressions = parse(sql, dialect)?;
495
496 if expressions.len() != 1 {
497 return Err(Error::parse(
498 format!("Expected 1 statement, found {}", expressions.len()),
499 0,
500 0,
501 0,
502 0,
503 ));
504 }
505
506 Ok(expressions.remove(0))
507}
508
509pub fn parse_data_type(sql: &str, dialect: DialectType) -> Result<DataType> {
518 Dialect::get(dialect).parse_data_type(sql)
519}
520
521#[cfg(feature = "generate")]
530pub fn generate_data_type(data_type: &DataType, dialect: DialectType) -> Result<String> {
531 Dialect::get(dialect).generate(&Expression::DataType(data_type.clone()))
532}
533
534#[cfg(feature = "generate")]
543pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
544 let d = Dialect::get(dialect);
545 d.generate(expression)
546}
547
548#[cfg(feature = "generate")]
552pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
553 format_with_options(sql, dialect, &FormatGuardOptions::default())
554}
555
556#[cfg(feature = "generate")]
558pub fn format_with_options(
559 sql: &str,
560 dialect: DialectType,
561 options: &FormatGuardOptions,
562) -> Result<Vec<String>> {
563 let d = Dialect::get(dialect);
564 format_with_dialect(sql, &d, options)
565}
566
567#[cfg(feature = "semantic")]
576pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
577 validate_with_options(sql, dialect, &ValidationOptions::default())
578}
579
580#[cfg(feature = "semantic")]
582#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
583#[serde(rename_all = "camelCase")]
584pub struct ValidationOptions {
585 #[serde(default)]
588 pub strict_syntax: bool,
589}
590
591#[cfg(feature = "semantic")]
593pub fn validate_with_options(
594 sql: &str,
595 dialect: DialectType,
596 options: &ValidationOptions,
597) -> ValidationResult {
598 let d = Dialect::get(dialect);
599 match d.parse(sql) {
600 Ok(expressions) => {
601 for expr in &expressions {
605 if !expr.is_statement() {
606 let msg = format!("Invalid expression / Unexpected token");
607 return ValidationResult::with_errors(vec![ValidationError::error(
608 msg, "E004",
609 )]);
610 }
611 }
612 if options.strict_syntax {
613 if let Some(error) = strict_syntax_error(sql, &d) {
614 return ValidationResult::with_errors(vec![error]);
615 }
616 }
617 ValidationResult::success()
618 }
619 Err(e) => {
620 let error = match &e {
621 Error::Syntax {
622 message,
623 line,
624 column,
625 start,
626 end,
627 } => ValidationError::error(message.clone(), "E001")
628 .with_location(*line, *column)
629 .with_span(Some(*start), Some(*end)),
630 Error::Tokenize {
631 message,
632 line,
633 column,
634 start,
635 end,
636 } => ValidationError::error(message.clone(), "E002")
637 .with_location(*line, *column)
638 .with_span(Some(*start), Some(*end)),
639 Error::Parse {
640 message,
641 line,
642 column,
643 start,
644 end,
645 } => ValidationError::error(message.clone(), "E003")
646 .with_location(*line, *column)
647 .with_span(Some(*start), Some(*end)),
648 _ => ValidationError::error(e.to_string(), "E000"),
649 };
650 ValidationResult::with_errors(vec![error])
651 }
652 }
653}
654
655#[cfg(feature = "semantic")]
656fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
657 let tokens = dialect.tokenize(sql).ok()?;
658
659 for (idx, token) in tokens.iter().enumerate() {
660 if token.token_type != TokenType::Comma {
661 continue;
662 }
663
664 let next = tokens.get(idx + 1);
665 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
666 Some(TokenType::From) => (true, "FROM"),
667 Some(TokenType::Where) => (true, "WHERE"),
668 Some(TokenType::GroupBy) => (true, "GROUP BY"),
669 Some(TokenType::Having) => (true, "HAVING"),
670 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
671 Some(TokenType::Limit) => (true, "LIMIT"),
672 Some(TokenType::Offset) => (true, "OFFSET"),
673 Some(TokenType::Union) => (true, "UNION"),
674 Some(TokenType::Intersect) => (true, "INTERSECT"),
675 Some(TokenType::Except) => (true, "EXCEPT"),
676 Some(TokenType::Qualify) => (true, "QUALIFY"),
677 Some(TokenType::Window) => (true, "WINDOW"),
678 Some(TokenType::Semicolon) | None => (true, "end of statement"),
679 _ => (false, ""),
680 };
681
682 if is_boundary {
683 let message = format!(
684 "Trailing comma before {} is not allowed in strict syntax mode",
685 boundary_name
686 );
687 return Some(
688 ValidationError::error(message, "E005")
689 .with_location(token.span.line, token.span.column),
690 );
691 }
692 }
693
694 None
695}
696
697#[cfg(feature = "transpile")]
710pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
711 transpile_with_by_name(sql, read, write, &TranspileOptions::default())
712}
713
714#[cfg(feature = "transpile")]
718pub fn transpile_with_by_name(
719 sql: &str,
720 read: &str,
721 write: &str,
722 opts: &TranspileOptions,
723) -> Result<Vec<String>> {
724 let read_dialect = Dialect::get_by_name(read)
725 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
726 let write_dialect = Dialect::get_by_name(write)
727 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
728 read_dialect.transpile_with(sql, &write_dialect, opts.clone())
729}
730
731pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
735 let d = Dialect::get_by_name(dialect)
736 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
737 d.parse(sql)
738}
739
740#[cfg(feature = "generate")]
744pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
745 let d = Dialect::get_by_name(dialect)
746 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
747 d.generate(expression)
748}
749
750#[cfg(feature = "generate")]
754pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
755 format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
756}
757
758#[cfg(feature = "generate")]
760pub fn format_with_options_by_name(
761 sql: &str,
762 dialect: &str,
763 options: &FormatGuardOptions,
764) -> Result<Vec<String>> {
765 let d = Dialect::get_by_name(dialect)
766 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
767 format_with_dialect(sql, &d, options)
768}
769
770#[cfg(all(test, feature = "semantic"))]
771mod validation_tests {
772 use super::*;
773
774 #[test]
775 fn validate_is_permissive_by_default_for_trailing_commas() {
776 let result = validate("SELECT name, FROM employees", DialectType::Generic);
777 assert!(result.valid, "Result: {:?}", result.errors);
778 }
779
780 #[test]
781 fn validate_with_options_rejects_trailing_comma_before_from() {
782 let options = ValidationOptions {
783 strict_syntax: true,
784 };
785 let result = validate_with_options(
786 "SELECT name, FROM employees",
787 DialectType::Generic,
788 &options,
789 );
790 assert!(!result.valid, "Result should be invalid");
791 assert!(
792 result.errors.iter().any(|e| e.code == "E005"),
793 "Expected E005, got: {:?}",
794 result.errors
795 );
796 }
797
798 #[test]
799 fn validate_with_options_rejects_trailing_comma_before_where() {
800 let options = ValidationOptions {
801 strict_syntax: true,
802 };
803 let result = validate_with_options(
804 "SELECT name FROM employees, WHERE salary > 10",
805 DialectType::Generic,
806 &options,
807 );
808 assert!(!result.valid, "Result should be invalid");
809 assert!(
810 result.errors.iter().any(|e| e.code == "E005"),
811 "Expected E005, got: {:?}",
812 result.errors
813 );
814 }
815}
816
817#[cfg(all(test, feature = "generate"))]
818mod format_tests {
819 use super::*;
820
821 #[test]
822 fn format_basic_query() {
823 let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
824 assert_eq!(result.len(), 1);
825 assert!(result[0].contains('\n'));
826 }
827
828 #[test]
829 fn format_guard_rejects_large_input() {
830 let options = FormatGuardOptions {
831 max_input_bytes: Some(7),
832 max_tokens: None,
833 max_ast_nodes: None,
834 max_set_op_chain: None,
835 };
836 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
837 .expect_err("expected guard error");
838 assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
839 }
840
841 #[test]
842 fn format_guard_rejects_token_budget() {
843 let options = FormatGuardOptions {
844 max_input_bytes: None,
845 max_tokens: Some(1),
846 max_ast_nodes: None,
847 max_set_op_chain: None,
848 };
849 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
850 .expect_err("expected guard error");
851 assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
852 }
853
854 #[test]
855 fn format_guard_rejects_ast_budget() {
856 let options = FormatGuardOptions {
857 max_input_bytes: None,
858 max_tokens: None,
859 max_ast_nodes: Some(1),
860 max_set_op_chain: None,
861 };
862 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
863 .expect_err("expected guard error");
864 assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
865 }
866
867 #[test]
868 fn format_guard_rejects_set_op_chain_budget() {
869 let options = FormatGuardOptions {
870 max_input_bytes: None,
871 max_tokens: None,
872 max_ast_nodes: None,
873 max_set_op_chain: Some(1),
874 };
875 let err = format_with_options(
876 "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
877 DialectType::Generic,
878 &options,
879 )
880 .expect_err("expected guard error");
881 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
882 }
883
884 #[test]
885 fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
886 let options = FormatGuardOptions {
887 max_input_bytes: None,
888 max_tokens: None,
889 max_ast_nodes: None,
890 max_set_op_chain: Some(0),
891 };
892 let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
893 assert!(result.is_ok(), "Result: {:?}", result);
894 }
895
896 #[test]
897 fn issue57_invalid_ternary_returns_error() {
898 let sql = "SELECT x > 0 ? 1 : 0 FROM t";
901
902 let parse_result = parse(sql, DialectType::PostgreSQL);
903 assert!(
904 parse_result.is_err(),
905 "Expected parse error for invalid ternary SQL, got: {:?}",
906 parse_result
907 );
908
909 let format_result = format(sql, DialectType::PostgreSQL);
910 assert!(
911 format_result.is_err(),
912 "Expected format error for invalid ternary SQL, got: {:?}",
913 format_result
914 );
915
916 let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
917 assert!(
918 transpile_result.is_err(),
919 "Expected transpile error for invalid ternary SQL, got: {:?}",
920 transpile_result
921 );
922 }
923
924 #[test]
929 fn transpile_applies_cross_dialect_rewrites() {
930 let out = transpile(
932 "SELECT to_timestamp(col) FROM t",
933 DialectType::DuckDB,
934 DialectType::Trino,
935 )
936 .expect("transpile failed");
937 assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
938
939 let out = transpile(
941 "SELECT CAST(col AS JSON) FROM t",
942 DialectType::DuckDB,
943 DialectType::Trino,
944 )
945 .expect("transpile failed");
946 assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
947 }
948
949 #[test]
953 fn transpile_matches_dialect_method() {
954 let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
955 (
956 DialectType::DuckDB,
957 DialectType::Trino,
958 "duckdb",
959 "trino",
960 "SELECT to_timestamp(col) FROM t",
961 ),
962 (
963 DialectType::DuckDB,
964 DialectType::Trino,
965 "duckdb",
966 "trino",
967 "SELECT CAST(col AS JSON) FROM t",
968 ),
969 (
970 DialectType::DuckDB,
971 DialectType::Trino,
972 "duckdb",
973 "trino",
974 "SELECT json_valid(col) FROM t",
975 ),
976 (
977 DialectType::Snowflake,
978 DialectType::DuckDB,
979 "snowflake",
980 "duckdb",
981 "SELECT DATEDIFF(day, a, b) FROM t",
982 ),
983 (
984 DialectType::BigQuery,
985 DialectType::DuckDB,
986 "bigquery",
987 "duckdb",
988 "SELECT DATE_DIFF(a, b, DAY) FROM t",
989 ),
990 (
991 DialectType::Generic,
992 DialectType::Generic,
993 "generic",
994 "generic",
995 "SELECT 1",
996 ),
997 ];
998 for (read, write, read_name, write_name, sql) in cases {
999 let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
1000 let via_name = transpile_by_name(sql, read_name, write_name)
1001 .expect("lib::transpile_by_name failed");
1002 let via_dialect = Dialect::get(*read)
1003 .transpile(sql, *write)
1004 .expect("Dialect::transpile failed");
1005 assert_eq!(
1006 via_lib, via_dialect,
1007 "lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
1008 read, write
1009 );
1010 assert_eq!(
1011 via_name, via_dialect,
1012 "lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
1013 );
1014 }
1015 }
1016
1017 #[test]
1018 fn format_default_guard_rejects_deep_union_chain_before_parse() {
1019 let base = "SELECT col0, col1 FROM t";
1020 let mut sql = base.to_string();
1021 for _ in 0..1100 {
1022 sql.push_str(" UNION ALL ");
1023 sql.push_str(base);
1024 }
1025
1026 let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
1027 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
1028 }
1029}