Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Databricks dialect
20pub struct DatabricksDialect;
21
22impl DialectImpl for DatabricksDialect {
23    fn dialect_type(&self) -> DialectType {
24        DialectType::Databricks
25    }
26
27    fn tokenizer_config(&self) -> TokenizerConfig {
28        let mut config = TokenizerConfig::default();
29        // Databricks uses backticks for identifiers (NOT double quotes)
30        config.identifiers.clear();
31        config.identifiers.insert('`', '`');
32        // Databricks (like Hive/Spark) uses double quotes as string delimiters
33        config.quotes.insert("\"".to_string(), "\"".to_string());
34        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
35        config.string_escapes.push('\\');
36        // Databricks supports DIV keyword for integer division
37        config
38            .keywords
39            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
40        // Databricks numeric literal suffixes (same as Hive/Spark)
41        config
42            .numeric_literals
43            .insert("L".to_string(), "BIGINT".to_string());
44        config
45            .numeric_literals
46            .insert("S".to_string(), "SMALLINT".to_string());
47        config
48            .numeric_literals
49            .insert("Y".to_string(), "TINYINT".to_string());
50        config
51            .numeric_literals
52            .insert("D".to_string(), "DOUBLE".to_string());
53        config
54            .numeric_literals
55            .insert("F".to_string(), "FLOAT".to_string());
56        config
57            .numeric_literals
58            .insert("BD".to_string(), "DECIMAL".to_string());
59        // Databricks allows identifiers to start with digits (like Hive/Spark)
60        config.identifiers_can_start_with_digit = true;
61        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
62        // Backslashes in raw strings are always literal (no escape processing)
63        config.string_escapes_allowed_in_raw_strings = false;
64        config
65    }
66
67    fn generator_config(&self) -> GeneratorConfig {
68        use crate::generator::IdentifierQuoteStyle;
69        GeneratorConfig {
70            identifier_quote: '`',
71            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
72            dialect: Some(DialectType::Databricks),
73            struct_field_sep: ": ",
74            create_function_return_as: false,
75            tablesample_seed_keyword: "REPEATABLE",
76            identifiers_can_start_with_digit: true,
77            // Databricks uses COMMENT 'value' without = sign
78            schema_comment_with_eq: false,
79            ..Default::default()
80        }
81    }
82
83    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
84        match expr {
85            // IFNULL -> COALESCE in Databricks
86            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
87                original_name: None,
88                expressions: vec![f.this, f.expression],
89                inferred_type: None,
90            }))),
91
92            // NVL -> COALESCE in Databricks
93            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
94                original_name: None,
95                expressions: vec![f.this, f.expression],
96                inferred_type: None,
97            }))),
98
99            // TryCast is native in Databricks
100            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
101
102            // SafeCast -> TRY_CAST in Databricks
103            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
104
105            // ILIKE is native in Databricks (Spark 3+)
106            Expression::ILike(op) => Ok(Expression::ILike(op)),
107
108            // UNNEST -> EXPLODE in Databricks
109            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
110
111            // EXPLODE is native to Databricks
112            Expression::Explode(f) => Ok(Expression::Explode(f)),
113
114            // ExplodeOuter is supported
115            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
116
117            // RANDOM -> RAND in Databricks
118            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
119                seed: None,
120                lower: None,
121                upper: None,
122            }))),
123
124            // Rand is native
125            Expression::Rand(r) => Ok(Expression::Rand(r)),
126
127            // || (Concat) -> CONCAT in Databricks
128            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
129                "CONCAT".to_string(),
130                vec![op.left, op.right],
131            )))),
132
133            // RegexpLike is native in Databricks
134            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
135
136            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
137            // This is a complex sqlglot transformation where:
138            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
139            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
140            Expression::Cast(c) => self.transform_cast(*c),
141
142            // Generic function transformations
143            Expression::Function(f) => self.transform_function(*f),
144
145            // Generic aggregate function transformations
146            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
147
148            // DateSub -> DATE_ADD(date, -val) in Databricks
149            Expression::DateSub(f) => {
150                // Convert string literals to numbers (interval values are often stored as strings)
151                let val = match f.interval {
152                    Expression::Literal(crate::expressions::Literal::String(s))
153                        if s.parse::<i64>().is_ok() =>
154                    {
155                        Expression::Literal(crate::expressions::Literal::Number(s))
156                    }
157                    other => other,
158                };
159                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
160                    this: val,
161                    inferred_type: None,
162                }));
163                Ok(Expression::Function(Box::new(Function::new(
164                    "DATE_ADD".to_string(),
165                    vec![f.this, neg_val],
166                ))))
167            }
168
169            // Pass through everything else
170            _ => Ok(expr),
171        }
172    }
173}
174
175impl DatabricksDialect {
176    fn transform_function(&self, f: Function) -> Result<Expression> {
177        let name_upper = f.name.to_uppercase();
178        match name_upper.as_str() {
179            // IFNULL -> COALESCE
180            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
181                original_name: None,
182                expressions: f.args,
183                inferred_type: None,
184            }))),
185
186            // NVL -> COALESCE
187            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
188                original_name: None,
189                expressions: f.args,
190                inferred_type: None,
191            }))),
192
193            // ISNULL -> COALESCE
194            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
195                original_name: None,
196                expressions: f.args,
197                inferred_type: None,
198            }))),
199
200            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
201            "ROW" => Ok(Expression::Function(Box::new(Function::new(
202                "STRUCT".to_string(),
203                f.args,
204            )))),
205
206            // GETDATE -> CURRENT_TIMESTAMP
207            "GETDATE" => Ok(Expression::CurrentTimestamp(
208                crate::expressions::CurrentTimestamp {
209                    precision: None,
210                    sysdate: false,
211                },
212            )),
213
214            // NOW -> CURRENT_TIMESTAMP
215            "NOW" => Ok(Expression::CurrentTimestamp(
216                crate::expressions::CurrentTimestamp {
217                    precision: None,
218                    sysdate: false,
219                },
220            )),
221
222            // CURDATE -> CURRENT_DATE
223            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
224
225            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
226            "CURRENT_DATE" if f.args.is_empty() => {
227                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
228            }
229
230            // RANDOM -> RAND
231            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
232                seed: None,
233                lower: None,
234                upper: None,
235            }))),
236
237            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
238            "GROUP_CONCAT" if !f.args.is_empty() => {
239                let mut args = f.args;
240                let first = args.remove(0);
241                let separator = args.pop();
242                let collect_list = Expression::Function(Box::new(Function::new(
243                    "COLLECT_LIST".to_string(),
244                    vec![first],
245                )));
246                if let Some(sep) = separator {
247                    Ok(Expression::Function(Box::new(Function::new(
248                        "ARRAY_JOIN".to_string(),
249                        vec![collect_list, sep],
250                    ))))
251                } else {
252                    Ok(Expression::Function(Box::new(Function::new(
253                        "ARRAY_JOIN".to_string(),
254                        vec![collect_list],
255                    ))))
256                }
257            }
258
259            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
260            "STRING_AGG" if !f.args.is_empty() => {
261                let mut args = f.args;
262                let first = args.remove(0);
263                let separator = args.pop();
264                let collect_list = Expression::Function(Box::new(Function::new(
265                    "COLLECT_LIST".to_string(),
266                    vec![first],
267                )));
268                if let Some(sep) = separator {
269                    Ok(Expression::Function(Box::new(Function::new(
270                        "ARRAY_JOIN".to_string(),
271                        vec![collect_list, sep],
272                    ))))
273                } else {
274                    Ok(Expression::Function(Box::new(Function::new(
275                        "ARRAY_JOIN".to_string(),
276                        vec![collect_list],
277                    ))))
278                }
279            }
280
281            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
282            "LISTAGG" if !f.args.is_empty() => {
283                let mut args = f.args;
284                let first = args.remove(0);
285                let separator = args.pop();
286                let collect_list = Expression::Function(Box::new(Function::new(
287                    "COLLECT_LIST".to_string(),
288                    vec![first],
289                )));
290                if let Some(sep) = separator {
291                    Ok(Expression::Function(Box::new(Function::new(
292                        "ARRAY_JOIN".to_string(),
293                        vec![collect_list, sep],
294                    ))))
295                } else {
296                    Ok(Expression::Function(Box::new(Function::new(
297                        "ARRAY_JOIN".to_string(),
298                        vec![collect_list],
299                    ))))
300                }
301            }
302
303            // ARRAY_AGG -> COLLECT_LIST in Databricks
304            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
305                "COLLECT_LIST".to_string(),
306                f.args,
307            )))),
308
309            // SUBSTR -> SUBSTRING
310            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
311                "SUBSTRING".to_string(),
312                f.args,
313            )))),
314
315            // LEN -> LENGTH
316            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
317                f.args.into_iter().next().unwrap(),
318            )))),
319
320            // CHARINDEX -> LOCATE (with swapped args, like Spark)
321            "CHARINDEX" if f.args.len() >= 2 => {
322                let mut args = f.args;
323                let substring = args.remove(0);
324                let string = args.remove(0);
325                // LOCATE(substring, string)
326                Ok(Expression::Function(Box::new(Function::new(
327                    "LOCATE".to_string(),
328                    vec![substring, string],
329                ))))
330            }
331
332            // POSITION -> LOCATE
333            "POSITION" if f.args.len() == 2 => {
334                let args = f.args;
335                Ok(Expression::Function(Box::new(Function::new(
336                    "LOCATE".to_string(),
337                    args,
338                ))))
339            }
340
341            // STRPOS -> LOCATE (with same arg order)
342            "STRPOS" if f.args.len() == 2 => {
343                let args = f.args;
344                let string = args[0].clone();
345                let substring = args[1].clone();
346                // STRPOS(string, substring) -> LOCATE(substring, string)
347                Ok(Expression::Function(Box::new(Function::new(
348                    "LOCATE".to_string(),
349                    vec![substring, string],
350                ))))
351            }
352
353            // INSTR is native in Databricks
354            "INSTR" => Ok(Expression::Function(Box::new(f))),
355
356            // LOCATE is native in Databricks
357            "LOCATE" => Ok(Expression::Function(Box::new(f))),
358
359            // ARRAY_LENGTH -> SIZE
360            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
361                Function::new("SIZE".to_string(), f.args),
362            ))),
363
364            // CARDINALITY -> SIZE
365            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
366                Function::new("SIZE".to_string(), f.args),
367            ))),
368
369            // SIZE is native
370            "SIZE" => Ok(Expression::Function(Box::new(f))),
371
372            // ARRAY_CONTAINS is native in Databricks
373            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
374
375            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
376            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
377            "CONTAINS" if f.args.len() == 2 => {
378                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
379                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
380                    && matches!(&f.args[1], Expression::Lower(_));
381                if is_string_contains {
382                    Ok(Expression::Function(Box::new(f)))
383                } else {
384                    Ok(Expression::Function(Box::new(Function::new(
385                        "ARRAY_CONTAINS".to_string(),
386                        f.args,
387                    ))))
388                }
389            }
390
391            // TO_DATE is native in Databricks
392            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
393
394            // TO_TIMESTAMP is native in Databricks
395            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
396
397            // DATE_FORMAT is native in Databricks
398            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
399
400            // strftime -> DATE_FORMAT in Databricks
401            "STRFTIME" if f.args.len() >= 2 => {
402                let mut args = f.args;
403                let format = args.remove(0);
404                let date = args.remove(0);
405                Ok(Expression::Function(Box::new(Function::new(
406                    "DATE_FORMAT".to_string(),
407                    vec![date, format],
408                ))))
409            }
410
411            // TO_CHAR is supported natively in Databricks (unlike Spark)
412            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
413
414            // DATE_TRUNC is native in Databricks
415            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
416
417            // DATEADD is native in Databricks - uppercase the unit if present
418            "DATEADD" => {
419                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
420                Ok(Expression::Function(Box::new(Function::new(
421                    "DATEADD".to_string(),
422                    transformed_args,
423                ))))
424            }
425
426            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
427            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
428            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
429            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
430            "DATE_ADD" => {
431                if f.args.len() == 2 {
432                    let is_simple_number = matches!(
433                        &f.args[1],
434                        Expression::Literal(crate::expressions::Literal::Number(_))
435                            | Expression::Neg(_)
436                    );
437                    if is_simple_number {
438                        // Keep as DATE_ADD(date, num_days)
439                        Ok(Expression::Function(Box::new(Function::new(
440                            "DATE_ADD".to_string(),
441                            f.args,
442                        ))))
443                    } else {
444                        let mut args = f.args;
445                        let date = args.remove(0);
446                        let interval = args.remove(0);
447                        let unit = Expression::Identifier(crate::expressions::Identifier {
448                            name: "DAY".to_string(),
449                            quoted: false,
450                            trailing_comments: Vec::new(),
451                            span: None,
452                        });
453                        Ok(Expression::Function(Box::new(Function::new(
454                            "DATEADD".to_string(),
455                            vec![unit, interval, date],
456                        ))))
457                    }
458                } else {
459                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
460                    Ok(Expression::Function(Box::new(Function::new(
461                        "DATE_ADD".to_string(),
462                        transformed_args,
463                    ))))
464                }
465            }
466
467            // DATEDIFF is native in Databricks - uppercase the unit if present
468            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
469            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
470            "DATEDIFF" => {
471                if f.args.len() == 2 {
472                    let mut args = f.args;
473                    let end_date = args.remove(0);
474                    let start_date = args.remove(0);
475                    let unit = Expression::Identifier(crate::expressions::Identifier {
476                        name: "DAY".to_string(),
477                        quoted: false,
478                        trailing_comments: Vec::new(),
479                        span: None,
480                    });
481                    Ok(Expression::Function(Box::new(Function::new(
482                        "DATEDIFF".to_string(),
483                        vec![unit, start_date, end_date],
484                    ))))
485                } else {
486                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
487                    Ok(Expression::Function(Box::new(Function::new(
488                        "DATEDIFF".to_string(),
489                        transformed_args,
490                    ))))
491                }
492            }
493
494            // DATE_DIFF -> DATEDIFF with uppercased unit
495            "DATE_DIFF" => {
496                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
497                Ok(Expression::Function(Box::new(Function::new(
498                    "DATEDIFF".to_string(),
499                    transformed_args,
500                ))))
501            }
502
503            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
504            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
505
506            // JSON_EXTRACT_SCALAR -> same handling
507            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
508
509            // GET_JSON_OBJECT -> colon syntax in Databricks
510            // GET_JSON_OBJECT(col, '$.path') becomes col:path
511            "GET_JSON_OBJECT" if f.args.len() == 2 => {
512                let mut args = f.args;
513                let col = args.remove(0);
514                let path_arg = args.remove(0);
515
516                // Extract and strip the $. prefix from the path
517                let path_expr = match &path_arg {
518                    Expression::Literal(crate::expressions::Literal::String(s)) => {
519                        // Strip leading '$.' if present
520                        let stripped = if s.starts_with("$.") {
521                            &s[2..]
522                        } else if s.starts_with("$") {
523                            &s[1..]
524                        } else {
525                            s.as_str()
526                        };
527                        Expression::Literal(crate::expressions::Literal::String(
528                            stripped.to_string(),
529                        ))
530                    }
531                    _ => path_arg,
532                };
533
534                Ok(Expression::JSONExtract(Box::new(JSONExtract {
535                    this: Box::new(col),
536                    expression: Box::new(path_expr),
537                    only_json_types: None,
538                    expressions: Vec::new(),
539                    variant_extract: Some(Box::new(Expression::true_())),
540                    json_query: None,
541                    option: None,
542                    quote: None,
543                    on_condition: None,
544                    requires_json: None,
545                })))
546            }
547
548            // FROM_JSON is native in Databricks
549            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
550
551            // PARSE_JSON is native in Databricks
552            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
553
554            // COLLECT_LIST is native in Databricks
555            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
556
557            // COLLECT_SET is native in Databricks
558            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
559
560            // RLIKE is native in Databricks
561            "RLIKE" => Ok(Expression::Function(Box::new(f))),
562
563            // REGEXP -> RLIKE in Databricks
564            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
565                "RLIKE".to_string(),
566                f.args,
567            )))),
568
569            // REGEXP_LIKE is native in Databricks
570            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
571
572            // LEVENSHTEIN is native in Databricks
573            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
574
575            // SEQUENCE is native (for GENERATE_SERIES)
576            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
577                Function::new("SEQUENCE".to_string(), f.args),
578            ))),
579
580            // SEQUENCE is native
581            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
582
583            // FLATTEN is native in Databricks
584            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
585
586            // ARRAY_SORT is native
587            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
588
589            // ARRAY_DISTINCT is native
590            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
591
592            // TRANSFORM is native (for array transformation)
593            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
594
595            // FILTER is native (for array filtering)
596            "FILTER" => Ok(Expression::Function(Box::new(f))),
597
598            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
599            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
600                let mut args = f.args;
601                let first_arg = args.remove(0);
602
603                // Check if first arg is already a Cast to TIMESTAMP
604                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
605                    first_arg
606                } else {
607                    // Wrap in CAST(... AS TIMESTAMP)
608                    Expression::Cast(Box::new(Cast {
609                        this: first_arg,
610                        to: DataType::Timestamp {
611                            precision: None,
612                            timezone: false,
613                        },
614                        trailing_comments: Vec::new(),
615                        double_colon_syntax: false,
616                        format: None,
617                        default: None,
618                        inferred_type: None,
619                    }))
620                };
621
622                let mut new_args = vec![wrapped_arg];
623                new_args.extend(args);
624
625                Ok(Expression::Function(Box::new(Function::new(
626                    "FROM_UTC_TIMESTAMP".to_string(),
627                    new_args,
628                ))))
629            }
630
631            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
632            "UNIFORM" if f.args.len() == 3 => {
633                let mut args = f.args;
634                let low = args.remove(0);
635                let high = args.remove(0);
636                let gen = args.remove(0);
637                match gen {
638                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
639                        if func.args.len() == 1 {
640                            // RANDOM(seed) -> extract seed
641                            let seed = func.args.into_iter().next().unwrap();
642                            Ok(Expression::Function(Box::new(Function::new(
643                                "UNIFORM".to_string(),
644                                vec![low, high, seed],
645                            ))))
646                        } else {
647                            // RANDOM() -> drop gen arg
648                            Ok(Expression::Function(Box::new(Function::new(
649                                "UNIFORM".to_string(),
650                                vec![low, high],
651                            ))))
652                        }
653                    }
654                    Expression::Rand(r) => {
655                        if let Some(seed) = r.seed {
656                            Ok(Expression::Function(Box::new(Function::new(
657                                "UNIFORM".to_string(),
658                                vec![low, high, *seed],
659                            ))))
660                        } else {
661                            Ok(Expression::Function(Box::new(Function::new(
662                                "UNIFORM".to_string(),
663                                vec![low, high],
664                            ))))
665                        }
666                    }
667                    _ => Ok(Expression::Function(Box::new(Function::new(
668                        "UNIFORM".to_string(),
669                        vec![low, high, gen],
670                    )))),
671                }
672            }
673
674            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
675            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
676                let subject = f.args[0].clone();
677                let pattern = f.args[1].clone();
678                Ok(Expression::Function(Box::new(Function::new(
679                    "REGEXP_EXTRACT".to_string(),
680                    vec![subject, pattern],
681                ))))
682            }
683
684            // Pass through everything else
685            _ => Ok(Expression::Function(Box::new(f))),
686        }
687    }
688
689    fn transform_aggregate_function(
690        &self,
691        f: Box<crate::expressions::AggregateFunction>,
692    ) -> Result<Expression> {
693        let name_upper = f.name.to_uppercase();
694        match name_upper.as_str() {
695            // COUNT_IF is native in Databricks (Spark 3+)
696            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
697
698            // ANY_VALUE is native in Databricks (Spark 3+)
699            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
700
701            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
702            "GROUP_CONCAT" if !f.args.is_empty() => {
703                let mut args = f.args;
704                let first = args.remove(0);
705                let separator = args.pop();
706                let collect_list = Expression::Function(Box::new(Function::new(
707                    "COLLECT_LIST".to_string(),
708                    vec![first],
709                )));
710                if let Some(sep) = separator {
711                    Ok(Expression::Function(Box::new(Function::new(
712                        "ARRAY_JOIN".to_string(),
713                        vec![collect_list, sep],
714                    ))))
715                } else {
716                    Ok(Expression::Function(Box::new(Function::new(
717                        "ARRAY_JOIN".to_string(),
718                        vec![collect_list],
719                    ))))
720                }
721            }
722
723            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
724            "STRING_AGG" if !f.args.is_empty() => {
725                let mut args = f.args;
726                let first = args.remove(0);
727                let separator = args.pop();
728                let collect_list = Expression::Function(Box::new(Function::new(
729                    "COLLECT_LIST".to_string(),
730                    vec![first],
731                )));
732                if let Some(sep) = separator {
733                    Ok(Expression::Function(Box::new(Function::new(
734                        "ARRAY_JOIN".to_string(),
735                        vec![collect_list, sep],
736                    ))))
737                } else {
738                    Ok(Expression::Function(Box::new(Function::new(
739                        "ARRAY_JOIN".to_string(),
740                        vec![collect_list],
741                    ))))
742                }
743            }
744
745            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
746            "LISTAGG" if !f.args.is_empty() => {
747                let mut args = f.args;
748                let first = args.remove(0);
749                let separator = args.pop();
750                let collect_list = Expression::Function(Box::new(Function::new(
751                    "COLLECT_LIST".to_string(),
752                    vec![first],
753                )));
754                if let Some(sep) = separator {
755                    Ok(Expression::Function(Box::new(Function::new(
756                        "ARRAY_JOIN".to_string(),
757                        vec![collect_list, sep],
758                    ))))
759                } else {
760                    Ok(Expression::Function(Box::new(Function::new(
761                        "ARRAY_JOIN".to_string(),
762                        vec![collect_list],
763                    ))))
764                }
765            }
766
767            // ARRAY_AGG -> COLLECT_LIST
768            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
769                "COLLECT_LIST".to_string(),
770                f.args,
771            )))),
772
773            // STDDEV is native in Databricks
774            "STDDEV" => Ok(Expression::AggregateFunction(f)),
775
776            // VARIANCE is native in Databricks
777            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
778
779            // APPROX_COUNT_DISTINCT is native in Databricks
780            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
781
782            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
783            "APPROX_DISTINCT" if !f.args.is_empty() => {
784                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
785                    name: "APPROX_COUNT_DISTINCT".to_string(),
786                    args: f.args,
787                    distinct: f.distinct,
788                    filter: f.filter,
789                    order_by: Vec::new(),
790                    limit: None,
791                    ignore_nulls: None,
792                    inferred_type: None,
793                })))
794            }
795
796            // Pass through everything else
797            _ => Ok(Expression::AggregateFunction(f)),
798        }
799    }
800
801    /// Transform Cast expressions - handles typed literals being cast
802    ///
803    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
804    /// Databricks/Spark transforms it as follows:
805    ///
806    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
807    ///
808    /// This reverses the types - the inner cast is to the target type,
809    /// the outer cast is to the original literal type.
810    fn transform_cast(&self, c: Cast) -> Result<Expression> {
811        // Check if the inner expression is a typed literal
812        match &c.this {
813            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
814            Expression::Literal(Literal::Timestamp(value)) => {
815                // Create inner cast: CAST('value' AS target_type)
816                let inner_cast = Expression::Cast(Box::new(Cast {
817                    this: Expression::Literal(Literal::String(value.clone())),
818                    to: c.to,
819                    trailing_comments: Vec::new(),
820                    double_colon_syntax: false,
821                    format: None,
822                    default: None,
823                    inferred_type: None,
824                }));
825                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
826                Ok(Expression::Cast(Box::new(Cast {
827                    this: inner_cast,
828                    to: DataType::Timestamp {
829                        precision: None,
830                        timezone: false,
831                    },
832                    trailing_comments: c.trailing_comments,
833                    double_colon_syntax: false,
834                    format: None,
835                    default: None,
836                    inferred_type: None,
837                })))
838            }
839            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
840            Expression::Literal(Literal::Date(value)) => {
841                let inner_cast = Expression::Cast(Box::new(Cast {
842                    this: Expression::Literal(Literal::String(value.clone())),
843                    to: c.to,
844                    trailing_comments: Vec::new(),
845                    double_colon_syntax: false,
846                    format: None,
847                    default: None,
848                    inferred_type: None,
849                }));
850                Ok(Expression::Cast(Box::new(Cast {
851                    this: inner_cast,
852                    to: DataType::Date,
853                    trailing_comments: c.trailing_comments,
854                    double_colon_syntax: false,
855                    format: None,
856                    default: None,
857                    inferred_type: None,
858                })))
859            }
860            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
861            Expression::Literal(Literal::Time(value)) => {
862                let inner_cast = Expression::Cast(Box::new(Cast {
863                    this: Expression::Literal(Literal::String(value.clone())),
864                    to: c.to,
865                    trailing_comments: Vec::new(),
866                    double_colon_syntax: false,
867                    format: None,
868                    default: None,
869                    inferred_type: None,
870                }));
871                Ok(Expression::Cast(Box::new(Cast {
872                    this: inner_cast,
873                    to: DataType::Time {
874                        precision: None,
875                        timezone: false,
876                    },
877                    trailing_comments: c.trailing_comments,
878                    double_colon_syntax: false,
879                    format: None,
880                    default: None,
881                    inferred_type: None,
882                })))
883            }
884            // For all other cases, pass through the Cast unchanged
885            _ => Ok(Expression::Cast(Box::new(c))),
886        }
887    }
888
889    /// Check if an expression is a CAST to TIMESTAMP
890    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
891        if let Expression::Cast(cast) = expr {
892            matches!(cast.to, DataType::Timestamp { .. })
893        } else {
894            false
895        }
896    }
897
898    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
899    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
900        use crate::expressions::Identifier;
901        if !args.is_empty() {
902            match &args[0] {
903                Expression::Identifier(id) => {
904                    args[0] = Expression::Identifier(Identifier {
905                        name: id.name.to_uppercase(),
906                        quoted: id.quoted,
907                        trailing_comments: id.trailing_comments.clone(),
908                        span: None,
909                    });
910                }
911                Expression::Column(col) if col.table.is_none() => {
912                    // Unqualified column name like "day" should be treated as a unit
913                    args[0] = Expression::Identifier(Identifier {
914                        name: col.name.name.to_uppercase(),
915                        quoted: col.name.quoted,
916                        trailing_comments: col.name.trailing_comments.clone(),
917                        span: None,
918                    });
919                }
920                _ => {}
921            }
922        }
923        args
924    }
925}
926
927#[cfg(test)]
928mod tests {
929    use super::*;
930    use crate::Dialect;
931
932    #[test]
933    fn test_timestamp_literal_cast() {
934        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
935        // This is test [47] in the Databricks dialect identity fixtures
936        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
937        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
938
939        let d = Dialect::get(DialectType::Databricks);
940        let ast = d.parse(sql).expect("Parse failed");
941        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
942        let output = d.generate(&transformed).expect("Generate failed");
943
944        assert_eq!(
945            output, expected,
946            "Timestamp literal cast transformation failed"
947        );
948    }
949
950    #[test]
951    fn test_from_utc_timestamp_wraps_column() {
952        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
953        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
954        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
955
956        let d = Dialect::get(DialectType::Databricks);
957        let ast = d.parse(sql).expect("Parse failed");
958        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
959        let output = d.generate(&transformed).expect("Generate failed");
960
961        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
962    }
963
964    #[test]
965    fn test_from_utc_timestamp_keeps_existing_cast() {
966        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
967        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
968        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
969        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
970
971        let d = Dialect::get(DialectType::Databricks);
972        let ast = d.parse(sql).expect("Parse failed");
973        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
974        let output = d.generate(&transformed).expect("Generate failed");
975
976        assert_eq!(
977            output, expected,
978            "FROM_UTC_TIMESTAMP with existing CAST failed"
979        );
980    }
981}