1pub mod ast_transforms;
16pub mod builder;
17pub mod dialects;
18pub mod diff;
19pub mod error;
20pub mod expressions;
21mod function_registry;
22pub mod generator;
23pub mod helper;
24pub mod lineage;
25pub mod optimizer;
26pub mod parser;
27pub mod planner;
28pub mod resolver;
29pub mod schema;
30pub mod scope;
31pub mod time;
32pub mod tokens;
33pub mod transforms;
34pub mod traversal;
35pub mod trie;
36pub mod validation;
37
38use serde::{Deserialize, Serialize};
39
40pub use ast_transforms::{
41 add_select_columns, add_where, get_aggregate_functions, get_column_names, get_functions,
42 get_identifiers, get_literals, get_subqueries, get_table_names, get_window_functions,
43 node_count, qualify_columns, remove_limit_offset, remove_nodes, remove_select_columns,
44 remove_where, rename_columns, rename_tables, replace_by_type, replace_nodes, set_distinct,
45 set_limit, set_offset,
46};
47pub use dialects::{unregister_custom_dialect, CustomDialectBuilder, Dialect, DialectType};
48pub use error::{Error, Result, ValidationError, ValidationResult, ValidationSeverity};
49pub use expressions::Expression;
50pub use generator::Generator;
51pub use helper::{
52 csv, find_new_name, is_date_unit, is_float, is_int, is_iso_date, is_iso_datetime, merge_ranges,
53 name_sequence, seq_get, split_num_words, tsort, while_changing, DATE_UNITS,
54};
55pub use optimizer::{annotate_types, TypeAnnotator, TypeCoercionClass};
56pub use parser::Parser;
57pub use resolver::{is_column_ambiguous, resolve_column, Resolver, ResolverError, ResolverResult};
58pub use schema::{
59 ensure_schema, from_simple_map, normalize_name, MappingSchema, Schema, SchemaError,
60};
61pub use scope::{
62 build_scope, find_all_in_scope, find_in_scope, traverse_scope, walk_in_scope, ColumnRef, Scope,
63 ScopeType, SourceInfo,
64};
65pub use time::{format_time, is_valid_timezone, subsecond_precision, TIMEZONES};
66pub use tokens::{Token, TokenType, Tokenizer};
67pub use traversal::{
68 contains_aggregate,
69 contains_subquery,
70 contains_window_function,
71 find_ancestor,
72 find_parent,
73 get_columns,
74 get_tables,
75 is_add,
76 is_aggregate,
77 is_alias,
78 is_alter_table,
79 is_and,
80 is_arithmetic,
81 is_avg,
82 is_between,
83 is_boolean,
84 is_case,
85 is_cast,
86 is_coalesce,
87 is_column,
88 is_comparison,
89 is_concat,
90 is_count,
91 is_create_index,
92 is_create_table,
93 is_create_view,
94 is_cte,
95 is_ddl,
96 is_delete,
97 is_div,
98 is_drop_index,
99 is_drop_table,
100 is_drop_view,
101 is_eq,
102 is_except,
103 is_exists,
104 is_from,
105 is_function,
106 is_group_by,
107 is_gt,
108 is_gte,
109 is_having,
110 is_identifier,
111 is_ilike,
112 is_in,
113 is_insert,
115 is_intersect,
116 is_is_null,
117 is_join,
118 is_like,
119 is_limit,
120 is_literal,
121 is_logical,
122 is_lt,
123 is_lte,
124 is_max_func,
125 is_min_func,
126 is_mod,
127 is_mul,
128 is_neq,
129 is_not,
130 is_null_if,
131 is_null_literal,
132 is_offset,
133 is_or,
134 is_order_by,
135 is_ordered,
136 is_paren,
137 is_query,
139 is_safe_cast,
140 is_select,
141 is_set_operation,
142 is_star,
143 is_sub,
144 is_subquery,
145 is_sum,
146 is_table,
147 is_try_cast,
148 is_union,
149 is_update,
150 is_where,
151 is_window_function,
152 is_with,
153 transform,
154 transform_map,
155 BfsIter,
156 DfsIter,
157 ExpressionWalk,
158 ParentInfo,
159 TreeContext,
160};
161pub use trie::{new_trie, new_trie_from_keys, Trie, TrieResult};
162pub use validation::{
163 validate_with_schema, SchemaColumn, SchemaColumnReference, SchemaForeignKey, SchemaTable,
164 SchemaTableReference, SchemaValidationOptions, ValidationSchema,
165};
166
167pub fn transpile(sql: &str, read: DialectType, write: DialectType) -> Result<Vec<String>> {
188 let read_dialect = Dialect::get(read);
189 let write_dialect = Dialect::get(write);
190 let generic_identity = read == DialectType::Generic && write == DialectType::Generic;
191
192 let expressions = read_dialect.parse(sql)?;
193
194 expressions
195 .into_iter()
196 .map(|expr| {
197 if generic_identity {
198 write_dialect.generate_with_source(&expr, read)
199 } else {
200 let transformed = write_dialect.transform(expr)?;
201 write_dialect.generate_with_source(&transformed, read)
202 }
203 })
204 .collect()
205}
206
207pub fn parse(sql: &str, dialect: DialectType) -> Result<Vec<Expression>> {
216 let d = Dialect::get(dialect);
217 d.parse(sql)
218}
219
220pub fn parse_one(sql: &str, dialect: DialectType) -> Result<Expression> {
229 let mut expressions = parse(sql, dialect)?;
230
231 if expressions.len() != 1 {
232 return Err(Error::parse(
233 format!("Expected 1 statement, found {}", expressions.len()),
234 0,
235 0,
236 ));
237 }
238
239 Ok(expressions.remove(0))
240}
241
242pub fn generate(expression: &Expression, dialect: DialectType) -> Result<String> {
251 let d = Dialect::get(dialect);
252 d.generate(expression)
253}
254
255pub fn validate(sql: &str, dialect: DialectType) -> ValidationResult {
264 validate_with_options(sql, dialect, &ValidationOptions::default())
265}
266
267#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize)]
269#[serde(rename_all = "camelCase")]
270pub struct ValidationOptions {
271 #[serde(default)]
274 pub strict_syntax: bool,
275}
276
277pub fn validate_with_options(
279 sql: &str,
280 dialect: DialectType,
281 options: &ValidationOptions,
282) -> ValidationResult {
283 let d = Dialect::get(dialect);
284 match d.parse(sql) {
285 Ok(expressions) => {
286 for expr in &expressions {
290 if !expr.is_statement() {
291 let msg = format!("Invalid expression / Unexpected token");
292 return ValidationResult::with_errors(vec![ValidationError::error(
293 msg, "E004",
294 )]);
295 }
296 }
297 if options.strict_syntax {
298 if let Some(error) = strict_syntax_error(sql, &d) {
299 return ValidationResult::with_errors(vec![error]);
300 }
301 }
302 ValidationResult::success()
303 }
304 Err(e) => {
305 let error = match &e {
306 Error::Syntax {
307 message,
308 line,
309 column,
310 } => ValidationError::error(message.clone(), "E001").with_location(*line, *column),
311 Error::Tokenize {
312 message,
313 line,
314 column,
315 } => ValidationError::error(message.clone(), "E002").with_location(*line, *column),
316 Error::Parse {
317 message,
318 line,
319 column,
320 } => ValidationError::error(message.clone(), "E003").with_location(*line, *column),
321 _ => ValidationError::error(e.to_string(), "E000"),
322 };
323 ValidationResult::with_errors(vec![error])
324 }
325 }
326}
327
328fn strict_syntax_error(sql: &str, dialect: &Dialect) -> Option<ValidationError> {
329 let tokens = dialect.tokenize(sql).ok()?;
330
331 for (idx, token) in tokens.iter().enumerate() {
332 if token.token_type != TokenType::Comma {
333 continue;
334 }
335
336 let next = tokens.get(idx + 1);
337 let (is_boundary, boundary_name) = match next.map(|t| t.token_type) {
338 Some(TokenType::From) => (true, "FROM"),
339 Some(TokenType::Where) => (true, "WHERE"),
340 Some(TokenType::GroupBy) => (true, "GROUP BY"),
341 Some(TokenType::Having) => (true, "HAVING"),
342 Some(TokenType::Order) | Some(TokenType::OrderBy) => (true, "ORDER BY"),
343 Some(TokenType::Limit) => (true, "LIMIT"),
344 Some(TokenType::Offset) => (true, "OFFSET"),
345 Some(TokenType::Union) => (true, "UNION"),
346 Some(TokenType::Intersect) => (true, "INTERSECT"),
347 Some(TokenType::Except) => (true, "EXCEPT"),
348 Some(TokenType::Qualify) => (true, "QUALIFY"),
349 Some(TokenType::Window) => (true, "WINDOW"),
350 Some(TokenType::Semicolon) | None => (true, "end of statement"),
351 _ => (false, ""),
352 };
353
354 if is_boundary {
355 let message = format!(
356 "Trailing comma before {} is not allowed in strict syntax mode",
357 boundary_name
358 );
359 return Some(
360 ValidationError::error(message, "E005")
361 .with_location(token.span.line, token.span.column),
362 );
363 }
364 }
365
366 None
367}
368
369pub fn transpile_by_name(sql: &str, read: &str, write: &str) -> Result<Vec<String>> {
382 let read_dialect = Dialect::get_by_name(read)
383 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", read), 0, 0))?;
384 let write_dialect = Dialect::get_by_name(write)
385 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", write), 0, 0))?;
386 let generic_identity = read_dialect.dialect_type() == DialectType::Generic
387 && write_dialect.dialect_type() == DialectType::Generic;
388
389 let expressions = read_dialect.parse(sql)?;
390
391 expressions
392 .into_iter()
393 .map(|expr| {
394 if generic_identity {
395 write_dialect.generate_with_source(&expr, read_dialect.dialect_type())
396 } else {
397 let transformed = write_dialect.transform(expr)?;
398 write_dialect.generate_with_source(&transformed, read_dialect.dialect_type())
399 }
400 })
401 .collect()
402}
403
404pub fn parse_by_name(sql: &str, dialect: &str) -> Result<Vec<Expression>> {
408 let d = Dialect::get_by_name(dialect)
409 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
410 d.parse(sql)
411}
412
413pub fn generate_by_name(expression: &Expression, dialect: &str) -> Result<String> {
417 let d = Dialect::get_by_name(dialect)
418 .ok_or_else(|| Error::parse(format!("Unknown dialect: {}", dialect), 0, 0))?;
419 d.generate(expression)
420}
421
422#[cfg(test)]
423mod validation_tests {
424 use super::*;
425
426 #[test]
427 fn validate_is_permissive_by_default_for_trailing_commas() {
428 let result = validate("SELECT name, FROM employees", DialectType::Generic);
429 assert!(result.valid, "Result: {:?}", result.errors);
430 }
431
432 #[test]
433 fn validate_with_options_rejects_trailing_comma_before_from() {
434 let options = ValidationOptions {
435 strict_syntax: true,
436 };
437 let result = validate_with_options(
438 "SELECT name, FROM employees",
439 DialectType::Generic,
440 &options,
441 );
442 assert!(!result.valid, "Result should be invalid");
443 assert!(
444 result.errors.iter().any(|e| e.code == "E005"),
445 "Expected E005, got: {:?}",
446 result.errors
447 );
448 }
449
450 #[test]
451 fn validate_with_options_rejects_trailing_comma_before_where() {
452 let options = ValidationOptions {
453 strict_syntax: true,
454 };
455 let result = validate_with_options(
456 "SELECT name FROM employees, WHERE salary > 10",
457 DialectType::Generic,
458 &options,
459 );
460 assert!(!result.valid, "Result should be invalid");
461 assert!(
462 result.errors.iter().any(|e| e.code == "E005"),
463 "Expected E005, got: {:?}",
464 result.errors
465 );
466 }
467}