Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
16pub mod ast_transforms;
17#[cfg(feature = "builder")]
18pub mod builder;
19pub mod dialects;
20#[cfg(feature = "diff")]
21pub mod diff;
22pub mod error;
23pub mod expressions;
24#[cfg(feature = "semantic")]
25pub mod function_catalog;
26mod function_registry;
27#[cfg(feature = "generate")]
28pub mod generator;
29#[cfg(feature = "semantic")]
30pub mod helper;
31#[cfg(feature = "semantic")]
32pub mod lineage;
33#[cfg(feature = "openlineage")]
34pub mod openlineage;
35#[cfg(feature = "semantic")]
36pub mod optimizer;
37pub mod parser;
38#[cfg(feature = "planner")]
39pub mod planner;
40#[cfg(feature = "semantic")]
41pub mod resolver;
42#[cfg(feature = "semantic")]
43pub mod schema;
44#[cfg(feature = "semantic")]
45pub mod scope;
46#[cfg(feature = "time")]
47pub mod time;
48pub mod tokens;
49#[cfg(feature = "transpile")]
50pub mod transforms;
51#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
52pub mod traversal;
53#[cfg(any(feature = "semantic", feature = "time"))]
54pub mod trie;
55#[cfg(feature = "semantic")]
56pub mod validation;
57
58#[cfg(any(feature = "generate", feature = "semantic"))]
59use serde::{Deserialize, Serialize};
60
61#[cfg(feature = "ast-tools")]
62pub use ast_transforms::{
63    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
64    get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
65    get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
66    remove_select_columns, remove_where, rename_columns, rename_tables, rename_tables_with_options,
67    replace_by_type, replace_nodes, set_distinct, set_limit, set_offset, RenameTablesOptions,
68};
69pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
70#[cfg(feature = "transpile")]
71pub use dialects::{TranspileOptions, TranspileTarget};
72pub use error::{Error, Result};
73#[cfg(feature = "semantic")]
74pub use error::{ValidationError, ValidationResult, ValidationSeverity};
75pub use expressions::Expression;
76#[cfg(feature = "semantic")]
77pub use function_catalog::{
78    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
79};
80#[cfg(feature = "generate")]
81pub use generator::Generator;
82#[cfg(feature = "semantic")]
83pub use helper::{
84    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
85    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
86};
87#[cfg(feature = "semantic")]
88pub use optimizer::{
89    annotate_types, qualify_tables, QualifyTablesOptions, TypeAnnotator, TypeCoercionClass,
90};
91pub use parser::Parser;
92#[cfg(feature = "semantic")]
93pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
94#[cfg(feature = "semantic")]
95pub use schema::{
96    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
97};
98#[cfg(feature = "semantic")]
99pub use scope::{
100    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
101    ScopeType, SourceInfo,
102};
103#[cfg(feature = "time")]
104pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
105pub use tokens::{Token, TokenType, Tokenizer};
106#[cfg(feature = "ast-tools")]
107pub use traversal::{
108    contains_aggregate,
109    contains_subquery,
110    contains_window_function,
111    find_ancestor,
112    find_parent,
113    get_all_tables,
114    get_columns,
115    get_merge_source,
116    get_merge_target,
117    get_tables,
118    is_add,
119    is_aggregate,
120    is_alias,
121    is_alter_table,
122    is_and,
123    is_arithmetic,
124    is_avg,
125    is_between,
126    is_boolean,
127    is_case,
128    is_cast,
129    is_coalesce,
130    is_column,
131    is_comparison,
132    is_concat,
133    is_count,
134    is_create_index,
135    is_create_table,
136    is_create_view,
137    is_cte,
138    is_ddl,
139    is_delete,
140    is_div,
141    is_drop_index,
142    is_drop_table,
143    is_drop_view,
144    is_eq,
145    is_except,
146    is_exists,
147    is_from,
148    is_function,
149    is_group_by,
150    is_gt,
151    is_gte,
152    is_having,
153    is_identifier,
154    is_ilike,
155    is_in,
156    // Extended type predicates
157    is_insert,
158    is_intersect,
159    is_is_null,
160    is_join,
161    is_like,
162    is_limit,
163    is_literal,
164    is_logical,
165    is_lt,
166    is_lte,
167    is_max_func,
168    is_merge,
169    is_min_func,
170    is_mod,
171    is_mul,
172    is_neq,
173    is_not,
174    is_null_if,
175    is_null_literal,
176    is_offset,
177    is_or,
178    is_order_by,
179    is_ordered,
180    is_paren,
181    // Composite predicates
182    is_query,
183    is_safe_cast,
184    is_select,
185    is_set_operation,
186    is_star,
187    is_sub,
188    is_subquery,
189    is_sum,
190    is_table,
191    is_try_cast,
192    is_union,
193    is_update,
194    is_where,
195    is_window_function,
196    is_with,
197    transform,
198    transform_map,
199    BfsIter,
200    DfsIter,
201    ExpressionWalk,
202    ParentInfo,
203    TreeContext,
204};
205#[cfg(any(feature = "semantic", feature = "time"))]
206pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
207#[cfg(feature = "semantic")]
208pub use validation::{
209    mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
210    SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
211    SchemaValidationOptions, ValidationSchema,
212};
213
214#[cfg(feature = "generate")]
215const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
216#[cfg(feature = "generate")]
217const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
218#[cfg(feature = "generate")]
219const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
220#[cfg(feature = "generate")]
221const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
222
223#[cfg(feature = "generate")]
224fn default_format_max_input_bytes() -> Option<usize> {
225    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
226}
227
228#[cfg(feature = "generate")]
229fn default_format_max_tokens() -> Option<usize> {
230    Some(DEFAULT_FORMAT_MAX_TOKENS)
231}
232
233#[cfg(feature = "generate")]
234fn default_format_max_ast_nodes() -> Option<usize> {
235    Some(DEFAULT_FORMAT_MAX_AST_NODES)
236}
237
238#[cfg(feature = "generate")]
239fn default_format_max_set_op_chain() -> Option<usize> {
240    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
241}
242
243/// Guard options for SQL pretty-formatting.
244///
245/// These limits protect against extremely large/complex queries that can cause
246/// high memory pressure in constrained runtimes (for example browser WASM).
247#[cfg(feature = "generate")]
248#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
249#[serde(rename_all = "camelCase")]
250pub struct FormatGuardOptions {
251    /// Maximum allowed SQL input size in bytes.
252    /// `None` disables this check.
253    #[serde(default = "default_format_max_input_bytes")]
254    pub max_input_bytes: Option<usize>,
255    /// Maximum allowed number of tokens after tokenization.
256    /// `None` disables this check.
257    #[serde(default = "default_format_max_tokens")]
258    pub max_tokens: Option<usize>,
259    /// Maximum allowed AST node count after parsing.
260    /// `None` disables this check.
261    #[serde(default = "default_format_max_ast_nodes")]
262    pub max_ast_nodes: Option<usize>,
263    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
264    /// observed in a statement before parsing.
265    ///
266    /// `None` disables this check.
267    #[serde(default = "default_format_max_set_op_chain")]
268    pub max_set_op_chain: Option<usize>,
269}
270
271#[cfg(feature = "generate")]
272impl Default for FormatGuardOptions {
273    fn default() -> Self {
274        Self {
275            max_input_bytes: default_format_max_input_bytes(),
276            max_tokens: default_format_max_tokens(),
277            max_ast_nodes: default_format_max_ast_nodes(),
278            max_set_op_chain: default_format_max_set_op_chain(),
279        }
280    }
281}
282
283#[cfg(feature = "generate")]
284fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
285    Error::generate(format!(
286        "{code}: value {actual} exceeds configured limit {limit}"
287    ))
288}
289
290#[cfg(feature = "generate")]
291fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
292    if let Some(max) = options.max_input_bytes {
293        let input_bytes = sql.len();
294        if input_bytes > max {
295            return Err(format_guard_error(
296                "E_GUARD_INPUT_TOO_LARGE",
297                input_bytes,
298                max,
299            ));
300        }
301    }
302    Ok(())
303}
304
305#[cfg(feature = "generate")]
306fn parse_with_token_guard(
307    sql: &str,
308    dialect: &Dialect,
309    options: &FormatGuardOptions,
310) -> Result<Vec<Expression>> {
311    let tokens = dialect.tokenize(sql)?;
312    if let Some(max) = options.max_tokens {
313        let token_count = tokens.len();
314        if token_count > max {
315            return Err(format_guard_error(
316                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
317                token_count,
318                max,
319            ));
320        }
321    }
322    enforce_set_op_chain_guard(&tokens, options)?;
323
324    let config = crate::parser::ParserConfig {
325        dialect: Some(dialect.dialect_type()),
326        ..Default::default()
327    };
328    let mut parser = Parser::with_source(tokens, config, sql.to_string());
329    parser.parse()
330}
331
332#[cfg(feature = "generate")]
333fn is_trivia_token(token_type: TokenType) -> bool {
334    matches!(
335        token_type,
336        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
337    )
338}
339
340#[cfg(feature = "generate")]
341fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
342    tokens
343        .iter()
344        .skip(start)
345        .find(|token| !is_trivia_token(token.token_type))
346}
347
348#[cfg(feature = "generate")]
349fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
350    let token = &tokens[idx];
351    match token.token_type {
352        TokenType::Union | TokenType::Intersect => true,
353        TokenType::Except => {
354            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
355            // is a function call rather than a set operation.
356            if token.text.eq_ignore_ascii_case("minus")
357                && matches!(
358                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
359                    Some(TokenType::LParen)
360                )
361            {
362                return false;
363            }
364            true
365        }
366        _ => false,
367    }
368}
369
370#[cfg(feature = "generate")]
371fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
372    let Some(max) = options.max_set_op_chain else {
373        return Ok(());
374    };
375
376    let mut set_op_count = 0usize;
377    for (idx, token) in tokens.iter().enumerate() {
378        if token.token_type == TokenType::Semicolon {
379            set_op_count = 0;
380            continue;
381        }
382
383        if is_set_operation_token(tokens, idx) {
384            set_op_count += 1;
385            if set_op_count > max {
386                return Err(format_guard_error(
387                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
388                    set_op_count,
389                    max,
390                ));
391            }
392        }
393    }
394
395    Ok(())
396}
397
398#[cfg(feature = "generate")]
399fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
400    if let Some(max) = options.max_ast_nodes {
401        let ast_nodes: usize = expressions
402            .iter()
403            .map(crate::ast_transforms::node_count)
404            .sum();
405        if ast_nodes > max {
406            return Err(format_guard_error(
407                "E_GUARD_AST_BUDGET_EXCEEDED",
408                ast_nodes,
409                max,
410            ));
411        }
412    }
413    Ok(())
414}
415
416#[cfg(feature = "generate")]
417fn format_with_dialect(
418    sql: &str,
419    dialect: &Dialect,
420    options: &FormatGuardOptions,
421) -> Result<Vec<String>> {
422    enforce_input_guard(sql, options)?;
423    let expressions = parse_with_token_guard(sql, dialect, options)?;
424    enforce_ast_guard(&expressions, options)?;
425
426    expressions
427        .iter()
428        .map(|expr| dialect.generate_pretty(expr))
429        .collect()
430}
431
432/// Transpile SQL from one dialect to another.
433///
434/// # Arguments
435/// * `sql` - The SQL string to transpile
436/// * `read` - The source dialect to parse with
437/// * `write` - The target dialect to generate
438///
439/// # Returns
440/// A vector of transpiled SQL statements
441///
442/// # Example
443/// ```
444/// use polyglot_sql::{transpile, DialectType};
445///
446/// let result = transpile(
447///     "SELECT EPOCH_MS(1618088028295)",
448///     DialectType::DuckDB,
449///     DialectType::Hive
450/// );
451/// ```
452#[cfg(feature = "transpile")]
453pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
454    // Delegate to Dialect::transpile so that the full cross-dialect rewrite
455    // pipeline (source+target-aware normalization in `cross_dialect_normalize`)
456    // runs here as well. This keeps Rust crate users on the same code path as
457    // the WASM/FFI/Python bindings and the playground.
458    Dialect::get(read).transpile(sql, write)
459}
460
461/// Parse SQL into an AST.
462///
463/// # Arguments
464/// * `sql` - The SQL string to parse
465/// * `dialect` - The dialect to use for parsing
466///
467/// # Returns
468/// A vector of parsed expressions
469pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
470    let d = Dialect::get(dialect);
471    d.parse(sql)
472}
473
474/// Parse a single SQL statement.
475///
476/// # Arguments
477/// * `sql` - The SQL string containing a single statement
478/// * `dialect` - The dialect to use for parsing
479///
480/// # Returns
481/// The parsed expression, or an error if multiple statements found
482pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
483    let mut expressions = parse(sql, dialect)?;
484
485    if expressions.len() != 1 {
486        return Err(Error::parse(
487            format!("Expected 1 statement, found {}", expressions.len()),
488            0,
489            0,
490            0,
491            0,
492        ));
493    }
494
495    Ok(expressions.remove(0))
496}
497
498/// Generate SQL from an AST.
499///
500/// # Arguments
501/// * `expression` - The expression to generate SQL from
502/// * `dialect` - The target dialect
503///
504/// # Returns
505/// The generated SQL string
506#[cfg(feature = "generate")]
507pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
508    let d = Dialect::get(dialect);
509    d.generate(expression)
510}
511
512/// Format/pretty-print SQL statements.
513///
514/// Uses [`FormatGuardOptions::default`] guards.
515#[cfg(feature = "generate")]
516pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
517    format_with_options(sql, dialect, &FormatGuardOptions::default())
518}
519
520/// Format/pretty-print SQL statements with configurable guard limits.
521#[cfg(feature = "generate")]
522pub fn format_with_options(
523    sql: &str,
524    dialect: DialectType,
525    options: &FormatGuardOptions,
526) -> Result<Vec<String>> {
527    let d = Dialect::get(dialect);
528    format_with_dialect(sql, &d, options)
529}
530
531/// Validate SQL syntax.
532///
533/// # Arguments
534/// * `sql` - The SQL string to validate
535/// * `dialect` - The dialect to use for validation
536///
537/// # Returns
538/// A validation result with any errors found
539#[cfg(feature = "semantic")]
540pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
541    validate_with_options(sql, dialect, &ValidationOptions::default())
542}
543
544/// Options for syntax validation behavior.
545#[cfg(feature = "semantic")]
546#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
547#[serde(rename_all = "camelCase")]
548pub struct ValidationOptions {
549    /// When enabled, validation rejects non-canonical trailing commas that the parser
550    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
551    #[serde(default)]
552    pub strict_syntax: bool,
553}
554
555/// Validate SQL syntax with additional validation options.
556#[cfg(feature = "semantic")]
557pub fn validate_with_options(
558    sql: &str,
559    dialect: DialectType,
560    options: &ValidationOptions,
561) -> ValidationResult {
562    let d = Dialect::get(dialect);
563    match d.parse(sql) {
564        Ok(expressions) => {
565            // Reject bare expressions that aren't valid SQL statements.
566            // The parser accepts any expression at the top level, but bare identifiers,
567            // literals, function calls, etc. are not valid statements.
568            for expr in &expressions {
569                if !expr.is_statement() {
570                    let msg = format!("Invalid expression / Unexpected token");
571                    return ValidationResult::with_errors(vec![ValidationError::error(
572                        msg, "E004",
573                    )]);
574                }
575            }
576            if options.strict_syntax {
577                if let Some(error) = strict_syntax_error(sql, &d) {
578                    return ValidationResult::with_errors(vec![error]);
579                }
580            }
581            ValidationResult::success()
582        }
583        Err(e) => {
584            let error = match &e {
585                Error::Syntax {
586                    message,
587                    line,
588                    column,
589                    start,
590                    end,
591                } => ValidationError::error(message.clone(), "E001")
592                    .with_location(*line, *column)
593                    .with_span(Some(*start), Some(*end)),
594                Error::Tokenize {
595                    message,
596                    line,
597                    column,
598                    start,
599                    end,
600                } => ValidationError::error(message.clone(), "E002")
601                    .with_location(*line, *column)
602                    .with_span(Some(*start), Some(*end)),
603                Error::Parse {
604                    message,
605                    line,
606                    column,
607                    start,
608                    end,
609                } => ValidationError::error(message.clone(), "E003")
610                    .with_location(*line, *column)
611                    .with_span(Some(*start), Some(*end)),
612                _ => ValidationError::error(e.to_string(), "E000"),
613            };
614            ValidationResult::with_errors(vec![error])
615        }
616    }
617}
618
619#[cfg(feature = "semantic")]
620fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
621    let tokens = dialect.tokenize(sql).ok()?;
622
623    for (idx, token) in tokens.iter().enumerate() {
624        if token.token_type != TokenType::Comma {
625            continue;
626        }
627
628        let next = tokens.get(idx + 1);
629        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
630            Some(TokenType::From) => (true, "FROM"),
631            Some(TokenType::Where) => (true, "WHERE"),
632            Some(TokenType::GroupBy) => (true, "GROUP BY"),
633            Some(TokenType::Having) => (true, "HAVING"),
634            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
635            Some(TokenType::Limit) => (true, "LIMIT"),
636            Some(TokenType::Offset) => (true, "OFFSET"),
637            Some(TokenType::Union) => (true, "UNION"),
638            Some(TokenType::Intersect) => (true, "INTERSECT"),
639            Some(TokenType::Except) => (true, "EXCEPT"),
640            Some(TokenType::Qualify) => (true, "QUALIFY"),
641            Some(TokenType::Window) => (true, "WINDOW"),
642            Some(TokenType::Semicolon) | None => (true, "end of statement"),
643            _ => (false, ""),
644        };
645
646        if is_boundary {
647            let message = format!(
648                "Trailing comma before {} is not allowed in strict syntax mode",
649                boundary_name
650            );
651            return Some(
652                ValidationError::error(message, "E005")
653                    .with_location(token.span.line, token.span.column),
654            );
655        }
656    }
657
658    None
659}
660
661/// Transpile SQL from one dialect to another, using string dialect names.
662///
663/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
664/// custom dialects registered via [`CustomDialectBuilder`].
665///
666/// # Arguments
667/// * `sql` - The SQL string to transpile
668/// * `read` - The source dialect name
669/// * `write` - The target dialect name
670///
671/// # Returns
672/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
673#[cfg(feature = "transpile")]
674pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
675    transpile_with_by_name(sql, read, write, &TranspileOptions::default())
676}
677
678/// Transpile SQL with configurable [`TranspileOptions`], using string dialect names.
679///
680/// Same as [`transpile_by_name`] but accepts options (e.g., pretty-printing).
681#[cfg(feature = "transpile")]
682pub fn transpile_with_by_name(
683    sql: &str,
684    read: &str,
685    write: &str,
686    opts: &TranspileOptions,
687) -> Result<Vec<String>> {
688    let read_dialect = Dialect::get_by_name(read)
689        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
690    let write_dialect = Dialect::get_by_name(write)
691        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
692    read_dialect.transpile_with(sql, &write_dialect, opts.clone())
693}
694
695/// Parse SQL into an AST using a string dialect name.
696///
697/// Supports both built-in and custom dialect names.
698pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
699    let d = Dialect::get_by_name(dialect)
700        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
701    d.parse(sql)
702}
703
704/// Generate SQL from an AST using a string dialect name.
705///
706/// Supports both built-in and custom dialect names.
707#[cfg(feature = "generate")]
708pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
709    let d = Dialect::get_by_name(dialect)
710        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
711    d.generate(expression)
712}
713
714/// Format SQL using a string dialect name.
715///
716/// Uses [`FormatGuardOptions::default`] guards.
717#[cfg(feature = "generate")]
718pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
719    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
720}
721
722/// Format SQL using a string dialect name with configurable guard limits.
723#[cfg(feature = "generate")]
724pub fn format_with_options_by_name(
725    sql: &str,
726    dialect: &str,
727    options: &FormatGuardOptions,
728) -> Result<Vec<String>> {
729    let d = Dialect::get_by_name(dialect)
730        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
731    format_with_dialect(sql, &d, options)
732}
733
734#[cfg(all(test, feature = "semantic"))]
735mod validation_tests {
736    use super::*;
737
738    #[test]
739    fn validate_is_permissive_by_default_for_trailing_commas() {
740        let result = validate("SELECT name, FROM employees", DialectType::Generic);
741        assert!(result.valid, "Result: {:?}", result.errors);
742    }
743
744    #[test]
745    fn validate_with_options_rejects_trailing_comma_before_from() {
746        let options = ValidationOptions {
747            strict_syntax: true,
748        };
749        let result = validate_with_options(
750            "SELECT name, FROM employees",
751            DialectType::Generic,
752            &options,
753        );
754        assert!(!result.valid, "Result should be invalid");
755        assert!(
756            result.errors.iter().any(|e| e.code == "E005"),
757            "Expected E005, got: {:?}",
758            result.errors
759        );
760    }
761
762    #[test]
763    fn validate_with_options_rejects_trailing_comma_before_where() {
764        let options = ValidationOptions {
765            strict_syntax: true,
766        };
767        let result = validate_with_options(
768            "SELECT name FROM employees, WHERE salary > 10",
769            DialectType::Generic,
770            &options,
771        );
772        assert!(!result.valid, "Result should be invalid");
773        assert!(
774            result.errors.iter().any(|e| e.code == "E005"),
775            "Expected E005, got: {:?}",
776            result.errors
777        );
778    }
779}
780
781#[cfg(all(test, feature = "generate"))]
782mod format_tests {
783    use super::*;
784
785    #[test]
786    fn format_basic_query() {
787        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
788        assert_eq!(result.len(), 1);
789        assert!(result[0].contains('\n'));
790    }
791
792    #[test]
793    fn format_guard_rejects_large_input() {
794        let options = FormatGuardOptions {
795            max_input_bytes: Some(7),
796            max_tokens: None,
797            max_ast_nodes: None,
798            max_set_op_chain: None,
799        };
800        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
801            .expect_err("expected guard error");
802        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
803    }
804
805    #[test]
806    fn format_guard_rejects_token_budget() {
807        let options = FormatGuardOptions {
808            max_input_bytes: None,
809            max_tokens: Some(1),
810            max_ast_nodes: None,
811            max_set_op_chain: None,
812        };
813        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
814            .expect_err("expected guard error");
815        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
816    }
817
818    #[test]
819    fn format_guard_rejects_ast_budget() {
820        let options = FormatGuardOptions {
821            max_input_bytes: None,
822            max_tokens: None,
823            max_ast_nodes: Some(1),
824            max_set_op_chain: None,
825        };
826        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
827            .expect_err("expected guard error");
828        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
829    }
830
831    #[test]
832    fn format_guard_rejects_set_op_chain_budget() {
833        let options = FormatGuardOptions {
834            max_input_bytes: None,
835            max_tokens: None,
836            max_ast_nodes: None,
837            max_set_op_chain: Some(1),
838        };
839        let err = format_with_options(
840            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
841            DialectType::Generic,
842            &options,
843        )
844        .expect_err("expected guard error");
845        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
846    }
847
848    #[test]
849    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
850        let options = FormatGuardOptions {
851            max_input_bytes: None,
852            max_tokens: None,
853            max_ast_nodes: None,
854            max_set_op_chain: Some(0),
855        };
856        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
857        assert!(result.is_ok(), "Result: {:?}", result);
858    }
859
860    #[test]
861    fn issue57_invalid_ternary_returns_error() {
862        // https://github.com/tobilg/polyglot/issues/57
863        // Invalid SQL with ternary operator should return an error, not garbled output.
864        let sql = "SELECT x > 0 ? 1 : 0 FROM t";
865
866        let parse_result = parse(sql, DialectType::PostgreSQL);
867        assert!(
868            parse_result.is_err(),
869            "Expected parse error for invalid ternary SQL, got: {:?}",
870            parse_result
871        );
872
873        let format_result = format(sql, DialectType::PostgreSQL);
874        assert!(
875            format_result.is_err(),
876            "Expected format error for invalid ternary SQL, got: {:?}",
877            format_result
878        );
879
880        let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
881        assert!(
882            transpile_result.is_err(),
883            "Expected transpile error for invalid ternary SQL, got: {:?}",
884            transpile_result
885        );
886    }
887
888    /// Regression guard: `lib::transpile()` must apply the full cross-dialect
889    /// rewrite pipeline (same as `Dialect::transpile()`). If these two paths
890    /// diverge again, Rust crate users silently get under-transformed SQL that
891    /// differs from what WASM/FFI/Python bindings produce.
892    #[test]
893    fn transpile_applies_cross_dialect_rewrites() {
894        // DuckDB to_timestamp → Trino FROM_UNIXTIME (different input semantics).
895        let out = transpile(
896            "SELECT to_timestamp(col) FROM t",
897            DialectType::DuckDB,
898            DialectType::Trino,
899        )
900        .expect("transpile failed");
901        assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
902
903        // DuckDB CAST(x AS JSON) → Trino JSON_PARSE(x) (different CAST semantics).
904        let out = transpile(
905            "SELECT CAST(col AS JSON) FROM t",
906            DialectType::DuckDB,
907            DialectType::Trino,
908        )
909        .expect("transpile failed");
910        assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
911    }
912
913    /// Regression guard: all three transpile entry points (lib::transpile,
914    /// lib::transpile_by_name, Dialect::transpile) must produce identical
915    /// output. transpile_by_name is the one used by Python and C FFI bindings.
916    #[test]
917    fn transpile_matches_dialect_method() {
918        let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
919            (
920                DialectType::DuckDB,
921                DialectType::Trino,
922                "duckdb",
923                "trino",
924                "SELECT to_timestamp(col) FROM t",
925            ),
926            (
927                DialectType::DuckDB,
928                DialectType::Trino,
929                "duckdb",
930                "trino",
931                "SELECT CAST(col AS JSON) FROM t",
932            ),
933            (
934                DialectType::DuckDB,
935                DialectType::Trino,
936                "duckdb",
937                "trino",
938                "SELECT json_valid(col) FROM t",
939            ),
940            (
941                DialectType::Snowflake,
942                DialectType::DuckDB,
943                "snowflake",
944                "duckdb",
945                "SELECT DATEDIFF(day, a, b) FROM t",
946            ),
947            (
948                DialectType::BigQuery,
949                DialectType::DuckDB,
950                "bigquery",
951                "duckdb",
952                "SELECT DATE_DIFF(a, b, DAY) FROM t",
953            ),
954            (
955                DialectType::Generic,
956                DialectType::Generic,
957                "generic",
958                "generic",
959                "SELECT 1",
960            ),
961        ];
962        for (read, write, read_name, write_name, sql) in cases {
963            let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
964            let via_name = transpile_by_name(sql, read_name, write_name)
965                .expect("lib::transpile_by_name failed");
966            let via_dialect = Dialect::get(*read)
967                .transpile(sql, *write)
968                .expect("Dialect::transpile failed");
969            assert_eq!(
970                via_lib, via_dialect,
971                "lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
972                read, write
973            );
974            assert_eq!(
975                via_name, via_dialect,
976                "lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
977            );
978        }
979    }
980
981    #[test]
982    fn format_default_guard_rejects_deep_union_chain_before_parse() {
983        let base = "SELECT col0, col1 FROM t";
984        let mut sql = base.to_string();
985        for _ in 0..1100 {
986            sql.push_str(" UNION ALL ");
987            sql.push_str(base);
988        }
989
990        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
991        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
992    }
993}