Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Databricks dialect
20pub struct DatabricksDialect;
21
22impl DialectImpl for DatabricksDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Databricks
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Databricks uses backticks for identifiers (NOT double quotes)
30        config.identifiers.clear();
31        config.identifiers.insert('`', '`');
32        // Databricks (like Hive/Spark) uses double quotes as string delimiters
33        config.quotes.insert("\"".to_string(), "\"".to_string());
34        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
35        config.string_escapes.push('\\');
36        // Databricks supports DIV keyword for integer division
37        config
38            .keywords
39            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
40        // Databricks numeric literal suffixes (same as Hive/Spark)
41        config
42            .numeric_literals
43            .insert("L".to_string(), "BIGINT".to_string());
44        config
45            .numeric_literals
46            .insert("S".to_string(), "SMALLINT".to_string());
47        config
48            .numeric_literals
49            .insert("Y".to_string(), "TINYINT".to_string());
50        config
51            .numeric_literals
52            .insert("D".to_string(), "DOUBLE".to_string());
53        config
54            .numeric_literals
55            .insert("F".to_string(), "FLOAT".to_string());
56        config
57            .numeric_literals
58            .insert("BD".to_string(), "DECIMAL".to_string());
59        // Databricks allows identifiers to start with digits (like Hive/Spark)
60        config.identifiers_can_start_with_digit = true;
61        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
62        // Backslashes in raw strings are always literal (no escape processing)
63        config.string_escapes_allowed_in_raw_strings = false;
64        config
65    }
66
67    fn generator_config(&self) -> GeneratorConfig {
68        use crate::generator::IdentifierQuoteStyle;
69        GeneratorConfig {
70            identifier_quote: '`',
71            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
72            dialect: Some(DialectType::Databricks),
73            struct_field_sep: ": ",
74            create_function_return_as: false,
75            tablesample_seed_keyword: "REPEATABLE",
76            identifiers_can_start_with_digit: true,
77            // Databricks uses COMMENT 'value' without = sign
78            schema_comment_with_eq: false,
79            ..Default::default()
80        }
81    }
82
83    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
84        match expr {
85            // IFNULL -> COALESCE in Databricks
86            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
87                original_name: None,
88                expressions: vec![f.this, f.expression],
89                inferred_type: None,
90            }))),
91
92            // NVL -> COALESCE in Databricks
93            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
94                original_name: None,
95                expressions: vec![f.this, f.expression],
96                inferred_type: None,
97            }))),
98
99            // TryCast is native in Databricks
100            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
101
102            // SafeCast -> TRY_CAST in Databricks
103            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
104
105            // ILIKE is native in Databricks (Spark 3+)
106            Expression::ILike(op) => Ok(Expression::ILike(op)),
107
108            // UNNEST -> EXPLODE in Databricks
109            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
110
111            // EXPLODE is native to Databricks
112            Expression::Explode(f) => Ok(Expression::Explode(f)),
113
114            // ExplodeOuter is supported
115            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
116
117            // RANDOM -> RAND in Databricks
118            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
119                seed: None,
120                lower: None,
121                upper: None,
122            }))),
123
124            // Rand is native
125            Expression::Rand(r) => Ok(Expression::Rand(r)),
126
127            // || (Concat) -> CONCAT in Databricks
128            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
129                "CONCAT".to_string(),
130                vec![op.left, op.right],
131            )))),
132
133            // RegexpLike is native in Databricks
134            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
135
136            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
137            // This is a complex sqlglot transformation where:
138            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
139            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
140            Expression::Cast(c) => self.transform_cast(*c),
141
142            // Generic function transformations
143            Expression::Function(f) => self.transform_function(*f),
144
145            // Generic aggregate function transformations
146            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
147
148            // DateSub -> DATE_ADD(date, -val) in Databricks
149            Expression::DateSub(f) => {
150                // Convert string literals to numbers (interval values are often stored as strings)
151                let val = match f.interval {
152                    Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(s) if s.parse::<i64>().is_ok()) =>
153                    {
154                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
155                            unreachable!()
156                        };
157                        Expression::Literal(Box::new(crate::expressions::Literal::Number(
158                            s.clone(),
159                        )))
160                    }
161                    other => other,
162                };
163                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
164                    this: val,
165                    inferred_type: None,
166                }));
167                Ok(Expression::Function(Box::new(Function::new(
168                    "DATE_ADD".to_string(),
169                    vec![f.this, neg_val],
170                ))))
171            }
172
173            // Pass through everything else
174            _ => Ok(expr),
175        }
176    }
177}
178
179impl DatabricksDialect {
180    fn transform_function(&self, f: Function) -> Result<Expression> {
181        let name_upper = f.name.to_uppercase();
182        match name_upper.as_str() {
183            // IFNULL -> COALESCE
184            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
185                original_name: None,
186                expressions: f.args,
187                inferred_type: None,
188            }))),
189
190            // NVL -> COALESCE
191            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
192                original_name: None,
193                expressions: f.args,
194                inferred_type: None,
195            }))),
196
197            // ISNULL -> COALESCE
198            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
199                original_name: None,
200                expressions: f.args,
201                inferred_type: None,
202            }))),
203
204            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
205            "ROW" => Ok(Expression::Function(Box::new(Function::new(
206                "STRUCT".to_string(),
207                f.args,
208            )))),
209
210            // GETDATE -> CURRENT_TIMESTAMP
211            "GETDATE" => Ok(Expression::CurrentTimestamp(
212                crate::expressions::CurrentTimestamp {
213                    precision: None,
214                    sysdate: false,
215                },
216            )),
217
218            // NOW -> CURRENT_TIMESTAMP
219            "NOW" => Ok(Expression::CurrentTimestamp(
220                crate::expressions::CurrentTimestamp {
221                    precision: None,
222                    sysdate: false,
223                },
224            )),
225
226            // CURDATE -> CURRENT_DATE
227            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
228
229            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
230            "CURRENT_DATE" if f.args.is_empty() => {
231                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
232            }
233
234            // RANDOM -> RAND
235            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
236                seed: None,
237                lower: None,
238                upper: None,
239            }))),
240
241            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
242            "GROUP_CONCAT" if !f.args.is_empty() => {
243                let mut args = f.args;
244                let first = args.remove(0);
245                let separator = args.pop();
246                let collect_list = Expression::Function(Box::new(Function::new(
247                    "COLLECT_LIST".to_string(),
248                    vec![first],
249                )));
250                if let Some(sep) = separator {
251                    Ok(Expression::Function(Box::new(Function::new(
252                        "ARRAY_JOIN".to_string(),
253                        vec![collect_list, sep],
254                    ))))
255                } else {
256                    Ok(Expression::Function(Box::new(Function::new(
257                        "ARRAY_JOIN".to_string(),
258                        vec![collect_list],
259                    ))))
260                }
261            }
262
263            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
264            "STRING_AGG" if !f.args.is_empty() => {
265                let mut args = f.args;
266                let first = args.remove(0);
267                let separator = args.pop();
268                let collect_list = Expression::Function(Box::new(Function::new(
269                    "COLLECT_LIST".to_string(),
270                    vec![first],
271                )));
272                if let Some(sep) = separator {
273                    Ok(Expression::Function(Box::new(Function::new(
274                        "ARRAY_JOIN".to_string(),
275                        vec![collect_list, sep],
276                    ))))
277                } else {
278                    Ok(Expression::Function(Box::new(Function::new(
279                        "ARRAY_JOIN".to_string(),
280                        vec![collect_list],
281                    ))))
282                }
283            }
284
285            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
286            "LISTAGG" if !f.args.is_empty() => {
287                let mut args = f.args;
288                let first = args.remove(0);
289                let separator = args.pop();
290                let collect_list = Expression::Function(Box::new(Function::new(
291                    "COLLECT_LIST".to_string(),
292                    vec![first],
293                )));
294                if let Some(sep) = separator {
295                    Ok(Expression::Function(Box::new(Function::new(
296                        "ARRAY_JOIN".to_string(),
297                        vec![collect_list, sep],
298                    ))))
299                } else {
300                    Ok(Expression::Function(Box::new(Function::new(
301                        "ARRAY_JOIN".to_string(),
302                        vec![collect_list],
303                    ))))
304                }
305            }
306
307            // ARRAY_AGG -> COLLECT_LIST in Databricks
308            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
309                "COLLECT_LIST".to_string(),
310                f.args,
311            )))),
312
313            // SUBSTR -> SUBSTRING
314            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
315                "SUBSTRING".to_string(),
316                f.args,
317            )))),
318
319            // LEN -> LENGTH
320            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
321                f.args.into_iter().next().unwrap(),
322            )))),
323
324            // CHARINDEX -> LOCATE (with swapped args, like Spark)
325            "CHARINDEX" if f.args.len() >= 2 => {
326                let mut args = f.args;
327                let substring = args.remove(0);
328                let string = args.remove(0);
329                // LOCATE(substring, string)
330                Ok(Expression::Function(Box::new(Function::new(
331                    "LOCATE".to_string(),
332                    vec![substring, string],
333                ))))
334            }
335
336            // POSITION -> LOCATE
337            "POSITION" if f.args.len() == 2 => {
338                let args = f.args;
339                Ok(Expression::Function(Box::new(Function::new(
340                    "LOCATE".to_string(),
341                    args,
342                ))))
343            }
344
345            // STRPOS -> LOCATE (with same arg order)
346            "STRPOS" if f.args.len() == 2 => {
347                let args = f.args;
348                let string = args[0].clone();
349                let substring = args[1].clone();
350                // STRPOS(string, substring) -> LOCATE(substring, string)
351                Ok(Expression::Function(Box::new(Function::new(
352                    "LOCATE".to_string(),
353                    vec![substring, string],
354                ))))
355            }
356
357            // INSTR is native in Databricks
358            "INSTR" => Ok(Expression::Function(Box::new(f))),
359
360            // LOCATE is native in Databricks
361            "LOCATE" => Ok(Expression::Function(Box::new(f))),
362
363            // ARRAY_LENGTH -> SIZE
364            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
365                Function::new("SIZE".to_string(), f.args),
366            ))),
367
368            // CARDINALITY -> SIZE
369            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
370                Function::new("SIZE".to_string(), f.args),
371            ))),
372
373            // SIZE is native
374            "SIZE" => Ok(Expression::Function(Box::new(f))),
375
376            // ARRAY_CONTAINS is native in Databricks
377            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
378
379            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
380            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
381            "CONTAINS" if f.args.len() == 2 => {
382                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
383                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
384                    && matches!(&f.args[1], Expression::Lower(_));
385                if is_string_contains {
386                    Ok(Expression::Function(Box::new(f)))
387                } else {
388                    Ok(Expression::Function(Box::new(Function::new(
389                        "ARRAY_CONTAINS".to_string(),
390                        f.args,
391                    ))))
392                }
393            }
394
395            // TO_DATE is native in Databricks
396            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
397
398            // TO_TIMESTAMP is native in Databricks
399            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
400
401            // DATE_FORMAT is native in Databricks
402            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
403
404            // strftime -> DATE_FORMAT in Databricks
405            "STRFTIME" if f.args.len() >= 2 => {
406                let mut args = f.args;
407                let format = args.remove(0);
408                let date = args.remove(0);
409                Ok(Expression::Function(Box::new(Function::new(
410                    "DATE_FORMAT".to_string(),
411                    vec![date, format],
412                ))))
413            }
414
415            // TO_CHAR is supported natively in Databricks (unlike Spark)
416            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
417
418            // DATE_TRUNC is native in Databricks
419            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
420
421            // DATEADD is native in Databricks - uppercase the unit if present
422            "DATEADD" => {
423                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
424                Ok(Expression::Function(Box::new(Function::new(
425                    "DATEADD".to_string(),
426                    transformed_args,
427                ))))
428            }
429
430            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
431            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
432            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
433            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
434            "DATE_ADD" => {
435                if f.args.len() == 2 {
436                    let is_simple_number = matches!(
437                        &f.args[1],
438                        Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::Number(_))
439                    ) || matches!(&f.args[1], Expression::Neg(_));
440                    if is_simple_number {
441                        // Keep as DATE_ADD(date, num_days)
442                        Ok(Expression::Function(Box::new(Function::new(
443                            "DATE_ADD".to_string(),
444                            f.args,
445                        ))))
446                    } else {
447                        let mut args = f.args;
448                        let date = args.remove(0);
449                        let interval = args.remove(0);
450                        let unit = Expression::Identifier(crate::expressions::Identifier {
451                            name: "DAY".to_string(),
452                            quoted: false,
453                            trailing_comments: Vec::new(),
454                            span: None,
455                        });
456                        Ok(Expression::Function(Box::new(Function::new(
457                            "DATEADD".to_string(),
458                            vec![unit, interval, date],
459                        ))))
460                    }
461                } else {
462                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
463                    Ok(Expression::Function(Box::new(Function::new(
464                        "DATE_ADD".to_string(),
465                        transformed_args,
466                    ))))
467                }
468            }
469
470            // DATEDIFF is native in Databricks - uppercase the unit if present
471            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
472            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
473            "DATEDIFF" => {
474                if f.args.len() == 2 {
475                    let mut args = f.args;
476                    let end_date = args.remove(0);
477                    let start_date = args.remove(0);
478                    let unit = Expression::Identifier(crate::expressions::Identifier {
479                        name: "DAY".to_string(),
480                        quoted: false,
481                        trailing_comments: Vec::new(),
482                        span: None,
483                    });
484                    Ok(Expression::Function(Box::new(Function::new(
485                        "DATEDIFF".to_string(),
486                        vec![unit, start_date, end_date],
487                    ))))
488                } else {
489                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
490                    Ok(Expression::Function(Box::new(Function::new(
491                        "DATEDIFF".to_string(),
492                        transformed_args,
493                    ))))
494                }
495            }
496
497            // DATE_DIFF -> DATEDIFF with uppercased unit
498            "DATE_DIFF" => {
499                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
500                Ok(Expression::Function(Box::new(Function::new(
501                    "DATEDIFF".to_string(),
502                    transformed_args,
503                ))))
504            }
505
506            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
507            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
508
509            // JSON_EXTRACT_SCALAR -> same handling
510            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
511
512            // GET_JSON_OBJECT -> colon syntax in Databricks
513            // GET_JSON_OBJECT(col, '$.path') becomes col:path
514            "GET_JSON_OBJECT" if f.args.len() == 2 => {
515                let mut args = f.args;
516                let col = args.remove(0);
517                let path_arg = args.remove(0);
518
519                // Extract and strip the $. prefix from the path
520                let path_expr = match &path_arg {
521                    Expression::Literal(lit)
522                        if matches!(lit.as_ref(), crate::expressions::Literal::String(_)) =>
523                    {
524                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
525                            unreachable!()
526                        };
527                        // Strip leading '$.' if present
528                        let stripped = if s.starts_with("$.") {
529                            &s[2..]
530                        } else if s.starts_with("$") {
531                            &s[1..]
532                        } else {
533                            s.as_str()
534                        };
535                        Expression::Literal(Box::new(crate::expressions::Literal::String(
536                            stripped.to_string(),
537                        )))
538                    }
539                    _ => path_arg,
540                };
541
542                Ok(Expression::JSONExtract(Box::new(JSONExtract {
543                    this: Box::new(col),
544                    expression: Box::new(path_expr),
545                    only_json_types: None,
546                    expressions: Vec::new(),
547                    variant_extract: Some(Box::new(Expression::true_())),
548                    json_query: None,
549                    option: None,
550                    quote: None,
551                    on_condition: None,
552                    requires_json: None,
553                })))
554            }
555
556            // FROM_JSON is native in Databricks
557            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
558
559            // PARSE_JSON is native in Databricks
560            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
561
562            // COLLECT_LIST is native in Databricks
563            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
564
565            // COLLECT_SET is native in Databricks
566            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
567
568            // RLIKE is native in Databricks
569            "RLIKE" => Ok(Expression::Function(Box::new(f))),
570
571            // REGEXP -> RLIKE in Databricks
572            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
573                "RLIKE".to_string(),
574                f.args,
575            )))),
576
577            // REGEXP_LIKE is native in Databricks
578            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
579
580            // LEVENSHTEIN is native in Databricks
581            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
582
583            // SEQUENCE is native (for GENERATE_SERIES)
584            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
585                Function::new("SEQUENCE".to_string(), f.args),
586            ))),
587
588            // SEQUENCE is native
589            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
590
591            // FLATTEN is native in Databricks
592            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
593
594            // ARRAY_SORT is native
595            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
596
597            // ARRAY_DISTINCT is native
598            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
599
600            // TRANSFORM is native (for array transformation)
601            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
602
603            // FILTER is native (for array filtering)
604            "FILTER" => Ok(Expression::Function(Box::new(f))),
605
606            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
607            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
608                let mut args = f.args;
609                let first_arg = args.remove(0);
610
611                // Check if first arg is already a Cast to TIMESTAMP
612                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
613                    first_arg
614                } else {
615                    // Wrap in CAST(... AS TIMESTAMP)
616                    Expression::Cast(Box::new(Cast {
617                        this: first_arg,
618                        to: DataType::Timestamp {
619                            precision: None,
620                            timezone: false,
621                        },
622                        trailing_comments: Vec::new(),
623                        double_colon_syntax: false,
624                        format: None,
625                        default: None,
626                        inferred_type: None,
627                    }))
628                };
629
630                let mut new_args = vec![wrapped_arg];
631                new_args.extend(args);
632
633                Ok(Expression::Function(Box::new(Function::new(
634                    "FROM_UTC_TIMESTAMP".to_string(),
635                    new_args,
636                ))))
637            }
638
639            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
640            "UNIFORM" if f.args.len() == 3 => {
641                let mut args = f.args;
642                let low = args.remove(0);
643                let high = args.remove(0);
644                let gen = args.remove(0);
645                match gen {
646                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
647                        if func.args.len() == 1 {
648                            // RANDOM(seed) -> extract seed
649                            let seed = func.args.into_iter().next().unwrap();
650                            Ok(Expression::Function(Box::new(Function::new(
651                                "UNIFORM".to_string(),
652                                vec![low, high, seed],
653                            ))))
654                        } else {
655                            // RANDOM() -> drop gen arg
656                            Ok(Expression::Function(Box::new(Function::new(
657                                "UNIFORM".to_string(),
658                                vec![low, high],
659                            ))))
660                        }
661                    }
662                    Expression::Rand(r) => {
663                        if let Some(seed) = r.seed {
664                            Ok(Expression::Function(Box::new(Function::new(
665                                "UNIFORM".to_string(),
666                                vec![low, high, *seed],
667                            ))))
668                        } else {
669                            Ok(Expression::Function(Box::new(Function::new(
670                                "UNIFORM".to_string(),
671                                vec![low, high],
672                            ))))
673                        }
674                    }
675                    _ => Ok(Expression::Function(Box::new(Function::new(
676                        "UNIFORM".to_string(),
677                        vec![low, high, gen],
678                    )))),
679                }
680            }
681
682            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
683            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
684                let subject = f.args[0].clone();
685                let pattern = f.args[1].clone();
686                Ok(Expression::Function(Box::new(Function::new(
687                    "REGEXP_EXTRACT".to_string(),
688                    vec![subject, pattern],
689                ))))
690            }
691
692            // BIT_GET -> GETBIT
693            "BIT_GET" => Ok(Expression::Function(Box::new(Function::new(
694                "GETBIT".to_string(),
695                f.args,
696            )))),
697
698            // Pass through everything else
699            _ => Ok(Expression::Function(Box::new(f))),
700        }
701    }
702
703    fn transform_aggregate_function(
704        &self,
705        f: Box<crate::expressions::AggregateFunction>,
706    ) -> Result<Expression> {
707        let name_upper = f.name.to_uppercase();
708        match name_upper.as_str() {
709            // COUNT_IF is native in Databricks (Spark 3+)
710            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
711
712            // ANY_VALUE is native in Databricks (Spark 3+)
713            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
714
715            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
716            "GROUP_CONCAT" if !f.args.is_empty() => {
717                let mut args = f.args;
718                let first = args.remove(0);
719                let separator = args.pop();
720                let collect_list = Expression::Function(Box::new(Function::new(
721                    "COLLECT_LIST".to_string(),
722                    vec![first],
723                )));
724                if let Some(sep) = separator {
725                    Ok(Expression::Function(Box::new(Function::new(
726                        "ARRAY_JOIN".to_string(),
727                        vec![collect_list, sep],
728                    ))))
729                } else {
730                    Ok(Expression::Function(Box::new(Function::new(
731                        "ARRAY_JOIN".to_string(),
732                        vec![collect_list],
733                    ))))
734                }
735            }
736
737            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
738            "STRING_AGG" if !f.args.is_empty() => {
739                let mut args = f.args;
740                let first = args.remove(0);
741                let separator = args.pop();
742                let collect_list = Expression::Function(Box::new(Function::new(
743                    "COLLECT_LIST".to_string(),
744                    vec![first],
745                )));
746                if let Some(sep) = separator {
747                    Ok(Expression::Function(Box::new(Function::new(
748                        "ARRAY_JOIN".to_string(),
749                        vec![collect_list, sep],
750                    ))))
751                } else {
752                    Ok(Expression::Function(Box::new(Function::new(
753                        "ARRAY_JOIN".to_string(),
754                        vec![collect_list],
755                    ))))
756                }
757            }
758
759            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
760            "LISTAGG" if !f.args.is_empty() => {
761                let mut args = f.args;
762                let first = args.remove(0);
763                let separator = args.pop();
764                let collect_list = Expression::Function(Box::new(Function::new(
765                    "COLLECT_LIST".to_string(),
766                    vec![first],
767                )));
768                if let Some(sep) = separator {
769                    Ok(Expression::Function(Box::new(Function::new(
770                        "ARRAY_JOIN".to_string(),
771                        vec![collect_list, sep],
772                    ))))
773                } else {
774                    Ok(Expression::Function(Box::new(Function::new(
775                        "ARRAY_JOIN".to_string(),
776                        vec![collect_list],
777                    ))))
778                }
779            }
780
781            // ARRAY_AGG -> COLLECT_LIST
782            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
783                "COLLECT_LIST".to_string(),
784                f.args,
785            )))),
786
787            // STDDEV is native in Databricks
788            "STDDEV" => Ok(Expression::AggregateFunction(f)),
789
790            // VARIANCE is native in Databricks
791            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
792
793            // APPROX_COUNT_DISTINCT is native in Databricks
794            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
795
796            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
797            "APPROX_DISTINCT" if !f.args.is_empty() => {
798                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
799                    name: "APPROX_COUNT_DISTINCT".to_string(),
800                    args: f.args,
801                    distinct: f.distinct,
802                    filter: f.filter,
803                    order_by: Vec::new(),
804                    limit: None,
805                    ignore_nulls: None,
806                    inferred_type: None,
807                })))
808            }
809
810            // Pass through everything else
811            _ => Ok(Expression::AggregateFunction(f)),
812        }
813    }
814
815    /// Transform Cast expressions - handles typed literals being cast
816    ///
817    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
818    /// Databricks/Spark transforms it as follows:
819    ///
820    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
821    ///
822    /// This reverses the types - the inner cast is to the target type,
823    /// the outer cast is to the original literal type.
824    fn transform_cast(&self, c: Cast) -> Result<Expression> {
825        // Check if the inner expression is a typed literal
826        match &c.this {
827            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
828            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Timestamp(_)) => {
829                let Literal::Timestamp(value) = lit.as_ref() else {
830                    unreachable!()
831                };
832                // Create inner cast: CAST('value' AS target_type)
833                let inner_cast = Expression::Cast(Box::new(Cast {
834                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
835                    to: c.to,
836                    trailing_comments: Vec::new(),
837                    double_colon_syntax: false,
838                    format: None,
839                    default: None,
840                    inferred_type: None,
841                }));
842                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
843                Ok(Expression::Cast(Box::new(Cast {
844                    this: inner_cast,
845                    to: DataType::Timestamp {
846                        precision: None,
847                        timezone: false,
848                    },
849                    trailing_comments: c.trailing_comments,
850                    double_colon_syntax: false,
851                    format: None,
852                    default: None,
853                    inferred_type: None,
854                })))
855            }
856            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
857            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
858                let Literal::Date(value) = lit.as_ref() else {
859                    unreachable!()
860                };
861                let inner_cast = Expression::Cast(Box::new(Cast {
862                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
863                    to: c.to,
864                    trailing_comments: Vec::new(),
865                    double_colon_syntax: false,
866                    format: None,
867                    default: None,
868                    inferred_type: None,
869                }));
870                Ok(Expression::Cast(Box::new(Cast {
871                    this: inner_cast,
872                    to: DataType::Date,
873                    trailing_comments: c.trailing_comments,
874                    double_colon_syntax: false,
875                    format: None,
876                    default: None,
877                    inferred_type: None,
878                })))
879            }
880            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
881            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Time(_)) => {
882                let Literal::Time(value) = lit.as_ref() else {
883                    unreachable!()
884                };
885                let inner_cast = Expression::Cast(Box::new(Cast {
886                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
887                    to: c.to,
888                    trailing_comments: Vec::new(),
889                    double_colon_syntax: false,
890                    format: None,
891                    default: None,
892                    inferred_type: None,
893                }));
894                Ok(Expression::Cast(Box::new(Cast {
895                    this: inner_cast,
896                    to: DataType::Time {
897                        precision: None,
898                        timezone: false,
899                    },
900                    trailing_comments: c.trailing_comments,
901                    double_colon_syntax: false,
902                    format: None,
903                    default: None,
904                    inferred_type: None,
905                })))
906            }
907            // For all other cases, pass through the Cast unchanged
908            _ => Ok(Expression::Cast(Box::new(c))),
909        }
910    }
911
912    /// Check if an expression is a CAST to TIMESTAMP
913    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
914        if let Expression::Cast(cast) = expr {
915            matches!(cast.to, DataType::Timestamp { .. })
916        } else {
917            false
918        }
919    }
920
921    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
922    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
923        use crate::expressions::Identifier;
924        if !args.is_empty() {
925            match &args[0] {
926                Expression::Identifier(id) => {
927                    args[0] = Expression::Identifier(Identifier {
928                        name: id.name.to_uppercase(),
929                        quoted: id.quoted,
930                        trailing_comments: id.trailing_comments.clone(),
931                        span: None,
932                    });
933                }
934                Expression::Var(v) => {
935                    args[0] = Expression::Identifier(Identifier {
936                        name: v.this.to_uppercase(),
937                        quoted: false,
938                        trailing_comments: Vec::new(),
939                        span: None,
940                    });
941                }
942                Expression::Column(col) if col.table.is_none() => {
943                    // Unqualified column name like "day" should be treated as a unit
944                    args[0] = Expression::Identifier(Identifier {
945                        name: col.name.name.to_uppercase(),
946                        quoted: col.name.quoted,
947                        trailing_comments: col.name.trailing_comments.clone(),
948                        span: None,
949                    });
950                }
951                _ => {}
952            }
953        }
954        args
955    }
956}
957
958#[cfg(test)]
959mod tests {
960    use super::*;
961    use crate::Dialect;
962
963    #[test]
964    fn test_timestamp_literal_cast() {
965        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
966        // This is test [47] in the Databricks dialect identity fixtures
967        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
968        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
969
970        let d = Dialect::get(DialectType::Databricks);
971        let ast = d.parse(sql).expect("Parse failed");
972        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
973        let output = d.generate(&transformed).expect("Generate failed");
974
975        assert_eq!(
976            output, expected,
977            "Timestamp literal cast transformation failed"
978        );
979    }
980
981    #[test]
982    fn test_from_utc_timestamp_wraps_column() {
983        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
984        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
985        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
986
987        let d = Dialect::get(DialectType::Databricks);
988        let ast = d.parse(sql).expect("Parse failed");
989        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
990        let output = d.generate(&transformed).expect("Generate failed");
991
992        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
993    }
994
995    #[test]
996    fn test_from_utc_timestamp_keeps_existing_cast() {
997        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
998        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
999        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
1000        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
1001
1002        let d = Dialect::get(DialectType::Databricks);
1003        let ast = d.parse(sql).expect("Parse failed");
1004        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1005        let output = d.generate(&transformed).expect("Generate failed");
1006
1007        assert_eq!(
1008            output, expected,
1009            "FROM_UTC_TIMESTAMP with existing CAST failed"
1010        );
1011    }
1012}