1#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
16pub mod ast_transforms;
17#[cfg(feature = "builder")]
18pub mod builder;
19pub mod dialects;
20#[cfg(feature = "diff")]
21pub mod diff;
22pub mod error;
23pub mod expressions;
24#[cfg(feature = "semantic")]
25pub mod function_catalog;
26mod function_registry;
27#[cfg(feature = "generate")]
28pub mod generator;
29#[cfg(feature = "semantic")]
30pub mod helper;
31#[cfg(feature = "semantic")]
32pub mod lineage;
33#[cfg(feature = "openlineage")]
34pub mod openlineage;
35#[cfg(feature = "semantic")]
36pub mod optimizer;
37pub mod parser;
38#[cfg(feature = "planner")]
39pub mod planner;
40#[cfg(feature = "semantic")]
41pub mod resolver;
42#[cfg(feature = "semantic")]
43pub mod schema;
44#[cfg(feature = "semantic")]
45pub mod scope;
46#[cfg(feature = "time")]
47pub mod time;
48pub mod tokens;
49#[cfg(feature = "transpile")]
50pub mod transforms;
51#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
52pub mod traversal;
53#[cfg(any(feature = "semantic", feature = "time"))]
54pub mod trie;
55#[cfg(feature = "semantic")]
56pub mod validation;
57
58#[cfg(any(feature = "generate", feature = "semantic"))]
59use serde::{Deserialize, Serialize};
60
61#[cfg(feature = "ast-tools")]
62pub use ast_transforms::{
63 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
64 get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
65 get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
66 remove_select_columns, remove_where, rename_columns, rename_tables, rename_tables_with_options,
67 replace_by_type, replace_nodes, set_distinct, set_limit, set_offset, RenameTablesOptions,
68};
69pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
70#[cfg(feature = "transpile")]
71pub use dialects::{TranspileOptions, TranspileTarget};
72pub use error::{Error, Result};
73#[cfg(feature = "semantic")]
74pub use error::{ValidationError, ValidationResult, ValidationSeverity};
75pub use expressions::Expression;
76#[cfg(feature = "semantic")]
77pub use function_catalog::{
78 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
79};
80#[cfg(feature = "generate")]
81pub use generator::Generator;
82#[cfg(feature = "semantic")]
83pub use helper::{
84 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
85 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
86};
87#[cfg(feature = "semantic")]
88pub use optimizer::{
89 annotate_types, qualify_tables, QualifyTablesOptions, TypeAnnotator, TypeCoercionClass,
90};
91pub use parser::Parser;
92#[cfg(feature = "semantic")]
93pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
94#[cfg(feature = "semantic")]
95pub use schema::{
96 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
97};
98#[cfg(feature = "semantic")]
99pub use scope::{
100 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
101 ScopeType, SourceInfo,
102};
103#[cfg(feature = "time")]
104pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
105pub use tokens::{Token, TokenType, Tokenizer};
106#[cfg(feature = "ast-tools")]
107pub use traversal::{
108 contains_aggregate,
109 contains_subquery,
110 contains_window_function,
111 find_ancestor,
112 find_parent,
113 get_all_tables,
114 get_columns,
115 get_merge_source,
116 get_merge_target,
117 get_tables,
118 is_add,
119 is_aggregate,
120 is_alias,
121 is_alter_table,
122 is_and,
123 is_arithmetic,
124 is_avg,
125 is_between,
126 is_boolean,
127 is_case,
128 is_cast,
129 is_coalesce,
130 is_column,
131 is_comparison,
132 is_concat,
133 is_count,
134 is_create_index,
135 is_create_table,
136 is_create_view,
137 is_cte,
138 is_ddl,
139 is_delete,
140 is_div,
141 is_drop_index,
142 is_drop_table,
143 is_drop_view,
144 is_eq,
145 is_except,
146 is_exists,
147 is_from,
148 is_function,
149 is_group_by,
150 is_gt,
151 is_gte,
152 is_having,
153 is_identifier,
154 is_ilike,
155 is_in,
156 is_insert,
158 is_intersect,
159 is_is_null,
160 is_join,
161 is_like,
162 is_limit,
163 is_literal,
164 is_logical,
165 is_lt,
166 is_lte,
167 is_max_func,
168 is_merge,
169 is_min_func,
170 is_mod,
171 is_mul,
172 is_neq,
173 is_not,
174 is_null_if,
175 is_null_literal,
176 is_offset,
177 is_or,
178 is_order_by,
179 is_ordered,
180 is_paren,
181 is_query,
183 is_safe_cast,
184 is_select,
185 is_set_operation,
186 is_star,
187 is_sub,
188 is_subquery,
189 is_sum,
190 is_table,
191 is_try_cast,
192 is_union,
193 is_update,
194 is_where,
195 is_window_function,
196 is_with,
197 transform,
198 transform_map,
199 BfsIter,
200 DfsIter,
201 ExpressionWalk,
202 ParentInfo,
203 TreeContext,
204};
205#[cfg(any(feature = "semantic", feature = "time"))]
206pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
207#[cfg(feature = "semantic")]
208pub use validation::{
209 mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
210 SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
211 SchemaValidationOptions, ValidationSchema,
212};
213
214#[cfg(feature = "generate")]
215const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; #[cfg(feature = "generate")]
217const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
218#[cfg(feature = "generate")]
219const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
220#[cfg(feature = "generate")]
221const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
222
223#[cfg(feature = "generate")]
224fn default_format_max_input_bytes() -> Option<usize> {
225 Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
226}
227
228#[cfg(feature = "generate")]
229fn default_format_max_tokens() -> Option<usize> {
230 Some(DEFAULT_FORMAT_MAX_TOKENS)
231}
232
233#[cfg(feature = "generate")]
234fn default_format_max_ast_nodes() -> Option<usize> {
235 Some(DEFAULT_FORMAT_MAX_AST_NODES)
236}
237
238#[cfg(feature = "generate")]
239fn default_format_max_set_op_chain() -> Option<usize> {
240 Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
241}
242
243#[cfg(feature = "generate")]
248#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
249#[serde(rename_all = "camelCase")]
250pub struct FormatGuardOptions {
251 #[serde(default = "default_format_max_input_bytes")]
254 pub max_input_bytes: Option<usize>,
255 #[serde(default = "default_format_max_tokens")]
258 pub max_tokens: Option<usize>,
259 #[serde(default = "default_format_max_ast_nodes")]
262 pub max_ast_nodes: Option<usize>,
263 #[serde(default = "default_format_max_set_op_chain")]
268 pub max_set_op_chain: Option<usize>,
269}
270
271#[cfg(feature = "generate")]
272impl Default for FormatGuardOptions {
273 fn default() -> Self {
274 Self {
275 max_input_bytes: default_format_max_input_bytes(),
276 max_tokens: default_format_max_tokens(),
277 max_ast_nodes: default_format_max_ast_nodes(),
278 max_set_op_chain: default_format_max_set_op_chain(),
279 }
280 }
281}
282
283#[cfg(feature = "generate")]
284fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
285 Error::generate(format!(
286 "{code}: value {actual} exceeds configured limit {limit}"
287 ))
288}
289
290#[cfg(feature = "generate")]
291fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
292 if let Some(max) = options.max_input_bytes {
293 let input_bytes = sql.len();
294 if input_bytes > max {
295 return Err(format_guard_error(
296 "E_GUARD_INPUT_TOO_LARGE",
297 input_bytes,
298 max,
299 ));
300 }
301 }
302 Ok(())
303}
304
305#[cfg(feature = "generate")]
306fn parse_with_token_guard(
307 sql: &str,
308 dialect: &Dialect,
309 options: &FormatGuardOptions,
310) -> Result<Vec<Expression>> {
311 let tokens = dialect.tokenize(sql)?;
312 if let Some(max) = options.max_tokens {
313 let token_count = tokens.len();
314 if token_count > max {
315 return Err(format_guard_error(
316 "E_GUARD_TOKEN_BUDGET_EXCEEDED",
317 token_count,
318 max,
319 ));
320 }
321 }
322 enforce_set_op_chain_guard(&tokens, options)?;
323
324 let config = crate::parser::ParserConfig {
325 dialect: Some(dialect.dialect_type()),
326 ..Default::default()
327 };
328 let mut parser = Parser::with_source(tokens, config, sql.to_string());
329 parser.parse()
330}
331
332#[cfg(feature = "generate")]
333fn is_trivia_token(token_type: TokenType) -> bool {
334 matches!(
335 token_type,
336 TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
337 )
338}
339
340#[cfg(feature = "generate")]
341fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
342 tokens
343 .iter()
344 .skip(start)
345 .find(|token| !is_trivia_token(token.token_type))
346}
347
348#[cfg(feature = "generate")]
349fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
350 let token = &tokens[idx];
351 match token.token_type {
352 TokenType::Union | TokenType::Intersect => true,
353 TokenType::Except => {
354 if token.text.eq_ignore_ascii_case("minus")
357 && matches!(
358 next_significant_token(tokens, idx + 1).map(|t| t.token_type),
359 Some(TokenType::LParen)
360 )
361 {
362 return false;
363 }
364 true
365 }
366 _ => false,
367 }
368}
369
370#[cfg(feature = "generate")]
371fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
372 let Some(max) = options.max_set_op_chain else {
373 return Ok(());
374 };
375
376 let mut set_op_count = 0usize;
377 for (idx, token) in tokens.iter().enumerate() {
378 if token.token_type == TokenType::Semicolon {
379 set_op_count = 0;
380 continue;
381 }
382
383 if is_set_operation_token(tokens, idx) {
384 set_op_count += 1;
385 if set_op_count > max {
386 return Err(format_guard_error(
387 "E_GUARD_SET_OP_CHAIN_EXCEEDED",
388 set_op_count,
389 max,
390 ));
391 }
392 }
393 }
394
395 Ok(())
396}
397
398#[cfg(feature = "generate")]
399fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
400 if let Some(max) = options.max_ast_nodes {
401 let ast_nodes: usize = expressions
402 .iter()
403 .map(crate::ast_transforms::node_count)
404 .sum();
405 if ast_nodes > max {
406 return Err(format_guard_error(
407 "E_GUARD_AST_BUDGET_EXCEEDED",
408 ast_nodes,
409 max,
410 ));
411 }
412 }
413 Ok(())
414}
415
416#[cfg(feature = "generate")]
417fn format_with_dialect(
418 sql: &str,
419 dialect: &Dialect,
420 options: &FormatGuardOptions,
421) -> Result<Vec<String>> {
422 enforce_input_guard(sql, options)?;
423 let expressions = parse_with_token_guard(sql, dialect, options)?;
424 enforce_ast_guard(&expressions, options)?;
425
426 expressions
427 .iter()
428 .map(|expr| dialect.generate_pretty(expr))
429 .collect()
430}
431
432#[cfg(feature = "transpile")]
453pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
454 Dialect::get(read).transpile(sql, write)
459}
460
461pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
470 let d = Dialect::get(dialect);
471 d.parse(sql)
472}
473
474pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
483 let mut expressions = parse(sql, dialect)?;
484
485 if expressions.len() != 1 {
486 return Err(Error::parse(
487 format!("Expected 1 statement, found {}", expressions.len()),
488 0,
489 0,
490 0,
491 0,
492 ));
493 }
494
495 Ok(expressions.remove(0))
496}
497
498#[cfg(feature = "generate")]
507pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
508 let d = Dialect::get(dialect);
509 d.generate(expression)
510}
511
512#[cfg(feature = "generate")]
516pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
517 format_with_options(sql, dialect, &FormatGuardOptions::default())
518}
519
520#[cfg(feature = "generate")]
522pub fn format_with_options(
523 sql: &str,
524 dialect: DialectType,
525 options: &FormatGuardOptions,
526) -> Result<Vec<String>> {
527 let d = Dialect::get(dialect);
528 format_with_dialect(sql, &d, options)
529}
530
531#[cfg(feature = "semantic")]
540pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
541 validate_with_options(sql, dialect, &ValidationOptions::default())
542}
543
544#[cfg(feature = "semantic")]
546#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
547#[serde(rename_all = "camelCase")]
548pub struct ValidationOptions {
549 #[serde(default)]
552 pub strict_syntax: bool,
553}
554
555#[cfg(feature = "semantic")]
557pub fn validate_with_options(
558 sql: &str,
559 dialect: DialectType,
560 options: &ValidationOptions,
561) -> ValidationResult {
562 let d = Dialect::get(dialect);
563 match d.parse(sql) {
564 Ok(expressions) => {
565 for expr in &expressions {
569 if !expr.is_statement() {
570 let msg = format!("Invalid expression / Unexpected token");
571 return ValidationResult::with_errors(vec![ValidationError::error(
572 msg, "E004",
573 )]);
574 }
575 }
576 if options.strict_syntax {
577 if let Some(error) = strict_syntax_error(sql, &d) {
578 return ValidationResult::with_errors(vec![error]);
579 }
580 }
581 ValidationResult::success()
582 }
583 Err(e) => {
584 let error = match &e {
585 Error::Syntax {
586 message,
587 line,
588 column,
589 start,
590 end,
591 } => ValidationError::error(message.clone(), "E001")
592 .with_location(*line, *column)
593 .with_span(Some(*start), Some(*end)),
594 Error::Tokenize {
595 message,
596 line,
597 column,
598 start,
599 end,
600 } => ValidationError::error(message.clone(), "E002")
601 .with_location(*line, *column)
602 .with_span(Some(*start), Some(*end)),
603 Error::Parse {
604 message,
605 line,
606 column,
607 start,
608 end,
609 } => ValidationError::error(message.clone(), "E003")
610 .with_location(*line, *column)
611 .with_span(Some(*start), Some(*end)),
612 _ => ValidationError::error(e.to_string(), "E000"),
613 };
614 ValidationResult::with_errors(vec![error])
615 }
616 }
617}
618
619#[cfg(feature = "semantic")]
620fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
621 let tokens = dialect.tokenize(sql).ok()?;
622
623 for (idx, token) in tokens.iter().enumerate() {
624 if token.token_type != TokenType::Comma {
625 continue;
626 }
627
628 let next = tokens.get(idx + 1);
629 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
630 Some(TokenType::From) => (true, "FROM"),
631 Some(TokenType::Where) => (true, "WHERE"),
632 Some(TokenType::GroupBy) => (true, "GROUP BY"),
633 Some(TokenType::Having) => (true, "HAVING"),
634 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
635 Some(TokenType::Limit) => (true, "LIMIT"),
636 Some(TokenType::Offset) => (true, "OFFSET"),
637 Some(TokenType::Union) => (true, "UNION"),
638 Some(TokenType::Intersect) => (true, "INTERSECT"),
639 Some(TokenType::Except) => (true, "EXCEPT"),
640 Some(TokenType::Qualify) => (true, "QUALIFY"),
641 Some(TokenType::Window) => (true, "WINDOW"),
642 Some(TokenType::Semicolon) | None => (true, "end of statement"),
643 _ => (false, ""),
644 };
645
646 if is_boundary {
647 let message = format!(
648 "Trailing comma before {} is not allowed in strict syntax mode",
649 boundary_name
650 );
651 return Some(
652 ValidationError::error(message, "E005")
653 .with_location(token.span.line, token.span.column),
654 );
655 }
656 }
657
658 None
659}
660
661#[cfg(feature = "transpile")]
674pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
675 transpile_with_by_name(sql, read, write, &TranspileOptions::default())
676}
677
678#[cfg(feature = "transpile")]
682pub fn transpile_with_by_name(
683 sql: &str,
684 read: &str,
685 write: &str,
686 opts: &TranspileOptions,
687) -> Result<Vec<String>> {
688 let read_dialect = Dialect::get_by_name(read)
689 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
690 let write_dialect = Dialect::get_by_name(write)
691 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
692 read_dialect.transpile_with(sql, &write_dialect, opts.clone())
693}
694
695pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
699 let d = Dialect::get_by_name(dialect)
700 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
701 d.parse(sql)
702}
703
704#[cfg(feature = "generate")]
708pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
709 let d = Dialect::get_by_name(dialect)
710 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
711 d.generate(expression)
712}
713
714#[cfg(feature = "generate")]
718pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
719 format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
720}
721
722#[cfg(feature = "generate")]
724pub fn format_with_options_by_name(
725 sql: &str,
726 dialect: &str,
727 options: &FormatGuardOptions,
728) -> Result<Vec<String>> {
729 let d = Dialect::get_by_name(dialect)
730 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
731 format_with_dialect(sql, &d, options)
732}
733
734#[cfg(all(test, feature = "semantic"))]
735mod validation_tests {
736 use super::*;
737
738 #[test]
739 fn validate_is_permissive_by_default_for_trailing_commas() {
740 let result = validate("SELECT name, FROM employees", DialectType::Generic);
741 assert!(result.valid, "Result: {:?}", result.errors);
742 }
743
744 #[test]
745 fn validate_with_options_rejects_trailing_comma_before_from() {
746 let options = ValidationOptions {
747 strict_syntax: true,
748 };
749 let result = validate_with_options(
750 "SELECT name, FROM employees",
751 DialectType::Generic,
752 &options,
753 );
754 assert!(!result.valid, "Result should be invalid");
755 assert!(
756 result.errors.iter().any(|e| e.code == "E005"),
757 "Expected E005, got: {:?}",
758 result.errors
759 );
760 }
761
762 #[test]
763 fn validate_with_options_rejects_trailing_comma_before_where() {
764 let options = ValidationOptions {
765 strict_syntax: true,
766 };
767 let result = validate_with_options(
768 "SELECT name FROM employees, WHERE salary > 10",
769 DialectType::Generic,
770 &options,
771 );
772 assert!(!result.valid, "Result should be invalid");
773 assert!(
774 result.errors.iter().any(|e| e.code == "E005"),
775 "Expected E005, got: {:?}",
776 result.errors
777 );
778 }
779}
780
781#[cfg(all(test, feature = "generate"))]
782mod format_tests {
783 use super::*;
784
785 #[test]
786 fn format_basic_query() {
787 let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
788 assert_eq!(result.len(), 1);
789 assert!(result[0].contains('\n'));
790 }
791
792 #[test]
793 fn format_guard_rejects_large_input() {
794 let options = FormatGuardOptions {
795 max_input_bytes: Some(7),
796 max_tokens: None,
797 max_ast_nodes: None,
798 max_set_op_chain: None,
799 };
800 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
801 .expect_err("expected guard error");
802 assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
803 }
804
805 #[test]
806 fn format_guard_rejects_token_budget() {
807 let options = FormatGuardOptions {
808 max_input_bytes: None,
809 max_tokens: Some(1),
810 max_ast_nodes: None,
811 max_set_op_chain: None,
812 };
813 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
814 .expect_err("expected guard error");
815 assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
816 }
817
818 #[test]
819 fn format_guard_rejects_ast_budget() {
820 let options = FormatGuardOptions {
821 max_input_bytes: None,
822 max_tokens: None,
823 max_ast_nodes: Some(1),
824 max_set_op_chain: None,
825 };
826 let err = format_with_options("SELECT 1", DialectType::Generic, &options)
827 .expect_err("expected guard error");
828 assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
829 }
830
831 #[test]
832 fn format_guard_rejects_set_op_chain_budget() {
833 let options = FormatGuardOptions {
834 max_input_bytes: None,
835 max_tokens: None,
836 max_ast_nodes: None,
837 max_set_op_chain: Some(1),
838 };
839 let err = format_with_options(
840 "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
841 DialectType::Generic,
842 &options,
843 )
844 .expect_err("expected guard error");
845 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
846 }
847
848 #[test]
849 fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
850 let options = FormatGuardOptions {
851 max_input_bytes: None,
852 max_tokens: None,
853 max_ast_nodes: None,
854 max_set_op_chain: Some(0),
855 };
856 let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
857 assert!(result.is_ok(), "Result: {:?}", result);
858 }
859
860 #[test]
861 fn issue57_invalid_ternary_returns_error() {
862 let sql = "SELECT x > 0 ? 1 : 0 FROM t";
865
866 let parse_result = parse(sql, DialectType::PostgreSQL);
867 assert!(
868 parse_result.is_err(),
869 "Expected parse error for invalid ternary SQL, got: {:?}",
870 parse_result
871 );
872
873 let format_result = format(sql, DialectType::PostgreSQL);
874 assert!(
875 format_result.is_err(),
876 "Expected format error for invalid ternary SQL, got: {:?}",
877 format_result
878 );
879
880 let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
881 assert!(
882 transpile_result.is_err(),
883 "Expected transpile error for invalid ternary SQL, got: {:?}",
884 transpile_result
885 );
886 }
887
888 #[test]
893 fn transpile_applies_cross_dialect_rewrites() {
894 let out = transpile(
896 "SELECT to_timestamp(col) FROM t",
897 DialectType::DuckDB,
898 DialectType::Trino,
899 )
900 .expect("transpile failed");
901 assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
902
903 let out = transpile(
905 "SELECT CAST(col AS JSON) FROM t",
906 DialectType::DuckDB,
907 DialectType::Trino,
908 )
909 .expect("transpile failed");
910 assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
911 }
912
913 #[test]
917 fn transpile_matches_dialect_method() {
918 let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
919 (
920 DialectType::DuckDB,
921 DialectType::Trino,
922 "duckdb",
923 "trino",
924 "SELECT to_timestamp(col) FROM t",
925 ),
926 (
927 DialectType::DuckDB,
928 DialectType::Trino,
929 "duckdb",
930 "trino",
931 "SELECT CAST(col AS JSON) FROM t",
932 ),
933 (
934 DialectType::DuckDB,
935 DialectType::Trino,
936 "duckdb",
937 "trino",
938 "SELECT json_valid(col) FROM t",
939 ),
940 (
941 DialectType::Snowflake,
942 DialectType::DuckDB,
943 "snowflake",
944 "duckdb",
945 "SELECT DATEDIFF(day, a, b) FROM t",
946 ),
947 (
948 DialectType::BigQuery,
949 DialectType::DuckDB,
950 "bigquery",
951 "duckdb",
952 "SELECT DATE_DIFF(a, b, DAY) FROM t",
953 ),
954 (
955 DialectType::Generic,
956 DialectType::Generic,
957 "generic",
958 "generic",
959 "SELECT 1",
960 ),
961 ];
962 for (read, write, read_name, write_name, sql) in cases {
963 let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
964 let via_name = transpile_by_name(sql, read_name, write_name)
965 .expect("lib::transpile_by_name failed");
966 let via_dialect = Dialect::get(*read)
967 .transpile(sql, *write)
968 .expect("Dialect::transpile failed");
969 assert_eq!(
970 via_lib, via_dialect,
971 "lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
972 read, write
973 );
974 assert_eq!(
975 via_name, via_dialect,
976 "lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
977 );
978 }
979 }
980
981 #[test]
982 fn format_default_guard_rejects_deep_union_chain_before_parse() {
983 let base = "SELECT col0, col1 FROM t";
984 let mut sql = base.to_string();
985 for _ in 0..1100 {
986 sql.push_str(" UNION ALL ");
987 sql.push_str(base);
988 }
989
990 let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
991 assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
992 }
993}