Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod function_catalog;
22mod function_registry;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43    get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
44    node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
45    remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
46    set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67    ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72    contains_aggregate,
73    contains_subquery,
74    contains_window_function,
75    find_ancestor,
76    find_parent,
77    get_columns,
78    get_tables,
79    is_add,
80    is_aggregate,
81    is_alias,
82    is_alter_table,
83    is_and,
84    is_arithmetic,
85    is_avg,
86    is_between,
87    is_boolean,
88    is_case,
89    is_cast,
90    is_coalesce,
91    is_column,
92    is_comparison,
93    is_concat,
94    is_count,
95    is_create_index,
96    is_create_table,
97    is_create_view,
98    is_cte,
99    is_ddl,
100    is_delete,
101    is_div,
102    is_drop_index,
103    is_drop_table,
104    is_drop_view,
105    is_eq,
106    is_except,
107    is_exists,
108    is_from,
109    is_function,
110    is_group_by,
111    is_gt,
112    is_gte,
113    is_having,
114    is_identifier,
115    is_ilike,
116    is_in,
117    // Extended type predicates
118    is_insert,
119    is_intersect,
120    is_is_null,
121    is_join,
122    is_like,
123    is_limit,
124    is_literal,
125    is_logical,
126    is_lt,
127    is_lte,
128    is_max_func,
129    is_min_func,
130    is_mod,
131    is_mul,
132    is_neq,
133    is_not,
134    is_null_if,
135    is_null_literal,
136    is_offset,
137    is_or,
138    is_order_by,
139    is_ordered,
140    is_paren,
141    // Composite predicates
142    is_query,
143    is_safe_cast,
144    is_select,
145    is_set_operation,
146    is_star,
147    is_sub,
148    is_subquery,
149    is_sum,
150    is_table,
151    is_try_cast,
152    is_union,
153    is_update,
154    is_where,
155    is_window_function,
156    is_with,
157    transform,
158    transform_map,
159    BfsIter,
160    DfsIter,
161    ExpressionWalk,
162    ParentInfo,
163    TreeContext,
164};
165pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
166pub use validation::{
167    validate_with_schema, SchemaColumn, SchemaColumnReference, SchemaForeignKey, SchemaTable,
168    SchemaTableReference, SchemaValidationOptions, ValidationSchema,
169};
170
171const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
172const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
173const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
174const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
175
176fn default_format_max_input_bytes() -> Option<usize> {
177    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
178}
179
180fn default_format_max_tokens() -> Option<usize> {
181    Some(DEFAULT_FORMAT_MAX_TOKENS)
182}
183
184fn default_format_max_ast_nodes() -> Option<usize> {
185    Some(DEFAULT_FORMAT_MAX_AST_NODES)
186}
187
188fn default_format_max_set_op_chain() -> Option<usize> {
189    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
190}
191
192/// Guard options for SQL pretty-formatting.
193///
194/// These limits protect against extremely large/complex queries that can cause
195/// high memory pressure in constrained runtimes (for example browser WASM).
196#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
197#[serde(rename_all = "camelCase")]
198pub struct FormatGuardOptions {
199    /// Maximum allowed SQL input size in bytes.
200    /// `None` disables this check.
201    #[serde(default = "default_format_max_input_bytes")]
202    pub max_input_bytes: Option<usize>,
203    /// Maximum allowed number of tokens after tokenization.
204    /// `None` disables this check.
205    #[serde(default = "default_format_max_tokens")]
206    pub max_tokens: Option<usize>,
207    /// Maximum allowed AST node count after parsing.
208    /// `None` disables this check.
209    #[serde(default = "default_format_max_ast_nodes")]
210    pub max_ast_nodes: Option<usize>,
211    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
212    /// observed in a statement before parsing.
213    ///
214    /// `None` disables this check.
215    #[serde(default = "default_format_max_set_op_chain")]
216    pub max_set_op_chain: Option<usize>,
217}
218
219impl Default for FormatGuardOptions {
220    fn default() -> Self {
221        Self {
222            max_input_bytes: default_format_max_input_bytes(),
223            max_tokens: default_format_max_tokens(),
224            max_ast_nodes: default_format_max_ast_nodes(),
225            max_set_op_chain: default_format_max_set_op_chain(),
226        }
227    }
228}
229
230fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
231    Error::generate(format!(
232        "{code}: value {actual} exceeds configured limit {limit}"
233    ))
234}
235
236fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
237    if let Some(max) = options.max_input_bytes {
238        let input_bytes = sql.len();
239        if input_bytes > max {
240            return Err(format_guard_error(
241                "E_GUARD_INPUT_TOO_LARGE",
242                input_bytes,
243                max,
244            ));
245        }
246    }
247    Ok(())
248}
249
250fn parse_with_token_guard(
251    sql: &str,
252    dialect: &Dialect,
253    options: &FormatGuardOptions,
254) -> Result<Vec<Expression>> {
255    let tokens = dialect.tokenize(sql)?;
256    if let Some(max) = options.max_tokens {
257        let token_count = tokens.len();
258        if token_count > max {
259            return Err(format_guard_error(
260                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
261                token_count,
262                max,
263            ));
264        }
265    }
266    enforce_set_op_chain_guard(&tokens, options)?;
267
268    let config = crate::parser::ParserConfig {
269        dialect: Some(dialect.dialect_type()),
270        ..Default::default()
271    };
272    let mut parser = Parser::with_source(tokens, config, sql.to_string());
273    parser.parse()
274}
275
276fn is_trivia_token(token_type: TokenType) -> bool {
277    matches!(
278        token_type,
279        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
280    )
281}
282
283fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
284    tokens
285        .iter()
286        .skip(start)
287        .find(|token| !is_trivia_token(token.token_type))
288}
289
290fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
291    let token = &tokens[idx];
292    match token.token_type {
293        TokenType::Union | TokenType::Intersect => true,
294        TokenType::Except => {
295            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
296            // is a function call rather than a set operation.
297            if token.text.eq_ignore_ascii_case("minus")
298                && matches!(
299                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
300                    Some(TokenType::LParen)
301                )
302            {
303                return false;
304            }
305            true
306        }
307        _ => false,
308    }
309}
310
311fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
312    let Some(max) = options.max_set_op_chain else {
313        return Ok(());
314    };
315
316    let mut set_op_count = 0usize;
317    for (idx, token) in tokens.iter().enumerate() {
318        if token.token_type == TokenType::Semicolon {
319            set_op_count = 0;
320            continue;
321        }
322
323        if is_set_operation_token(tokens, idx) {
324            set_op_count += 1;
325            if set_op_count > max {
326                return Err(format_guard_error(
327                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
328                    set_op_count,
329                    max,
330                ));
331            }
332        }
333    }
334
335    Ok(())
336}
337
338fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
339    if let Some(max) = options.max_ast_nodes {
340        let ast_nodes: usize = expressions.iter().map(node_count).sum();
341        if ast_nodes > max {
342            return Err(format_guard_error(
343                "E_GUARD_AST_BUDGET_EXCEEDED",
344                ast_nodes,
345                max,
346            ));
347        }
348    }
349    Ok(())
350}
351
352fn format_with_dialect(
353    sql: &str,
354    dialect: &Dialect,
355    options: &FormatGuardOptions,
356) -> Result<Vec<String>> {
357    enforce_input_guard(sql, options)?;
358    let expressions = parse_with_token_guard(sql, dialect, options)?;
359    enforce_ast_guard(&expressions, options)?;
360
361    expressions
362        .iter()
363        .map(|expr| dialect.generate_pretty(expr))
364        .collect()
365}
366
367/// Transpile SQL from one dialect to another.
368///
369/// # Arguments
370/// * `sql` - The SQL string to transpile
371/// * `read` - The source dialect to parse with
372/// * `write` - The target dialect to generate
373///
374/// # Returns
375/// A vector of transpiled SQL statements
376///
377/// # Example
378/// ```
379/// use polyglot_sql::{transpile, DialectType};
380///
381/// let result = transpile(
382///     "SELECT EPOCH_MS(1618088028295)",
383///     DialectType::DuckDB,
384///     DialectType::Hive
385/// );
386/// ```
387pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
388    let read_dialect = Dialect::get(read);
389    let write_dialect = Dialect::get(write);
390    let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
391
392    let expressions = read_dialect.parse(sql)?;
393
394    expressions
395        .into_iter()
396        .map(|expr| {
397            if generic_identity {
398                write_dialect.generate_with_source(&expr, read)
399            } else {
400                let transformed = write_dialect.transform(expr)?;
401                write_dialect.generate_with_source(&transformed, read)
402            }
403        })
404        .collect()
405}
406
407/// Parse SQL into an AST.
408///
409/// # Arguments
410/// * `sql` - The SQL string to parse
411/// * `dialect` - The dialect to use for parsing
412///
413/// # Returns
414/// A vector of parsed expressions
415pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
416    let d = Dialect::get(dialect);
417    d.parse(sql)
418}
419
420/// Parse a single SQL statement.
421///
422/// # Arguments
423/// * `sql` - The SQL string containing a single statement
424/// * `dialect` - The dialect to use for parsing
425///
426/// # Returns
427/// The parsed expression, or an error if multiple statements found
428pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
429    let mut expressions = parse(sql, dialect)?;
430
431    if expressions.len() != 1 {
432        return Err(Error::parse(
433            format!("Expected 1 statement, found {}", expressions.len()),
434            0,
435            0,
436            0,
437            0,
438        ));
439    }
440
441    Ok(expressions.remove(0))
442}
443
444/// Generate SQL from an AST.
445///
446/// # Arguments
447/// * `expression` - The expression to generate SQL from
448/// * `dialect` - The target dialect
449///
450/// # Returns
451/// The generated SQL string
452pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
453    let d = Dialect::get(dialect);
454    d.generate(expression)
455}
456
457/// Format/pretty-print SQL statements.
458///
459/// Uses [`FormatGuardOptions::default`] guards.
460pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
461    format_with_options(sql, dialect, &FormatGuardOptions::default())
462}
463
464/// Format/pretty-print SQL statements with configurable guard limits.
465pub fn format_with_options(
466    sql: &str,
467    dialect: DialectType,
468    options: &FormatGuardOptions,
469) -> Result<Vec<String>> {
470    let d = Dialect::get(dialect);
471    format_with_dialect(sql, &d, options)
472}
473
474/// Validate SQL syntax.
475///
476/// # Arguments
477/// * `sql` - The SQL string to validate
478/// * `dialect` - The dialect to use for validation
479///
480/// # Returns
481/// A validation result with any errors found
482pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
483    validate_with_options(sql, dialect, &ValidationOptions::default())
484}
485
486/// Options for syntax validation behavior.
487#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
488#[serde(rename_all = "camelCase")]
489pub struct ValidationOptions {
490    /// When enabled, validation rejects non-canonical trailing commas that the parser
491    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
492    #[serde(default)]
493    pub strict_syntax: bool,
494}
495
496/// Validate SQL syntax with additional validation options.
497pub fn validate_with_options(
498    sql: &str,
499    dialect: DialectType,
500    options: &ValidationOptions,
501) -> ValidationResult {
502    let d = Dialect::get(dialect);
503    match d.parse(sql) {
504        Ok(expressions) => {
505            // Reject bare expressions that aren't valid SQL statements.
506            // The parser accepts any expression at the top level, but bare identifiers,
507            // literals, function calls, etc. are not valid statements.
508            for expr in &expressions {
509                if !expr.is_statement() {
510                    let msg = format!("Invalid expression / Unexpected token");
511                    return ValidationResult::with_errors(vec![ValidationError::error(
512                        msg, "E004",
513                    )]);
514                }
515            }
516            if options.strict_syntax {
517                if let Some(error) = strict_syntax_error(sql, &d) {
518                    return ValidationResult::with_errors(vec![error]);
519                }
520            }
521            ValidationResult::success()
522        }
523        Err(e) => {
524            let error = match &e {
525                Error::Syntax {
526                    message,
527                    line,
528                    column,
529                    start,
530                    end,
531                } => ValidationError::error(message.clone(), "E001")
532                    .with_location(*line, *column)
533                    .with_span(Some(*start), Some(*end)),
534                Error::Tokenize {
535                    message,
536                    line,
537                    column,
538                    start,
539                    end,
540                } => ValidationError::error(message.clone(), "E002")
541                    .with_location(*line, *column)
542                    .with_span(Some(*start), Some(*end)),
543                Error::Parse {
544                    message,
545                    line,
546                    column,
547                    start,
548                    end,
549                } => ValidationError::error(message.clone(), "E003")
550                    .with_location(*line, *column)
551                    .with_span(Some(*start), Some(*end)),
552                _ => ValidationError::error(e.to_string(), "E000"),
553            };
554            ValidationResult::with_errors(vec![error])
555        }
556    }
557}
558
559fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
560    let tokens = dialect.tokenize(sql).ok()?;
561
562    for (idx, token) in tokens.iter().enumerate() {
563        if token.token_type != TokenType::Comma {
564            continue;
565        }
566
567        let next = tokens.get(idx + 1);
568        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
569            Some(TokenType::From) => (true, "FROM"),
570            Some(TokenType::Where) => (true, "WHERE"),
571            Some(TokenType::GroupBy) => (true, "GROUP BY"),
572            Some(TokenType::Having) => (true, "HAVING"),
573            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
574            Some(TokenType::Limit) => (true, "LIMIT"),
575            Some(TokenType::Offset) => (true, "OFFSET"),
576            Some(TokenType::Union) => (true, "UNION"),
577            Some(TokenType::Intersect) => (true, "INTERSECT"),
578            Some(TokenType::Except) => (true, "EXCEPT"),
579            Some(TokenType::Qualify) => (true, "QUALIFY"),
580            Some(TokenType::Window) => (true, "WINDOW"),
581            Some(TokenType::Semicolon) | None => (true, "end of statement"),
582            _ => (false, ""),
583        };
584
585        if is_boundary {
586            let message = format!(
587                "Trailing comma before {} is not allowed in strict syntax mode",
588                boundary_name
589            );
590            return Some(
591                ValidationError::error(message, "E005")
592                    .with_location(token.span.line, token.span.column),
593            );
594        }
595    }
596
597    None
598}
599
600/// Transpile SQL from one dialect to another, using string dialect names.
601///
602/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
603/// custom dialects registered via [`CustomDialectBuilder`].
604///
605/// # Arguments
606/// * `sql` - The SQL string to transpile
607/// * `read` - The source dialect name
608/// * `write` - The target dialect name
609///
610/// # Returns
611/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
612pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
613    let read_dialect = Dialect::get_by_name(read)
614        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
615    let write_dialect = Dialect::get_by_name(write)
616        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
617    let generic_identity = read_dialect.dialect_type() == DialectType::Generic
618        && write_dialect.dialect_type() == DialectType::Generic;
619
620    let expressions = read_dialect.parse(sql)?;
621
622    expressions
623        .into_iter()
624        .map(|expr| {
625            if generic_identity {
626                write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
627            } else {
628                let transformed = write_dialect.transform(expr)?;
629                write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
630            }
631        })
632        .collect()
633}
634
635/// Parse SQL into an AST using a string dialect name.
636///
637/// Supports both built-in and custom dialect names.
638pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
639    let d = Dialect::get_by_name(dialect)
640        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
641    d.parse(sql)
642}
643
644/// Generate SQL from an AST using a string dialect name.
645///
646/// Supports both built-in and custom dialect names.
647pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
648    let d = Dialect::get_by_name(dialect)
649        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
650    d.generate(expression)
651}
652
653/// Format SQL using a string dialect name.
654///
655/// Uses [`FormatGuardOptions::default`] guards.
656pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
657    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
658}
659
660/// Format SQL using a string dialect name with configurable guard limits.
661pub fn format_with_options_by_name(
662    sql: &str,
663    dialect: &str,
664    options: &FormatGuardOptions,
665) -> Result<Vec<String>> {
666    let d = Dialect::get_by_name(dialect)
667        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
668    format_with_dialect(sql, &d, options)
669}
670
671#[cfg(test)]
672mod validation_tests {
673    use super::*;
674
675    #[test]
676    fn validate_is_permissive_by_default_for_trailing_commas() {
677        let result = validate("SELECT name, FROM employees", DialectType::Generic);
678        assert!(result.valid, "Result: {:?}", result.errors);
679    }
680
681    #[test]
682    fn validate_with_options_rejects_trailing_comma_before_from() {
683        let options = ValidationOptions {
684            strict_syntax: true,
685        };
686        let result = validate_with_options(
687            "SELECT name, FROM employees",
688            DialectType::Generic,
689            &options,
690        );
691        assert!(!result.valid, "Result should be invalid");
692        assert!(
693            result.errors.iter().any(|e| e.code == "E005"),
694            "Expected E005, got: {:?}",
695            result.errors
696        );
697    }
698
699    #[test]
700    fn validate_with_options_rejects_trailing_comma_before_where() {
701        let options = ValidationOptions {
702            strict_syntax: true,
703        };
704        let result = validate_with_options(
705            "SELECT name FROM employees, WHERE salary > 10",
706            DialectType::Generic,
707            &options,
708        );
709        assert!(!result.valid, "Result should be invalid");
710        assert!(
711            result.errors.iter().any(|e| e.code == "E005"),
712            "Expected E005, got: {:?}",
713            result.errors
714        );
715    }
716}
717
718#[cfg(test)]
719mod format_tests {
720    use super::*;
721
722    #[test]
723    fn format_basic_query() {
724        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
725        assert_eq!(result.len(), 1);
726        assert!(result[0].contains('\n'));
727    }
728
729    #[test]
730    fn format_guard_rejects_large_input() {
731        let options = FormatGuardOptions {
732            max_input_bytes: Some(7),
733            max_tokens: None,
734            max_ast_nodes: None,
735            max_set_op_chain: None,
736        };
737        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
738            .expect_err("expected guard error");
739        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
740    }
741
742    #[test]
743    fn format_guard_rejects_token_budget() {
744        let options = FormatGuardOptions {
745            max_input_bytes: None,
746            max_tokens: Some(1),
747            max_ast_nodes: None,
748            max_set_op_chain: None,
749        };
750        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
751            .expect_err("expected guard error");
752        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
753    }
754
755    #[test]
756    fn format_guard_rejects_ast_budget() {
757        let options = FormatGuardOptions {
758            max_input_bytes: None,
759            max_tokens: None,
760            max_ast_nodes: Some(1),
761            max_set_op_chain: None,
762        };
763        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
764            .expect_err("expected guard error");
765        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
766    }
767
768    #[test]
769    fn format_guard_rejects_set_op_chain_budget() {
770        let options = FormatGuardOptions {
771            max_input_bytes: None,
772            max_tokens: None,
773            max_ast_nodes: None,
774            max_set_op_chain: Some(1),
775        };
776        let err = format_with_options(
777            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
778            DialectType::Generic,
779            &options,
780        )
781        .expect_err("expected guard error");
782        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
783    }
784
785    #[test]
786    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
787        let options = FormatGuardOptions {
788            max_input_bytes: None,
789            max_tokens: None,
790            max_ast_nodes: None,
791            max_set_op_chain: Some(0),
792        };
793        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
794        assert!(result.is_ok(), "Result: {:?}", result);
795    }
796
797    #[test]
798    fn format_default_guard_rejects_deep_union_chain_before_parse() {
799        let base = "SELECT col0, col1 FROM t";
800        let mut sql = base.to_string();
801        for _ in 0..1100 {
802            sql.push_str(" UNION ALL ");
803            sql.push_str(base);
804        }
805
806        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
807        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
808    }
809}