Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, DateAddFunc, Expression, Function, IntervalUnit, Literal,
14    UnaryFunc, VarArgFunc,
15};
16#[cfg(feature = "generate")]
17use crate::generator::GeneratorConfig;
18use crate::tokens::TokenizerConfig;
19
20/// Databricks dialect
21pub struct DatabricksDialect;
22
23impl DialectImpl for DatabricksDialect {
24    fn dialect_type(&self) -> DialectType {
25        DialectType::Databricks
26    }
27
28    fn tokenizer_config(&self) -> TokenizerConfig {
29        let mut config = TokenizerConfig::default();
30        // Databricks uses backticks for identifiers (NOT double quotes)
31        config.identifiers.clear();
32        config.identifiers.insert('`', '`');
33        // Databricks (like Hive/Spark) uses double quotes as string delimiters
34        config.quotes.insert("\"".to_string(), "\"".to_string());
35        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
36        config.string_escapes.push('\\');
37        // Databricks supports DIV keyword for integer division
38        config
39            .keywords
40            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
41        config
42            .keywords
43            .insert("REPAIR".to_string(), crate::tokens::TokenType::Command);
44        config
45            .keywords
46            .insert("MSCK".to_string(), crate::tokens::TokenType::Command);
47        // Databricks numeric literal suffixes (same as Hive/Spark)
48        config
49            .numeric_literals
50            .insert("L".to_string(), "BIGINT".to_string());
51        config
52            .numeric_literals
53            .insert("S".to_string(), "SMALLINT".to_string());
54        config
55            .numeric_literals
56            .insert("Y".to_string(), "TINYINT".to_string());
57        config
58            .numeric_literals
59            .insert("D".to_string(), "DOUBLE".to_string());
60        config
61            .numeric_literals
62            .insert("F".to_string(), "FLOAT".to_string());
63        config
64            .numeric_literals
65            .insert("BD".to_string(), "DECIMAL".to_string());
66        // Databricks allows identifiers to start with digits (like Hive/Spark)
67        config.identifiers_can_start_with_digit = true;
68        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
69        // Backslashes in raw strings are always literal (no escape processing)
70        config.string_escapes_allowed_in_raw_strings = false;
71        config
72    }
73
74    #[cfg(feature = "generate")]
75
76    fn generator_config(&self) -> GeneratorConfig {
77        use crate::generator::IdentifierQuoteStyle;
78        GeneratorConfig {
79            identifier_quote: '`',
80            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
81            dialect: Some(DialectType::Databricks),
82            struct_field_sep: ": ",
83            create_function_return_as: false,
84            tablesample_seed_keyword: "REPEATABLE",
85            identifiers_can_start_with_digit: true,
86            // Databricks uses COMMENT 'value' without = sign
87            schema_comment_with_eq: false,
88            ..Default::default()
89        }
90    }
91
92    #[cfg(feature = "transpile")]
93
94    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
95        match expr {
96            // IFNULL -> COALESCE in Databricks
97            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
98                original_name: None,
99                expressions: vec![f.this, f.expression],
100                inferred_type: None,
101            }))),
102
103            // NVL -> COALESCE in Databricks
104            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
105                original_name: None,
106                expressions: vec![f.this, f.expression],
107                inferred_type: None,
108            }))),
109
110            // TryCast is native in Databricks
111            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
112
113            // SafeCast -> TRY_CAST in Databricks
114            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
115
116            // ILIKE is native in Databricks (Spark 3+)
117            Expression::ILike(op) => Ok(Expression::ILike(op)),
118
119            // UNNEST -> EXPLODE in Databricks
120            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
121
122            // EXPLODE is native to Databricks
123            Expression::Explode(f) => Ok(Expression::Explode(f)),
124
125            // ExplodeOuter is supported
126            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
127
128            // RANDOM -> RAND in Databricks
129            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
130                seed: None,
131                lower: None,
132                upper: None,
133            }))),
134
135            // Rand is native
136            Expression::Rand(r) => Ok(Expression::Rand(r)),
137
138            // || (Concat) -> CONCAT in Databricks
139            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
140                "CONCAT".to_string(),
141                vec![op.left, op.right],
142            )))),
143
144            // RegexpLike is native in Databricks
145            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
146
147            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
148            // This is a complex sqlglot transformation where:
149            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
150            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
151            Expression::Cast(c) => self.transform_cast(*c),
152
153            // Generic function transformations
154            Expression::Function(f) => self.transform_function(*f),
155
156            // Generic aggregate function transformations
157            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
158
159            // DateSub -> DATE_ADD(date, -val) in Databricks
160            Expression::DateSub(f) => {
161                // Convert string literals to numbers (interval values are often stored as strings)
162                let val = match f.interval {
163                    Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(s) if s.parse::<i64>().is_ok()) =>
164                    {
165                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
166                            unreachable!()
167                        };
168                        Expression::Literal(Box::new(crate::expressions::Literal::Number(
169                            s.clone(),
170                        )))
171                    }
172                    other => other,
173                };
174                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
175                    this: val,
176                    inferred_type: None,
177                }));
178                Ok(Expression::Function(Box::new(Function::new(
179                    "DATE_ADD".to_string(),
180                    vec![f.this, neg_val],
181                ))))
182            }
183
184            // DateAdd -> native Databricks DATE_ADD forms.
185            Expression::DateAdd(f) => Ok(Self::transform_date_add(*f)),
186
187            // Pass through everything else
188            _ => Ok(expr),
189        }
190    }
191}
192
193#[cfg(feature = "transpile")]
194impl DatabricksDialect {
195    fn transform_function(&self, f: Function) -> Result<Expression> {
196        let name_upper = f.name.to_uppercase();
197        match name_upper.as_str() {
198            // IFNULL -> COALESCE
199            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
200                original_name: None,
201                expressions: f.args,
202                inferred_type: None,
203            }))),
204
205            // NVL -> COALESCE
206            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
207                original_name: None,
208                expressions: f.args,
209                inferred_type: None,
210            }))),
211
212            // ISNULL -> COALESCE
213            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
214                original_name: None,
215                expressions: f.args,
216                inferred_type: None,
217            }))),
218
219            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
220            "ROW" => Ok(Expression::Function(Box::new(Function::new(
221                "STRUCT".to_string(),
222                f.args,
223            )))),
224
225            // NAMED_STRUCT('a', 1) -> STRUCT(1 AS a) for SQLGlot Databricks outputs
226            "NAMED_STRUCT" if f.args.len() % 2 == 0 => {
227                let original_args = f.args.clone();
228                let mut struct_args = Vec::new();
229                for pair in f.args.chunks(2) {
230                    if let Expression::Literal(lit) = &pair[0] {
231                        if let Literal::String(field_name) = lit.as_ref() {
232                            struct_args.push(Expression::Alias(Box::new(
233                                crate::expressions::Alias {
234                                    this: pair[1].clone(),
235                                    alias: crate::expressions::Identifier::new(field_name),
236                                    column_aliases: Vec::new(),
237                                    alias_explicit_as: false,
238                                    alias_keyword: None,
239                                    pre_alias_comments: Vec::new(),
240                                    trailing_comments: Vec::new(),
241                                    inferred_type: None,
242                                },
243                            )));
244                            continue;
245                        }
246                    }
247                    return Ok(Expression::Function(Box::new(Function::new(
248                        "NAMED_STRUCT".to_string(),
249                        original_args,
250                    ))));
251                }
252                Ok(Expression::Function(Box::new(Function::new(
253                    "STRUCT".to_string(),
254                    struct_args,
255                ))))
256            }
257
258            // GETDATE -> CURRENT_TIMESTAMP
259            "GETDATE" => Ok(Expression::CurrentTimestamp(
260                crate::expressions::CurrentTimestamp {
261                    precision: None,
262                    sysdate: false,
263                },
264            )),
265
266            // NOW -> CURRENT_TIMESTAMP
267            "NOW" => Ok(Expression::CurrentTimestamp(
268                crate::expressions::CurrentTimestamp {
269                    precision: None,
270                    sysdate: false,
271                },
272            )),
273
274            // CURDATE -> CURRENT_DATE
275            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
276
277            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
278            "CURRENT_DATE" if f.args.is_empty() => {
279                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
280            }
281
282            // RANDOM -> RAND
283            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
284                seed: None,
285                lower: None,
286                upper: None,
287            }))),
288
289            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
290            "GROUP_CONCAT" if !f.args.is_empty() => {
291                let mut args = f.args;
292                let first = args.remove(0);
293                let separator = args.pop();
294                let collect_list = Expression::Function(Box::new(Function::new(
295                    "COLLECT_LIST".to_string(),
296                    vec![first],
297                )));
298                if let Some(sep) = separator {
299                    Ok(Expression::Function(Box::new(Function::new(
300                        "ARRAY_JOIN".to_string(),
301                        vec![collect_list, sep],
302                    ))))
303                } else {
304                    Ok(Expression::Function(Box::new(Function::new(
305                        "ARRAY_JOIN".to_string(),
306                        vec![collect_list],
307                    ))))
308                }
309            }
310
311            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
312            "STRING_AGG" if !f.args.is_empty() => {
313                let mut args = f.args;
314                let first = args.remove(0);
315                let separator = args.pop();
316                let collect_list = Expression::Function(Box::new(Function::new(
317                    "COLLECT_LIST".to_string(),
318                    vec![first],
319                )));
320                if let Some(sep) = separator {
321                    Ok(Expression::Function(Box::new(Function::new(
322                        "ARRAY_JOIN".to_string(),
323                        vec![collect_list, sep],
324                    ))))
325                } else {
326                    Ok(Expression::Function(Box::new(Function::new(
327                        "ARRAY_JOIN".to_string(),
328                        vec![collect_list],
329                    ))))
330                }
331            }
332
333            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
334            "LISTAGG" if !f.args.is_empty() => {
335                let mut args = f.args;
336                let first = args.remove(0);
337                let separator = args.pop();
338                let collect_list = Expression::Function(Box::new(Function::new(
339                    "COLLECT_LIST".to_string(),
340                    vec![first],
341                )));
342                if let Some(sep) = separator {
343                    Ok(Expression::Function(Box::new(Function::new(
344                        "ARRAY_JOIN".to_string(),
345                        vec![collect_list, sep],
346                    ))))
347                } else {
348                    Ok(Expression::Function(Box::new(Function::new(
349                        "ARRAY_JOIN".to_string(),
350                        vec![collect_list],
351                    ))))
352                }
353            }
354
355            // ARRAY_AGG -> COLLECT_LIST in Databricks
356            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
357                "COLLECT_LIST".to_string(),
358                f.args,
359            )))),
360
361            // SUBSTR -> SUBSTRING
362            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
363                "SUBSTRING".to_string(),
364                f.args,
365            )))),
366
367            // LEN -> LENGTH
368            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
369                f.args.into_iter().next().unwrap(),
370            )))),
371
372            // CHARINDEX -> LOCATE (with swapped args, like Spark)
373            "CHARINDEX" if f.args.len() >= 2 => {
374                let mut args = f.args;
375                let substring = args.remove(0);
376                let string = args.remove(0);
377                // LOCATE(substring, string)
378                Ok(Expression::Function(Box::new(Function::new(
379                    "LOCATE".to_string(),
380                    vec![substring, string],
381                ))))
382            }
383
384            // POSITION -> LOCATE
385            "POSITION" if f.args.len() == 2 => {
386                let args = f.args;
387                Ok(Expression::Function(Box::new(Function::new(
388                    "LOCATE".to_string(),
389                    args,
390                ))))
391            }
392
393            // STRPOS -> LOCATE (with same arg order)
394            "STRPOS" if f.args.len() == 2 => {
395                let args = f.args;
396                let string = args[0].clone();
397                let substring = args[1].clone();
398                // STRPOS(string, substring) -> LOCATE(substring, string)
399                Ok(Expression::Function(Box::new(Function::new(
400                    "LOCATE".to_string(),
401                    vec![substring, string],
402                ))))
403            }
404
405            // INSTR is native in Databricks
406            "INSTR" => Ok(Expression::Function(Box::new(f))),
407
408            // LOCATE is native in Databricks
409            "LOCATE" => Ok(Expression::Function(Box::new(f))),
410
411            // ARRAY_LENGTH -> SIZE
412            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
413                Function::new("SIZE".to_string(), f.args),
414            ))),
415
416            // CARDINALITY -> SIZE
417            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
418                Function::new("SIZE".to_string(), f.args),
419            ))),
420
421            // SIZE is native
422            "SIZE" => Ok(Expression::Function(Box::new(f))),
423
424            // ARRAY_CONTAINS is native in Databricks
425            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
426
427            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
428            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
429            "CONTAINS" if f.args.len() == 2 => {
430                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
431                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
432                    && matches!(&f.args[1], Expression::Lower(_));
433                if is_string_contains {
434                    Ok(Expression::Function(Box::new(f)))
435                } else {
436                    Ok(Expression::Function(Box::new(Function::new(
437                        "ARRAY_CONTAINS".to_string(),
438                        f.args,
439                    ))))
440                }
441            }
442
443            // TO_DATE is native in Databricks
444            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
445
446            // TO_TIMESTAMP is native in Databricks
447            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
448
449            // DATE_FORMAT is native in Databricks
450            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
451
452            // strftime -> DATE_FORMAT in Databricks
453            "STRFTIME" if f.args.len() >= 2 => {
454                let mut args = f.args;
455                let format = args.remove(0);
456                let date = args.remove(0);
457                Ok(Expression::Function(Box::new(Function::new(
458                    "DATE_FORMAT".to_string(),
459                    vec![date, format],
460                ))))
461            }
462
463            // TO_CHAR is supported natively in Databricks (unlike Spark)
464            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
465
466            // DATE_TRUNC is native in Databricks
467            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
468
469            // DATEADD normalizes to Databricks DATE_ADD with an uppercased unit.
470            "DATEADD" => {
471                if f.args.len() == 2 {
472                    Ok(Expression::Function(Box::new(Function::new(
473                        "DATE_ADD".to_string(),
474                        f.args,
475                    ))))
476                } else {
477                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
478                    let function_name = if matches!(
479                        transformed_args.first(),
480                        Some(Expression::Identifier(unit))
481                            if unit.name.eq_ignore_ascii_case("WEEK")
482                    ) {
483                        "DATEADD"
484                    } else {
485                        "DATE_ADD"
486                    };
487                    Ok(Expression::Function(Box::new(Function::new(
488                        function_name.to_string(),
489                        transformed_args,
490                    ))))
491                }
492            }
493
494            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
495            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
496            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
497            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
498            "DATE_ADD" => {
499                if f.args.len() == 2 {
500                    let is_simple_number = matches!(
501                        &f.args[1],
502                        Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::Number(_))
503                    ) || matches!(&f.args[1], Expression::Neg(_));
504                    if is_simple_number {
505                        // Keep as DATE_ADD(date, num_days)
506                        Ok(Expression::Function(Box::new(Function::new(
507                            "DATE_ADD".to_string(),
508                            f.args,
509                        ))))
510                    } else {
511                        let mut args = f.args;
512                        let date = args.remove(0);
513                        let interval = args.remove(0);
514                        let unit = Expression::Identifier(crate::expressions::Identifier {
515                            name: "DAY".to_string(),
516                            quoted: false,
517                            trailing_comments: Vec::new(),
518                            span: None,
519                        });
520                        Ok(Expression::Function(Box::new(Function::new(
521                            "DATEADD".to_string(),
522                            vec![unit, interval, date],
523                        ))))
524                    }
525                } else {
526                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
527                    Ok(Expression::Function(Box::new(Function::new(
528                        "DATE_ADD".to_string(),
529                        transformed_args,
530                    ))))
531                }
532            }
533
534            // DATEDIFF is native in Databricks - uppercase the unit if present
535            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
536            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
537            "DATEDIFF" => {
538                if f.args.len() == 2 {
539                    let mut args = f.args;
540                    let end_date = args.remove(0);
541                    let start_date = args.remove(0);
542                    let unit = Expression::Identifier(crate::expressions::Identifier {
543                        name: "DAY".to_string(),
544                        quoted: false,
545                        trailing_comments: Vec::new(),
546                        span: None,
547                    });
548                    Ok(Expression::Function(Box::new(Function::new(
549                        "DATEDIFF".to_string(),
550                        vec![unit, start_date, end_date],
551                    ))))
552                } else {
553                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
554                    Ok(Expression::Function(Box::new(Function::new(
555                        "DATEDIFF".to_string(),
556                        transformed_args,
557                    ))))
558                }
559            }
560
561            // DATE_DIFF -> DATEDIFF with uppercased unit
562            "DATE_DIFF" => {
563                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
564                Ok(Expression::Function(Box::new(Function::new(
565                    "DATEDIFF".to_string(),
566                    transformed_args,
567                ))))
568            }
569
570            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
571            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
572
573            // JSON_EXTRACT_SCALAR -> same handling
574            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
575
576            // GET_JSON_OBJECT is native in Databricks.
577            "GET_JSON_OBJECT" if f.args.len() == 2 => {
578                let mut args = f.args;
579                let col = args.remove(0);
580                let path_arg = args.remove(0);
581
582                Ok(Expression::Function(Box::new(Function::new(
583                    "GET_JSON_OBJECT".to_string(),
584                    vec![col, self.normalize_get_json_object_path(path_arg)],
585                ))))
586            }
587
588            // FROM_JSON is native in Databricks
589            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
590
591            // PARSE_JSON is native in Databricks
592            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
593
594            // COLLECT_LIST is native in Databricks
595            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
596
597            // COLLECT_SET is native in Databricks
598            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
599
600            // RLIKE is native in Databricks
601            "RLIKE" => Ok(Expression::Function(Box::new(f))),
602
603            // REGEXP -> RLIKE in Databricks
604            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
605                "RLIKE".to_string(),
606                f.args,
607            )))),
608
609            // REGEXP_LIKE is native in Databricks
610            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
611
612            // LEVENSHTEIN is native in Databricks
613            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
614
615            // SEQUENCE is native (for GENERATE_SERIES)
616            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
617                Function::new("SEQUENCE".to_string(), f.args),
618            ))),
619
620            // SEQUENCE is native
621            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
622
623            // FLATTEN is native in Databricks
624            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
625
626            // ARRAY_SORT is native
627            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
628
629            // ARRAY_DISTINCT is native
630            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
631
632            // TRANSFORM is native (for array transformation)
633            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
634
635            // FILTER is native (for array filtering)
636            "FILTER" => Ok(Expression::Function(Box::new(f))),
637
638            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
639            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
640                let mut args = f.args;
641                let first_arg = args.remove(0);
642
643                // Check if first arg is already a Cast to TIMESTAMP
644                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
645                    first_arg
646                } else {
647                    // Wrap in CAST(... AS TIMESTAMP)
648                    Expression::Cast(Box::new(Cast {
649                        this: first_arg,
650                        to: DataType::Timestamp {
651                            precision: None,
652                            timezone: false,
653                        },
654                        trailing_comments: Vec::new(),
655                        double_colon_syntax: false,
656                        format: None,
657                        default: None,
658                        inferred_type: None,
659                    }))
660                };
661
662                let mut new_args = vec![wrapped_arg];
663                new_args.extend(args);
664
665                Ok(Expression::Function(Box::new(Function::new(
666                    "FROM_UTC_TIMESTAMP".to_string(),
667                    new_args,
668                ))))
669            }
670
671            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
672            "UNIFORM" if f.args.len() == 3 => {
673                let mut args = f.args;
674                let low = args.remove(0);
675                let high = args.remove(0);
676                let gen = args.remove(0);
677                match gen {
678                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
679                        if func.args.len() == 1 {
680                            // RANDOM(seed) -> extract seed
681                            let seed = func.args.into_iter().next().unwrap();
682                            Ok(Expression::Function(Box::new(Function::new(
683                                "UNIFORM".to_string(),
684                                vec![low, high, seed],
685                            ))))
686                        } else {
687                            // RANDOM() -> drop gen arg
688                            Ok(Expression::Function(Box::new(Function::new(
689                                "UNIFORM".to_string(),
690                                vec![low, high],
691                            ))))
692                        }
693                    }
694                    Expression::Rand(r) => {
695                        if let Some(seed) = r.seed {
696                            Ok(Expression::Function(Box::new(Function::new(
697                                "UNIFORM".to_string(),
698                                vec![low, high, *seed],
699                            ))))
700                        } else {
701                            Ok(Expression::Function(Box::new(Function::new(
702                                "UNIFORM".to_string(),
703                                vec![low, high],
704                            ))))
705                        }
706                    }
707                    _ => Ok(Expression::Function(Box::new(Function::new(
708                        "UNIFORM".to_string(),
709                        vec![low, high, gen],
710                    )))),
711                }
712            }
713
714            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
715            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
716                let subject = f.args[0].clone();
717                let pattern = f.args[1].clone();
718                Ok(Expression::Function(Box::new(Function::new(
719                    "REGEXP_EXTRACT".to_string(),
720                    vec![subject, pattern],
721                ))))
722            }
723
724            // BIT_GET -> GETBIT
725            "BIT_GET" => Ok(Expression::Function(Box::new(Function::new(
726                "GETBIT".to_string(),
727                f.args,
728            )))),
729
730            // Pass through everything else
731            _ => Ok(Expression::Function(Box::new(f))),
732        }
733    }
734
735    fn transform_aggregate_function(
736        &self,
737        f: Box<crate::expressions::AggregateFunction>,
738    ) -> Result<Expression> {
739        let name_upper = f.name.to_uppercase();
740        match name_upper.as_str() {
741            // COUNT_IF is native in Databricks (Spark 3+)
742            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
743
744            // ANY_VALUE is native in Databricks (Spark 3+)
745            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
746
747            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
748            "GROUP_CONCAT" if !f.args.is_empty() => {
749                let mut args = f.args;
750                let first = args.remove(0);
751                let separator = args.pop();
752                let collect_list = Expression::Function(Box::new(Function::new(
753                    "COLLECT_LIST".to_string(),
754                    vec![first],
755                )));
756                if let Some(sep) = separator {
757                    Ok(Expression::Function(Box::new(Function::new(
758                        "ARRAY_JOIN".to_string(),
759                        vec![collect_list, sep],
760                    ))))
761                } else {
762                    Ok(Expression::Function(Box::new(Function::new(
763                        "ARRAY_JOIN".to_string(),
764                        vec![collect_list],
765                    ))))
766                }
767            }
768
769            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
770            "STRING_AGG" if !f.args.is_empty() => {
771                let mut args = f.args;
772                let first = args.remove(0);
773                let separator = args.pop();
774                let collect_list = Expression::Function(Box::new(Function::new(
775                    "COLLECT_LIST".to_string(),
776                    vec![first],
777                )));
778                if let Some(sep) = separator {
779                    Ok(Expression::Function(Box::new(Function::new(
780                        "ARRAY_JOIN".to_string(),
781                        vec![collect_list, sep],
782                    ))))
783                } else {
784                    Ok(Expression::Function(Box::new(Function::new(
785                        "ARRAY_JOIN".to_string(),
786                        vec![collect_list],
787                    ))))
788                }
789            }
790
791            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
792            "LISTAGG" if !f.args.is_empty() => {
793                let mut args = f.args;
794                let first = args.remove(0);
795                let separator = args.pop();
796                let collect_list = Expression::Function(Box::new(Function::new(
797                    "COLLECT_LIST".to_string(),
798                    vec![first],
799                )));
800                if let Some(sep) = separator {
801                    Ok(Expression::Function(Box::new(Function::new(
802                        "ARRAY_JOIN".to_string(),
803                        vec![collect_list, sep],
804                    ))))
805                } else {
806                    Ok(Expression::Function(Box::new(Function::new(
807                        "ARRAY_JOIN".to_string(),
808                        vec![collect_list],
809                    ))))
810                }
811            }
812
813            // ARRAY_AGG -> COLLECT_LIST
814            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
815                "COLLECT_LIST".to_string(),
816                f.args,
817            )))),
818
819            // STDDEV is native in Databricks
820            "STDDEV" => Ok(Expression::AggregateFunction(f)),
821
822            // VARIANCE is native in Databricks
823            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
824
825            // APPROX_COUNT_DISTINCT is native in Databricks
826            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
827
828            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
829            "APPROX_DISTINCT" if !f.args.is_empty() => {
830                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
831                    name: "APPROX_COUNT_DISTINCT".to_string(),
832                    args: f.args,
833                    distinct: f.distinct,
834                    filter: f.filter,
835                    order_by: Vec::new(),
836                    limit: None,
837                    ignore_nulls: None,
838                    inferred_type: None,
839                })))
840            }
841
842            // Pass through everything else
843            _ => Ok(Expression::AggregateFunction(f)),
844        }
845    }
846
847    /// Transform Cast expressions - handles typed literals being cast
848    ///
849    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
850    /// Databricks/Spark transforms it as follows:
851    ///
852    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
853    ///
854    /// This reverses the types - the inner cast is to the target type,
855    /// the outer cast is to the original literal type.
856    fn transform_cast(&self, c: Cast) -> Result<Expression> {
857        // Check if the inner expression is a typed literal
858        match &c.this {
859            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
860            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Timestamp(_)) => {
861                let Literal::Timestamp(value) = lit.as_ref() else {
862                    unreachable!()
863                };
864                // Create inner cast: CAST('value' AS target_type)
865                let inner_cast = Expression::Cast(Box::new(Cast {
866                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
867                    to: c.to,
868                    trailing_comments: Vec::new(),
869                    double_colon_syntax: false,
870                    format: None,
871                    default: None,
872                    inferred_type: None,
873                }));
874                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
875                Ok(Expression::Cast(Box::new(Cast {
876                    this: inner_cast,
877                    to: DataType::Timestamp {
878                        precision: None,
879                        timezone: false,
880                    },
881                    trailing_comments: c.trailing_comments,
882                    double_colon_syntax: false,
883                    format: None,
884                    default: None,
885                    inferred_type: None,
886                })))
887            }
888            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
889            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
890                let Literal::Date(value) = lit.as_ref() else {
891                    unreachable!()
892                };
893                let inner_cast = Expression::Cast(Box::new(Cast {
894                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
895                    to: c.to,
896                    trailing_comments: Vec::new(),
897                    double_colon_syntax: false,
898                    format: None,
899                    default: None,
900                    inferred_type: None,
901                }));
902                Ok(Expression::Cast(Box::new(Cast {
903                    this: inner_cast,
904                    to: DataType::Date,
905                    trailing_comments: c.trailing_comments,
906                    double_colon_syntax: false,
907                    format: None,
908                    default: None,
909                    inferred_type: None,
910                })))
911            }
912            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
913            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Time(_)) => {
914                let Literal::Time(value) = lit.as_ref() else {
915                    unreachable!()
916                };
917                let inner_cast = Expression::Cast(Box::new(Cast {
918                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
919                    to: c.to,
920                    trailing_comments: Vec::new(),
921                    double_colon_syntax: false,
922                    format: None,
923                    default: None,
924                    inferred_type: None,
925                }));
926                Ok(Expression::Cast(Box::new(Cast {
927                    this: inner_cast,
928                    to: DataType::Time {
929                        precision: None,
930                        timezone: false,
931                    },
932                    trailing_comments: c.trailing_comments,
933                    double_colon_syntax: false,
934                    format: None,
935                    default: None,
936                    inferred_type: None,
937                })))
938            }
939            // For all other cases, pass through the Cast unchanged
940            _ => Ok(Expression::Cast(Box::new(c))),
941        }
942    }
943
944    /// Check if an expression is a CAST to TIMESTAMP
945    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
946        if let Expression::Cast(cast) = expr {
947            matches!(cast.to, DataType::Timestamp { .. })
948        } else {
949            false
950        }
951    }
952
953    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
954    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
955        use crate::expressions::Identifier;
956        if !args.is_empty() {
957            match &args[0] {
958                Expression::Identifier(id) => {
959                    args[0] = Expression::Identifier(Identifier {
960                        name: id.name.to_uppercase(),
961                        quoted: id.quoted,
962                        trailing_comments: id.trailing_comments.clone(),
963                        span: None,
964                    });
965                }
966                Expression::Var(v) => {
967                    args[0] = Expression::Identifier(Identifier {
968                        name: v.this.to_uppercase(),
969                        quoted: false,
970                        trailing_comments: Vec::new(),
971                        span: None,
972                    });
973                }
974                Expression::Column(col) if col.table.is_none() => {
975                    // Unqualified column name like "day" should be treated as a unit
976                    args[0] = Expression::Identifier(Identifier {
977                        name: col.name.name.to_uppercase(),
978                        quoted: col.name.quoted,
979                        trailing_comments: col.name.trailing_comments.clone(),
980                        span: None,
981                    });
982                }
983                _ => {}
984            }
985        }
986        args
987    }
988
989    fn normalize_get_json_object_path(&self, path_arg: Expression) -> Expression {
990        let Expression::Literal(lit) = &path_arg else {
991            return path_arg;
992        };
993        let crate::expressions::Literal::String(path) = lit.as_ref() else {
994            return path_arg;
995        };
996
997        let Some(segment) = path.strip_prefix("$.") else {
998            return path_arg;
999        };
1000
1001        if segment.is_empty()
1002            || segment.contains('.')
1003            || segment.contains('[')
1004            || segment
1005                .chars()
1006                .all(|c| c.is_ascii_alphanumeric() || c == '_')
1007        {
1008            return path_arg;
1009        }
1010
1011        Expression::Literal(Box::new(crate::expressions::Literal::String(format!(
1012            "$[\"{}\"]",
1013            segment.replace('"', "\\\"")
1014        ))))
1015    }
1016
1017    fn transform_date_add(f: DateAddFunc) -> Expression {
1018        if f.unit == IntervalUnit::Day {
1019            Expression::Function(Box::new(Function::new(
1020                "DATE_ADD".to_string(),
1021                vec![f.this, f.interval],
1022            )))
1023        } else {
1024            Expression::Function(Box::new(Function::new(
1025                "DATE_ADD".to_string(),
1026                vec![
1027                    Expression::Identifier(crate::expressions::Identifier {
1028                        name: Self::interval_unit_name(f.unit).to_string(),
1029                        quoted: false,
1030                        trailing_comments: Vec::new(),
1031                        span: None,
1032                    }),
1033                    f.interval,
1034                    f.this,
1035                ],
1036            )))
1037        }
1038    }
1039
1040    fn interval_unit_name(unit: IntervalUnit) -> &'static str {
1041        match unit {
1042            IntervalUnit::Year => "YEAR",
1043            IntervalUnit::Quarter => "QUARTER",
1044            IntervalUnit::Month => "MONTH",
1045            IntervalUnit::Week => "WEEK",
1046            IntervalUnit::Day => "DAY",
1047            IntervalUnit::Hour => "HOUR",
1048            IntervalUnit::Minute => "MINUTE",
1049            IntervalUnit::Second => "SECOND",
1050            IntervalUnit::Millisecond => "MILLISECOND",
1051            IntervalUnit::Microsecond => "MICROSECOND",
1052            IntervalUnit::Nanosecond => "NANOSECOND",
1053        }
1054    }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060    use crate::Dialect;
1061
1062    #[test]
1063    fn test_timestamp_literal_cast() {
1064        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
1065        // This is test [47] in the Databricks dialect identity fixtures
1066        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
1067        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
1068
1069        let d = Dialect::get(DialectType::Databricks);
1070        let ast = d.parse(sql).expect("Parse failed");
1071        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1072        let output = d.generate(&transformed).expect("Generate failed");
1073
1074        assert_eq!(
1075            output, expected,
1076            "Timestamp literal cast transformation failed"
1077        );
1078    }
1079
1080    #[test]
1081    fn test_from_utc_timestamp_wraps_column() {
1082        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
1083        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1084        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1085
1086        let d = Dialect::get(DialectType::Databricks);
1087        let ast = d.parse(sql).expect("Parse failed");
1088        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1089        let output = d.generate(&transformed).expect("Generate failed");
1090
1091        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
1092    }
1093
1094    #[test]
1095    fn test_from_utc_timestamp_keeps_existing_cast() {
1096        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
1097        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
1098        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
1099        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
1100
1101        let d = Dialect::get(DialectType::Databricks);
1102        let ast = d.parse(sql).expect("Parse failed");
1103        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1104        let output = d.generate(&transformed).expect("Generate failed");
1105
1106        assert_eq!(
1107            output, expected,
1108            "FROM_UTC_TIMESTAMP with existing CAST failed"
1109        );
1110    }
1111
1112    #[test]
1113    fn test_deep_clone_version_as_of() {
1114        let sql = "CREATE TABLE events_clone DEEP CLONE events VERSION AS OF 5";
1115        let d = Dialect::get(DialectType::Databricks);
1116        let ast = d.parse(sql).expect("Parse failed");
1117        let output = d.generate(&ast[0]).expect("Generate failed");
1118
1119        assert_eq!(output, sql);
1120    }
1121
1122    #[test]
1123    fn test_deep_clone_timestamp_as_of() {
1124        let sql = "CREATE TABLE events_clone DEEP CLONE events TIMESTAMP AS OF '2024-01-01'";
1125        let d = Dialect::get(DialectType::Databricks);
1126        let ast = d.parse(sql).expect("Parse failed");
1127        let output = d.generate(&ast[0]).expect("Generate failed");
1128
1129        assert_eq!(output, sql);
1130    }
1131
1132    #[test]
1133    fn test_shallow_clone_still_roundtrips() {
1134        let sql = "CREATE TABLE events_clone SHALLOW CLONE events";
1135        let d = Dialect::get(DialectType::Databricks);
1136        let ast = d.parse(sql).expect("Parse failed");
1137        let output = d.generate(&ast[0]).expect("Generate failed");
1138
1139        assert_eq!(output, sql);
1140    }
1141
1142    #[test]
1143    fn test_repair_table_commands_roundtrip() {
1144        let d = Dialect::get(DialectType::Databricks);
1145        let cases = [
1146            "REPAIR TABLE events",
1147            "MSCK REPAIR TABLE events",
1148            "REPAIR TABLE events ADD PARTITIONS",
1149            "REPAIR TABLE events DROP PARTITIONS",
1150            "REPAIR TABLE events SYNC PARTITIONS",
1151            "REPAIR TABLE events SYNC METADATA",
1152        ];
1153
1154        for sql in cases {
1155            let ast = d.parse(sql).expect("Parse failed");
1156            let output = d.generate(&ast[0]).expect("Generate failed");
1157            assert_eq!(output, sql);
1158        }
1159    }
1160
1161    #[test]
1162    fn test_apply_changes_commands_roundtrip() {
1163        let d = Dialect::get(DialectType::Databricks);
1164        let cases = [
1165            "APPLY CHANGES INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1166            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) IGNORE NULL UPDATES SEQUENCE BY ts",
1167            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) APPLY AS DELETE WHEN operation = 'DELETE' SEQUENCE BY ts COLUMNS * EXCEPT (operation) STORED AS SCD TYPE 1",
1168            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) SEQUENCE BY ts STORED AS SCD TYPE 2 TRACK HISTORY ON * EXCEPT (updated_at)",
1169            "AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1170            "CREATE FLOW apply_cdc AS AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1171        ];
1172
1173        for sql in cases {
1174            let ast = d.parse(sql).expect("Parse failed");
1175            let output = d.generate(&ast[0]).expect("Generate failed");
1176            assert_eq!(output, sql);
1177        }
1178    }
1179
1180    #[test]
1181    fn test_generate_symlink_format_manifest_roundtrip() {
1182        let d = Dialect::get(DialectType::Databricks);
1183        let cases = [
1184            "GENERATE symlink_format_manifest FOR TABLE events",
1185            "GENERATE symlink_format_manifest FOR TABLE catalog.schema.events",
1186        ];
1187
1188        for sql in cases {
1189            let ast = d.parse(sql).expect("Parse failed");
1190            let output = d.generate(&ast[0]).expect("Generate failed");
1191            assert_eq!(output, sql);
1192        }
1193    }
1194
1195    #[test]
1196    fn test_convert_to_delta_roundtrip() {
1197        let d = Dialect::get(DialectType::Databricks);
1198        let cases = [
1199            "CONVERT TO DELTA parquet.`/mnt/data/events`",
1200            "CONVERT TO DELTA database_name.table_name",
1201            "CONVERT TO DELTA parquet.`s3://my-bucket/path/to/table` PARTITIONED BY (date DATE)",
1202            "CONVERT TO DELTA database_name.table_name NO STATISTICS",
1203        ];
1204
1205        for sql in cases {
1206            let ast = d.parse(sql).expect("Parse failed");
1207            let output = d.generate(&ast[0]).expect("Generate failed");
1208            assert_eq!(output, sql);
1209        }
1210    }
1211}