Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Databricks dialect
20pub struct DatabricksDialect;
21
22impl DialectImpl for DatabricksDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Databricks
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Databricks uses backticks for identifiers (NOT double quotes)
30        config.identifiers.clear();
31        config.identifiers.insert('`', '`');
32        // Databricks (like Hive/Spark) uses double quotes as string delimiters
33        config.quotes.insert("\"".to_string(), "\"".to_string());
34        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
35        config.string_escapes.push('\\');
36        // Databricks supports DIV keyword for integer division
37        config
38            .keywords
39            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
40        // Databricks numeric literal suffixes (same as Hive/Spark)
41        config
42            .numeric_literals
43            .insert("L".to_string(), "BIGINT".to_string());
44        config
45            .numeric_literals
46            .insert("S".to_string(), "SMALLINT".to_string());
47        config
48            .numeric_literals
49            .insert("Y".to_string(), "TINYINT".to_string());
50        config
51            .numeric_literals
52            .insert("D".to_string(), "DOUBLE".to_string());
53        config
54            .numeric_literals
55            .insert("F".to_string(), "FLOAT".to_string());
56        config
57            .numeric_literals
58            .insert("BD".to_string(), "DECIMAL".to_string());
59        // Databricks allows identifiers to start with digits (like Hive/Spark)
60        config.identifiers_can_start_with_digit = true;
61        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
62        // Backslashes in raw strings are always literal (no escape processing)
63        config.string_escapes_allowed_in_raw_strings = false;
64        config
65    }
66
67    fn generator_config(&self) -> GeneratorConfig {
68        use crate::generator::IdentifierQuoteStyle;
69        GeneratorConfig {
70            identifier_quote: '`',
71            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
72            dialect: Some(DialectType::Databricks),
73            struct_field_sep: ": ",
74            create_function_return_as: false,
75            tablesample_seed_keyword: "REPEATABLE",
76            identifiers_can_start_with_digit: true,
77            // Databricks uses COMMENT 'value' without = sign
78            schema_comment_with_eq: false,
79            ..Default::default()
80        }
81    }
82
83    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
84        match expr {
85            // IFNULL -> COALESCE in Databricks
86            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
87                original_name: None,
88                expressions: vec![f.this, f.expression],
89                inferred_type: None,
90            }))),
91
92            // NVL -> COALESCE in Databricks
93            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
94                original_name: None,
95                expressions: vec![f.this, f.expression],
96                inferred_type: None,
97            }))),
98
99            // TryCast is native in Databricks
100            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
101
102            // SafeCast -> TRY_CAST in Databricks
103            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
104
105            // ILIKE is native in Databricks (Spark 3+)
106            Expression::ILike(op) => Ok(Expression::ILike(op)),
107
108            // UNNEST -> EXPLODE in Databricks
109            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
110
111            // EXPLODE is native to Databricks
112            Expression::Explode(f) => Ok(Expression::Explode(f)),
113
114            // ExplodeOuter is supported
115            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
116
117            // RANDOM -> RAND in Databricks
118            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
119                seed: None,
120                lower: None,
121                upper: None,
122            }))),
123
124            // Rand is native
125            Expression::Rand(r) => Ok(Expression::Rand(r)),
126
127            // || (Concat) -> CONCAT in Databricks
128            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
129                "CONCAT".to_string(),
130                vec![op.left, op.right],
131            )))),
132
133            // RegexpLike is native in Databricks
134            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
135
136            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
137            // This is a complex sqlglot transformation where:
138            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
139            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
140            Expression::Cast(c) => self.transform_cast(*c),
141
142            // Generic function transformations
143            Expression::Function(f) => self.transform_function(*f),
144
145            // Generic aggregate function transformations
146            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
147
148            // DateSub -> DATE_ADD(date, -val) in Databricks
149            Expression::DateSub(f) => {
150                // Convert string literals to numbers (interval values are often stored as strings)
151                let val = match f.interval {
152                    Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(s) if s.parse::<i64>().is_ok())
153                        =>
154                    {
155                        let crate::expressions::Literal::String(s) = lit.as_ref() else { unreachable!() };
156                        Expression::Literal(Box::new(crate::expressions::Literal::Number(s.clone())))
157                    }
158                    other => other,
159                };
160                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
161                    this: val,
162                    inferred_type: None,
163                }));
164                Ok(Expression::Function(Box::new(Function::new(
165                    "DATE_ADD".to_string(),
166                    vec![f.this, neg_val],
167                ))))
168            }
169
170            // Pass through everything else
171            _ => Ok(expr),
172        }
173    }
174}
175
176impl DatabricksDialect {
177    fn transform_function(&self, f: Function) -> Result<Expression> {
178        let name_upper = f.name.to_uppercase();
179        match name_upper.as_str() {
180            // IFNULL -> COALESCE
181            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
182                original_name: None,
183                expressions: f.args,
184                inferred_type: None,
185            }))),
186
187            // NVL -> COALESCE
188            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
189                original_name: None,
190                expressions: f.args,
191                inferred_type: None,
192            }))),
193
194            // ISNULL -> COALESCE
195            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
196                original_name: None,
197                expressions: f.args,
198                inferred_type: None,
199            }))),
200
201            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
202            "ROW" => Ok(Expression::Function(Box::new(Function::new(
203                "STRUCT".to_string(),
204                f.args,
205            )))),
206
207            // GETDATE -> CURRENT_TIMESTAMP
208            "GETDATE" => Ok(Expression::CurrentTimestamp(
209                crate::expressions::CurrentTimestamp {
210                    precision: None,
211                    sysdate: false,
212                },
213            )),
214
215            // NOW -> CURRENT_TIMESTAMP
216            "NOW" => Ok(Expression::CurrentTimestamp(
217                crate::expressions::CurrentTimestamp {
218                    precision: None,
219                    sysdate: false,
220                },
221            )),
222
223            // CURDATE -> CURRENT_DATE
224            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
225
226            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
227            "CURRENT_DATE" if f.args.is_empty() => {
228                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
229            }
230
231            // RANDOM -> RAND
232            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
233                seed: None,
234                lower: None,
235                upper: None,
236            }))),
237
238            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
239            "GROUP_CONCAT" if !f.args.is_empty() => {
240                let mut args = f.args;
241                let first = args.remove(0);
242                let separator = args.pop();
243                let collect_list = Expression::Function(Box::new(Function::new(
244                    "COLLECT_LIST".to_string(),
245                    vec![first],
246                )));
247                if let Some(sep) = separator {
248                    Ok(Expression::Function(Box::new(Function::new(
249                        "ARRAY_JOIN".to_string(),
250                        vec![collect_list, sep],
251                    ))))
252                } else {
253                    Ok(Expression::Function(Box::new(Function::new(
254                        "ARRAY_JOIN".to_string(),
255                        vec![collect_list],
256                    ))))
257                }
258            }
259
260            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
261            "STRING_AGG" if !f.args.is_empty() => {
262                let mut args = f.args;
263                let first = args.remove(0);
264                let separator = args.pop();
265                let collect_list = Expression::Function(Box::new(Function::new(
266                    "COLLECT_LIST".to_string(),
267                    vec![first],
268                )));
269                if let Some(sep) = separator {
270                    Ok(Expression::Function(Box::new(Function::new(
271                        "ARRAY_JOIN".to_string(),
272                        vec![collect_list, sep],
273                    ))))
274                } else {
275                    Ok(Expression::Function(Box::new(Function::new(
276                        "ARRAY_JOIN".to_string(),
277                        vec![collect_list],
278                    ))))
279                }
280            }
281
282            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
283            "LISTAGG" if !f.args.is_empty() => {
284                let mut args = f.args;
285                let first = args.remove(0);
286                let separator = args.pop();
287                let collect_list = Expression::Function(Box::new(Function::new(
288                    "COLLECT_LIST".to_string(),
289                    vec![first],
290                )));
291                if let Some(sep) = separator {
292                    Ok(Expression::Function(Box::new(Function::new(
293                        "ARRAY_JOIN".to_string(),
294                        vec![collect_list, sep],
295                    ))))
296                } else {
297                    Ok(Expression::Function(Box::new(Function::new(
298                        "ARRAY_JOIN".to_string(),
299                        vec![collect_list],
300                    ))))
301                }
302            }
303
304            // ARRAY_AGG -> COLLECT_LIST in Databricks
305            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
306                "COLLECT_LIST".to_string(),
307                f.args,
308            )))),
309
310            // SUBSTR -> SUBSTRING
311            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
312                "SUBSTRING".to_string(),
313                f.args,
314            )))),
315
316            // LEN -> LENGTH
317            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
318                f.args.into_iter().next().unwrap(),
319            )))),
320
321            // CHARINDEX -> LOCATE (with swapped args, like Spark)
322            "CHARINDEX" if f.args.len() >= 2 => {
323                let mut args = f.args;
324                let substring = args.remove(0);
325                let string = args.remove(0);
326                // LOCATE(substring, string)
327                Ok(Expression::Function(Box::new(Function::new(
328                    "LOCATE".to_string(),
329                    vec![substring, string],
330                ))))
331            }
332
333            // POSITION -> LOCATE
334            "POSITION" if f.args.len() == 2 => {
335                let args = f.args;
336                Ok(Expression::Function(Box::new(Function::new(
337                    "LOCATE".to_string(),
338                    args,
339                ))))
340            }
341
342            // STRPOS -> LOCATE (with same arg order)
343            "STRPOS" if f.args.len() == 2 => {
344                let args = f.args;
345                let string = args[0].clone();
346                let substring = args[1].clone();
347                // STRPOS(string, substring) -> LOCATE(substring, string)
348                Ok(Expression::Function(Box::new(Function::new(
349                    "LOCATE".to_string(),
350                    vec![substring, string],
351                ))))
352            }
353
354            // INSTR is native in Databricks
355            "INSTR" => Ok(Expression::Function(Box::new(f))),
356
357            // LOCATE is native in Databricks
358            "LOCATE" => Ok(Expression::Function(Box::new(f))),
359
360            // ARRAY_LENGTH -> SIZE
361            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
362                Function::new("SIZE".to_string(), f.args),
363            ))),
364
365            // CARDINALITY -> SIZE
366            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
367                Function::new("SIZE".to_string(), f.args),
368            ))),
369
370            // SIZE is native
371            "SIZE" => Ok(Expression::Function(Box::new(f))),
372
373            // ARRAY_CONTAINS is native in Databricks
374            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
375
376            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
377            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
378            "CONTAINS" if f.args.len() == 2 => {
379                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
380                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
381                    && matches!(&f.args[1], Expression::Lower(_));
382                if is_string_contains {
383                    Ok(Expression::Function(Box::new(f)))
384                } else {
385                    Ok(Expression::Function(Box::new(Function::new(
386                        "ARRAY_CONTAINS".to_string(),
387                        f.args,
388                    ))))
389                }
390            }
391
392            // TO_DATE is native in Databricks
393            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
394
395            // TO_TIMESTAMP is native in Databricks
396            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
397
398            // DATE_FORMAT is native in Databricks
399            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
400
401            // strftime -> DATE_FORMAT in Databricks
402            "STRFTIME" if f.args.len() >= 2 => {
403                let mut args = f.args;
404                let format = args.remove(0);
405                let date = args.remove(0);
406                Ok(Expression::Function(Box::new(Function::new(
407                    "DATE_FORMAT".to_string(),
408                    vec![date, format],
409                ))))
410            }
411
412            // TO_CHAR is supported natively in Databricks (unlike Spark)
413            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
414
415            // DATE_TRUNC is native in Databricks
416            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
417
418            // DATEADD is native in Databricks - uppercase the unit if present
419            "DATEADD" => {
420                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
421                Ok(Expression::Function(Box::new(Function::new(
422                    "DATEADD".to_string(),
423                    transformed_args,
424                ))))
425            }
426
427            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
428            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
429            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
430            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
431            "DATE_ADD" => {
432                if f.args.len() == 2 {
433                    let is_simple_number = matches!(
434                        &f.args[1],
435                        Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::Number(_))
436                    ) || matches!(&f.args[1], Expression::Neg(_));
437                    if is_simple_number {
438                        // Keep as DATE_ADD(date, num_days)
439                        Ok(Expression::Function(Box::new(Function::new(
440                            "DATE_ADD".to_string(),
441                            f.args,
442                        ))))
443                    } else {
444                        let mut args = f.args;
445                        let date = args.remove(0);
446                        let interval = args.remove(0);
447                        let unit = Expression::Identifier(crate::expressions::Identifier {
448                            name: "DAY".to_string(),
449                            quoted: false,
450                            trailing_comments: Vec::new(),
451                            span: None,
452                        });
453                        Ok(Expression::Function(Box::new(Function::new(
454                            "DATEADD".to_string(),
455                            vec![unit, interval, date],
456                        ))))
457                    }
458                } else {
459                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
460                    Ok(Expression::Function(Box::new(Function::new(
461                        "DATE_ADD".to_string(),
462                        transformed_args,
463                    ))))
464                }
465            }
466
467            // DATEDIFF is native in Databricks - uppercase the unit if present
468            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
469            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
470            "DATEDIFF" => {
471                if f.args.len() == 2 {
472                    let mut args = f.args;
473                    let end_date = args.remove(0);
474                    let start_date = args.remove(0);
475                    let unit = Expression::Identifier(crate::expressions::Identifier {
476                        name: "DAY".to_string(),
477                        quoted: false,
478                        trailing_comments: Vec::new(),
479                        span: None,
480                    });
481                    Ok(Expression::Function(Box::new(Function::new(
482                        "DATEDIFF".to_string(),
483                        vec![unit, start_date, end_date],
484                    ))))
485                } else {
486                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
487                    Ok(Expression::Function(Box::new(Function::new(
488                        "DATEDIFF".to_string(),
489                        transformed_args,
490                    ))))
491                }
492            }
493
494            // DATE_DIFF -> DATEDIFF with uppercased unit
495            "DATE_DIFF" => {
496                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
497                Ok(Expression::Function(Box::new(Function::new(
498                    "DATEDIFF".to_string(),
499                    transformed_args,
500                ))))
501            }
502
503            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
504            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
505
506            // JSON_EXTRACT_SCALAR -> same handling
507            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
508
509            // GET_JSON_OBJECT -> colon syntax in Databricks
510            // GET_JSON_OBJECT(col, '$.path') becomes col:path
511            "GET_JSON_OBJECT" if f.args.len() == 2 => {
512                let mut args = f.args;
513                let col = args.remove(0);
514                let path_arg = args.remove(0);
515
516                // Extract and strip the $. prefix from the path
517                let path_expr = match &path_arg {
518                    Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(_)) => {
519                        let crate::expressions::Literal::String(s) = lit.as_ref() else { unreachable!() };
520                        // Strip leading '$.' if present
521                        let stripped = if s.starts_with("$.") {
522                            &s[2..]
523                        } else if s.starts_with("$") {
524                            &s[1..]
525                        } else {
526                            s.as_str()
527                        };
528                        Expression::Literal(Box::new(crate::expressions::Literal::String(
529                            stripped.to_string(),
530                        )))
531                    }
532                    _ => path_arg,
533                };
534
535                Ok(Expression::JSONExtract(Box::new(JSONExtract {
536                    this: Box::new(col),
537                    expression: Box::new(path_expr),
538                    only_json_types: None,
539                    expressions: Vec::new(),
540                    variant_extract: Some(Box::new(Expression::true_())),
541                    json_query: None,
542                    option: None,
543                    quote: None,
544                    on_condition: None,
545                    requires_json: None,
546                })))
547            }
548
549            // FROM_JSON is native in Databricks
550            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
551
552            // PARSE_JSON is native in Databricks
553            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
554
555            // COLLECT_LIST is native in Databricks
556            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
557
558            // COLLECT_SET is native in Databricks
559            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
560
561            // RLIKE is native in Databricks
562            "RLIKE" => Ok(Expression::Function(Box::new(f))),
563
564            // REGEXP -> RLIKE in Databricks
565            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
566                "RLIKE".to_string(),
567                f.args,
568            )))),
569
570            // REGEXP_LIKE is native in Databricks
571            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
572
573            // LEVENSHTEIN is native in Databricks
574            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
575
576            // SEQUENCE is native (for GENERATE_SERIES)
577            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
578                Function::new("SEQUENCE".to_string(), f.args),
579            ))),
580
581            // SEQUENCE is native
582            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
583
584            // FLATTEN is native in Databricks
585            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
586
587            // ARRAY_SORT is native
588            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
589
590            // ARRAY_DISTINCT is native
591            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
592
593            // TRANSFORM is native (for array transformation)
594            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
595
596            // FILTER is native (for array filtering)
597            "FILTER" => Ok(Expression::Function(Box::new(f))),
598
599            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
600            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
601                let mut args = f.args;
602                let first_arg = args.remove(0);
603
604                // Check if first arg is already a Cast to TIMESTAMP
605                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
606                    first_arg
607                } else {
608                    // Wrap in CAST(... AS TIMESTAMP)
609                    Expression::Cast(Box::new(Cast {
610                        this: first_arg,
611                        to: DataType::Timestamp {
612                            precision: None,
613                            timezone: false,
614                        },
615                        trailing_comments: Vec::new(),
616                        double_colon_syntax: false,
617                        format: None,
618                        default: None,
619                        inferred_type: None,
620                    }))
621                };
622
623                let mut new_args = vec![wrapped_arg];
624                new_args.extend(args);
625
626                Ok(Expression::Function(Box::new(Function::new(
627                    "FROM_UTC_TIMESTAMP".to_string(),
628                    new_args,
629                ))))
630            }
631
632            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
633            "UNIFORM" if f.args.len() == 3 => {
634                let mut args = f.args;
635                let low = args.remove(0);
636                let high = args.remove(0);
637                let gen = args.remove(0);
638                match gen {
639                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
640                        if func.args.len() == 1 {
641                            // RANDOM(seed) -> extract seed
642                            let seed = func.args.into_iter().next().unwrap();
643                            Ok(Expression::Function(Box::new(Function::new(
644                                "UNIFORM".to_string(),
645                                vec![low, high, seed],
646                            ))))
647                        } else {
648                            // RANDOM() -> drop gen arg
649                            Ok(Expression::Function(Box::new(Function::new(
650                                "UNIFORM".to_string(),
651                                vec![low, high],
652                            ))))
653                        }
654                    }
655                    Expression::Rand(r) => {
656                        if let Some(seed) = r.seed {
657                            Ok(Expression::Function(Box::new(Function::new(
658                                "UNIFORM".to_string(),
659                                vec![low, high, *seed],
660                            ))))
661                        } else {
662                            Ok(Expression::Function(Box::new(Function::new(
663                                "UNIFORM".to_string(),
664                                vec![low, high],
665                            ))))
666                        }
667                    }
668                    _ => Ok(Expression::Function(Box::new(Function::new(
669                        "UNIFORM".to_string(),
670                        vec![low, high, gen],
671                    )))),
672                }
673            }
674
675            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
676            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
677                let subject = f.args[0].clone();
678                let pattern = f.args[1].clone();
679                Ok(Expression::Function(Box::new(Function::new(
680                    "REGEXP_EXTRACT".to_string(),
681                    vec![subject, pattern],
682                ))))
683            }
684
685            // BIT_GET -> GETBIT
686            "BIT_GET" => Ok(Expression::Function(Box::new(Function::new(
687                "GETBIT".to_string(),
688                f.args,
689            )))),
690
691            // Pass through everything else
692            _ => Ok(Expression::Function(Box::new(f))),
693        }
694    }
695
696    fn transform_aggregate_function(
697        &self,
698        f: Box<crate::expressions::AggregateFunction>,
699    ) -> Result<Expression> {
700        let name_upper = f.name.to_uppercase();
701        match name_upper.as_str() {
702            // COUNT_IF is native in Databricks (Spark 3+)
703            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
704
705            // ANY_VALUE is native in Databricks (Spark 3+)
706            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
707
708            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
709            "GROUP_CONCAT" if !f.args.is_empty() => {
710                let mut args = f.args;
711                let first = args.remove(0);
712                let separator = args.pop();
713                let collect_list = Expression::Function(Box::new(Function::new(
714                    "COLLECT_LIST".to_string(),
715                    vec![first],
716                )));
717                if let Some(sep) = separator {
718                    Ok(Expression::Function(Box::new(Function::new(
719                        "ARRAY_JOIN".to_string(),
720                        vec![collect_list, sep],
721                    ))))
722                } else {
723                    Ok(Expression::Function(Box::new(Function::new(
724                        "ARRAY_JOIN".to_string(),
725                        vec![collect_list],
726                    ))))
727                }
728            }
729
730            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
731            "STRING_AGG" if !f.args.is_empty() => {
732                let mut args = f.args;
733                let first = args.remove(0);
734                let separator = args.pop();
735                let collect_list = Expression::Function(Box::new(Function::new(
736                    "COLLECT_LIST".to_string(),
737                    vec![first],
738                )));
739                if let Some(sep) = separator {
740                    Ok(Expression::Function(Box::new(Function::new(
741                        "ARRAY_JOIN".to_string(),
742                        vec![collect_list, sep],
743                    ))))
744                } else {
745                    Ok(Expression::Function(Box::new(Function::new(
746                        "ARRAY_JOIN".to_string(),
747                        vec![collect_list],
748                    ))))
749                }
750            }
751
752            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
753            "LISTAGG" if !f.args.is_empty() => {
754                let mut args = f.args;
755                let first = args.remove(0);
756                let separator = args.pop();
757                let collect_list = Expression::Function(Box::new(Function::new(
758                    "COLLECT_LIST".to_string(),
759                    vec![first],
760                )));
761                if let Some(sep) = separator {
762                    Ok(Expression::Function(Box::new(Function::new(
763                        "ARRAY_JOIN".to_string(),
764                        vec![collect_list, sep],
765                    ))))
766                } else {
767                    Ok(Expression::Function(Box::new(Function::new(
768                        "ARRAY_JOIN".to_string(),
769                        vec![collect_list],
770                    ))))
771                }
772            }
773
774            // ARRAY_AGG -> COLLECT_LIST
775            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
776                "COLLECT_LIST".to_string(),
777                f.args,
778            )))),
779
780            // STDDEV is native in Databricks
781            "STDDEV" => Ok(Expression::AggregateFunction(f)),
782
783            // VARIANCE is native in Databricks
784            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
785
786            // APPROX_COUNT_DISTINCT is native in Databricks
787            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
788
789            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
790            "APPROX_DISTINCT" if !f.args.is_empty() => {
791                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
792                    name: "APPROX_COUNT_DISTINCT".to_string(),
793                    args: f.args,
794                    distinct: f.distinct,
795                    filter: f.filter,
796                    order_by: Vec::new(),
797                    limit: None,
798                    ignore_nulls: None,
799                    inferred_type: None,
800                })))
801            }
802
803            // Pass through everything else
804            _ => Ok(Expression::AggregateFunction(f)),
805        }
806    }
807
808    /// Transform Cast expressions - handles typed literals being cast
809    ///
810    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
811    /// Databricks/Spark transforms it as follows:
812    ///
813    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
814    ///
815    /// This reverses the types - the inner cast is to the target type,
816    /// the outer cast is to the original literal type.
817    fn transform_cast(&self, c: Cast) -> Result<Expression> {
818        // Check if the inner expression is a typed literal
819        match &c.this {
820            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
821            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Timestamp(_)) => {
822                let Literal::Timestamp(value) = lit.as_ref() else { unreachable!() };
823                // Create inner cast: CAST('value' AS target_type)
824                let inner_cast = Expression::Cast(Box::new(Cast {
825                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
826                    to: c.to,
827                    trailing_comments: Vec::new(),
828                    double_colon_syntax: false,
829                    format: None,
830                    default: None,
831                    inferred_type: None,
832                }));
833                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
834                Ok(Expression::Cast(Box::new(Cast {
835                    this: inner_cast,
836                    to: DataType::Timestamp {
837                        precision: None,
838                        timezone: false,
839                    },
840                    trailing_comments: c.trailing_comments,
841                    double_colon_syntax: false,
842                    format: None,
843                    default: None,
844                    inferred_type: None,
845                })))
846            }
847            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
848            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
849                let Literal::Date(value) = lit.as_ref() else { unreachable!() };
850                let inner_cast = Expression::Cast(Box::new(Cast {
851                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
852                    to: c.to,
853                    trailing_comments: Vec::new(),
854                    double_colon_syntax: false,
855                    format: None,
856                    default: None,
857                    inferred_type: None,
858                }));
859                Ok(Expression::Cast(Box::new(Cast {
860                    this: inner_cast,
861                    to: DataType::Date,
862                    trailing_comments: c.trailing_comments,
863                    double_colon_syntax: false,
864                    format: None,
865                    default: None,
866                    inferred_type: None,
867                })))
868            }
869            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
870            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Time(_)) => {
871                let Literal::Time(value) = lit.as_ref() else { unreachable!() };
872                let inner_cast = Expression::Cast(Box::new(Cast {
873                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
874                    to: c.to,
875                    trailing_comments: Vec::new(),
876                    double_colon_syntax: false,
877                    format: None,
878                    default: None,
879                    inferred_type: None,
880                }));
881                Ok(Expression::Cast(Box::new(Cast {
882                    this: inner_cast,
883                    to: DataType::Time {
884                        precision: None,
885                        timezone: false,
886                    },
887                    trailing_comments: c.trailing_comments,
888                    double_colon_syntax: false,
889                    format: None,
890                    default: None,
891                    inferred_type: None,
892                })))
893            }
894            // For all other cases, pass through the Cast unchanged
895            _ => Ok(Expression::Cast(Box::new(c))),
896        }
897    }
898
899    /// Check if an expression is a CAST to TIMESTAMP
900    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
901        if let Expression::Cast(cast) = expr {
902            matches!(cast.to, DataType::Timestamp { .. })
903        } else {
904            false
905        }
906    }
907
908    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
909    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
910        use crate::expressions::Identifier;
911        if !args.is_empty() {
912            match &args[0] {
913                Expression::Identifier(id) => {
914                    args[0] = Expression::Identifier(Identifier {
915                        name: id.name.to_uppercase(),
916                        quoted: id.quoted,
917                        trailing_comments: id.trailing_comments.clone(),
918                        span: None,
919                    });
920                }
921                Expression::Column(col) if col.table.is_none() => {
922                    // Unqualified column name like "day" should be treated as a unit
923                    args[0] = Expression::Identifier(Identifier {
924                        name: col.name.name.to_uppercase(),
925                        quoted: col.name.quoted,
926                        trailing_comments: col.name.trailing_comments.clone(),
927                        span: None,
928                    });
929                }
930                _ => {}
931            }
932        }
933        args
934    }
935}
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940    use crate::Dialect;
941
942    #[test]
943    fn test_timestamp_literal_cast() {
944        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
945        // This is test [47] in the Databricks dialect identity fixtures
946        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
947        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
948
949        let d = Dialect::get(DialectType::Databricks);
950        let ast = d.parse(sql).expect("Parse failed");
951        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
952        let output = d.generate(&transformed).expect("Generate failed");
953
954        assert_eq!(
955            output, expected,
956            "Timestamp literal cast transformation failed"
957        );
958    }
959
960    #[test]
961    fn test_from_utc_timestamp_wraps_column() {
962        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
963        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
964        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
965
966        let d = Dialect::get(DialectType::Databricks);
967        let ast = d.parse(sql).expect("Parse failed");
968        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
969        let output = d.generate(&transformed).expect("Generate failed");
970
971        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
972    }
973
974    #[test]
975    fn test_from_utc_timestamp_keeps_existing_cast() {
976        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
977        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
978        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
979        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
980
981        let d = Dialect::get(DialectType::Databricks);
982        let ast = d.parse(sql).expect("Parse failed");
983        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
984        let output = d.generate(&transformed).expect("Generate failed");
985
986        assert_eq!(
987            output, expected,
988            "FROM_UTC_TIMESTAMP with existing CAST failed"
989        );
990    }
991}