Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43    get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
44    get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
45    remove_select_columns, remove_where, rename_columns, rename_tables, replace_by_type,
46    replace_nodes, set_distinct, set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67    ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72    contains_aggregate,
73    contains_subquery,
74    contains_window_function,
75    find_ancestor,
76    find_parent,
77    get_columns,
78    get_tables,
79    is_add,
80    is_aggregate,
81    is_alias,
82    is_alter_table,
83    is_and,
84    is_arithmetic,
85    is_avg,
86    is_between,
87    is_boolean,
88    is_case,
89    is_cast,
90    is_coalesce,
91    is_column,
92    is_comparison,
93    is_concat,
94    is_count,
95    is_create_index,
96    is_create_table,
97    is_create_view,
98    is_cte,
99    is_ddl,
100    is_delete,
101    is_div,
102    is_drop_index,
103    is_drop_table,
104    is_drop_view,
105    is_eq,
106    is_except,
107    is_exists,
108    is_from,
109    is_function,
110    is_group_by,
111    is_gt,
112    is_gte,
113    is_having,
114    is_identifier,
115    is_ilike,
116    is_in,
117    // Extended type predicates
118    is_insert,
119    is_intersect,
120    is_is_null,
121    is_join,
122    is_like,
123    is_limit,
124    is_literal,
125    is_logical,
126    is_lt,
127    is_lte,
128    is_max_func,
129    is_min_func,
130    is_mod,
131    is_mul,
132    is_neq,
133    is_not,
134    is_null_if,
135    is_null_literal,
136    is_offset,
137    is_or,
138    is_order_by,
139    is_ordered,
140    is_paren,
141    // Composite predicates
142    is_query,
143    is_safe_cast,
144    is_select,
145    is_set_operation,
146    is_star,
147    is_sub,
148    is_subquery,
149    is_sum,
150    is_table,
151    is_try_cast,
152    is_union,
153    is_update,
154    is_where,
155    is_window_function,
156    is_with,
157    transform,
158    transform_map,
159    BfsIter,
160    DfsIter,
161    ExpressionWalk,
162    ParentInfo,
163    TreeContext,
164};
165pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
166pub use validation::{
167    mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
168    SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
169    SchemaValidationOptions, ValidationSchema,
170};
171
172const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
173const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
174const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
175const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
176
177fn default_format_max_input_bytes() -> Option<usize> {
178    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
179}
180
181fn default_format_max_tokens() -> Option<usize> {
182    Some(DEFAULT_FORMAT_MAX_TOKENS)
183}
184
185fn default_format_max_ast_nodes() -> Option<usize> {
186    Some(DEFAULT_FORMAT_MAX_AST_NODES)
187}
188
189fn default_format_max_set_op_chain() -> Option<usize> {
190    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
191}
192
193/// Guard options for SQL pretty-formatting.
194///
195/// These limits protect against extremely large/complex queries that can cause
196/// high memory pressure in constrained runtimes (for example browser WASM).
197#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
198#[serde(rename_all = "camelCase")]
199pub struct FormatGuardOptions {
200    /// Maximum allowed SQL input size in bytes.
201    /// `None` disables this check.
202    #[serde(default = "default_format_max_input_bytes")]
203    pub max_input_bytes: Option<usize>,
204    /// Maximum allowed number of tokens after tokenization.
205    /// `None` disables this check.
206    #[serde(default = "default_format_max_tokens")]
207    pub max_tokens: Option<usize>,
208    /// Maximum allowed AST node count after parsing.
209    /// `None` disables this check.
210    #[serde(default = "default_format_max_ast_nodes")]
211    pub max_ast_nodes: Option<usize>,
212    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
213    /// observed in a statement before parsing.
214    ///
215    /// `None` disables this check.
216    #[serde(default = "default_format_max_set_op_chain")]
217    pub max_set_op_chain: Option<usize>,
218}
219
220impl Default for FormatGuardOptions {
221    fn default() -> Self {
222        Self {
223            max_input_bytes: default_format_max_input_bytes(),
224            max_tokens: default_format_max_tokens(),
225            max_ast_nodes: default_format_max_ast_nodes(),
226            max_set_op_chain: default_format_max_set_op_chain(),
227        }
228    }
229}
230
231fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
232    Error::generate(format!(
233        "{code}: value {actual} exceeds configured limit {limit}"
234    ))
235}
236
237fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
238    if let Some(max) = options.max_input_bytes {
239        let input_bytes = sql.len();
240        if input_bytes > max {
241            return Err(format_guard_error(
242                "E_GUARD_INPUT_TOO_LARGE",
243                input_bytes,
244                max,
245            ));
246        }
247    }
248    Ok(())
249}
250
251fn parse_with_token_guard(
252    sql: &str,
253    dialect: &Dialect,
254    options: &FormatGuardOptions,
255) -> Result<Vec<Expression>> {
256    let tokens = dialect.tokenize(sql)?;
257    if let Some(max) = options.max_tokens {
258        let token_count = tokens.len();
259        if token_count > max {
260            return Err(format_guard_error(
261                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
262                token_count,
263                max,
264            ));
265        }
266    }
267    enforce_set_op_chain_guard(&tokens, options)?;
268
269    let config = crate::parser::ParserConfig {
270        dialect: Some(dialect.dialect_type()),
271        ..Default::default()
272    };
273    let mut parser = Parser::with_source(tokens, config, sql.to_string());
274    parser.parse()
275}
276
277fn is_trivia_token(token_type: TokenType) -> bool {
278    matches!(
279        token_type,
280        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
281    )
282}
283
284fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
285    tokens
286        .iter()
287        .skip(start)
288        .find(|token| !is_trivia_token(token.token_type))
289}
290
291fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
292    let token = &tokens[idx];
293    match token.token_type {
294        TokenType::Union | TokenType::Intersect => true,
295        TokenType::Except => {
296            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
297            // is a function call rather than a set operation.
298            if token.text.eq_ignore_ascii_case("minus")
299                && matches!(
300                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
301                    Some(TokenType::LParen)
302                )
303            {
304                return false;
305            }
306            true
307        }
308        _ => false,
309    }
310}
311
312fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
313    let Some(max) = options.max_set_op_chain else {
314        return Ok(());
315    };
316
317    let mut set_op_count = 0usize;
318    for (idx, token) in tokens.iter().enumerate() {
319        if token.token_type == TokenType::Semicolon {
320            set_op_count = 0;
321            continue;
322        }
323
324        if is_set_operation_token(tokens, idx) {
325            set_op_count += 1;
326            if set_op_count > max {
327                return Err(format_guard_error(
328                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
329                    set_op_count,
330                    max,
331                ));
332            }
333        }
334    }
335
336    Ok(())
337}
338
339fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
340    if let Some(max) = options.max_ast_nodes {
341        let ast_nodes: usize = expressions.iter().map(node_count).sum();
342        if ast_nodes > max {
343            return Err(format_guard_error(
344                "E_GUARD_AST_BUDGET_EXCEEDED",
345                ast_nodes,
346                max,
347            ));
348        }
349    }
350    Ok(())
351}
352
353fn format_with_dialect(
354    sql: &str,
355    dialect: &Dialect,
356    options: &FormatGuardOptions,
357) -> Result<Vec<String>> {
358    enforce_input_guard(sql, options)?;
359    let expressions = parse_with_token_guard(sql, dialect, options)?;
360    enforce_ast_guard(&expressions, options)?;
361
362    expressions
363        .iter()
364        .map(|expr| dialect.generate_pretty(expr))
365        .collect()
366}
367
368/// Transpile SQL from one dialect to another.
369///
370/// # Arguments
371/// * `sql` - The SQL string to transpile
372/// * `read` - The source dialect to parse with
373/// * `write` - The target dialect to generate
374///
375/// # Returns
376/// A vector of transpiled SQL statements
377///
378/// # Example
379/// ```
380/// use polyglot_sql::{transpile, DialectType};
381///
382/// let result = transpile(
383///     "SELECT EPOCH_MS(1618088028295)",
384///     DialectType::DuckDB,
385///     DialectType::Hive
386/// );
387/// ```
388pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
389    let read_dialect = Dialect::get(read);
390    let write_dialect = Dialect::get(write);
391    let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
392
393    let expressions = read_dialect.parse(sql)?;
394
395    expressions
396        .into_iter()
397        .map(|expr| {
398            if generic_identity {
399                write_dialect.generate_with_source(&expr, read)
400            } else {
401                let transformed = write_dialect.transform(expr)?;
402                write_dialect.generate_with_source(&transformed, read)
403            }
404        })
405        .collect()
406}
407
408/// Parse SQL into an AST.
409///
410/// # Arguments
411/// * `sql` - The SQL string to parse
412/// * `dialect` - The dialect to use for parsing
413///
414/// # Returns
415/// A vector of parsed expressions
416pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
417    let d = Dialect::get(dialect);
418    d.parse(sql)
419}
420
421/// Parse a single SQL statement.
422///
423/// # Arguments
424/// * `sql` - The SQL string containing a single statement
425/// * `dialect` - The dialect to use for parsing
426///
427/// # Returns
428/// The parsed expression, or an error if multiple statements found
429pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
430    let mut expressions = parse(sql, dialect)?;
431
432    if expressions.len() != 1 {
433        return Err(Error::parse(
434            format!("Expected 1 statement, found {}", expressions.len()),
435            0,
436            0,
437            0,
438            0,
439        ));
440    }
441
442    Ok(expressions.remove(0))
443}
444
445/// Generate SQL from an AST.
446///
447/// # Arguments
448/// * `expression` - The expression to generate SQL from
449/// * `dialect` - The target dialect
450///
451/// # Returns
452/// The generated SQL string
453pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
454    let d = Dialect::get(dialect);
455    d.generate(expression)
456}
457
458/// Format/pretty-print SQL statements.
459///
460/// Uses [`FormatGuardOptions::default`] guards.
461pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
462    format_with_options(sql, dialect, &FormatGuardOptions::default())
463}
464
465/// Format/pretty-print SQL statements with configurable guard limits.
466pub fn format_with_options(
467    sql: &str,
468    dialect: DialectType,
469    options: &FormatGuardOptions,
470) -> Result<Vec<String>> {
471    let d = Dialect::get(dialect);
472    format_with_dialect(sql, &d, options)
473}
474
475/// Validate SQL syntax.
476///
477/// # Arguments
478/// * `sql` - The SQL string to validate
479/// * `dialect` - The dialect to use for validation
480///
481/// # Returns
482/// A validation result with any errors found
483pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
484    validate_with_options(sql, dialect, &ValidationOptions::default())
485}
486
487/// Options for syntax validation behavior.
488#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
489#[serde(rename_all = "camelCase")]
490pub struct ValidationOptions {
491    /// When enabled, validation rejects non-canonical trailing commas that the parser
492    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
493    #[serde(default)]
494    pub strict_syntax: bool,
495}
496
497/// Validate SQL syntax with additional validation options.
498pub fn validate_with_options(
499    sql: &str,
500    dialect: DialectType,
501    options: &ValidationOptions,
502) -> ValidationResult {
503    let d = Dialect::get(dialect);
504    match d.parse(sql) {
505        Ok(expressions) => {
506            // Reject bare expressions that aren't valid SQL statements.
507            // The parser accepts any expression at the top level, but bare identifiers,
508            // literals, function calls, etc. are not valid statements.
509            for expr in &expressions {
510                if !expr.is_statement() {
511                    let msg = format!("Invalid expression / Unexpected token");
512                    return ValidationResult::with_errors(vec![ValidationError::error(
513                        msg, "E004",
514                    )]);
515                }
516            }
517            if options.strict_syntax {
518                if let Some(error) = strict_syntax_error(sql, &d) {
519                    return ValidationResult::with_errors(vec![error]);
520                }
521            }
522            ValidationResult::success()
523        }
524        Err(e) => {
525            let error = match &e {
526                Error::Syntax {
527                    message,
528                    line,
529                    column,
530                    start,
531                    end,
532                } => ValidationError::error(message.clone(), "E001")
533                    .with_location(*line, *column)
534                    .with_span(Some(*start), Some(*end)),
535                Error::Tokenize {
536                    message,
537                    line,
538                    column,
539                    start,
540                    end,
541                } => ValidationError::error(message.clone(), "E002")
542                    .with_location(*line, *column)
543                    .with_span(Some(*start), Some(*end)),
544                Error::Parse {
545                    message,
546                    line,
547                    column,
548                    start,
549                    end,
550                } => ValidationError::error(message.clone(), "E003")
551                    .with_location(*line, *column)
552                    .with_span(Some(*start), Some(*end)),
553                _ => ValidationError::error(e.to_string(), "E000"),
554            };
555            ValidationResult::with_errors(vec![error])
556        }
557    }
558}
559
560fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
561    let tokens = dialect.tokenize(sql).ok()?;
562
563    for (idx, token) in tokens.iter().enumerate() {
564        if token.token_type != TokenType::Comma {
565            continue;
566        }
567
568        let next = tokens.get(idx + 1);
569        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
570            Some(TokenType::From) => (true, "FROM"),
571            Some(TokenType::Where) => (true, "WHERE"),
572            Some(TokenType::GroupBy) => (true, "GROUP BY"),
573            Some(TokenType::Having) => (true, "HAVING"),
574            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
575            Some(TokenType::Limit) => (true, "LIMIT"),
576            Some(TokenType::Offset) => (true, "OFFSET"),
577            Some(TokenType::Union) => (true, "UNION"),
578            Some(TokenType::Intersect) => (true, "INTERSECT"),
579            Some(TokenType::Except) => (true, "EXCEPT"),
580            Some(TokenType::Qualify) => (true, "QUALIFY"),
581            Some(TokenType::Window) => (true, "WINDOW"),
582            Some(TokenType::Semicolon) | None => (true, "end of statement"),
583            _ => (false, ""),
584        };
585
586        if is_boundary {
587            let message = format!(
588                "Trailing comma before {} is not allowed in strict syntax mode",
589                boundary_name
590            );
591            return Some(
592                ValidationError::error(message, "E005")
593                    .with_location(token.span.line, token.span.column),
594            );
595        }
596    }
597
598    None
599}
600
601/// Transpile SQL from one dialect to another, using string dialect names.
602///
603/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
604/// custom dialects registered via [`CustomDialectBuilder`].
605///
606/// # Arguments
607/// * `sql` - The SQL string to transpile
608/// * `read` - The source dialect name
609/// * `write` - The target dialect name
610///
611/// # Returns
612/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
613pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
614    let read_dialect = Dialect::get_by_name(read)
615        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
616    let write_dialect = Dialect::get_by_name(write)
617        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
618    let generic_identity = read_dialect.dialect_type() == DialectType::Generic
619        && write_dialect.dialect_type() == DialectType::Generic;
620
621    let expressions = read_dialect.parse(sql)?;
622
623    expressions
624        .into_iter()
625        .map(|expr| {
626            if generic_identity {
627                write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
628            } else {
629                let transformed = write_dialect.transform(expr)?;
630                write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
631            }
632        })
633        .collect()
634}
635
636/// Parse SQL into an AST using a string dialect name.
637///
638/// Supports both built-in and custom dialect names.
639pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
640    let d = Dialect::get_by_name(dialect)
641        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
642    d.parse(sql)
643}
644
645/// Generate SQL from an AST using a string dialect name.
646///
647/// Supports both built-in and custom dialect names.
648pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
649    let d = Dialect::get_by_name(dialect)
650        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
651    d.generate(expression)
652}
653
654/// Format SQL using a string dialect name.
655///
656/// Uses [`FormatGuardOptions::default`] guards.
657pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
658    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
659}
660
661/// Format SQL using a string dialect name with configurable guard limits.
662pub fn format_with_options_by_name(
663    sql: &str,
664    dialect: &str,
665    options: &FormatGuardOptions,
666) -> Result<Vec<String>> {
667    let d = Dialect::get_by_name(dialect)
668        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
669    format_with_dialect(sql, &d, options)
670}
671
672#[cfg(test)]
673mod validation_tests {
674    use super::*;
675
676    #[test]
677    fn validate_is_permissive_by_default_for_trailing_commas() {
678        let result = validate("SELECT name, FROM employees", DialectType::Generic);
679        assert!(result.valid, "Result: {:?}", result.errors);
680    }
681
682    #[test]
683    fn validate_with_options_rejects_trailing_comma_before_from() {
684        let options = ValidationOptions {
685            strict_syntax: true,
686        };
687        let result = validate_with_options(
688            "SELECT name, FROM employees",
689            DialectType::Generic,
690            &options,
691        );
692        assert!(!result.valid, "Result should be invalid");
693        assert!(
694            result.errors.iter().any(|e| e.code == "E005"),
695            "Expected E005, got: {:?}",
696            result.errors
697        );
698    }
699
700    #[test]
701    fn validate_with_options_rejects_trailing_comma_before_where() {
702        let options = ValidationOptions {
703            strict_syntax: true,
704        };
705        let result = validate_with_options(
706            "SELECT name FROM employees, WHERE salary > 10",
707            DialectType::Generic,
708            &options,
709        );
710        assert!(!result.valid, "Result should be invalid");
711        assert!(
712            result.errors.iter().any(|e| e.code == "E005"),
713            "Expected E005, got: {:?}",
714            result.errors
715        );
716    }
717}
718
719#[cfg(test)]
720mod format_tests {
721    use super::*;
722
723    #[test]
724    fn format_basic_query() {
725        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
726        assert_eq!(result.len(), 1);
727        assert!(result[0].contains('\n'));
728    }
729
730    #[test]
731    fn format_guard_rejects_large_input() {
732        let options = FormatGuardOptions {
733            max_input_bytes: Some(7),
734            max_tokens: None,
735            max_ast_nodes: None,
736            max_set_op_chain: None,
737        };
738        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
739            .expect_err("expected guard error");
740        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
741    }
742
743    #[test]
744    fn format_guard_rejects_token_budget() {
745        let options = FormatGuardOptions {
746            max_input_bytes: None,
747            max_tokens: Some(1),
748            max_ast_nodes: None,
749            max_set_op_chain: None,
750        };
751        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
752            .expect_err("expected guard error");
753        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
754    }
755
756    #[test]
757    fn format_guard_rejects_ast_budget() {
758        let options = FormatGuardOptions {
759            max_input_bytes: None,
760            max_tokens: None,
761            max_ast_nodes: Some(1),
762            max_set_op_chain: None,
763        };
764        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
765            .expect_err("expected guard error");
766        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
767    }
768
769    #[test]
770    fn format_guard_rejects_set_op_chain_budget() {
771        let options = FormatGuardOptions {
772            max_input_bytes: None,
773            max_tokens: None,
774            max_ast_nodes: None,
775            max_set_op_chain: Some(1),
776        };
777        let err = format_with_options(
778            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
779            DialectType::Generic,
780            &options,
781        )
782        .expect_err("expected guard error");
783        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
784    }
785
786    #[test]
787    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
788        let options = FormatGuardOptions {
789            max_input_bytes: None,
790            max_tokens: None,
791            max_ast_nodes: None,
792            max_set_op_chain: Some(0),
793        };
794        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
795        assert!(result.is_ok(), "Result: {:?}", result);
796    }
797
798    #[test]
799    fn issue57_invalid_ternary_returns_error() {
800        // https://github.com/tobilg/polyglot/issues/57
801        // Invalid SQL with ternary operator should return an error, not garbled output.
802        let sql = "SELECT x > 0 ? 1 : 0 FROM t";
803
804        let parse_result = parse(sql, DialectType::PostgreSQL);
805        assert!(
806            parse_result.is_err(),
807            "Expected parse error for invalid ternary SQL, got: {:?}",
808            parse_result
809        );
810
811        let format_result = format(sql, DialectType::PostgreSQL);
812        assert!(
813            format_result.is_err(),
814            "Expected format error for invalid ternary SQL, got: {:?}",
815            format_result
816        );
817
818        let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
819        assert!(
820            transpile_result.is_err(),
821            "Expected transpile error for invalid ternary SQL, got: {:?}",
822            transpile_result
823        );
824    }
825
826    #[test]
827    fn format_default_guard_rejects_deep_union_chain_before_parse() {
828        let base = "SELECT col0, col1 FROM t";
829        let mut sql = base.to_string();
830        for _ in 0..1100 {
831            sql.push_str(" UNION ALL ");
832            sql.push_str(base);
833        }
834
835        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
836        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
837    }
838}