Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21mod function_registry;
22pub mod generator;
23pub mod helper;
24pub mod lineage;
25pub mod optimizer;
26pub mod parser;
27pub mod planner;
28pub mod resolver;
29pub mod schema;
30pub mod scope;
31pub mod time;
32pub mod tokens;
33pub mod transforms;
34pub mod traversal;
35pub mod trie;
36pub mod validation;
37
38use serde::{Deserialize, Serialize};
39
40pub use ast_transforms::{
41    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
42    get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
43    node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
44    remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
45    set_limit, set_offset,
46};
47pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
48pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
49pub use expressions::Expression;
50pub use generator::Generator;
51pub use helper::{
52    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
53    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
54};
55pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
56pub use parser::Parser;
57pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
58pub use schema::{
59    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
60};
61pub use scope::{
62    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
63    ScopeType, SourceInfo,
64};
65pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
66pub use tokens::{Token, TokenType, Tokenizer};
67pub use traversal::{
68    contains_aggregate,
69    contains_subquery,
70    contains_window_function,
71    find_ancestor,
72    find_parent,
73    get_columns,
74    get_tables,
75    is_add,
76    is_aggregate,
77    is_alias,
78    is_alter_table,
79    is_and,
80    is_arithmetic,
81    is_avg,
82    is_between,
83    is_boolean,
84    is_case,
85    is_cast,
86    is_coalesce,
87    is_column,
88    is_comparison,
89    is_concat,
90    is_count,
91    is_create_index,
92    is_create_table,
93    is_create_view,
94    is_cte,
95    is_ddl,
96    is_delete,
97    is_div,
98    is_drop_index,
99    is_drop_table,
100    is_drop_view,
101    is_eq,
102    is_except,
103    is_exists,
104    is_from,
105    is_function,
106    is_group_by,
107    is_gt,
108    is_gte,
109    is_having,
110    is_identifier,
111    is_ilike,
112    is_in,
113    // Extended type predicates
114    is_insert,
115    is_intersect,
116    is_is_null,
117    is_join,
118    is_like,
119    is_limit,
120    is_literal,
121    is_logical,
122    is_lt,
123    is_lte,
124    is_max_func,
125    is_min_func,
126    is_mod,
127    is_mul,
128    is_neq,
129    is_not,
130    is_null_if,
131    is_null_literal,
132    is_offset,
133    is_or,
134    is_order_by,
135    is_ordered,
136    is_paren,
137    // Composite predicates
138    is_query,
139    is_safe_cast,
140    is_select,
141    is_set_operation,
142    is_star,
143    is_sub,
144    is_subquery,
145    is_sum,
146    is_table,
147    is_try_cast,
148    is_union,
149    is_update,
150    is_where,
151    is_window_function,
152    is_with,
153    transform,
154    transform_map,
155    BfsIter,
156    DfsIter,
157    ExpressionWalk,
158    ParentInfo,
159    TreeContext,
160};
161pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
162pub use validation::{
163    validate_with_schema, SchemaColumn, SchemaColumnReference, SchemaForeignKey, SchemaTable,
164    SchemaTableReference, SchemaValidationOptions, ValidationSchema,
165};
166
167/// Transpile SQL from one dialect to another.
168///
169/// # Arguments
170/// * `sql` - The SQL string to transpile
171/// * `read` - The source dialect to parse with
172/// * `write` - The target dialect to generate
173///
174/// # Returns
175/// A vector of transpiled SQL statements
176///
177/// # Example
178/// ```
179/// use polyglot_sql::{transpile, DialectType};
180///
181/// let result = transpile(
182///     "SELECT EPOCH_MS(1618088028295)",
183///     DialectType::DuckDB,
184///     DialectType::Hive
185/// );
186/// ```
187pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
188    let read_dialect = Dialect::get(read);
189    let write_dialect = Dialect::get(write);
190    let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
191
192    let expressions = read_dialect.parse(sql)?;
193
194    expressions
195        .into_iter()
196        .map(|expr| {
197            if generic_identity {
198                write_dialect.generate_with_source(&expr, read)
199            } else {
200                let transformed = write_dialect.transform(expr)?;
201                write_dialect.generate_with_source(&transformed, read)
202            }
203        })
204        .collect()
205}
206
207/// Parse SQL into an AST.
208///
209/// # Arguments
210/// * `sql` - The SQL string to parse
211/// * `dialect` - The dialect to use for parsing
212///
213/// # Returns
214/// A vector of parsed expressions
215pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
216    let d = Dialect::get(dialect);
217    d.parse(sql)
218}
219
220/// Parse a single SQL statement.
221///
222/// # Arguments
223/// * `sql` - The SQL string containing a single statement
224/// * `dialect` - The dialect to use for parsing
225///
226/// # Returns
227/// The parsed expression, or an error if multiple statements found
228pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
229    let mut expressions = parse(sql, dialect)?;
230
231    if expressions.len() != 1 {
232        return Err(Error::parse(
233            format!("Expected 1 statement, found {}", expressions.len()),
234            0,
235            0,
236        ));
237    }
238
239    Ok(expressions.remove(0))
240}
241
242/// Generate SQL from an AST.
243///
244/// # Arguments
245/// * `expression` - The expression to generate SQL from
246/// * `dialect` - The target dialect
247///
248/// # Returns
249/// The generated SQL string
250pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
251    let d = Dialect::get(dialect);
252    d.generate(expression)
253}
254
255/// Validate SQL syntax.
256///
257/// # Arguments
258/// * `sql` - The SQL string to validate
259/// * `dialect` - The dialect to use for validation
260///
261/// # Returns
262/// A validation result with any errors found
263pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
264    validate_with_options(sql, dialect, &ValidationOptions::default())
265}
266
267/// Options for syntax validation behavior.
268#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
269#[serde(rename_all = "camelCase")]
270pub struct ValidationOptions {
271    /// When enabled, validation rejects non-canonical trailing commas that the parser
272    /// would otherwise accept for compatibility (e.g. `SELECT a, FROM t`).
273    #[serde(default)]
274    pub strict_syntax: bool,
275}
276
277/// Validate SQL syntax with additional validation options.
278pub fn validate_with_options(
279    sql: &str,
280    dialect: DialectType,
281    options: &ValidationOptions,
282) -> ValidationResult {
283    let d = Dialect::get(dialect);
284    match d.parse(sql) {
285        Ok(expressions) => {
286            // Reject bare expressions that aren't valid SQL statements.
287            // The parser accepts any expression at the top level, but bare identifiers,
288            // literals, function calls, etc. are not valid statements.
289            for expr in &expressions {
290                if !expr.is_statement() {
291                    let msg = format!("Invalid expression / Unexpected token");
292                    return ValidationResult::with_errors(vec![ValidationError::error(
293                        msg, "E004",
294                    )]);
295                }
296            }
297            if options.strict_syntax {
298                if let Some(error) = strict_syntax_error(sql, &d) {
299                    return ValidationResult::with_errors(vec![error]);
300                }
301            }
302            ValidationResult::success()
303        }
304        Err(e) => {
305            let error = match &e {
306                Error::Syntax {
307                    message,
308                    line,
309                    column,
310                } => ValidationError::error(message.clone(), "E001").with_location(*line, *column),
311                Error::Tokenize {
312                    message,
313                    line,
314                    column,
315                } => ValidationError::error(message.clone(), "E002").with_location(*line, *column),
316                Error::Parse {
317                    message,
318                    line,
319                    column,
320                } => ValidationError::error(message.clone(), "E003").with_location(*line, *column),
321                _ => ValidationError::error(e.to_string(), "E000"),
322            };
323            ValidationResult::with_errors(vec![error])
324        }
325    }
326}
327
328fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
329    let tokens = dialect.tokenize(sql).ok()?;
330
331    for (idx, token) in tokens.iter().enumerate() {
332        if token.token_type != TokenType::Comma {
333            continue;
334        }
335
336        let next = tokens.get(idx + 1);
337        let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
338            Some(TokenType::From) => (true, "FROM"),
339            Some(TokenType::Where) => (true, "WHERE"),
340            Some(TokenType::GroupBy) => (true, "GROUP BY"),
341            Some(TokenType::Having) => (true, "HAVING"),
342            Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
343            Some(TokenType::Limit) => (true, "LIMIT"),
344            Some(TokenType::Offset) => (true, "OFFSET"),
345            Some(TokenType::Union) => (true, "UNION"),
346            Some(TokenType::Intersect) => (true, "INTERSECT"),
347            Some(TokenType::Except) => (true, "EXCEPT"),
348            Some(TokenType::Qualify) => (true, "QUALIFY"),
349            Some(TokenType::Window) => (true, "WINDOW"),
350            Some(TokenType::Semicolon) | None => (true, "end of statement"),
351            _ => (false, ""),
352        };
353
354        if is_boundary {
355            let message = format!(
356                "Trailing comma before {} is not allowed in strict syntax mode",
357                boundary_name
358            );
359            return Some(
360                ValidationError::error(message, "E005")
361                    .with_location(token.span.line, token.span.column),
362            );
363        }
364    }
365
366    None
367}
368
369/// Transpile SQL from one dialect to another, using string dialect names.
370///
371/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
372/// custom dialects registered via [`CustomDialectBuilder`].
373///
374/// # Arguments
375/// * `sql` - The SQL string to transpile
376/// * `read` - The source dialect name
377/// * `write` - The target dialect name
378///
379/// # Returns
380/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
381pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
382    let read_dialect = Dialect::get_by_name(read)
383        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0))?;
384    let write_dialect = Dialect::get_by_name(write)
385        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0))?;
386    let generic_identity = read_dialect.dialect_type() == DialectType::Generic
387        && write_dialect.dialect_type() == DialectType::Generic;
388
389    let expressions = read_dialect.parse(sql)?;
390
391    expressions
392        .into_iter()
393        .map(|expr| {
394            if generic_identity {
395                write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
396            } else {
397                let transformed = write_dialect.transform(expr)?;
398                write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
399            }
400        })
401        .collect()
402}
403
404/// Parse SQL into an AST using a string dialect name.
405///
406/// Supports both built-in and custom dialect names.
407pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
408    let d = Dialect::get_by_name(dialect)
409        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
410    d.parse(sql)
411}
412
413/// Generate SQL from an AST using a string dialect name.
414///
415/// Supports both built-in and custom dialect names.
416pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
417    let d = Dialect::get_by_name(dialect)
418        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
419    d.generate(expression)
420}
421
422#[cfg(test)]
423mod validation_tests {
424    use super::*;
425
426    #[test]
427    fn validate_is_permissive_by_default_for_trailing_commas() {
428        let result = validate("SELECT name, FROM employees", DialectType::Generic);
429        assert!(result.valid, "Result: {:?}", result.errors);
430    }
431
432    #[test]
433    fn validate_with_options_rejects_trailing_comma_before_from() {
434        let options = ValidationOptions {
435            strict_syntax: true,
436        };
437        let result = validate_with_options(
438            "SELECT name, FROM employees",
439            DialectType::Generic,
440            &options,
441        );
442        assert!(!result.valid, "Result should be invalid");
443        assert!(
444            result.errors.iter().any(|e| e.code == "E005"),
445            "Expected E005, got: {:?}",
446            result.errors
447        );
448    }
449
450    #[test]
451    fn validate_with_options_rejects_trailing_comma_before_where() {
452        let options = ValidationOptions {
453            strict_syntax: true,
454        };
455        let result = validate_with_options(
456            "SELECT name FROM employees, WHERE salary > 10",
457            DialectType::Generic,
458            &options,
459        );
460        assert!(!result.valid, "Result should be invalid");
461        assert!(
462            result.errors.iter().any(|e| e.code == "E005"),
463            "Expected E005, got: {:?}",
464            result.errors
465        );
466    }
467}