Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Databricks dialect
20pub struct DatabricksDialect;
21
22impl DialectImpl for DatabricksDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Databricks
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Databricks uses backticks for identifiers (NOT double quotes)
30        config.identifiers.clear();
31        config.identifiers.insert('`', '`');
32        // Databricks (like Hive/Spark) uses double quotes as string delimiters
33        config.quotes.insert("\"".to_string(), "\"".to_string());
34        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
35        config.string_escapes.push('\\');
36        // Databricks supports DIV keyword for integer division
37        config
38            .keywords
39            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
40        config
41            .keywords
42            .insert("REPAIR".to_string(), crate::tokens::TokenType::Command);
43        config
44            .keywords
45            .insert("MSCK".to_string(), crate::tokens::TokenType::Command);
46        // Databricks numeric literal suffixes (same as Hive/Spark)
47        config
48            .numeric_literals
49            .insert("L".to_string(), "BIGINT".to_string());
50        config
51            .numeric_literals
52            .insert("S".to_string(), "SMALLINT".to_string());
53        config
54            .numeric_literals
55            .insert("Y".to_string(), "TINYINT".to_string());
56        config
57            .numeric_literals
58            .insert("D".to_string(), "DOUBLE".to_string());
59        config
60            .numeric_literals
61            .insert("F".to_string(), "FLOAT".to_string());
62        config
63            .numeric_literals
64            .insert("BD".to_string(), "DECIMAL".to_string());
65        // Databricks allows identifiers to start with digits (like Hive/Spark)
66        config.identifiers_can_start_with_digit = true;
67        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
68        // Backslashes in raw strings are always literal (no escape processing)
69        config.string_escapes_allowed_in_raw_strings = false;
70        config
71    }
72
73    fn generator_config(&self) -> GeneratorConfig {
74        use crate::generator::IdentifierQuoteStyle;
75        GeneratorConfig {
76            identifier_quote: '`',
77            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
78            dialect: Some(DialectType::Databricks),
79            struct_field_sep: ": ",
80            create_function_return_as: false,
81            tablesample_seed_keyword: "REPEATABLE",
82            identifiers_can_start_with_digit: true,
83            // Databricks uses COMMENT 'value' without = sign
84            schema_comment_with_eq: false,
85            ..Default::default()
86        }
87    }
88
89    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
90        match expr {
91            // IFNULL -> COALESCE in Databricks
92            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
93                original_name: None,
94                expressions: vec![f.this, f.expression],
95                inferred_type: None,
96            }))),
97
98            // NVL -> COALESCE in Databricks
99            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
100                original_name: None,
101                expressions: vec![f.this, f.expression],
102                inferred_type: None,
103            }))),
104
105            // TryCast is native in Databricks
106            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
107
108            // SafeCast -> TRY_CAST in Databricks
109            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
110
111            // ILIKE is native in Databricks (Spark 3+)
112            Expression::ILike(op) => Ok(Expression::ILike(op)),
113
114            // UNNEST -> EXPLODE in Databricks
115            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
116
117            // EXPLODE is native to Databricks
118            Expression::Explode(f) => Ok(Expression::Explode(f)),
119
120            // ExplodeOuter is supported
121            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
122
123            // RANDOM -> RAND in Databricks
124            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
125                seed: None,
126                lower: None,
127                upper: None,
128            }))),
129
130            // Rand is native
131            Expression::Rand(r) => Ok(Expression::Rand(r)),
132
133            // || (Concat) -> CONCAT in Databricks
134            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
135                "CONCAT".to_string(),
136                vec![op.left, op.right],
137            )))),
138
139            // RegexpLike is native in Databricks
140            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
141
142            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
143            // This is a complex sqlglot transformation where:
144            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
145            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
146            Expression::Cast(c) => self.transform_cast(*c),
147
148            // Generic function transformations
149            Expression::Function(f) => self.transform_function(*f),
150
151            // Generic aggregate function transformations
152            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
153
154            // DateSub -> DATE_ADD(date, -val) in Databricks
155            Expression::DateSub(f) => {
156                // Convert string literals to numbers (interval values are often stored as strings)
157                let val = match f.interval {
158                    Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(s) if s.parse::<i64>().is_ok()) =>
159                    {
160                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
161                            unreachable!()
162                        };
163                        Expression::Literal(Box::new(crate::expressions::Literal::Number(
164                            s.clone(),
165                        )))
166                    }
167                    other => other,
168                };
169                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
170                    this: val,
171                    inferred_type: None,
172                }));
173                Ok(Expression::Function(Box::new(Function::new(
174                    "DATE_ADD".to_string(),
175                    vec![f.this, neg_val],
176                ))))
177            }
178
179            // Pass through everything else
180            _ => Ok(expr),
181        }
182    }
183}
184
185impl DatabricksDialect {
186    fn transform_function(&self, f: Function) -> Result<Expression> {
187        let name_upper = f.name.to_uppercase();
188        match name_upper.as_str() {
189            // IFNULL -> COALESCE
190            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
191                original_name: None,
192                expressions: f.args,
193                inferred_type: None,
194            }))),
195
196            // NVL -> COALESCE
197            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
198                original_name: None,
199                expressions: f.args,
200                inferred_type: None,
201            }))),
202
203            // ISNULL -> COALESCE
204            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
205                original_name: None,
206                expressions: f.args,
207                inferred_type: None,
208            }))),
209
210            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
211            "ROW" => Ok(Expression::Function(Box::new(Function::new(
212                "STRUCT".to_string(),
213                f.args,
214            )))),
215
216            // NAMED_STRUCT('a', 1) -> STRUCT(1 AS a) for SQLGlot Databricks outputs
217            "NAMED_STRUCT" if f.args.len() % 2 == 0 => {
218                let original_args = f.args.clone();
219                let mut struct_args = Vec::new();
220                for pair in f.args.chunks(2) {
221                    if let Expression::Literal(lit) = &pair[0] {
222                        if let Literal::String(field_name) = lit.as_ref() {
223                            struct_args.push(Expression::Alias(Box::new(
224                                crate::expressions::Alias {
225                                    this: pair[1].clone(),
226                                    alias: crate::expressions::Identifier::new(field_name),
227                                    column_aliases: Vec::new(),
228                                    pre_alias_comments: Vec::new(),
229                                    trailing_comments: Vec::new(),
230                                    inferred_type: None,
231                                },
232                            )));
233                            continue;
234                        }
235                    }
236                    return Ok(Expression::Function(Box::new(Function::new(
237                        "NAMED_STRUCT".to_string(),
238                        original_args,
239                    ))));
240                }
241                Ok(Expression::Function(Box::new(Function::new(
242                    "STRUCT".to_string(),
243                    struct_args,
244                ))))
245            }
246
247            // GETDATE -> CURRENT_TIMESTAMP
248            "GETDATE" => Ok(Expression::CurrentTimestamp(
249                crate::expressions::CurrentTimestamp {
250                    precision: None,
251                    sysdate: false,
252                },
253            )),
254
255            // NOW -> CURRENT_TIMESTAMP
256            "NOW" => Ok(Expression::CurrentTimestamp(
257                crate::expressions::CurrentTimestamp {
258                    precision: None,
259                    sysdate: false,
260                },
261            )),
262
263            // CURDATE -> CURRENT_DATE
264            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
265
266            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
267            "CURRENT_DATE" if f.args.is_empty() => {
268                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
269            }
270
271            // RANDOM -> RAND
272            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
273                seed: None,
274                lower: None,
275                upper: None,
276            }))),
277
278            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
279            "GROUP_CONCAT" if !f.args.is_empty() => {
280                let mut args = f.args;
281                let first = args.remove(0);
282                let separator = args.pop();
283                let collect_list = Expression::Function(Box::new(Function::new(
284                    "COLLECT_LIST".to_string(),
285                    vec![first],
286                )));
287                if let Some(sep) = separator {
288                    Ok(Expression::Function(Box::new(Function::new(
289                        "ARRAY_JOIN".to_string(),
290                        vec![collect_list, sep],
291                    ))))
292                } else {
293                    Ok(Expression::Function(Box::new(Function::new(
294                        "ARRAY_JOIN".to_string(),
295                        vec![collect_list],
296                    ))))
297                }
298            }
299
300            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
301            "STRING_AGG" if !f.args.is_empty() => {
302                let mut args = f.args;
303                let first = args.remove(0);
304                let separator = args.pop();
305                let collect_list = Expression::Function(Box::new(Function::new(
306                    "COLLECT_LIST".to_string(),
307                    vec![first],
308                )));
309                if let Some(sep) = separator {
310                    Ok(Expression::Function(Box::new(Function::new(
311                        "ARRAY_JOIN".to_string(),
312                        vec![collect_list, sep],
313                    ))))
314                } else {
315                    Ok(Expression::Function(Box::new(Function::new(
316                        "ARRAY_JOIN".to_string(),
317                        vec![collect_list],
318                    ))))
319                }
320            }
321
322            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
323            "LISTAGG" if !f.args.is_empty() => {
324                let mut args = f.args;
325                let first = args.remove(0);
326                let separator = args.pop();
327                let collect_list = Expression::Function(Box::new(Function::new(
328                    "COLLECT_LIST".to_string(),
329                    vec![first],
330                )));
331                if let Some(sep) = separator {
332                    Ok(Expression::Function(Box::new(Function::new(
333                        "ARRAY_JOIN".to_string(),
334                        vec![collect_list, sep],
335                    ))))
336                } else {
337                    Ok(Expression::Function(Box::new(Function::new(
338                        "ARRAY_JOIN".to_string(),
339                        vec![collect_list],
340                    ))))
341                }
342            }
343
344            // ARRAY_AGG -> COLLECT_LIST in Databricks
345            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
346                "COLLECT_LIST".to_string(),
347                f.args,
348            )))),
349
350            // SUBSTR -> SUBSTRING
351            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
352                "SUBSTRING".to_string(),
353                f.args,
354            )))),
355
356            // LEN -> LENGTH
357            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
358                f.args.into_iter().next().unwrap(),
359            )))),
360
361            // CHARINDEX -> LOCATE (with swapped args, like Spark)
362            "CHARINDEX" if f.args.len() >= 2 => {
363                let mut args = f.args;
364                let substring = args.remove(0);
365                let string = args.remove(0);
366                // LOCATE(substring, string)
367                Ok(Expression::Function(Box::new(Function::new(
368                    "LOCATE".to_string(),
369                    vec![substring, string],
370                ))))
371            }
372
373            // POSITION -> LOCATE
374            "POSITION" if f.args.len() == 2 => {
375                let args = f.args;
376                Ok(Expression::Function(Box::new(Function::new(
377                    "LOCATE".to_string(),
378                    args,
379                ))))
380            }
381
382            // STRPOS -> LOCATE (with same arg order)
383            "STRPOS" if f.args.len() == 2 => {
384                let args = f.args;
385                let string = args[0].clone();
386                let substring = args[1].clone();
387                // STRPOS(string, substring) -> LOCATE(substring, string)
388                Ok(Expression::Function(Box::new(Function::new(
389                    "LOCATE".to_string(),
390                    vec![substring, string],
391                ))))
392            }
393
394            // INSTR is native in Databricks
395            "INSTR" => Ok(Expression::Function(Box::new(f))),
396
397            // LOCATE is native in Databricks
398            "LOCATE" => Ok(Expression::Function(Box::new(f))),
399
400            // ARRAY_LENGTH -> SIZE
401            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
402                Function::new("SIZE".to_string(), f.args),
403            ))),
404
405            // CARDINALITY -> SIZE
406            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
407                Function::new("SIZE".to_string(), f.args),
408            ))),
409
410            // SIZE is native
411            "SIZE" => Ok(Expression::Function(Box::new(f))),
412
413            // ARRAY_CONTAINS is native in Databricks
414            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
415
416            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
417            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
418            "CONTAINS" if f.args.len() == 2 => {
419                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
420                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
421                    && matches!(&f.args[1], Expression::Lower(_));
422                if is_string_contains {
423                    Ok(Expression::Function(Box::new(f)))
424                } else {
425                    Ok(Expression::Function(Box::new(Function::new(
426                        "ARRAY_CONTAINS".to_string(),
427                        f.args,
428                    ))))
429                }
430            }
431
432            // TO_DATE is native in Databricks
433            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
434
435            // TO_TIMESTAMP is native in Databricks
436            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
437
438            // DATE_FORMAT is native in Databricks
439            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
440
441            // strftime -> DATE_FORMAT in Databricks
442            "STRFTIME" if f.args.len() >= 2 => {
443                let mut args = f.args;
444                let format = args.remove(0);
445                let date = args.remove(0);
446                Ok(Expression::Function(Box::new(Function::new(
447                    "DATE_FORMAT".to_string(),
448                    vec![date, format],
449                ))))
450            }
451
452            // TO_CHAR is supported natively in Databricks (unlike Spark)
453            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
454
455            // DATE_TRUNC is native in Databricks
456            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
457
458            // DATEADD is native in Databricks - uppercase the unit if present
459            "DATEADD" => {
460                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
461                Ok(Expression::Function(Box::new(Function::new(
462                    "DATEADD".to_string(),
463                    transformed_args,
464                ))))
465            }
466
467            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
468            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
469            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
470            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
471            "DATE_ADD" => {
472                if f.args.len() == 2 {
473                    let is_simple_number = matches!(
474                        &f.args[1],
475                        Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::Number(_))
476                    ) || matches!(&f.args[1], Expression::Neg(_));
477                    if is_simple_number {
478                        // Keep as DATE_ADD(date, num_days)
479                        Ok(Expression::Function(Box::new(Function::new(
480                            "DATE_ADD".to_string(),
481                            f.args,
482                        ))))
483                    } else {
484                        let mut args = f.args;
485                        let date = args.remove(0);
486                        let interval = args.remove(0);
487                        let unit = Expression::Identifier(crate::expressions::Identifier {
488                            name: "DAY".to_string(),
489                            quoted: false,
490                            trailing_comments: Vec::new(),
491                            span: None,
492                        });
493                        Ok(Expression::Function(Box::new(Function::new(
494                            "DATEADD".to_string(),
495                            vec![unit, interval, date],
496                        ))))
497                    }
498                } else {
499                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
500                    Ok(Expression::Function(Box::new(Function::new(
501                        "DATE_ADD".to_string(),
502                        transformed_args,
503                    ))))
504                }
505            }
506
507            // DATEDIFF is native in Databricks - uppercase the unit if present
508            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
509            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
510            "DATEDIFF" => {
511                if f.args.len() == 2 {
512                    let mut args = f.args;
513                    let end_date = args.remove(0);
514                    let start_date = args.remove(0);
515                    let unit = Expression::Identifier(crate::expressions::Identifier {
516                        name: "DAY".to_string(),
517                        quoted: false,
518                        trailing_comments: Vec::new(),
519                        span: None,
520                    });
521                    Ok(Expression::Function(Box::new(Function::new(
522                        "DATEDIFF".to_string(),
523                        vec![unit, start_date, end_date],
524                    ))))
525                } else {
526                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
527                    Ok(Expression::Function(Box::new(Function::new(
528                        "DATEDIFF".to_string(),
529                        transformed_args,
530                    ))))
531                }
532            }
533
534            // DATE_DIFF -> DATEDIFF with uppercased unit
535            "DATE_DIFF" => {
536                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
537                Ok(Expression::Function(Box::new(Function::new(
538                    "DATEDIFF".to_string(),
539                    transformed_args,
540                ))))
541            }
542
543            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
544            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
545
546            // JSON_EXTRACT_SCALAR -> same handling
547            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
548
549            // GET_JSON_OBJECT -> colon syntax in Databricks
550            // GET_JSON_OBJECT(col, '$.path') becomes col:path
551            "GET_JSON_OBJECT" if f.args.len() == 2 => {
552                let mut args = f.args;
553                let col = args.remove(0);
554                let path_arg = args.remove(0);
555
556                // Extract and strip the $. prefix from the path
557                let path_expr = match &path_arg {
558                    Expression::Literal(lit)
559                        if matches!(lit.as_ref(), crate::expressions::Literal::String(_)) =>
560                    {
561                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
562                            unreachable!()
563                        };
564                        // Strip leading '$.' if present
565                        let stripped = if s.starts_with("$.") {
566                            &s[2..]
567                        } else if s.starts_with("$") {
568                            &s[1..]
569                        } else {
570                            s.as_str()
571                        };
572                        Expression::Literal(Box::new(crate::expressions::Literal::String(
573                            stripped.to_string(),
574                        )))
575                    }
576                    _ => path_arg,
577                };
578
579                Ok(Expression::JSONExtract(Box::new(JSONExtract {
580                    this: Box::new(col),
581                    expression: Box::new(path_expr),
582                    only_json_types: None,
583                    expressions: Vec::new(),
584                    variant_extract: Some(Box::new(Expression::true_())),
585                    json_query: None,
586                    option: None,
587                    quote: None,
588                    on_condition: None,
589                    requires_json: None,
590                })))
591            }
592
593            // FROM_JSON is native in Databricks
594            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
595
596            // PARSE_JSON is native in Databricks
597            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
598
599            // COLLECT_LIST is native in Databricks
600            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
601
602            // COLLECT_SET is native in Databricks
603            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
604
605            // RLIKE is native in Databricks
606            "RLIKE" => Ok(Expression::Function(Box::new(f))),
607
608            // REGEXP -> RLIKE in Databricks
609            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
610                "RLIKE".to_string(),
611                f.args,
612            )))),
613
614            // REGEXP_LIKE is native in Databricks
615            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
616
617            // LEVENSHTEIN is native in Databricks
618            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
619
620            // SEQUENCE is native (for GENERATE_SERIES)
621            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
622                Function::new("SEQUENCE".to_string(), f.args),
623            ))),
624
625            // SEQUENCE is native
626            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
627
628            // FLATTEN is native in Databricks
629            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
630
631            // ARRAY_SORT is native
632            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
633
634            // ARRAY_DISTINCT is native
635            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
636
637            // TRANSFORM is native (for array transformation)
638            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
639
640            // FILTER is native (for array filtering)
641            "FILTER" => Ok(Expression::Function(Box::new(f))),
642
643            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
644            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
645                let mut args = f.args;
646                let first_arg = args.remove(0);
647
648                // Check if first arg is already a Cast to TIMESTAMP
649                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
650                    first_arg
651                } else {
652                    // Wrap in CAST(... AS TIMESTAMP)
653                    Expression::Cast(Box::new(Cast {
654                        this: first_arg,
655                        to: DataType::Timestamp {
656                            precision: None,
657                            timezone: false,
658                        },
659                        trailing_comments: Vec::new(),
660                        double_colon_syntax: false,
661                        format: None,
662                        default: None,
663                        inferred_type: None,
664                    }))
665                };
666
667                let mut new_args = vec![wrapped_arg];
668                new_args.extend(args);
669
670                Ok(Expression::Function(Box::new(Function::new(
671                    "FROM_UTC_TIMESTAMP".to_string(),
672                    new_args,
673                ))))
674            }
675
676            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
677            "UNIFORM" if f.args.len() == 3 => {
678                let mut args = f.args;
679                let low = args.remove(0);
680                let high = args.remove(0);
681                let gen = args.remove(0);
682                match gen {
683                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
684                        if func.args.len() == 1 {
685                            // RANDOM(seed) -> extract seed
686                            let seed = func.args.into_iter().next().unwrap();
687                            Ok(Expression::Function(Box::new(Function::new(
688                                "UNIFORM".to_string(),
689                                vec![low, high, seed],
690                            ))))
691                        } else {
692                            // RANDOM() -> drop gen arg
693                            Ok(Expression::Function(Box::new(Function::new(
694                                "UNIFORM".to_string(),
695                                vec![low, high],
696                            ))))
697                        }
698                    }
699                    Expression::Rand(r) => {
700                        if let Some(seed) = r.seed {
701                            Ok(Expression::Function(Box::new(Function::new(
702                                "UNIFORM".to_string(),
703                                vec![low, high, *seed],
704                            ))))
705                        } else {
706                            Ok(Expression::Function(Box::new(Function::new(
707                                "UNIFORM".to_string(),
708                                vec![low, high],
709                            ))))
710                        }
711                    }
712                    _ => Ok(Expression::Function(Box::new(Function::new(
713                        "UNIFORM".to_string(),
714                        vec![low, high, gen],
715                    )))),
716                }
717            }
718
719            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
720            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
721                let subject = f.args[0].clone();
722                let pattern = f.args[1].clone();
723                Ok(Expression::Function(Box::new(Function::new(
724                    "REGEXP_EXTRACT".to_string(),
725                    vec![subject, pattern],
726                ))))
727            }
728
729            // BIT_GET -> GETBIT
730            "BIT_GET" => Ok(Expression::Function(Box::new(Function::new(
731                "GETBIT".to_string(),
732                f.args,
733            )))),
734
735            // Pass through everything else
736            _ => Ok(Expression::Function(Box::new(f))),
737        }
738    }
739
740    fn transform_aggregate_function(
741        &self,
742        f: Box<crate::expressions::AggregateFunction>,
743    ) -> Result<Expression> {
744        let name_upper = f.name.to_uppercase();
745        match name_upper.as_str() {
746            // COUNT_IF is native in Databricks (Spark 3+)
747            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
748
749            // ANY_VALUE is native in Databricks (Spark 3+)
750            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
751
752            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
753            "GROUP_CONCAT" if !f.args.is_empty() => {
754                let mut args = f.args;
755                let first = args.remove(0);
756                let separator = args.pop();
757                let collect_list = Expression::Function(Box::new(Function::new(
758                    "COLLECT_LIST".to_string(),
759                    vec![first],
760                )));
761                if let Some(sep) = separator {
762                    Ok(Expression::Function(Box::new(Function::new(
763                        "ARRAY_JOIN".to_string(),
764                        vec![collect_list, sep],
765                    ))))
766                } else {
767                    Ok(Expression::Function(Box::new(Function::new(
768                        "ARRAY_JOIN".to_string(),
769                        vec![collect_list],
770                    ))))
771                }
772            }
773
774            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
775            "STRING_AGG" if !f.args.is_empty() => {
776                let mut args = f.args;
777                let first = args.remove(0);
778                let separator = args.pop();
779                let collect_list = Expression::Function(Box::new(Function::new(
780                    "COLLECT_LIST".to_string(),
781                    vec![first],
782                )));
783                if let Some(sep) = separator {
784                    Ok(Expression::Function(Box::new(Function::new(
785                        "ARRAY_JOIN".to_string(),
786                        vec![collect_list, sep],
787                    ))))
788                } else {
789                    Ok(Expression::Function(Box::new(Function::new(
790                        "ARRAY_JOIN".to_string(),
791                        vec![collect_list],
792                    ))))
793                }
794            }
795
796            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
797            "LISTAGG" if !f.args.is_empty() => {
798                let mut args = f.args;
799                let first = args.remove(0);
800                let separator = args.pop();
801                let collect_list = Expression::Function(Box::new(Function::new(
802                    "COLLECT_LIST".to_string(),
803                    vec![first],
804                )));
805                if let Some(sep) = separator {
806                    Ok(Expression::Function(Box::new(Function::new(
807                        "ARRAY_JOIN".to_string(),
808                        vec![collect_list, sep],
809                    ))))
810                } else {
811                    Ok(Expression::Function(Box::new(Function::new(
812                        "ARRAY_JOIN".to_string(),
813                        vec![collect_list],
814                    ))))
815                }
816            }
817
818            // ARRAY_AGG -> COLLECT_LIST
819            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
820                "COLLECT_LIST".to_string(),
821                f.args,
822            )))),
823
824            // STDDEV is native in Databricks
825            "STDDEV" => Ok(Expression::AggregateFunction(f)),
826
827            // VARIANCE is native in Databricks
828            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
829
830            // APPROX_COUNT_DISTINCT is native in Databricks
831            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
832
833            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
834            "APPROX_DISTINCT" if !f.args.is_empty() => {
835                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
836                    name: "APPROX_COUNT_DISTINCT".to_string(),
837                    args: f.args,
838                    distinct: f.distinct,
839                    filter: f.filter,
840                    order_by: Vec::new(),
841                    limit: None,
842                    ignore_nulls: None,
843                    inferred_type: None,
844                })))
845            }
846
847            // Pass through everything else
848            _ => Ok(Expression::AggregateFunction(f)),
849        }
850    }
851
852    /// Transform Cast expressions - handles typed literals being cast
853    ///
854    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
855    /// Databricks/Spark transforms it as follows:
856    ///
857    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
858    ///
859    /// This reverses the types - the inner cast is to the target type,
860    /// the outer cast is to the original literal type.
861    fn transform_cast(&self, c: Cast) -> Result<Expression> {
862        // Check if the inner expression is a typed literal
863        match &c.this {
864            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
865            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Timestamp(_)) => {
866                let Literal::Timestamp(value) = lit.as_ref() else {
867                    unreachable!()
868                };
869                // Create inner cast: CAST('value' AS target_type)
870                let inner_cast = Expression::Cast(Box::new(Cast {
871                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
872                    to: c.to,
873                    trailing_comments: Vec::new(),
874                    double_colon_syntax: false,
875                    format: None,
876                    default: None,
877                    inferred_type: None,
878                }));
879                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
880                Ok(Expression::Cast(Box::new(Cast {
881                    this: inner_cast,
882                    to: DataType::Timestamp {
883                        precision: None,
884                        timezone: false,
885                    },
886                    trailing_comments: c.trailing_comments,
887                    double_colon_syntax: false,
888                    format: None,
889                    default: None,
890                    inferred_type: None,
891                })))
892            }
893            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
894            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
895                let Literal::Date(value) = lit.as_ref() else {
896                    unreachable!()
897                };
898                let inner_cast = Expression::Cast(Box::new(Cast {
899                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
900                    to: c.to,
901                    trailing_comments: Vec::new(),
902                    double_colon_syntax: false,
903                    format: None,
904                    default: None,
905                    inferred_type: None,
906                }));
907                Ok(Expression::Cast(Box::new(Cast {
908                    this: inner_cast,
909                    to: DataType::Date,
910                    trailing_comments: c.trailing_comments,
911                    double_colon_syntax: false,
912                    format: None,
913                    default: None,
914                    inferred_type: None,
915                })))
916            }
917            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
918            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Time(_)) => {
919                let Literal::Time(value) = lit.as_ref() else {
920                    unreachable!()
921                };
922                let inner_cast = Expression::Cast(Box::new(Cast {
923                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
924                    to: c.to,
925                    trailing_comments: Vec::new(),
926                    double_colon_syntax: false,
927                    format: None,
928                    default: None,
929                    inferred_type: None,
930                }));
931                Ok(Expression::Cast(Box::new(Cast {
932                    this: inner_cast,
933                    to: DataType::Time {
934                        precision: None,
935                        timezone: false,
936                    },
937                    trailing_comments: c.trailing_comments,
938                    double_colon_syntax: false,
939                    format: None,
940                    default: None,
941                    inferred_type: None,
942                })))
943            }
944            // For all other cases, pass through the Cast unchanged
945            _ => Ok(Expression::Cast(Box::new(c))),
946        }
947    }
948
949    /// Check if an expression is a CAST to TIMESTAMP
950    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
951        if let Expression::Cast(cast) = expr {
952            matches!(cast.to, DataType::Timestamp { .. })
953        } else {
954            false
955        }
956    }
957
958    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
959    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
960        use crate::expressions::Identifier;
961        if !args.is_empty() {
962            match &args[0] {
963                Expression::Identifier(id) => {
964                    args[0] = Expression::Identifier(Identifier {
965                        name: id.name.to_uppercase(),
966                        quoted: id.quoted,
967                        trailing_comments: id.trailing_comments.clone(),
968                        span: None,
969                    });
970                }
971                Expression::Var(v) => {
972                    args[0] = Expression::Identifier(Identifier {
973                        name: v.this.to_uppercase(),
974                        quoted: false,
975                        trailing_comments: Vec::new(),
976                        span: None,
977                    });
978                }
979                Expression::Column(col) if col.table.is_none() => {
980                    // Unqualified column name like "day" should be treated as a unit
981                    args[0] = Expression::Identifier(Identifier {
982                        name: col.name.name.to_uppercase(),
983                        quoted: col.name.quoted,
984                        trailing_comments: col.name.trailing_comments.clone(),
985                        span: None,
986                    });
987                }
988                _ => {}
989            }
990        }
991        args
992    }
993}
994
995#[cfg(test)]
996mod tests {
997    use super::*;
998    use crate::Dialect;
999
1000    #[test]
1001    fn test_timestamp_literal_cast() {
1002        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
1003        // This is test [47] in the Databricks dialect identity fixtures
1004        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
1005        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
1006
1007        let d = Dialect::get(DialectType::Databricks);
1008        let ast = d.parse(sql).expect("Parse failed");
1009        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1010        let output = d.generate(&transformed).expect("Generate failed");
1011
1012        assert_eq!(
1013            output, expected,
1014            "Timestamp literal cast transformation failed"
1015        );
1016    }
1017
1018    #[test]
1019    fn test_from_utc_timestamp_wraps_column() {
1020        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
1021        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1022        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1023
1024        let d = Dialect::get(DialectType::Databricks);
1025        let ast = d.parse(sql).expect("Parse failed");
1026        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1027        let output = d.generate(&transformed).expect("Generate failed");
1028
1029        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
1030    }
1031
1032    #[test]
1033    fn test_from_utc_timestamp_keeps_existing_cast() {
1034        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
1035        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
1036        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
1037        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
1038
1039        let d = Dialect::get(DialectType::Databricks);
1040        let ast = d.parse(sql).expect("Parse failed");
1041        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1042        let output = d.generate(&transformed).expect("Generate failed");
1043
1044        assert_eq!(
1045            output, expected,
1046            "FROM_UTC_TIMESTAMP with existing CAST failed"
1047        );
1048    }
1049
1050    #[test]
1051    fn test_deep_clone_version_as_of() {
1052        let sql = "CREATE TABLE events_clone DEEP CLONE events VERSION AS OF 5";
1053        let d = Dialect::get(DialectType::Databricks);
1054        let ast = d.parse(sql).expect("Parse failed");
1055        let output = d.generate(&ast[0]).expect("Generate failed");
1056
1057        assert_eq!(output, sql);
1058    }
1059
1060    #[test]
1061    fn test_deep_clone_timestamp_as_of() {
1062        let sql = "CREATE TABLE events_clone DEEP CLONE events TIMESTAMP AS OF '2024-01-01'";
1063        let d = Dialect::get(DialectType::Databricks);
1064        let ast = d.parse(sql).expect("Parse failed");
1065        let output = d.generate(&ast[0]).expect("Generate failed");
1066
1067        assert_eq!(output, sql);
1068    }
1069
1070    #[test]
1071    fn test_shallow_clone_still_roundtrips() {
1072        let sql = "CREATE TABLE events_clone SHALLOW CLONE events";
1073        let d = Dialect::get(DialectType::Databricks);
1074        let ast = d.parse(sql).expect("Parse failed");
1075        let output = d.generate(&ast[0]).expect("Generate failed");
1076
1077        assert_eq!(output, sql);
1078    }
1079
1080    #[test]
1081    fn test_repair_table_commands_roundtrip() {
1082        let d = Dialect::get(DialectType::Databricks);
1083        let cases = [
1084            "REPAIR TABLE events",
1085            "MSCK REPAIR TABLE events",
1086            "REPAIR TABLE events ADD PARTITIONS",
1087            "REPAIR TABLE events DROP PARTITIONS",
1088            "REPAIR TABLE events SYNC PARTITIONS",
1089            "REPAIR TABLE events SYNC METADATA",
1090        ];
1091
1092        for sql in cases {
1093            let ast = d.parse(sql).expect("Parse failed");
1094            let output = d.generate(&ast[0]).expect("Generate failed");
1095            assert_eq!(output, sql);
1096        }
1097    }
1098
1099    #[test]
1100    fn test_apply_changes_commands_roundtrip() {
1101        let d = Dialect::get(DialectType::Databricks);
1102        let cases = [
1103            "APPLY CHANGES INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1104            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) IGNORE NULL UPDATES SEQUENCE BY ts",
1105            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) APPLY AS DELETE WHEN operation = 'DELETE' SEQUENCE BY ts COLUMNS * EXCEPT (operation) STORED AS SCD TYPE 1",
1106            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) SEQUENCE BY ts STORED AS SCD TYPE 2 TRACK HISTORY ON * EXCEPT (updated_at)",
1107            "AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1108            "CREATE FLOW apply_cdc AS AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1109        ];
1110
1111        for sql in cases {
1112            let ast = d.parse(sql).expect("Parse failed");
1113            let output = d.generate(&ast[0]).expect("Generate failed");
1114            assert_eq!(output, sql);
1115        }
1116    }
1117
1118    #[test]
1119    fn test_generate_symlink_format_manifest_roundtrip() {
1120        let d = Dialect::get(DialectType::Databricks);
1121        let cases = [
1122            "GENERATE symlink_format_manifest FOR TABLE events",
1123            "GENERATE symlink_format_manifest FOR TABLE catalog.schema.events",
1124        ];
1125
1126        for sql in cases {
1127            let ast = d.parse(sql).expect("Parse failed");
1128            let output = d.generate(&ast[0]).expect("Generate failed");
1129            assert_eq!(output, sql);
1130        }
1131    }
1132
1133    #[test]
1134    fn test_convert_to_delta_roundtrip() {
1135        let d = Dialect::get(DialectType::Databricks);
1136        let cases = [
1137            "CONVERT TO DELTA parquet.`/mnt/data/events`",
1138            "CONVERT TO DELTA database_name.table_name",
1139            "CONVERT TO DELTA parquet.`s3://my-bucket/path/to/table` PARTITIONED BY (date DATE)",
1140            "CONVERT TO DELTA database_name.table_name NO STATISTICS",
1141        ];
1142
1143        for sql in cases {
1144            let ast = d.parse(sql).expect("Parse failed");
1145            let output = d.generate(&ast[0]).expect("Generate failed");
1146            assert_eq!(output, sql);
1147        }
1148    }
1149}