Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21mod function_registry;
22pub mod function_catalog;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43    get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
44    node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
45    remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
46    set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52    FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67    ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72    contains_aggregate,
73    contains_subquery,
74    contains_window_function,
75    find_ancestor,
76    find_parent,
77    get_columns,
78    get_tables,
79    is_add,
80    is_aggregate,
81    is_alias,
82    is_alter_table,
83    is_and,
84    is_arithmetic,
85    is_avg,
86    is_between,
87    is_boolean,
88    is_case,
89    is_cast,
90    is_coalesce,
91    is_column,
92    is_comparison,
93    is_concat,
94    is_count,
95    is_create_index,
96    is_create_table,
97    is_create_view,
98    is_cte,
99    is_ddl,
100    is_delete,
101    is_div,
102    is_drop_index,
103    is_drop_table,
104    is_drop_view,
105    is_eq,
106    is_except,
107    is_exists,
108    is_from,
109    is_function,
110    is_group_by,
111    is_gt,
112    is_gte,
113    is_having,
114    is_identifier,
115    is_ilike,
116    is_in,
117    // Extended type predicates
118    is_insert,
119    is_intersect,
120    is_is_null,
121    is_join,
122    is_like,
123    is_limit,
124    is_literal,
125    is_logical,
126    is_lt,
127    is_lte,
128    is_max_func,
129    is_min_func,
130    is_mod,
131    is_mul,
132    is_neq,
133    is_not,
134    is_null_if,
135    is_null_literal,
136    is_offset,
137    is_or,
138    is_order_by,
139    is_ordered,
140    is_paren,
141    // Composite predicates
142    is_query,
143    is_safe_cast,
144    is_select,
145    is_set_operation,
146    is_star,
147    is_sub,
148    is_subquery,
149    is_sum,
150    is_table,
151    is_try_cast,
152    is_union,
153    is_update,
154    is_where,
155    is_window_function,
156    is_with,
157    transform,
158    transform_map,
159    BfsIter,
160    DfsIter,
161    ExpressionWalk,
162    ParentInfo,
163    TreeContext,
164};
165pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
166pub use validation::{
167    validate_with_schema, SchemaColumn, SchemaColumnReference, SchemaForeignKey, SchemaTable,
168    SchemaTableReference, SchemaValidationOptions, ValidationSchema,
169};
170
171/// Transpile SQL from one dialect to another.
172///
173/// # Arguments
174/// * `sql` - The SQL string to transpile
175/// * `read` - The source dialect to parse with
176/// * `write` - The target dialect to generate
177///
178/// # Returns
179/// A vector of transpiled SQL statements
180///
181/// # Example
182/// ```
183/// use polyglot_sql::{transpile, DialectType};
184///
185/// let result = transpile(
186///     "SELECT EPOCH_MS(1618088028295)",
187///     DialectType::DuckDB,
188///     DialectType::Hive
189/// );
190/// ```
191pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
192    let read_dialect = Dialect::get(read);
193    let write_dialect = Dialect::get(write);
194    let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
195
196    let expressions = read_dialect.parse(sql)?;
197
198    expressions
199        .into_iter()
200        .map(|expr| {
201            if generic_identity {
202                write_dialect.generate_with_source(&expr, read)
203            } else {
204                let transformed = write_dialect.transform(expr)?;
205                write_dialect.generate_with_source(&transformed, read)
206            }
207        })
208        .collect()
209}
210
211/// Parse SQL into an AST.
212///
213/// # Arguments
214/// * `sql` - The SQL string to parse
215/// * `dialect` - The dialect to use for parsing
216///
217/// # Returns
218/// A vector of parsed expressions
219pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
220    let d = Dialect::get(dialect);
221    d.parse(sql)
222}
223
224/// Parse a single SQL statement.
225///
226/// # Arguments
227/// * `sql` - The SQL string containing a single statement
228/// * `dialect` - The dialect to use for parsing
229///
230/// # Returns
231/// The parsed expression, or an error if multiple statements found
232pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
233    let mut expressions = parse(sql, dialect)?;
234
235    if expressions.len() != 1 {
236        return Err(Error::parse(
237            format!("Expected 1 statement, found {}", expressions.len()),
238            0,
239            0,
240        ));
241    }
242
243    Ok(expressions.remove(0))
244}
245
246/// Generate SQL from an AST.
247///
248/// # Arguments
249/// * `expression` - The expression to generate SQL from
250/// * `dialect` - The target dialect
251///
252/// # Returns
253/// The generated SQL string
254pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
255    let d = Dialect::get(dialect);
256    d.generate(expression)
257}
258
259/// Validate SQL syntax.
260///
261/// # Arguments
262/// * `sql` - The SQL string to validate
263/// * `dialect` - The dialect to use for validation
264///
265/// # Returns
266/// A validation result with any errors found
267pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
268    validate_with_options(sql, dialect, &ValidationOptions::default())
269}
270
271/// Options for syntax validation behavior.
272#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
273#[serde(rename_all = "camelCase")]
274pub struct ValidationOptions {
275    /// When enabled, validation rejects non-canonical trailing commas that the parser
276    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
277    #[serde(default)]
278    pub strict_syntax: bool,
279}
280
281/// Validate SQL syntax with additional validation options.
282pub fn validate_with_options(
283    sql: &str,
284    dialect: DialectType,
285    options: &ValidationOptions,
286) -> ValidationResult {
287    let d = Dialect::get(dialect);
288    match d.parse(sql) {
289        Ok(expressions) => {
290            // Reject bare expressions that aren't valid SQL statements.
291            // The parser accepts any expression at the top level, but bare identifiers,
292            // literals, function calls, etc. are not valid statements.
293            for expr in &expressions {
294                if !expr.is_statement() {
295                    let msg = format!("Invalid expression / Unexpected token");
296                    return ValidationResult::with_errors(vec![ValidationError::error(
297                        msg, "E004",
298                    )]);
299                }
300            }
301            if options.strict_syntax {
302                if let Some(error) = strict_syntax_error(sql, &d) {
303                    return ValidationResult::with_errors(vec![error]);
304                }
305            }
306            ValidationResult::success()
307        }
308        Err(e) => {
309            let error = match &e {
310                Error::Syntax {
311                    message,
312                    line,
313                    column,
314                } => ValidationError::error(message.clone(), "E001").with_location(*line, *column),
315                Error::Tokenize {
316                    message,
317                    line,
318                    column,
319                } => ValidationError::error(message.clone(), "E002").with_location(*line, *column),
320                Error::Parse {
321                    message,
322                    line,
323                    column,
324                } => ValidationError::error(message.clone(), "E003").with_location(*line, *column),
325                _ => ValidationError::error(e.to_string(), "E000"),
326            };
327            ValidationResult::with_errors(vec![error])
328        }
329    }
330}
331
332fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
333    let tokens = dialect.tokenize(sql).ok()?;
334
335    for (idx, token) in tokens.iter().enumerate() {
336        if token.token_type != TokenType::Comma {
337            continue;
338        }
339
340        let next = tokens.get(idx + 1);
341        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
342            Some(TokenType::From) => (true, "FROM"),
343            Some(TokenType::Where) => (true, "WHERE"),
344            Some(TokenType::GroupBy) => (true, "GROUP BY"),
345            Some(TokenType::Having) => (true, "HAVING"),
346            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
347            Some(TokenType::Limit) => (true, "LIMIT"),
348            Some(TokenType::Offset) => (true, "OFFSET"),
349            Some(TokenType::Union) => (true, "UNION"),
350            Some(TokenType::Intersect) => (true, "INTERSECT"),
351            Some(TokenType::Except) => (true, "EXCEPT"),
352            Some(TokenType::Qualify) => (true, "QUALIFY"),
353            Some(TokenType::Window) => (true, "WINDOW"),
354            Some(TokenType::Semicolon) | None => (true, "end of statement"),
355            _ => (false, ""),
356        };
357
358        if is_boundary {
359            let message = format!(
360                "Trailing comma before {} is not allowed in strict syntax mode",
361                boundary_name
362            );
363            return Some(
364                ValidationError::error(message, "E005")
365                    .with_location(token.span.line, token.span.column),
366            );
367        }
368    }
369
370    None
371}
372
373/// Transpile SQL from one dialect to another, using string dialect names.
374///
375/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
376/// custom dialects registered via [`CustomDialectBuilder`].
377///
378/// # Arguments
379/// * `sql` - The SQL string to transpile
380/// * `read` - The source dialect name
381/// * `write` - The target dialect name
382///
383/// # Returns
384/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
385pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
386    let read_dialect = Dialect::get_by_name(read)
387        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0))?;
388    let write_dialect = Dialect::get_by_name(write)
389        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0))?;
390    let generic_identity = read_dialect.dialect_type() == DialectType::Generic
391        && write_dialect.dialect_type() == DialectType::Generic;
392
393    let expressions = read_dialect.parse(sql)?;
394
395    expressions
396        .into_iter()
397        .map(|expr| {
398            if generic_identity {
399                write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
400            } else {
401                let transformed = write_dialect.transform(expr)?;
402                write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
403            }
404        })
405        .collect()
406}
407
408/// Parse SQL into an AST using a string dialect name.
409///
410/// Supports both built-in and custom dialect names.
411pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
412    let d = Dialect::get_by_name(dialect)
413        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
414    d.parse(sql)
415}
416
417/// Generate SQL from an AST using a string dialect name.
418///
419/// Supports both built-in and custom dialect names.
420pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
421    let d = Dialect::get_by_name(dialect)
422        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
423    d.generate(expression)
424}
425
426#[cfg(test)]
427mod validation_tests {
428    use super::*;
429
430    #[test]
431    fn validate_is_permissive_by_default_for_trailing_commas() {
432        let result = validate("SELECT name, FROM employees", DialectType::Generic);
433        assert!(result.valid, "Result: {:?}", result.errors);
434    }
435
436    #[test]
437    fn validate_with_options_rejects_trailing_comma_before_from() {
438        let options = ValidationOptions {
439            strict_syntax: true,
440        };
441        let result = validate_with_options(
442            "SELECT name, FROM employees",
443            DialectType::Generic,
444            &options,
445        );
446        assert!(!result.valid, "Result should be invalid");
447        assert!(
448            result.errors.iter().any(|e| e.code == "E005"),
449            "Expected E005, got: {:?}",
450            result.errors
451        );
452    }
453
454    #[test]
455    fn validate_with_options_rejects_trailing_comma_before_where() {
456        let options = ValidationOptions {
457            strict_syntax: true,
458        };
459        let result = validate_with_options(
460            "SELECT name FROM employees, WHERE salary > 10",
461            DialectType::Generic,
462            &options,
463        );
464        assert!(!result.valid, "Result should be invalid");
465        assert!(
466            result.errors.iter().any(|e| e.code == "E005"),
467            "Expected E005, got: {:?}",
468            result.errors
469        );
470    }
471}