Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Databricks dialect
20pub struct DatabricksDialect;
21
22impl DialectImpl for DatabricksDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Databricks
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Databricks uses backticks for identifiers (NOT double quotes)
30        config.identifiers.clear();
31        config.identifiers.insert('`', '`');
32        // Databricks (like Hive/Spark) uses double quotes as string delimiters
33        config.quotes.insert("\"".to_string(), "\"".to_string());
34        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
35        config.string_escapes.push('\\');
36        // Databricks supports DIV keyword for integer division
37        config
38            .keywords
39            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
40        config
41            .keywords
42            .insert("REPAIR".to_string(), crate::tokens::TokenType::Command);
43        config
44            .keywords
45            .insert("MSCK".to_string(), crate::tokens::TokenType::Command);
46        // Databricks numeric literal suffixes (same as Hive/Spark)
47        config
48            .numeric_literals
49            .insert("L".to_string(), "BIGINT".to_string());
50        config
51            .numeric_literals
52            .insert("S".to_string(), "SMALLINT".to_string());
53        config
54            .numeric_literals
55            .insert("Y".to_string(), "TINYINT".to_string());
56        config
57            .numeric_literals
58            .insert("D".to_string(), "DOUBLE".to_string());
59        config
60            .numeric_literals
61            .insert("F".to_string(), "FLOAT".to_string());
62        config
63            .numeric_literals
64            .insert("BD".to_string(), "DECIMAL".to_string());
65        // Databricks allows identifiers to start with digits (like Hive/Spark)
66        config.identifiers_can_start_with_digit = true;
67        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
68        // Backslashes in raw strings are always literal (no escape processing)
69        config.string_escapes_allowed_in_raw_strings = false;
70        config
71    }
72
73    fn generator_config(&self) -> GeneratorConfig {
74        use crate::generator::IdentifierQuoteStyle;
75        GeneratorConfig {
76            identifier_quote: '`',
77            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
78            dialect: Some(DialectType::Databricks),
79            struct_field_sep: ": ",
80            create_function_return_as: false,
81            tablesample_seed_keyword: "REPEATABLE",
82            identifiers_can_start_with_digit: true,
83            // Databricks uses COMMENT 'value' without = sign
84            schema_comment_with_eq: false,
85            ..Default::default()
86        }
87    }
88
89    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
90        match expr {
91            // IFNULL -> COALESCE in Databricks
92            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
93                original_name: None,
94                expressions: vec![f.this, f.expression],
95                inferred_type: None,
96            }))),
97
98            // NVL -> COALESCE in Databricks
99            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
100                original_name: None,
101                expressions: vec![f.this, f.expression],
102                inferred_type: None,
103            }))),
104
105            // TryCast is native in Databricks
106            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
107
108            // SafeCast -> TRY_CAST in Databricks
109            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
110
111            // ILIKE is native in Databricks (Spark 3+)
112            Expression::ILike(op) => Ok(Expression::ILike(op)),
113
114            // UNNEST -> EXPLODE in Databricks
115            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
116
117            // EXPLODE is native to Databricks
118            Expression::Explode(f) => Ok(Expression::Explode(f)),
119
120            // ExplodeOuter is supported
121            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
122
123            // RANDOM -> RAND in Databricks
124            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
125                seed: None,
126                lower: None,
127                upper: None,
128            }))),
129
130            // Rand is native
131            Expression::Rand(r) => Ok(Expression::Rand(r)),
132
133            // || (Concat) -> CONCAT in Databricks
134            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
135                "CONCAT".to_string(),
136                vec![op.left, op.right],
137            )))),
138
139            // RegexpLike is native in Databricks
140            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
141
142            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
143            // This is a complex sqlglot transformation where:
144            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
145            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
146            Expression::Cast(c) => self.transform_cast(*c),
147
148            // Generic function transformations
149            Expression::Function(f) => self.transform_function(*f),
150
151            // Generic aggregate function transformations
152            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
153
154            // DateSub -> DATE_ADD(date, -val) in Databricks
155            Expression::DateSub(f) => {
156                // Convert string literals to numbers (interval values are often stored as strings)
157                let val = match f.interval {
158                    Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(s) if s.parse::<i64>().is_ok()) =>
159                    {
160                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
161                            unreachable!()
162                        };
163                        Expression::Literal(Box::new(crate::expressions::Literal::Number(
164                            s.clone(),
165                        )))
166                    }
167                    other => other,
168                };
169                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
170                    this: val,
171                    inferred_type: None,
172                }));
173                Ok(Expression::Function(Box::new(Function::new(
174                    "DATE_ADD".to_string(),
175                    vec![f.this, neg_val],
176                ))))
177            }
178
179            // Pass through everything else
180            _ => Ok(expr),
181        }
182    }
183}
184
185impl DatabricksDialect {
186    fn transform_function(&self, f: Function) -> Result<Expression> {
187        let name_upper = f.name.to_uppercase();
188        match name_upper.as_str() {
189            // IFNULL -> COALESCE
190            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
191                original_name: None,
192                expressions: f.args,
193                inferred_type: None,
194            }))),
195
196            // NVL -> COALESCE
197            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
198                original_name: None,
199                expressions: f.args,
200                inferred_type: None,
201            }))),
202
203            // ISNULL -> COALESCE
204            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
205                original_name: None,
206                expressions: f.args,
207                inferred_type: None,
208            }))),
209
210            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
211            "ROW" => Ok(Expression::Function(Box::new(Function::new(
212                "STRUCT".to_string(),
213                f.args,
214            )))),
215
216            // GETDATE -> CURRENT_TIMESTAMP
217            "GETDATE" => Ok(Expression::CurrentTimestamp(
218                crate::expressions::CurrentTimestamp {
219                    precision: None,
220                    sysdate: false,
221                },
222            )),
223
224            // NOW -> CURRENT_TIMESTAMP
225            "NOW" => Ok(Expression::CurrentTimestamp(
226                crate::expressions::CurrentTimestamp {
227                    precision: None,
228                    sysdate: false,
229                },
230            )),
231
232            // CURDATE -> CURRENT_DATE
233            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
234
235            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
236            "CURRENT_DATE" if f.args.is_empty() => {
237                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
238            }
239
240            // RANDOM -> RAND
241            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
242                seed: None,
243                lower: None,
244                upper: None,
245            }))),
246
247            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
248            "GROUP_CONCAT" if !f.args.is_empty() => {
249                let mut args = f.args;
250                let first = args.remove(0);
251                let separator = args.pop();
252                let collect_list = Expression::Function(Box::new(Function::new(
253                    "COLLECT_LIST".to_string(),
254                    vec![first],
255                )));
256                if let Some(sep) = separator {
257                    Ok(Expression::Function(Box::new(Function::new(
258                        "ARRAY_JOIN".to_string(),
259                        vec![collect_list, sep],
260                    ))))
261                } else {
262                    Ok(Expression::Function(Box::new(Function::new(
263                        "ARRAY_JOIN".to_string(),
264                        vec![collect_list],
265                    ))))
266                }
267            }
268
269            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
270            "STRING_AGG" if !f.args.is_empty() => {
271                let mut args = f.args;
272                let first = args.remove(0);
273                let separator = args.pop();
274                let collect_list = Expression::Function(Box::new(Function::new(
275                    "COLLECT_LIST".to_string(),
276                    vec![first],
277                )));
278                if let Some(sep) = separator {
279                    Ok(Expression::Function(Box::new(Function::new(
280                        "ARRAY_JOIN".to_string(),
281                        vec![collect_list, sep],
282                    ))))
283                } else {
284                    Ok(Expression::Function(Box::new(Function::new(
285                        "ARRAY_JOIN".to_string(),
286                        vec![collect_list],
287                    ))))
288                }
289            }
290
291            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
292            "LISTAGG" if !f.args.is_empty() => {
293                let mut args = f.args;
294                let first = args.remove(0);
295                let separator = args.pop();
296                let collect_list = Expression::Function(Box::new(Function::new(
297                    "COLLECT_LIST".to_string(),
298                    vec![first],
299                )));
300                if let Some(sep) = separator {
301                    Ok(Expression::Function(Box::new(Function::new(
302                        "ARRAY_JOIN".to_string(),
303                        vec![collect_list, sep],
304                    ))))
305                } else {
306                    Ok(Expression::Function(Box::new(Function::new(
307                        "ARRAY_JOIN".to_string(),
308                        vec![collect_list],
309                    ))))
310                }
311            }
312
313            // ARRAY_AGG -> COLLECT_LIST in Databricks
314            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
315                "COLLECT_LIST".to_string(),
316                f.args,
317            )))),
318
319            // SUBSTR -> SUBSTRING
320            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
321                "SUBSTRING".to_string(),
322                f.args,
323            )))),
324
325            // LEN -> LENGTH
326            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
327                f.args.into_iter().next().unwrap(),
328            )))),
329
330            // CHARINDEX -> LOCATE (with swapped args, like Spark)
331            "CHARINDEX" if f.args.len() >= 2 => {
332                let mut args = f.args;
333                let substring = args.remove(0);
334                let string = args.remove(0);
335                // LOCATE(substring, string)
336                Ok(Expression::Function(Box::new(Function::new(
337                    "LOCATE".to_string(),
338                    vec![substring, string],
339                ))))
340            }
341
342            // POSITION -> LOCATE
343            "POSITION" if f.args.len() == 2 => {
344                let args = f.args;
345                Ok(Expression::Function(Box::new(Function::new(
346                    "LOCATE".to_string(),
347                    args,
348                ))))
349            }
350
351            // STRPOS -> LOCATE (with same arg order)
352            "STRPOS" if f.args.len() == 2 => {
353                let args = f.args;
354                let string = args[0].clone();
355                let substring = args[1].clone();
356                // STRPOS(string, substring) -> LOCATE(substring, string)
357                Ok(Expression::Function(Box::new(Function::new(
358                    "LOCATE".to_string(),
359                    vec![substring, string],
360                ))))
361            }
362
363            // INSTR is native in Databricks
364            "INSTR" => Ok(Expression::Function(Box::new(f))),
365
366            // LOCATE is native in Databricks
367            "LOCATE" => Ok(Expression::Function(Box::new(f))),
368
369            // ARRAY_LENGTH -> SIZE
370            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
371                Function::new("SIZE".to_string(), f.args),
372            ))),
373
374            // CARDINALITY -> SIZE
375            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
376                Function::new("SIZE".to_string(), f.args),
377            ))),
378
379            // SIZE is native
380            "SIZE" => Ok(Expression::Function(Box::new(f))),
381
382            // ARRAY_CONTAINS is native in Databricks
383            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
384
385            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
386            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
387            "CONTAINS" if f.args.len() == 2 => {
388                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
389                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
390                    && matches!(&f.args[1], Expression::Lower(_));
391                if is_string_contains {
392                    Ok(Expression::Function(Box::new(f)))
393                } else {
394                    Ok(Expression::Function(Box::new(Function::new(
395                        "ARRAY_CONTAINS".to_string(),
396                        f.args,
397                    ))))
398                }
399            }
400
401            // TO_DATE is native in Databricks
402            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
403
404            // TO_TIMESTAMP is native in Databricks
405            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
406
407            // DATE_FORMAT is native in Databricks
408            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
409
410            // strftime -> DATE_FORMAT in Databricks
411            "STRFTIME" if f.args.len() >= 2 => {
412                let mut args = f.args;
413                let format = args.remove(0);
414                let date = args.remove(0);
415                Ok(Expression::Function(Box::new(Function::new(
416                    "DATE_FORMAT".to_string(),
417                    vec![date, format],
418                ))))
419            }
420
421            // TO_CHAR is supported natively in Databricks (unlike Spark)
422            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
423
424            // DATE_TRUNC is native in Databricks
425            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
426
427            // DATEADD is native in Databricks - uppercase the unit if present
428            "DATEADD" => {
429                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
430                Ok(Expression::Function(Box::new(Function::new(
431                    "DATEADD".to_string(),
432                    transformed_args,
433                ))))
434            }
435
436            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
437            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
438            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
439            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
440            "DATE_ADD" => {
441                if f.args.len() == 2 {
442                    let is_simple_number = matches!(
443                        &f.args[1],
444                        Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::Number(_))
445                    ) || matches!(&f.args[1], Expression::Neg(_));
446                    if is_simple_number {
447                        // Keep as DATE_ADD(date, num_days)
448                        Ok(Expression::Function(Box::new(Function::new(
449                            "DATE_ADD".to_string(),
450                            f.args,
451                        ))))
452                    } else {
453                        let mut args = f.args;
454                        let date = args.remove(0);
455                        let interval = args.remove(0);
456                        let unit = Expression::Identifier(crate::expressions::Identifier {
457                            name: "DAY".to_string(),
458                            quoted: false,
459                            trailing_comments: Vec::new(),
460                            span: None,
461                        });
462                        Ok(Expression::Function(Box::new(Function::new(
463                            "DATEADD".to_string(),
464                            vec![unit, interval, date],
465                        ))))
466                    }
467                } else {
468                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
469                    Ok(Expression::Function(Box::new(Function::new(
470                        "DATE_ADD".to_string(),
471                        transformed_args,
472                    ))))
473                }
474            }
475
476            // DATEDIFF is native in Databricks - uppercase the unit if present
477            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
478            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
479            "DATEDIFF" => {
480                if f.args.len() == 2 {
481                    let mut args = f.args;
482                    let end_date = args.remove(0);
483                    let start_date = args.remove(0);
484                    let unit = Expression::Identifier(crate::expressions::Identifier {
485                        name: "DAY".to_string(),
486                        quoted: false,
487                        trailing_comments: Vec::new(),
488                        span: None,
489                    });
490                    Ok(Expression::Function(Box::new(Function::new(
491                        "DATEDIFF".to_string(),
492                        vec![unit, start_date, end_date],
493                    ))))
494                } else {
495                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
496                    Ok(Expression::Function(Box::new(Function::new(
497                        "DATEDIFF".to_string(),
498                        transformed_args,
499                    ))))
500                }
501            }
502
503            // DATE_DIFF -> DATEDIFF with uppercased unit
504            "DATE_DIFF" => {
505                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
506                Ok(Expression::Function(Box::new(Function::new(
507                    "DATEDIFF".to_string(),
508                    transformed_args,
509                ))))
510            }
511
512            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
513            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
514
515            // JSON_EXTRACT_SCALAR -> same handling
516            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
517
518            // GET_JSON_OBJECT -> colon syntax in Databricks
519            // GET_JSON_OBJECT(col, '$.path') becomes col:path
520            "GET_JSON_OBJECT" if f.args.len() == 2 => {
521                let mut args = f.args;
522                let col = args.remove(0);
523                let path_arg = args.remove(0);
524
525                // Extract and strip the $. prefix from the path
526                let path_expr = match &path_arg {
527                    Expression::Literal(lit)
528                        if matches!(lit.as_ref(), crate::expressions::Literal::String(_)) =>
529                    {
530                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
531                            unreachable!()
532                        };
533                        // Strip leading '$.' if present
534                        let stripped = if s.starts_with("$.") {
535                            &s[2..]
536                        } else if s.starts_with("$") {
537                            &s[1..]
538                        } else {
539                            s.as_str()
540                        };
541                        Expression::Literal(Box::new(crate::expressions::Literal::String(
542                            stripped.to_string(),
543                        )))
544                    }
545                    _ => path_arg,
546                };
547
548                Ok(Expression::JSONExtract(Box::new(JSONExtract {
549                    this: Box::new(col),
550                    expression: Box::new(path_expr),
551                    only_json_types: None,
552                    expressions: Vec::new(),
553                    variant_extract: Some(Box::new(Expression::true_())),
554                    json_query: None,
555                    option: None,
556                    quote: None,
557                    on_condition: None,
558                    requires_json: None,
559                })))
560            }
561
562            // FROM_JSON is native in Databricks
563            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
564
565            // PARSE_JSON is native in Databricks
566            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
567
568            // COLLECT_LIST is native in Databricks
569            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
570
571            // COLLECT_SET is native in Databricks
572            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
573
574            // RLIKE is native in Databricks
575            "RLIKE" => Ok(Expression::Function(Box::new(f))),
576
577            // REGEXP -> RLIKE in Databricks
578            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
579                "RLIKE".to_string(),
580                f.args,
581            )))),
582
583            // REGEXP_LIKE is native in Databricks
584            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
585
586            // LEVENSHTEIN is native in Databricks
587            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
588
589            // SEQUENCE is native (for GENERATE_SERIES)
590            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
591                Function::new("SEQUENCE".to_string(), f.args),
592            ))),
593
594            // SEQUENCE is native
595            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
596
597            // FLATTEN is native in Databricks
598            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
599
600            // ARRAY_SORT is native
601            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
602
603            // ARRAY_DISTINCT is native
604            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
605
606            // TRANSFORM is native (for array transformation)
607            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
608
609            // FILTER is native (for array filtering)
610            "FILTER" => Ok(Expression::Function(Box::new(f))),
611
612            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
613            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
614                let mut args = f.args;
615                let first_arg = args.remove(0);
616
617                // Check if first arg is already a Cast to TIMESTAMP
618                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
619                    first_arg
620                } else {
621                    // Wrap in CAST(... AS TIMESTAMP)
622                    Expression::Cast(Box::new(Cast {
623                        this: first_arg,
624                        to: DataType::Timestamp {
625                            precision: None,
626                            timezone: false,
627                        },
628                        trailing_comments: Vec::new(),
629                        double_colon_syntax: false,
630                        format: None,
631                        default: None,
632                        inferred_type: None,
633                    }))
634                };
635
636                let mut new_args = vec![wrapped_arg];
637                new_args.extend(args);
638
639                Ok(Expression::Function(Box::new(Function::new(
640                    "FROM_UTC_TIMESTAMP".to_string(),
641                    new_args,
642                ))))
643            }
644
645            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
646            "UNIFORM" if f.args.len() == 3 => {
647                let mut args = f.args;
648                let low = args.remove(0);
649                let high = args.remove(0);
650                let gen = args.remove(0);
651                match gen {
652                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
653                        if func.args.len() == 1 {
654                            // RANDOM(seed) -> extract seed
655                            let seed = func.args.into_iter().next().unwrap();
656                            Ok(Expression::Function(Box::new(Function::new(
657                                "UNIFORM".to_string(),
658                                vec![low, high, seed],
659                            ))))
660                        } else {
661                            // RANDOM() -> drop gen arg
662                            Ok(Expression::Function(Box::new(Function::new(
663                                "UNIFORM".to_string(),
664                                vec![low, high],
665                            ))))
666                        }
667                    }
668                    Expression::Rand(r) => {
669                        if let Some(seed) = r.seed {
670                            Ok(Expression::Function(Box::new(Function::new(
671                                "UNIFORM".to_string(),
672                                vec![low, high, *seed],
673                            ))))
674                        } else {
675                            Ok(Expression::Function(Box::new(Function::new(
676                                "UNIFORM".to_string(),
677                                vec![low, high],
678                            ))))
679                        }
680                    }
681                    _ => Ok(Expression::Function(Box::new(Function::new(
682                        "UNIFORM".to_string(),
683                        vec![low, high, gen],
684                    )))),
685                }
686            }
687
688            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
689            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
690                let subject = f.args[0].clone();
691                let pattern = f.args[1].clone();
692                Ok(Expression::Function(Box::new(Function::new(
693                    "REGEXP_EXTRACT".to_string(),
694                    vec![subject, pattern],
695                ))))
696            }
697
698            // BIT_GET -> GETBIT
699            "BIT_GET" => Ok(Expression::Function(Box::new(Function::new(
700                "GETBIT".to_string(),
701                f.args,
702            )))),
703
704            // Pass through everything else
705            _ => Ok(Expression::Function(Box::new(f))),
706        }
707    }
708
709    fn transform_aggregate_function(
710        &self,
711        f: Box<crate::expressions::AggregateFunction>,
712    ) -> Result<Expression> {
713        let name_upper = f.name.to_uppercase();
714        match name_upper.as_str() {
715            // COUNT_IF is native in Databricks (Spark 3+)
716            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
717
718            // ANY_VALUE is native in Databricks (Spark 3+)
719            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
720
721            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
722            "GROUP_CONCAT" if !f.args.is_empty() => {
723                let mut args = f.args;
724                let first = args.remove(0);
725                let separator = args.pop();
726                let collect_list = Expression::Function(Box::new(Function::new(
727                    "COLLECT_LIST".to_string(),
728                    vec![first],
729                )));
730                if let Some(sep) = separator {
731                    Ok(Expression::Function(Box::new(Function::new(
732                        "ARRAY_JOIN".to_string(),
733                        vec![collect_list, sep],
734                    ))))
735                } else {
736                    Ok(Expression::Function(Box::new(Function::new(
737                        "ARRAY_JOIN".to_string(),
738                        vec![collect_list],
739                    ))))
740                }
741            }
742
743            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
744            "STRING_AGG" if !f.args.is_empty() => {
745                let mut args = f.args;
746                let first = args.remove(0);
747                let separator = args.pop();
748                let collect_list = Expression::Function(Box::new(Function::new(
749                    "COLLECT_LIST".to_string(),
750                    vec![first],
751                )));
752                if let Some(sep) = separator {
753                    Ok(Expression::Function(Box::new(Function::new(
754                        "ARRAY_JOIN".to_string(),
755                        vec![collect_list, sep],
756                    ))))
757                } else {
758                    Ok(Expression::Function(Box::new(Function::new(
759                        "ARRAY_JOIN".to_string(),
760                        vec![collect_list],
761                    ))))
762                }
763            }
764
765            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
766            "LISTAGG" if !f.args.is_empty() => {
767                let mut args = f.args;
768                let first = args.remove(0);
769                let separator = args.pop();
770                let collect_list = Expression::Function(Box::new(Function::new(
771                    "COLLECT_LIST".to_string(),
772                    vec![first],
773                )));
774                if let Some(sep) = separator {
775                    Ok(Expression::Function(Box::new(Function::new(
776                        "ARRAY_JOIN".to_string(),
777                        vec![collect_list, sep],
778                    ))))
779                } else {
780                    Ok(Expression::Function(Box::new(Function::new(
781                        "ARRAY_JOIN".to_string(),
782                        vec![collect_list],
783                    ))))
784                }
785            }
786
787            // ARRAY_AGG -> COLLECT_LIST
788            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
789                "COLLECT_LIST".to_string(),
790                f.args,
791            )))),
792
793            // STDDEV is native in Databricks
794            "STDDEV" => Ok(Expression::AggregateFunction(f)),
795
796            // VARIANCE is native in Databricks
797            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
798
799            // APPROX_COUNT_DISTINCT is native in Databricks
800            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
801
802            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
803            "APPROX_DISTINCT" if !f.args.is_empty() => {
804                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
805                    name: "APPROX_COUNT_DISTINCT".to_string(),
806                    args: f.args,
807                    distinct: f.distinct,
808                    filter: f.filter,
809                    order_by: Vec::new(),
810                    limit: None,
811                    ignore_nulls: None,
812                    inferred_type: None,
813                })))
814            }
815
816            // Pass through everything else
817            _ => Ok(Expression::AggregateFunction(f)),
818        }
819    }
820
821    /// Transform Cast expressions - handles typed literals being cast
822    ///
823    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
824    /// Databricks/Spark transforms it as follows:
825    ///
826    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
827    ///
828    /// This reverses the types - the inner cast is to the target type,
829    /// the outer cast is to the original literal type.
830    fn transform_cast(&self, c: Cast) -> Result<Expression> {
831        // Check if the inner expression is a typed literal
832        match &c.this {
833            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
834            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Timestamp(_)) => {
835                let Literal::Timestamp(value) = lit.as_ref() else {
836                    unreachable!()
837                };
838                // Create inner cast: CAST('value' AS target_type)
839                let inner_cast = Expression::Cast(Box::new(Cast {
840                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
841                    to: c.to,
842                    trailing_comments: Vec::new(),
843                    double_colon_syntax: false,
844                    format: None,
845                    default: None,
846                    inferred_type: None,
847                }));
848                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
849                Ok(Expression::Cast(Box::new(Cast {
850                    this: inner_cast,
851                    to: DataType::Timestamp {
852                        precision: None,
853                        timezone: false,
854                    },
855                    trailing_comments: c.trailing_comments,
856                    double_colon_syntax: false,
857                    format: None,
858                    default: None,
859                    inferred_type: None,
860                })))
861            }
862            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
863            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
864                let Literal::Date(value) = lit.as_ref() else {
865                    unreachable!()
866                };
867                let inner_cast = Expression::Cast(Box::new(Cast {
868                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
869                    to: c.to,
870                    trailing_comments: Vec::new(),
871                    double_colon_syntax: false,
872                    format: None,
873                    default: None,
874                    inferred_type: None,
875                }));
876                Ok(Expression::Cast(Box::new(Cast {
877                    this: inner_cast,
878                    to: DataType::Date,
879                    trailing_comments: c.trailing_comments,
880                    double_colon_syntax: false,
881                    format: None,
882                    default: None,
883                    inferred_type: None,
884                })))
885            }
886            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
887            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Time(_)) => {
888                let Literal::Time(value) = lit.as_ref() else {
889                    unreachable!()
890                };
891                let inner_cast = Expression::Cast(Box::new(Cast {
892                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
893                    to: c.to,
894                    trailing_comments: Vec::new(),
895                    double_colon_syntax: false,
896                    format: None,
897                    default: None,
898                    inferred_type: None,
899                }));
900                Ok(Expression::Cast(Box::new(Cast {
901                    this: inner_cast,
902                    to: DataType::Time {
903                        precision: None,
904                        timezone: false,
905                    },
906                    trailing_comments: c.trailing_comments,
907                    double_colon_syntax: false,
908                    format: None,
909                    default: None,
910                    inferred_type: None,
911                })))
912            }
913            // For all other cases, pass through the Cast unchanged
914            _ => Ok(Expression::Cast(Box::new(c))),
915        }
916    }
917
918    /// Check if an expression is a CAST to TIMESTAMP
919    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
920        if let Expression::Cast(cast) = expr {
921            matches!(cast.to, DataType::Timestamp { .. })
922        } else {
923            false
924        }
925    }
926
927    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
928    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
929        use crate::expressions::Identifier;
930        if !args.is_empty() {
931            match &args[0] {
932                Expression::Identifier(id) => {
933                    args[0] = Expression::Identifier(Identifier {
934                        name: id.name.to_uppercase(),
935                        quoted: id.quoted,
936                        trailing_comments: id.trailing_comments.clone(),
937                        span: None,
938                    });
939                }
940                Expression::Var(v) => {
941                    args[0] = Expression::Identifier(Identifier {
942                        name: v.this.to_uppercase(),
943                        quoted: false,
944                        trailing_comments: Vec::new(),
945                        span: None,
946                    });
947                }
948                Expression::Column(col) if col.table.is_none() => {
949                    // Unqualified column name like "day" should be treated as a unit
950                    args[0] = Expression::Identifier(Identifier {
951                        name: col.name.name.to_uppercase(),
952                        quoted: col.name.quoted,
953                        trailing_comments: col.name.trailing_comments.clone(),
954                        span: None,
955                    });
956                }
957                _ => {}
958            }
959        }
960        args
961    }
962}
963
964#[cfg(test)]
965mod tests {
966    use super::*;
967    use crate::Dialect;
968
969    #[test]
970    fn test_timestamp_literal_cast() {
971        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
972        // This is test [47] in the Databricks dialect identity fixtures
973        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
974        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
975
976        let d = Dialect::get(DialectType::Databricks);
977        let ast = d.parse(sql).expect("Parse failed");
978        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
979        let output = d.generate(&transformed).expect("Generate failed");
980
981        assert_eq!(
982            output, expected,
983            "Timestamp literal cast transformation failed"
984        );
985    }
986
987    #[test]
988    fn test_from_utc_timestamp_wraps_column() {
989        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
990        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
991        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
992
993        let d = Dialect::get(DialectType::Databricks);
994        let ast = d.parse(sql).expect("Parse failed");
995        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
996        let output = d.generate(&transformed).expect("Generate failed");
997
998        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
999    }
1000
1001    #[test]
1002    fn test_from_utc_timestamp_keeps_existing_cast() {
1003        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
1004        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
1005        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
1006        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
1007
1008        let d = Dialect::get(DialectType::Databricks);
1009        let ast = d.parse(sql).expect("Parse failed");
1010        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1011        let output = d.generate(&transformed).expect("Generate failed");
1012
1013        assert_eq!(
1014            output, expected,
1015            "FROM_UTC_TIMESTAMP with existing CAST failed"
1016        );
1017    }
1018
1019    #[test]
1020    fn test_deep_clone_version_as_of() {
1021        let sql = "CREATE TABLE events_clone DEEP CLONE events VERSION AS OF 5";
1022        let d = Dialect::get(DialectType::Databricks);
1023        let ast = d.parse(sql).expect("Parse failed");
1024        let output = d.generate(&ast[0]).expect("Generate failed");
1025
1026        assert_eq!(output, sql);
1027    }
1028
1029    #[test]
1030    fn test_deep_clone_timestamp_as_of() {
1031        let sql = "CREATE TABLE events_clone DEEP CLONE events TIMESTAMP AS OF '2024-01-01'";
1032        let d = Dialect::get(DialectType::Databricks);
1033        let ast = d.parse(sql).expect("Parse failed");
1034        let output = d.generate(&ast[0]).expect("Generate failed");
1035
1036        assert_eq!(output, sql);
1037    }
1038
1039    #[test]
1040    fn test_shallow_clone_still_roundtrips() {
1041        let sql = "CREATE TABLE events_clone SHALLOW CLONE events";
1042        let d = Dialect::get(DialectType::Databricks);
1043        let ast = d.parse(sql).expect("Parse failed");
1044        let output = d.generate(&ast[0]).expect("Generate failed");
1045
1046        assert_eq!(output, sql);
1047    }
1048
1049    #[test]
1050    fn test_repair_table_commands_roundtrip() {
1051        let d = Dialect::get(DialectType::Databricks);
1052        let cases = [
1053            "REPAIR TABLE events",
1054            "MSCK REPAIR TABLE events",
1055            "REPAIR TABLE events ADD PARTITIONS",
1056            "REPAIR TABLE events DROP PARTITIONS",
1057            "REPAIR TABLE events SYNC PARTITIONS",
1058            "REPAIR TABLE events SYNC METADATA",
1059        ];
1060
1061        for sql in cases {
1062            let ast = d.parse(sql).expect("Parse failed");
1063            let output = d.generate(&ast[0]).expect("Generate failed");
1064            assert_eq!(output, sql);
1065        }
1066    }
1067
1068    #[test]
1069    fn test_apply_changes_commands_roundtrip() {
1070        let d = Dialect::get(DialectType::Databricks);
1071        let cases = [
1072            "APPLY CHANGES INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1073            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) IGNORE NULL UPDATES SEQUENCE BY ts",
1074            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) APPLY AS DELETE WHEN operation = 'DELETE' SEQUENCE BY ts COLUMNS * EXCEPT (operation) STORED AS SCD TYPE 1",
1075            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) SEQUENCE BY ts STORED AS SCD TYPE 2 TRACK HISTORY ON * EXCEPT (updated_at)",
1076            "AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1077            "CREATE FLOW apply_cdc AS AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1078        ];
1079
1080        for sql in cases {
1081            let ast = d.parse(sql).expect("Parse failed");
1082            let output = d.generate(&ast[0]).expect("Generate failed");
1083            assert_eq!(output, sql);
1084        }
1085    }
1086
1087    #[test]
1088    fn test_generate_symlink_format_manifest_roundtrip() {
1089        let d = Dialect::get(DialectType::Databricks);
1090        let cases = [
1091            "GENERATE symlink_format_manifest FOR TABLE events",
1092            "GENERATE symlink_format_manifest FOR TABLE catalog.schema.events",
1093        ];
1094
1095        for sql in cases {
1096            let ast = d.parse(sql).expect("Parse failed");
1097            let output = d.generate(&ast[0]).expect("Generate failed");
1098            assert_eq!(output, sql);
1099        }
1100    }
1101
1102    #[test]
1103    fn test_convert_to_delta_roundtrip() {
1104        let d = Dialect::get(DialectType::Databricks);
1105        let cases = [
1106            "CONVERT TO DELTA parquet.`/mnt/data/events`",
1107            "CONVERT TO DELTA database_name.table_name",
1108            "CONVERT TO DELTA parquet.`s3://my-bucket/path/to/table` PARTITIONED BY (date DATE)",
1109            "CONVERT TO DELTA database_name.table_name NO STATISTICS",
1110        ];
1111
1112        for sql in cases {
1113            let ast = d.parse(sql).expect("Parse failed");
1114            let output = d.generate(&ast[0]).expect("Generate failed");
1115            assert_eq!(output, sql);
1116        }
1117    }
1118}