Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Databricks dialect
20pub struct DatabricksDialect;
21
22impl DialectImpl for DatabricksDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Databricks
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Databricks uses backticks for identifiers (NOT double quotes)
30        config.identifiers.clear();
31        config.identifiers.insert('`', '`');
32        // Databricks (like Hive/Spark) uses double quotes as string delimiters
33        config.quotes.insert("\"".to_string(), "\"".to_string());
34        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
35        config.string_escapes.push('\\');
36        // Databricks supports DIV keyword for integer division
37        config
38            .keywords
39            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
40        config
41            .keywords
42            .insert("REPAIR".to_string(), crate::tokens::TokenType::Command);
43        config
44            .keywords
45            .insert("MSCK".to_string(), crate::tokens::TokenType::Command);
46        // Databricks numeric literal suffixes (same as Hive/Spark)
47        config
48            .numeric_literals
49            .insert("L".to_string(), "BIGINT".to_string());
50        config
51            .numeric_literals
52            .insert("S".to_string(), "SMALLINT".to_string());
53        config
54            .numeric_literals
55            .insert("Y".to_string(), "TINYINT".to_string());
56        config
57            .numeric_literals
58            .insert("D".to_string(), "DOUBLE".to_string());
59        config
60            .numeric_literals
61            .insert("F".to_string(), "FLOAT".to_string());
62        config
63            .numeric_literals
64            .insert("BD".to_string(), "DECIMAL".to_string());
65        // Databricks allows identifiers to start with digits (like Hive/Spark)
66        config.identifiers_can_start_with_digit = true;
67        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
68        // Backslashes in raw strings are always literal (no escape processing)
69        config.string_escapes_allowed_in_raw_strings = false;
70        config
71    }
72
73    fn generator_config(&self) -> GeneratorConfig {
74        use crate::generator::IdentifierQuoteStyle;
75        GeneratorConfig {
76            identifier_quote: '`',
77            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
78            dialect: Some(DialectType::Databricks),
79            struct_field_sep: ": ",
80            create_function_return_as: false,
81            tablesample_seed_keyword: "REPEATABLE",
82            identifiers_can_start_with_digit: true,
83            // Databricks uses COMMENT 'value' without = sign
84            schema_comment_with_eq: false,
85            ..Default::default()
86        }
87    }
88
89    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
90        match expr {
91            // IFNULL -> COALESCE in Databricks
92            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
93                original_name: None,
94                expressions: vec![f.this, f.expression],
95                inferred_type: None,
96            }))),
97
98            // NVL -> COALESCE in Databricks
99            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
100                original_name: None,
101                expressions: vec![f.this, f.expression],
102                inferred_type: None,
103            }))),
104
105            // TryCast is native in Databricks
106            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
107
108            // SafeCast -> TRY_CAST in Databricks
109            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
110
111            // ILIKE is native in Databricks (Spark 3+)
112            Expression::ILike(op) => Ok(Expression::ILike(op)),
113
114            // UNNEST -> EXPLODE in Databricks
115            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
116
117            // EXPLODE is native to Databricks
118            Expression::Explode(f) => Ok(Expression::Explode(f)),
119
120            // ExplodeOuter is supported
121            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
122
123            // RANDOM -> RAND in Databricks
124            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
125                seed: None,
126                lower: None,
127                upper: None,
128            }))),
129
130            // Rand is native
131            Expression::Rand(r) => Ok(Expression::Rand(r)),
132
133            // || (Concat) -> CONCAT in Databricks
134            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
135                "CONCAT".to_string(),
136                vec![op.left, op.right],
137            )))),
138
139            // RegexpLike is native in Databricks
140            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
141
142            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
143            // This is a complex sqlglot transformation where:
144            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
145            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
146            Expression::Cast(c) => self.transform_cast(*c),
147
148            // Generic function transformations
149            Expression::Function(f) => self.transform_function(*f),
150
151            // Generic aggregate function transformations
152            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
153
154            // DateSub -> DATE_ADD(date, -val) in Databricks
155            Expression::DateSub(f) => {
156                // Convert string literals to numbers (interval values are often stored as strings)
157                let val = match f.interval {
158                    Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(s) if s.parse::<i64>().is_ok()) =>
159                    {
160                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
161                            unreachable!()
162                        };
163                        Expression::Literal(Box::new(crate::expressions::Literal::Number(
164                            s.clone(),
165                        )))
166                    }
167                    other => other,
168                };
169                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
170                    this: val,
171                    inferred_type: None,
172                }));
173                Ok(Expression::Function(Box::new(Function::new(
174                    "DATE_ADD".to_string(),
175                    vec![f.this, neg_val],
176                ))))
177            }
178
179            // Pass through everything else
180            _ => Ok(expr),
181        }
182    }
183}
184
185impl DatabricksDialect {
186    fn transform_function(&self, f: Function) -> Result<Expression> {
187        let name_upper = f.name.to_uppercase();
188        match name_upper.as_str() {
189            // IFNULL -> COALESCE
190            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
191                original_name: None,
192                expressions: f.args,
193                inferred_type: None,
194            }))),
195
196            // NVL -> COALESCE
197            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
198                original_name: None,
199                expressions: f.args,
200                inferred_type: None,
201            }))),
202
203            // ISNULL -> COALESCE
204            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
205                original_name: None,
206                expressions: f.args,
207                inferred_type: None,
208            }))),
209
210            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
211            "ROW" => Ok(Expression::Function(Box::new(Function::new(
212                "STRUCT".to_string(),
213                f.args,
214            )))),
215
216            // NAMED_STRUCT('a', 1) -> STRUCT(1 AS a) for SQLGlot Databricks outputs
217            "NAMED_STRUCT" if f.args.len() % 2 == 0 => {
218                let original_args = f.args.clone();
219                let mut struct_args = Vec::new();
220                for pair in f.args.chunks(2) {
221                    if let Expression::Literal(lit) = &pair[0] {
222                        if let Literal::String(field_name) = lit.as_ref() {
223                            struct_args.push(Expression::Alias(Box::new(
224                                crate::expressions::Alias {
225                                    this: pair[1].clone(),
226                                    alias: crate::expressions::Identifier::new(field_name),
227                                    column_aliases: Vec::new(),
228                                    alias_explicit_as: false,
229                                    alias_keyword: None,
230                                    pre_alias_comments: Vec::new(),
231                                    trailing_comments: Vec::new(),
232                                    inferred_type: None,
233                                },
234                            )));
235                            continue;
236                        }
237                    }
238                    return Ok(Expression::Function(Box::new(Function::new(
239                        "NAMED_STRUCT".to_string(),
240                        original_args,
241                    ))));
242                }
243                Ok(Expression::Function(Box::new(Function::new(
244                    "STRUCT".to_string(),
245                    struct_args,
246                ))))
247            }
248
249            // GETDATE -> CURRENT_TIMESTAMP
250            "GETDATE" => Ok(Expression::CurrentTimestamp(
251                crate::expressions::CurrentTimestamp {
252                    precision: None,
253                    sysdate: false,
254                },
255            )),
256
257            // NOW -> CURRENT_TIMESTAMP
258            "NOW" => Ok(Expression::CurrentTimestamp(
259                crate::expressions::CurrentTimestamp {
260                    precision: None,
261                    sysdate: false,
262                },
263            )),
264
265            // CURDATE -> CURRENT_DATE
266            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
267
268            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
269            "CURRENT_DATE" if f.args.is_empty() => {
270                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
271            }
272
273            // RANDOM -> RAND
274            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
275                seed: None,
276                lower: None,
277                upper: None,
278            }))),
279
280            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
281            "GROUP_CONCAT" if !f.args.is_empty() => {
282                let mut args = f.args;
283                let first = args.remove(0);
284                let separator = args.pop();
285                let collect_list = Expression::Function(Box::new(Function::new(
286                    "COLLECT_LIST".to_string(),
287                    vec![first],
288                )));
289                if let Some(sep) = separator {
290                    Ok(Expression::Function(Box::new(Function::new(
291                        "ARRAY_JOIN".to_string(),
292                        vec![collect_list, sep],
293                    ))))
294                } else {
295                    Ok(Expression::Function(Box::new(Function::new(
296                        "ARRAY_JOIN".to_string(),
297                        vec![collect_list],
298                    ))))
299                }
300            }
301
302            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
303            "STRING_AGG" if !f.args.is_empty() => {
304                let mut args = f.args;
305                let first = args.remove(0);
306                let separator = args.pop();
307                let collect_list = Expression::Function(Box::new(Function::new(
308                    "COLLECT_LIST".to_string(),
309                    vec![first],
310                )));
311                if let Some(sep) = separator {
312                    Ok(Expression::Function(Box::new(Function::new(
313                        "ARRAY_JOIN".to_string(),
314                        vec![collect_list, sep],
315                    ))))
316                } else {
317                    Ok(Expression::Function(Box::new(Function::new(
318                        "ARRAY_JOIN".to_string(),
319                        vec![collect_list],
320                    ))))
321                }
322            }
323
324            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
325            "LISTAGG" if !f.args.is_empty() => {
326                let mut args = f.args;
327                let first = args.remove(0);
328                let separator = args.pop();
329                let collect_list = Expression::Function(Box::new(Function::new(
330                    "COLLECT_LIST".to_string(),
331                    vec![first],
332                )));
333                if let Some(sep) = separator {
334                    Ok(Expression::Function(Box::new(Function::new(
335                        "ARRAY_JOIN".to_string(),
336                        vec![collect_list, sep],
337                    ))))
338                } else {
339                    Ok(Expression::Function(Box::new(Function::new(
340                        "ARRAY_JOIN".to_string(),
341                        vec![collect_list],
342                    ))))
343                }
344            }
345
346            // ARRAY_AGG -> COLLECT_LIST in Databricks
347            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
348                "COLLECT_LIST".to_string(),
349                f.args,
350            )))),
351
352            // SUBSTR -> SUBSTRING
353            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
354                "SUBSTRING".to_string(),
355                f.args,
356            )))),
357
358            // LEN -> LENGTH
359            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
360                f.args.into_iter().next().unwrap(),
361            )))),
362
363            // CHARINDEX -> LOCATE (with swapped args, like Spark)
364            "CHARINDEX" if f.args.len() >= 2 => {
365                let mut args = f.args;
366                let substring = args.remove(0);
367                let string = args.remove(0);
368                // LOCATE(substring, string)
369                Ok(Expression::Function(Box::new(Function::new(
370                    "LOCATE".to_string(),
371                    vec![substring, string],
372                ))))
373            }
374
375            // POSITION -> LOCATE
376            "POSITION" if f.args.len() == 2 => {
377                let args = f.args;
378                Ok(Expression::Function(Box::new(Function::new(
379                    "LOCATE".to_string(),
380                    args,
381                ))))
382            }
383
384            // STRPOS -> LOCATE (with same arg order)
385            "STRPOS" if f.args.len() == 2 => {
386                let args = f.args;
387                let string = args[0].clone();
388                let substring = args[1].clone();
389                // STRPOS(string, substring) -> LOCATE(substring, string)
390                Ok(Expression::Function(Box::new(Function::new(
391                    "LOCATE".to_string(),
392                    vec![substring, string],
393                ))))
394            }
395
396            // INSTR is native in Databricks
397            "INSTR" => Ok(Expression::Function(Box::new(f))),
398
399            // LOCATE is native in Databricks
400            "LOCATE" => Ok(Expression::Function(Box::new(f))),
401
402            // ARRAY_LENGTH -> SIZE
403            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
404                Function::new("SIZE".to_string(), f.args),
405            ))),
406
407            // CARDINALITY -> SIZE
408            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
409                Function::new("SIZE".to_string(), f.args),
410            ))),
411
412            // SIZE is native
413            "SIZE" => Ok(Expression::Function(Box::new(f))),
414
415            // ARRAY_CONTAINS is native in Databricks
416            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
417
418            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
419            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
420            "CONTAINS" if f.args.len() == 2 => {
421                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
422                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
423                    && matches!(&f.args[1], Expression::Lower(_));
424                if is_string_contains {
425                    Ok(Expression::Function(Box::new(f)))
426                } else {
427                    Ok(Expression::Function(Box::new(Function::new(
428                        "ARRAY_CONTAINS".to_string(),
429                        f.args,
430                    ))))
431                }
432            }
433
434            // TO_DATE is native in Databricks
435            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
436
437            // TO_TIMESTAMP is native in Databricks
438            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
439
440            // DATE_FORMAT is native in Databricks
441            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
442
443            // strftime -> DATE_FORMAT in Databricks
444            "STRFTIME" if f.args.len() >= 2 => {
445                let mut args = f.args;
446                let format = args.remove(0);
447                let date = args.remove(0);
448                Ok(Expression::Function(Box::new(Function::new(
449                    "DATE_FORMAT".to_string(),
450                    vec![date, format],
451                ))))
452            }
453
454            // TO_CHAR is supported natively in Databricks (unlike Spark)
455            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
456
457            // DATE_TRUNC is native in Databricks
458            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
459
460            // DATEADD is native in Databricks - uppercase the unit if present
461            "DATEADD" => {
462                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
463                Ok(Expression::Function(Box::new(Function::new(
464                    "DATEADD".to_string(),
465                    transformed_args,
466                ))))
467            }
468
469            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
470            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
471            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
472            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
473            "DATE_ADD" => {
474                if f.args.len() == 2 {
475                    let is_simple_number = matches!(
476                        &f.args[1],
477                        Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::Number(_))
478                    ) || matches!(&f.args[1], Expression::Neg(_));
479                    if is_simple_number {
480                        // Keep as DATE_ADD(date, num_days)
481                        Ok(Expression::Function(Box::new(Function::new(
482                            "DATE_ADD".to_string(),
483                            f.args,
484                        ))))
485                    } else {
486                        let mut args = f.args;
487                        let date = args.remove(0);
488                        let interval = args.remove(0);
489                        let unit = Expression::Identifier(crate::expressions::Identifier {
490                            name: "DAY".to_string(),
491                            quoted: false,
492                            trailing_comments: Vec::new(),
493                            span: None,
494                        });
495                        Ok(Expression::Function(Box::new(Function::new(
496                            "DATEADD".to_string(),
497                            vec![unit, interval, date],
498                        ))))
499                    }
500                } else {
501                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
502                    Ok(Expression::Function(Box::new(Function::new(
503                        "DATE_ADD".to_string(),
504                        transformed_args,
505                    ))))
506                }
507            }
508
509            // DATEDIFF is native in Databricks - uppercase the unit if present
510            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
511            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
512            "DATEDIFF" => {
513                if f.args.len() == 2 {
514                    let mut args = f.args;
515                    let end_date = args.remove(0);
516                    let start_date = args.remove(0);
517                    let unit = Expression::Identifier(crate::expressions::Identifier {
518                        name: "DAY".to_string(),
519                        quoted: false,
520                        trailing_comments: Vec::new(),
521                        span: None,
522                    });
523                    Ok(Expression::Function(Box::new(Function::new(
524                        "DATEDIFF".to_string(),
525                        vec![unit, start_date, end_date],
526                    ))))
527                } else {
528                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
529                    Ok(Expression::Function(Box::new(Function::new(
530                        "DATEDIFF".to_string(),
531                        transformed_args,
532                    ))))
533                }
534            }
535
536            // DATE_DIFF -> DATEDIFF with uppercased unit
537            "DATE_DIFF" => {
538                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
539                Ok(Expression::Function(Box::new(Function::new(
540                    "DATEDIFF".to_string(),
541                    transformed_args,
542                ))))
543            }
544
545            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
546            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
547
548            // JSON_EXTRACT_SCALAR -> same handling
549            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
550
551            // GET_JSON_OBJECT -> colon syntax in Databricks
552            // GET_JSON_OBJECT(col, '$.path') becomes col:path
553            "GET_JSON_OBJECT" if f.args.len() == 2 => {
554                let mut args = f.args;
555                let col = args.remove(0);
556                let path_arg = args.remove(0);
557
558                // Extract and strip the $. prefix from the path
559                let path_expr = match &path_arg {
560                    Expression::Literal(lit)
561                        if matches!(lit.as_ref(), crate::expressions::Literal::String(_)) =>
562                    {
563                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
564                            unreachable!()
565                        };
566                        // Strip leading '$.' if present
567                        let stripped = if s.starts_with("$.") {
568                            &s[2..]
569                        } else if s.starts_with("$") {
570                            &s[1..]
571                        } else {
572                            s.as_str()
573                        };
574                        Expression::Literal(Box::new(crate::expressions::Literal::String(
575                            stripped.to_string(),
576                        )))
577                    }
578                    _ => path_arg,
579                };
580
581                Ok(Expression::JSONExtract(Box::new(JSONExtract {
582                    this: Box::new(col),
583                    expression: Box::new(path_expr),
584                    only_json_types: None,
585                    expressions: Vec::new(),
586                    variant_extract: Some(Box::new(Expression::true_())),
587                    json_query: None,
588                    option: None,
589                    quote: None,
590                    on_condition: None,
591                    requires_json: None,
592                })))
593            }
594
595            // FROM_JSON is native in Databricks
596            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
597
598            // PARSE_JSON is native in Databricks
599            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
600
601            // COLLECT_LIST is native in Databricks
602            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
603
604            // COLLECT_SET is native in Databricks
605            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
606
607            // RLIKE is native in Databricks
608            "RLIKE" => Ok(Expression::Function(Box::new(f))),
609
610            // REGEXP -> RLIKE in Databricks
611            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
612                "RLIKE".to_string(),
613                f.args,
614            )))),
615
616            // REGEXP_LIKE is native in Databricks
617            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
618
619            // LEVENSHTEIN is native in Databricks
620            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
621
622            // SEQUENCE is native (for GENERATE_SERIES)
623            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
624                Function::new("SEQUENCE".to_string(), f.args),
625            ))),
626
627            // SEQUENCE is native
628            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
629
630            // FLATTEN is native in Databricks
631            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
632
633            // ARRAY_SORT is native
634            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
635
636            // ARRAY_DISTINCT is native
637            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
638
639            // TRANSFORM is native (for array transformation)
640            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
641
642            // FILTER is native (for array filtering)
643            "FILTER" => Ok(Expression::Function(Box::new(f))),
644
645            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
646            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
647                let mut args = f.args;
648                let first_arg = args.remove(0);
649
650                // Check if first arg is already a Cast to TIMESTAMP
651                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
652                    first_arg
653                } else {
654                    // Wrap in CAST(... AS TIMESTAMP)
655                    Expression::Cast(Box::new(Cast {
656                        this: first_arg,
657                        to: DataType::Timestamp {
658                            precision: None,
659                            timezone: false,
660                        },
661                        trailing_comments: Vec::new(),
662                        double_colon_syntax: false,
663                        format: None,
664                        default: None,
665                        inferred_type: None,
666                    }))
667                };
668
669                let mut new_args = vec![wrapped_arg];
670                new_args.extend(args);
671
672                Ok(Expression::Function(Box::new(Function::new(
673                    "FROM_UTC_TIMESTAMP".to_string(),
674                    new_args,
675                ))))
676            }
677
678            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
679            "UNIFORM" if f.args.len() == 3 => {
680                let mut args = f.args;
681                let low = args.remove(0);
682                let high = args.remove(0);
683                let gen = args.remove(0);
684                match gen {
685                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
686                        if func.args.len() == 1 {
687                            // RANDOM(seed) -> extract seed
688                            let seed = func.args.into_iter().next().unwrap();
689                            Ok(Expression::Function(Box::new(Function::new(
690                                "UNIFORM".to_string(),
691                                vec![low, high, seed],
692                            ))))
693                        } else {
694                            // RANDOM() -> drop gen arg
695                            Ok(Expression::Function(Box::new(Function::new(
696                                "UNIFORM".to_string(),
697                                vec![low, high],
698                            ))))
699                        }
700                    }
701                    Expression::Rand(r) => {
702                        if let Some(seed) = r.seed {
703                            Ok(Expression::Function(Box::new(Function::new(
704                                "UNIFORM".to_string(),
705                                vec![low, high, *seed],
706                            ))))
707                        } else {
708                            Ok(Expression::Function(Box::new(Function::new(
709                                "UNIFORM".to_string(),
710                                vec![low, high],
711                            ))))
712                        }
713                    }
714                    _ => Ok(Expression::Function(Box::new(Function::new(
715                        "UNIFORM".to_string(),
716                        vec![low, high, gen],
717                    )))),
718                }
719            }
720
721            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
722            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
723                let subject = f.args[0].clone();
724                let pattern = f.args[1].clone();
725                Ok(Expression::Function(Box::new(Function::new(
726                    "REGEXP_EXTRACT".to_string(),
727                    vec![subject, pattern],
728                ))))
729            }
730
731            // BIT_GET -> GETBIT
732            "BIT_GET" => Ok(Expression::Function(Box::new(Function::new(
733                "GETBIT".to_string(),
734                f.args,
735            )))),
736
737            // Pass through everything else
738            _ => Ok(Expression::Function(Box::new(f))),
739        }
740    }
741
742    fn transform_aggregate_function(
743        &self,
744        f: Box<crate::expressions::AggregateFunction>,
745    ) -> Result<Expression> {
746        let name_upper = f.name.to_uppercase();
747        match name_upper.as_str() {
748            // COUNT_IF is native in Databricks (Spark 3+)
749            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
750
751            // ANY_VALUE is native in Databricks (Spark 3+)
752            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
753
754            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
755            "GROUP_CONCAT" if !f.args.is_empty() => {
756                let mut args = f.args;
757                let first = args.remove(0);
758                let separator = args.pop();
759                let collect_list = Expression::Function(Box::new(Function::new(
760                    "COLLECT_LIST".to_string(),
761                    vec![first],
762                )));
763                if let Some(sep) = separator {
764                    Ok(Expression::Function(Box::new(Function::new(
765                        "ARRAY_JOIN".to_string(),
766                        vec![collect_list, sep],
767                    ))))
768                } else {
769                    Ok(Expression::Function(Box::new(Function::new(
770                        "ARRAY_JOIN".to_string(),
771                        vec![collect_list],
772                    ))))
773                }
774            }
775
776            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
777            "STRING_AGG" if !f.args.is_empty() => {
778                let mut args = f.args;
779                let first = args.remove(0);
780                let separator = args.pop();
781                let collect_list = Expression::Function(Box::new(Function::new(
782                    "COLLECT_LIST".to_string(),
783                    vec![first],
784                )));
785                if let Some(sep) = separator {
786                    Ok(Expression::Function(Box::new(Function::new(
787                        "ARRAY_JOIN".to_string(),
788                        vec![collect_list, sep],
789                    ))))
790                } else {
791                    Ok(Expression::Function(Box::new(Function::new(
792                        "ARRAY_JOIN".to_string(),
793                        vec![collect_list],
794                    ))))
795                }
796            }
797
798            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
799            "LISTAGG" if !f.args.is_empty() => {
800                let mut args = f.args;
801                let first = args.remove(0);
802                let separator = args.pop();
803                let collect_list = Expression::Function(Box::new(Function::new(
804                    "COLLECT_LIST".to_string(),
805                    vec![first],
806                )));
807                if let Some(sep) = separator {
808                    Ok(Expression::Function(Box::new(Function::new(
809                        "ARRAY_JOIN".to_string(),
810                        vec![collect_list, sep],
811                    ))))
812                } else {
813                    Ok(Expression::Function(Box::new(Function::new(
814                        "ARRAY_JOIN".to_string(),
815                        vec![collect_list],
816                    ))))
817                }
818            }
819
820            // ARRAY_AGG -> COLLECT_LIST
821            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
822                "COLLECT_LIST".to_string(),
823                f.args,
824            )))),
825
826            // STDDEV is native in Databricks
827            "STDDEV" => Ok(Expression::AggregateFunction(f)),
828
829            // VARIANCE is native in Databricks
830            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
831
832            // APPROX_COUNT_DISTINCT is native in Databricks
833            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
834
835            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
836            "APPROX_DISTINCT" if !f.args.is_empty() => {
837                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
838                    name: "APPROX_COUNT_DISTINCT".to_string(),
839                    args: f.args,
840                    distinct: f.distinct,
841                    filter: f.filter,
842                    order_by: Vec::new(),
843                    limit: None,
844                    ignore_nulls: None,
845                    inferred_type: None,
846                })))
847            }
848
849            // Pass through everything else
850            _ => Ok(Expression::AggregateFunction(f)),
851        }
852    }
853
854    /// Transform Cast expressions - handles typed literals being cast
855    ///
856    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
857    /// Databricks/Spark transforms it as follows:
858    ///
859    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
860    ///
861    /// This reverses the types - the inner cast is to the target type,
862    /// the outer cast is to the original literal type.
863    fn transform_cast(&self, c: Cast) -> Result<Expression> {
864        // Check if the inner expression is a typed literal
865        match &c.this {
866            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
867            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Timestamp(_)) => {
868                let Literal::Timestamp(value) = lit.as_ref() else {
869                    unreachable!()
870                };
871                // Create inner cast: CAST('value' AS target_type)
872                let inner_cast = Expression::Cast(Box::new(Cast {
873                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
874                    to: c.to,
875                    trailing_comments: Vec::new(),
876                    double_colon_syntax: false,
877                    format: None,
878                    default: None,
879                    inferred_type: None,
880                }));
881                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
882                Ok(Expression::Cast(Box::new(Cast {
883                    this: inner_cast,
884                    to: DataType::Timestamp {
885                        precision: None,
886                        timezone: false,
887                    },
888                    trailing_comments: c.trailing_comments,
889                    double_colon_syntax: false,
890                    format: None,
891                    default: None,
892                    inferred_type: None,
893                })))
894            }
895            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
896            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
897                let Literal::Date(value) = lit.as_ref() else {
898                    unreachable!()
899                };
900                let inner_cast = Expression::Cast(Box::new(Cast {
901                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
902                    to: c.to,
903                    trailing_comments: Vec::new(),
904                    double_colon_syntax: false,
905                    format: None,
906                    default: None,
907                    inferred_type: None,
908                }));
909                Ok(Expression::Cast(Box::new(Cast {
910                    this: inner_cast,
911                    to: DataType::Date,
912                    trailing_comments: c.trailing_comments,
913                    double_colon_syntax: false,
914                    format: None,
915                    default: None,
916                    inferred_type: None,
917                })))
918            }
919            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
920            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Time(_)) => {
921                let Literal::Time(value) = lit.as_ref() else {
922                    unreachable!()
923                };
924                let inner_cast = Expression::Cast(Box::new(Cast {
925                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
926                    to: c.to,
927                    trailing_comments: Vec::new(),
928                    double_colon_syntax: false,
929                    format: None,
930                    default: None,
931                    inferred_type: None,
932                }));
933                Ok(Expression::Cast(Box::new(Cast {
934                    this: inner_cast,
935                    to: DataType::Time {
936                        precision: None,
937                        timezone: false,
938                    },
939                    trailing_comments: c.trailing_comments,
940                    double_colon_syntax: false,
941                    format: None,
942                    default: None,
943                    inferred_type: None,
944                })))
945            }
946            // For all other cases, pass through the Cast unchanged
947            _ => Ok(Expression::Cast(Box::new(c))),
948        }
949    }
950
951    /// Check if an expression is a CAST to TIMESTAMP
952    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
953        if let Expression::Cast(cast) = expr {
954            matches!(cast.to, DataType::Timestamp { .. })
955        } else {
956            false
957        }
958    }
959
960    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
961    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
962        use crate::expressions::Identifier;
963        if !args.is_empty() {
964            match &args[0] {
965                Expression::Identifier(id) => {
966                    args[0] = Expression::Identifier(Identifier {
967                        name: id.name.to_uppercase(),
968                        quoted: id.quoted,
969                        trailing_comments: id.trailing_comments.clone(),
970                        span: None,
971                    });
972                }
973                Expression::Var(v) => {
974                    args[0] = Expression::Identifier(Identifier {
975                        name: v.this.to_uppercase(),
976                        quoted: false,
977                        trailing_comments: Vec::new(),
978                        span: None,
979                    });
980                }
981                Expression::Column(col) if col.table.is_none() => {
982                    // Unqualified column name like "day" should be treated as a unit
983                    args[0] = Expression::Identifier(Identifier {
984                        name: col.name.name.to_uppercase(),
985                        quoted: col.name.quoted,
986                        trailing_comments: col.name.trailing_comments.clone(),
987                        span: None,
988                    });
989                }
990                _ => {}
991            }
992        }
993        args
994    }
995}
996
997#[cfg(test)]
998mod tests {
999    use super::*;
1000    use crate::Dialect;
1001
1002    #[test]
1003    fn test_timestamp_literal_cast() {
1004        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
1005        // This is test [47] in the Databricks dialect identity fixtures
1006        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
1007        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
1008
1009        let d = Dialect::get(DialectType::Databricks);
1010        let ast = d.parse(sql).expect("Parse failed");
1011        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1012        let output = d.generate(&transformed).expect("Generate failed");
1013
1014        assert_eq!(
1015            output, expected,
1016            "Timestamp literal cast transformation failed"
1017        );
1018    }
1019
1020    #[test]
1021    fn test_from_utc_timestamp_wraps_column() {
1022        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
1023        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1024        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1025
1026        let d = Dialect::get(DialectType::Databricks);
1027        let ast = d.parse(sql).expect("Parse failed");
1028        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1029        let output = d.generate(&transformed).expect("Generate failed");
1030
1031        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
1032    }
1033
1034    #[test]
1035    fn test_from_utc_timestamp_keeps_existing_cast() {
1036        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
1037        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
1038        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
1039        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
1040
1041        let d = Dialect::get(DialectType::Databricks);
1042        let ast = d.parse(sql).expect("Parse failed");
1043        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1044        let output = d.generate(&transformed).expect("Generate failed");
1045
1046        assert_eq!(
1047            output, expected,
1048            "FROM_UTC_TIMESTAMP with existing CAST failed"
1049        );
1050    }
1051
1052    #[test]
1053    fn test_deep_clone_version_as_of() {
1054        let sql = "CREATE TABLE events_clone DEEP CLONE events VERSION AS OF 5";
1055        let d = Dialect::get(DialectType::Databricks);
1056        let ast = d.parse(sql).expect("Parse failed");
1057        let output = d.generate(&ast[0]).expect("Generate failed");
1058
1059        assert_eq!(output, sql);
1060    }
1061
1062    #[test]
1063    fn test_deep_clone_timestamp_as_of() {
1064        let sql = "CREATE TABLE events_clone DEEP CLONE events TIMESTAMP AS OF '2024-01-01'";
1065        let d = Dialect::get(DialectType::Databricks);
1066        let ast = d.parse(sql).expect("Parse failed");
1067        let output = d.generate(&ast[0]).expect("Generate failed");
1068
1069        assert_eq!(output, sql);
1070    }
1071
1072    #[test]
1073    fn test_shallow_clone_still_roundtrips() {
1074        let sql = "CREATE TABLE events_clone SHALLOW CLONE events";
1075        let d = Dialect::get(DialectType::Databricks);
1076        let ast = d.parse(sql).expect("Parse failed");
1077        let output = d.generate(&ast[0]).expect("Generate failed");
1078
1079        assert_eq!(output, sql);
1080    }
1081
1082    #[test]
1083    fn test_repair_table_commands_roundtrip() {
1084        let d = Dialect::get(DialectType::Databricks);
1085        let cases = [
1086            "REPAIR TABLE events",
1087            "MSCK REPAIR TABLE events",
1088            "REPAIR TABLE events ADD PARTITIONS",
1089            "REPAIR TABLE events DROP PARTITIONS",
1090            "REPAIR TABLE events SYNC PARTITIONS",
1091            "REPAIR TABLE events SYNC METADATA",
1092        ];
1093
1094        for sql in cases {
1095            let ast = d.parse(sql).expect("Parse failed");
1096            let output = d.generate(&ast[0]).expect("Generate failed");
1097            assert_eq!(output, sql);
1098        }
1099    }
1100
1101    #[test]
1102    fn test_apply_changes_commands_roundtrip() {
1103        let d = Dialect::get(DialectType::Databricks);
1104        let cases = [
1105            "APPLY CHANGES INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1106            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) IGNORE NULL UPDATES SEQUENCE BY ts",
1107            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) APPLY AS DELETE WHEN operation = 'DELETE' SEQUENCE BY ts COLUMNS * EXCEPT (operation) STORED AS SCD TYPE 1",
1108            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) SEQUENCE BY ts STORED AS SCD TYPE 2 TRACK HISTORY ON * EXCEPT (updated_at)",
1109            "AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1110            "CREATE FLOW apply_cdc AS AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1111        ];
1112
1113        for sql in cases {
1114            let ast = d.parse(sql).expect("Parse failed");
1115            let output = d.generate(&ast[0]).expect("Generate failed");
1116            assert_eq!(output, sql);
1117        }
1118    }
1119
1120    #[test]
1121    fn test_generate_symlink_format_manifest_roundtrip() {
1122        let d = Dialect::get(DialectType::Databricks);
1123        let cases = [
1124            "GENERATE symlink_format_manifest FOR TABLE events",
1125            "GENERATE symlink_format_manifest FOR TABLE catalog.schema.events",
1126        ];
1127
1128        for sql in cases {
1129            let ast = d.parse(sql).expect("Parse failed");
1130            let output = d.generate(&ast[0]).expect("Generate failed");
1131            assert_eq!(output, sql);
1132        }
1133    }
1134
1135    #[test]
1136    fn test_convert_to_delta_roundtrip() {
1137        let d = Dialect::get(DialectType::Databricks);
1138        let cases = [
1139            "CONVERT TO DELTA parquet.`/mnt/data/events`",
1140            "CONVERT TO DELTA database_name.table_name",
1141            "CONVERT TO DELTA parquet.`s3://my-bucket/path/to/table` PARTITIONED BY (date DATE)",
1142            "CONVERT TO DELTA database_name.table_name NO STATISTICS",
1143        ];
1144
1145        for sql in cases {
1146            let ast = d.parse(sql).expect("Parse failed");
1147            let output = d.generate(&ast[0]).expect("Generate failed");
1148            assert_eq!(output, sql);
1149        }
1150    }
1151}