1use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13 AggregateFunction, Cast, DataType, DateAddFunc, Expression, Function, IntervalUnit, Literal,
14 UnaryFunc, VarArgFunc,
15};
16#[cfg(feature = "generate")]
17use crate::generator::GeneratorConfig;
18use crate::tokens::TokenizerConfig;
19
20pub struct DatabricksDialect;
22
23impl DialectImpl for DatabricksDialect {
24 fn dialect_type(&self) -> DialectType {
25 DialectType::Databricks
26 }
27
28 fn tokenizer_config(&self) -> TokenizerConfig {
29 let mut config = TokenizerConfig::default();
30 config.identifiers.clear();
32 config.identifiers.insert('`', '`');
33 config.quotes.insert("\"".to_string(), "\"".to_string());
35 config.string_escapes.push('\\');
37 config
39 .keywords
40 .insert("DIV".to_string(), crate::tokens::TokenType::Div);
41 config
42 .keywords
43 .insert("REPAIR".to_string(), crate::tokens::TokenType::Command);
44 config
45 .keywords
46 .insert("MSCK".to_string(), crate::tokens::TokenType::Command);
47 config
49 .numeric_literals
50 .insert("L".to_string(), "BIGINT".to_string());
51 config
52 .numeric_literals
53 .insert("S".to_string(), "SMALLINT".to_string());
54 config
55 .numeric_literals
56 .insert("Y".to_string(), "TINYINT".to_string());
57 config
58 .numeric_literals
59 .insert("D".to_string(), "DOUBLE".to_string());
60 config
61 .numeric_literals
62 .insert("F".to_string(), "FLOAT".to_string());
63 config
64 .numeric_literals
65 .insert("BD".to_string(), "DECIMAL".to_string());
66 config.identifiers_can_start_with_digit = true;
68 config.string_escapes_allowed_in_raw_strings = false;
71 config
72 }
73
74 #[cfg(feature = "generate")]
75
76 fn generator_config(&self) -> GeneratorConfig {
77 use crate::generator::IdentifierQuoteStyle;
78 GeneratorConfig {
79 identifier_quote: '`',
80 identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
81 dialect: Some(DialectType::Databricks),
82 struct_field_sep: ": ",
83 create_function_return_as: false,
84 tablesample_seed_keyword: "REPEATABLE",
85 identifiers_can_start_with_digit: true,
86 schema_comment_with_eq: false,
88 ..Default::default()
89 }
90 }
91
92 #[cfg(feature = "transpile")]
93
94 fn transform_expr(&self, expr: Expression) -> Result<Expression> {
95 match expr {
96 Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
98 original_name: None,
99 expressions: vec![f.this, f.expression],
100 inferred_type: None,
101 }))),
102
103 Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
105 original_name: None,
106 expressions: vec![f.this, f.expression],
107 inferred_type: None,
108 }))),
109
110 Expression::TryCast(c) => Ok(Expression::TryCast(c)),
112
113 Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
115
116 Expression::ILike(op) => Ok(Expression::ILike(op)),
118
119 Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
121
122 Expression::Explode(f) => Ok(Expression::Explode(f)),
124
125 Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
127
128 Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
130 seed: None,
131 lower: None,
132 upper: None,
133 }))),
134
135 Expression::Rand(r) => Ok(Expression::Rand(r)),
137
138 Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
140 "CONCAT".to_string(),
141 vec![op.left, op.right],
142 )))),
143
144 Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
146
147 Expression::Cast(c) => self.transform_cast(*c),
152
153 Expression::Function(f) => self.transform_function(*f),
155
156 Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
158
159 Expression::DateSub(f) => {
161 let val = match f.interval {
163 Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(s) if s.parse::<i64>().is_ok()) =>
164 {
165 let crate::expressions::Literal::String(s) = lit.as_ref() else {
166 unreachable!()
167 };
168 Expression::Literal(Box::new(crate::expressions::Literal::Number(
169 s.clone(),
170 )))
171 }
172 other => other,
173 };
174 let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
175 this: val,
176 inferred_type: None,
177 }));
178 Ok(Expression::Function(Box::new(Function::new(
179 "DATE_ADD".to_string(),
180 vec![f.this, neg_val],
181 ))))
182 }
183
184 Expression::DateAdd(f) => Ok(Self::transform_date_add(*f)),
186
187 _ => Ok(expr),
189 }
190 }
191}
192
193#[cfg(feature = "transpile")]
194impl DatabricksDialect {
195 fn transform_function(&self, f: Function) -> Result<Expression> {
196 let name_upper = f.name.to_uppercase();
197 match name_upper.as_str() {
198 "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
200 original_name: None,
201 expressions: f.args,
202 inferred_type: None,
203 }))),
204
205 "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
207 original_name: None,
208 expressions: f.args,
209 inferred_type: None,
210 }))),
211
212 "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
214 original_name: None,
215 expressions: f.args,
216 inferred_type: None,
217 }))),
218
219 "ROW" => Ok(Expression::Function(Box::new(Function::new(
221 "STRUCT".to_string(),
222 f.args,
223 )))),
224
225 "NAMED_STRUCT" if f.args.len() % 2 == 0 => {
227 let original_args = f.args.clone();
228 let mut struct_args = Vec::new();
229 for pair in f.args.chunks(2) {
230 if let Expression::Literal(lit) = &pair[0] {
231 if let Literal::String(field_name) = lit.as_ref() {
232 struct_args.push(Expression::Alias(Box::new(
233 crate::expressions::Alias {
234 this: pair[1].clone(),
235 alias: crate::expressions::Identifier::new(field_name),
236 column_aliases: Vec::new(),
237 alias_explicit_as: false,
238 alias_keyword: None,
239 pre_alias_comments: Vec::new(),
240 trailing_comments: Vec::new(),
241 inferred_type: None,
242 },
243 )));
244 continue;
245 }
246 }
247 return Ok(Expression::Function(Box::new(Function::new(
248 "NAMED_STRUCT".to_string(),
249 original_args,
250 ))));
251 }
252 Ok(Expression::Function(Box::new(Function::new(
253 "STRUCT".to_string(),
254 struct_args,
255 ))))
256 }
257
258 "GETDATE" => Ok(Expression::CurrentTimestamp(
260 crate::expressions::CurrentTimestamp {
261 precision: None,
262 sysdate: false,
263 },
264 )),
265
266 "NOW" => Ok(Expression::CurrentTimestamp(
268 crate::expressions::CurrentTimestamp {
269 precision: None,
270 sysdate: false,
271 },
272 )),
273
274 "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
276
277 "CURRENT_DATE" if f.args.is_empty() => {
279 Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
280 }
281
282 "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
284 seed: None,
285 lower: None,
286 upper: None,
287 }))),
288
289 "GROUP_CONCAT" if !f.args.is_empty() => {
291 let mut args = f.args;
292 let first = args.remove(0);
293 let separator = args.pop();
294 let collect_list = Expression::Function(Box::new(Function::new(
295 "COLLECT_LIST".to_string(),
296 vec![first],
297 )));
298 if let Some(sep) = separator {
299 Ok(Expression::Function(Box::new(Function::new(
300 "ARRAY_JOIN".to_string(),
301 vec![collect_list, sep],
302 ))))
303 } else {
304 Ok(Expression::Function(Box::new(Function::new(
305 "ARRAY_JOIN".to_string(),
306 vec![collect_list],
307 ))))
308 }
309 }
310
311 "STRING_AGG" if !f.args.is_empty() => {
313 let mut args = f.args;
314 let first = args.remove(0);
315 let separator = args.pop();
316 let collect_list = Expression::Function(Box::new(Function::new(
317 "COLLECT_LIST".to_string(),
318 vec![first],
319 )));
320 if let Some(sep) = separator {
321 Ok(Expression::Function(Box::new(Function::new(
322 "ARRAY_JOIN".to_string(),
323 vec![collect_list, sep],
324 ))))
325 } else {
326 Ok(Expression::Function(Box::new(Function::new(
327 "ARRAY_JOIN".to_string(),
328 vec![collect_list],
329 ))))
330 }
331 }
332
333 "LISTAGG" if !f.args.is_empty() => {
335 let mut args = f.args;
336 let first = args.remove(0);
337 let separator = args.pop();
338 let collect_list = Expression::Function(Box::new(Function::new(
339 "COLLECT_LIST".to_string(),
340 vec![first],
341 )));
342 if let Some(sep) = separator {
343 Ok(Expression::Function(Box::new(Function::new(
344 "ARRAY_JOIN".to_string(),
345 vec![collect_list, sep],
346 ))))
347 } else {
348 Ok(Expression::Function(Box::new(Function::new(
349 "ARRAY_JOIN".to_string(),
350 vec![collect_list],
351 ))))
352 }
353 }
354
355 "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
357 "COLLECT_LIST".to_string(),
358 f.args,
359 )))),
360
361 "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
363 "SUBSTRING".to_string(),
364 f.args,
365 )))),
366
367 "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
369 f.args.into_iter().next().unwrap(),
370 )))),
371
372 "CHARINDEX" if f.args.len() >= 2 => {
374 let mut args = f.args;
375 let substring = args.remove(0);
376 let string = args.remove(0);
377 Ok(Expression::Function(Box::new(Function::new(
379 "LOCATE".to_string(),
380 vec![substring, string],
381 ))))
382 }
383
384 "POSITION" if f.args.len() == 2 => {
386 let args = f.args;
387 Ok(Expression::Function(Box::new(Function::new(
388 "LOCATE".to_string(),
389 args,
390 ))))
391 }
392
393 "STRPOS" if f.args.len() == 2 => {
395 let args = f.args;
396 let string = args[0].clone();
397 let substring = args[1].clone();
398 Ok(Expression::Function(Box::new(Function::new(
400 "LOCATE".to_string(),
401 vec![substring, string],
402 ))))
403 }
404
405 "INSTR" => Ok(Expression::Function(Box::new(f))),
407
408 "LOCATE" => Ok(Expression::Function(Box::new(f))),
410
411 "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
413 Function::new("SIZE".to_string(), f.args),
414 ))),
415
416 "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
418 Function::new("SIZE".to_string(), f.args),
419 ))),
420
421 "SIZE" => Ok(Expression::Function(Box::new(f))),
423
424 "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
426
427 "CONTAINS" if f.args.len() == 2 => {
430 let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
432 && matches!(&f.args[1], Expression::Lower(_));
433 if is_string_contains {
434 Ok(Expression::Function(Box::new(f)))
435 } else {
436 Ok(Expression::Function(Box::new(Function::new(
437 "ARRAY_CONTAINS".to_string(),
438 f.args,
439 ))))
440 }
441 }
442
443 "TO_DATE" => Ok(Expression::Function(Box::new(f))),
445
446 "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
448
449 "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
451
452 "STRFTIME" if f.args.len() >= 2 => {
454 let mut args = f.args;
455 let format = args.remove(0);
456 let date = args.remove(0);
457 Ok(Expression::Function(Box::new(Function::new(
458 "DATE_FORMAT".to_string(),
459 vec![date, format],
460 ))))
461 }
462
463 "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
465
466 "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
468
469 "DATEADD" => {
471 if f.args.len() == 2 {
472 Ok(Expression::Function(Box::new(Function::new(
473 "DATE_ADD".to_string(),
474 f.args,
475 ))))
476 } else {
477 let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
478 let function_name = if matches!(
479 transformed_args.first(),
480 Some(Expression::Identifier(unit))
481 if unit.name.eq_ignore_ascii_case("WEEK")
482 ) {
483 "DATEADD"
484 } else {
485 "DATE_ADD"
486 };
487 Ok(Expression::Function(Box::new(Function::new(
488 function_name.to_string(),
489 transformed_args,
490 ))))
491 }
492 }
493
494 "DATE_ADD" => {
499 if f.args.len() == 2 {
500 let is_simple_number = matches!(
501 &f.args[1],
502 Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::Number(_))
503 ) || matches!(&f.args[1], Expression::Neg(_));
504 if is_simple_number {
505 Ok(Expression::Function(Box::new(Function::new(
507 "DATE_ADD".to_string(),
508 f.args,
509 ))))
510 } else {
511 let mut args = f.args;
512 let date = args.remove(0);
513 let interval = args.remove(0);
514 let unit = Expression::Identifier(crate::expressions::Identifier {
515 name: "DAY".to_string(),
516 quoted: false,
517 trailing_comments: Vec::new(),
518 span: None,
519 });
520 Ok(Expression::Function(Box::new(Function::new(
521 "DATEADD".to_string(),
522 vec![unit, interval, date],
523 ))))
524 }
525 } else {
526 let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
527 Ok(Expression::Function(Box::new(Function::new(
528 "DATE_ADD".to_string(),
529 transformed_args,
530 ))))
531 }
532 }
533
534 "DATEDIFF" => {
538 if f.args.len() == 2 {
539 let mut args = f.args;
540 let end_date = args.remove(0);
541 let start_date = args.remove(0);
542 let unit = Expression::Identifier(crate::expressions::Identifier {
543 name: "DAY".to_string(),
544 quoted: false,
545 trailing_comments: Vec::new(),
546 span: None,
547 });
548 Ok(Expression::Function(Box::new(Function::new(
549 "DATEDIFF".to_string(),
550 vec![unit, start_date, end_date],
551 ))))
552 } else {
553 let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
554 Ok(Expression::Function(Box::new(Function::new(
555 "DATEDIFF".to_string(),
556 transformed_args,
557 ))))
558 }
559 }
560
561 "DATE_DIFF" => {
563 let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
564 Ok(Expression::Function(Box::new(Function::new(
565 "DATEDIFF".to_string(),
566 transformed_args,
567 ))))
568 }
569
570 "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
572
573 "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
575
576 "GET_JSON_OBJECT" if f.args.len() == 2 => {
578 let mut args = f.args;
579 let col = args.remove(0);
580 let path_arg = args.remove(0);
581
582 Ok(Expression::Function(Box::new(Function::new(
583 "GET_JSON_OBJECT".to_string(),
584 vec![col, self.normalize_get_json_object_path(path_arg)],
585 ))))
586 }
587
588 "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
590
591 "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
593
594 "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
596
597 "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
599
600 "RLIKE" => Ok(Expression::Function(Box::new(f))),
602
603 "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
605 "RLIKE".to_string(),
606 f.args,
607 )))),
608
609 "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
611
612 "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
614
615 "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
617 Function::new("SEQUENCE".to_string(), f.args),
618 ))),
619
620 "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
622
623 "FLATTEN" => Ok(Expression::Function(Box::new(f))),
625
626 "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
628
629 "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
631
632 "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
634
635 "FILTER" => Ok(Expression::Function(Box::new(f))),
637
638 "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
640 let mut args = f.args;
641 let first_arg = args.remove(0);
642
643 let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
645 first_arg
646 } else {
647 Expression::Cast(Box::new(Cast {
649 this: first_arg,
650 to: DataType::Timestamp {
651 precision: None,
652 timezone: false,
653 },
654 trailing_comments: Vec::new(),
655 double_colon_syntax: false,
656 format: None,
657 default: None,
658 inferred_type: None,
659 }))
660 };
661
662 let mut new_args = vec![wrapped_arg];
663 new_args.extend(args);
664
665 Ok(Expression::Function(Box::new(Function::new(
666 "FROM_UTC_TIMESTAMP".to_string(),
667 new_args,
668 ))))
669 }
670
671 "UNIFORM" if f.args.len() == 3 => {
673 let mut args = f.args;
674 let low = args.remove(0);
675 let high = args.remove(0);
676 let gen = args.remove(0);
677 match gen {
678 Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
679 if func.args.len() == 1 {
680 let seed = func.args.into_iter().next().unwrap();
682 Ok(Expression::Function(Box::new(Function::new(
683 "UNIFORM".to_string(),
684 vec![low, high, seed],
685 ))))
686 } else {
687 Ok(Expression::Function(Box::new(Function::new(
689 "UNIFORM".to_string(),
690 vec![low, high],
691 ))))
692 }
693 }
694 Expression::Rand(r) => {
695 if let Some(seed) = r.seed {
696 Ok(Expression::Function(Box::new(Function::new(
697 "UNIFORM".to_string(),
698 vec![low, high, *seed],
699 ))))
700 } else {
701 Ok(Expression::Function(Box::new(Function::new(
702 "UNIFORM".to_string(),
703 vec![low, high],
704 ))))
705 }
706 }
707 _ => Ok(Expression::Function(Box::new(Function::new(
708 "UNIFORM".to_string(),
709 vec![low, high, gen],
710 )))),
711 }
712 }
713
714 "REGEXP_SUBSTR" if f.args.len() >= 2 => {
716 let subject = f.args[0].clone();
717 let pattern = f.args[1].clone();
718 Ok(Expression::Function(Box::new(Function::new(
719 "REGEXP_EXTRACT".to_string(),
720 vec![subject, pattern],
721 ))))
722 }
723
724 "BIT_GET" => Ok(Expression::Function(Box::new(Function::new(
726 "GETBIT".to_string(),
727 f.args,
728 )))),
729
730 _ => Ok(Expression::Function(Box::new(f))),
732 }
733 }
734
735 fn transform_aggregate_function(
736 &self,
737 f: Box<crate::expressions::AggregateFunction>,
738 ) -> Result<Expression> {
739 let name_upper = f.name.to_uppercase();
740 match name_upper.as_str() {
741 "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
743
744 "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
746
747 "GROUP_CONCAT" if !f.args.is_empty() => {
749 let mut args = f.args;
750 let first = args.remove(0);
751 let separator = args.pop();
752 let collect_list = Expression::Function(Box::new(Function::new(
753 "COLLECT_LIST".to_string(),
754 vec![first],
755 )));
756 if let Some(sep) = separator {
757 Ok(Expression::Function(Box::new(Function::new(
758 "ARRAY_JOIN".to_string(),
759 vec![collect_list, sep],
760 ))))
761 } else {
762 Ok(Expression::Function(Box::new(Function::new(
763 "ARRAY_JOIN".to_string(),
764 vec![collect_list],
765 ))))
766 }
767 }
768
769 "STRING_AGG" if !f.args.is_empty() => {
771 let mut args = f.args;
772 let first = args.remove(0);
773 let separator = args.pop();
774 let collect_list = Expression::Function(Box::new(Function::new(
775 "COLLECT_LIST".to_string(),
776 vec![first],
777 )));
778 if let Some(sep) = separator {
779 Ok(Expression::Function(Box::new(Function::new(
780 "ARRAY_JOIN".to_string(),
781 vec![collect_list, sep],
782 ))))
783 } else {
784 Ok(Expression::Function(Box::new(Function::new(
785 "ARRAY_JOIN".to_string(),
786 vec![collect_list],
787 ))))
788 }
789 }
790
791 "LISTAGG" if !f.args.is_empty() => {
793 let mut args = f.args;
794 let first = args.remove(0);
795 let separator = args.pop();
796 let collect_list = Expression::Function(Box::new(Function::new(
797 "COLLECT_LIST".to_string(),
798 vec![first],
799 )));
800 if let Some(sep) = separator {
801 Ok(Expression::Function(Box::new(Function::new(
802 "ARRAY_JOIN".to_string(),
803 vec![collect_list, sep],
804 ))))
805 } else {
806 Ok(Expression::Function(Box::new(Function::new(
807 "ARRAY_JOIN".to_string(),
808 vec![collect_list],
809 ))))
810 }
811 }
812
813 "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
815 "COLLECT_LIST".to_string(),
816 f.args,
817 )))),
818
819 "STDDEV" => Ok(Expression::AggregateFunction(f)),
821
822 "VARIANCE" => Ok(Expression::AggregateFunction(f)),
824
825 "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
827
828 "APPROX_DISTINCT" if !f.args.is_empty() => {
830 Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
831 name: "APPROX_COUNT_DISTINCT".to_string(),
832 args: f.args,
833 distinct: f.distinct,
834 filter: f.filter,
835 order_by: Vec::new(),
836 limit: None,
837 ignore_nulls: None,
838 inferred_type: None,
839 })))
840 }
841
842 _ => Ok(Expression::AggregateFunction(f)),
844 }
845 }
846
847 fn transform_cast(&self, c: Cast) -> Result<Expression> {
857 match &c.this {
859 Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Timestamp(_)) => {
861 let Literal::Timestamp(value) = lit.as_ref() else {
862 unreachable!()
863 };
864 let inner_cast = Expression::Cast(Box::new(Cast {
866 this: Expression::Literal(Box::new(Literal::String(value.clone()))),
867 to: c.to,
868 trailing_comments: Vec::new(),
869 double_colon_syntax: false,
870 format: None,
871 default: None,
872 inferred_type: None,
873 }));
874 Ok(Expression::Cast(Box::new(Cast {
876 this: inner_cast,
877 to: DataType::Timestamp {
878 precision: None,
879 timezone: false,
880 },
881 trailing_comments: c.trailing_comments,
882 double_colon_syntax: false,
883 format: None,
884 default: None,
885 inferred_type: None,
886 })))
887 }
888 Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
890 let Literal::Date(value) = lit.as_ref() else {
891 unreachable!()
892 };
893 let inner_cast = Expression::Cast(Box::new(Cast {
894 this: Expression::Literal(Box::new(Literal::String(value.clone()))),
895 to: c.to,
896 trailing_comments: Vec::new(),
897 double_colon_syntax: false,
898 format: None,
899 default: None,
900 inferred_type: None,
901 }));
902 Ok(Expression::Cast(Box::new(Cast {
903 this: inner_cast,
904 to: DataType::Date,
905 trailing_comments: c.trailing_comments,
906 double_colon_syntax: false,
907 format: None,
908 default: None,
909 inferred_type: None,
910 })))
911 }
912 Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Time(_)) => {
914 let Literal::Time(value) = lit.as_ref() else {
915 unreachable!()
916 };
917 let inner_cast = Expression::Cast(Box::new(Cast {
918 this: Expression::Literal(Box::new(Literal::String(value.clone()))),
919 to: c.to,
920 trailing_comments: Vec::new(),
921 double_colon_syntax: false,
922 format: None,
923 default: None,
924 inferred_type: None,
925 }));
926 Ok(Expression::Cast(Box::new(Cast {
927 this: inner_cast,
928 to: DataType::Time {
929 precision: None,
930 timezone: false,
931 },
932 trailing_comments: c.trailing_comments,
933 double_colon_syntax: false,
934 format: None,
935 default: None,
936 inferred_type: None,
937 })))
938 }
939 _ => Ok(Expression::Cast(Box::new(c))),
941 }
942 }
943
944 fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
946 if let Expression::Cast(cast) = expr {
947 matches!(cast.to, DataType::Timestamp { .. })
948 } else {
949 false
950 }
951 }
952
953 fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
955 use crate::expressions::Identifier;
956 if !args.is_empty() {
957 match &args[0] {
958 Expression::Identifier(id) => {
959 args[0] = Expression::Identifier(Identifier {
960 name: id.name.to_uppercase(),
961 quoted: id.quoted,
962 trailing_comments: id.trailing_comments.clone(),
963 span: None,
964 });
965 }
966 Expression::Var(v) => {
967 args[0] = Expression::Identifier(Identifier {
968 name: v.this.to_uppercase(),
969 quoted: false,
970 trailing_comments: Vec::new(),
971 span: None,
972 });
973 }
974 Expression::Column(col) if col.table.is_none() => {
975 args[0] = Expression::Identifier(Identifier {
977 name: col.name.name.to_uppercase(),
978 quoted: col.name.quoted,
979 trailing_comments: col.name.trailing_comments.clone(),
980 span: None,
981 });
982 }
983 _ => {}
984 }
985 }
986 args
987 }
988
989 fn normalize_get_json_object_path(&self, path_arg: Expression) -> Expression {
990 let Expression::Literal(lit) = &path_arg else {
991 return path_arg;
992 };
993 let crate::expressions::Literal::String(path) = lit.as_ref() else {
994 return path_arg;
995 };
996
997 let Some(segment) = path.strip_prefix("$.") else {
998 return path_arg;
999 };
1000
1001 if segment.is_empty()
1002 || segment.contains('.')
1003 || segment.contains('[')
1004 || segment
1005 .chars()
1006 .all(|c| c.is_ascii_alphanumeric() || c == '_')
1007 {
1008 return path_arg;
1009 }
1010
1011 Expression::Literal(Box::new(crate::expressions::Literal::String(format!(
1012 "$[\"{}\"]",
1013 segment.replace('"', "\\\"")
1014 ))))
1015 }
1016
1017 fn transform_date_add(f: DateAddFunc) -> Expression {
1018 if f.unit == IntervalUnit::Day {
1019 Expression::Function(Box::new(Function::new(
1020 "DATE_ADD".to_string(),
1021 vec![f.this, f.interval],
1022 )))
1023 } else {
1024 Expression::Function(Box::new(Function::new(
1025 "DATE_ADD".to_string(),
1026 vec![
1027 Expression::Identifier(crate::expressions::Identifier {
1028 name: Self::interval_unit_name(f.unit).to_string(),
1029 quoted: false,
1030 trailing_comments: Vec::new(),
1031 span: None,
1032 }),
1033 f.interval,
1034 f.this,
1035 ],
1036 )))
1037 }
1038 }
1039
1040 fn interval_unit_name(unit: IntervalUnit) -> &'static str {
1041 match unit {
1042 IntervalUnit::Year => "YEAR",
1043 IntervalUnit::Quarter => "QUARTER",
1044 IntervalUnit::Month => "MONTH",
1045 IntervalUnit::Week => "WEEK",
1046 IntervalUnit::Day => "DAY",
1047 IntervalUnit::Hour => "HOUR",
1048 IntervalUnit::Minute => "MINUTE",
1049 IntervalUnit::Second => "SECOND",
1050 IntervalUnit::Millisecond => "MILLISECOND",
1051 IntervalUnit::Microsecond => "MICROSECOND",
1052 IntervalUnit::Nanosecond => "NANOSECOND",
1053 }
1054 }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060 use crate::Dialect;
1061
1062 #[test]
1063 fn test_timestamp_literal_cast() {
1064 let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
1067 let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
1068
1069 let d = Dialect::get(DialectType::Databricks);
1070 let ast = d.parse(sql).expect("Parse failed");
1071 let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1072 let output = d.generate(&transformed).expect("Generate failed");
1073
1074 assert_eq!(
1075 output, expected,
1076 "Timestamp literal cast transformation failed"
1077 );
1078 }
1079
1080 #[test]
1081 fn test_from_utc_timestamp_wraps_column() {
1082 let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1084 let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1085
1086 let d = Dialect::get(DialectType::Databricks);
1087 let ast = d.parse(sql).expect("Parse failed");
1088 let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1089 let output = d.generate(&transformed).expect("Generate failed");
1090
1091 assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
1092 }
1093
1094 #[test]
1095 fn test_from_utc_timestamp_keeps_existing_cast() {
1096 let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
1099 let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
1100
1101 let d = Dialect::get(DialectType::Databricks);
1102 let ast = d.parse(sql).expect("Parse failed");
1103 let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1104 let output = d.generate(&transformed).expect("Generate failed");
1105
1106 assert_eq!(
1107 output, expected,
1108 "FROM_UTC_TIMESTAMP with existing CAST failed"
1109 );
1110 }
1111
1112 #[test]
1113 fn test_deep_clone_version_as_of() {
1114 let sql = "CREATE TABLE events_clone DEEP CLONE events VERSION AS OF 5";
1115 let d = Dialect::get(DialectType::Databricks);
1116 let ast = d.parse(sql).expect("Parse failed");
1117 let output = d.generate(&ast[0]).expect("Generate failed");
1118
1119 assert_eq!(output, sql);
1120 }
1121
1122 #[test]
1123 fn test_deep_clone_timestamp_as_of() {
1124 let sql = "CREATE TABLE events_clone DEEP CLONE events TIMESTAMP AS OF '2024-01-01'";
1125 let d = Dialect::get(DialectType::Databricks);
1126 let ast = d.parse(sql).expect("Parse failed");
1127 let output = d.generate(&ast[0]).expect("Generate failed");
1128
1129 assert_eq!(output, sql);
1130 }
1131
1132 #[test]
1133 fn test_shallow_clone_still_roundtrips() {
1134 let sql = "CREATE TABLE events_clone SHALLOW CLONE events";
1135 let d = Dialect::get(DialectType::Databricks);
1136 let ast = d.parse(sql).expect("Parse failed");
1137 let output = d.generate(&ast[0]).expect("Generate failed");
1138
1139 assert_eq!(output, sql);
1140 }
1141
1142 #[test]
1143 fn test_repair_table_commands_roundtrip() {
1144 let d = Dialect::get(DialectType::Databricks);
1145 let cases = [
1146 "REPAIR TABLE events",
1147 "MSCK REPAIR TABLE events",
1148 "REPAIR TABLE events ADD PARTITIONS",
1149 "REPAIR TABLE events DROP PARTITIONS",
1150 "REPAIR TABLE events SYNC PARTITIONS",
1151 "REPAIR TABLE events SYNC METADATA",
1152 ];
1153
1154 for sql in cases {
1155 let ast = d.parse(sql).expect("Parse failed");
1156 let output = d.generate(&ast[0]).expect("Generate failed");
1157 assert_eq!(output, sql);
1158 }
1159 }
1160
1161 #[test]
1162 fn test_apply_changes_commands_roundtrip() {
1163 let d = Dialect::get(DialectType::Databricks);
1164 let cases = [
1165 "APPLY CHANGES INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1166 "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) IGNORE NULL UPDATES SEQUENCE BY ts",
1167 "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) APPLY AS DELETE WHEN operation = 'DELETE' SEQUENCE BY ts COLUMNS * EXCEPT (operation) STORED AS SCD TYPE 1",
1168 "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) SEQUENCE BY ts STORED AS SCD TYPE 2 TRACK HISTORY ON * EXCEPT (updated_at)",
1169 "AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1170 "CREATE FLOW apply_cdc AS AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1171 ];
1172
1173 for sql in cases {
1174 let ast = d.parse(sql).expect("Parse failed");
1175 let output = d.generate(&ast[0]).expect("Generate failed");
1176 assert_eq!(output, sql);
1177 }
1178 }
1179
1180 #[test]
1181 fn test_generate_symlink_format_manifest_roundtrip() {
1182 let d = Dialect::get(DialectType::Databricks);
1183 let cases = [
1184 "GENERATE symlink_format_manifest FOR TABLE events",
1185 "GENERATE symlink_format_manifest FOR TABLE catalog.schema.events",
1186 ];
1187
1188 for sql in cases {
1189 let ast = d.parse(sql).expect("Parse failed");
1190 let output = d.generate(&ast[0]).expect("Generate failed");
1191 assert_eq!(output, sql);
1192 }
1193 }
1194
1195 #[test]
1196 fn test_convert_to_delta_roundtrip() {
1197 let d = Dialect::get(DialectType::Databricks);
1198 let cases = [
1199 "CONVERT TO DELTA parquet.`/mnt/data/events`",
1200 "CONVERT TO DELTA database_name.table_name",
1201 "CONVERT TO DELTA parquet.`s3://my-bucket/path/to/table` PARTITIONED BY (date DATE)",
1202 "CONVERT TO DELTA database_name.table_name NO STATISTICS",
1203 ];
1204
1205 for sql in cases {
1206 let ast = d.parse(sql).expect("Parse failed");
1207 let output = d.generate(&ast[0]).expect("Generate failed");
1208 assert_eq!(output, sql);
1209 }
1210 }
1211}