Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_json;
16#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
17pub mod ast_transforms;
18#[cfg(feature = "builder")]
19pub mod builder;
20pub mod dialects;
21#[cfg(feature = "diff")]
22pub mod diff;
23pub mod error;
24pub mod expressions;
25#[cfg(feature = "semantic")]
26pub mod function_catalog;
27mod function_registry;
28#[cfg(feature = "generate")]
29pub mod generator;
30#[cfg(feature = "semantic")]
31pub mod helper;
32#[cfg(feature = "semantic")]
33pub mod lineage;
34#[cfg(feature = "openlineage")]
35pub mod openlineage;
36#[cfg(feature = "semantic")]
37pub mod optimizer;
38pub mod parser;
39#[cfg(feature = "planner")]
40pub mod planner;
41#[cfg(all(feature = "semantic", feature = "generate"))]
42pub mod query_analysis;
43#[cfg(feature = "semantic")]
44pub mod resolver;
45#[cfg(feature = "semantic")]
46pub mod schema;
47#[cfg(feature = "semantic")]
48pub mod scope;
49#[cfg(feature = "time")]
50pub mod time;
51pub mod tokens;
52#[cfg(feature = "transpile")]
53pub mod transforms;
54#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
55pub mod traversal;
56#[cfg(any(feature = "semantic", feature = "time"))]
57pub mod trie;
58#[cfg(feature = "semantic")]
59pub mod validation;
60
61#[cfg(any(feature = "generate", feature = "semantic"))]
62use serde::{Deserialize, Serialize};
63
64#[cfg(feature = "ast-tools")]
65pub use ast_transforms::{
66    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
67    get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
68    get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
69    remove_select_columns, remove_where, rename_columns, rename_tables, rename_tables_with_options,
70    replace_by_type, replace_nodes, set_distinct, set_limit, set_limit_expr, set_offset,
71    set_offset_expr, set_order_by, RenameTablesOptions,
72};
73pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
74#[cfg(feature = "transpile")]
75pub use dialects::{TranspileOptions, TranspileTarget};
76pub use error::{Error, Result};
77#[cfg(feature = "semantic")]
78pub use error::{ValidationError, ValidationResult, ValidationSeverity};
79pub use expressions::{DataType, Expression};
80#[cfg(feature = "semantic")]
81pub use function_catalog::{
82    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
83};
84#[cfg(feature = "generate")]
85pub use generator::{Generator, UnsupportedLevel};
86#[cfg(feature = "semantic")]
87pub use helper::{
88    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
89    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
90};
91#[cfg(feature = "semantic")]
92pub use optimizer::{
93    annotate_types, qualify_tables, QualifyTablesOptions, TypeAnnotator, TypeCoercionClass,
94};
95pub use parser::Parser;
96#[cfg(all(feature = "semantic", feature = "generate"))]
97pub use query_analysis::{
98    analyze_query, AnalyzeQueryOptions, ColumnReferenceFact, CteFact, ProjectionFact,
99    ProjectionNullability, QueryAnalysis, QueryShape, ReferenceConfidence, RelationFact,
100    SetOperationBranchFact, SetOperationFact, StarProjectionFact, TransformFunctionFact,
101    TransformKind,
102};
103#[cfg(feature = "semantic")]
104pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
105#[cfg(feature = "semantic")]
106pub use schema::{
107    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
108};
109#[cfg(feature = "semantic")]
110pub use scope::{
111    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
112    ScopeType, SourceInfo,
113};
114#[cfg(feature = "time")]
115pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
116pub use tokens::{Token, TokenType, Tokenizer};
117#[cfg(feature = "ast-tools")]
118pub use traversal::{
119    contains_aggregate,
120    contains_subquery,
121    contains_window_function,
122    find_ancestor,
123    find_parent,
124    get_all_tables,
125    get_columns,
126    get_merge_source,
127    get_merge_target,
128    get_tables,
129    is_add,
130    is_aggregate,
131    is_alias,
132    is_alter_table,
133    is_and,
134    is_arithmetic,
135    is_avg,
136    is_between,
137    is_boolean,
138    is_case,
139    is_cast,
140    is_coalesce,
141    is_column,
142    is_comparison,
143    is_concat,
144    is_count,
145    is_create_index,
146    is_create_table,
147    is_create_view,
148    is_cte,
149    is_ddl,
150    is_delete,
151    is_div,
152    is_drop_index,
153    is_drop_table,
154    is_drop_view,
155    is_eq,
156    is_except,
157    is_exists,
158    is_from,
159    is_function,
160    is_group_by,
161    is_gt,
162    is_gte,
163    is_having,
164    is_identifier,
165    is_ilike,
166    is_in,
167    // Extended type predicates
168    is_insert,
169    is_intersect,
170    is_is_null,
171    is_join,
172    is_like,
173    is_limit,
174    is_literal,
175    is_logical,
176    is_lt,
177    is_lte,
178    is_max_func,
179    is_merge,
180    is_min_func,
181    is_mod,
182    is_mul,
183    is_neq,
184    is_not,
185    is_null_if,
186    is_null_literal,
187    is_offset,
188    is_or,
189    is_order_by,
190    is_ordered,
191    is_paren,
192    // Composite predicates
193    is_query,
194    is_safe_cast,
195    is_select,
196    is_set_operation,
197    is_star,
198    is_sub,
199    is_subquery,
200    is_sum,
201    is_table,
202    is_try_cast,
203    is_union,
204    is_update,
205    is_where,
206    is_window_function,
207    is_with,
208    transform,
209    transform_map,
210    BfsIter,
211    DfsIter,
212    ExpressionWalk,
213    ParentInfo,
214    TreeContext,
215};
216#[cfg(any(feature = "semantic", feature = "time"))]
217pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
218#[cfg(feature = "semantic")]
219pub use validation::{
220    mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
221    SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
222    SchemaValidationOptions, ValidationSchema,
223};
224
225#[cfg(feature = "generate")]
226const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
227#[cfg(feature = "generate")]
228const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
229#[cfg(feature = "generate")]
230const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
231#[cfg(feature = "generate")]
232const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
233
234#[cfg(feature = "generate")]
235fn default_format_max_input_bytes() -> Option<usize> {
236    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
237}
238
239#[cfg(feature = "generate")]
240fn default_format_max_tokens() -> Option<usize> {
241    Some(DEFAULT_FORMAT_MAX_TOKENS)
242}
243
244#[cfg(feature = "generate")]
245fn default_format_max_ast_nodes() -> Option<usize> {
246    Some(DEFAULT_FORMAT_MAX_AST_NODES)
247}
248
249#[cfg(feature = "generate")]
250fn default_format_max_set_op_chain() -> Option<usize> {
251    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
252}
253
254/// Guard options for SQL pretty-formatting.
255///
256/// These limits protect against extremely large/complex queries that can cause
257/// high memory pressure in constrained runtimes (for example browser WASM).
258#[cfg(feature = "generate")]
259#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
260#[serde(rename_all = "camelCase")]
261pub struct FormatGuardOptions {
262    /// Maximum allowed SQL input size in bytes.
263    /// `None` disables this check.
264    #[serde(default = "default_format_max_input_bytes")]
265    pub max_input_bytes: Option<usize>,
266    /// Maximum allowed number of tokens after tokenization.
267    /// `None` disables this check.
268    #[serde(default = "default_format_max_tokens")]
269    pub max_tokens: Option<usize>,
270    /// Maximum allowed AST node count after parsing.
271    /// `None` disables this check.
272    #[serde(default = "default_format_max_ast_nodes")]
273    pub max_ast_nodes: Option<usize>,
274    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
275    /// observed in a statement before parsing.
276    ///
277    /// `None` disables this check.
278    #[serde(default = "default_format_max_set_op_chain")]
279    pub max_set_op_chain: Option<usize>,
280}
281
282#[cfg(feature = "generate")]
283impl Default for FormatGuardOptions {
284    fn default() -> Self {
285        Self {
286            max_input_bytes: default_format_max_input_bytes(),
287            max_tokens: default_format_max_tokens(),
288            max_ast_nodes: default_format_max_ast_nodes(),
289            max_set_op_chain: default_format_max_set_op_chain(),
290        }
291    }
292}
293
294#[cfg(feature = "generate")]
295fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
296    Error::generate(format!(
297        "{code}: value {actual} exceeds configured limit {limit}"
298    ))
299}
300
301#[cfg(feature = "generate")]
302fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
303    if let Some(max) = options.max_input_bytes {
304        let input_bytes = sql.len();
305        if input_bytes > max {
306            return Err(format_guard_error(
307                "E_GUARD_INPUT_TOO_LARGE",
308                input_bytes,
309                max,
310            ));
311        }
312    }
313    Ok(())
314}
315
316#[cfg(feature = "generate")]
317fn parse_with_token_guard(
318    sql: &str,
319    dialect: &Dialect,
320    options: &FormatGuardOptions,
321) -> Result<Vec<Expression>> {
322    let tokens = dialect.tokenize(sql)?;
323    if let Some(max) = options.max_tokens {
324        let token_count = tokens.len();
325        if token_count > max {
326            return Err(format_guard_error(
327                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
328                token_count,
329                max,
330            ));
331        }
332    }
333    enforce_set_op_chain_guard(&tokens, options)?;
334
335    let config = crate::parser::ParserConfig {
336        dialect: Some(dialect.dialect_type()),
337        ..Default::default()
338    };
339    let mut parser = Parser::with_source(tokens, config, sql.to_string());
340    parser.parse()
341}
342
343#[cfg(feature = "generate")]
344fn is_trivia_token(token_type: TokenType) -> bool {
345    matches!(
346        token_type,
347        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
348    )
349}
350
351#[cfg(feature = "generate")]
352fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
353    tokens
354        .iter()
355        .skip(start)
356        .find(|token| !is_trivia_token(token.token_type))
357}
358
359#[cfg(feature = "generate")]
360fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
361    let token = &tokens[idx];
362    match token.token_type {
363        TokenType::Union | TokenType::Intersect => true,
364        TokenType::Except => {
365            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
366            // is a function call rather than a set operation.
367            if token.text.eq_ignore_ascii_case("minus")
368                && matches!(
369                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
370                    Some(TokenType::LParen)
371                )
372            {
373                return false;
374            }
375            true
376        }
377        _ => false,
378    }
379}
380
381#[cfg(feature = "generate")]
382fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
383    let Some(max) = options.max_set_op_chain else {
384        return Ok(());
385    };
386
387    let mut set_op_count = 0usize;
388    for (idx, token) in tokens.iter().enumerate() {
389        if token.token_type == TokenType::Semicolon {
390            set_op_count = 0;
391            continue;
392        }
393
394        if is_set_operation_token(tokens, idx) {
395            set_op_count += 1;
396            if set_op_count > max {
397                return Err(format_guard_error(
398                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
399                    set_op_count,
400                    max,
401                ));
402            }
403        }
404    }
405
406    Ok(())
407}
408
409#[cfg(feature = "generate")]
410fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
411    if let Some(max) = options.max_ast_nodes {
412        let ast_nodes: usize = expressions
413            .iter()
414            .map(crate::ast_transforms::node_count)
415            .sum();
416        if ast_nodes > max {
417            return Err(format_guard_error(
418                "E_GUARD_AST_BUDGET_EXCEEDED",
419                ast_nodes,
420                max,
421            ));
422        }
423    }
424    Ok(())
425}
426
427#[cfg(feature = "generate")]
428fn format_with_dialect(
429    sql: &str,
430    dialect: &Dialect,
431    options: &FormatGuardOptions,
432) -> Result<Vec<String>> {
433    enforce_input_guard(sql, options)?;
434    let expressions = parse_with_token_guard(sql, dialect, options)?;
435    enforce_ast_guard(&expressions, options)?;
436
437    expressions
438        .iter()
439        .map(|expr| dialect.generate_pretty(expr))
440        .collect()
441}
442
443/// Transpile SQL from one dialect to another.
444///
445/// # Arguments
446/// * `sql` - The SQL string to transpile
447/// * `read` - The source dialect to parse with
448/// * `write` - The target dialect to generate
449///
450/// # Returns
451/// A vector of transpiled SQL statements
452///
453/// # Example
454/// ```
455/// use polyglot_sql::{transpile, DialectType};
456///
457/// let result = transpile(
458///     "SELECT EPOCH_MS(1618088028295)",
459///     DialectType::DuckDB,
460///     DialectType::Hive
461/// );
462/// ```
463#[cfg(feature = "transpile")]
464pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
465    // Delegate to Dialect::transpile so that the full cross-dialect rewrite
466    // pipeline (source+target-aware normalization in `cross_dialect_normalize`)
467    // runs here as well. This keeps Rust crate users on the same code path as
468    // the WASM/FFI/Python bindings and the playground.
469    Dialect::get(read).transpile(sql, write)
470}
471
472/// Parse SQL into an AST.
473///
474/// # Arguments
475/// * `sql` - The SQL string to parse
476/// * `dialect` - The dialect to use for parsing
477///
478/// # Returns
479/// A vector of parsed expressions
480pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
481    let d = Dialect::get(dialect);
482    d.parse(sql)
483}
484
485/// Parse a single SQL statement.
486///
487/// # Arguments
488/// * `sql` - The SQL string containing a single statement
489/// * `dialect` - The dialect to use for parsing
490///
491/// # Returns
492/// The parsed expression, or an error if multiple statements found
493pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
494    let mut expressions = parse(sql, dialect)?;
495
496    if expressions.len() != 1 {
497        return Err(Error::parse(
498            format!("Expected 1 statement, found {}", expressions.len()),
499            0,
500            0,
501            0,
502            0,
503        ));
504    }
505
506    Ok(expressions.remove(0))
507}
508
509/// Parse a standalone SQL data type.
510///
511/// # Arguments
512/// * `sql` - The data type string to parse, e.g. `DECIMAL(10, 2)`
513/// * `dialect` - The dialect to use for parsing
514///
515/// # Returns
516/// The parsed data type
517pub fn parse_data_type(sql: &str, dialect: DialectType) -> Result<DataType> {
518    Dialect::get(dialect).parse_data_type(sql)
519}
520
521/// Generate SQL from a standalone data type.
522///
523/// # Arguments
524/// * `data_type` - The data type to render
525/// * `dialect` - The target dialect
526///
527/// # Returns
528/// The generated type SQL string
529#[cfg(feature = "generate")]
530pub fn generate_data_type(data_type: &DataType, dialect: DialectType) -> Result<String> {
531    Dialect::get(dialect).generate(&Expression::DataType(data_type.clone()))
532}
533
534/// Generate SQL from an AST.
535///
536/// # Arguments
537/// * `expression` - The expression to generate SQL from
538/// * `dialect` - The target dialect
539///
540/// # Returns
541/// The generated SQL string
542#[cfg(feature = "generate")]
543pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
544    let d = Dialect::get(dialect);
545    d.generate(expression)
546}
547
548/// Format/pretty-print SQL statements.
549///
550/// Uses [`FormatGuardOptions::default`] guards.
551#[cfg(feature = "generate")]
552pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
553    format_with_options(sql, dialect, &FormatGuardOptions::default())
554}
555
556/// Format/pretty-print SQL statements with configurable guard limits.
557#[cfg(feature = "generate")]
558pub fn format_with_options(
559    sql: &str,
560    dialect: DialectType,
561    options: &FormatGuardOptions,
562) -> Result<Vec<String>> {
563    let d = Dialect::get(dialect);
564    format_with_dialect(sql, &d, options)
565}
566
567/// Validate SQL syntax.
568///
569/// # Arguments
570/// * `sql` - The SQL string to validate
571/// * `dialect` - The dialect to use for validation
572///
573/// # Returns
574/// A validation result with any errors found
575#[cfg(feature = "semantic")]
576pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
577    validate_with_options(sql, dialect, &ValidationOptions::default())
578}
579
580/// Options for syntax validation behavior.
581#[cfg(feature = "semantic")]
582#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
583#[serde(rename_all = "camelCase")]
584pub struct ValidationOptions {
585    /// When enabled, validation rejects non-canonical trailing commas that the parser
586    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
587    #[serde(default)]
588    pub strict_syntax: bool,
589}
590
591/// Validate SQL syntax with additional validation options.
592#[cfg(feature = "semantic")]
593pub fn validate_with_options(
594    sql: &str,
595    dialect: DialectType,
596    options: &ValidationOptions,
597) -> ValidationResult {
598    let d = Dialect::get(dialect);
599    match d.parse(sql) {
600        Ok(expressions) => {
601            // Reject bare expressions that aren't valid SQL statements.
602            // The parser accepts any expression at the top level, but bare identifiers,
603            // literals, function calls, etc. are not valid statements.
604            for expr in &expressions {
605                if !expr.is_statement() {
606                    let msg = format!("Invalid expression / Unexpected token");
607                    return ValidationResult::with_errors(vec![ValidationError::error(
608                        msg, "E004",
609                    )]);
610                }
611            }
612            if options.strict_syntax {
613                if let Some(error) = strict_syntax_error(sql, &d) {
614                    return ValidationResult::with_errors(vec![error]);
615                }
616            }
617            ValidationResult::success()
618        }
619        Err(e) => {
620            let error = match &e {
621                Error::Syntax {
622                    message,
623                    line,
624                    column,
625                    start,
626                    end,
627                } => ValidationError::error(message.clone(), "E001")
628                    .with_location(*line, *column)
629                    .with_span(Some(*start), Some(*end)),
630                Error::Tokenize {
631                    message,
632                    line,
633                    column,
634                    start,
635                    end,
636                } => ValidationError::error(message.clone(), "E002")
637                    .with_location(*line, *column)
638                    .with_span(Some(*start), Some(*end)),
639                Error::Parse {
640                    message,
641                    line,
642                    column,
643                    start,
644                    end,
645                } => ValidationError::error(message.clone(), "E003")
646                    .with_location(*line, *column)
647                    .with_span(Some(*start), Some(*end)),
648                _ => ValidationError::error(e.to_string(), "E000"),
649            };
650            ValidationResult::with_errors(vec![error])
651        }
652    }
653}
654
655#[cfg(feature = "semantic")]
656fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
657    let tokens = dialect.tokenize(sql).ok()?;
658
659    for (idx, token) in tokens.iter().enumerate() {
660        if token.token_type != TokenType::Comma {
661            continue;
662        }
663
664        let next = tokens.get(idx + 1);
665        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
666            Some(TokenType::From) => (true, "FROM"),
667            Some(TokenType::Where) => (true, "WHERE"),
668            Some(TokenType::GroupBy) => (true, "GROUP BY"),
669            Some(TokenType::Having) => (true, "HAVING"),
670            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
671            Some(TokenType::Limit) => (true, "LIMIT"),
672            Some(TokenType::Offset) => (true, "OFFSET"),
673            Some(TokenType::Union) => (true, "UNION"),
674            Some(TokenType::Intersect) => (true, "INTERSECT"),
675            Some(TokenType::Except) => (true, "EXCEPT"),
676            Some(TokenType::Qualify) => (true, "QUALIFY"),
677            Some(TokenType::Window) => (true, "WINDOW"),
678            Some(TokenType::Semicolon) | None => (true, "end of statement"),
679            _ => (false, ""),
680        };
681
682        if is_boundary {
683            let message = format!(
684                "Trailing comma before {} is not allowed in strict syntax mode",
685                boundary_name
686            );
687            return Some(
688                ValidationError::error(message, "E005")
689                    .with_location(token.span.line, token.span.column),
690            );
691        }
692    }
693
694    None
695}
696
697/// Transpile SQL from one dialect to another, using string dialect names.
698///
699/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
700/// custom dialects registered via [`CustomDialectBuilder`].
701///
702/// # Arguments
703/// * `sql` - The SQL string to transpile
704/// * `read` - The source dialect name
705/// * `write` - The target dialect name
706///
707/// # Returns
708/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
709#[cfg(feature = "transpile")]
710pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
711    transpile_with_by_name(sql, read, write, &TranspileOptions::default())
712}
713
714/// Transpile SQL with configurable [`TranspileOptions`], using string dialect names.
715///
716/// Same as [`transpile_by_name`] but accepts options (e.g., pretty-printing).
717#[cfg(feature = "transpile")]
718pub fn transpile_with_by_name(
719    sql: &str,
720    read: &str,
721    write: &str,
722    opts: &TranspileOptions,
723) -> Result<Vec<String>> {
724    let read_dialect = Dialect::get_by_name(read)
725        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
726    let write_dialect = Dialect::get_by_name(write)
727        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
728    read_dialect.transpile_with(sql, &write_dialect, opts.clone())
729}
730
731/// Parse SQL into an AST using a string dialect name.
732///
733/// Supports both built-in and custom dialect names.
734pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
735    let d = Dialect::get_by_name(dialect)
736        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
737    d.parse(sql)
738}
739
740/// Generate SQL from an AST using a string dialect name.
741///
742/// Supports both built-in and custom dialect names.
743#[cfg(feature = "generate")]
744pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
745    let d = Dialect::get_by_name(dialect)
746        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
747    d.generate(expression)
748}
749
750/// Format SQL using a string dialect name.
751///
752/// Uses [`FormatGuardOptions::default`] guards.
753#[cfg(feature = "generate")]
754pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
755    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
756}
757
758/// Format SQL using a string dialect name with configurable guard limits.
759#[cfg(feature = "generate")]
760pub fn format_with_options_by_name(
761    sql: &str,
762    dialect: &str,
763    options: &FormatGuardOptions,
764) -> Result<Vec<String>> {
765    let d = Dialect::get_by_name(dialect)
766        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
767    format_with_dialect(sql, &d, options)
768}
769
770#[cfg(all(test, feature = "semantic"))]
771mod validation_tests {
772    use super::*;
773
774    #[test]
775    fn validate_is_permissive_by_default_for_trailing_commas() {
776        let result = validate("SELECT name, FROM employees", DialectType::Generic);
777        assert!(result.valid, "Result: {:?}", result.errors);
778    }
779
780    #[test]
781    fn validate_with_options_rejects_trailing_comma_before_from() {
782        let options = ValidationOptions {
783            strict_syntax: true,
784        };
785        let result = validate_with_options(
786            "SELECT name, FROM employees",
787            DialectType::Generic,
788            &options,
789        );
790        assert!(!result.valid, "Result should be invalid");
791        assert!(
792            result.errors.iter().any(|e| e.code == "E005"),
793            "Expected E005, got: {:?}",
794            result.errors
795        );
796    }
797
798    #[test]
799    fn validate_with_options_rejects_trailing_comma_before_where() {
800        let options = ValidationOptions {
801            strict_syntax: true,
802        };
803        let result = validate_with_options(
804            "SELECT name FROM employees, WHERE salary > 10",
805            DialectType::Generic,
806            &options,
807        );
808        assert!(!result.valid, "Result should be invalid");
809        assert!(
810            result.errors.iter().any(|e| e.code == "E005"),
811            "Expected E005, got: {:?}",
812            result.errors
813        );
814    }
815}
816
817#[cfg(all(test, feature = "generate"))]
818mod format_tests {
819    use super::*;
820
821    #[test]
822    fn format_basic_query() {
823        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
824        assert_eq!(result.len(), 1);
825        assert!(result[0].contains('\n'));
826    }
827
828    #[test]
829    fn format_guard_rejects_large_input() {
830        let options = FormatGuardOptions {
831            max_input_bytes: Some(7),
832            max_tokens: None,
833            max_ast_nodes: None,
834            max_set_op_chain: None,
835        };
836        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
837            .expect_err("expected guard error");
838        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
839    }
840
841    #[test]
842    fn format_guard_rejects_token_budget() {
843        let options = FormatGuardOptions {
844            max_input_bytes: None,
845            max_tokens: Some(1),
846            max_ast_nodes: None,
847            max_set_op_chain: None,
848        };
849        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
850            .expect_err("expected guard error");
851        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
852    }
853
854    #[test]
855    fn format_guard_rejects_ast_budget() {
856        let options = FormatGuardOptions {
857            max_input_bytes: None,
858            max_tokens: None,
859            max_ast_nodes: Some(1),
860            max_set_op_chain: None,
861        };
862        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
863            .expect_err("expected guard error");
864        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
865    }
866
867    #[test]
868    fn format_guard_rejects_set_op_chain_budget() {
869        let options = FormatGuardOptions {
870            max_input_bytes: None,
871            max_tokens: None,
872            max_ast_nodes: None,
873            max_set_op_chain: Some(1),
874        };
875        let err = format_with_options(
876            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
877            DialectType::Generic,
878            &options,
879        )
880        .expect_err("expected guard error");
881        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
882    }
883
884    #[test]
885    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
886        let options = FormatGuardOptions {
887            max_input_bytes: None,
888            max_tokens: None,
889            max_ast_nodes: None,
890            max_set_op_chain: Some(0),
891        };
892        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
893        assert!(result.is_ok(), "Result: {:?}", result);
894    }
895
896    #[test]
897    fn issue57_invalid_ternary_returns_error() {
898        // https://github.com/tobilg/polyglot/issues/57
899        // Invalid SQL with ternary operator should return an error, not garbled output.
900        let sql = "SELECT x > 0 ? 1 : 0 FROM t";
901
902        let parse_result = parse(sql, DialectType::PostgreSQL);
903        assert!(
904            parse_result.is_err(),
905            "Expected parse error for invalid ternary SQL, got: {:?}",
906            parse_result
907        );
908
909        let format_result = format(sql, DialectType::PostgreSQL);
910        assert!(
911            format_result.is_err(),
912            "Expected format error for invalid ternary SQL, got: {:?}",
913            format_result
914        );
915
916        let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
917        assert!(
918            transpile_result.is_err(),
919            "Expected transpile error for invalid ternary SQL, got: {:?}",
920            transpile_result
921        );
922    }
923
924    /// Regression guard: `lib::transpile()` must apply the full cross-dialect
925    /// rewrite pipeline (same as `Dialect::transpile()`). If these two paths
926    /// diverge again, Rust crate users silently get under-transformed SQL that
927    /// differs from what WASM/FFI/Python bindings produce.
928    #[test]
929    fn transpile_applies_cross_dialect_rewrites() {
930        // DuckDB to_timestamp → Trino FROM_UNIXTIME (different input semantics).
931        let out = transpile(
932            "SELECT to_timestamp(col) FROM t",
933            DialectType::DuckDB,
934            DialectType::Trino,
935        )
936        .expect("transpile failed");
937        assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
938
939        // DuckDB CAST(x AS JSON) → Trino JSON_PARSE(x) (different CAST semantics).
940        let out = transpile(
941            "SELECT CAST(col AS JSON) FROM t",
942            DialectType::DuckDB,
943            DialectType::Trino,
944        )
945        .expect("transpile failed");
946        assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
947    }
948
949    /// Regression guard: all three transpile entry points (lib::transpile,
950    /// lib::transpile_by_name, Dialect::transpile) must produce identical
951    /// output. transpile_by_name is the one used by Python and C FFI bindings.
952    #[test]
953    fn transpile_matches_dialect_method() {
954        let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
955            (
956                DialectType::DuckDB,
957                DialectType::Trino,
958                "duckdb",
959                "trino",
960                "SELECT to_timestamp(col) FROM t",
961            ),
962            (
963                DialectType::DuckDB,
964                DialectType::Trino,
965                "duckdb",
966                "trino",
967                "SELECT CAST(col AS JSON) FROM t",
968            ),
969            (
970                DialectType::DuckDB,
971                DialectType::Trino,
972                "duckdb",
973                "trino",
974                "SELECT json_valid(col) FROM t",
975            ),
976            (
977                DialectType::Snowflake,
978                DialectType::DuckDB,
979                "snowflake",
980                "duckdb",
981                "SELECT DATEDIFF(day, a, b) FROM t",
982            ),
983            (
984                DialectType::BigQuery,
985                DialectType::DuckDB,
986                "bigquery",
987                "duckdb",
988                "SELECT DATE_DIFF(a, b, DAY) FROM t",
989            ),
990            (
991                DialectType::Generic,
992                DialectType::Generic,
993                "generic",
994                "generic",
995                "SELECT 1",
996            ),
997        ];
998        for (read, write, read_name, write_name, sql) in cases {
999            let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
1000            let via_name = transpile_by_name(sql, read_name, write_name)
1001                .expect("lib::transpile_by_name failed");
1002            let via_dialect = Dialect::get(*read)
1003                .transpile(sql, *write)
1004                .expect("Dialect::transpile failed");
1005            assert_eq!(
1006                via_lib, via_dialect,
1007                "lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
1008                read, write
1009            );
1010            assert_eq!(
1011                via_name, via_dialect,
1012                "lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
1013            );
1014        }
1015    }
1016
1017    #[test]
1018    fn format_default_guard_rejects_deep_union_chain_before_parse() {
1019        let base = "SELECT col0, col1 FROM t";
1020        let mut sql = base.to_string();
1021        for _ in 0..1100 {
1022            sql.push_str(" UNION ALL ");
1023            sql.push_str(base);
1024        }
1025
1026        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
1027        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
1028    }
1029}