Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43    get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
44    get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
45    remove_select_columns, remove_where, rename_columns, rename_tables, replace_by_type,
46    replace_nodes, set_distinct, set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67    ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72    contains_aggregate,
73    contains_subquery,
74    contains_window_function,
75    find_ancestor,
76    find_parent,
77    get_columns,
78    get_merge_source,
79    get_merge_target,
80    get_tables,
81    is_add,
82    is_aggregate,
83    is_alias,
84    is_alter_table,
85    is_and,
86    is_arithmetic,
87    is_avg,
88    is_between,
89    is_boolean,
90    is_case,
91    is_cast,
92    is_coalesce,
93    is_column,
94    is_comparison,
95    is_concat,
96    is_count,
97    is_create_index,
98    is_create_table,
99    is_create_view,
100    is_cte,
101    is_ddl,
102    is_delete,
103    is_div,
104    is_drop_index,
105    is_drop_table,
106    is_drop_view,
107    is_eq,
108    is_except,
109    is_exists,
110    is_from,
111    is_function,
112    is_group_by,
113    is_gt,
114    is_gte,
115    is_having,
116    is_identifier,
117    is_ilike,
118    is_in,
119    // Extended type predicates
120    is_insert,
121    is_intersect,
122    is_is_null,
123    is_join,
124    is_like,
125    is_limit,
126    is_literal,
127    is_logical,
128    is_merge,
129    is_lt,
130    is_lte,
131    is_max_func,
132    is_min_func,
133    is_mod,
134    is_mul,
135    is_neq,
136    is_not,
137    is_null_if,
138    is_null_literal,
139    is_offset,
140    is_or,
141    is_order_by,
142    is_ordered,
143    is_paren,
144    // Composite predicates
145    is_query,
146    is_safe_cast,
147    is_select,
148    is_set_operation,
149    is_star,
150    is_sub,
151    is_subquery,
152    is_sum,
153    is_table,
154    is_try_cast,
155    is_union,
156    is_update,
157    is_where,
158    is_window_function,
159    is_with,
160    transform,
161    transform_map,
162    BfsIter,
163    DfsIter,
164    ExpressionWalk,
165    ParentInfo,
166    TreeContext,
167};
168pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
169pub use validation::{
170    mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
171    SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
172    SchemaValidationOptions, ValidationSchema,
173};
174
175const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
176const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
177const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
178const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
179
180fn default_format_max_input_bytes() -> Option<usize> {
181    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
182}
183
184fn default_format_max_tokens() -> Option<usize> {
185    Some(DEFAULT_FORMAT_MAX_TOKENS)
186}
187
188fn default_format_max_ast_nodes() -> Option<usize> {
189    Some(DEFAULT_FORMAT_MAX_AST_NODES)
190}
191
192fn default_format_max_set_op_chain() -> Option<usize> {
193    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
194}
195
196/// Guard options for SQL pretty-formatting.
197///
198/// These limits protect against extremely large/complex queries that can cause
199/// high memory pressure in constrained runtimes (for example browser WASM).
200#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
201#[serde(rename_all = "camelCase")]
202pub struct FormatGuardOptions {
203    /// Maximum allowed SQL input size in bytes.
204    /// `None` disables this check.
205    #[serde(default = "default_format_max_input_bytes")]
206    pub max_input_bytes: Option<usize>,
207    /// Maximum allowed number of tokens after tokenization.
208    /// `None` disables this check.
209    #[serde(default = "default_format_max_tokens")]
210    pub max_tokens: Option<usize>,
211    /// Maximum allowed AST node count after parsing.
212    /// `None` disables this check.
213    #[serde(default = "default_format_max_ast_nodes")]
214    pub max_ast_nodes: Option<usize>,
215    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
216    /// observed in a statement before parsing.
217    ///
218    /// `None` disables this check.
219    #[serde(default = "default_format_max_set_op_chain")]
220    pub max_set_op_chain: Option<usize>,
221}
222
223impl Default for FormatGuardOptions {
224    fn default() -> Self {
225        Self {
226            max_input_bytes: default_format_max_input_bytes(),
227            max_tokens: default_format_max_tokens(),
228            max_ast_nodes: default_format_max_ast_nodes(),
229            max_set_op_chain: default_format_max_set_op_chain(),
230        }
231    }
232}
233
234fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
235    Error::generate(format!(
236        "{code}: value {actual} exceeds configured limit {limit}"
237    ))
238}
239
240fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
241    if let Some(max) = options.max_input_bytes {
242        let input_bytes = sql.len();
243        if input_bytes > max {
244            return Err(format_guard_error(
245                "E_GUARD_INPUT_TOO_LARGE",
246                input_bytes,
247                max,
248            ));
249        }
250    }
251    Ok(())
252}
253
254fn parse_with_token_guard(
255    sql: &str,
256    dialect: &Dialect,
257    options: &FormatGuardOptions,
258) -> Result<Vec<Expression>> {
259    let tokens = dialect.tokenize(sql)?;
260    if let Some(max) = options.max_tokens {
261        let token_count = tokens.len();
262        if token_count > max {
263            return Err(format_guard_error(
264                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
265                token_count,
266                max,
267            ));
268        }
269    }
270    enforce_set_op_chain_guard(&tokens, options)?;
271
272    let config = crate::parser::ParserConfig {
273        dialect: Some(dialect.dialect_type()),
274        ..Default::default()
275    };
276    let mut parser = Parser::with_source(tokens, config, sql.to_string());
277    parser.parse()
278}
279
280fn is_trivia_token(token_type: TokenType) -> bool {
281    matches!(
282        token_type,
283        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
284    )
285}
286
287fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
288    tokens
289        .iter()
290        .skip(start)
291        .find(|token| !is_trivia_token(token.token_type))
292}
293
294fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
295    let token = &tokens[idx];
296    match token.token_type {
297        TokenType::Union | TokenType::Intersect => true,
298        TokenType::Except => {
299            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
300            // is a function call rather than a set operation.
301            if token.text.eq_ignore_ascii_case("minus")
302                && matches!(
303                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
304                    Some(TokenType::LParen)
305                )
306            {
307                return false;
308            }
309            true
310        }
311        _ => false,
312    }
313}
314
315fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
316    let Some(max) = options.max_set_op_chain else {
317        return Ok(());
318    };
319
320    let mut set_op_count = 0usize;
321    for (idx, token) in tokens.iter().enumerate() {
322        if token.token_type == TokenType::Semicolon {
323            set_op_count = 0;
324            continue;
325        }
326
327        if is_set_operation_token(tokens, idx) {
328            set_op_count += 1;
329            if set_op_count > max {
330                return Err(format_guard_error(
331                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
332                    set_op_count,
333                    max,
334                ));
335            }
336        }
337    }
338
339    Ok(())
340}
341
342fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
343    if let Some(max) = options.max_ast_nodes {
344        let ast_nodes: usize = expressions.iter().map(node_count).sum();
345        if ast_nodes > max {
346            return Err(format_guard_error(
347                "E_GUARD_AST_BUDGET_EXCEEDED",
348                ast_nodes,
349                max,
350            ));
351        }
352    }
353    Ok(())
354}
355
356fn format_with_dialect(
357    sql: &str,
358    dialect: &Dialect,
359    options: &FormatGuardOptions,
360) -> Result<Vec<String>> {
361    enforce_input_guard(sql, options)?;
362    let expressions = parse_with_token_guard(sql, dialect, options)?;
363    enforce_ast_guard(&expressions, options)?;
364
365    expressions
366        .iter()
367        .map(|expr| dialect.generate_pretty(expr))
368        .collect()
369}
370
371/// Transpile SQL from one dialect to another.
372///
373/// # Arguments
374/// * `sql` - The SQL string to transpile
375/// * `read` - The source dialect to parse with
376/// * `write` - The target dialect to generate
377///
378/// # Returns
379/// A vector of transpiled SQL statements
380///
381/// # Example
382/// ```
383/// use polyglot_sql::{transpile, DialectType};
384///
385/// let result = transpile(
386///     "SELECT EPOCH_MS(1618088028295)",
387///     DialectType::DuckDB,
388///     DialectType::Hive
389/// );
390/// ```
391pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
392    let read_dialect = Dialect::get(read);
393    let write_dialect = Dialect::get(write);
394    let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
395
396    let expressions = read_dialect.parse(sql)?;
397
398    expressions
399        .into_iter()
400        .map(|expr| {
401            if generic_identity {
402                write_dialect.generate_with_source(&expr, read)
403            } else {
404                let transformed = write_dialect.transform(expr)?;
405                write_dialect.generate_with_source(&transformed, read)
406            }
407        })
408        .collect()
409}
410
411/// Parse SQL into an AST.
412///
413/// # Arguments
414/// * `sql` - The SQL string to parse
415/// * `dialect` - The dialect to use for parsing
416///
417/// # Returns
418/// A vector of parsed expressions
419pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
420    let d = Dialect::get(dialect);
421    d.parse(sql)
422}
423
424/// Parse a single SQL statement.
425///
426/// # Arguments
427/// * `sql` - The SQL string containing a single statement
428/// * `dialect` - The dialect to use for parsing
429///
430/// # Returns
431/// The parsed expression, or an error if multiple statements found
432pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
433    let mut expressions = parse(sql, dialect)?;
434
435    if expressions.len() != 1 {
436        return Err(Error::parse(
437            format!("Expected 1 statement, found {}", expressions.len()),
438            0,
439            0,
440            0,
441            0,
442        ));
443    }
444
445    Ok(expressions.remove(0))
446}
447
448/// Generate SQL from an AST.
449///
450/// # Arguments
451/// * `expression` - The expression to generate SQL from
452/// * `dialect` - The target dialect
453///
454/// # Returns
455/// The generated SQL string
456pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
457    let d = Dialect::get(dialect);
458    d.generate(expression)
459}
460
461/// Format/pretty-print SQL statements.
462///
463/// Uses [`FormatGuardOptions::default`] guards.
464pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
465    format_with_options(sql, dialect, &FormatGuardOptions::default())
466}
467
468/// Format/pretty-print SQL statements with configurable guard limits.
469pub fn format_with_options(
470    sql: &str,
471    dialect: DialectType,
472    options: &FormatGuardOptions,
473) -> Result<Vec<String>> {
474    let d = Dialect::get(dialect);
475    format_with_dialect(sql, &d, options)
476}
477
478/// Validate SQL syntax.
479///
480/// # Arguments
481/// * `sql` - The SQL string to validate
482/// * `dialect` - The dialect to use for validation
483///
484/// # Returns
485/// A validation result with any errors found
486pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
487    validate_with_options(sql, dialect, &ValidationOptions::default())
488}
489
490/// Options for syntax validation behavior.
491#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
492#[serde(rename_all = "camelCase")]
493pub struct ValidationOptions {
494    /// When enabled, validation rejects non-canonical trailing commas that the parser
495    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
496    #[serde(default)]
497    pub strict_syntax: bool,
498}
499
500/// Validate SQL syntax with additional validation options.
501pub fn validate_with_options(
502    sql: &str,
503    dialect: DialectType,
504    options: &ValidationOptions,
505) -> ValidationResult {
506    let d = Dialect::get(dialect);
507    match d.parse(sql) {
508        Ok(expressions) => {
509            // Reject bare expressions that aren't valid SQL statements.
510            // The parser accepts any expression at the top level, but bare identifiers,
511            // literals, function calls, etc. are not valid statements.
512            for expr in &expressions {
513                if !expr.is_statement() {
514                    let msg = format!("Invalid expression / Unexpected token");
515                    return ValidationResult::with_errors(vec![ValidationError::error(
516                        msg, "E004",
517                    )]);
518                }
519            }
520            if options.strict_syntax {
521                if let Some(error) = strict_syntax_error(sql, &d) {
522                    return ValidationResult::with_errors(vec![error]);
523                }
524            }
525            ValidationResult::success()
526        }
527        Err(e) => {
528            let error = match &e {
529                Error::Syntax {
530                    message,
531                    line,
532                    column,
533                    start,
534                    end,
535                } => ValidationError::error(message.clone(), "E001")
536                    .with_location(*line, *column)
537                    .with_span(Some(*start), Some(*end)),
538                Error::Tokenize {
539                    message,
540                    line,
541                    column,
542                    start,
543                    end,
544                } => ValidationError::error(message.clone(), "E002")
545                    .with_location(*line, *column)
546                    .with_span(Some(*start), Some(*end)),
547                Error::Parse {
548                    message,
549                    line,
550                    column,
551                    start,
552                    end,
553                } => ValidationError::error(message.clone(), "E003")
554                    .with_location(*line, *column)
555                    .with_span(Some(*start), Some(*end)),
556                _ => ValidationError::error(e.to_string(), "E000"),
557            };
558            ValidationResult::with_errors(vec![error])
559        }
560    }
561}
562
563fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
564    let tokens = dialect.tokenize(sql).ok()?;
565
566    for (idx, token) in tokens.iter().enumerate() {
567        if token.token_type != TokenType::Comma {
568            continue;
569        }
570
571        let next = tokens.get(idx + 1);
572        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
573            Some(TokenType::From) => (true, "FROM"),
574            Some(TokenType::Where) => (true, "WHERE"),
575            Some(TokenType::GroupBy) => (true, "GROUP BY"),
576            Some(TokenType::Having) => (true, "HAVING"),
577            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
578            Some(TokenType::Limit) => (true, "LIMIT"),
579            Some(TokenType::Offset) => (true, "OFFSET"),
580            Some(TokenType::Union) => (true, "UNION"),
581            Some(TokenType::Intersect) => (true, "INTERSECT"),
582            Some(TokenType::Except) => (true, "EXCEPT"),
583            Some(TokenType::Qualify) => (true, "QUALIFY"),
584            Some(TokenType::Window) => (true, "WINDOW"),
585            Some(TokenType::Semicolon) | None => (true, "end of statement"),
586            _ => (false, ""),
587        };
588
589        if is_boundary {
590            let message = format!(
591                "Trailing comma before {} is not allowed in strict syntax mode",
592                boundary_name
593            );
594            return Some(
595                ValidationError::error(message, "E005")
596                    .with_location(token.span.line, token.span.column),
597            );
598        }
599    }
600
601    None
602}
603
604/// Transpile SQL from one dialect to another, using string dialect names.
605///
606/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
607/// custom dialects registered via [`CustomDialectBuilder`].
608///
609/// # Arguments
610/// * `sql` - The SQL string to transpile
611/// * `read` - The source dialect name
612/// * `write` - The target dialect name
613///
614/// # Returns
615/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
616pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
617    let read_dialect = Dialect::get_by_name(read)
618        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
619    let write_dialect = Dialect::get_by_name(write)
620        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
621    let generic_identity = read_dialect.dialect_type() == DialectType::Generic
622        && write_dialect.dialect_type() == DialectType::Generic;
623
624    let expressions = read_dialect.parse(sql)?;
625
626    expressions
627        .into_iter()
628        .map(|expr| {
629            if generic_identity {
630                write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
631            } else {
632                let transformed = write_dialect.transform(expr)?;
633                write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
634            }
635        })
636        .collect()
637}
638
639/// Parse SQL into an AST using a string dialect name.
640///
641/// Supports both built-in and custom dialect names.
642pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
643    let d = Dialect::get_by_name(dialect)
644        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
645    d.parse(sql)
646}
647
648/// Generate SQL from an AST using a string dialect name.
649///
650/// Supports both built-in and custom dialect names.
651pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
652    let d = Dialect::get_by_name(dialect)
653        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
654    d.generate(expression)
655}
656
657/// Format SQL using a string dialect name.
658///
659/// Uses [`FormatGuardOptions::default`] guards.
660pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
661    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
662}
663
664/// Format SQL using a string dialect name with configurable guard limits.
665pub fn format_with_options_by_name(
666    sql: &str,
667    dialect: &str,
668    options: &FormatGuardOptions,
669) -> Result<Vec<String>> {
670    let d = Dialect::get_by_name(dialect)
671        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
672    format_with_dialect(sql, &d, options)
673}
674
675#[cfg(test)]
676mod validation_tests {
677    use super::*;
678
679    #[test]
680    fn validate_is_permissive_by_default_for_trailing_commas() {
681        let result = validate("SELECT name, FROM employees", DialectType::Generic);
682        assert!(result.valid, "Result: {:?}", result.errors);
683    }
684
685    #[test]
686    fn validate_with_options_rejects_trailing_comma_before_from() {
687        let options = ValidationOptions {
688            strict_syntax: true,
689        };
690        let result = validate_with_options(
691            "SELECT name, FROM employees",
692            DialectType::Generic,
693            &options,
694        );
695        assert!(!result.valid, "Result should be invalid");
696        assert!(
697            result.errors.iter().any(|e| e.code == "E005"),
698            "Expected E005, got: {:?}",
699            result.errors
700        );
701    }
702
703    #[test]
704    fn validate_with_options_rejects_trailing_comma_before_where() {
705        let options = ValidationOptions {
706            strict_syntax: true,
707        };
708        let result = validate_with_options(
709            "SELECT name FROM employees, WHERE salary > 10",
710            DialectType::Generic,
711            &options,
712        );
713        assert!(!result.valid, "Result should be invalid");
714        assert!(
715            result.errors.iter().any(|e| e.code == "E005"),
716            "Expected E005, got: {:?}",
717            result.errors
718        );
719    }
720}
721
722#[cfg(test)]
723mod format_tests {
724    use super::*;
725
726    #[test]
727    fn format_basic_query() {
728        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
729        assert_eq!(result.len(), 1);
730        assert!(result[0].contains('\n'));
731    }
732
733    #[test]
734    fn format_guard_rejects_large_input() {
735        let options = FormatGuardOptions {
736            max_input_bytes: Some(7),
737            max_tokens: None,
738            max_ast_nodes: None,
739            max_set_op_chain: None,
740        };
741        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
742            .expect_err("expected guard error");
743        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
744    }
745
746    #[test]
747    fn format_guard_rejects_token_budget() {
748        let options = FormatGuardOptions {
749            max_input_bytes: None,
750            max_tokens: Some(1),
751            max_ast_nodes: None,
752            max_set_op_chain: None,
753        };
754        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
755            .expect_err("expected guard error");
756        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
757    }
758
759    #[test]
760    fn format_guard_rejects_ast_budget() {
761        let options = FormatGuardOptions {
762            max_input_bytes: None,
763            max_tokens: None,
764            max_ast_nodes: Some(1),
765            max_set_op_chain: None,
766        };
767        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
768            .expect_err("expected guard error");
769        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
770    }
771
772    #[test]
773    fn format_guard_rejects_set_op_chain_budget() {
774        let options = FormatGuardOptions {
775            max_input_bytes: None,
776            max_tokens: None,
777            max_ast_nodes: None,
778            max_set_op_chain: Some(1),
779        };
780        let err = format_with_options(
781            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
782            DialectType::Generic,
783            &options,
784        )
785        .expect_err("expected guard error");
786        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
787    }
788
789    #[test]
790    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
791        let options = FormatGuardOptions {
792            max_input_bytes: None,
793            max_tokens: None,
794            max_ast_nodes: None,
795            max_set_op_chain: Some(0),
796        };
797        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
798        assert!(result.is_ok(), "Result: {:?}", result);
799    }
800
801    #[test]
802    fn issue57_invalid_ternary_returns_error() {
803        // https://github.com/tobilg/polyglot/issues/57
804        // Invalid SQL with ternary operator should return an error, not garbled output.
805        let sql = "SELECT x > 0 ? 1 : 0 FROM t";
806
807        let parse_result = parse(sql, DialectType::PostgreSQL);
808        assert!(
809            parse_result.is_err(),
810            "Expected parse error for invalid ternary SQL, got: {:?}",
811            parse_result
812        );
813
814        let format_result = format(sql, DialectType::PostgreSQL);
815        assert!(
816            format_result.is_err(),
817            "Expected format error for invalid ternary SQL, got: {:?}",
818            format_result
819        );
820
821        let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
822        assert!(
823            transpile_result.is_err(),
824            "Expected transpile error for invalid ternary SQL, got: {:?}",
825            transpile_result
826        );
827    }
828
829    #[test]
830    fn format_default_guard_rejects_deep_union_chain_before_parse() {
831        let base = "SELECT col0, col1 FROM t";
832        let mut sql = base.to_string();
833        for _ in 0..1100 {
834            sql.push_str(" UNION ALL ");
835            sql.push_str(base);
836        }
837
838        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
839        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
840    }
841}