Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Databricks dialect
20pub struct DatabricksDialect;
21
22impl DialectImpl for DatabricksDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Databricks
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Databricks uses backticks for identifiers (NOT double quotes)
30        config.identifiers.clear();
31        config.identifiers.insert('`', '`');
32        // Databricks (like Hive/Spark) uses double quotes as string delimiters
33        config.quotes.insert("\"".to_string(), "\"".to_string());
34        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
35        config.string_escapes.push('\\');
36        // Databricks supports DIV keyword for integer division
37        config
38            .keywords
39            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
40        // Databricks numeric literal suffixes (same as Hive/Spark)
41        config
42            .numeric_literals
43            .insert("L".to_string(), "BIGINT".to_string());
44        config
45            .numeric_literals
46            .insert("S".to_string(), "SMALLINT".to_string());
47        config
48            .numeric_literals
49            .insert("Y".to_string(), "TINYINT".to_string());
50        config
51            .numeric_literals
52            .insert("D".to_string(), "DOUBLE".to_string());
53        config
54            .numeric_literals
55            .insert("F".to_string(), "FLOAT".to_string());
56        config
57            .numeric_literals
58            .insert("BD".to_string(), "DECIMAL".to_string());
59        // Databricks allows identifiers to start with digits (like Hive/Spark)
60        config.identifiers_can_start_with_digit = true;
61        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
62        // Backslashes in raw strings are always literal (no escape processing)
63        config.string_escapes_allowed_in_raw_strings = false;
64        config
65    }
66
67    fn generator_config(&self) -> GeneratorConfig {
68        use crate::generator::IdentifierQuoteStyle;
69        GeneratorConfig {
70            identifier_quote: '`',
71            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
72            dialect: Some(DialectType::Databricks),
73            struct_field_sep: ": ",
74            create_function_return_as: false,
75            tablesample_seed_keyword: "REPEATABLE",
76            identifiers_can_start_with_digit: true,
77            // Databricks uses COMMENT 'value' without = sign
78            schema_comment_with_eq: false,
79            ..Default::default()
80        }
81    }
82
83    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
84        match expr {
85            // IFNULL -> COALESCE in Databricks
86            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
87                original_name: None,
88                expressions: vec![f.this, f.expression],
89            }))),
90
91            // NVL -> COALESCE in Databricks
92            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
93                original_name: None,
94                expressions: vec![f.this, f.expression],
95            }))),
96
97            // TryCast is native in Databricks
98            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
99
100            // SafeCast -> TRY_CAST in Databricks
101            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
102
103            // ILIKE is native in Databricks (Spark 3+)
104            Expression::ILike(op) => Ok(Expression::ILike(op)),
105
106            // UNNEST -> EXPLODE in Databricks
107            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
108
109            // EXPLODE is native to Databricks
110            Expression::Explode(f) => Ok(Expression::Explode(f)),
111
112            // ExplodeOuter is supported
113            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
114
115            // RANDOM -> RAND in Databricks
116            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
117                seed: None,
118                lower: None,
119                upper: None,
120            }))),
121
122            // Rand is native
123            Expression::Rand(r) => Ok(Expression::Rand(r)),
124
125            // || (Concat) -> CONCAT in Databricks
126            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
127                "CONCAT".to_string(),
128                vec![op.left, op.right],
129            )))),
130
131            // RegexpLike is native in Databricks
132            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
133
134            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
135            // This is a complex sqlglot transformation where:
136            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
137            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
138            Expression::Cast(c) => self.transform_cast(*c),
139
140            // Generic function transformations
141            Expression::Function(f) => self.transform_function(*f),
142
143            // Generic aggregate function transformations
144            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
145
146            // DateSub -> DATE_ADD(date, -val) in Databricks
147            Expression::DateSub(f) => {
148                // Convert string literals to numbers (interval values are often stored as strings)
149                let val = match f.interval {
150                    Expression::Literal(crate::expressions::Literal::String(s))
151                        if s.parse::<i64>().is_ok() =>
152                    {
153                        Expression::Literal(crate::expressions::Literal::Number(s))
154                    }
155                    other => other,
156                };
157                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp { this: val }));
158                Ok(Expression::Function(Box::new(Function::new(
159                    "DATE_ADD".to_string(),
160                    vec![f.this, neg_val],
161                ))))
162            }
163
164            // Pass through everything else
165            _ => Ok(expr),
166        }
167    }
168}
169
170impl DatabricksDialect {
171    fn transform_function(&self, f: Function) -> Result<Expression> {
172        let name_upper = f.name.to_uppercase();
173        match name_upper.as_str() {
174            // IFNULL -> COALESCE
175            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
176                original_name: None,
177                expressions: f.args,
178            }))),
179
180            // NVL -> COALESCE
181            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
182                original_name: None,
183                expressions: f.args,
184            }))),
185
186            // ISNULL -> COALESCE
187            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
188                original_name: None,
189                expressions: f.args,
190            }))),
191
192            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
193            "ROW" => Ok(Expression::Function(Box::new(Function::new(
194                "STRUCT".to_string(),
195                f.args,
196            )))),
197
198            // GETDATE -> CURRENT_TIMESTAMP
199            "GETDATE" => Ok(Expression::CurrentTimestamp(
200                crate::expressions::CurrentTimestamp {
201                    precision: None,
202                    sysdate: false,
203                },
204            )),
205
206            // NOW -> CURRENT_TIMESTAMP
207            "NOW" => Ok(Expression::CurrentTimestamp(
208                crate::expressions::CurrentTimestamp {
209                    precision: None,
210                    sysdate: false,
211                },
212            )),
213
214            // CURDATE -> CURRENT_DATE
215            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
216
217            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
218            "CURRENT_DATE" if f.args.is_empty() => {
219                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
220            }
221
222            // RANDOM -> RAND
223            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
224                seed: None,
225                lower: None,
226                upper: None,
227            }))),
228
229            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
230            "GROUP_CONCAT" if !f.args.is_empty() => {
231                let mut args = f.args;
232                let first = args.remove(0);
233                let separator = args.pop();
234                let collect_list = Expression::Function(Box::new(Function::new(
235                    "COLLECT_LIST".to_string(),
236                    vec![first],
237                )));
238                if let Some(sep) = separator {
239                    Ok(Expression::Function(Box::new(Function::new(
240                        "ARRAY_JOIN".to_string(),
241                        vec![collect_list, sep],
242                    ))))
243                } else {
244                    Ok(Expression::Function(Box::new(Function::new(
245                        "ARRAY_JOIN".to_string(),
246                        vec![collect_list],
247                    ))))
248                }
249            }
250
251            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
252            "STRING_AGG" if !f.args.is_empty() => {
253                let mut args = f.args;
254                let first = args.remove(0);
255                let separator = args.pop();
256                let collect_list = Expression::Function(Box::new(Function::new(
257                    "COLLECT_LIST".to_string(),
258                    vec![first],
259                )));
260                if let Some(sep) = separator {
261                    Ok(Expression::Function(Box::new(Function::new(
262                        "ARRAY_JOIN".to_string(),
263                        vec![collect_list, sep],
264                    ))))
265                } else {
266                    Ok(Expression::Function(Box::new(Function::new(
267                        "ARRAY_JOIN".to_string(),
268                        vec![collect_list],
269                    ))))
270                }
271            }
272
273            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
274            "LISTAGG" if !f.args.is_empty() => {
275                let mut args = f.args;
276                let first = args.remove(0);
277                let separator = args.pop();
278                let collect_list = Expression::Function(Box::new(Function::new(
279                    "COLLECT_LIST".to_string(),
280                    vec![first],
281                )));
282                if let Some(sep) = separator {
283                    Ok(Expression::Function(Box::new(Function::new(
284                        "ARRAY_JOIN".to_string(),
285                        vec![collect_list, sep],
286                    ))))
287                } else {
288                    Ok(Expression::Function(Box::new(Function::new(
289                        "ARRAY_JOIN".to_string(),
290                        vec![collect_list],
291                    ))))
292                }
293            }
294
295            // ARRAY_AGG -> COLLECT_LIST in Databricks
296            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
297                "COLLECT_LIST".to_string(),
298                f.args,
299            )))),
300
301            // SUBSTR -> SUBSTRING
302            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
303                "SUBSTRING".to_string(),
304                f.args,
305            )))),
306
307            // LEN -> LENGTH
308            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
309                f.args.into_iter().next().unwrap(),
310            )))),
311
312            // CHARINDEX -> LOCATE (with swapped args, like Spark)
313            "CHARINDEX" if f.args.len() >= 2 => {
314                let mut args = f.args;
315                let substring = args.remove(0);
316                let string = args.remove(0);
317                // LOCATE(substring, string)
318                Ok(Expression::Function(Box::new(Function::new(
319                    "LOCATE".to_string(),
320                    vec![substring, string],
321                ))))
322            }
323
324            // POSITION -> LOCATE
325            "POSITION" if f.args.len() == 2 => {
326                let args = f.args;
327                Ok(Expression::Function(Box::new(Function::new(
328                    "LOCATE".to_string(),
329                    args,
330                ))))
331            }
332
333            // STRPOS -> LOCATE (with same arg order)
334            "STRPOS" if f.args.len() == 2 => {
335                let args = f.args;
336                let string = args[0].clone();
337                let substring = args[1].clone();
338                // STRPOS(string, substring) -> LOCATE(substring, string)
339                Ok(Expression::Function(Box::new(Function::new(
340                    "LOCATE".to_string(),
341                    vec![substring, string],
342                ))))
343            }
344
345            // INSTR is native in Databricks
346            "INSTR" => Ok(Expression::Function(Box::new(f))),
347
348            // LOCATE is native in Databricks
349            "LOCATE" => Ok(Expression::Function(Box::new(f))),
350
351            // ARRAY_LENGTH -> SIZE
352            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
353                Function::new("SIZE".to_string(), f.args),
354            ))),
355
356            // CARDINALITY -> SIZE
357            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
358                Function::new("SIZE".to_string(), f.args),
359            ))),
360
361            // SIZE is native
362            "SIZE" => Ok(Expression::Function(Box::new(f))),
363
364            // ARRAY_CONTAINS is native in Databricks
365            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
366
367            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
368            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
369            "CONTAINS" if f.args.len() == 2 => {
370                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
371                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
372                    && matches!(&f.args[1], Expression::Lower(_));
373                if is_string_contains {
374                    Ok(Expression::Function(Box::new(f)))
375                } else {
376                    Ok(Expression::Function(Box::new(Function::new(
377                        "ARRAY_CONTAINS".to_string(),
378                        f.args,
379                    ))))
380                }
381            }
382
383            // TO_DATE is native in Databricks
384            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
385
386            // TO_TIMESTAMP is native in Databricks
387            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
388
389            // DATE_FORMAT is native in Databricks
390            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
391
392            // strftime -> DATE_FORMAT in Databricks
393            "STRFTIME" if f.args.len() >= 2 => {
394                let mut args = f.args;
395                let format = args.remove(0);
396                let date = args.remove(0);
397                Ok(Expression::Function(Box::new(Function::new(
398                    "DATE_FORMAT".to_string(),
399                    vec![date, format],
400                ))))
401            }
402
403            // TO_CHAR is supported natively in Databricks (unlike Spark)
404            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
405
406            // DATE_TRUNC is native in Databricks
407            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
408
409            // DATEADD is native in Databricks - uppercase the unit if present
410            "DATEADD" => {
411                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
412                Ok(Expression::Function(Box::new(Function::new(
413                    "DATEADD".to_string(),
414                    transformed_args,
415                ))))
416            }
417
418            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
419            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
420            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
421            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
422            "DATE_ADD" => {
423                if f.args.len() == 2 {
424                    let is_simple_number = matches!(
425                        &f.args[1],
426                        Expression::Literal(crate::expressions::Literal::Number(_))
427                            | Expression::Neg(_)
428                    );
429                    if is_simple_number {
430                        // Keep as DATE_ADD(date, num_days)
431                        Ok(Expression::Function(Box::new(Function::new(
432                            "DATE_ADD".to_string(),
433                            f.args,
434                        ))))
435                    } else {
436                        let mut args = f.args;
437                        let date = args.remove(0);
438                        let interval = args.remove(0);
439                        let unit = Expression::Identifier(crate::expressions::Identifier {
440                            name: "DAY".to_string(),
441                            quoted: false,
442                            trailing_comments: Vec::new(),
443                        });
444                        Ok(Expression::Function(Box::new(Function::new(
445                            "DATEADD".to_string(),
446                            vec![unit, interval, date],
447                        ))))
448                    }
449                } else {
450                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
451                    Ok(Expression::Function(Box::new(Function::new(
452                        "DATE_ADD".to_string(),
453                        transformed_args,
454                    ))))
455                }
456            }
457
458            // DATEDIFF is native in Databricks - uppercase the unit if present
459            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
460            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
461            "DATEDIFF" => {
462                if f.args.len() == 2 {
463                    let mut args = f.args;
464                    let end_date = args.remove(0);
465                    let start_date = args.remove(0);
466                    let unit = Expression::Identifier(crate::expressions::Identifier {
467                        name: "DAY".to_string(),
468                        quoted: false,
469                        trailing_comments: Vec::new(),
470                    });
471                    Ok(Expression::Function(Box::new(Function::new(
472                        "DATEDIFF".to_string(),
473                        vec![unit, start_date, end_date],
474                    ))))
475                } else {
476                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
477                    Ok(Expression::Function(Box::new(Function::new(
478                        "DATEDIFF".to_string(),
479                        transformed_args,
480                    ))))
481                }
482            }
483
484            // DATE_DIFF -> DATEDIFF with uppercased unit
485            "DATE_DIFF" => {
486                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
487                Ok(Expression::Function(Box::new(Function::new(
488                    "DATEDIFF".to_string(),
489                    transformed_args,
490                ))))
491            }
492
493            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
494            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
495
496            // JSON_EXTRACT_SCALAR -> same handling
497            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
498
499            // GET_JSON_OBJECT -> colon syntax in Databricks
500            // GET_JSON_OBJECT(col, '$.path') becomes col:path
501            "GET_JSON_OBJECT" if f.args.len() == 2 => {
502                let mut args = f.args;
503                let col = args.remove(0);
504                let path_arg = args.remove(0);
505
506                // Extract and strip the $. prefix from the path
507                let path_expr = match &path_arg {
508                    Expression::Literal(crate::expressions::Literal::String(s)) => {
509                        // Strip leading '$.' if present
510                        let stripped = if s.starts_with("$.") {
511                            &s[2..]
512                        } else if s.starts_with("$") {
513                            &s[1..]
514                        } else {
515                            s.as_str()
516                        };
517                        Expression::Literal(crate::expressions::Literal::String(
518                            stripped.to_string(),
519                        ))
520                    }
521                    _ => path_arg,
522                };
523
524                Ok(Expression::JSONExtract(Box::new(JSONExtract {
525                    this: Box::new(col),
526                    expression: Box::new(path_expr),
527                    only_json_types: None,
528                    expressions: Vec::new(),
529                    variant_extract: Some(Box::new(Expression::true_())),
530                    json_query: None,
531                    option: None,
532                    quote: None,
533                    on_condition: None,
534                    requires_json: None,
535                })))
536            }
537
538            // FROM_JSON is native in Databricks
539            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
540
541            // PARSE_JSON is native in Databricks
542            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
543
544            // COLLECT_LIST is native in Databricks
545            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
546
547            // COLLECT_SET is native in Databricks
548            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
549
550            // RLIKE is native in Databricks
551            "RLIKE" => Ok(Expression::Function(Box::new(f))),
552
553            // REGEXP -> RLIKE in Databricks
554            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
555                "RLIKE".to_string(),
556                f.args,
557            )))),
558
559            // REGEXP_LIKE is native in Databricks
560            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
561
562            // LEVENSHTEIN is native in Databricks
563            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
564
565            // SEQUENCE is native (for GENERATE_SERIES)
566            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
567                Function::new("SEQUENCE".to_string(), f.args),
568            ))),
569
570            // SEQUENCE is native
571            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
572
573            // FLATTEN is native in Databricks
574            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
575
576            // ARRAY_SORT is native
577            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
578
579            // ARRAY_DISTINCT is native
580            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
581
582            // TRANSFORM is native (for array transformation)
583            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
584
585            // FILTER is native (for array filtering)
586            "FILTER" => Ok(Expression::Function(Box::new(f))),
587
588            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
589            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
590                let mut args = f.args;
591                let first_arg = args.remove(0);
592
593                // Check if first arg is already a Cast to TIMESTAMP
594                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
595                    first_arg
596                } else {
597                    // Wrap in CAST(... AS TIMESTAMP)
598                    Expression::Cast(Box::new(Cast {
599                        this: first_arg,
600                        to: DataType::Timestamp {
601                            precision: None,
602                            timezone: false,
603                        },
604                        trailing_comments: Vec::new(),
605                        double_colon_syntax: false,
606                        format: None,
607                        default: None,
608                    }))
609                };
610
611                let mut new_args = vec![wrapped_arg];
612                new_args.extend(args);
613
614                Ok(Expression::Function(Box::new(Function::new(
615                    "FROM_UTC_TIMESTAMP".to_string(),
616                    new_args,
617                ))))
618            }
619
620            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
621            "UNIFORM" if f.args.len() == 3 => {
622                let mut args = f.args;
623                let low = args.remove(0);
624                let high = args.remove(0);
625                let gen = args.remove(0);
626                match gen {
627                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
628                        if func.args.len() == 1 {
629                            // RANDOM(seed) -> extract seed
630                            let seed = func.args.into_iter().next().unwrap();
631                            Ok(Expression::Function(Box::new(Function::new(
632                                "UNIFORM".to_string(),
633                                vec![low, high, seed],
634                            ))))
635                        } else {
636                            // RANDOM() -> drop gen arg
637                            Ok(Expression::Function(Box::new(Function::new(
638                                "UNIFORM".to_string(),
639                                vec![low, high],
640                            ))))
641                        }
642                    }
643                    Expression::Rand(r) => {
644                        if let Some(seed) = r.seed {
645                            Ok(Expression::Function(Box::new(Function::new(
646                                "UNIFORM".to_string(),
647                                vec![low, high, *seed],
648                            ))))
649                        } else {
650                            Ok(Expression::Function(Box::new(Function::new(
651                                "UNIFORM".to_string(),
652                                vec![low, high],
653                            ))))
654                        }
655                    }
656                    _ => Ok(Expression::Function(Box::new(Function::new(
657                        "UNIFORM".to_string(),
658                        vec![low, high, gen],
659                    )))),
660                }
661            }
662
663            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
664            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
665                let subject = f.args[0].clone();
666                let pattern = f.args[1].clone();
667                Ok(Expression::Function(Box::new(Function::new(
668                    "REGEXP_EXTRACT".to_string(),
669                    vec![subject, pattern],
670                ))))
671            }
672
673            // Pass through everything else
674            _ => Ok(Expression::Function(Box::new(f))),
675        }
676    }
677
678    fn transform_aggregate_function(
679        &self,
680        f: Box<crate::expressions::AggregateFunction>,
681    ) -> Result<Expression> {
682        let name_upper = f.name.to_uppercase();
683        match name_upper.as_str() {
684            // COUNT_IF is native in Databricks (Spark 3+)
685            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
686
687            // ANY_VALUE is native in Databricks (Spark 3+)
688            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
689
690            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
691            "GROUP_CONCAT" if !f.args.is_empty() => {
692                let mut args = f.args;
693                let first = args.remove(0);
694                let separator = args.pop();
695                let collect_list = Expression::Function(Box::new(Function::new(
696                    "COLLECT_LIST".to_string(),
697                    vec![first],
698                )));
699                if let Some(sep) = separator {
700                    Ok(Expression::Function(Box::new(Function::new(
701                        "ARRAY_JOIN".to_string(),
702                        vec![collect_list, sep],
703                    ))))
704                } else {
705                    Ok(Expression::Function(Box::new(Function::new(
706                        "ARRAY_JOIN".to_string(),
707                        vec![collect_list],
708                    ))))
709                }
710            }
711
712            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
713            "STRING_AGG" if !f.args.is_empty() => {
714                let mut args = f.args;
715                let first = args.remove(0);
716                let separator = args.pop();
717                let collect_list = Expression::Function(Box::new(Function::new(
718                    "COLLECT_LIST".to_string(),
719                    vec![first],
720                )));
721                if let Some(sep) = separator {
722                    Ok(Expression::Function(Box::new(Function::new(
723                        "ARRAY_JOIN".to_string(),
724                        vec![collect_list, sep],
725                    ))))
726                } else {
727                    Ok(Expression::Function(Box::new(Function::new(
728                        "ARRAY_JOIN".to_string(),
729                        vec![collect_list],
730                    ))))
731                }
732            }
733
734            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
735            "LISTAGG" if !f.args.is_empty() => {
736                let mut args = f.args;
737                let first = args.remove(0);
738                let separator = args.pop();
739                let collect_list = Expression::Function(Box::new(Function::new(
740                    "COLLECT_LIST".to_string(),
741                    vec![first],
742                )));
743                if let Some(sep) = separator {
744                    Ok(Expression::Function(Box::new(Function::new(
745                        "ARRAY_JOIN".to_string(),
746                        vec![collect_list, sep],
747                    ))))
748                } else {
749                    Ok(Expression::Function(Box::new(Function::new(
750                        "ARRAY_JOIN".to_string(),
751                        vec![collect_list],
752                    ))))
753                }
754            }
755
756            // ARRAY_AGG -> COLLECT_LIST
757            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
758                "COLLECT_LIST".to_string(),
759                f.args,
760            )))),
761
762            // STDDEV is native in Databricks
763            "STDDEV" => Ok(Expression::AggregateFunction(f)),
764
765            // VARIANCE is native in Databricks
766            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
767
768            // APPROX_COUNT_DISTINCT is native in Databricks
769            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
770
771            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
772            "APPROX_DISTINCT" if !f.args.is_empty() => {
773                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
774                    name: "APPROX_COUNT_DISTINCT".to_string(),
775                    args: f.args,
776                    distinct: f.distinct,
777                    filter: f.filter,
778                    order_by: Vec::new(),
779                    limit: None,
780                    ignore_nulls: None,
781                })))
782            }
783
784            // Pass through everything else
785            _ => Ok(Expression::AggregateFunction(f)),
786        }
787    }
788
789    /// Transform Cast expressions - handles typed literals being cast
790    ///
791    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
792    /// Databricks/Spark transforms it as follows:
793    ///
794    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
795    ///
796    /// This reverses the types - the inner cast is to the target type,
797    /// the outer cast is to the original literal type.
798    fn transform_cast(&self, c: Cast) -> Result<Expression> {
799        // Check if the inner expression is a typed literal
800        match &c.this {
801            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
802            Expression::Literal(Literal::Timestamp(value)) => {
803                // Create inner cast: CAST('value' AS target_type)
804                let inner_cast = Expression::Cast(Box::new(Cast {
805                    this: Expression::Literal(Literal::String(value.clone())),
806                    to: c.to,
807                    trailing_comments: Vec::new(),
808                    double_colon_syntax: false,
809                    format: None,
810                    default: None,
811                }));
812                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
813                Ok(Expression::Cast(Box::new(Cast {
814                    this: inner_cast,
815                    to: DataType::Timestamp {
816                        precision: None,
817                        timezone: false,
818                    },
819                    trailing_comments: c.trailing_comments,
820                    double_colon_syntax: false,
821                    format: None,
822                    default: None,
823                })))
824            }
825            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
826            Expression::Literal(Literal::Date(value)) => {
827                let inner_cast = Expression::Cast(Box::new(Cast {
828                    this: Expression::Literal(Literal::String(value.clone())),
829                    to: c.to,
830                    trailing_comments: Vec::new(),
831                    double_colon_syntax: false,
832                    format: None,
833                    default: None,
834                }));
835                Ok(Expression::Cast(Box::new(Cast {
836                    this: inner_cast,
837                    to: DataType::Date,
838                    trailing_comments: c.trailing_comments,
839                    double_colon_syntax: false,
840                    format: None,
841                    default: None,
842                })))
843            }
844            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
845            Expression::Literal(Literal::Time(value)) => {
846                let inner_cast = Expression::Cast(Box::new(Cast {
847                    this: Expression::Literal(Literal::String(value.clone())),
848                    to: c.to,
849                    trailing_comments: Vec::new(),
850                    double_colon_syntax: false,
851                    format: None,
852                    default: None,
853                }));
854                Ok(Expression::Cast(Box::new(Cast {
855                    this: inner_cast,
856                    to: DataType::Time {
857                        precision: None,
858                        timezone: false,
859                    },
860                    trailing_comments: c.trailing_comments,
861                    double_colon_syntax: false,
862                    format: None,
863                    default: None,
864                })))
865            }
866            // For all other cases, pass through the Cast unchanged
867            _ => Ok(Expression::Cast(Box::new(c))),
868        }
869    }
870
871    /// Check if an expression is a CAST to TIMESTAMP
872    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
873        if let Expression::Cast(cast) = expr {
874            matches!(cast.to, DataType::Timestamp { .. })
875        } else {
876            false
877        }
878    }
879
880    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
881    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
882        use crate::expressions::Identifier;
883        if !args.is_empty() {
884            match &args[0] {
885                Expression::Identifier(id) => {
886                    args[0] = Expression::Identifier(Identifier {
887                        name: id.name.to_uppercase(),
888                        quoted: id.quoted,
889                        trailing_comments: id.trailing_comments.clone(),
890                    });
891                }
892                Expression::Column(col) if col.table.is_none() => {
893                    // Unqualified column name like "day" should be treated as a unit
894                    args[0] = Expression::Identifier(Identifier {
895                        name: col.name.name.to_uppercase(),
896                        quoted: col.name.quoted,
897                        trailing_comments: col.name.trailing_comments.clone(),
898                    });
899                }
900                _ => {}
901            }
902        }
903        args
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    use super::*;
910    use crate::Dialect;
911
912    #[test]
913    fn test_timestamp_literal_cast() {
914        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
915        // This is test [47] in the Databricks dialect identity fixtures
916        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
917        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
918
919        let d = Dialect::get(DialectType::Databricks);
920        let ast = d.parse(sql).expect("Parse failed");
921        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
922        let output = d.generate(&transformed).expect("Generate failed");
923
924        assert_eq!(
925            output, expected,
926            "Timestamp literal cast transformation failed"
927        );
928    }
929
930    #[test]
931    fn test_from_utc_timestamp_wraps_column() {
932        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
933        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
934        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
935
936        let d = Dialect::get(DialectType::Databricks);
937        let ast = d.parse(sql).expect("Parse failed");
938        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
939        let output = d.generate(&transformed).expect("Generate failed");
940
941        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
942    }
943
944    #[test]
945    fn test_from_utc_timestamp_keeps_existing_cast() {
946        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
947        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
948        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
949        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
950
951        let d = Dialect::get(DialectType::Databricks);
952        let ast = d.parse(sql).expect("Parse failed");
953        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
954        let output = d.generate(&transformed).expect("Generate failed");
955
956        assert_eq!(
957            output, expected,
958            "FROM_UTC_TIMESTAMP with existing CAST failed"
959        );
960    }
961}