Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43    get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
44    get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
45    remove_select_columns, remove_where, rename_columns, rename_tables, replace_by_type,
46    replace_nodes, set_distinct, set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67    ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72    contains_aggregate,
73    contains_subquery,
74    contains_window_function,
75    find_ancestor,
76    find_parent,
77    get_all_tables,
78    get_columns,
79    get_merge_source,
80    get_merge_target,
81    get_tables,
82    is_add,
83    is_aggregate,
84    is_alias,
85    is_alter_table,
86    is_and,
87    is_arithmetic,
88    is_avg,
89    is_between,
90    is_boolean,
91    is_case,
92    is_cast,
93    is_coalesce,
94    is_column,
95    is_comparison,
96    is_concat,
97    is_count,
98    is_create_index,
99    is_create_table,
100    is_create_view,
101    is_cte,
102    is_ddl,
103    is_delete,
104    is_div,
105    is_drop_index,
106    is_drop_table,
107    is_drop_view,
108    is_eq,
109    is_except,
110    is_exists,
111    is_from,
112    is_function,
113    is_group_by,
114    is_gt,
115    is_gte,
116    is_having,
117    is_identifier,
118    is_ilike,
119    is_in,
120    // Extended type predicates
121    is_insert,
122    is_intersect,
123    is_is_null,
124    is_join,
125    is_like,
126    is_limit,
127    is_literal,
128    is_logical,
129    is_lt,
130    is_lte,
131    is_max_func,
132    is_merge,
133    is_min_func,
134    is_mod,
135    is_mul,
136    is_neq,
137    is_not,
138    is_null_if,
139    is_null_literal,
140    is_offset,
141    is_or,
142    is_order_by,
143    is_ordered,
144    is_paren,
145    // Composite predicates
146    is_query,
147    is_safe_cast,
148    is_select,
149    is_set_operation,
150    is_star,
151    is_sub,
152    is_subquery,
153    is_sum,
154    is_table,
155    is_try_cast,
156    is_union,
157    is_update,
158    is_where,
159    is_window_function,
160    is_with,
161    transform,
162    transform_map,
163    BfsIter,
164    DfsIter,
165    ExpressionWalk,
166    ParentInfo,
167    TreeContext,
168};
169pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
170pub use validation::{
171    mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
172    SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
173    SchemaValidationOptions, ValidationSchema,
174};
175
176const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
177const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
178const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
179const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
180
181fn default_format_max_input_bytes() -> Option<usize> {
182    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
183}
184
185fn default_format_max_tokens() -> Option<usize> {
186    Some(DEFAULT_FORMAT_MAX_TOKENS)
187}
188
189fn default_format_max_ast_nodes() -> Option<usize> {
190    Some(DEFAULT_FORMAT_MAX_AST_NODES)
191}
192
193fn default_format_max_set_op_chain() -> Option<usize> {
194    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
195}
196
197/// Guard options for SQL pretty-formatting.
198///
199/// These limits protect against extremely large/complex queries that can cause
200/// high memory pressure in constrained runtimes (for example browser WASM).
201#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
202#[serde(rename_all = "camelCase")]
203pub struct FormatGuardOptions {
204    /// Maximum allowed SQL input size in bytes.
205    /// `None` disables this check.
206    #[serde(default = "default_format_max_input_bytes")]
207    pub max_input_bytes: Option<usize>,
208    /// Maximum allowed number of tokens after tokenization.
209    /// `None` disables this check.
210    #[serde(default = "default_format_max_tokens")]
211    pub max_tokens: Option<usize>,
212    /// Maximum allowed AST node count after parsing.
213    /// `None` disables this check.
214    #[serde(default = "default_format_max_ast_nodes")]
215    pub max_ast_nodes: Option<usize>,
216    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
217    /// observed in a statement before parsing.
218    ///
219    /// `None` disables this check.
220    #[serde(default = "default_format_max_set_op_chain")]
221    pub max_set_op_chain: Option<usize>,
222}
223
224impl Default for FormatGuardOptions {
225    fn default() -> Self {
226        Self {
227            max_input_bytes: default_format_max_input_bytes(),
228            max_tokens: default_format_max_tokens(),
229            max_ast_nodes: default_format_max_ast_nodes(),
230            max_set_op_chain: default_format_max_set_op_chain(),
231        }
232    }
233}
234
235fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
236    Error::generate(format!(
237        "{code}: value {actual} exceeds configured limit {limit}"
238    ))
239}
240
241fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
242    if let Some(max) = options.max_input_bytes {
243        let input_bytes = sql.len();
244        if input_bytes > max {
245            return Err(format_guard_error(
246                "E_GUARD_INPUT_TOO_LARGE",
247                input_bytes,
248                max,
249            ));
250        }
251    }
252    Ok(())
253}
254
255fn parse_with_token_guard(
256    sql: &str,
257    dialect: &Dialect,
258    options: &FormatGuardOptions,
259) -> Result<Vec<Expression>> {
260    let tokens = dialect.tokenize(sql)?;
261    if let Some(max) = options.max_tokens {
262        let token_count = tokens.len();
263        if token_count > max {
264            return Err(format_guard_error(
265                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
266                token_count,
267                max,
268            ));
269        }
270    }
271    enforce_set_op_chain_guard(&tokens, options)?;
272
273    let config = crate::parser::ParserConfig {
274        dialect: Some(dialect.dialect_type()),
275        ..Default::default()
276    };
277    let mut parser = Parser::with_source(tokens, config, sql.to_string());
278    parser.parse()
279}
280
281fn is_trivia_token(token_type: TokenType) -> bool {
282    matches!(
283        token_type,
284        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
285    )
286}
287
288fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
289    tokens
290        .iter()
291        .skip(start)
292        .find(|token| !is_trivia_token(token.token_type))
293}
294
295fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
296    let token = &tokens[idx];
297    match token.token_type {
298        TokenType::Union | TokenType::Intersect => true,
299        TokenType::Except => {
300            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
301            // is a function call rather than a set operation.
302            if token.text.eq_ignore_ascii_case("minus")
303                && matches!(
304                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
305                    Some(TokenType::LParen)
306                )
307            {
308                return false;
309            }
310            true
311        }
312        _ => false,
313    }
314}
315
316fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
317    let Some(max) = options.max_set_op_chain else {
318        return Ok(());
319    };
320
321    let mut set_op_count = 0usize;
322    for (idx, token) in tokens.iter().enumerate() {
323        if token.token_type == TokenType::Semicolon {
324            set_op_count = 0;
325            continue;
326        }
327
328        if is_set_operation_token(tokens, idx) {
329            set_op_count += 1;
330            if set_op_count > max {
331                return Err(format_guard_error(
332                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
333                    set_op_count,
334                    max,
335                ));
336            }
337        }
338    }
339
340    Ok(())
341}
342
343fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
344    if let Some(max) = options.max_ast_nodes {
345        let ast_nodes: usize = expressions.iter().map(node_count).sum();
346        if ast_nodes > max {
347            return Err(format_guard_error(
348                "E_GUARD_AST_BUDGET_EXCEEDED",
349                ast_nodes,
350                max,
351            ));
352        }
353    }
354    Ok(())
355}
356
357fn format_with_dialect(
358    sql: &str,
359    dialect: &Dialect,
360    options: &FormatGuardOptions,
361) -> Result<Vec<String>> {
362    enforce_input_guard(sql, options)?;
363    let expressions = parse_with_token_guard(sql, dialect, options)?;
364    enforce_ast_guard(&expressions, options)?;
365
366    expressions
367        .iter()
368        .map(|expr| dialect.generate_pretty(expr))
369        .collect()
370}
371
372/// Transpile SQL from one dialect to another.
373///
374/// # Arguments
375/// * `sql` - The SQL string to transpile
376/// * `read` - The source dialect to parse with
377/// * `write` - The target dialect to generate
378///
379/// # Returns
380/// A vector of transpiled SQL statements
381///
382/// # Example
383/// ```
384/// use polyglot_sql::{transpile, DialectType};
385///
386/// let result = transpile(
387///     "SELECT EPOCH_MS(1618088028295)",
388///     DialectType::DuckDB,
389///     DialectType::Hive
390/// );
391/// ```
392pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
393    let read_dialect = Dialect::get(read);
394    let write_dialect = Dialect::get(write);
395    let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
396
397    let expressions = read_dialect.parse(sql)?;
398
399    expressions
400        .into_iter()
401        .map(|expr| {
402            if generic_identity {
403                write_dialect.generate_with_source(&expr, read)
404            } else {
405                let transformed = write_dialect.transform(expr)?;
406                write_dialect.generate_with_source(&transformed, read)
407            }
408        })
409        .collect()
410}
411
412/// Parse SQL into an AST.
413///
414/// # Arguments
415/// * `sql` - The SQL string to parse
416/// * `dialect` - The dialect to use for parsing
417///
418/// # Returns
419/// A vector of parsed expressions
420pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
421    let d = Dialect::get(dialect);
422    d.parse(sql)
423}
424
425/// Parse a single SQL statement.
426///
427/// # Arguments
428/// * `sql` - The SQL string containing a single statement
429/// * `dialect` - The dialect to use for parsing
430///
431/// # Returns
432/// The parsed expression, or an error if multiple statements found
433pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
434    let mut expressions = parse(sql, dialect)?;
435
436    if expressions.len() != 1 {
437        return Err(Error::parse(
438            format!("Expected 1 statement, found {}", expressions.len()),
439            0,
440            0,
441            0,
442            0,
443        ));
444    }
445
446    Ok(expressions.remove(0))
447}
448
449/// Generate SQL from an AST.
450///
451/// # Arguments
452/// * `expression` - The expression to generate SQL from
453/// * `dialect` - The target dialect
454///
455/// # Returns
456/// The generated SQL string
457pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
458    let d = Dialect::get(dialect);
459    d.generate(expression)
460}
461
462/// Format/pretty-print SQL statements.
463///
464/// Uses [`FormatGuardOptions::default`] guards.
465pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
466    format_with_options(sql, dialect, &FormatGuardOptions::default())
467}
468
469/// Format/pretty-print SQL statements with configurable guard limits.
470pub fn format_with_options(
471    sql: &str,
472    dialect: DialectType,
473    options: &FormatGuardOptions,
474) -> Result<Vec<String>> {
475    let d = Dialect::get(dialect);
476    format_with_dialect(sql, &d, options)
477}
478
479/// Validate SQL syntax.
480///
481/// # Arguments
482/// * `sql` - The SQL string to validate
483/// * `dialect` - The dialect to use for validation
484///
485/// # Returns
486/// A validation result with any errors found
487pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
488    validate_with_options(sql, dialect, &ValidationOptions::default())
489}
490
491/// Options for syntax validation behavior.
492#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
493#[serde(rename_all = "camelCase")]
494pub struct ValidationOptions {
495    /// When enabled, validation rejects non-canonical trailing commas that the parser
496    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
497    #[serde(default)]
498    pub strict_syntax: bool,
499}
500
501/// Validate SQL syntax with additional validation options.
502pub fn validate_with_options(
503    sql: &str,
504    dialect: DialectType,
505    options: &ValidationOptions,
506) -> ValidationResult {
507    let d = Dialect::get(dialect);
508    match d.parse(sql) {
509        Ok(expressions) => {
510            // Reject bare expressions that aren't valid SQL statements.
511            // The parser accepts any expression at the top level, but bare identifiers,
512            // literals, function calls, etc. are not valid statements.
513            for expr in &expressions {
514                if !expr.is_statement() {
515                    let msg = format!("Invalid expression / Unexpected token");
516                    return ValidationResult::with_errors(vec![ValidationError::error(
517                        msg, "E004",
518                    )]);
519                }
520            }
521            if options.strict_syntax {
522                if let Some(error) = strict_syntax_error(sql, &d) {
523                    return ValidationResult::with_errors(vec![error]);
524                }
525            }
526            ValidationResult::success()
527        }
528        Err(e) => {
529            let error = match &e {
530                Error::Syntax {
531                    message,
532                    line,
533                    column,
534                    start,
535                    end,
536                } => ValidationError::error(message.clone(), "E001")
537                    .with_location(*line, *column)
538                    .with_span(Some(*start), Some(*end)),
539                Error::Tokenize {
540                    message,
541                    line,
542                    column,
543                    start,
544                    end,
545                } => ValidationError::error(message.clone(), "E002")
546                    .with_location(*line, *column)
547                    .with_span(Some(*start), Some(*end)),
548                Error::Parse {
549                    message,
550                    line,
551                    column,
552                    start,
553                    end,
554                } => ValidationError::error(message.clone(), "E003")
555                    .with_location(*line, *column)
556                    .with_span(Some(*start), Some(*end)),
557                _ => ValidationError::error(e.to_string(), "E000"),
558            };
559            ValidationResult::with_errors(vec![error])
560        }
561    }
562}
563
564fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
565    let tokens = dialect.tokenize(sql).ok()?;
566
567    for (idx, token) in tokens.iter().enumerate() {
568        if token.token_type != TokenType::Comma {
569            continue;
570        }
571
572        let next = tokens.get(idx + 1);
573        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
574            Some(TokenType::From) => (true, "FROM"),
575            Some(TokenType::Where) => (true, "WHERE"),
576            Some(TokenType::GroupBy) => (true, "GROUP BY"),
577            Some(TokenType::Having) => (true, "HAVING"),
578            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
579            Some(TokenType::Limit) => (true, "LIMIT"),
580            Some(TokenType::Offset) => (true, "OFFSET"),
581            Some(TokenType::Union) => (true, "UNION"),
582            Some(TokenType::Intersect) => (true, "INTERSECT"),
583            Some(TokenType::Except) => (true, "EXCEPT"),
584            Some(TokenType::Qualify) => (true, "QUALIFY"),
585            Some(TokenType::Window) => (true, "WINDOW"),
586            Some(TokenType::Semicolon) | None => (true, "end of statement"),
587            _ => (false, ""),
588        };
589
590        if is_boundary {
591            let message = format!(
592                "Trailing comma before {} is not allowed in strict syntax mode",
593                boundary_name
594            );
595            return Some(
596                ValidationError::error(message, "E005")
597                    .with_location(token.span.line, token.span.column),
598            );
599        }
600    }
601
602    None
603}
604
605/// Transpile SQL from one dialect to another, using string dialect names.
606///
607/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
608/// custom dialects registered via [`CustomDialectBuilder`].
609///
610/// # Arguments
611/// * `sql` - The SQL string to transpile
612/// * `read` - The source dialect name
613/// * `write` - The target dialect name
614///
615/// # Returns
616/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
617pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
618    let read_dialect = Dialect::get_by_name(read)
619        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
620    let write_dialect = Dialect::get_by_name(write)
621        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
622    let generic_identity = read_dialect.dialect_type() == DialectType::Generic
623        && write_dialect.dialect_type() == DialectType::Generic;
624
625    let expressions = read_dialect.parse(sql)?;
626
627    expressions
628        .into_iter()
629        .map(|expr| {
630            if generic_identity {
631                write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
632            } else {
633                let transformed = write_dialect.transform(expr)?;
634                write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
635            }
636        })
637        .collect()
638}
639
640/// Parse SQL into an AST using a string dialect name.
641///
642/// Supports both built-in and custom dialect names.
643pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
644    let d = Dialect::get_by_name(dialect)
645        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
646    d.parse(sql)
647}
648
649/// Generate SQL from an AST using a string dialect name.
650///
651/// Supports both built-in and custom dialect names.
652pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
653    let d = Dialect::get_by_name(dialect)
654        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
655    d.generate(expression)
656}
657
658/// Format SQL using a string dialect name.
659///
660/// Uses [`FormatGuardOptions::default`] guards.
661pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
662    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
663}
664
665/// Format SQL using a string dialect name with configurable guard limits.
666pub fn format_with_options_by_name(
667    sql: &str,
668    dialect: &str,
669    options: &FormatGuardOptions,
670) -> Result<Vec<String>> {
671    let d = Dialect::get_by_name(dialect)
672        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
673    format_with_dialect(sql, &d, options)
674}
675
676#[cfg(test)]
677mod validation_tests {
678    use super::*;
679
680    #[test]
681    fn validate_is_permissive_by_default_for_trailing_commas() {
682        let result = validate("SELECT name, FROM employees", DialectType::Generic);
683        assert!(result.valid, "Result: {:?}", result.errors);
684    }
685
686    #[test]
687    fn validate_with_options_rejects_trailing_comma_before_from() {
688        let options = ValidationOptions {
689            strict_syntax: true,
690        };
691        let result = validate_with_options(
692            "SELECT name, FROM employees",
693            DialectType::Generic,
694            &options,
695        );
696        assert!(!result.valid, "Result should be invalid");
697        assert!(
698            result.errors.iter().any(|e| e.code == "E005"),
699            "Expected E005, got: {:?}",
700            result.errors
701        );
702    }
703
704    #[test]
705    fn validate_with_options_rejects_trailing_comma_before_where() {
706        let options = ValidationOptions {
707            strict_syntax: true,
708        };
709        let result = validate_with_options(
710            "SELECT name FROM employees, WHERE salary > 10",
711            DialectType::Generic,
712            &options,
713        );
714        assert!(!result.valid, "Result should be invalid");
715        assert!(
716            result.errors.iter().any(|e| e.code == "E005"),
717            "Expected E005, got: {:?}",
718            result.errors
719        );
720    }
721}
722
723#[cfg(test)]
724mod format_tests {
725    use super::*;
726
727    #[test]
728    fn format_basic_query() {
729        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
730        assert_eq!(result.len(), 1);
731        assert!(result[0].contains('\n'));
732    }
733
734    #[test]
735    fn format_guard_rejects_large_input() {
736        let options = FormatGuardOptions {
737            max_input_bytes: Some(7),
738            max_tokens: None,
739            max_ast_nodes: None,
740            max_set_op_chain: None,
741        };
742        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
743            .expect_err("expected guard error");
744        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
745    }
746
747    #[test]
748    fn format_guard_rejects_token_budget() {
749        let options = FormatGuardOptions {
750            max_input_bytes: None,
751            max_tokens: Some(1),
752            max_ast_nodes: None,
753            max_set_op_chain: None,
754        };
755        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
756            .expect_err("expected guard error");
757        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
758    }
759
760    #[test]
761    fn format_guard_rejects_ast_budget() {
762        let options = FormatGuardOptions {
763            max_input_bytes: None,
764            max_tokens: None,
765            max_ast_nodes: Some(1),
766            max_set_op_chain: None,
767        };
768        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
769            .expect_err("expected guard error");
770        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
771    }
772
773    #[test]
774    fn format_guard_rejects_set_op_chain_budget() {
775        let options = FormatGuardOptions {
776            max_input_bytes: None,
777            max_tokens: None,
778            max_ast_nodes: None,
779            max_set_op_chain: Some(1),
780        };
781        let err = format_with_options(
782            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
783            DialectType::Generic,
784            &options,
785        )
786        .expect_err("expected guard error");
787        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
788    }
789
790    #[test]
791    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
792        let options = FormatGuardOptions {
793            max_input_bytes: None,
794            max_tokens: None,
795            max_ast_nodes: None,
796            max_set_op_chain: Some(0),
797        };
798        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
799        assert!(result.is_ok(), "Result: {:?}", result);
800    }
801
802    #[test]
803    fn issue57_invalid_ternary_returns_error() {
804        // https://github.com/tobilg/polyglot/issues/57
805        // Invalid SQL with ternary operator should return an error, not garbled output.
806        let sql = "SELECT x > 0 ? 1 : 0 FROM t";
807
808        let parse_result = parse(sql, DialectType::PostgreSQL);
809        assert!(
810            parse_result.is_err(),
811            "Expected parse error for invalid ternary SQL, got: {:?}",
812            parse_result
813        );
814
815        let format_result = format(sql, DialectType::PostgreSQL);
816        assert!(
817            format_result.is_err(),
818            "Expected format error for invalid ternary SQL, got: {:?}",
819            format_result
820        );
821
822        let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
823        assert!(
824            transpile_result.is_err(),
825            "Expected transpile error for invalid ternary SQL, got: {:?}",
826            transpile_result
827        );
828    }
829
830    #[test]
831    fn format_default_guard_rejects_deep_union_chain_before_parse() {
832        let base = "SELECT col0, col1 FROM t";
833        let mut sql = base.to_string();
834        for _ in 0..1100 {
835            sql.push_str(" UNION ALL ");
836            sql.push_str(base);
837        }
838
839        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
840        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
841    }
842}