Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
16pub mod ast_transforms;
17#[cfg(feature = "builder")]
18pub mod builder;
19pub mod dialects;
20#[cfg(feature = "diff")]
21pub mod diff;
22pub mod error;
23pub mod expressions;
24#[cfg(feature = "semantic")]
25pub mod function_catalog;
26mod function_registry;
27#[cfg(feature = "generate")]
28pub mod generator;
29#[cfg(feature = "semantic")]
30pub mod helper;
31#[cfg(feature = "semantic")]
32pub mod lineage;
33#[cfg(feature = "openlineage")]
34pub mod openlineage;
35#[cfg(feature = "semantic")]
36pub mod optimizer;
37pub mod parser;
38#[cfg(feature = "planner")]
39pub mod planner;
40#[cfg(all(feature = "semantic", feature = "generate"))]
41pub mod query_analysis;
42#[cfg(feature = "semantic")]
43pub mod resolver;
44#[cfg(feature = "semantic")]
45pub mod schema;
46#[cfg(feature = "semantic")]
47pub mod scope;
48#[cfg(feature = "time")]
49pub mod time;
50pub mod tokens;
51#[cfg(feature = "transpile")]
52pub mod transforms;
53#[cfg(any(feature = "ast-tools", feature = "generate", feature = "semantic"))]
54pub mod traversal;
55#[cfg(any(feature = "semantic", feature = "time"))]
56pub mod trie;
57#[cfg(feature = "semantic")]
58pub mod validation;
59
60#[cfg(any(feature = "generate", feature = "semantic"))]
61use serde::{Deserialize, Serialize};
62
63#[cfg(feature = "ast-tools")]
64pub use ast_transforms::{
65    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
66    get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
67    get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
68    remove_select_columns, remove_where, rename_columns, rename_tables, rename_tables_with_options,
69    replace_by_type, replace_nodes, set_distinct, set_limit, set_offset, RenameTablesOptions,
70};
71pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
72#[cfg(feature = "transpile")]
73pub use dialects::{TranspileOptions, TranspileTarget};
74pub use error::{Error, Result};
75#[cfg(feature = "semantic")]
76pub use error::{ValidationError, ValidationResult, ValidationSeverity};
77pub use expressions::{DataType, Expression};
78#[cfg(feature = "semantic")]
79pub use function_catalog::{
80    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
81};
82#[cfg(feature = "generate")]
83pub use generator::{Generator, UnsupportedLevel};
84#[cfg(feature = "semantic")]
85pub use helper::{
86    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
87    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
88};
89#[cfg(feature = "semantic")]
90pub use optimizer::{
91    annotate_types, qualify_tables, QualifyTablesOptions, TypeAnnotator, TypeCoercionClass,
92};
93pub use parser::Parser;
94#[cfg(all(feature = "semantic", feature = "generate"))]
95pub use query_analysis::{
96    analyze_query, AnalyzeQueryOptions, ColumnReferenceFact, ProjectionFact, QueryAnalysis,
97    QueryShape, ReferenceConfidence, RelationFact, SetOperationBranchFact, SetOperationFact,
98    TransformKind,
99};
100#[cfg(feature = "semantic")]
101pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
102#[cfg(feature = "semantic")]
103pub use schema::{
104    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
105};
106#[cfg(feature = "semantic")]
107pub use scope::{
108    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
109    ScopeType, SourceInfo,
110};
111#[cfg(feature = "time")]
112pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
113pub use tokens::{Token, TokenType, Tokenizer};
114#[cfg(feature = "ast-tools")]
115pub use traversal::{
116    contains_aggregate,
117    contains_subquery,
118    contains_window_function,
119    find_ancestor,
120    find_parent,
121    get_all_tables,
122    get_columns,
123    get_merge_source,
124    get_merge_target,
125    get_tables,
126    is_add,
127    is_aggregate,
128    is_alias,
129    is_alter_table,
130    is_and,
131    is_arithmetic,
132    is_avg,
133    is_between,
134    is_boolean,
135    is_case,
136    is_cast,
137    is_coalesce,
138    is_column,
139    is_comparison,
140    is_concat,
141    is_count,
142    is_create_index,
143    is_create_table,
144    is_create_view,
145    is_cte,
146    is_ddl,
147    is_delete,
148    is_div,
149    is_drop_index,
150    is_drop_table,
151    is_drop_view,
152    is_eq,
153    is_except,
154    is_exists,
155    is_from,
156    is_function,
157    is_group_by,
158    is_gt,
159    is_gte,
160    is_having,
161    is_identifier,
162    is_ilike,
163    is_in,
164    // Extended type predicates
165    is_insert,
166    is_intersect,
167    is_is_null,
168    is_join,
169    is_like,
170    is_limit,
171    is_literal,
172    is_logical,
173    is_lt,
174    is_lte,
175    is_max_func,
176    is_merge,
177    is_min_func,
178    is_mod,
179    is_mul,
180    is_neq,
181    is_not,
182    is_null_if,
183    is_null_literal,
184    is_offset,
185    is_or,
186    is_order_by,
187    is_ordered,
188    is_paren,
189    // Composite predicates
190    is_query,
191    is_safe_cast,
192    is_select,
193    is_set_operation,
194    is_star,
195    is_sub,
196    is_subquery,
197    is_sum,
198    is_table,
199    is_try_cast,
200    is_union,
201    is_update,
202    is_where,
203    is_window_function,
204    is_with,
205    transform,
206    transform_map,
207    BfsIter,
208    DfsIter,
209    ExpressionWalk,
210    ParentInfo,
211    TreeContext,
212};
213#[cfg(any(feature = "semantic", feature = "time"))]
214pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
215#[cfg(feature = "semantic")]
216pub use validation::{
217    mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
218    SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
219    SchemaValidationOptions, ValidationSchema,
220};
221
222#[cfg(feature = "generate")]
223const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; // 16 MiB
224#[cfg(feature = "generate")]
225const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
226#[cfg(feature = "generate")]
227const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
228#[cfg(feature = "generate")]
229const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
230
231#[cfg(feature = "generate")]
232fn default_format_max_input_bytes() -> Option<usize> {
233    Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
234}
235
236#[cfg(feature = "generate")]
237fn default_format_max_tokens() -> Option<usize> {
238    Some(DEFAULT_FORMAT_MAX_TOKENS)
239}
240
241#[cfg(feature = "generate")]
242fn default_format_max_ast_nodes() -> Option<usize> {
243    Some(DEFAULT_FORMAT_MAX_AST_NODES)
244}
245
246#[cfg(feature = "generate")]
247fn default_format_max_set_op_chain() -> Option<usize> {
248    Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
249}
250
251/// Guard options for SQL pretty-formatting.
252///
253/// These limits protect against extremely large/complex queries that can cause
254/// high memory pressure in constrained runtimes (for example browser WASM).
255#[cfg(feature = "generate")]
256#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
257#[serde(rename_all = "camelCase")]
258pub struct FormatGuardOptions {
259    /// Maximum allowed SQL input size in bytes.
260    /// `None` disables this check.
261    #[serde(default = "default_format_max_input_bytes")]
262    pub max_input_bytes: Option<usize>,
263    /// Maximum allowed number of tokens after tokenization.
264    /// `None` disables this check.
265    #[serde(default = "default_format_max_tokens")]
266    pub max_tokens: Option<usize>,
267    /// Maximum allowed AST node count after parsing.
268    /// `None` disables this check.
269    #[serde(default = "default_format_max_ast_nodes")]
270    pub max_ast_nodes: Option<usize>,
271    /// Maximum allowed count of set-operation operators (`UNION`/`INTERSECT`/`EXCEPT`)
272    /// observed in a statement before parsing.
273    ///
274    /// `None` disables this check.
275    #[serde(default = "default_format_max_set_op_chain")]
276    pub max_set_op_chain: Option<usize>,
277}
278
279#[cfg(feature = "generate")]
280impl Default for FormatGuardOptions {
281    fn default() -> Self {
282        Self {
283            max_input_bytes: default_format_max_input_bytes(),
284            max_tokens: default_format_max_tokens(),
285            max_ast_nodes: default_format_max_ast_nodes(),
286            max_set_op_chain: default_format_max_set_op_chain(),
287        }
288    }
289}
290
291#[cfg(feature = "generate")]
292fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
293    Error::generate(format!(
294        "{code}: value {actual} exceeds configured limit {limit}"
295    ))
296}
297
298#[cfg(feature = "generate")]
299fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
300    if let Some(max) = options.max_input_bytes {
301        let input_bytes = sql.len();
302        if input_bytes > max {
303            return Err(format_guard_error(
304                "E_GUARD_INPUT_TOO_LARGE",
305                input_bytes,
306                max,
307            ));
308        }
309    }
310    Ok(())
311}
312
313#[cfg(feature = "generate")]
314fn parse_with_token_guard(
315    sql: &str,
316    dialect: &Dialect,
317    options: &FormatGuardOptions,
318) -> Result<Vec<Expression>> {
319    let tokens = dialect.tokenize(sql)?;
320    if let Some(max) = options.max_tokens {
321        let token_count = tokens.len();
322        if token_count > max {
323            return Err(format_guard_error(
324                "E_GUARD_TOKEN_BUDGET_EXCEEDED",
325                token_count,
326                max,
327            ));
328        }
329    }
330    enforce_set_op_chain_guard(&tokens, options)?;
331
332    let config = crate::parser::ParserConfig {
333        dialect: Some(dialect.dialect_type()),
334        ..Default::default()
335    };
336    let mut parser = Parser::with_source(tokens, config, sql.to_string());
337    parser.parse()
338}
339
340#[cfg(feature = "generate")]
341fn is_trivia_token(token_type: TokenType) -> bool {
342    matches!(
343        token_type,
344        TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
345    )
346}
347
348#[cfg(feature = "generate")]
349fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
350    tokens
351        .iter()
352        .skip(start)
353        .find(|token| !is_trivia_token(token.token_type))
354}
355
356#[cfg(feature = "generate")]
357fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
358    let token = &tokens[idx];
359    match token.token_type {
360        TokenType::Union | TokenType::Intersect => true,
361        TokenType::Except => {
362            // MINUS is aliased to EXCEPT in the tokenizer, but in ClickHouse minus(...)
363            // is a function call rather than a set operation.
364            if token.text.eq_ignore_ascii_case("minus")
365                && matches!(
366                    next_significant_token(tokens, idx + 1).map(|t| t.token_type),
367                    Some(TokenType::LParen)
368                )
369            {
370                return false;
371            }
372            true
373        }
374        _ => false,
375    }
376}
377
378#[cfg(feature = "generate")]
379fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
380    let Some(max) = options.max_set_op_chain else {
381        return Ok(());
382    };
383
384    let mut set_op_count = 0usize;
385    for (idx, token) in tokens.iter().enumerate() {
386        if token.token_type == TokenType::Semicolon {
387            set_op_count = 0;
388            continue;
389        }
390
391        if is_set_operation_token(tokens, idx) {
392            set_op_count += 1;
393            if set_op_count > max {
394                return Err(format_guard_error(
395                    "E_GUARD_SET_OP_CHAIN_EXCEEDED",
396                    set_op_count,
397                    max,
398                ));
399            }
400        }
401    }
402
403    Ok(())
404}
405
406#[cfg(feature = "generate")]
407fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
408    if let Some(max) = options.max_ast_nodes {
409        let ast_nodes: usize = expressions
410            .iter()
411            .map(crate::ast_transforms::node_count)
412            .sum();
413        if ast_nodes > max {
414            return Err(format_guard_error(
415                "E_GUARD_AST_BUDGET_EXCEEDED",
416                ast_nodes,
417                max,
418            ));
419        }
420    }
421    Ok(())
422}
423
424#[cfg(feature = "generate")]
425fn format_with_dialect(
426    sql: &str,
427    dialect: &Dialect,
428    options: &FormatGuardOptions,
429) -> Result<Vec<String>> {
430    enforce_input_guard(sql, options)?;
431    let expressions = parse_with_token_guard(sql, dialect, options)?;
432    enforce_ast_guard(&expressions, options)?;
433
434    expressions
435        .iter()
436        .map(|expr| dialect.generate_pretty(expr))
437        .collect()
438}
439
440/// Transpile SQL from one dialect to another.
441///
442/// # Arguments
443/// * `sql` - The SQL string to transpile
444/// * `read` - The source dialect to parse with
445/// * `write` - The target dialect to generate
446///
447/// # Returns
448/// A vector of transpiled SQL statements
449///
450/// # Example
451/// ```
452/// use polyglot_sql::{transpile, DialectType};
453///
454/// let result = transpile(
455///     "SELECT EPOCH_MS(1618088028295)",
456///     DialectType::DuckDB,
457///     DialectType::Hive
458/// );
459/// ```
460#[cfg(feature = "transpile")]
461pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
462    // Delegate to Dialect::transpile so that the full cross-dialect rewrite
463    // pipeline (source+target-aware normalization in `cross_dialect_normalize`)
464    // runs here as well. This keeps Rust crate users on the same code path as
465    // the WASM/FFI/Python bindings and the playground.
466    Dialect::get(read).transpile(sql, write)
467}
468
469/// Parse SQL into an AST.
470///
471/// # Arguments
472/// * `sql` - The SQL string to parse
473/// * `dialect` - The dialect to use for parsing
474///
475/// # Returns
476/// A vector of parsed expressions
477pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
478    let d = Dialect::get(dialect);
479    d.parse(sql)
480}
481
482/// Parse a single SQL statement.
483///
484/// # Arguments
485/// * `sql` - The SQL string containing a single statement
486/// * `dialect` - The dialect to use for parsing
487///
488/// # Returns
489/// The parsed expression, or an error if multiple statements found
490pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
491    let mut expressions = parse(sql, dialect)?;
492
493    if expressions.len() != 1 {
494        return Err(Error::parse(
495            format!("Expected 1 statement, found {}", expressions.len()),
496            0,
497            0,
498            0,
499            0,
500        ));
501    }
502
503    Ok(expressions.remove(0))
504}
505
506/// Parse a standalone SQL data type.
507///
508/// # Arguments
509/// * `sql` - The data type string to parse, e.g. `DECIMAL(10, 2)`
510/// * `dialect` - The dialect to use for parsing
511///
512/// # Returns
513/// The parsed data type
514pub fn parse_data_type(sql: &str, dialect: DialectType) -> Result<DataType> {
515    Dialect::get(dialect).parse_data_type(sql)
516}
517
518/// Generate SQL from a standalone data type.
519///
520/// # Arguments
521/// * `data_type` - The data type to render
522/// * `dialect` - The target dialect
523///
524/// # Returns
525/// The generated type SQL string
526#[cfg(feature = "generate")]
527pub fn generate_data_type(data_type: &DataType, dialect: DialectType) -> Result<String> {
528    Dialect::get(dialect).generate(&Expression::DataType(data_type.clone()))
529}
530
531/// Generate SQL from an AST.
532///
533/// # Arguments
534/// * `expression` - The expression to generate SQL from
535/// * `dialect` - The target dialect
536///
537/// # Returns
538/// The generated SQL string
539#[cfg(feature = "generate")]
540pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
541    let d = Dialect::get(dialect);
542    d.generate(expression)
543}
544
545/// Format/pretty-print SQL statements.
546///
547/// Uses [`FormatGuardOptions::default`] guards.
548#[cfg(feature = "generate")]
549pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
550    format_with_options(sql, dialect, &FormatGuardOptions::default())
551}
552
553/// Format/pretty-print SQL statements with configurable guard limits.
554#[cfg(feature = "generate")]
555pub fn format_with_options(
556    sql: &str,
557    dialect: DialectType,
558    options: &FormatGuardOptions,
559) -> Result<Vec<String>> {
560    let d = Dialect::get(dialect);
561    format_with_dialect(sql, &d, options)
562}
563
564/// Validate SQL syntax.
565///
566/// # Arguments
567/// * `sql` - The SQL string to validate
568/// * `dialect` - The dialect to use for validation
569///
570/// # Returns
571/// A validation result with any errors found
572#[cfg(feature = "semantic")]
573pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
574    validate_with_options(sql, dialect, &ValidationOptions::default())
575}
576
577/// Options for syntax validation behavior.
578#[cfg(feature = "semantic")]
579#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
580#[serde(rename_all = "camelCase")]
581pub struct ValidationOptions {
582    /// When enabled, validation rejects non-canonical trailing commas that the parser
583    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
584    #[serde(default)]
585    pub strict_syntax: bool,
586}
587
588/// Validate SQL syntax with additional validation options.
589#[cfg(feature = "semantic")]
590pub fn validate_with_options(
591    sql: &str,
592    dialect: DialectType,
593    options: &ValidationOptions,
594) -> ValidationResult {
595    let d = Dialect::get(dialect);
596    match d.parse(sql) {
597        Ok(expressions) => {
598            // Reject bare expressions that aren't valid SQL statements.
599            // The parser accepts any expression at the top level, but bare identifiers,
600            // literals, function calls, etc. are not valid statements.
601            for expr in &expressions {
602                if !expr.is_statement() {
603                    let msg = format!("Invalid expression / Unexpected token");
604                    return ValidationResult::with_errors(vec![ValidationError::error(
605                        msg, "E004",
606                    )]);
607                }
608            }
609            if options.strict_syntax {
610                if let Some(error) = strict_syntax_error(sql, &d) {
611                    return ValidationResult::with_errors(vec![error]);
612                }
613            }
614            ValidationResult::success()
615        }
616        Err(e) => {
617            let error = match &e {
618                Error::Syntax {
619                    message,
620                    line,
621                    column,
622                    start,
623                    end,
624                } => ValidationError::error(message.clone(), "E001")
625                    .with_location(*line, *column)
626                    .with_span(Some(*start), Some(*end)),
627                Error::Tokenize {
628                    message,
629                    line,
630                    column,
631                    start,
632                    end,
633                } => ValidationError::error(message.clone(), "E002")
634                    .with_location(*line, *column)
635                    .with_span(Some(*start), Some(*end)),
636                Error::Parse {
637                    message,
638                    line,
639                    column,
640                    start,
641                    end,
642                } => ValidationError::error(message.clone(), "E003")
643                    .with_location(*line, *column)
644                    .with_span(Some(*start), Some(*end)),
645                _ => ValidationError::error(e.to_string(), "E000"),
646            };
647            ValidationResult::with_errors(vec![error])
648        }
649    }
650}
651
652#[cfg(feature = "semantic")]
653fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
654    let tokens = dialect.tokenize(sql).ok()?;
655
656    for (idx, token) in tokens.iter().enumerate() {
657        if token.token_type != TokenType::Comma {
658            continue;
659        }
660
661        let next = tokens.get(idx + 1);
662        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
663            Some(TokenType::From) => (true, "FROM"),
664            Some(TokenType::Where) => (true, "WHERE"),
665            Some(TokenType::GroupBy) => (true, "GROUP BY"),
666            Some(TokenType::Having) => (true, "HAVING"),
667            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
668            Some(TokenType::Limit) => (true, "LIMIT"),
669            Some(TokenType::Offset) => (true, "OFFSET"),
670            Some(TokenType::Union) => (true, "UNION"),
671            Some(TokenType::Intersect) => (true, "INTERSECT"),
672            Some(TokenType::Except) => (true, "EXCEPT"),
673            Some(TokenType::Qualify) => (true, "QUALIFY"),
674            Some(TokenType::Window) => (true, "WINDOW"),
675            Some(TokenType::Semicolon) | None => (true, "end of statement"),
676            _ => (false, ""),
677        };
678
679        if is_boundary {
680            let message = format!(
681                "Trailing comma before {} is not allowed in strict syntax mode",
682                boundary_name
683            );
684            return Some(
685                ValidationError::error(message, "E005")
686                    .with_location(token.span.line, token.span.column),
687            );
688        }
689    }
690
691    None
692}
693
694/// Transpile SQL from one dialect to another, using string dialect names.
695///
696/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
697/// custom dialects registered via [`CustomDialectBuilder`].
698///
699/// # Arguments
700/// * `sql` - The SQL string to transpile
701/// * `read` - The source dialect name
702/// * `write` - The target dialect name
703///
704/// # Returns
705/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
706#[cfg(feature = "transpile")]
707pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
708    transpile_with_by_name(sql, read, write, &TranspileOptions::default())
709}
710
711/// Transpile SQL with configurable [`TranspileOptions`], using string dialect names.
712///
713/// Same as [`transpile_by_name`] but accepts options (e.g., pretty-printing).
714#[cfg(feature = "transpile")]
715pub fn transpile_with_by_name(
716    sql: &str,
717    read: &str,
718    write: &str,
719    opts: &TranspileOptions,
720) -> Result<Vec<String>> {
721    let read_dialect = Dialect::get_by_name(read)
722        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
723    let write_dialect = Dialect::get_by_name(write)
724        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
725    read_dialect.transpile_with(sql, &write_dialect, opts.clone())
726}
727
728/// Parse SQL into an AST using a string dialect name.
729///
730/// Supports both built-in and custom dialect names.
731pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
732    let d = Dialect::get_by_name(dialect)
733        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
734    d.parse(sql)
735}
736
737/// Generate SQL from an AST using a string dialect name.
738///
739/// Supports both built-in and custom dialect names.
740#[cfg(feature = "generate")]
741pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
742    let d = Dialect::get_by_name(dialect)
743        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
744    d.generate(expression)
745}
746
747/// Format SQL using a string dialect name.
748///
749/// Uses [`FormatGuardOptions::default`] guards.
750#[cfg(feature = "generate")]
751pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
752    format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
753}
754
755/// Format SQL using a string dialect name with configurable guard limits.
756#[cfg(feature = "generate")]
757pub fn format_with_options_by_name(
758    sql: &str,
759    dialect: &str,
760    options: &FormatGuardOptions,
761) -> Result<Vec<String>> {
762    let d = Dialect::get_by_name(dialect)
763        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
764    format_with_dialect(sql, &d, options)
765}
766
767#[cfg(all(test, feature = "semantic"))]
768mod validation_tests {
769    use super::*;
770
771    #[test]
772    fn validate_is_permissive_by_default_for_trailing_commas() {
773        let result = validate("SELECT name, FROM employees", DialectType::Generic);
774        assert!(result.valid, "Result: {:?}", result.errors);
775    }
776
777    #[test]
778    fn validate_with_options_rejects_trailing_comma_before_from() {
779        let options = ValidationOptions {
780            strict_syntax: true,
781        };
782        let result = validate_with_options(
783            "SELECT name, FROM employees",
784            DialectType::Generic,
785            &options,
786        );
787        assert!(!result.valid, "Result should be invalid");
788        assert!(
789            result.errors.iter().any(|e| e.code == "E005"),
790            "Expected E005, got: {:?}",
791            result.errors
792        );
793    }
794
795    #[test]
796    fn validate_with_options_rejects_trailing_comma_before_where() {
797        let options = ValidationOptions {
798            strict_syntax: true,
799        };
800        let result = validate_with_options(
801            "SELECT name FROM employees, WHERE salary > 10",
802            DialectType::Generic,
803            &options,
804        );
805        assert!(!result.valid, "Result should be invalid");
806        assert!(
807            result.errors.iter().any(|e| e.code == "E005"),
808            "Expected E005, got: {:?}",
809            result.errors
810        );
811    }
812}
813
814#[cfg(all(test, feature = "generate"))]
815mod format_tests {
816    use super::*;
817
818    #[test]
819    fn format_basic_query() {
820        let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
821        assert_eq!(result.len(), 1);
822        assert!(result[0].contains('\n'));
823    }
824
825    #[test]
826    fn format_guard_rejects_large_input() {
827        let options = FormatGuardOptions {
828            max_input_bytes: Some(7),
829            max_tokens: None,
830            max_ast_nodes: None,
831            max_set_op_chain: None,
832        };
833        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
834            .expect_err("expected guard error");
835        assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
836    }
837
838    #[test]
839    fn format_guard_rejects_token_budget() {
840        let options = FormatGuardOptions {
841            max_input_bytes: None,
842            max_tokens: Some(1),
843            max_ast_nodes: None,
844            max_set_op_chain: None,
845        };
846        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
847            .expect_err("expected guard error");
848        assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
849    }
850
851    #[test]
852    fn format_guard_rejects_ast_budget() {
853        let options = FormatGuardOptions {
854            max_input_bytes: None,
855            max_tokens: None,
856            max_ast_nodes: Some(1),
857            max_set_op_chain: None,
858        };
859        let err = format_with_options("SELECT 1", DialectType::Generic, &options)
860            .expect_err("expected guard error");
861        assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
862    }
863
864    #[test]
865    fn format_guard_rejects_set_op_chain_budget() {
866        let options = FormatGuardOptions {
867            max_input_bytes: None,
868            max_tokens: None,
869            max_ast_nodes: None,
870            max_set_op_chain: Some(1),
871        };
872        let err = format_with_options(
873            "SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
874            DialectType::Generic,
875            &options,
876        )
877        .expect_err("expected guard error");
878        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
879    }
880
881    #[test]
882    fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
883        let options = FormatGuardOptions {
884            max_input_bytes: None,
885            max_tokens: None,
886            max_ast_nodes: None,
887            max_set_op_chain: Some(0),
888        };
889        let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
890        assert!(result.is_ok(), "Result: {:?}", result);
891    }
892
893    #[test]
894    fn issue57_invalid_ternary_returns_error() {
895        // https://github.com/tobilg/polyglot/issues/57
896        // Invalid SQL with ternary operator should return an error, not garbled output.
897        let sql = "SELECT x > 0 ? 1 : 0 FROM t";
898
899        let parse_result = parse(sql, DialectType::PostgreSQL);
900        assert!(
901            parse_result.is_err(),
902            "Expected parse error for invalid ternary SQL, got: {:?}",
903            parse_result
904        );
905
906        let format_result = format(sql, DialectType::PostgreSQL);
907        assert!(
908            format_result.is_err(),
909            "Expected format error for invalid ternary SQL, got: {:?}",
910            format_result
911        );
912
913        let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
914        assert!(
915            transpile_result.is_err(),
916            "Expected transpile error for invalid ternary SQL, got: {:?}",
917            transpile_result
918        );
919    }
920
921    /// Regression guard: `lib::transpile()` must apply the full cross-dialect
922    /// rewrite pipeline (same as `Dialect::transpile()`). If these two paths
923    /// diverge again, Rust crate users silently get under-transformed SQL that
924    /// differs from what WASM/FFI/Python bindings produce.
925    #[test]
926    fn transpile_applies_cross_dialect_rewrites() {
927        // DuckDB to_timestamp → Trino FROM_UNIXTIME (different input semantics).
928        let out = transpile(
929            "SELECT to_timestamp(col) FROM t",
930            DialectType::DuckDB,
931            DialectType::Trino,
932        )
933        .expect("transpile failed");
934        assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
935
936        // DuckDB CAST(x AS JSON) → Trino JSON_PARSE(x) (different CAST semantics).
937        let out = transpile(
938            "SELECT CAST(col AS JSON) FROM t",
939            DialectType::DuckDB,
940            DialectType::Trino,
941        )
942        .expect("transpile failed");
943        assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
944    }
945
946    /// Regression guard: all three transpile entry points (lib::transpile,
947    /// lib::transpile_by_name, Dialect::transpile) must produce identical
948    /// output. transpile_by_name is the one used by Python and C FFI bindings.
949    #[test]
950    fn transpile_matches_dialect_method() {
951        let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
952            (
953                DialectType::DuckDB,
954                DialectType::Trino,
955                "duckdb",
956                "trino",
957                "SELECT to_timestamp(col) FROM t",
958            ),
959            (
960                DialectType::DuckDB,
961                DialectType::Trino,
962                "duckdb",
963                "trino",
964                "SELECT CAST(col AS JSON) FROM t",
965            ),
966            (
967                DialectType::DuckDB,
968                DialectType::Trino,
969                "duckdb",
970                "trino",
971                "SELECT json_valid(col) FROM t",
972            ),
973            (
974                DialectType::Snowflake,
975                DialectType::DuckDB,
976                "snowflake",
977                "duckdb",
978                "SELECT DATEDIFF(day, a, b) FROM t",
979            ),
980            (
981                DialectType::BigQuery,
982                DialectType::DuckDB,
983                "bigquery",
984                "duckdb",
985                "SELECT DATE_DIFF(a, b, DAY) FROM t",
986            ),
987            (
988                DialectType::Generic,
989                DialectType::Generic,
990                "generic",
991                "generic",
992                "SELECT 1",
993            ),
994        ];
995        for (read, write, read_name, write_name, sql) in cases {
996            let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
997            let via_name = transpile_by_name(sql, read_name, write_name)
998                .expect("lib::transpile_by_name failed");
999            let via_dialect = Dialect::get(*read)
1000                .transpile(sql, *write)
1001                .expect("Dialect::transpile failed");
1002            assert_eq!(
1003                via_lib, via_dialect,
1004                "lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
1005                read, write
1006            );
1007            assert_eq!(
1008                via_name, via_dialect,
1009                "lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
1010            );
1011        }
1012    }
1013
1014    #[test]
1015    fn format_default_guard_rejects_deep_union_chain_before_parse() {
1016        let base = "SELECT col0, col1 FROM t";
1017        let mut sql = base.to_string();
1018        for _ in 0..1100 {
1019            sql.push_str(" UNION ALL ");
1020            sql.push_str(base);
1021        }
1022
1023        let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
1024        assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
1025    }
1026}