Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43    get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
44    get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
45    remove_select_columns, remove_where, rename_columns, rename_tables, replace_by_type,
46    replace_nodes, set_distinct, set_limit, set_offset,
47};
48pub use dialects::{
49    unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType, TranspileOptions,
50    TranspileTarget,
51};
52pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
53pub use expressions::Expression;
54pub use function_catalog::{
55    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
56};
57pub use generator::Generator;
58pub use helper::{
59    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
60    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
61};
62pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
63pub use parser::Parser;
64pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
65pub use schema::{
66    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
67};
68pub use scope::{
69    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
70    ScopeType, SourceInfo,
71};
72pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
73pub use tokens::{Token, TokenType, Tokenizer};
74pub use traversal::{
75    contains_aggregate,
76    contains_subquery,
77    contains_window_function,
78    find_ancestor,
79    find_parent,
80    get_all_tables,
81    get_columns,
82    get_merge_source,
83    get_merge_target,
84    get_tables,
85    is_add,
86    is_aggregate,
87    is_alias,
88    is_alter_table,
89    is_and,
90    is_arithmetic,
91    is_avg,
92    is_between,
93    is_boolean,
94    is_case,
95    is_cast,
96    is_coalesce,
97    is_column,
98    is_comparison,
99    is_concat,
100    is_count,
101    is_create_index,
102    is_create_table,
103    is_create_view,
104    is_cte,
105    is_ddl,
106    is_delete,
107    is_div,
108    is_drop_index,
109    is_drop_table,
110    is_drop_view,
111    is_eq,
112    is_except,
113    is_exists,
114    is_from,
115    is_function,
116    is_group_by,
117    is_gt,
118    is_gte,
119    is_having,
120    is_identifier,
121    is_ilike,
122    is_in,
123    // Extended type predicates
124    is_insert,
125    is_intersect,
126    is_is_null,
127    is_join,
128    is_like,
129    is_limit,
130    is_literal,
131    is_logical,
132    is_lt,
133    is_lte,
134    is_max_func,
135    is_merge,
136    is_min_func,
137    is_mod,
138    is_mul,
139    is_neq,
140    is_not,
141    is_null_if,
142    is_null_literal,
143    is_offset,
144    is_or,
145    is_order_by,
146    is_ordered,
147    is_paren,
148    // Composite predicates
149    is_query,
150    is_safe_cast,
151    is_select,
152    is_set_operation,
153    is_star,
154    is_sub,
155    is_subquery,
156    is_sum,
157    is_table,
158    is_try_cast,
159    is_union,
160    is_update,
161    is_where,
162    is_window_function,
163    is_with,
164    transform,
165    transform_map,
166    BfsIter,
167    DfsIter,
168    ExpressionWalk,
169    ParentInfo,
170    TreeContext,
171};
172pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
173pub use validation::{
174    mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
175    SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
176    SchemaValidationOptions, ValidationSchema,
177};
178
179const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
180const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
181const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
182const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
183
184fn default_format_max_input_bytes() -> Option<usize> {
185    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
186}
187
188fn default_format_max_tokens() -> Option<usize> {
189    Some(DEFAULT_FORMAT_MAX_TOKENS)
190}
191
192fn default_format_max_ast_nodes() -> Option<usize> {
193    Some(DEFAULT_FORMAT_MAX_AST_NODES)
194}
195
196fn default_format_max_set_op_chain() -> Option<usize> {
197    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
198}
199
200/// Guard options for SQL pretty-formatting.
201///
202/// These limits protect against extremely large/complex queries that can cause
203/// high memory pressure in constrained runtimes (for example browser WASM).
204#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
205#[serde(rename_all = "camelCase")]
206pub struct FormatGuardOptions {
207    /// Maximum allowed SQL input size in bytes.
208    /// `None` disables this check.
209    #[serde(default = "default_format_max_input_bytes")]
210    pub max_input_bytes: Option<usize>,
211    /// Maximum allowed number of tokens after tokenization.
212    /// `None` disables this check.
213    #[serde(default = "default_format_max_tokens")]
214    pub max_tokens: Option<usize>,
215    /// Maximum allowed AST node count after parsing.
216    /// `None` disables this check.
217    #[serde(default = "default_format_max_ast_nodes")]
218    pub max_ast_nodes: Option<usize>,
219    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
220    /// observed in a statement before parsing.
221    ///
222    /// `None` disables this check.
223    #[serde(default = "default_format_max_set_op_chain")]
224    pub max_set_op_chain: Option<usize>,
225}
226
227impl Default for FormatGuardOptions {
228    fn default() -> Self {
229        Self {
230            max_input_bytes: default_format_max_input_bytes(),
231            max_tokens: default_format_max_tokens(),
232            max_ast_nodes: default_format_max_ast_nodes(),
233            max_set_op_chain: default_format_max_set_op_chain(),
234        }
235    }
236}
237
238fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
239    Error::generate(format!(
240        "{code}: value {actual} exceeds configured limit {limit}"
241    ))
242}
243
244fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
245    if let Some(max) = options.max_input_bytes {
246        let input_bytes = sql.len();
247        if input_bytes > max {
248            return Err(format_guard_error(
249                "E_GUARD_INPUT_TOO_LARGE",
250                input_bytes,
251                max,
252            ));
253        }
254    }
255    Ok(())
256}
257
258fn parse_with_token_guard(
259    sql: &str,
260    dialect: &Dialect,
261    options: &FormatGuardOptions,
262) -> Result<Vec<Expression>> {
263    let tokens = dialect.tokenize(sql)?;
264    if let Some(max) = options.max_tokens {
265        let token_count = tokens.len();
266        if token_count > max {
267            return Err(format_guard_error(
268                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
269                token_count,
270                max,
271            ));
272        }
273    }
274    enforce_set_op_chain_guard(&tokens, options)?;
275
276    let config = crate::parser::ParserConfig {
277        dialect: Some(dialect.dialect_type()),
278        ..Default::default()
279    };
280    let mut parser = Parser::with_source(tokens, config, sql.to_string());
281    parser.parse()
282}
283
284fn is_trivia_token(token_type: TokenType) -> bool {
285    matches!(
286        token_type,
287        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
288    )
289}
290
291fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
292    tokens
293        .iter()
294        .skip(start)
295        .find(|token| !is_trivia_token(token.token_type))
296}
297
298fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
299    let token = &tokens[idx];
300    match token.token_type {
301        TokenType::Union | TokenType::Intersect => true,
302        TokenType::Except => {
303            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
304            // is a function call rather than a set operation.
305            if token.text.eq_ignore_ascii_case("minus")
306                && matches!(
307                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
308                    Some(TokenType::LParen)
309                )
310            {
311                return false;
312            }
313            true
314        }
315        _ => false,
316    }
317}
318
319fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
320    let Some(max) = options.max_set_op_chain else {
321        return Ok(());
322    };
323
324    let mut set_op_count = 0usize;
325    for (idx, token) in tokens.iter().enumerate() {
326        if token.token_type == TokenType::Semicolon {
327            set_op_count = 0;
328            continue;
329        }
330
331        if is_set_operation_token(tokens, idx) {
332            set_op_count += 1;
333            if set_op_count > max {
334                return Err(format_guard_error(
335                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
336                    set_op_count,
337                    max,
338                ));
339            }
340        }
341    }
342
343    Ok(())
344}
345
346fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
347    if let Some(max) = options.max_ast_nodes {
348        let ast_nodes: usize = expressions.iter().map(node_count).sum();
349        if ast_nodes > max {
350            return Err(format_guard_error(
351                "E_GUARD_AST_BUDGET_EXCEEDED",
352                ast_nodes,
353                max,
354            ));
355        }
356    }
357    Ok(())
358}
359
360fn format_with_dialect(
361    sql: &str,
362    dialect: &Dialect,
363    options: &FormatGuardOptions,
364) -> Result<Vec<String>> {
365    enforce_input_guard(sql, options)?;
366    let expressions = parse_with_token_guard(sql, dialect, options)?;
367    enforce_ast_guard(&expressions, options)?;
368
369    expressions
370        .iter()
371        .map(|expr| dialect.generate_pretty(expr))
372        .collect()
373}
374
375/// Transpile SQL from one dialect to another.
376///
377/// # Arguments
378/// * `sql` - The SQL string to transpile
379/// * `read` - The source dialect to parse with
380/// * `write` - The target dialect to generate
381///
382/// # Returns
383/// A vector of transpiled SQL statements
384///
385/// # Example
386/// ```
387/// use polyglot_sql::{transpile, DialectType};
388///
389/// let result = transpile(
390///     "SELECT EPOCH_MS(1618088028295)",
391///     DialectType::DuckDB,
392///     DialectType::Hive
393/// );
394/// ```
395pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
396    // Delegate to Dialect::transpile so that the full cross-dialect rewrite
397    // pipeline (source+target-aware normalization in `cross_dialect_normalize`)
398    // runs here as well. This keeps Rust crate users on the same code path as
399    // the WASM/FFI/Python bindings and the playground.
400    Dialect::get(read).transpile(sql, write)
401}
402
403/// Parse SQL into an AST.
404///
405/// # Arguments
406/// * `sql` - The SQL string to parse
407/// * `dialect` - The dialect to use for parsing
408///
409/// # Returns
410/// A vector of parsed expressions
411pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
412    let d = Dialect::get(dialect);
413    d.parse(sql)
414}
415
416/// Parse a single SQL statement.
417///
418/// # Arguments
419/// * `sql` - The SQL string containing a single statement
420/// * `dialect` - The dialect to use for parsing
421///
422/// # Returns
423/// The parsed expression, or an error if multiple statements found
424pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
425    let mut expressions = parse(sql, dialect)?;
426
427    if expressions.len() != 1 {
428        return Err(Error::parse(
429            format!("Expected 1 statement, found {}", expressions.len()),
430            0,
431            0,
432            0,
433            0,
434        ));
435    }
436
437    Ok(expressions.remove(0))
438}
439
440/// Generate SQL from an AST.
441///
442/// # Arguments
443/// * `expression` - The expression to generate SQL from
444/// * `dialect` - The target dialect
445///
446/// # Returns
447/// The generated SQL string
448pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
449    let d = Dialect::get(dialect);
450    d.generate(expression)
451}
452
453/// Format/pretty-print SQL statements.
454///
455/// Uses [`FormatGuardOptions::default`] guards.
456pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
457    format_with_options(sql, dialect, &FormatGuardOptions::default())
458}
459
460/// Format/pretty-print SQL statements with configurable guard limits.
461pub fn format_with_options(
462    sql: &str,
463    dialect: DialectType,
464    options: &FormatGuardOptions,
465) -> Result<Vec<String>> {
466    let d = Dialect::get(dialect);
467    format_with_dialect(sql, &d, options)
468}
469
470/// Validate SQL syntax.
471///
472/// # Arguments
473/// * `sql` - The SQL string to validate
474/// * `dialect` - The dialect to use for validation
475///
476/// # Returns
477/// A validation result with any errors found
478pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
479    validate_with_options(sql, dialect, &ValidationOptions::default())
480}
481
482/// Options for syntax validation behavior.
483#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
484#[serde(rename_all = "camelCase")]
485pub struct ValidationOptions {
486    /// When enabled, validation rejects non-canonical trailing commas that the parser
487    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
488    #[serde(default)]
489    pub strict_syntax: bool,
490}
491
492/// Validate SQL syntax with additional validation options.
493pub fn validate_with_options(
494    sql: &str,
495    dialect: DialectType,
496    options: &ValidationOptions,
497) -> ValidationResult {
498    let d = Dialect::get(dialect);
499    match d.parse(sql) {
500        Ok(expressions) => {
501            // Reject bare expressions that aren't valid SQL statements.
502            // The parser accepts any expression at the top level, but bare identifiers,
503            // literals, function calls, etc. are not valid statements.
504            for expr in &expressions {
505                if !expr.is_statement() {
506                    let msg = format!("Invalid expression / Unexpected token");
507                    return ValidationResult::with_errors(vec![ValidationError::error(
508                        msg, "E004",
509                    )]);
510                }
511            }
512            if options.strict_syntax {
513                if let Some(error) = strict_syntax_error(sql, &d) {
514                    return ValidationResult::with_errors(vec![error]);
515                }
516            }
517            ValidationResult::success()
518        }
519        Err(e) => {
520            let error = match &e {
521                Error::Syntax {
522                    message,
523                    line,
524                    column,
525                    start,
526                    end,
527                } => ValidationError::error(message.clone(), "E001")
528                    .with_location(*line, *column)
529                    .with_span(Some(*start), Some(*end)),
530                Error::Tokenize {
531                    message,
532                    line,
533                    column,
534                    start,
535                    end,
536                } => ValidationError::error(message.clone(), "E002")
537                    .with_location(*line, *column)
538                    .with_span(Some(*start), Some(*end)),
539                Error::Parse {
540                    message,
541                    line,
542                    column,
543                    start,
544                    end,
545                } => ValidationError::error(message.clone(), "E003")
546                    .with_location(*line, *column)
547                    .with_span(Some(*start), Some(*end)),
548                _ => ValidationError::error(e.to_string(), "E000"),
549            };
550            ValidationResult::with_errors(vec![error])
551        }
552    }
553}
554
555fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
556    let tokens = dialect.tokenize(sql).ok()?;
557
558    for (idx, token) in tokens.iter().enumerate() {
559        if token.token_type != TokenType::Comma {
560            continue;
561        }
562
563        let next = tokens.get(idx + 1);
564        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
565            Some(TokenType::From) => (true, "FROM"),
566            Some(TokenType::Where) => (true, "WHERE"),
567            Some(TokenType::GroupBy) => (true, "GROUP BY"),
568            Some(TokenType::Having) => (true, "HAVING"),
569            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
570            Some(TokenType::Limit) => (true, "LIMIT"),
571            Some(TokenType::Offset) => (true, "OFFSET"),
572            Some(TokenType::Union) => (true, "UNION"),
573            Some(TokenType::Intersect) => (true, "INTERSECT"),
574            Some(TokenType::Except) => (true, "EXCEPT"),
575            Some(TokenType::Qualify) => (true, "QUALIFY"),
576            Some(TokenType::Window) => (true, "WINDOW"),
577            Some(TokenType::Semicolon) | None => (true, "end of statement"),
578            _ => (false, ""),
579        };
580
581        if is_boundary {
582            let message = format!(
583                "Trailing comma before {} is not allowed in strict syntax mode",
584                boundary_name
585            );
586            return Some(
587                ValidationError::error(message, "E005")
588                    .with_location(token.span.line, token.span.column),
589            );
590        }
591    }
592
593    None
594}
595
596/// Transpile SQL from one dialect to another, using string dialect names.
597///
598/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
599/// custom dialects registered via [`CustomDialectBuilder`].
600///
601/// # Arguments
602/// * `sql` - The SQL string to transpile
603/// * `read` - The source dialect name
604/// * `write` - The target dialect name
605///
606/// # Returns
607/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
608pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
609    transpile_with_by_name(sql, read, write, &TranspileOptions::default())
610}
611
612/// Transpile SQL with configurable [`TranspileOptions`], using string dialect names.
613///
614/// Same as [`transpile_by_name`] but accepts options (e.g., pretty-printing).
615pub fn transpile_with_by_name(
616    sql: &str,
617    read: &str,
618    write: &str,
619    opts: &TranspileOptions,
620) -> Result<Vec<String>> {
621    let read_dialect = Dialect::get_by_name(read)
622        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
623    let write_dialect = Dialect::get_by_name(write)
624        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
625    read_dialect.transpile_with(sql, &write_dialect, opts.clone())
626}
627
628/// Parse SQL into an AST using a string dialect name.
629///
630/// Supports both built-in and custom dialect names.
631pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
632    let d = Dialect::get_by_name(dialect)
633        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
634    d.parse(sql)
635}
636
637/// Generate SQL from an AST using a string dialect name.
638///
639/// Supports both built-in and custom dialect names.
640pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
641    let d = Dialect::get_by_name(dialect)
642        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
643    d.generate(expression)
644}
645
646/// Format SQL using a string dialect name.
647///
648/// Uses [`FormatGuardOptions::default`] guards.
649pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
650    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
651}
652
653/// Format SQL using a string dialect name with configurable guard limits.
654pub fn format_with_options_by_name(
655    sql: &str,
656    dialect: &str,
657    options: &FormatGuardOptions,
658) -> Result<Vec<String>> {
659    let d = Dialect::get_by_name(dialect)
660        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
661    format_with_dialect(sql, &d, options)
662}
663
664#[cfg(test)]
665mod validation_tests {
666    use super::*;
667
668    #[test]
669    fn validate_is_permissive_by_default_for_trailing_commas() {
670        let result = validate("SELECT name, FROM employees", DialectType::Generic);
671        assert!(result.valid, "Result: {:?}", result.errors);
672    }
673
674    #[test]
675    fn validate_with_options_rejects_trailing_comma_before_from() {
676        let options = ValidationOptions {
677            strict_syntax: true,
678        };
679        let result = validate_with_options(
680            "SELECT name, FROM employees",
681            DialectType::Generic,
682            &options,
683        );
684        assert!(!result.valid, "Result should be invalid");
685        assert!(
686            result.errors.iter().any(|e| e.code == "E005"),
687            "Expected E005, got: {:?}",
688            result.errors
689        );
690    }
691
692    #[test]
693    fn validate_with_options_rejects_trailing_comma_before_where() {
694        let options = ValidationOptions {
695            strict_syntax: true,
696        };
697        let result = validate_with_options(
698            "SELECT name FROM employees, WHERE salary > 10",
699            DialectType::Generic,
700            &options,
701        );
702        assert!(!result.valid, "Result should be invalid");
703        assert!(
704            result.errors.iter().any(|e| e.code == "E005"),
705            "Expected E005, got: {:?}",
706            result.errors
707        );
708    }
709}
710
711#[cfg(test)]
712mod format_tests {
713    use super::*;
714
715    #[test]
716    fn format_basic_query() {
717        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
718        assert_eq!(result.len(), 1);
719        assert!(result[0].contains('\n'));
720    }
721
722    #[test]
723    fn format_guard_rejects_large_input() {
724        let options = FormatGuardOptions {
725            max_input_bytes: Some(7),
726            max_tokens: None,
727            max_ast_nodes: None,
728            max_set_op_chain: None,
729        };
730        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
731            .expect_err("expected guard error");
732        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
733    }
734
735    #[test]
736    fn format_guard_rejects_token_budget() {
737        let options = FormatGuardOptions {
738            max_input_bytes: None,
739            max_tokens: Some(1),
740            max_ast_nodes: None,
741            max_set_op_chain: None,
742        };
743        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
744            .expect_err("expected guard error");
745        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
746    }
747
748    #[test]
749    fn format_guard_rejects_ast_budget() {
750        let options = FormatGuardOptions {
751            max_input_bytes: None,
752            max_tokens: None,
753            max_ast_nodes: Some(1),
754            max_set_op_chain: None,
755        };
756        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
757            .expect_err("expected guard error");
758        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
759    }
760
761    #[test]
762    fn format_guard_rejects_set_op_chain_budget() {
763        let options = FormatGuardOptions {
764            max_input_bytes: None,
765            max_tokens: None,
766            max_ast_nodes: None,
767            max_set_op_chain: Some(1),
768        };
769        let err = format_with_options(
770            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
771            DialectType::Generic,
772            &options,
773        )
774        .expect_err("expected guard error");
775        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
776    }
777
778    #[test]
779    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
780        let options = FormatGuardOptions {
781            max_input_bytes: None,
782            max_tokens: None,
783            max_ast_nodes: None,
784            max_set_op_chain: Some(0),
785        };
786        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
787        assert!(result.is_ok(), "Result: {:?}", result);
788    }
789
790    #[test]
791    fn issue57_invalid_ternary_returns_error() {
792        // https://github.com/tobilg/polyglot/issues/57
793        // Invalid SQL with ternary operator should return an error, not garbled output.
794        let sql = "SELECT x > 0 ? 1 : 0 FROM t";
795
796        let parse_result = parse(sql, DialectType::PostgreSQL);
797        assert!(
798            parse_result.is_err(),
799            "Expected parse error for invalid ternary SQL, got: {:?}",
800            parse_result
801        );
802
803        let format_result = format(sql, DialectType::PostgreSQL);
804        assert!(
805            format_result.is_err(),
806            "Expected format error for invalid ternary SQL, got: {:?}",
807            format_result
808        );
809
810        let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
811        assert!(
812            transpile_result.is_err(),
813            "Expected transpile error for invalid ternary SQL, got: {:?}",
814            transpile_result
815        );
816    }
817
818    /// Regression guard: `lib::transpile()` must apply the full cross-dialect
819    /// rewrite pipeline (same as `Dialect::transpile()`). If these two paths
820    /// diverge again, Rust crate users silently get under-transformed SQL that
821    /// differs from what WASM/FFI/Python bindings produce.
822    #[test]
823    fn transpile_applies_cross_dialect_rewrites() {
824        // DuckDB to_timestamp → Trino FROM_UNIXTIME (different input semantics).
825        let out = transpile(
826            "SELECT to_timestamp(col) FROM t",
827            DialectType::DuckDB,
828            DialectType::Trino,
829        )
830        .expect("transpile failed");
831        assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
832
833        // DuckDB CAST(x AS JSON) → Trino JSON_PARSE(x) (different CAST semantics).
834        let out = transpile(
835            "SELECT CAST(col AS JSON) FROM t",
836            DialectType::DuckDB,
837            DialectType::Trino,
838        )
839        .expect("transpile failed");
840        assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
841    }
842
843    /// Regression guard: all three transpile entry points (lib::transpile,
844    /// lib::transpile_by_name, Dialect::transpile) must produce identical
845    /// output. transpile_by_name is the one used by Python and C FFI bindings.
846    #[test]
847    fn transpile_matches_dialect_method() {
848        let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
849            (DialectType::DuckDB, DialectType::Trino, "duckdb", "trino",
850             "SELECT to_timestamp(col) FROM t"),
851            (DialectType::DuckDB, DialectType::Trino, "duckdb", "trino",
852             "SELECT CAST(col AS JSON) FROM t"),
853            (DialectType::DuckDB, DialectType::Trino, "duckdb", "trino",
854             "SELECT json_valid(col) FROM t"),
855            (DialectType::Snowflake, DialectType::DuckDB, "snowflake", "duckdb",
856             "SELECT DATEDIFF(day, a, b) FROM t"),
857            (DialectType::BigQuery, DialectType::DuckDB, "bigquery", "duckdb",
858             "SELECT DATE_DIFF(a, b, DAY) FROM t"),
859            (DialectType::Generic, DialectType::Generic, "generic", "generic", "SELECT 1"),
860        ];
861        for (read, write, read_name, write_name, sql) in cases {
862            let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
863            let via_name = transpile_by_name(sql, read_name, write_name)
864                .expect("lib::transpile_by_name failed");
865            let via_dialect = Dialect::get(*read)
866                .transpile(sql, *write)
867                .expect("Dialect::transpile failed");
868            assert_eq!(
869                via_lib, via_dialect,
870                "lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
871                read, write
872            );
873            assert_eq!(
874                via_name, via_dialect,
875                "lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
876            );
877        }
878    }
879
880    #[test]
881    fn format_default_guard_rejects_deep_union_chain_before_parse() {
882        let base = "SELECT col0, col1 FROM t";
883        let mut sql = base.to_string();
884        for _ in 0..1100 {
885            sql.push_str(" UNION ALL ");
886            sql.push_str(base);
887        }
888
889        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
890        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
891    }
892}