Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc, VarArgFunc};
13use crate::generator::GeneratorConfig;
14use crate::tokens::TokenizerConfig;
15
16/// Databricks dialect
17pub struct DatabricksDialect;
18
19impl DialectImpl for DatabricksDialect {
20    fn dialect_type(&self) -> DialectType {
21        DialectType::Databricks
22    }
23
24    fn tokenizer_config(&self) -> TokenizerConfig {
25        let mut config = TokenizerConfig::default();
26        // Databricks uses backticks for identifiers (NOT double quotes)
27        config.identifiers.clear();
28        config.identifiers.insert('`', '`');
29        // Databricks (like Hive/Spark) uses double quotes as string delimiters
30        config.quotes.insert("\"".to_string(), "\"".to_string());
31        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
32        config.string_escapes.push('\\');
33        // Databricks supports DIV keyword for integer division
34        config.keywords.insert("DIV".to_string(), crate::tokens::TokenType::Div);
35        // Databricks numeric literal suffixes (same as Hive/Spark)
36        config.numeric_literals.insert("L".to_string(), "BIGINT".to_string());
37        config.numeric_literals.insert("S".to_string(), "SMALLINT".to_string());
38        config.numeric_literals.insert("Y".to_string(), "TINYINT".to_string());
39        config.numeric_literals.insert("D".to_string(), "DOUBLE".to_string());
40        config.numeric_literals.insert("F".to_string(), "FLOAT".to_string());
41        config.numeric_literals.insert("BD".to_string(), "DECIMAL".to_string());
42        // Databricks allows identifiers to start with digits (like Hive/Spark)
43        config.identifiers_can_start_with_digit = true;
44        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
45        // Backslashes in raw strings are always literal (no escape processing)
46        config.string_escapes_allowed_in_raw_strings = false;
47        config
48    }
49
50    fn generator_config(&self) -> GeneratorConfig {
51        use crate::generator::IdentifierQuoteStyle;
52        GeneratorConfig {
53            identifier_quote: '`',
54            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
55            dialect: Some(DialectType::Databricks),
56            struct_field_sep: ": ",
57            create_function_return_as: false,
58            tablesample_seed_keyword: "REPEATABLE",
59            identifiers_can_start_with_digit: true,
60            // Databricks uses COMMENT 'value' without = sign
61            schema_comment_with_eq: false,
62            ..Default::default()
63        }
64    }
65
66    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
67        match expr {
68            // IFNULL -> COALESCE in Databricks
69            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
70                expressions: vec![f.this, f.expression],
71            }))),
72
73            // NVL -> COALESCE in Databricks
74            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
75                expressions: vec![f.this, f.expression],
76            }))),
77
78            // TryCast is native in Databricks
79            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
80
81            // SafeCast -> TRY_CAST in Databricks
82            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
83
84            // ILIKE is native in Databricks (Spark 3+)
85            Expression::ILike(op) => Ok(Expression::ILike(op)),
86
87            // UNNEST -> EXPLODE in Databricks
88            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
89
90            // EXPLODE is native to Databricks
91            Expression::Explode(f) => Ok(Expression::Explode(f)),
92
93            // ExplodeOuter is supported
94            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
95
96            // RANDOM -> RAND in Databricks
97            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
98                seed: None, lower: None, upper: None,
99            }))),
100
101            // Rand is native
102            Expression::Rand(r) => Ok(Expression::Rand(r)),
103
104            // || (Concat) -> CONCAT in Databricks
105            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
106                "CONCAT".to_string(),
107                vec![op.left, op.right],
108            )))),
109
110            // RegexpLike is native in Databricks
111            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
112
113            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
114            // This is a complex sqlglot transformation where:
115            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
116            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
117            Expression::Cast(c) => self.transform_cast(*c),
118
119            // Generic function transformations
120            Expression::Function(f) => self.transform_function(*f),
121
122            // Generic aggregate function transformations
123            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
124
125            // DateSub -> DATE_ADD(date, -val) in Databricks
126            Expression::DateSub(f) => {
127                // Convert string literals to numbers (interval values are often stored as strings)
128                let val = match f.interval {
129                    Expression::Literal(crate::expressions::Literal::String(s)) if s.parse::<i64>().is_ok() => {
130                        Expression::Literal(crate::expressions::Literal::Number(s))
131                    }
132                    other => other,
133                };
134                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp { this: val }));
135                Ok(Expression::Function(Box::new(Function::new(
136                    "DATE_ADD".to_string(),
137                    vec![f.this, neg_val],
138                ))))
139            }
140
141            // Pass through everything else
142            _ => Ok(expr),
143        }
144    }
145}
146
147impl DatabricksDialect {
148    fn transform_function(&self, f: Function) -> Result<Expression> {
149        let name_upper = f.name.to_uppercase();
150        match name_upper.as_str() {
151            // IFNULL -> COALESCE
152            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
153                expressions: f.args,
154            }))),
155
156            // NVL -> COALESCE
157            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
158                expressions: f.args,
159            }))),
160
161            // ISNULL -> COALESCE
162            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc { original_name: None,
163                expressions: f.args,
164            }))),
165
166            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
167            "ROW" => Ok(Expression::Function(Box::new(Function::new(
168                "STRUCT".to_string(),
169                f.args,
170            )))),
171
172            // GETDATE -> CURRENT_TIMESTAMP
173            "GETDATE" => Ok(Expression::CurrentTimestamp(
174                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
175            )),
176
177            // NOW -> CURRENT_TIMESTAMP
178            "NOW" => Ok(Expression::CurrentTimestamp(
179                crate::expressions::CurrentTimestamp { precision: None, sysdate: false },
180            )),
181
182            // CURDATE -> CURRENT_DATE
183            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
184
185            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
186            "CURRENT_DATE" if f.args.is_empty() => {
187                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
188            }
189
190            // RANDOM -> RAND
191            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
192                seed: None, lower: None, upper: None,
193            }))),
194
195            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
196            "GROUP_CONCAT" if !f.args.is_empty() => {
197                let mut args = f.args;
198                let first = args.remove(0);
199                let separator = args.pop();
200                let collect_list = Expression::Function(Box::new(Function::new(
201                    "COLLECT_LIST".to_string(),
202                    vec![first],
203                )));
204                if let Some(sep) = separator {
205                    Ok(Expression::Function(Box::new(Function::new(
206                        "ARRAY_JOIN".to_string(),
207                        vec![collect_list, sep],
208                    ))))
209                } else {
210                    Ok(Expression::Function(Box::new(Function::new(
211                        "ARRAY_JOIN".to_string(),
212                        vec![collect_list],
213                    ))))
214                }
215            }
216
217            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
218            "STRING_AGG" if !f.args.is_empty() => {
219                let mut args = f.args;
220                let first = args.remove(0);
221                let separator = args.pop();
222                let collect_list = Expression::Function(Box::new(Function::new(
223                    "COLLECT_LIST".to_string(),
224                    vec![first],
225                )));
226                if let Some(sep) = separator {
227                    Ok(Expression::Function(Box::new(Function::new(
228                        "ARRAY_JOIN".to_string(),
229                        vec![collect_list, sep],
230                    ))))
231                } else {
232                    Ok(Expression::Function(Box::new(Function::new(
233                        "ARRAY_JOIN".to_string(),
234                        vec![collect_list],
235                    ))))
236                }
237            }
238
239            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
240            "LISTAGG" if !f.args.is_empty() => {
241                let mut args = f.args;
242                let first = args.remove(0);
243                let separator = args.pop();
244                let collect_list = Expression::Function(Box::new(Function::new(
245                    "COLLECT_LIST".to_string(),
246                    vec![first],
247                )));
248                if let Some(sep) = separator {
249                    Ok(Expression::Function(Box::new(Function::new(
250                        "ARRAY_JOIN".to_string(),
251                        vec![collect_list, sep],
252                    ))))
253                } else {
254                    Ok(Expression::Function(Box::new(Function::new(
255                        "ARRAY_JOIN".to_string(),
256                        vec![collect_list],
257                    ))))
258                }
259            }
260
261            // ARRAY_AGG -> COLLECT_LIST in Databricks
262            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
263                Function::new("COLLECT_LIST".to_string(), f.args),
264            ))),
265
266            // SUBSTR -> SUBSTRING
267            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
268                "SUBSTRING".to_string(),
269                f.args,
270            )))),
271
272            // LEN -> LENGTH
273            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
274                f.args.into_iter().next().unwrap(),
275            )))),
276
277            // CHARINDEX -> LOCATE (with swapped args, like Spark)
278            "CHARINDEX" if f.args.len() >= 2 => {
279                let mut args = f.args;
280                let substring = args.remove(0);
281                let string = args.remove(0);
282                // LOCATE(substring, string)
283                Ok(Expression::Function(Box::new(Function::new(
284                    "LOCATE".to_string(),
285                    vec![substring, string],
286                ))))
287            }
288
289            // POSITION -> LOCATE
290            "POSITION" if f.args.len() == 2 => {
291                let args = f.args;
292                Ok(Expression::Function(Box::new(Function::new(
293                    "LOCATE".to_string(),
294                    args,
295                ))))
296            }
297
298            // STRPOS -> LOCATE (with same arg order)
299            "STRPOS" if f.args.len() == 2 => {
300                let args = f.args;
301                let string = args[0].clone();
302                let substring = args[1].clone();
303                // STRPOS(string, substring) -> LOCATE(substring, string)
304                Ok(Expression::Function(Box::new(Function::new(
305                    "LOCATE".to_string(),
306                    vec![substring, string],
307                ))))
308            }
309
310            // INSTR is native in Databricks
311            "INSTR" => Ok(Expression::Function(Box::new(f))),
312
313            // LOCATE is native in Databricks
314            "LOCATE" => Ok(Expression::Function(Box::new(f))),
315
316            // ARRAY_LENGTH -> SIZE
317            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
318                Function::new("SIZE".to_string(), f.args),
319            ))),
320
321            // CARDINALITY -> SIZE
322            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
323                Function::new("SIZE".to_string(), f.args),
324            ))),
325
326            // SIZE is native
327            "SIZE" => Ok(Expression::Function(Box::new(f))),
328
329            // ARRAY_CONTAINS is native in Databricks
330            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
331
332            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
333            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
334            "CONTAINS" if f.args.len() == 2 => {
335                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
336                let is_string_contains = matches!(&f.args[0], Expression::Lower(_)) && matches!(&f.args[1], Expression::Lower(_));
337                if is_string_contains {
338                    Ok(Expression::Function(Box::new(f)))
339                } else {
340                    Ok(Expression::Function(Box::new(Function::new(
341                        "ARRAY_CONTAINS".to_string(),
342                        f.args,
343                    ))))
344                }
345            }
346
347            // TO_DATE is native in Databricks
348            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
349
350            // TO_TIMESTAMP is native in Databricks
351            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
352
353            // DATE_FORMAT is native in Databricks
354            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
355
356            // strftime -> DATE_FORMAT in Databricks
357            "STRFTIME" if f.args.len() >= 2 => {
358                let mut args = f.args;
359                let format = args.remove(0);
360                let date = args.remove(0);
361                Ok(Expression::Function(Box::new(Function::new(
362                    "DATE_FORMAT".to_string(),
363                    vec![date, format],
364                ))))
365            }
366
367            // TO_CHAR is supported natively in Databricks (unlike Spark)
368            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
369
370            // DATE_TRUNC is native in Databricks
371            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
372
373            // DATEADD is native in Databricks - uppercase the unit if present
374            "DATEADD" => {
375                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
376                Ok(Expression::Function(Box::new(Function::new(
377                    "DATEADD".to_string(),
378                    transformed_args,
379                ))))
380            }
381
382            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
383            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
384            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
385            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
386            "DATE_ADD" => {
387                if f.args.len() == 2 {
388                    let is_simple_number = matches!(&f.args[1],
389                        Expression::Literal(crate::expressions::Literal::Number(_))
390                        | Expression::Neg(_)
391                    );
392                    if is_simple_number {
393                        // Keep as DATE_ADD(date, num_days)
394                        Ok(Expression::Function(Box::new(Function::new(
395                            "DATE_ADD".to_string(),
396                            f.args,
397                        ))))
398                    } else {
399                        let mut args = f.args;
400                        let date = args.remove(0);
401                        let interval = args.remove(0);
402                        let unit = Expression::Identifier(crate::expressions::Identifier {
403                            name: "DAY".to_string(),
404                            quoted: false,
405                            trailing_comments: Vec::new(),
406                        });
407                        Ok(Expression::Function(Box::new(Function::new(
408                            "DATEADD".to_string(),
409                            vec![unit, interval, date],
410                        ))))
411                    }
412                } else {
413                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
414                    Ok(Expression::Function(Box::new(Function::new(
415                        "DATE_ADD".to_string(),
416                        transformed_args,
417                    ))))
418                }
419            }
420
421            // DATEDIFF is native in Databricks - uppercase the unit if present
422            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
423            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
424            "DATEDIFF" => {
425                if f.args.len() == 2 {
426                    let mut args = f.args;
427                    let end_date = args.remove(0);
428                    let start_date = args.remove(0);
429                    let unit = Expression::Identifier(crate::expressions::Identifier {
430                        name: "DAY".to_string(),
431                        quoted: false,
432                        trailing_comments: Vec::new(),
433                    });
434                    Ok(Expression::Function(Box::new(Function::new(
435                        "DATEDIFF".to_string(),
436                        vec![unit, start_date, end_date],
437                    ))))
438                } else {
439                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
440                    Ok(Expression::Function(Box::new(Function::new(
441                        "DATEDIFF".to_string(),
442                        transformed_args,
443                    ))))
444                }
445            }
446
447            // DATE_DIFF -> DATEDIFF with uppercased unit
448            "DATE_DIFF" => {
449                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
450                Ok(Expression::Function(Box::new(Function::new(
451                    "DATEDIFF".to_string(),
452                    transformed_args,
453                ))))
454            }
455
456            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
457            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
458
459            // JSON_EXTRACT_SCALAR -> same handling
460            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
461
462            // GET_JSON_OBJECT -> colon syntax in Databricks
463            // GET_JSON_OBJECT(col, '$.path') becomes col:path
464            "GET_JSON_OBJECT" if f.args.len() == 2 => {
465                let mut args = f.args;
466                let col = args.remove(0);
467                let path_arg = args.remove(0);
468
469                // Extract and strip the $. prefix from the path
470                let path_expr = match &path_arg {
471                    Expression::Literal(crate::expressions::Literal::String(s)) => {
472                        // Strip leading '$.' if present
473                        let stripped = if s.starts_with("$.") {
474                            &s[2..]
475                        } else if s.starts_with("$") {
476                            &s[1..]
477                        } else {
478                            s.as_str()
479                        };
480                        Expression::Literal(crate::expressions::Literal::String(stripped.to_string()))
481                    }
482                    _ => path_arg,
483                };
484
485                Ok(Expression::JSONExtract(Box::new(JSONExtract {
486                    this: Box::new(col),
487                    expression: Box::new(path_expr),
488                    only_json_types: None,
489                    expressions: Vec::new(),
490                    variant_extract: Some(Box::new(Expression::true_())),
491                    json_query: None,
492                    option: None,
493                    quote: None,
494                    on_condition: None,
495                    requires_json: None,
496                })))
497            }
498
499            // FROM_JSON is native in Databricks
500            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
501
502            // PARSE_JSON is native in Databricks
503            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
504
505            // COLLECT_LIST is native in Databricks
506            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
507
508            // COLLECT_SET is native in Databricks
509            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
510
511            // RLIKE is native in Databricks
512            "RLIKE" => Ok(Expression::Function(Box::new(f))),
513
514            // REGEXP -> RLIKE in Databricks
515            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
516                "RLIKE".to_string(),
517                f.args,
518            )))),
519
520            // REGEXP_LIKE is native in Databricks
521            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
522
523            // LEVENSHTEIN is native in Databricks
524            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
525
526            // SEQUENCE is native (for GENERATE_SERIES)
527            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
528                Function::new("SEQUENCE".to_string(), f.args),
529            ))),
530
531            // SEQUENCE is native
532            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
533
534            // FLATTEN is native in Databricks
535            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
536
537            // ARRAY_SORT is native
538            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
539
540            // ARRAY_DISTINCT is native
541            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
542
543            // TRANSFORM is native (for array transformation)
544            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
545
546            // FILTER is native (for array filtering)
547            "FILTER" => Ok(Expression::Function(Box::new(f))),
548
549            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
550            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
551                let mut args = f.args;
552                let first_arg = args.remove(0);
553
554                // Check if first arg is already a Cast to TIMESTAMP
555                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
556                    first_arg
557                } else {
558                    // Wrap in CAST(... AS TIMESTAMP)
559                    Expression::Cast(Box::new(Cast {
560                        this: first_arg,
561                        to: DataType::Timestamp { precision: None, timezone: false },
562                        trailing_comments: Vec::new(),
563                        double_colon_syntax: false,
564                        format: None,
565                        default: None,
566                    }))
567                };
568
569                let mut new_args = vec![wrapped_arg];
570                new_args.extend(args);
571
572                Ok(Expression::Function(Box::new(Function::new(
573                    "FROM_UTC_TIMESTAMP".to_string(),
574                    new_args,
575                ))))
576            }
577
578            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
579            "UNIFORM" if f.args.len() == 3 => {
580                let mut args = f.args;
581                let low = args.remove(0);
582                let high = args.remove(0);
583                let gen = args.remove(0);
584                match gen {
585                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
586                        if func.args.len() == 1 {
587                            // RANDOM(seed) -> extract seed
588                            let seed = func.args.into_iter().next().unwrap();
589                            Ok(Expression::Function(Box::new(Function::new(
590                                "UNIFORM".to_string(),
591                                vec![low, high, seed],
592                            ))))
593                        } else {
594                            // RANDOM() -> drop gen arg
595                            Ok(Expression::Function(Box::new(Function::new(
596                                "UNIFORM".to_string(),
597                                vec![low, high],
598                            ))))
599                        }
600                    }
601                    Expression::Rand(r) => {
602                        if let Some(seed) = r.seed {
603                            Ok(Expression::Function(Box::new(Function::new(
604                                "UNIFORM".to_string(),
605                                vec![low, high, *seed],
606                            ))))
607                        } else {
608                            Ok(Expression::Function(Box::new(Function::new(
609                                "UNIFORM".to_string(),
610                                vec![low, high],
611                            ))))
612                        }
613                    }
614                    _ => {
615                        Ok(Expression::Function(Box::new(Function::new(
616                            "UNIFORM".to_string(),
617                            vec![low, high, gen],
618                        ))))
619                    }
620                }
621            }
622
623            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
624            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
625                let subject = f.args[0].clone();
626                let pattern = f.args[1].clone();
627                Ok(Expression::Function(Box::new(Function::new(
628                    "REGEXP_EXTRACT".to_string(),
629                    vec![subject, pattern],
630                ))))
631            }
632
633            // Pass through everything else
634            _ => Ok(Expression::Function(Box::new(f))),
635        }
636    }
637
638    fn transform_aggregate_function(
639        &self,
640        f: Box<crate::expressions::AggregateFunction>,
641    ) -> Result<Expression> {
642        let name_upper = f.name.to_uppercase();
643        match name_upper.as_str() {
644            // COUNT_IF is native in Databricks (Spark 3+)
645            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
646
647            // ANY_VALUE is native in Databricks (Spark 3+)
648            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
649
650            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
651            "GROUP_CONCAT" if !f.args.is_empty() => {
652                let mut args = f.args;
653                let first = args.remove(0);
654                let separator = args.pop();
655                let collect_list = Expression::Function(Box::new(Function::new(
656                    "COLLECT_LIST".to_string(),
657                    vec![first],
658                )));
659                if let Some(sep) = separator {
660                    Ok(Expression::Function(Box::new(Function::new(
661                        "ARRAY_JOIN".to_string(),
662                        vec![collect_list, sep],
663                    ))))
664                } else {
665                    Ok(Expression::Function(Box::new(Function::new(
666                        "ARRAY_JOIN".to_string(),
667                        vec![collect_list],
668                    ))))
669                }
670            }
671
672            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
673            "STRING_AGG" if !f.args.is_empty() => {
674                let mut args = f.args;
675                let first = args.remove(0);
676                let separator = args.pop();
677                let collect_list = Expression::Function(Box::new(Function::new(
678                    "COLLECT_LIST".to_string(),
679                    vec![first],
680                )));
681                if let Some(sep) = separator {
682                    Ok(Expression::Function(Box::new(Function::new(
683                        "ARRAY_JOIN".to_string(),
684                        vec![collect_list, sep],
685                    ))))
686                } else {
687                    Ok(Expression::Function(Box::new(Function::new(
688                        "ARRAY_JOIN".to_string(),
689                        vec![collect_list],
690                    ))))
691                }
692            }
693
694            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
695            "LISTAGG" if !f.args.is_empty() => {
696                let mut args = f.args;
697                let first = args.remove(0);
698                let separator = args.pop();
699                let collect_list = Expression::Function(Box::new(Function::new(
700                    "COLLECT_LIST".to_string(),
701                    vec![first],
702                )));
703                if let Some(sep) = separator {
704                    Ok(Expression::Function(Box::new(Function::new(
705                        "ARRAY_JOIN".to_string(),
706                        vec![collect_list, sep],
707                    ))))
708                } else {
709                    Ok(Expression::Function(Box::new(Function::new(
710                        "ARRAY_JOIN".to_string(),
711                        vec![collect_list],
712                    ))))
713                }
714            }
715
716            // ARRAY_AGG -> COLLECT_LIST
717            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
718                Function::new("COLLECT_LIST".to_string(), f.args),
719            ))),
720
721            // STDDEV is native in Databricks
722            "STDDEV" => Ok(Expression::AggregateFunction(f)),
723
724            // VARIANCE is native in Databricks
725            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
726
727            // APPROX_COUNT_DISTINCT is native in Databricks
728            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
729
730            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
731            "APPROX_DISTINCT" if !f.args.is_empty() => Ok(Expression::AggregateFunction(Box::new(
732                AggregateFunction {
733                    name: "APPROX_COUNT_DISTINCT".to_string(),
734                    args: f.args,
735                    distinct: f.distinct,
736                    filter: f.filter,
737                    order_by: Vec::new(),
738                    limit: None,
739                    ignore_nulls: None,
740                },
741            ))),
742
743            // Pass through everything else
744            _ => Ok(Expression::AggregateFunction(f)),
745        }
746    }
747
748    /// Transform Cast expressions - handles typed literals being cast
749    ///
750    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
751    /// Databricks/Spark transforms it as follows:
752    ///
753    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
754    ///
755    /// This reverses the types - the inner cast is to the target type,
756    /// the outer cast is to the original literal type.
757    fn transform_cast(&self, c: Cast) -> Result<Expression> {
758        // Check if the inner expression is a typed literal
759        match &c.this {
760            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
761            Expression::Literal(Literal::Timestamp(value)) => {
762                // Create inner cast: CAST('value' AS target_type)
763                let inner_cast = Expression::Cast(Box::new(Cast {
764                    this: Expression::Literal(Literal::String(value.clone())),
765                    to: c.to,
766                    trailing_comments: Vec::new(),
767                    double_colon_syntax: false,
768                    format: None,
769                    default: None,
770                }));
771                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
772                Ok(Expression::Cast(Box::new(Cast {
773                    this: inner_cast,
774                    to: DataType::Timestamp { precision: None, timezone: false },
775                    trailing_comments: c.trailing_comments,
776                    double_colon_syntax: false,
777                    format: None,
778                    default: None,
779                })))
780            }
781            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
782            Expression::Literal(Literal::Date(value)) => {
783                let inner_cast = Expression::Cast(Box::new(Cast {
784                    this: Expression::Literal(Literal::String(value.clone())),
785                    to: c.to,
786                    trailing_comments: Vec::new(),
787                    double_colon_syntax: false,
788                    format: None,
789                    default: None,
790                }));
791                Ok(Expression::Cast(Box::new(Cast {
792                    this: inner_cast,
793                    to: DataType::Date,
794                    trailing_comments: c.trailing_comments,
795                    double_colon_syntax: false,
796                    format: None,
797                    default: None,
798                })))
799            }
800            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
801            Expression::Literal(Literal::Time(value)) => {
802                let inner_cast = Expression::Cast(Box::new(Cast {
803                    this: Expression::Literal(Literal::String(value.clone())),
804                    to: c.to,
805                    trailing_comments: Vec::new(),
806                    double_colon_syntax: false,
807                    format: None,
808                    default: None,
809                }));
810                Ok(Expression::Cast(Box::new(Cast {
811                    this: inner_cast,
812                    to: DataType::Time { precision: None, timezone: false },
813                    trailing_comments: c.trailing_comments,
814                    double_colon_syntax: false,
815                    format: None,
816                    default: None,
817                })))
818            }
819            // For all other cases, pass through the Cast unchanged
820            _ => Ok(Expression::Cast(Box::new(c))),
821        }
822    }
823
824    /// Check if an expression is a CAST to TIMESTAMP
825    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
826        if let Expression::Cast(cast) = expr {
827            matches!(cast.to, DataType::Timestamp { .. })
828        } else {
829            false
830        }
831    }
832
833    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
834    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
835        use crate::expressions::Identifier;
836        if !args.is_empty() {
837            match &args[0] {
838                Expression::Identifier(id) => {
839                    args[0] = Expression::Identifier(Identifier {
840                        name: id.name.to_uppercase(),
841                        quoted: id.quoted,
842                        trailing_comments: id.trailing_comments.clone(),
843                    });
844                }
845                Expression::Column(col) if col.table.is_none() => {
846                    // Unqualified column name like "day" should be treated as a unit
847                    args[0] = Expression::Identifier(Identifier {
848                        name: col.name.name.to_uppercase(),
849                        quoted: col.name.quoted,
850                        trailing_comments: col.name.trailing_comments.clone(),
851                    });
852                }
853                _ => {}
854            }
855        }
856        args
857    }
858}
859
860#[cfg(test)]
861mod tests {
862    use super::*;
863    use crate::Dialect;
864
865    #[test]
866    fn test_timestamp_literal_cast() {
867        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
868        // This is test [47] in the Databricks dialect identity fixtures
869        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
870        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
871
872        let d = Dialect::get(DialectType::Databricks);
873        let ast = d.parse(sql).expect("Parse failed");
874        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
875        let output = d.generate(&transformed).expect("Generate failed");
876
877        assert_eq!(output, expected, "Timestamp literal cast transformation failed");
878    }
879
880    #[test]
881    fn test_from_utc_timestamp_wraps_column() {
882        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
883        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
884        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
885
886        let d = Dialect::get(DialectType::Databricks);
887        let ast = d.parse(sql).expect("Parse failed");
888        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
889        let output = d.generate(&transformed).expect("Generate failed");
890
891        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
892    }
893
894    #[test]
895    fn test_from_utc_timestamp_keeps_existing_cast() {
896        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
897        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
898        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
899        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
900
901        let d = Dialect::get(DialectType::Databricks);
902        let ast = d.parse(sql).expect("Parse failed");
903        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
904        let output = d.generate(&transformed).expect("Generate failed");
905
906        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP with existing CAST failed");
907    }
908}