Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Databricks dialect
20pub struct DatabricksDialect;
21
22impl DialectImpl for DatabricksDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Databricks
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Databricks uses backticks for identifiers (NOT double quotes)
30        config.identifiers.clear();
31        config.identifiers.insert('`', '`');
32        // Databricks (like Hive/Spark) uses double quotes as string delimiters
33        config.quotes.insert("\"".to_string(), "\"".to_string());
34        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
35        config.string_escapes.push('\\');
36        // Databricks supports DIV keyword for integer division
37        config
38            .keywords
39            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
40        // Databricks numeric literal suffixes (same as Hive/Spark)
41        config
42            .numeric_literals
43            .insert("L".to_string(), "BIGINT".to_string());
44        config
45            .numeric_literals
46            .insert("S".to_string(), "SMALLINT".to_string());
47        config
48            .numeric_literals
49            .insert("Y".to_string(), "TINYINT".to_string());
50        config
51            .numeric_literals
52            .insert("D".to_string(), "DOUBLE".to_string());
53        config
54            .numeric_literals
55            .insert("F".to_string(), "FLOAT".to_string());
56        config
57            .numeric_literals
58            .insert("BD".to_string(), "DECIMAL".to_string());
59        // Databricks allows identifiers to start with digits (like Hive/Spark)
60        config.identifiers_can_start_with_digit = true;
61        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
62        // Backslashes in raw strings are always literal (no escape processing)
63        config.string_escapes_allowed_in_raw_strings = false;
64        config
65    }
66
67    fn generator_config(&self) -> GeneratorConfig {
68        use crate::generator::IdentifierQuoteStyle;
69        GeneratorConfig {
70            identifier_quote: '`',
71            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
72            dialect: Some(DialectType::Databricks),
73            struct_field_sep: ": ",
74            create_function_return_as: false,
75            tablesample_seed_keyword: "REPEATABLE",
76            identifiers_can_start_with_digit: true,
77            // Databricks uses COMMENT 'value' without = sign
78            schema_comment_with_eq: false,
79            ..Default::default()
80        }
81    }
82
83    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
84        match expr {
85            // IFNULL -> COALESCE in Databricks
86            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
87                original_name: None,
88                expressions: vec![f.this, f.expression],
89            }))),
90
91            // NVL -> COALESCE in Databricks
92            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
93                original_name: None,
94                expressions: vec![f.this, f.expression],
95            }))),
96
97            // TryCast is native in Databricks
98            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
99
100            // SafeCast -> TRY_CAST in Databricks
101            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
102
103            // ILIKE is native in Databricks (Spark 3+)
104            Expression::ILike(op) => Ok(Expression::ILike(op)),
105
106            // UNNEST -> EXPLODE in Databricks
107            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
108
109            // EXPLODE is native to Databricks
110            Expression::Explode(f) => Ok(Expression::Explode(f)),
111
112            // ExplodeOuter is supported
113            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
114
115            // RANDOM -> RAND in Databricks
116            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
117                seed: None,
118                lower: None,
119                upper: None,
120            }))),
121
122            // Rand is native
123            Expression::Rand(r) => Ok(Expression::Rand(r)),
124
125            // || (Concat) -> CONCAT in Databricks
126            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
127                "CONCAT".to_string(),
128                vec![op.left, op.right],
129            )))),
130
131            // RegexpLike is native in Databricks
132            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
133
134            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
135            // This is a complex sqlglot transformation where:
136            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
137            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
138            Expression::Cast(c) => self.transform_cast(*c),
139
140            // Generic function transformations
141            Expression::Function(f) => self.transform_function(*f),
142
143            // Generic aggregate function transformations
144            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
145
146            // DateSub -> DATE_ADD(date, -val) in Databricks
147            Expression::DateSub(f) => {
148                // Convert string literals to numbers (interval values are often stored as strings)
149                let val = match f.interval {
150                    Expression::Literal(crate::expressions::Literal::String(s))
151                        if s.parse::<i64>().is_ok() =>
152                    {
153                        Expression::Literal(crate::expressions::Literal::Number(s))
154                    }
155                    other => other,
156                };
157                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp { this: val }));
158                Ok(Expression::Function(Box::new(Function::new(
159                    "DATE_ADD".to_string(),
160                    vec![f.this, neg_val],
161                ))))
162            }
163
164            // Pass through everything else
165            _ => Ok(expr),
166        }
167    }
168}
169
170impl DatabricksDialect {
171    fn transform_function(&self, f: Function) -> Result<Expression> {
172        let name_upper = f.name.to_uppercase();
173        match name_upper.as_str() {
174            // IFNULL -> COALESCE
175            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
176                original_name: None,
177                expressions: f.args,
178            }))),
179
180            // NVL -> COALESCE
181            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
182                original_name: None,
183                expressions: f.args,
184            }))),
185
186            // ISNULL -> COALESCE
187            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
188                original_name: None,
189                expressions: f.args,
190            }))),
191
192            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
193            "ROW" => Ok(Expression::Function(Box::new(Function::new(
194                "STRUCT".to_string(),
195                f.args,
196            )))),
197
198            // GETDATE -> CURRENT_TIMESTAMP
199            "GETDATE" => Ok(Expression::CurrentTimestamp(
200                crate::expressions::CurrentTimestamp {
201                    precision: None,
202                    sysdate: false,
203                },
204            )),
205
206            // NOW -> CURRENT_TIMESTAMP
207            "NOW" => Ok(Expression::CurrentTimestamp(
208                crate::expressions::CurrentTimestamp {
209                    precision: None,
210                    sysdate: false,
211                },
212            )),
213
214            // CURDATE -> CURRENT_DATE
215            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
216
217            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
218            "CURRENT_DATE" if f.args.is_empty() => {
219                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
220            }
221
222            // RANDOM -> RAND
223            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
224                seed: None,
225                lower: None,
226                upper: None,
227            }))),
228
229            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
230            "GROUP_CONCAT" if !f.args.is_empty() => {
231                let mut args = f.args;
232                let first = args.remove(0);
233                let separator = args.pop();
234                let collect_list = Expression::Function(Box::new(Function::new(
235                    "COLLECT_LIST".to_string(),
236                    vec![first],
237                )));
238                if let Some(sep) = separator {
239                    Ok(Expression::Function(Box::new(Function::new(
240                        "ARRAY_JOIN".to_string(),
241                        vec![collect_list, sep],
242                    ))))
243                } else {
244                    Ok(Expression::Function(Box::new(Function::new(
245                        "ARRAY_JOIN".to_string(),
246                        vec![collect_list],
247                    ))))
248                }
249            }
250
251            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
252            "STRING_AGG" if !f.args.is_empty() => {
253                let mut args = f.args;
254                let first = args.remove(0);
255                let separator = args.pop();
256                let collect_list = Expression::Function(Box::new(Function::new(
257                    "COLLECT_LIST".to_string(),
258                    vec![first],
259                )));
260                if let Some(sep) = separator {
261                    Ok(Expression::Function(Box::new(Function::new(
262                        "ARRAY_JOIN".to_string(),
263                        vec![collect_list, sep],
264                    ))))
265                } else {
266                    Ok(Expression::Function(Box::new(Function::new(
267                        "ARRAY_JOIN".to_string(),
268                        vec![collect_list],
269                    ))))
270                }
271            }
272
273            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
274            "LISTAGG" if !f.args.is_empty() => {
275                let mut args = f.args;
276                let first = args.remove(0);
277                let separator = args.pop();
278                let collect_list = Expression::Function(Box::new(Function::new(
279                    "COLLECT_LIST".to_string(),
280                    vec![first],
281                )));
282                if let Some(sep) = separator {
283                    Ok(Expression::Function(Box::new(Function::new(
284                        "ARRAY_JOIN".to_string(),
285                        vec![collect_list, sep],
286                    ))))
287                } else {
288                    Ok(Expression::Function(Box::new(Function::new(
289                        "ARRAY_JOIN".to_string(),
290                        vec![collect_list],
291                    ))))
292                }
293            }
294
295            // ARRAY_AGG -> COLLECT_LIST in Databricks
296            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
297                "COLLECT_LIST".to_string(),
298                f.args,
299            )))),
300
301            // SUBSTR -> SUBSTRING
302            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
303                "SUBSTRING".to_string(),
304                f.args,
305            )))),
306
307            // LEN -> LENGTH
308            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
309                f.args.into_iter().next().unwrap(),
310            )))),
311
312            // CHARINDEX -> LOCATE (with swapped args, like Spark)
313            "CHARINDEX" if f.args.len() >= 2 => {
314                let mut args = f.args;
315                let substring = args.remove(0);
316                let string = args.remove(0);
317                // LOCATE(substring, string)
318                Ok(Expression::Function(Box::new(Function::new(
319                    "LOCATE".to_string(),
320                    vec![substring, string],
321                ))))
322            }
323
324            // POSITION -> LOCATE
325            "POSITION" if f.args.len() == 2 => {
326                let args = f.args;
327                Ok(Expression::Function(Box::new(Function::new(
328                    "LOCATE".to_string(),
329                    args,
330                ))))
331            }
332
333            // STRPOS -> LOCATE (with same arg order)
334            "STRPOS" if f.args.len() == 2 => {
335                let args = f.args;
336                let string = args[0].clone();
337                let substring = args[1].clone();
338                // STRPOS(string, substring) -> LOCATE(substring, string)
339                Ok(Expression::Function(Box::new(Function::new(
340                    "LOCATE".to_string(),
341                    vec![substring, string],
342                ))))
343            }
344
345            // INSTR is native in Databricks
346            "INSTR" => Ok(Expression::Function(Box::new(f))),
347
348            // LOCATE is native in Databricks
349            "LOCATE" => Ok(Expression::Function(Box::new(f))),
350
351            // ARRAY_LENGTH -> SIZE
352            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
353                Function::new("SIZE".to_string(), f.args),
354            ))),
355
356            // CARDINALITY -> SIZE
357            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
358                Function::new("SIZE".to_string(), f.args),
359            ))),
360
361            // SIZE is native
362            "SIZE" => Ok(Expression::Function(Box::new(f))),
363
364            // ARRAY_CONTAINS is native in Databricks
365            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
366
367            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
368            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
369            "CONTAINS" if f.args.len() == 2 => {
370                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
371                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
372                    && matches!(&f.args[1], Expression::Lower(_));
373                if is_string_contains {
374                    Ok(Expression::Function(Box::new(f)))
375                } else {
376                    Ok(Expression::Function(Box::new(Function::new(
377                        "ARRAY_CONTAINS".to_string(),
378                        f.args,
379                    ))))
380                }
381            }
382
383            // TO_DATE is native in Databricks
384            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
385
386            // TO_TIMESTAMP is native in Databricks
387            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
388
389            // DATE_FORMAT is native in Databricks
390            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
391
392            // strftime -> DATE_FORMAT in Databricks
393            "STRFTIME" if f.args.len() >= 2 => {
394                let mut args = f.args;
395                let format = args.remove(0);
396                let date = args.remove(0);
397                Ok(Expression::Function(Box::new(Function::new(
398                    "DATE_FORMAT".to_string(),
399                    vec![date, format],
400                ))))
401            }
402
403            // TO_CHAR is supported natively in Databricks (unlike Spark)
404            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
405
406            // DATE_TRUNC is native in Databricks
407            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
408
409            // DATEADD is native in Databricks - uppercase the unit if present
410            "DATEADD" => {
411                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
412                Ok(Expression::Function(Box::new(Function::new(
413                    "DATEADD".to_string(),
414                    transformed_args,
415                ))))
416            }
417
418            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
419            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
420            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
421            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
422            "DATE_ADD" => {
423                if f.args.len() == 2 {
424                    let is_simple_number = matches!(
425                        &f.args[1],
426                        Expression::Literal(crate::expressions::Literal::Number(_))
427                            | Expression::Neg(_)
428                    );
429                    if is_simple_number {
430                        // Keep as DATE_ADD(date, num_days)
431                        Ok(Expression::Function(Box::new(Function::new(
432                            "DATE_ADD".to_string(),
433                            f.args,
434                        ))))
435                    } else {
436                        let mut args = f.args;
437                        let date = args.remove(0);
438                        let interval = args.remove(0);
439                        let unit = Expression::Identifier(crate::expressions::Identifier {
440                            name: "DAY".to_string(),
441                            quoted: false,
442                            trailing_comments: Vec::new(),
443                            span: None,
444                        });
445                        Ok(Expression::Function(Box::new(Function::new(
446                            "DATEADD".to_string(),
447                            vec![unit, interval, date],
448                        ))))
449                    }
450                } else {
451                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
452                    Ok(Expression::Function(Box::new(Function::new(
453                        "DATE_ADD".to_string(),
454                        transformed_args,
455                    ))))
456                }
457            }
458
459            // DATEDIFF is native in Databricks - uppercase the unit if present
460            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
461            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
462            "DATEDIFF" => {
463                if f.args.len() == 2 {
464                    let mut args = f.args;
465                    let end_date = args.remove(0);
466                    let start_date = args.remove(0);
467                    let unit = Expression::Identifier(crate::expressions::Identifier {
468                        name: "DAY".to_string(),
469                        quoted: false,
470                        trailing_comments: Vec::new(),
471                        span: None,
472                    });
473                    Ok(Expression::Function(Box::new(Function::new(
474                        "DATEDIFF".to_string(),
475                        vec![unit, start_date, end_date],
476                    ))))
477                } else {
478                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
479                    Ok(Expression::Function(Box::new(Function::new(
480                        "DATEDIFF".to_string(),
481                        transformed_args,
482                    ))))
483                }
484            }
485
486            // DATE_DIFF -> DATEDIFF with uppercased unit
487            "DATE_DIFF" => {
488                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
489                Ok(Expression::Function(Box::new(Function::new(
490                    "DATEDIFF".to_string(),
491                    transformed_args,
492                ))))
493            }
494
495            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
496            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
497
498            // JSON_EXTRACT_SCALAR -> same handling
499            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
500
501            // GET_JSON_OBJECT -> colon syntax in Databricks
502            // GET_JSON_OBJECT(col, '$.path') becomes col:path
503            "GET_JSON_OBJECT" if f.args.len() == 2 => {
504                let mut args = f.args;
505                let col = args.remove(0);
506                let path_arg = args.remove(0);
507
508                // Extract and strip the $. prefix from the path
509                let path_expr = match &path_arg {
510                    Expression::Literal(crate::expressions::Literal::String(s)) => {
511                        // Strip leading '$.' if present
512                        let stripped = if s.starts_with("$.") {
513                            &s[2..]
514                        } else if s.starts_with("$") {
515                            &s[1..]
516                        } else {
517                            s.as_str()
518                        };
519                        Expression::Literal(crate::expressions::Literal::String(
520                            stripped.to_string(),
521                        ))
522                    }
523                    _ => path_arg,
524                };
525
526                Ok(Expression::JSONExtract(Box::new(JSONExtract {
527                    this: Box::new(col),
528                    expression: Box::new(path_expr),
529                    only_json_types: None,
530                    expressions: Vec::new(),
531                    variant_extract: Some(Box::new(Expression::true_())),
532                    json_query: None,
533                    option: None,
534                    quote: None,
535                    on_condition: None,
536                    requires_json: None,
537                })))
538            }
539
540            // FROM_JSON is native in Databricks
541            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
542
543            // PARSE_JSON is native in Databricks
544            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
545
546            // COLLECT_LIST is native in Databricks
547            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
548
549            // COLLECT_SET is native in Databricks
550            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
551
552            // RLIKE is native in Databricks
553            "RLIKE" => Ok(Expression::Function(Box::new(f))),
554
555            // REGEXP -> RLIKE in Databricks
556            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
557                "RLIKE".to_string(),
558                f.args,
559            )))),
560
561            // REGEXP_LIKE is native in Databricks
562            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
563
564            // LEVENSHTEIN is native in Databricks
565            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
566
567            // SEQUENCE is native (for GENERATE_SERIES)
568            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
569                Function::new("SEQUENCE".to_string(), f.args),
570            ))),
571
572            // SEQUENCE is native
573            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
574
575            // FLATTEN is native in Databricks
576            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
577
578            // ARRAY_SORT is native
579            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
580
581            // ARRAY_DISTINCT is native
582            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
583
584            // TRANSFORM is native (for array transformation)
585            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
586
587            // FILTER is native (for array filtering)
588            "FILTER" => Ok(Expression::Function(Box::new(f))),
589
590            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
591            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
592                let mut args = f.args;
593                let first_arg = args.remove(0);
594
595                // Check if first arg is already a Cast to TIMESTAMP
596                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
597                    first_arg
598                } else {
599                    // Wrap in CAST(... AS TIMESTAMP)
600                    Expression::Cast(Box::new(Cast {
601                        this: first_arg,
602                        to: DataType::Timestamp {
603                            precision: None,
604                            timezone: false,
605                        },
606                        trailing_comments: Vec::new(),
607                        double_colon_syntax: false,
608                        format: None,
609                        default: None,
610                    }))
611                };
612
613                let mut new_args = vec![wrapped_arg];
614                new_args.extend(args);
615
616                Ok(Expression::Function(Box::new(Function::new(
617                    "FROM_UTC_TIMESTAMP".to_string(),
618                    new_args,
619                ))))
620            }
621
622            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
623            "UNIFORM" if f.args.len() == 3 => {
624                let mut args = f.args;
625                let low = args.remove(0);
626                let high = args.remove(0);
627                let gen = args.remove(0);
628                match gen {
629                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
630                        if func.args.len() == 1 {
631                            // RANDOM(seed) -> extract seed
632                            let seed = func.args.into_iter().next().unwrap();
633                            Ok(Expression::Function(Box::new(Function::new(
634                                "UNIFORM".to_string(),
635                                vec![low, high, seed],
636                            ))))
637                        } else {
638                            // RANDOM() -> drop gen arg
639                            Ok(Expression::Function(Box::new(Function::new(
640                                "UNIFORM".to_string(),
641                                vec![low, high],
642                            ))))
643                        }
644                    }
645                    Expression::Rand(r) => {
646                        if let Some(seed) = r.seed {
647                            Ok(Expression::Function(Box::new(Function::new(
648                                "UNIFORM".to_string(),
649                                vec![low, high, *seed],
650                            ))))
651                        } else {
652                            Ok(Expression::Function(Box::new(Function::new(
653                                "UNIFORM".to_string(),
654                                vec![low, high],
655                            ))))
656                        }
657                    }
658                    _ => Ok(Expression::Function(Box::new(Function::new(
659                        "UNIFORM".to_string(),
660                        vec![low, high, gen],
661                    )))),
662                }
663            }
664
665            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
666            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
667                let subject = f.args[0].clone();
668                let pattern = f.args[1].clone();
669                Ok(Expression::Function(Box::new(Function::new(
670                    "REGEXP_EXTRACT".to_string(),
671                    vec![subject, pattern],
672                ))))
673            }
674
675            // Pass through everything else
676            _ => Ok(Expression::Function(Box::new(f))),
677        }
678    }
679
680    fn transform_aggregate_function(
681        &self,
682        f: Box<crate::expressions::AggregateFunction>,
683    ) -> Result<Expression> {
684        let name_upper = f.name.to_uppercase();
685        match name_upper.as_str() {
686            // COUNT_IF is native in Databricks (Spark 3+)
687            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
688
689            // ANY_VALUE is native in Databricks (Spark 3+)
690            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
691
692            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
693            "GROUP_CONCAT" if !f.args.is_empty() => {
694                let mut args = f.args;
695                let first = args.remove(0);
696                let separator = args.pop();
697                let collect_list = Expression::Function(Box::new(Function::new(
698                    "COLLECT_LIST".to_string(),
699                    vec![first],
700                )));
701                if let Some(sep) = separator {
702                    Ok(Expression::Function(Box::new(Function::new(
703                        "ARRAY_JOIN".to_string(),
704                        vec![collect_list, sep],
705                    ))))
706                } else {
707                    Ok(Expression::Function(Box::new(Function::new(
708                        "ARRAY_JOIN".to_string(),
709                        vec![collect_list],
710                    ))))
711                }
712            }
713
714            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
715            "STRING_AGG" if !f.args.is_empty() => {
716                let mut args = f.args;
717                let first = args.remove(0);
718                let separator = args.pop();
719                let collect_list = Expression::Function(Box::new(Function::new(
720                    "COLLECT_LIST".to_string(),
721                    vec![first],
722                )));
723                if let Some(sep) = separator {
724                    Ok(Expression::Function(Box::new(Function::new(
725                        "ARRAY_JOIN".to_string(),
726                        vec![collect_list, sep],
727                    ))))
728                } else {
729                    Ok(Expression::Function(Box::new(Function::new(
730                        "ARRAY_JOIN".to_string(),
731                        vec![collect_list],
732                    ))))
733                }
734            }
735
736            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
737            "LISTAGG" if !f.args.is_empty() => {
738                let mut args = f.args;
739                let first = args.remove(0);
740                let separator = args.pop();
741                let collect_list = Expression::Function(Box::new(Function::new(
742                    "COLLECT_LIST".to_string(),
743                    vec![first],
744                )));
745                if let Some(sep) = separator {
746                    Ok(Expression::Function(Box::new(Function::new(
747                        "ARRAY_JOIN".to_string(),
748                        vec![collect_list, sep],
749                    ))))
750                } else {
751                    Ok(Expression::Function(Box::new(Function::new(
752                        "ARRAY_JOIN".to_string(),
753                        vec![collect_list],
754                    ))))
755                }
756            }
757
758            // ARRAY_AGG -> COLLECT_LIST
759            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
760                "COLLECT_LIST".to_string(),
761                f.args,
762            )))),
763
764            // STDDEV is native in Databricks
765            "STDDEV" => Ok(Expression::AggregateFunction(f)),
766
767            // VARIANCE is native in Databricks
768            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
769
770            // APPROX_COUNT_DISTINCT is native in Databricks
771            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
772
773            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
774            "APPROX_DISTINCT" if !f.args.is_empty() => {
775                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
776                    name: "APPROX_COUNT_DISTINCT".to_string(),
777                    args: f.args,
778                    distinct: f.distinct,
779                    filter: f.filter,
780                    order_by: Vec::new(),
781                    limit: None,
782                    ignore_nulls: None,
783                })))
784            }
785
786            // Pass through everything else
787            _ => Ok(Expression::AggregateFunction(f)),
788        }
789    }
790
791    /// Transform Cast expressions - handles typed literals being cast
792    ///
793    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
794    /// Databricks/Spark transforms it as follows:
795    ///
796    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
797    ///
798    /// This reverses the types - the inner cast is to the target type,
799    /// the outer cast is to the original literal type.
800    fn transform_cast(&self, c: Cast) -> Result<Expression> {
801        // Check if the inner expression is a typed literal
802        match &c.this {
803            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
804            Expression::Literal(Literal::Timestamp(value)) => {
805                // Create inner cast: CAST('value' AS target_type)
806                let inner_cast = Expression::Cast(Box::new(Cast {
807                    this: Expression::Literal(Literal::String(value.clone())),
808                    to: c.to,
809                    trailing_comments: Vec::new(),
810                    double_colon_syntax: false,
811                    format: None,
812                    default: None,
813                }));
814                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
815                Ok(Expression::Cast(Box::new(Cast {
816                    this: inner_cast,
817                    to: DataType::Timestamp {
818                        precision: None,
819                        timezone: false,
820                    },
821                    trailing_comments: c.trailing_comments,
822                    double_colon_syntax: false,
823                    format: None,
824                    default: None,
825                })))
826            }
827            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
828            Expression::Literal(Literal::Date(value)) => {
829                let inner_cast = Expression::Cast(Box::new(Cast {
830                    this: Expression::Literal(Literal::String(value.clone())),
831                    to: c.to,
832                    trailing_comments: Vec::new(),
833                    double_colon_syntax: false,
834                    format: None,
835                    default: None,
836                }));
837                Ok(Expression::Cast(Box::new(Cast {
838                    this: inner_cast,
839                    to: DataType::Date,
840                    trailing_comments: c.trailing_comments,
841                    double_colon_syntax: false,
842                    format: None,
843                    default: None,
844                })))
845            }
846            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
847            Expression::Literal(Literal::Time(value)) => {
848                let inner_cast = Expression::Cast(Box::new(Cast {
849                    this: Expression::Literal(Literal::String(value.clone())),
850                    to: c.to,
851                    trailing_comments: Vec::new(),
852                    double_colon_syntax: false,
853                    format: None,
854                    default: None,
855                }));
856                Ok(Expression::Cast(Box::new(Cast {
857                    this: inner_cast,
858                    to: DataType::Time {
859                        precision: None,
860                        timezone: false,
861                    },
862                    trailing_comments: c.trailing_comments,
863                    double_colon_syntax: false,
864                    format: None,
865                    default: None,
866                })))
867            }
868            // For all other cases, pass through the Cast unchanged
869            _ => Ok(Expression::Cast(Box::new(c))),
870        }
871    }
872
873    /// Check if an expression is a CAST to TIMESTAMP
874    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
875        if let Expression::Cast(cast) = expr {
876            matches!(cast.to, DataType::Timestamp { .. })
877        } else {
878            false
879        }
880    }
881
882    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
883    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
884        use crate::expressions::Identifier;
885        if !args.is_empty() {
886            match &args[0] {
887                Expression::Identifier(id) => {
888                    args[0] = Expression::Identifier(Identifier {
889                        name: id.name.to_uppercase(),
890                        quoted: id.quoted,
891                        trailing_comments: id.trailing_comments.clone(),
892                        span: None,
893                    });
894                }
895                Expression::Column(col) if col.table.is_none() => {
896                    // Unqualified column name like "day" should be treated as a unit
897                    args[0] = Expression::Identifier(Identifier {
898                        name: col.name.name.to_uppercase(),
899                        quoted: col.name.quoted,
900                        trailing_comments: col.name.trailing_comments.clone(),
901                        span: None,
902                    });
903                }
904                _ => {}
905            }
906        }
907        args
908    }
909}
910
911#[cfg(test)]
912mod tests {
913    use super::*;
914    use crate::Dialect;
915
916    #[test]
917    fn test_timestamp_literal_cast() {
918        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
919        // This is test [47] in the Databricks dialect identity fixtures
920        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
921        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
922
923        let d = Dialect::get(DialectType::Databricks);
924        let ast = d.parse(sql).expect("Parse failed");
925        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
926        let output = d.generate(&transformed).expect("Generate failed");
927
928        assert_eq!(
929            output, expected,
930            "Timestamp literal cast transformation failed"
931        );
932    }
933
934    #[test]
935    fn test_from_utc_timestamp_wraps_column() {
936        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
937        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
938        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
939
940        let d = Dialect::get(DialectType::Databricks);
941        let ast = d.parse(sql).expect("Parse failed");
942        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
943        let output = d.generate(&transformed).expect("Generate failed");
944
945        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
946    }
947
948    #[test]
949    fn test_from_utc_timestamp_keeps_existing_cast() {
950        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
951        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
952        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
953        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
954
955        let d = Dialect::get(DialectType::Databricks);
956        let ast = d.parse(sql).expect("Parse failed");
957        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
958        let output = d.generate(&transformed).expect("Generate failed");
959
960        assert_eq!(
961            output, expected,
962            "FROM_UTC_TIMESTAMP with existing CAST failed"
963        );
964    }
965}