Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21pub mod generator;
22pub mod helper;
23pub mod lineage;
24pub mod optimizer;
25pub mod parser;
26pub mod planner;
27pub mod resolver;
28pub mod schema;
29pub mod scope;
30pub mod time;
31pub mod tokens;
32pub mod transforms;
33pub mod traversal;
34pub mod trie;
35
36pub use ast_transforms::{
37    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
38    get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
39    node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
40    remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
41    set_limit, set_offset,
42};
43pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
44pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
45pub use expressions::Expression;
46pub use generator::Generator;
47pub use helper::{
48    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
49    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
50};
51pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
52pub use parser::Parser;
53pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
54pub use schema::{
55    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
56};
57pub use scope::{
58    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
59    ScopeType, SourceInfo,
60};
61pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
62pub use tokens::{Token, TokenType, Tokenizer};
63pub use traversal::{
64    contains_aggregate,
65    contains_subquery,
66    contains_window_function,
67    find_ancestor,
68    find_parent,
69    get_columns,
70    get_tables,
71    is_add,
72    is_aggregate,
73    is_alias,
74    is_alter_table,
75    is_and,
76    is_arithmetic,
77    is_avg,
78    is_between,
79    is_boolean,
80    is_case,
81    is_cast,
82    is_coalesce,
83    is_column,
84    is_comparison,
85    is_concat,
86    is_count,
87    is_create_index,
88    is_create_table,
89    is_create_view,
90    is_cte,
91    is_ddl,
92    is_delete,
93    is_div,
94    is_drop_index,
95    is_drop_table,
96    is_drop_view,
97    is_eq,
98    is_except,
99    is_exists,
100    is_from,
101    is_function,
102    is_group_by,
103    is_gt,
104    is_gte,
105    is_having,
106    is_identifier,
107    is_ilike,
108    is_in,
109    // Extended type predicates
110    is_insert,
111    is_intersect,
112    is_is_null,
113    is_join,
114    is_like,
115    is_limit,
116    is_literal,
117    is_logical,
118    is_lt,
119    is_lte,
120    is_max_func,
121    is_min_func,
122    is_mod,
123    is_mul,
124    is_neq,
125    is_not,
126    is_null_if,
127    is_null_literal,
128    is_offset,
129    is_or,
130    is_order_by,
131    is_ordered,
132    is_paren,
133    // Composite predicates
134    is_query,
135    is_safe_cast,
136    is_select,
137    is_set_operation,
138    is_star,
139    is_sub,
140    is_subquery,
141    is_sum,
142    is_table,
143    is_try_cast,
144    is_union,
145    is_update,
146    is_where,
147    is_window_function,
148    is_with,
149    transform,
150    transform_map,
151    BfsIter,
152    DfsIter,
153    ExpressionWalk,
154    ParentInfo,
155    TreeContext,
156};
157pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
158
159/// Transpile SQL from one dialect to another.
160///
161/// # Arguments
162/// * `sql` - The SQL string to transpile
163/// * `read` - The source dialect to parse with
164/// * `write` - The target dialect to generate
165///
166/// # Returns
167/// A vector of transpiled SQL statements
168///
169/// # Example
170/// ```
171/// use polyglot_sql::{transpile, DialectType};
172///
173/// let result = transpile(
174///     "SELECT EPOCH_MS(1618088028295)",
175///     DialectType::DuckDB,
176///     DialectType::Hive
177/// );
178/// ```
179pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
180    let read_dialect = Dialect::get(read);
181    let write_dialect = Dialect::get(write);
182
183    let expressions = read_dialect.parse(sql)?;
184
185    expressions
186        .into_iter()
187        .map(|expr| {
188            let transformed = write_dialect.transform(expr)?;
189            write_dialect.generate(&transformed)
190        })
191        .collect()
192}
193
194/// Parse SQL into an AST.
195///
196/// # Arguments
197/// * `sql` - The SQL string to parse
198/// * `dialect` - The dialect to use for parsing
199///
200/// # Returns
201/// A vector of parsed expressions
202pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
203    let d = Dialect::get(dialect);
204    d.parse(sql)
205}
206
207/// Parse a single SQL statement.
208///
209/// # Arguments
210/// * `sql` - The SQL string containing a single statement
211/// * `dialect` - The dialect to use for parsing
212///
213/// # Returns
214/// The parsed expression, or an error if multiple statements found
215pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
216    let mut expressions = parse(sql, dialect)?;
217
218    if expressions.len() != 1 {
219        return Err(Error::Parse(format!(
220            "Expected 1 statement, found {}",
221            expressions.len()
222        )));
223    }
224
225    Ok(expressions.remove(0))
226}
227
228/// Generate SQL from an AST.
229///
230/// # Arguments
231/// * `expression` - The expression to generate SQL from
232/// * `dialect` - The target dialect
233///
234/// # Returns
235/// The generated SQL string
236pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
237    let d = Dialect::get(dialect);
238    d.generate(expression)
239}
240
241/// Validate SQL syntax.
242///
243/// # Arguments
244/// * `sql` - The SQL string to validate
245/// * `dialect` - The dialect to use for validation
246///
247/// # Returns
248/// A validation result with any errors found
249pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
250    let d = Dialect::get(dialect);
251    match d.parse(sql) {
252        Ok(expressions) => {
253            // Reject bare expressions that aren't valid SQL statements.
254            // The parser accepts any expression at the top level, but bare identifiers,
255            // literals, function calls, etc. are not valid statements.
256            for expr in &expressions {
257                if !expr.is_statement() {
258                    let msg = format!("Invalid expression / Unexpected token");
259                    return ValidationResult::with_errors(vec![ValidationError::error(
260                        msg, "E004",
261                    )]);
262                }
263            }
264            ValidationResult::success()
265        }
266        Err(e) => {
267            let error = match &e {
268                Error::Syntax {
269                    message,
270                    line,
271                    column,
272                } => ValidationError::error(message.clone(), "E001").with_location(*line, *column),
273                Error::Tokenize {
274                    message,
275                    line,
276                    column,
277                } => ValidationError::error(message.clone(), "E002").with_location(*line, *column),
278                Error::Parse(msg) => ValidationError::error(msg.clone(), "E003"),
279                _ => ValidationError::error(e.to_string(), "E000"),
280            };
281            ValidationResult::with_errors(vec![error])
282        }
283    }
284}
285
286/// Transpile SQL from one dialect to another, using string dialect names.
287///
288/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
289/// custom dialects registered via [`CustomDialectBuilder`].
290///
291/// # Arguments
292/// * `sql` - The SQL string to transpile
293/// * `read` - The source dialect name
294/// * `write` - The target dialect name
295///
296/// # Returns
297/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
298pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
299    let read_dialect = Dialect::get_by_name(read)
300        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read)))?;
301    let write_dialect = Dialect::get_by_name(write)
302        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write)))?;
303
304    let expressions = read_dialect.parse(sql)?;
305
306    expressions
307        .into_iter()
308        .map(|expr| {
309            let transformed = write_dialect.transform(expr)?;
310            write_dialect.generate(&transformed)
311        })
312        .collect()
313}
314
315/// Parse SQL into an AST using a string dialect name.
316///
317/// Supports both built-in and custom dialect names.
318pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
319    let d = Dialect::get_by_name(dialect)
320        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect)))?;
321    d.parse(sql)
322}
323
324/// Generate SQL from an AST using a string dialect name.
325///
326/// Supports both built-in and custom dialect names.
327pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
328    let d = Dialect::get_by_name(dialect)
329        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect)))?;
330    d.generate(expression)
331}