Skip to main content

polyglot_sql/
lib.rs

1//! Polyglot Core - SQL parsing and dialect translation library
2//!
3//! This library provides the core functionality for parsing SQL statements,
4//! building an abstract syntax tree (AST), and generating SQL in different dialects.
5//!
6//! # Architecture
7//!
8//! The library follows a pipeline architecture:
9//! 1. **Tokenizer** - Converts SQL string to token stream
10//! 2. **Parser** - Builds AST from tokens
11//! 3. **Generator** - Converts AST back to SQL string
12//!
13//! Each stage can be customized per dialect.
14
15pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21mod function_registry;
22pub mod generator;
23pub mod helper;
24pub mod lineage;
25pub mod optimizer;
26pub mod parser;
27pub mod planner;
28pub mod resolver;
29pub mod schema;
30pub mod scope;
31pub mod time;
32pub mod tokens;
33pub mod transforms;
34pub mod traversal;
35pub mod trie;
36pub mod validation;
37
38pub use ast_transforms::{
39    add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
40    get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
41    node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
42    remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
43    set_limit, set_offset,
44};
45pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
46pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
47pub use expressions::Expression;
48pub use generator::Generator;
49pub use helper::{
50    csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
51    name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
52};
53pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
54pub use parser::Parser;
55pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
56pub use schema::{
57    ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
58};
59pub use scope::{
60    build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
61    ScopeType, SourceInfo,
62};
63pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
64pub use tokens::{Token, TokenType, Tokenizer};
65pub use traversal::{
66    contains_aggregate,
67    contains_subquery,
68    contains_window_function,
69    find_ancestor,
70    find_parent,
71    get_columns,
72    get_tables,
73    is_add,
74    is_aggregate,
75    is_alias,
76    is_alter_table,
77    is_and,
78    is_arithmetic,
79    is_avg,
80    is_between,
81    is_boolean,
82    is_case,
83    is_cast,
84    is_coalesce,
85    is_column,
86    is_comparison,
87    is_concat,
88    is_count,
89    is_create_index,
90    is_create_table,
91    is_create_view,
92    is_cte,
93    is_ddl,
94    is_delete,
95    is_div,
96    is_drop_index,
97    is_drop_table,
98    is_drop_view,
99    is_eq,
100    is_except,
101    is_exists,
102    is_from,
103    is_function,
104    is_group_by,
105    is_gt,
106    is_gte,
107    is_having,
108    is_identifier,
109    is_ilike,
110    is_in,
111    // Extended type predicates
112    is_insert,
113    is_intersect,
114    is_is_null,
115    is_join,
116    is_like,
117    is_limit,
118    is_literal,
119    is_logical,
120    is_lt,
121    is_lte,
122    is_max_func,
123    is_min_func,
124    is_mod,
125    is_mul,
126    is_neq,
127    is_not,
128    is_null_if,
129    is_null_literal,
130    is_offset,
131    is_or,
132    is_order_by,
133    is_ordered,
134    is_paren,
135    // Composite predicates
136    is_query,
137    is_safe_cast,
138    is_select,
139    is_set_operation,
140    is_star,
141    is_sub,
142    is_subquery,
143    is_sum,
144    is_table,
145    is_try_cast,
146    is_union,
147    is_update,
148    is_where,
149    is_window_function,
150    is_with,
151    transform,
152    transform_map,
153    BfsIter,
154    DfsIter,
155    ExpressionWalk,
156    ParentInfo,
157    TreeContext,
158};
159pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
160pub use validation::{
161    validate_with_schema, SchemaColumn, SchemaColumnReference, SchemaForeignKey, SchemaTable,
162    SchemaTableReference, SchemaValidationOptions, ValidationSchema,
163};
164
165/// Transpile SQL from one dialect to another.
166///
167/// # Arguments
168/// * `sql` - The SQL string to transpile
169/// * `read` - The source dialect to parse with
170/// * `write` - The target dialect to generate
171///
172/// # Returns
173/// A vector of transpiled SQL statements
174///
175/// # Example
176/// ```
177/// use polyglot_sql::{transpile, DialectType};
178///
179/// let result = transpile(
180///     "SELECT EPOCH_MS(1618088028295)",
181///     DialectType::DuckDB,
182///     DialectType::Hive
183/// );
184/// ```
185pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
186    let read_dialect = Dialect::get(read);
187    let write_dialect = Dialect::get(write);
188
189    let expressions = read_dialect.parse(sql)?;
190
191    expressions
192        .into_iter()
193        .map(|expr| {
194            let transformed = write_dialect.transform(expr)?;
195            write_dialect.generate_with_source(&transformed, read)
196        })
197        .collect()
198}
199
200/// Parse SQL into an AST.
201///
202/// # Arguments
203/// * `sql` - The SQL string to parse
204/// * `dialect` - The dialect to use for parsing
205///
206/// # Returns
207/// A vector of parsed expressions
208pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
209    let d = Dialect::get(dialect);
210    d.parse(sql)
211}
212
213/// Parse a single SQL statement.
214///
215/// # Arguments
216/// * `sql` - The SQL string containing a single statement
217/// * `dialect` - The dialect to use for parsing
218///
219/// # Returns
220/// The parsed expression, or an error if multiple statements found
221pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
222    let mut expressions = parse(sql, dialect)?;
223
224    if expressions.len() != 1 {
225        return Err(Error::parse(
226            format!("Expected 1 statement, found {}", expressions.len()),
227            0,
228            0,
229        ));
230    }
231
232    Ok(expressions.remove(0))
233}
234
235/// Generate SQL from an AST.
236///
237/// # Arguments
238/// * `expression` - The expression to generate SQL from
239/// * `dialect` - The target dialect
240///
241/// # Returns
242/// The generated SQL string
243pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
244    let d = Dialect::get(dialect);
245    d.generate(expression)
246}
247
248/// Validate SQL syntax.
249///
250/// # Arguments
251/// * `sql` - The SQL string to validate
252/// * `dialect` - The dialect to use for validation
253///
254/// # Returns
255/// A validation result with any errors found
256pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
257    let d = Dialect::get(dialect);
258    match d.parse(sql) {
259        Ok(expressions) => {
260            // Reject bare expressions that aren't valid SQL statements.
261            // The parser accepts any expression at the top level, but bare identifiers,
262            // literals, function calls, etc. are not valid statements.
263            for expr in &expressions {
264                if !expr.is_statement() {
265                    let msg = format!("Invalid expression / Unexpected token");
266                    return ValidationResult::with_errors(vec![ValidationError::error(
267                        msg, "E004",
268                    )]);
269                }
270            }
271            ValidationResult::success()
272        }
273        Err(e) => {
274            let error = match &e {
275                Error::Syntax {
276                    message,
277                    line,
278                    column,
279                } => ValidationError::error(message.clone(), "E001").with_location(*line, *column),
280                Error::Tokenize {
281                    message,
282                    line,
283                    column,
284                } => ValidationError::error(message.clone(), "E002").with_location(*line, *column),
285                Error::Parse {
286                    message,
287                    line,
288                    column,
289                } => ValidationError::error(message.clone(), "E003").with_location(*line, *column),
290                _ => ValidationError::error(e.to_string(), "E000"),
291            };
292            ValidationResult::with_errors(vec![error])
293        }
294    }
295}
296
297/// Transpile SQL from one dialect to another, using string dialect names.
298///
299/// This supports both built-in dialect names (e.g., "postgresql", "mysql") and
300/// custom dialects registered via [`CustomDialectBuilder`].
301///
302/// # Arguments
303/// * `sql` - The SQL string to transpile
304/// * `read` - The source dialect name
305/// * `write` - The target dialect name
306///
307/// # Returns
308/// A vector of transpiled SQL statements, or an error if a dialect name is unknown.
309pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
310    let read_dialect = Dialect::get_by_name(read)
311        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0))?;
312    let write_dialect = Dialect::get_by_name(write)
313        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0))?;
314
315    let expressions = read_dialect.parse(sql)?;
316
317    expressions
318        .into_iter()
319        .map(|expr| {
320            let transformed = write_dialect.transform(expr)?;
321            write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
322        })
323        .collect()
324}
325
326/// Parse SQL into an AST using a string dialect name.
327///
328/// Supports both built-in and custom dialect names.
329pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
330    let d = Dialect::get_by_name(dialect)
331        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
332    d.parse(sql)
333}
334
335/// Generate SQL from an AST using a string dialect name.
336///
337/// Supports both built-in and custom dialect names.
338pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
339    let d = Dialect::get_by_name(dialect)
340        .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
341    d.generate(expression)
342}