1#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
16pub mod ast_transforms;
17#[cfg(feature = "builder")]
18pub mod builder;
19pub mod dialects;
20#[cfg(feature = "diff")]
21pub mod diff;
22pub mod error;
23pub mod expressions;
24#[cfg(feature = "semantic")]
25pub mod function_catalog;
26mod function_registry;
27#[cfg(feature = "generate")]
28pub mod generator;
29#[cfg(feature = "semantic")]
30pub mod helper;
31#[cfg(feature = "semantic")]
32pub mod lineage;
33#[cfg(feature = "openlineage")]
34pub mod openlineage;
35#[cfg(feature = "semantic")]
36pub mod optimizer;
37pub mod parser;
38#[cfg(feature = "planner")]
39pub mod planner;
40#[cfg(all(feature = "semantic", feature = "generate"))]
41pub mod query_analysis;
42#[cfg(feature = "semantic")]
43pub mod resolver;
44#[cfg(feature = "semantic")]
45pub mod schema;
46#[cfg(feature = "semantic")]
47pub mod scope;
48#[cfg(feature = "time")]
49pub mod time;
50pub mod tokens;
51#[cfg(feature = "transpile")]
52pub mod transforms;
53#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
54pub mod traversal;
55#[cfg(any(feature = "semantic", feature = "time"))]
56pub mod trie;
57#[cfg(feature = "semantic")]
58pub mod validation;
59
60#[cfg(any(feature = "generate", feature = "semantic"))]
61use serde::{Deserialize, Serialize};
62
63#[cfg(feature = "ast-tools")]
64pub use ast_transforms::{
65 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
66 get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
67 get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
68 remove_select_columns, remove_where, rename_columns, rename_tables, rename_tables_with_options,
69 replace_by_type, replace_nodes, set_distinct, set_limit, set_offset, RenameTablesOptions,
70};
71pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
72#[cfg(feature = "transpile")]
73pub use dialects::{TranspileOptions, TranspileTarget};
74pub use error::{Error, Result};
75#[cfg(feature = "semantic")]
76pub use error::{ValidationError, ValidationResult, ValidationSeverity};
77pub use expressions::{DataType, Expression};
78#[cfg(feature = "semantic")]
79pub use function_catalog::{
80 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
81};
82#[cfg(feature = "generate")]
83pub use generator::{Generator, UnsupportedLevel};
84#[cfg(feature = "semantic")]
85pub use helper::{
86 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
87 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
88};
89#[cfg(feature = "semantic")]
90pub use optimizer::{
91 annotate_types, qualify_tables, QualifyTablesOptions, TypeAnnotator, TypeCoercionClass,
92};
93pub use parser::Parser;
94#[cfg(all(feature = "semantic", feature = "generate"))]
95pub use query_analysis::{
96 analyze_query, AnalyzeQueryOptions, ColumnReferenceFact, ProjectionFact, QueryAnalysis,
97 QueryShape, ReferenceConfidence, RelationFact, SetOperationBranchFact, SetOperationFact,
98 TransformKind,
99};
100#[cfg(feature = "semantic")]
101pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
102#[cfg(feature = "semantic")]
103pub use schema::{
104 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
105};
106#[cfg(feature = "semantic")]
107pub use scope::{
108 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
109 ScopeType, SourceInfo,
110};
111#[cfg(feature = "time")]
112pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
113pub use tokens::{Token, TokenType, Tokenizer};
114#[cfg(feature = "ast-tools")]
115pub use traversal::{
116 contains_aggregate,
117 contains_subquery,
118 contains_window_function,
119 find_ancestor,
120 find_parent,
121 get_all_tables,
122 get_columns,
123 get_merge_source,
124 get_merge_target,
125 get_tables,
126 is_add,
127 is_aggregate,
128 is_alias,
129 is_alter_table,
130 is_and,
131 is_arithmetic,
132 is_avg,
133 is_between,
134 is_boolean,
135 is_case,
136 is_cast,
137 is_coalesce,
138 is_column,
139 is_comparison,
140 is_concat,
141 is_count,
142 is_create_index,
143 is_create_table,
144 is_create_view,
145 is_cte,
146 is_ddl,
147 is_delete,
148 is_div,
149 is_drop_index,
150 is_drop_table,
151 is_drop_view,
152 is_eq,
153 is_except,
154 is_exists,
155 is_from,
156 is_function,
157 is_group_by,
158 is_gt,
159 is_gte,
160 is_having,
161 is_identifier,
162 is_ilike,
163 is_in,
164 is_insert,
166 is_intersect,
167 is_is_null,
168 is_join,
169 is_like,
170 is_limit,
171 is_literal,
172 is_logical,
173 is_lt,
174 is_lte,
175 is_max_func,
176 is_merge,
177 is_min_func,
178 is_mod,
179 is_mul,
180 is_neq,
181 is_not,
182 is_null_if,
183 is_null_literal,
184 is_offset,
185 is_or,
186 is_order_by,
187 is_ordered,
188 is_paren,
189 is_query,
191 is_safe_cast,
192 is_select,
193 is_set_operation,
194 is_star,
195 is_sub,
196 is_subquery,
197 is_sum,
198 is_table,
199 is_try_cast,
200 is_union,
201 is_update,
202 is_where,
203 is_window_function,
204 is_with,
205 transform,
206 transform_map,
207 BfsIter,
208 DfsIter,
209 ExpressionWalk,
210 ParentInfo,
211 TreeContext,
212};
213#[cfg(any(feature = "semantic", feature = "time"))]
214pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
215#[cfg(feature = "semantic")]
216pub use validation::{
217 mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
218 SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
219 SchemaValidationOptions, ValidationSchema,
220};
221
222#[cfg(feature = "generate")]
223const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; #[cfg(feature = "generate")]
225const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
226#[cfg(feature = "generate")]
227const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
228#[cfg(feature = "generate")]
229const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
230
231#[cfg(feature = "generate")]
232fn default_format_max_input_bytes() -> Option<usize> {
233 Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
234}
235
236#[cfg(feature = "generate")]
237fn default_format_max_tokens() -> Option<usize> {
238 Some(DEFAULT_FORMAT_MAX_TOKENS)
239}
240
241#[cfg(feature = "generate")]
242fn default_format_max_ast_nodes() -> Option<usize> {
243 Some(DEFAULT_FORMAT_MAX_AST_NODES)
244}
245
246#[cfg(feature = "generate")]
247fn default_format_max_set_op_chain() -> Option<usize> {
248 Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
249}
250
251#[cfg(feature = "generate")]
256#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
257#[serde(rename_all = "camelCase")]
258pub struct FormatGuardOptions {
259 #[serde(default = "default_format_max_input_bytes")]
262 pub max_input_bytes: Option<usize>,
263 #[serde(default = "default_format_max_tokens")]
266 pub max_tokens: Option<usize>,
267 #[serde(default = "default_format_max_ast_nodes")]
270 pub max_ast_nodes: Option<usize>,
271 #[serde(default = "default_format_max_set_op_chain")]
276 pub max_set_op_chain: Option<usize>,
277}
278
279#[cfg(feature = "generate")]
280impl Default for FormatGuardOptions {
281 fn default() -> Self {
282 Self {
283 max_input_bytes: default_format_max_input_bytes(),
284 max_tokens: default_format_max_tokens(),
285 max_ast_nodes: default_format_max_ast_nodes(),
286 max_set_op_chain: default_format_max_set_op_chain(),
287 }
288 }
289}
290
291#[cfg(feature = "generate")]
292fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
293 Error::generate(format!(
294 "{code}: value {actual} exceeds configured limit {limit}"
295 ))
296}
297
298#[cfg(feature = "generate")]
299fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
300 if let Some(max) = options.max_input_bytes {
301 let input_bytes = sql.len();
302 if input_bytes > max {
303 return Err(format_guard_error(
304 "E_GUARD_INPUT_TOO_LARGE",
305 input_bytes,
306 max,
307 ));
308 }
309 }
310 Ok(())
311}
312
313#[cfg(feature = "generate")]
314fn parse_with_token_guard(
315 sql: &str,
316 dialect: &Dialect,
317 options: &FormatGuardOptions,
318) -> Result<Vec<Expression>> {
319 let tokens = dialect.tokenize(sql)?;
320 if let Some(max) = options.max_tokens {
321 let token_count = tokens.len();
322 if token_count > max {
323 return Err(format_guard_error(
324 "E_GUARD_TOKEN_BUDGET_EXCEEDED",
325 token_count,
326 max,
327 ));
328 }
329 }
330 enforce_set_op_chain_guard(&tokens, options)?;
331
332 let config = crate::parser::ParserConfig {
333 dialect: Some(dialect.dialect_type()),
334 ..Default::default()
335 };
336 let mut parser = Parser::with_source(tokens, config, sql.to_string());
337 parser.parse()
338}
339
340#[cfg(feature = "generate")]
341fn is_trivia_token(token_type: TokenType) -> bool {
342 matches!(
343 token_type,
344 TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
345 )
346}
347
348#[cfg(feature = "generate")]
349fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
350 tokens
351 .iter()
352 .skip(start)
353 .find(|token| !is_trivia_token(token.token_type))
354}
355
356#[cfg(feature = "generate")]
357fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
358 let token = &tokens[idx];
359 match token.token_type {
360 TokenType::Union | TokenType::Intersect => true,
361 TokenType::Except => {
362 if token.text.eq_ignore_ascii_case("minus")
365 && matches!(
366 next_significant_token(tokens, idx + 1).map(|t| t.token_type),
367 Some(TokenType::LParen)
368 )
369 {
370 return false;
371 }
372 true
373 }
374 _ => false,
375 }
376}
377
378#[cfg(feature = "generate")]
379fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
380 let Some(max) = options.max_set_op_chain else {
381 return Ok(());
382 };
383
384 let mut set_op_count = 0usize;
385 for (idx, token) in tokens.iter().enumerate() {
386 if token.token_type == TokenType::Semicolon {
387 set_op_count = 0;
388 continue;
389 }
390
391 if is_set_operation_token(tokens, idx) {
392 set_op_count += 1;
393 if set_op_count > max {
394 return Err(format_guard_error(
395 "E_GUARD_SET_OP_CHAIN_EXCEEDED",
396 set_op_count,
397 max,
398 ));
399 }
400 }
401 }
402
403 Ok(())
404}
405
406#[cfg(feature = "generate")]
407fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
408 if let Some(max) = options.max_ast_nodes {
409 let ast_nodes: usize = expressions
410 .iter()
411 .map(crate::ast_transforms::node_count)
412 .sum();
413 if ast_nodes > max {
414 return Err(format_guard_error(
415 "E_GUARD_AST_BUDGET_EXCEEDED",
416 ast_nodes,
417 max,
418 ));
419 }
420 }
421 Ok(())
422}
423
424#[cfg(feature = "generate")]
425fn format_with_dialect(
426 sql: &str,
427 dialect: &Dialect,
428 options: &FormatGuardOptions,
429) -> Result<Vec<String>> {
430 enforce_input_guard(sql, options)?;
431 let expressions = parse_with_token_guard(sql, dialect, options)?;
432 enforce_ast_guard(&expressions, options)?;
433
434 expressions
435 .iter()
436 .map(|expr| dialect.generate_pretty(expr))
437 .collect()
438}
439
440#[cfg(feature = "transpile")]
461pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
462 Dialect::get(read).transpile(sql, write)
467}
468
469pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
478 let d = Dialect::get(dialect);
479 d.parse(sql)
480}
481
482pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
491 let mut expressions = parse(sql, dialect)?;
492
493 if expressions.len() != 1 {
494 return Err(Error::parse(
495 format!("Expected 1 statement, found {}", expressions.len()),
496 0,
497 0,
498 0,
499 0,
500 ));
501 }
502
503 Ok(expressions.remove(0))
504}
505
506pub fn parse_data_type(sql: &str, dialect: DialectType) -> Result<DataType> {
515 Dialect::get(dialect).parse_data_type(sql)
516}
517
518#[cfg(feature = "generate")]
527pub fn generate_data_type(data_type: &DataType, dialect: DialectType) -> Result<String> {
528 Dialect::get(dialect).generate(&Expression::DataType(data_type.clone()))
529}
530
531#[cfg(feature = "generate")]
540pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
541 let d = Dialect::get(dialect);
542 d.generate(expression)
543}
544
545#[cfg(feature = "generate")]
549pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
550 format_with_options(sql, dialect, &FormatGuardOptions::default())
551}
552
553#[cfg(feature = "generate")]
555pub fn format_with_options(
556 sql: &str,
557 dialect: DialectType,
558 options: &FormatGuardOptions,
559) -> Result<Vec<String>> {
560 let d = Dialect::get(dialect);
561 format_with_dialect(sql, &d, options)
562}
563
564#[cfg(feature = "semantic")]
573pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
574 validate_with_options(sql, dialect, &ValidationOptions::default())
575}
576
577#[cfg(feature = "semantic")]
579#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
580#[serde(rename_all = "camelCase")]
581pub struct ValidationOptions {
582 #[serde(default)]
585 pub strict_syntax: bool,
586}
587
588#[cfg(feature = "semantic")]
590pub fn validate_with_options(
591 sql: &str,
592 dialect: DialectType,
593 options: &ValidationOptions,
594) -> ValidationResult {
595 let d = Dialect::get(dialect);
596 match d.parse(sql) {
597 Ok(expressions) => {
598 for expr in &expressions {
602 if !expr.is_statement() {
603 let msg = format!("Invalid expression / Unexpected token");
604 return ValidationResult::with_errors(vec![ValidationError::error(
605 msg, "E004",
606 )]);
607 }
608 }
609 if options.strict_syntax {
610 if let Some(error) = strict_syntax_error(sql, &d) {
611 return ValidationResult::with_errors(vec![error]);
612 }
613 }
614 ValidationResult::success()
615 }
616 Err(e) => {
617 let error = match &e {
618 Error::Syntax {
619 message,
620 line,
621 column,
622 start,
623 end,
624 } => ValidationError::error(message.clone(), "E001")
625 .with_location(*line, *column)
626 .with_span(Some(*start), Some(*end)),
627 Error::Tokenize {
628 message,
629 line,
630 column,
631 start,
632 end,
633 } => ValidationError::error(message.clone(), "E002")
634 .with_location(*line, *column)
635 .with_span(Some(*start), Some(*end)),
636 Error::Parse {
637 message,
638 line,
639 column,
640 start,
641 end,
642 } => ValidationError::error(message.clone(), "E003")
643 .with_location(*line, *column)
644 .with_span(Some(*start), Some(*end)),
645 _ => ValidationError::error(e.to_string(), "E000"),
646 };
647 ValidationResult::with_errors(vec![error])
648 }
649 }
650}
651
652#[cfg(feature = "semantic")]
653fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
654 let tokens = dialect.tokenize(sql).ok()?;
655
656 for (idx, token) in tokens.iter().enumerate() {
657 if token.token_type != TokenType::Comma {
658 continue;
659 }
660
661 let next = tokens.get(idx + 1);
662 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
663 Some(TokenType::From) => (true, "FROM"),
664 Some(TokenType::Where) => (true, "WHERE"),
665 Some(TokenType::GroupBy) => (true, "GROUP BY"),
666 Some(TokenType::Having) => (true, "HAVING"),
667 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
668 Some(TokenType::Limit) => (true, "LIMIT"),
669 Some(TokenType::Offset) => (true, "OFFSET"),
670 Some(TokenType::Union) => (true, "UNION"),
671 Some(TokenType::Intersect) => (true, "INTERSECT"),
672 Some(TokenType::Except) => (true, "EXCEPT"),
673 Some(TokenType::Qualify) => (true, "QUALIFY"),
674 Some(TokenType::Window) => (true, "WINDOW"),
675 Some(TokenType::Semicolon) | None => (true, "end of statement"),
676 _ => (false, ""),
677 };
678
679 if is_boundary {
680 let message = format!(
681 "Trailing comma before {} is not allowed in strict syntax mode",
682 boundary_name
683 );
684 return Some(
685 ValidationError::error(message, "E005")
686 .with_location(token.span.line, token.span.column),
687 );
688 }
689 }
690
691 None
692}
693
694#[cfg(feature = "transpile")]
707pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
708 transpile_with_by_name(sql, read, write, &TranspileOptions::default())
709}
710
711#[cfg(feature = "transpile")]
715pub fn transpile_with_by_name(
716 sql: &str,
717 read: &str,
718 write: &str,
719 opts: &TranspileOptions,
720) -> Result<Vec<String>> {
721 let read_dialect = Dialect::get_by_name(read)
722 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
723 let write_dialect = Dialect::get_by_name(write)
724 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
725 read_dialect.transpile_with(sql, &write_dialect, opts.clone())
726}
727
728pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
732 let d = Dialect::get_by_name(dialect)
733 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
734 d.parse(sql)
735}
736
737#[cfg(feature = "generate")]
741pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
742 let d = Dialect::get_by_name(dialect)
743 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
744 d.generate(expression)
745}
746
747#[cfg(feature = "generate")]
751pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
752 format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
753}
754
755#[cfg(feature = "generate")]
757pub fn format_with_options_by_name(
758 sql: &str,
759 dialect: &str,
760 options: &FormatGuardOptions,
761) -> Result<Vec<String>> {
762 let d = Dialect::get_by_name(dialect)
763 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
764 format_with_dialect(sql, &d, options)
765}
766
767#[cfg(all(test, feature = "semantic"))]
768mod validation_tests {
769 use super::*;
770
771 #[test]
772 fn validate_is_permissive_by_default_for_trailing_commas() {
773 let result = validate("SELECT name, FROM employees", DialectType::Generic);
774 assert!(result.valid, "Result: {:?}", result.errors);
775 }
776
777 #[test]
778 fn validate_with_options_rejects_trailing_comma_before_from() {
779 let options = ValidationOptions {
780 strict_syntax: true,
781 };
782 let result = validate_with_options(
783 "SELECT name, FROM employees",
784 DialectType::Generic,
785 &options,
786 );
787 assert!(!result.valid, "Result should be invalid");
788 assert!(
789 result.errors.iter().any(|e| e.code == "E005"),
790 "Expected E005, got: {:?}",
791 result.errors
792 );
793 }
794
795 #[test]
796 fn validate_with_options_rejects_trailing_comma_before_where() {
797 let options = ValidationOptions {
798 strict_syntax: true,
799 };
800 let result = validate_with_options(
801 "SELECT name FROM employees, WHERE salary > 10",
802 DialectType::Generic,
803 &options,
804 );
805 assert!(!result.valid, "Result should be invalid");
806 assert!(
807 result.errors.iter().any(|e| e.code == "E005"),
808 "Expected E005, got: {:?}",
809 result.errors
810 );
811 }
812}
813
814#[cfg(all(test, feature = "generate"))]
815mod format_tests {
816 use super::*;
817
818 #[test]
819 fn format_basic_query() {
820 let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
821 assert_eq!(result.len(), 1);
822 assert!(result[0].contains('\n'));
823 }
824
825 #[test]
826 fn format_guard_rejects_large_input() {
827 let options = FormatGuardOptions {
828 max_input_bytes: Some(7),
829 max_tokens: None,
830 max_ast_nodes: None,
831 max_set_op_chain: None,
832 };
833 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
834 .expect_err("expected guard error");
835 assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
836 }
837
838 #[test]
839 fn format_guard_rejects_token_budget() {
840 let options = FormatGuardOptions {
841 max_input_bytes: None,
842 max_tokens: Some(1),
843 max_ast_nodes: None,
844 max_set_op_chain: None,
845 };
846 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
847 .expect_err("expected guard error");
848 assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
849 }
850
851 #[test]
852 fn format_guard_rejects_ast_budget() {
853 let options = FormatGuardOptions {
854 max_input_bytes: None,
855 max_tokens: None,
856 max_ast_nodes: Some(1),
857 max_set_op_chain: None,
858 };
859 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
860 .expect_err("expected guard error");
861 assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
862 }
863
864 #[test]
865 fn format_guard_rejects_set_op_chain_budget() {
866 let options = FormatGuardOptions {
867 max_input_bytes: None,
868 max_tokens: None,
869 max_ast_nodes: None,
870 max_set_op_chain: Some(1),
871 };
872 let err = format_with_options(
873 "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
874 DialectType::Generic,
875 &options,
876 )
877 .expect_err("expected guard error");
878 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
879 }
880
881 #[test]
882 fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
883 let options = FormatGuardOptions {
884 max_input_bytes: None,
885 max_tokens: None,
886 max_ast_nodes: None,
887 max_set_op_chain: Some(0),
888 };
889 let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
890 assert!(result.is_ok(), "Result: {:?}", result);
891 }
892
893 #[test]
894 fn issue57_invalid_ternary_returns_error() {
895 let sql = "SELECT x > 0 ? 1 : 0 FROM t";
898
899 let parse_result = parse(sql, DialectType::PostgreSQL);
900 assert!(
901 parse_result.is_err(),
902 "Expected parse error for invalid ternary SQL, got: {:?}",
903 parse_result
904 );
905
906 let format_result = format(sql, DialectType::PostgreSQL);
907 assert!(
908 format_result.is_err(),
909 "Expected format error for invalid ternary SQL, got: {:?}",
910 format_result
911 );
912
913 let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
914 assert!(
915 transpile_result.is_err(),
916 "Expected transpile error for invalid ternary SQL, got: {:?}",
917 transpile_result
918 );
919 }
920
921 #[test]
926 fn transpile_applies_cross_dialect_rewrites() {
927 let out = transpile(
929 "SELECT to_timestamp(col) FROM t",
930 DialectType::DuckDB,
931 DialectType::Trino,
932 )
933 .expect("transpile failed");
934 assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
935
936 let out = transpile(
938 "SELECT CAST(col AS JSON) FROM t",
939 DialectType::DuckDB,
940 DialectType::Trino,
941 )
942 .expect("transpile failed");
943 assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
944 }
945
946 #[test]
950 fn transpile_matches_dialect_method() {
951 let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
952 (
953 DialectType::DuckDB,
954 DialectType::Trino,
955 "duckdb",
956 "trino",
957 "SELECT to_timestamp(col) FROM t",
958 ),
959 (
960 DialectType::DuckDB,
961 DialectType::Trino,
962 "duckdb",
963 "trino",
964 "SELECT CAST(col AS JSON) FROM t",
965 ),
966 (
967 DialectType::DuckDB,
968 DialectType::Trino,
969 "duckdb",
970 "trino",
971 "SELECT json_valid(col) FROM t",
972 ),
973 (
974 DialectType::Snowflake,
975 DialectType::DuckDB,
976 "snowflake",
977 "duckdb",
978 "SELECT DATEDIFF(day, a, b) FROM t",
979 ),
980 (
981 DialectType::BigQuery,
982 DialectType::DuckDB,
983 "bigquery",
984 "duckdb",
985 "SELECT DATE_DIFF(a, b, DAY) FROM t",
986 ),
987 (
988 DialectType::Generic,
989 DialectType::Generic,
990 "generic",
991 "generic",
992 "SELECT 1",
993 ),
994 ];
995 for (read, write, read_name, write_name, sql) in cases {
996 let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
997 let via_name = transpile_by_name(sql, read_name, write_name)
998 .expect("lib::transpile_by_name failed");
999 let via_dialect = Dialect::get(*read)
1000 .transpile(sql, *write)
1001 .expect("Dialect::transpile failed");
1002 assert_eq!(
1003 via_lib, via_dialect,
1004 "lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
1005 read, write
1006 );
1007 assert_eq!(
1008 via_name, via_dialect,
1009 "lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
1010 );
1011 }
1012 }
1013
1014 #[test]
1015 fn format_default_guard_rejects_deep_union_chain_before_parse() {
1016 let base = "SELECT col0, col1 FROM t";
1017 let mut sql = base.to_string();
1018 for _ in 0..1100 {
1019 sql.push_str(" UNION ALL ");
1020 sql.push_str(base);
1021 }
1022
1023 let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
1024 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
1025 }
1026}