pub mod ast_transforms;
pub mod builder;
pub mod dialects;
pub mod diff;
pub mod error;
pub mod expressions;
pub mod function_catalog;
mod function_registry;
pub mod generator;
pub mod helper;
pub mod lineage;
pub mod optimizer;
pub mod parser;
pub mod planner;
pub mod resolver;
pub mod schema;
pub mod scope;
pub mod time;
pub mod tokens;
pub mod transforms;
pub mod traversal;
pub mod trie;
pub mod validation;
use serde::{Deserialize, Serialize};
pub use ast_transforms::{
add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
get_identifiers, get_literals, get_output_column_names, get_subqueries, get_table_names,
get_window_functions, node_count, qualify_columns, remove_limit_offset, remove_nodes,
remove_select_columns, remove_where, rename_columns, rename_tables, replace_by_type,
replace_nodes, set_distinct, set_limit, set_offset,
};
pub use dialects::{
unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType, TranspileOptions,
TranspileTarget,
};
pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
pub use expressions::Expression;
pub use function_catalog::{
FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
};
pub use generator::Generator;
pub use helper::{
csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
};
pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
pub use parser::Parser;
pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
pub use schema::{
ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
};
pub use scope::{
build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
ScopeType, SourceInfo,
};
pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
pub use tokens::{Token, TokenType, Tokenizer};
pub use traversal::{
contains_aggregate,
contains_subquery,
contains_window_function,
find_ancestor,
find_parent,
get_all_tables,
get_columns,
get_merge_source,
get_merge_target,
get_tables,
is_add,
is_aggregate,
is_alias,
is_alter_table,
is_and,
is_arithmetic,
is_avg,
is_between,
is_boolean,
is_case,
is_cast,
is_coalesce,
is_column,
is_comparison,
is_concat,
is_count,
is_create_index,
is_create_table,
is_create_view,
is_cte,
is_ddl,
is_delete,
is_div,
is_drop_index,
is_drop_table,
is_drop_view,
is_eq,
is_except,
is_exists,
is_from,
is_function,
is_group_by,
is_gt,
is_gte,
is_having,
is_identifier,
is_ilike,
is_in,
is_insert,
is_intersect,
is_is_null,
is_join,
is_like,
is_limit,
is_literal,
is_logical,
is_lt,
is_lte,
is_max_func,
is_merge,
is_min_func,
is_mod,
is_mul,
is_neq,
is_not,
is_null_if,
is_null_literal,
is_offset,
is_or,
is_order_by,
is_ordered,
is_paren,
is_query,
is_safe_cast,
is_select,
is_set_operation,
is_star,
is_sub,
is_subquery,
is_sum,
is_table,
is_try_cast,
is_union,
is_update,
is_where,
is_window_function,
is_with,
transform,
transform_map,
BfsIter,
DfsIter,
ExpressionWalk,
ParentInfo,
TreeContext,
};
pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
pub use validation::{
mapping_schema_from_validation_schema, validate_with_schema, SchemaColumn,
SchemaColumnReference, SchemaForeignKey, SchemaTable, SchemaTableReference,
SchemaValidationOptions, ValidationSchema,
};
const DEFAULT_FORMAT_MAX_INPUT_BYTES: usize = 16 * 1024 * 1024; const DEFAULT_FORMAT_MAX_TOKENS: usize = 1_000_000;
const DEFAULT_FORMAT_MAX_AST_NODES: usize = 1_000_000;
const DEFAULT_FORMAT_MAX_SET_OP_CHAIN: usize = 256;
fn default_format_max_input_bytes() -> Option<usize> {
Some(DEFAULT_FORMAT_MAX_INPUT_BYTES)
}
fn default_format_max_tokens() -> Option<usize> {
Some(DEFAULT_FORMAT_MAX_TOKENS)
}
fn default_format_max_ast_nodes() -> Option<usize> {
Some(DEFAULT_FORMAT_MAX_AST_NODES)
}
fn default_format_max_set_op_chain() -> Option<usize> {
Some(DEFAULT_FORMAT_MAX_SET_OP_CHAIN)
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct FormatGuardOptions {
#[serde(default = "default_format_max_input_bytes")]
pub max_input_bytes: Option<usize>,
#[serde(default = "default_format_max_tokens")]
pub max_tokens: Option<usize>,
#[serde(default = "default_format_max_ast_nodes")]
pub max_ast_nodes: Option<usize>,
#[serde(default = "default_format_max_set_op_chain")]
pub max_set_op_chain: Option<usize>,
}
impl Default for FormatGuardOptions {
fn default() -> Self {
Self {
max_input_bytes: default_format_max_input_bytes(),
max_tokens: default_format_max_tokens(),
max_ast_nodes: default_format_max_ast_nodes(),
max_set_op_chain: default_format_max_set_op_chain(),
}
}
}
fn format_guard_error(code: &str, actual: usize, limit: usize) -> Error {
Error::generate(format!(
"{code}: value {actual} exceeds configured limit {limit}"
))
}
fn enforce_input_guard(sql: &str, options: &FormatGuardOptions) -> Result<()> {
if let Some(max) = options.max_input_bytes {
let input_bytes = sql.len();
if input_bytes > max {
return Err(format_guard_error(
"E_GUARD_INPUT_TOO_LARGE",
input_bytes,
max,
));
}
}
Ok(())
}
fn parse_with_token_guard(
sql: &str,
dialect: &Dialect,
options: &FormatGuardOptions,
) -> Result<Vec<Expression>> {
let tokens = dialect.tokenize(sql)?;
if let Some(max) = options.max_tokens {
let token_count = tokens.len();
if token_count > max {
return Err(format_guard_error(
"E_GUARD_TOKEN_BUDGET_EXCEEDED",
token_count,
max,
));
}
}
enforce_set_op_chain_guard(&tokens, options)?;
let config = crate::parser::ParserConfig {
dialect: Some(dialect.dialect_type()),
..Default::default()
};
let mut parser = Parser::with_source(tokens, config, sql.to_string());
parser.parse()
}
fn is_trivia_token(token_type: TokenType) -> bool {
matches!(
token_type,
TokenType::Space | TokenType::Break | TokenType::LineComment | TokenType::BlockComment
)
}
fn next_significant_token(tokens: &[Token], start: usize) -> Option<&Token> {
tokens
.iter()
.skip(start)
.find(|token| !is_trivia_token(token.token_type))
}
fn is_set_operation_token(tokens: &[Token], idx: usize) -> bool {
let token = &tokens[idx];
match token.token_type {
TokenType::Union | TokenType::Intersect => true,
TokenType::Except => {
if token.text.eq_ignore_ascii_case("minus")
&& matches!(
next_significant_token(tokens, idx + 1).map(|t| t.token_type),
Some(TokenType::LParen)
)
{
return false;
}
true
}
_ => false,
}
}
fn enforce_set_op_chain_guard(tokens: &[Token], options: &FormatGuardOptions) -> Result<()> {
let Some(max) = options.max_set_op_chain else {
return Ok(());
};
let mut set_op_count = 0usize;
for (idx, token) in tokens.iter().enumerate() {
if token.token_type == TokenType::Semicolon {
set_op_count = 0;
continue;
}
if is_set_operation_token(tokens, idx) {
set_op_count += 1;
if set_op_count > max {
return Err(format_guard_error(
"E_GUARD_SET_OP_CHAIN_EXCEEDED",
set_op_count,
max,
));
}
}
}
Ok(())
}
fn enforce_ast_guard(expressions: &[Expression], options: &FormatGuardOptions) -> Result<()> {
if let Some(max) = options.max_ast_nodes {
let ast_nodes: usize = expressions.iter().map(node_count).sum();
if ast_nodes > max {
return Err(format_guard_error(
"E_GUARD_AST_BUDGET_EXCEEDED",
ast_nodes,
max,
));
}
}
Ok(())
}
fn format_with_dialect(
sql: &str,
dialect: &Dialect,
options: &FormatGuardOptions,
) -> Result<Vec<String>> {
enforce_input_guard(sql, options)?;
let expressions = parse_with_token_guard(sql, dialect, options)?;
enforce_ast_guard(&expressions, options)?;
expressions
.iter()
.map(|expr| dialect.generate_pretty(expr))
.collect()
}
pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
Dialect::get(read).transpile(sql, write)
}
pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
let d = Dialect::get(dialect);
d.parse(sql)
}
pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
let mut expressions = parse(sql, dialect)?;
if expressions.len() != 1 {
return Err(Error::parse(
format!("Expected 1 statement, found {}", expressions.len()),
0,
0,
0,
0,
));
}
Ok(expressions.remove(0))
}
pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
let d = Dialect::get(dialect);
d.generate(expression)
}
pub fn format(sql: &str, dialect: DialectType) -> Result<Vec<String>> {
format_with_options(sql, dialect, &FormatGuardOptions::default())
}
pub fn format_with_options(
sql: &str,
dialect: DialectType,
options: &FormatGuardOptions,
) -> Result<Vec<String>> {
let d = Dialect::get(dialect);
format_with_dialect(sql, &d, options)
}
pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
validate_with_options(sql, dialect, &ValidationOptions::default())
}
#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct ValidationOptions {
#[serde(default)]
pub strict_syntax: bool,
}
pub fn validate_with_options(
sql: &str,
dialect: DialectType,
options: &ValidationOptions,
) -> ValidationResult {
let d = Dialect::get(dialect);
match d.parse(sql) {
Ok(expressions) => {
for expr in &expressions {
if !expr.is_statement() {
let msg = format!("Invalid expression / Unexpected token");
return ValidationResult::with_errors(vec![ValidationError::error(
msg, "E004",
)]);
}
}
if options.strict_syntax {
if let Some(error) = strict_syntax_error(sql, &d) {
return ValidationResult::with_errors(vec![error]);
}
}
ValidationResult::success()
}
Err(e) => {
let error = match &e {
Error::Syntax {
message,
line,
column,
start,
end,
} => ValidationError::error(message.clone(), "E001")
.with_location(*line, *column)
.with_span(Some(*start), Some(*end)),
Error::Tokenize {
message,
line,
column,
start,
end,
} => ValidationError::error(message.clone(), "E002")
.with_location(*line, *column)
.with_span(Some(*start), Some(*end)),
Error::Parse {
message,
line,
column,
start,
end,
} => ValidationError::error(message.clone(), "E003")
.with_location(*line, *column)
.with_span(Some(*start), Some(*end)),
_ => ValidationError::error(e.to_string(), "E000"),
};
ValidationResult::with_errors(vec![error])
}
}
}
fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
let tokens = dialect.tokenize(sql).ok()?;
for (idx, token) in tokens.iter().enumerate() {
if token.token_type != TokenType::Comma {
continue;
}
let next = tokens.get(idx + 1);
let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
Some(TokenType::From) => (true, "FROM"),
Some(TokenType::Where) => (true, "WHERE"),
Some(TokenType::GroupBy) => (true, "GROUP BY"),
Some(TokenType::Having) => (true, "HAVING"),
Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
Some(TokenType::Limit) => (true, "LIMIT"),
Some(TokenType::Offset) => (true, "OFFSET"),
Some(TokenType::Union) => (true, "UNION"),
Some(TokenType::Intersect) => (true, "INTERSECT"),
Some(TokenType::Except) => (true, "EXCEPT"),
Some(TokenType::Qualify) => (true, "QUALIFY"),
Some(TokenType::Window) => (true, "WINDOW"),
Some(TokenType::Semicolon) | None => (true, "end of statement"),
_ => (false, ""),
};
if is_boundary {
let message = format!(
"Trailing comma before {} is not allowed in strict syntax mode",
boundary_name
);
return Some(
ValidationError::error(message, "E005")
.with_location(token.span.line, token.span.column),
);
}
}
None
}
pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
transpile_with_by_name(sql, read, write, &TranspileOptions::default())
}
pub fn transpile_with_by_name(
sql: &str,
read: &str,
write: &str,
opts: &TranspileOptions,
) -> Result<Vec<String>> {
let read_dialect = Dialect::get_by_name(read)
.ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0, 0, 0))?;
let write_dialect = Dialect::get_by_name(write)
.ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0, 0, 0))?;
read_dialect.transpile_with(sql, &write_dialect, opts.clone())
}
pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
let d = Dialect::get_by_name(dialect)
.ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
d.parse(sql)
}
pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
let d = Dialect::get_by_name(dialect)
.ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
d.generate(expression)
}
pub fn format_by_name(sql: &str, dialect: &str) -> Result<Vec<String>> {
format_with_options_by_name(sql, dialect, &FormatGuardOptions::default())
}
pub fn format_with_options_by_name(
sql: &str,
dialect: &str,
options: &FormatGuardOptions,
) -> Result<Vec<String>> {
let d = Dialect::get_by_name(dialect)
.ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0, 0, 0))?;
format_with_dialect(sql, &d, options)
}
#[cfg(test)]
mod validation_tests {
use super::*;
#[test]
fn validate_is_permissive_by_default_for_trailing_commas() {
let result = validate("SELECT name, FROM employees", DialectType::Generic);
assert!(result.valid, "Result: {:?}", result.errors);
}
#[test]
fn validate_with_options_rejects_trailing_comma_before_from() {
let options = ValidationOptions {
strict_syntax: true,
};
let result = validate_with_options(
"SELECT name, FROM employees",
DialectType::Generic,
&options,
);
assert!(!result.valid, "Result should be invalid");
assert!(
result.errors.iter().any(|e| e.code == "E005"),
"Expected E005, got: {:?}",
result.errors
);
}
#[test]
fn validate_with_options_rejects_trailing_comma_before_where() {
let options = ValidationOptions {
strict_syntax: true,
};
let result = validate_with_options(
"SELECT name FROM employees, WHERE salary > 10",
DialectType::Generic,
&options,
);
assert!(!result.valid, "Result should be invalid");
assert!(
result.errors.iter().any(|e| e.code == "E005"),
"Expected E005, got: {:?}",
result.errors
);
}
}
#[cfg(test)]
mod format_tests {
use super::*;
#[test]
fn format_basic_query() {
let result = format("SELECT a,b FROM t", DialectType::Generic).expect("format failed");
assert_eq!(result.len(), 1);
assert!(result[0].contains('\n'));
}
#[test]
fn format_guard_rejects_large_input() {
let options = FormatGuardOptions {
max_input_bytes: Some(7),
max_tokens: None,
max_ast_nodes: None,
max_set_op_chain: None,
};
let err = format_with_options("SELECT 1", DialectType::Generic, &options)
.expect_err("expected guard error");
assert!(err.to_string().contains("E_GUARD_INPUT_TOO_LARGE"));
}
#[test]
fn format_guard_rejects_token_budget() {
let options = FormatGuardOptions {
max_input_bytes: None,
max_tokens: Some(1),
max_ast_nodes: None,
max_set_op_chain: None,
};
let err = format_with_options("SELECT 1", DialectType::Generic, &options)
.expect_err("expected guard error");
assert!(err.to_string().contains("E_GUARD_TOKEN_BUDGET_EXCEEDED"));
}
#[test]
fn format_guard_rejects_ast_budget() {
let options = FormatGuardOptions {
max_input_bytes: None,
max_tokens: None,
max_ast_nodes: Some(1),
max_set_op_chain: None,
};
let err = format_with_options("SELECT 1", DialectType::Generic, &options)
.expect_err("expected guard error");
assert!(err.to_string().contains("E_GUARD_AST_BUDGET_EXCEEDED"));
}
#[test]
fn format_guard_rejects_set_op_chain_budget() {
let options = FormatGuardOptions {
max_input_bytes: None,
max_tokens: None,
max_ast_nodes: None,
max_set_op_chain: Some(1),
};
let err = format_with_options(
"SELECT 1 UNION ALL SELECT 2 UNION ALL SELECT 3",
DialectType::Generic,
&options,
)
.expect_err("expected guard error");
assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
}
#[test]
fn format_guard_does_not_treat_clickhouse_minus_function_as_set_op() {
let options = FormatGuardOptions {
max_input_bytes: None,
max_tokens: None,
max_ast_nodes: None,
max_set_op_chain: Some(0),
};
let result = format_with_options("SELECT minus(3, 2)", DialectType::ClickHouse, &options);
assert!(result.is_ok(), "Result: {:?}", result);
}
#[test]
fn issue57_invalid_ternary_returns_error() {
let sql = "SELECT x > 0 ? 1 : 0 FROM t";
let parse_result = parse(sql, DialectType::PostgreSQL);
assert!(
parse_result.is_err(),
"Expected parse error for invalid ternary SQL, got: {:?}",
parse_result
);
let format_result = format(sql, DialectType::PostgreSQL);
assert!(
format_result.is_err(),
"Expected format error for invalid ternary SQL, got: {:?}",
format_result
);
let transpile_result = transpile(sql, DialectType::PostgreSQL, DialectType::PostgreSQL);
assert!(
transpile_result.is_err(),
"Expected transpile error for invalid ternary SQL, got: {:?}",
transpile_result
);
}
#[test]
fn transpile_applies_cross_dialect_rewrites() {
let out = transpile(
"SELECT to_timestamp(col) FROM t",
DialectType::DuckDB,
DialectType::Trino,
)
.expect("transpile failed");
assert_eq!(out[0], "SELECT FROM_UNIXTIME(col) FROM t");
let out = transpile(
"SELECT CAST(col AS JSON) FROM t",
DialectType::DuckDB,
DialectType::Trino,
)
.expect("transpile failed");
assert_eq!(out[0], "SELECT JSON_PARSE(col) FROM t");
}
#[test]
fn transpile_matches_dialect_method() {
let cases: &[(DialectType, DialectType, &str, &str, &str)] = &[
(
DialectType::DuckDB,
DialectType::Trino,
"duckdb",
"trino",
"SELECT to_timestamp(col) FROM t",
),
(
DialectType::DuckDB,
DialectType::Trino,
"duckdb",
"trino",
"SELECT CAST(col AS JSON) FROM t",
),
(
DialectType::DuckDB,
DialectType::Trino,
"duckdb",
"trino",
"SELECT json_valid(col) FROM t",
),
(
DialectType::Snowflake,
DialectType::DuckDB,
"snowflake",
"duckdb",
"SELECT DATEDIFF(day, a, b) FROM t",
),
(
DialectType::BigQuery,
DialectType::DuckDB,
"bigquery",
"duckdb",
"SELECT DATE_DIFF(a, b, DAY) FROM t",
),
(
DialectType::Generic,
DialectType::Generic,
"generic",
"generic",
"SELECT 1",
),
];
for (read, write, read_name, write_name, sql) in cases {
let via_lib = transpile(sql, *read, *write).expect("lib::transpile failed");
let via_name = transpile_by_name(sql, read_name, write_name)
.expect("lib::transpile_by_name failed");
let via_dialect = Dialect::get(*read)
.transpile(sql, *write)
.expect("Dialect::transpile failed");
assert_eq!(
via_lib, via_dialect,
"lib::transpile / Dialect::transpile diverged for {:?} -> {:?}: {sql}",
read, write
);
assert_eq!(
via_name, via_dialect,
"lib::transpile_by_name / Dialect::transpile diverged for {read_name} -> {write_name}: {sql}"
);
}
}
#[test]
fn format_default_guard_rejects_deep_union_chain_before_parse() {
let base = "SELECT col0, col1 FROM t";
let mut sql = base.to_string();
for _ in 0..1100 {
sql.push_str(" UNION ALL ");
sql.push_str(base);
}
let err = format(&sql, DialectType::Athena).expect_err("expected guard error");
assert!(err.to_string().contains("E_GUARD_SET_OP_CHAIN_EXCEEDED"));
}
}