Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43    get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
44    get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
45    remove_select_columns, remove_where, rename_columns, rename_tables, rename_tables_with_options,
46    replace_by_type, replace_nodes, set_distinct, set_limit, set_offset, RenameTablesOptions,
47};
48pub use dialects::{
49    unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType, TranspileOptions,
50    TranspileTarget,
51};
52pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
53pub use expressions::Expression;
54pub use function_catalog::{
55    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
56};
57pub use generator::Generator;
58pub use helper::{
59    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
60    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
61};
62pub use optimizer::{
63    annotate_types, qualify_tables, QualifyTablesOptions, TypeAnnotator, TypeCoercionClass,
64};
65pub use parser::Parser;
66pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
67pub use schema::{
68    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
69};
70pub use scope::{
71    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
72    ScopeType, SourceInfo,
73};
74pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
75pub use tokens::{Token, TokenType, Tokenizer};
76pub use traversal::{
77    contains_aggregate,
78    contains_subquery,
79    contains_window_function,
80    find_ancestor,
81    find_parent,
82    get_all_tables,
83    get_columns,
84    get_merge_source,
85    get_merge_target,
86    get_tables,
87    is_add,
88    is_aggregate,
89    is_alias,
90    is_alter_table,
91    is_and,
92    is_arithmetic,
93    is_avg,
94    is_between,
95    is_boolean,
96    is_case,
97    is_cast,
98    is_coalesce,
99    is_column,
100    is_comparison,
101    is_concat,
102    is_count,
103    is_create_index,
104    is_create_table,
105    is_create_view,
106    is_cte,
107    is_ddl,
108    is_delete,
109    is_div,
110    is_drop_index,
111    is_drop_table,
112    is_drop_view,
113    is_eq,
114    is_except,
115    is_exists,
116    is_from,
117    is_function,
118    is_group_by,
119    is_gt,
120    is_gte,
121    is_having,
122    is_identifier,
123    is_ilike,
124    is_in,
125    // Extended type predicates
126    is_insert,
127    is_intersect,
128    is_is_null,
129    is_join,
130    is_like,
131    is_limit,
132    is_literal,
133    is_logical,
134    is_lt,
135    is_lte,
136    is_max_func,
137    is_merge,
138    is_min_func,
139    is_mod,
140    is_mul,
141    is_neq,
142    is_not,
143    is_null_if,
144    is_null_literal,
145    is_offset,
146    is_or,
147    is_order_by,
148    is_ordered,
149    is_paren,
150    // Composite predicates
151    is_query,
152    is_safe_cast,
153    is_select,
154    is_set_operation,
155    is_star,
156    is_sub,
157    is_subquery,
158    is_sum,
159    is_table,
160    is_try_cast,
161    is_union,
162    is_update,
163    is_where,
164    is_window_function,
165    is_with,
166    transform,
167    transform_map,
168    BfsIter,
169    DfsIter,
170    ExpressionWalk,
171    ParentInfo,
172    TreeContext,
173};
174pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
175pub use validation::{
176    mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
177    SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
178    SchemaValidationOptions, ValidationSchema,
179};
180
181const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
182const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
183const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
184const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
185
186fn default_format_max_input_bytes() -> Option<usize> {
187    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
188}
189
190fn default_format_max_tokens() -> Option<usize> {
191    Some(DEFAULT_FORMAT_MAX_TOKENS)
192}
193
194fn default_format_max_ast_nodes() -> Option<usize> {
195    Some(DEFAULT_FORMAT_MAX_AST_NODES)
196}
197
198fn default_format_max_set_op_chain() -> Option<usize> {
199    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
200}
201
202/// Guard options for SQL pretty-formatting.
203///
204/// These limits protect against extremely large/complex queries that can cause
205/// high memory pressure in constrained runtimes (for example browser WASM).
206#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
207#[serde(rename_all = "camelCase")]
208pub struct FormatGuardOptions {
209    /// Maximum allowed SQL input size in bytes.
210    /// `None` disables this check.
211    #[serde(default = "default_format_max_input_bytes")]
212    pub max_input_bytes: Option<usize>,
213    /// Maximum allowed number of tokens after tokenization.
214    /// `None` disables this check.
215    #[serde(default = "default_format_max_tokens")]
216    pub max_tokens: Option<usize>,
217    /// Maximum allowed AST node count after parsing.
218    /// `None` disables this check.
219    #[serde(default = "default_format_max_ast_nodes")]
220    pub max_ast_nodes: Option<usize>,
221    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
222    /// observed in a statement before parsing.
223    ///
224    /// `None` disables this check.
225    #[serde(default = "default_format_max_set_op_chain")]
226    pub max_set_op_chain: Option<usize>,
227}
228
229impl Default for FormatGuardOptions {
230    fn default() -> Self {
231        Self {
232            max_input_bytes: default_format_max_input_bytes(),
233            max_tokens: default_format_max_tokens(),
234            max_ast_nodes: default_format_max_ast_nodes(),
235            max_set_op_chain: default_format_max_set_op_chain(),
236        }
237    }
238}
239
240fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
241    Error::generate(format!(
242        "{code}: value {actual} exceeds configured limit {limit}"
243    ))
244}
245
246fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
247    if let Some(max) = options.max_input_bytes {
248        let input_bytes = sql.len();
249        if input_bytes > max {
250            return Err(format_guard_error(
251                "E_GUARD_INPUT_TOO_LARGE",
252                input_bytes,
253                max,
254            ));
255        }
256    }
257    Ok(())
258}
259
260fn parse_with_token_guard(
261    sql: &str,
262    dialect: &Dialect,
263    options: &FormatGuardOptions,
264) -> Result<Vec<Expression>> {
265    let tokens = dialect.tokenize(sql)?;
266    if let Some(max) = options.max_tokens {
267        let token_count = tokens.len();
268        if token_count > max {
269            return Err(format_guard_error(
270                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
271                token_count,
272                max,
273            ));
274        }
275    }
276    enforce_set_op_chain_guard(&tokens, options)?;
277
278    let config = crate::parser::ParserConfig {
279        dialect: Some(dialect.dialect_type()),
280        ..Default::default()
281    };
282    let mut parser = Parser::with_source(tokens, config, sql.to_string());
283    parser.parse()
284}
285
286fn is_trivia_token(token_type: TokenType) -> bool {
287    matches!(
288        token_type,
289        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
290    )
291}
292
293fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
294    tokens
295        .iter()
296        .skip(start)
297        .find(|token| !is_trivia_token(token.token_type))
298}
299
300fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
301    let token = &tokens[idx];
302    match token.token_type {
303        TokenType::Union | TokenType::Intersect => true,
304        TokenType::Except => {
305            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
306            // is a function call rather than a set operation.
307            if token.text.eq_ignore_ascii_case("minus")
308                && matches!(
309                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
310                    Some(TokenType::LParen)
311                )
312            {
313                return false;
314            }
315            true
316        }
317        _ => false,
318    }
319}
320
321fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
322    let Some(max) = options.max_set_op_chain else {
323        return Ok(());
324    };
325
326    let mut set_op_count = 0usize;
327    for (idx, token) in tokens.iter().enumerate() {
328        if token.token_type == TokenType::Semicolon {
329            set_op_count = 0;
330            continue;
331        }
332
333        if is_set_operation_token(tokens, idx) {
334            set_op_count += 1;
335            if set_op_count > max {
336                return Err(format_guard_error(
337                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
338                    set_op_count,
339                    max,
340                ));
341            }
342        }
343    }
344
345    Ok(())
346}
347
348fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
349    if let Some(max) = options.max_ast_nodes {
350        let ast_nodes: usize = expressions.iter().map(node_count).sum();
351        if ast_nodes > max {
352            return Err(format_guard_error(
353                "E_GUARD_AST_BUDGET_EXCEEDED",
354                ast_nodes,
355                max,
356            ));
357        }
358    }
359    Ok(())
360}
361
362fn format_with_dialect(
363    sql: &str,
364    dialect: &Dialect,
365    options: &FormatGuardOptions,
366) -> Result<Vec<String>> {
367    enforce_input_guard(sql, options)?;
368    let expressions = parse_with_token_guard(sql, dialect, options)?;
369    enforce_ast_guard(&expressions, options)?;
370
371    expressions
372        .iter()
373        .map(|expr| dialect.generate_pretty(expr))
374        .collect()
375}
376
377/// Transpile SQL from one dialect to another.
378///
379/// # Arguments
380/// * `sql` - The SQL string to transpile
381/// * `read` - The source dialect to parse with
382/// * `write` - The target dialect to generate
383///
384/// # Returns
385/// A vector of transpiled SQL statements
386///
387/// # Example
388/// ```
389/// use polyglot_sql::{transpile, DialectType};
390///
391/// let result = transpile(
392///     "SELECT EPOCH_MS(1618088028295)",
393///     DialectType::DuckDB,
394///     DialectType::Hive
395/// );
396/// ```
397pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
398    // Delegate to Dialect::transpile so that the full cross-dialect rewrite
399    // pipeline (source+target-aware normalization in `cross_dialect_normalize`)
400    // runs here as well. This keeps Rust crate users on the same code path as
401    // the WASM/FFI/Python bindings and the playground.
402    Dialect::get(read).transpile(sql, write)
403}
404
405/// Parse SQL into an AST.
406///
407/// # Arguments
408/// * `sql` - The SQL string to parse
409/// * `dialect` - The dialect to use for parsing
410///
411/// # Returns
412/// A vector of parsed expressions
413pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
414    let d = Dialect::get(dialect);
415    d.parse(sql)
416}
417
418/// Parse a single SQL statement.
419///
420/// # Arguments
421/// * `sql` - The SQL string containing a single statement
422/// * `dialect` - The dialect to use for parsing
423///
424/// # Returns
425/// The parsed expression, or an error if multiple statements found
426pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
427    let mut expressions = parse(sql, dialect)?;
428
429    if expressions.len() != 1 {
430        return Err(Error::parse(
431            format!("Expected 1 statement, found {}", expressions.len()),
432            0,
433            0,
434            0,
435            0,
436        ));
437    }
438
439    Ok(expressions.remove(0))
440}
441
442/// Generate SQL from an AST.
443///
444/// # Arguments
445/// * `expression` - The expression to generate SQL from
446/// * `dialect` - The target dialect
447///
448/// # Returns
449/// The generated SQL string
450pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
451    let d = Dialect::get(dialect);
452    d.generate(expression)
453}
454
455/// Format/pretty-print SQL statements.
456///
457/// Uses [`FormatGuardOptions::default`] guards.
458pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
459    format_with_options(sql, dialect, &FormatGuardOptions::default())
460}
461
462/// Format/pretty-print SQL statements with configurable guard limits.
463pub fn format_with_options(
464    sql: &str,
465    dialect: DialectType,
466    options: &FormatGuardOptions,
467) -> Result<Vec<String>> {
468    let d = Dialect::get(dialect);
469    format_with_dialect(sql, &d, options)
470}
471
472/// Validate SQL syntax.
473///
474/// # Arguments
475/// * `sql` - The SQL string to validate
476/// * `dialect` - The dialect to use for validation
477///
478/// # Returns
479/// A validation result with any errors found
480pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
481    validate_with_options(sql, dialect, &ValidationOptions::default())
482}
483
484/// Options for syntax validation behavior.
485#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
486#[serde(rename_all = "camelCase")]
487pub struct ValidationOptions {
488    /// When enabled, validation rejects non-canonical trailing commas that the parser
489    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
490    #[serde(default)]
491    pub strict_syntax: bool,
492}
493
494/// Validate SQL syntax with additional validation options.
495pub fn validate_with_options(
496    sql: &str,
497    dialect: DialectType,
498    options: &ValidationOptions,
499) -> ValidationResult {
500    let d = Dialect::get(dialect);
501    match d.parse(sql) {
502        Ok(expressions) => {
503            // Reject bare expressions that aren't valid SQL statements.
504            // The parser accepts any expression at the top level, but bare identifiers,
505            // literals, function calls, etc. are not valid statements.
506            for expr in &expressions {
507                if !expr.is_statement() {
508                    let msg = format!("Invalid expression / Unexpected token");
509                    return ValidationResult::with_errors(vec![ValidationError::error(
510                        msg, "E004",
511                    )]);
512                }
513            }
514            if options.strict_syntax {
515                if let Some(error) = strict_syntax_error(sql, &d) {
516                    return ValidationResult::with_errors(vec![error]);
517                }
518            }
519            ValidationResult::success()
520        }
521        Err(e) => {
522            let error = match &e {
523                Error::Syntax {
524                    message,
525                    line,
526                    column,
527                    start,
528                    end,
529                } => ValidationError::error(message.clone(), "E001")
530                    .with_location(*line, *column)
531                    .with_span(Some(*start), Some(*end)),
532                Error::Tokenize {
533                    message,
534                    line,
535                    column,
536                    start,
537                    end,
538                } => ValidationError::error(message.clone(), "E002")
539                    .with_location(*line, *column)
540                    .with_span(Some(*start), Some(*end)),
541                Error::Parse {
542                    message,
543                    line,
544                    column,
545                    start,
546                    end,
547                } => ValidationError::error(message.clone(), "E003")
548                    .with_location(*line, *column)
549                    .with_span(Some(*start), Some(*end)),
550                _ => ValidationError::error(e.to_string(), "E000"),
551            };
552            ValidationResult::with_errors(vec![error])
553        }
554    }
555}
556
557fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
558    let tokens = dialect.tokenize(sql).ok()?;
559
560    for (idx, token) in tokens.iter().enumerate() {
561        if token.token_type != TokenType::Comma {
562            continue;
563        }
564
565        let next = tokens.get(idx + 1);
566        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
567            Some(TokenType::From) => (true, "FROM"),
568            Some(TokenType::Where) => (true, "WHERE"),
569            Some(TokenType::GroupBy) => (true, "GROUP BY"),
570            Some(TokenType::Having) => (true, "HAVING"),
571            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
572            Some(TokenType::Limit) => (true, "LIMIT"),
573            Some(TokenType::Offset) => (true, "OFFSET"),
574            Some(TokenType::Union) => (true, "UNION"),
575            Some(TokenType::Intersect) => (true, "INTERSECT"),
576            Some(TokenType::Except) => (true, "EXCEPT"),
577            Some(TokenType::Qualify) => (true, "QUALIFY"),
578            Some(TokenType::Window) => (true, "WINDOW"),
579            Some(TokenType::Semicolon) | None => (true, "end of statement"),
580            _ => (false, ""),
581        };
582
583        if is_boundary {
584            let message = format!(
585                "Trailing comma before {} is not allowed in strict syntax mode",
586                boundary_name
587            );
588            return Some(
589                ValidationError::error(message, "E005")
590                    .with_location(token.span.line, token.span.column),
591            );
592        }
593    }
594
595    None
596}
597
598/// Transpile SQL from one dialect to another, using string dialect names.
599///
600/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
601/// custom dialects registered via [`CustomDialectBuilder`].
602///
603/// # Arguments
604/// * `sql` - The SQL string to transpile
605/// * `read` - The source dialect name
606/// * `write` - The target dialect name
607///
608/// # Returns
609/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
610pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
611    transpile_with_by_name(sql, read, write, &TranspileOptions::default())
612}
613
614/// Transpile SQL with configurable [`TranspileOptions`], using string dialect names.
615///
616/// Same as [`transpile_by_name`] but accepts options (e.g., pretty-printing).
617pub fn transpile_with_by_name(
618    sql: &str,
619    read: &str,
620    write: &str,
621    opts: &TranspileOptions,
622) -> Result<Vec<String>> {
623    let read_dialect = Dialect::get_by_name(read)
624        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
625    let write_dialect = Dialect::get_by_name(write)
626        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
627    read_dialect.transpile_with(sql, &write_dialect, opts.clone())
628}
629
630/// Parse SQL into an AST using a string dialect name.
631///
632/// Supports both built-in and custom dialect names.
633pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
634    let d = Dialect::get_by_name(dialect)
635        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
636    d.parse(sql)
637}
638
639/// Generate SQL from an AST using a string dialect name.
640///
641/// Supports both built-in and custom dialect names.
642pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
643    let d = Dialect::get_by_name(dialect)
644        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
645    d.generate(expression)
646}
647
648/// Format SQL using a string dialect name.
649///
650/// Uses [`FormatGuardOptions::default`] guards.
651pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
652    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
653}
654
655/// Format SQL using a string dialect name with configurable guard limits.
656pub fn format_with_options_by_name(
657    sql: &str,
658    dialect: &str,
659    options: &FormatGuardOptions,
660) -> Result<Vec<String>> {
661    let d = Dialect::get_by_name(dialect)
662        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
663    format_with_dialect(sql, &d, options)
664}
665
666#[cfg(test)]
667mod validation_tests {
668    use super::*;
669
670    #[test]
671    fn validate_is_permissive_by_default_for_trailing_commas() {
672        let result = validate("SELECT name, FROM employees", DialectType::Generic);
673        assert!(result.valid, "Result: {:?}", result.errors);
674    }
675
676    #[test]
677    fn validate_with_options_rejects_trailing_comma_before_from() {
678        let options = ValidationOptions {
679            strict_syntax: true,
680        };
681        let result = validate_with_options(
682            "SELECT name, FROM employees",
683            DialectType::Generic,
684            &options,
685        );
686        assert!(!result.valid, "Result should be invalid");
687        assert!(
688            result.errors.iter().any(|e| e.code == "E005"),
689            "Expected E005, got: {:?}",
690            result.errors
691        );
692    }
693
694    #[test]
695    fn validate_with_options_rejects_trailing_comma_before_where() {
696        let options = ValidationOptions {
697            strict_syntax: true,
698        };
699        let result = validate_with_options(
700            "SELECT name FROM employees, WHERE salary > 10",
701            DialectType::Generic,
702            &options,
703        );
704        assert!(!result.valid, "Result should be invalid");
705        assert!(
706            result.errors.iter().any(|e| e.code == "E005"),
707            "Expected E005, got: {:?}",
708            result.errors
709        );
710    }
711}
712
713#[cfg(test)]
714mod format_tests {
715    use super::*;
716
717    #[test]
718    fn format_basic_query() {
719        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
720        assert_eq!(result.len(), 1);
721        assert!(result[0].contains('\n'));
722    }
723
724    #[test]
725    fn format_guard_rejects_large_input() {
726        let options = FormatGuardOptions {
727            max_input_bytes: Some(7),
728            max_tokens: None,
729            max_ast_nodes: None,
730            max_set_op_chain: None,
731        };
732        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
733            .expect_err("expected guard error");
734        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
735    }
736
737    #[test]
738    fn format_guard_rejects_token_budget() {
739        let options = FormatGuardOptions {
740            max_input_bytes: None,
741            max_tokens: Some(1),
742            max_ast_nodes: None,
743            max_set_op_chain: None,
744        };
745        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
746            .expect_err("expected guard error");
747        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
748    }
749
750    #[test]
751    fn format_guard_rejects_ast_budget() {
752        let options = FormatGuardOptions {
753            max_input_bytes: None,
754            max_tokens: None,
755            max_ast_nodes: Some(1),
756            max_set_op_chain: None,
757        };
758        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
759            .expect_err("expected guard error");
760        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
761    }
762
763    #[test]
764    fn format_guard_rejects_set_op_chain_budget() {
765        let options = FormatGuardOptions {
766            max_input_bytes: None,
767            max_tokens: None,
768            max_ast_nodes: None,
769            max_set_op_chain: Some(1),
770        };
771        let err = format_with_options(
772            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
773            DialectType::Generic,
774            &options,
775        )
776        .expect_err("expected guard error");
777        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
778    }
779
780    #[test]
781    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
782        let options = FormatGuardOptions {
783            max_input_bytes: None,
784            max_tokens: None,
785            max_ast_nodes: None,
786            max_set_op_chain: Some(0),
787        };
788        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
789        assert!(result.is_ok(), "Result: {:?}", result);
790    }
791
792    #[test]
793    fn issue57_invalid_ternary_returns_error() {
794        // https://github.com/tobilg/polyglot/issues/57
795        // Invalid SQL with ternary operator should return an error, not garbled output.
796        let sql = "SELECT x > 0 ? 1 : 0 FROM t";
797
798        let parse_result = parse(sql, DialectType::PostgreSQL);
799        assert!(
800            parse_result.is_err(),
801            "Expected parse error for invalid ternary SQL, got: {:?}",
802            parse_result
803        );
804
805        let format_result = format(sql, DialectType::PostgreSQL);
806        assert!(
807            format_result.is_err(),
808            "Expected format error for invalid ternary SQL, got: {:?}",
809            format_result
810        );
811
812        let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
813        assert!(
814            transpile_result.is_err(),
815            "Expected transpile error for invalid ternary SQL, got: {:?}",
816            transpile_result
817        );
818    }
819
820    /// Regression guard: `lib::transpile()` must apply the full cross-dialect
821    /// rewrite pipeline (same as `Dialect::transpile()`). If these two paths
822    /// diverge again, Rust crate users silently get under-transformed SQL that
823    /// differs from what WASM/FFI/Python bindings produce.
824    #[test]
825    fn transpile_applies_cross_dialect_rewrites() {
826        // DuckDB to_timestamp → Trino FROM_UNIXTIME (different input semantics).
827        let out = transpile(
828            "SELECT to_timestamp(col) FROM t",
829            DialectType::DuckDB,
830            DialectType::Trino,
831        )
832        .expect("transpile failed");
833        assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
834
835        // DuckDB CAST(x AS JSON) → Trino JSON_PARSE(x) (different CAST semantics).
836        let out = transpile(
837            "SELECT CAST(col AS JSON) FROM t",
838            DialectType::DuckDB,
839            DialectType::Trino,
840        )
841        .expect("transpile failed");
842        assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
843    }
844
845    /// Regression guard: all three transpile entry points (lib::transpile,
846    /// lib::transpile_by_name, Dialect::transpile) must produce identical
847    /// output. transpile_by_name is the one used by Python and C FFI bindings.
848    #[test]
849    fn transpile_matches_dialect_method() {
850        let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
851            (
852                DialectType::DuckDB,
853                DialectType::Trino,
854                "duckdb",
855                "trino",
856                "SELECT to_timestamp(col) FROM t",
857            ),
858            (
859                DialectType::DuckDB,
860                DialectType::Trino,
861                "duckdb",
862                "trino",
863                "SELECT CAST(col AS JSON) FROM t",
864            ),
865            (
866                DialectType::DuckDB,
867                DialectType::Trino,
868                "duckdb",
869                "trino",
870                "SELECT json_valid(col) FROM t",
871            ),
872            (
873                DialectType::Snowflake,
874                DialectType::DuckDB,
875                "snowflake",
876                "duckdb",
877                "SELECT DATEDIFF(day, a, b) FROM t",
878            ),
879            (
880                DialectType::BigQuery,
881                DialectType::DuckDB,
882                "bigquery",
883                "duckdb",
884                "SELECT DATE_DIFF(a, b, DAY) FROM t",
885            ),
886            (
887                DialectType::Generic,
888                DialectType::Generic,
889                "generic",
890                "generic",
891                "SELECT 1",
892            ),
893        ];
894        for (read, write, read_name, write_name, sql) in cases {
895            let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
896            let via_name = transpile_by_name(sql, read_name, write_name)
897                .expect("lib::transpile_by_name failed");
898            let via_dialect = Dialect::get(*read)
899                .transpile(sql, *write)
900                .expect("Dialect::transpile failed");
901            assert_eq!(
902                via_lib, via_dialect,
903                "lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
904                read, write
905            );
906            assert_eq!(
907                via_name, via_dialect,
908                "lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
909            );
910        }
911    }
912
913    #[test]
914    fn format_default_guard_rejects_deep_union_chain_before_parse() {
915        let base = "SELECT col0, col1 FROM t";
916        let mut sql = base.to_string();
917        for _ in 0..1100 {
918            sql.push_str(" UNION ALL ");
919            sql.push_str(base);
920        }
921
922        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
923        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
924    }
925}