1pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21mod function_registry;
22pub mod function_catalog;
23pub mod generator;
24pub mod helper;
25pub mod lineage;
26pub mod optimizer;
27pub mod parser;
28pub mod planner;
29pub mod resolver;
30pub mod schema;
31pub mod scope;
32pub mod time;
33pub mod tokens;
34pub mod transforms;
35pub mod traversal;
36pub mod trie;
37pub mod validation;
38
39use serde::{Deserialize, Serialize};
40
41pub use ast_transforms::{
42 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
43 get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
44 node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
45 remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
46 set_limit, set_offset,
47};
48pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
49pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
50pub use expressions::Expression;
51pub use function_catalog::{
52 FunctionCatalog, FunctionNameCase, FunctionSignature, HashMapFunctionCatalog,
53};
54pub use generator::Generator;
55pub use helper::{
56 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
57 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
58};
59pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
60pub use parser::Parser;
61pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
62pub use schema::{
63 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
64};
65pub use scope::{
66 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
67 ScopeType, SourceInfo,
68};
69pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
70pub use tokens::{Token, TokenType, Tokenizer};
71pub use traversal::{
72 contains_aggregate,
73 contains_subquery,
74 contains_window_function,
75 find_ancestor,
76 find_parent,
77 get_columns,
78 get_tables,
79 is_add,
80 is_aggregate,
81 is_alias,
82 is_alter_table,
83 is_and,
84 is_arithmetic,
85 is_avg,
86 is_between,
87 is_boolean,
88 is_case,
89 is_cast,
90 is_coalesce,
91 is_column,
92 is_comparison,
93 is_concat,
94 is_count,
95 is_create_index,
96 is_create_table,
97 is_create_view,
98 is_cte,
99 is_ddl,
100 is_delete,
101 is_div,
102 is_drop_index,
103 is_drop_table,
104 is_drop_view,
105 is_eq,
106 is_except,
107 is_exists,
108 is_from,
109 is_function,
110 is_group_by,
111 is_gt,
112 is_gte,
113 is_having,
114 is_identifier,
115 is_ilike,
116 is_in,
117 is_insert,
119 is_intersect,
120 is_is_null,
121 is_join,
122 is_like,
123 is_limit,
124 is_literal,
125 is_logical,
126 is_lt,
127 is_lte,
128 is_max_func,
129 is_min_func,
130 is_mod,
131 is_mul,
132 is_neq,
133 is_not,
134 is_null_if,
135 is_null_literal,
136 is_offset,
137 is_or,
138 is_order_by,
139 is_ordered,
140 is_paren,
141 is_query,
143 is_safe_cast,
144 is_select,
145 is_set_operation,
146 is_star,
147 is_sub,
148 is_subquery,
149 is_sum,
150 is_table,
151 is_try_cast,
152 is_union,
153 is_update,
154 is_where,
155 is_window_function,
156 is_with,
157 transform,
158 transform_map,
159 BfsIter,
160 DfsIter,
161 ExpressionWalk,
162 ParentInfo,
163 TreeContext,
164};
165pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
166pub use validation::{
167 validate_with_schema, SchemaColumn, SchemaColumnReference, SchemaForeignKey, SchemaTable,
168 SchemaTableReference, SchemaValidationOptions, ValidationSchema,
169};
170
171pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
192 let read_dialect = Dialect::get(read);
193 let write_dialect = Dialect::get(write);
194 let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
195
196 let expressions = read_dialect.parse(sql)?;
197
198 expressions
199 .into_iter()
200 .map(|expr| {
201 if generic_identity {
202 write_dialect.generate_with_source(&expr, read)
203 } else {
204 let transformed = write_dialect.transform(expr)?;
205 write_dialect.generate_with_source(&transformed, read)
206 }
207 })
208 .collect()
209}
210
211pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
220 let d = Dialect::get(dialect);
221 d.parse(sql)
222}
223
224pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
233 let mut expressions = parse(sql, dialect)?;
234
235 if expressions.len() != 1 {
236 return Err(Error::parse(
237 format!("Expected 1 statement, found {}", expressions.len()),
238 0,
239 0,
240 ));
241 }
242
243 Ok(expressions.remove(0))
244}
245
246pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
255 let d = Dialect::get(dialect);
256 d.generate(expression)
257}
258
259pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
268 validate_with_options(sql, dialect, &ValidationOptions::default())
269}
270
271#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
273#[serde(rename_all = "camelCase")]
274pub struct ValidationOptions {
275 #[serde(default)]
278 pub strict_syntax: bool,
279}
280
281pub fn validate_with_options(
283 sql: &str,
284 dialect: DialectType,
285 options: &ValidationOptions,
286) -> ValidationResult {
287 let d = Dialect::get(dialect);
288 match d.parse(sql) {
289 Ok(expressions) => {
290 for expr in &expressions {
294 if !expr.is_statement() {
295 let msg = format!("Invalid expression / Unexpected token");
296 return ValidationResult::with_errors(vec![ValidationError::error(
297 msg, "E004",
298 )]);
299 }
300 }
301 if options.strict_syntax {
302 if let Some(error) = strict_syntax_error(sql, &d) {
303 return ValidationResult::with_errors(vec![error]);
304 }
305 }
306 ValidationResult::success()
307 }
308 Err(e) => {
309 let error = match &e {
310 Error::Syntax {
311 message,
312 line,
313 column,
314 } => ValidationError::error(message.clone(), "E001").with_location(*line, *column),
315 Error::Tokenize {
316 message,
317 line,
318 column,
319 } => ValidationError::error(message.clone(), "E002").with_location(*line, *column),
320 Error::Parse {
321 message,
322 line,
323 column,
324 } => ValidationError::error(message.clone(), "E003").with_location(*line, *column),
325 _ => ValidationError::error(e.to_string(), "E000"),
326 };
327 ValidationResult::with_errors(vec![error])
328 }
329 }
330}
331
332fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
333 let tokens = dialect.tokenize(sql).ok()?;
334
335 for (idx, token) in tokens.iter().enumerate() {
336 if token.token_type != TokenType::Comma {
337 continue;
338 }
339
340 let next = tokens.get(idx + 1);
341 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
342 Some(TokenType::From) => (true, "FROM"),
343 Some(TokenType::Where) => (true, "WHERE"),
344 Some(TokenType::GroupBy) => (true, "GROUP BY"),
345 Some(TokenType::Having) => (true, "HAVING"),
346 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
347 Some(TokenType::Limit) => (true, "LIMIT"),
348 Some(TokenType::Offset) => (true, "OFFSET"),
349 Some(TokenType::Union) => (true, "UNION"),
350 Some(TokenType::Intersect) => (true, "INTERSECT"),
351 Some(TokenType::Except) => (true, "EXCEPT"),
352 Some(TokenType::Qualify) => (true, "QUALIFY"),
353 Some(TokenType::Window) => (true, "WINDOW"),
354 Some(TokenType::Semicolon) | None => (true, "end of statement"),
355 _ => (false, ""),
356 };
357
358 if is_boundary {
359 let message = format!(
360 "Trailing comma before {} is not allowed in strict syntax mode",
361 boundary_name
362 );
363 return Some(
364 ValidationError::error(message, "E005")
365 .with_location(token.span.line, token.span.column),
366 );
367 }
368 }
369
370 None
371}
372
373pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
386 let read_dialect = Dialect::get_by_name(read)
387 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0))?;
388 let write_dialect = Dialect::get_by_name(write)
389 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0))?;
390 let generic_identity = read_dialect.dialect_type() == DialectType::Generic
391 && write_dialect.dialect_type() == DialectType::Generic;
392
393 let expressions = read_dialect.parse(sql)?;
394
395 expressions
396 .into_iter()
397 .map(|expr| {
398 if generic_identity {
399 write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
400 } else {
401 let transformed = write_dialect.transform(expr)?;
402 write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
403 }
404 })
405 .collect()
406}
407
408pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
412 let d = Dialect::get_by_name(dialect)
413 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
414 d.parse(sql)
415}
416
417pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
421 let d = Dialect::get_by_name(dialect)
422 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
423 d.generate(expression)
424}
425
426#[cfg(test)]
427mod validation_tests {
428 use super::*;
429
430 #[test]
431 fn validate_is_permissive_by_default_for_trailing_commas() {
432 let result = validate("SELECT name, FROM employees", DialectType::Generic);
433 assert!(result.valid, "Result: {:?}", result.errors);
434 }
435
436 #[test]
437 fn validate_with_options_rejects_trailing_comma_before_from() {
438 let options = ValidationOptions {
439 strict_syntax: true,
440 };
441 let result = validate_with_options(
442 "SELECT name, FROM employees",
443 DialectType::Generic,
444 &options,
445 );
446 assert!(!result.valid, "Result should be invalid");
447 assert!(
448 result.errors.iter().any(|e| e.code == "E005"),
449 "Expected E005, got: {:?}",
450 result.errors
451 );
452 }
453
454 #[test]
455 fn validate_with_options_rejects_trailing_comma_before_where() {
456 let options = ValidationOptions {
457 strict_syntax: true,
458 };
459 let result = validate_with_options(
460 "SELECT name FROM employees, WHERE salary > 10",
461 DialectType::Generic,
462 &options,
463 );
464 assert!(!result.valid, "Result should be invalid");
465 assert!(
466 result.errors.iter().any(|e| e.code == "E005"),
467 "Expected E005, got: {:?}",
468 result.errors
469 );
470 }
471}