Skip to main content

polyglot_sql/dialects/
databricks.rs

1//! Databricks Dialect
2//!
3//! Databricks-specific transformations based on sqlglot patterns.
4//! Databricks extends Spark SQL with additional features:
5//! - Colon operator for JSON extraction (col:path)
6//! - DATEADD/DATEDIFF with specific syntax
7//! - NULL type mapped to VOID
8//! - Native REGEXP_LIKE and TRY_CAST support
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    AggregateFunction, Cast, DataType, Expression, Function, JSONExtract, Literal, UnaryFunc,
14    VarArgFunc,
15};
16#[cfg(feature = "generate")]
17use crate::generator::GeneratorConfig;
18use crate::tokens::TokenizerConfig;
19
20/// Databricks dialect
21pub struct DatabricksDialect;
22
23impl DialectImpl for DatabricksDialect {
24    fn dialect_type(&self) -> DialectType {
25        DialectType::Databricks
26    }
27
28    fn tokenizer_config(&self) -> TokenizerConfig {
29        let mut config = TokenizerConfig::default();
30        // Databricks uses backticks for identifiers (NOT double quotes)
31        config.identifiers.clear();
32        config.identifiers.insert('`', '`');
33        // Databricks (like Hive/Spark) uses double quotes as string delimiters
34        config.quotes.insert("\"".to_string(), "\"".to_string());
35        // Databricks uses backslash escapes in strings (inherited from Hive/Spark)
36        config.string_escapes.push('\\');
37        // Databricks supports DIV keyword for integer division
38        config
39            .keywords
40            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
41        config
42            .keywords
43            .insert("REPAIR".to_string(), crate::tokens::TokenType::Command);
44        config
45            .keywords
46            .insert("MSCK".to_string(), crate::tokens::TokenType::Command);
47        // Databricks numeric literal suffixes (same as Hive/Spark)
48        config
49            .numeric_literals
50            .insert("L".to_string(), "BIGINT".to_string());
51        config
52            .numeric_literals
53            .insert("S".to_string(), "SMALLINT".to_string());
54        config
55            .numeric_literals
56            .insert("Y".to_string(), "TINYINT".to_string());
57        config
58            .numeric_literals
59            .insert("D".to_string(), "DOUBLE".to_string());
60        config
61            .numeric_literals
62            .insert("F".to_string(), "FLOAT".to_string());
63        config
64            .numeric_literals
65            .insert("BD".to_string(), "DECIMAL".to_string());
66        // Databricks allows identifiers to start with digits (like Hive/Spark)
67        config.identifiers_can_start_with_digit = true;
68        // Databricks (like Spark): STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
69        // Backslashes in raw strings are always literal (no escape processing)
70        config.string_escapes_allowed_in_raw_strings = false;
71        config
72    }
73
74    #[cfg(feature = "generate")]
75
76    fn generator_config(&self) -> GeneratorConfig {
77        use crate::generator::IdentifierQuoteStyle;
78        GeneratorConfig {
79            identifier_quote: '`',
80            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
81            dialect: Some(DialectType::Databricks),
82            struct_field_sep: ": ",
83            create_function_return_as: false,
84            tablesample_seed_keyword: "REPEATABLE",
85            identifiers_can_start_with_digit: true,
86            // Databricks uses COMMENT 'value' without = sign
87            schema_comment_with_eq: false,
88            ..Default::default()
89        }
90    }
91
92    #[cfg(feature = "transpile")]
93
94    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
95        match expr {
96            // IFNULL -> COALESCE in Databricks
97            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
98                original_name: None,
99                expressions: vec![f.this, f.expression],
100                inferred_type: None,
101            }))),
102
103            // NVL -> COALESCE in Databricks
104            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
105                original_name: None,
106                expressions: vec![f.this, f.expression],
107                inferred_type: None,
108            }))),
109
110            // TryCast is native in Databricks
111            Expression::TryCast(c) => Ok(Expression::TryCast(c)),
112
113            // SafeCast -> TRY_CAST in Databricks
114            Expression::SafeCast(c) => Ok(Expression::TryCast(c)),
115
116            // ILIKE is native in Databricks (Spark 3+)
117            Expression::ILike(op) => Ok(Expression::ILike(op)),
118
119            // UNNEST -> EXPLODE in Databricks
120            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
121
122            // EXPLODE is native to Databricks
123            Expression::Explode(f) => Ok(Expression::Explode(f)),
124
125            // ExplodeOuter is supported
126            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
127
128            // RANDOM -> RAND in Databricks
129            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
130                seed: None,
131                lower: None,
132                upper: None,
133            }))),
134
135            // Rand is native
136            Expression::Rand(r) => Ok(Expression::Rand(r)),
137
138            // || (Concat) -> CONCAT in Databricks
139            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
140                "CONCAT".to_string(),
141                vec![op.left, op.right],
142            )))),
143
144            // RegexpLike is native in Databricks
145            Expression::RegexpLike(op) => Ok(Expression::RegexpLike(op)),
146
147            // Cast with typed literal: TIMESTAMP 'x'::TYPE -> CAST(CAST('x' AS TYPE) AS TIMESTAMP)
148            // This is a complex sqlglot transformation where:
149            // 1. The inner typed literal (e.g., TIMESTAMP 'x') becomes CAST('x' AS <target_type>)
150            // 2. The outer result is wrapped in CAST(... AS <original_literal_type>)
151            Expression::Cast(c) => self.transform_cast(*c),
152
153            // Generic function transformations
154            Expression::Function(f) => self.transform_function(*f),
155
156            // Generic aggregate function transformations
157            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
158
159            // DateSub -> DATE_ADD(date, -val) in Databricks
160            Expression::DateSub(f) => {
161                // Convert string literals to numbers (interval values are often stored as strings)
162                let val = match f.interval {
163                    Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::String(s) if s.parse::<i64>().is_ok()) =>
164                    {
165                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
166                            unreachable!()
167                        };
168                        Expression::Literal(Box::new(crate::expressions::Literal::Number(
169                            s.clone(),
170                        )))
171                    }
172                    other => other,
173                };
174                let neg_val = Expression::Neg(Box::new(crate::expressions::UnaryOp {
175                    this: val,
176                    inferred_type: None,
177                }));
178                Ok(Expression::Function(Box::new(Function::new(
179                    "DATE_ADD".to_string(),
180                    vec![f.this, neg_val],
181                ))))
182            }
183
184            // Pass through everything else
185            _ => Ok(expr),
186        }
187    }
188}
189
190#[cfg(feature = "transpile")]
191impl DatabricksDialect {
192    fn transform_function(&self, f: Function) -> Result<Expression> {
193        let name_upper = f.name.to_uppercase();
194        match name_upper.as_str() {
195            // IFNULL -> COALESCE
196            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
197                original_name: None,
198                expressions: f.args,
199                inferred_type: None,
200            }))),
201
202            // NVL -> COALESCE
203            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
204                original_name: None,
205                expressions: f.args,
206                inferred_type: None,
207            }))),
208
209            // ISNULL -> COALESCE
210            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
211                original_name: None,
212                expressions: f.args,
213                inferred_type: None,
214            }))),
215
216            // ROW -> STRUCT (no auto-naming for cross-dialect conversion)
217            "ROW" => Ok(Expression::Function(Box::new(Function::new(
218                "STRUCT".to_string(),
219                f.args,
220            )))),
221
222            // NAMED_STRUCT('a', 1) -> STRUCT(1 AS a) for SQLGlot Databricks outputs
223            "NAMED_STRUCT" if f.args.len() % 2 == 0 => {
224                let original_args = f.args.clone();
225                let mut struct_args = Vec::new();
226                for pair in f.args.chunks(2) {
227                    if let Expression::Literal(lit) = &pair[0] {
228                        if let Literal::String(field_name) = lit.as_ref() {
229                            struct_args.push(Expression::Alias(Box::new(
230                                crate::expressions::Alias {
231                                    this: pair[1].clone(),
232                                    alias: crate::expressions::Identifier::new(field_name),
233                                    column_aliases: Vec::new(),
234                                    alias_explicit_as: false,
235                                    alias_keyword: None,
236                                    pre_alias_comments: Vec::new(),
237                                    trailing_comments: Vec::new(),
238                                    inferred_type: None,
239                                },
240                            )));
241                            continue;
242                        }
243                    }
244                    return Ok(Expression::Function(Box::new(Function::new(
245                        "NAMED_STRUCT".to_string(),
246                        original_args,
247                    ))));
248                }
249                Ok(Expression::Function(Box::new(Function::new(
250                    "STRUCT".to_string(),
251                    struct_args,
252                ))))
253            }
254
255            // GETDATE -> CURRENT_TIMESTAMP
256            "GETDATE" => Ok(Expression::CurrentTimestamp(
257                crate::expressions::CurrentTimestamp {
258                    precision: None,
259                    sysdate: false,
260                },
261            )),
262
263            // NOW -> CURRENT_TIMESTAMP
264            "NOW" => Ok(Expression::CurrentTimestamp(
265                crate::expressions::CurrentTimestamp {
266                    precision: None,
267                    sysdate: false,
268                },
269            )),
270
271            // CURDATE -> CURRENT_DATE
272            "CURDATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
273
274            // CURRENT_DATE() with parens -> CURRENT_DATE (no parens)
275            "CURRENT_DATE" if f.args.is_empty() => {
276                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
277            }
278
279            // RANDOM -> RAND
280            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
281                seed: None,
282                lower: None,
283                upper: None,
284            }))),
285
286            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
287            "GROUP_CONCAT" if !f.args.is_empty() => {
288                let mut args = f.args;
289                let first = args.remove(0);
290                let separator = args.pop();
291                let collect_list = Expression::Function(Box::new(Function::new(
292                    "COLLECT_LIST".to_string(),
293                    vec![first],
294                )));
295                if let Some(sep) = separator {
296                    Ok(Expression::Function(Box::new(Function::new(
297                        "ARRAY_JOIN".to_string(),
298                        vec![collect_list, sep],
299                    ))))
300                } else {
301                    Ok(Expression::Function(Box::new(Function::new(
302                        "ARRAY_JOIN".to_string(),
303                        vec![collect_list],
304                    ))))
305                }
306            }
307
308            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN in Databricks
309            "STRING_AGG" if !f.args.is_empty() => {
310                let mut args = f.args;
311                let first = args.remove(0);
312                let separator = args.pop();
313                let collect_list = Expression::Function(Box::new(Function::new(
314                    "COLLECT_LIST".to_string(),
315                    vec![first],
316                )));
317                if let Some(sep) = separator {
318                    Ok(Expression::Function(Box::new(Function::new(
319                        "ARRAY_JOIN".to_string(),
320                        vec![collect_list, sep],
321                    ))))
322                } else {
323                    Ok(Expression::Function(Box::new(Function::new(
324                        "ARRAY_JOIN".to_string(),
325                        vec![collect_list],
326                    ))))
327                }
328            }
329
330            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
331            "LISTAGG" if !f.args.is_empty() => {
332                let mut args = f.args;
333                let first = args.remove(0);
334                let separator = args.pop();
335                let collect_list = Expression::Function(Box::new(Function::new(
336                    "COLLECT_LIST".to_string(),
337                    vec![first],
338                )));
339                if let Some(sep) = separator {
340                    Ok(Expression::Function(Box::new(Function::new(
341                        "ARRAY_JOIN".to_string(),
342                        vec![collect_list, sep],
343                    ))))
344                } else {
345                    Ok(Expression::Function(Box::new(Function::new(
346                        "ARRAY_JOIN".to_string(),
347                        vec![collect_list],
348                    ))))
349                }
350            }
351
352            // ARRAY_AGG -> COLLECT_LIST in Databricks
353            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
354                "COLLECT_LIST".to_string(),
355                f.args,
356            )))),
357
358            // SUBSTR -> SUBSTRING
359            "SUBSTR" => Ok(Expression::Function(Box::new(Function::new(
360                "SUBSTRING".to_string(),
361                f.args,
362            )))),
363
364            // LEN -> LENGTH
365            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
366                f.args.into_iter().next().unwrap(),
367            )))),
368
369            // CHARINDEX -> LOCATE (with swapped args, like Spark)
370            "CHARINDEX" if f.args.len() >= 2 => {
371                let mut args = f.args;
372                let substring = args.remove(0);
373                let string = args.remove(0);
374                // LOCATE(substring, string)
375                Ok(Expression::Function(Box::new(Function::new(
376                    "LOCATE".to_string(),
377                    vec![substring, string],
378                ))))
379            }
380
381            // POSITION -> LOCATE
382            "POSITION" if f.args.len() == 2 => {
383                let args = f.args;
384                Ok(Expression::Function(Box::new(Function::new(
385                    "LOCATE".to_string(),
386                    args,
387                ))))
388            }
389
390            // STRPOS -> LOCATE (with same arg order)
391            "STRPOS" if f.args.len() == 2 => {
392                let args = f.args;
393                let string = args[0].clone();
394                let substring = args[1].clone();
395                // STRPOS(string, substring) -> LOCATE(substring, string)
396                Ok(Expression::Function(Box::new(Function::new(
397                    "LOCATE".to_string(),
398                    vec![substring, string],
399                ))))
400            }
401
402            // INSTR is native in Databricks
403            "INSTR" => Ok(Expression::Function(Box::new(f))),
404
405            // LOCATE is native in Databricks
406            "LOCATE" => Ok(Expression::Function(Box::new(f))),
407
408            // ARRAY_LENGTH -> SIZE
409            "ARRAY_LENGTH" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
410                Function::new("SIZE".to_string(), f.args),
411            ))),
412
413            // CARDINALITY -> SIZE
414            "CARDINALITY" if f.args.len() == 1 => Ok(Expression::Function(Box::new(
415                Function::new("SIZE".to_string(), f.args),
416            ))),
417
418            // SIZE is native
419            "SIZE" => Ok(Expression::Function(Box::new(f))),
420
421            // ARRAY_CONTAINS is native in Databricks
422            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
423
424            // CONTAINS -> ARRAY_CONTAINS in Databricks (for array operations)
425            // But keep CONTAINS for string contains (from CONTAINS_SUBSTR transpilation)
426            "CONTAINS" if f.args.len() == 2 => {
427                // Check if this is a string CONTAINS (LOWER() args pattern from CONTAINS_SUBSTR)
428                let is_string_contains = matches!(&f.args[0], Expression::Lower(_))
429                    && matches!(&f.args[1], Expression::Lower(_));
430                if is_string_contains {
431                    Ok(Expression::Function(Box::new(f)))
432                } else {
433                    Ok(Expression::Function(Box::new(Function::new(
434                        "ARRAY_CONTAINS".to_string(),
435                        f.args,
436                    ))))
437                }
438            }
439
440            // TO_DATE is native in Databricks
441            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
442
443            // TO_TIMESTAMP is native in Databricks
444            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
445
446            // DATE_FORMAT is native in Databricks
447            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
448
449            // strftime -> DATE_FORMAT in Databricks
450            "STRFTIME" if f.args.len() >= 2 => {
451                let mut args = f.args;
452                let format = args.remove(0);
453                let date = args.remove(0);
454                Ok(Expression::Function(Box::new(Function::new(
455                    "DATE_FORMAT".to_string(),
456                    vec![date, format],
457                ))))
458            }
459
460            // TO_CHAR is supported natively in Databricks (unlike Spark)
461            "TO_CHAR" => Ok(Expression::Function(Box::new(f))),
462
463            // DATE_TRUNC is native in Databricks
464            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
465
466            // DATEADD is native in Databricks - uppercase the unit if present
467            "DATEADD" => {
468                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
469                Ok(Expression::Function(Box::new(Function::new(
470                    "DATEADD".to_string(),
471                    transformed_args,
472                ))))
473            }
474
475            // DATE_ADD -> DATEADD in Databricks (2-arg form only)
476            // 2-arg with interval: DATE_ADD(date, interval) -> DATEADD(DAY, interval, date)
477            // 2-arg with number: DATE_ADD(date, -2) -> keep as DATE_ADD(date, -2)
478            // 3-arg: DATE_ADD(unit, amount, date) -> keep as DATE_ADD(UNIT, amount, date)
479            "DATE_ADD" => {
480                if f.args.len() == 2 {
481                    let is_simple_number = matches!(
482                        &f.args[1],
483                        Expression::Literal(lit) if matches!(lit.as_ref(), crate::expressions::Literal::Number(_))
484                    ) || matches!(&f.args[1], Expression::Neg(_));
485                    if is_simple_number {
486                        // Keep as DATE_ADD(date, num_days)
487                        Ok(Expression::Function(Box::new(Function::new(
488                            "DATE_ADD".to_string(),
489                            f.args,
490                        ))))
491                    } else {
492                        let mut args = f.args;
493                        let date = args.remove(0);
494                        let interval = args.remove(0);
495                        let unit = Expression::Identifier(crate::expressions::Identifier {
496                            name: "DAY".to_string(),
497                            quoted: false,
498                            trailing_comments: Vec::new(),
499                            span: None,
500                        });
501                        Ok(Expression::Function(Box::new(Function::new(
502                            "DATEADD".to_string(),
503                            vec![unit, interval, date],
504                        ))))
505                    }
506                } else {
507                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
508                    Ok(Expression::Function(Box::new(Function::new(
509                        "DATE_ADD".to_string(),
510                        transformed_args,
511                    ))))
512                }
513            }
514
515            // DATEDIFF is native in Databricks - uppercase the unit if present
516            // 2-arg: DATEDIFF(end, start) -> DATEDIFF(DAY, start, end)
517            // 3-arg: DATEDIFF(unit, start, end) -> DATEDIFF(UNIT, start, end)
518            "DATEDIFF" => {
519                if f.args.len() == 2 {
520                    let mut args = f.args;
521                    let end_date = args.remove(0);
522                    let start_date = args.remove(0);
523                    let unit = Expression::Identifier(crate::expressions::Identifier {
524                        name: "DAY".to_string(),
525                        quoted: false,
526                        trailing_comments: Vec::new(),
527                        span: None,
528                    });
529                    Ok(Expression::Function(Box::new(Function::new(
530                        "DATEDIFF".to_string(),
531                        vec![unit, start_date, end_date],
532                    ))))
533                } else {
534                    let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
535                    Ok(Expression::Function(Box::new(Function::new(
536                        "DATEDIFF".to_string(),
537                        transformed_args,
538                    ))))
539                }
540            }
541
542            // DATE_DIFF -> DATEDIFF with uppercased unit
543            "DATE_DIFF" => {
544                let transformed_args = self.uppercase_first_arg_if_identifier(f.args);
545                Ok(Expression::Function(Box::new(Function::new(
546                    "DATEDIFF".to_string(),
547                    transformed_args,
548                ))))
549            }
550
551            // JSON_EXTRACT -> Use colon operator in generation, but keep as function for now
552            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
553
554            // JSON_EXTRACT_SCALAR -> same handling
555            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(f))),
556
557            // GET_JSON_OBJECT -> colon syntax in Databricks
558            // GET_JSON_OBJECT(col, '$.path') becomes col:path
559            "GET_JSON_OBJECT" if f.args.len() == 2 => {
560                let mut args = f.args;
561                let col = args.remove(0);
562                let path_arg = args.remove(0);
563
564                // Extract and strip the $. prefix from the path
565                let path_expr = match &path_arg {
566                    Expression::Literal(lit)
567                        if matches!(lit.as_ref(), crate::expressions::Literal::String(_)) =>
568                    {
569                        let crate::expressions::Literal::String(s) = lit.as_ref() else {
570                            unreachable!()
571                        };
572                        // Strip leading '$.' if present
573                        let stripped = if s.starts_with("$.") {
574                            &s[2..]
575                        } else if s.starts_with("$") {
576                            &s[1..]
577                        } else {
578                            s.as_str()
579                        };
580                        Expression::Literal(Box::new(crate::expressions::Literal::String(
581                            stripped.to_string(),
582                        )))
583                    }
584                    _ => path_arg,
585                };
586
587                Ok(Expression::JSONExtract(Box::new(JSONExtract {
588                    this: Box::new(col),
589                    expression: Box::new(path_expr),
590                    only_json_types: None,
591                    expressions: Vec::new(),
592                    variant_extract: Some(Box::new(Expression::true_())),
593                    json_query: None,
594                    option: None,
595                    quote: None,
596                    on_condition: None,
597                    requires_json: None,
598                })))
599            }
600
601            // FROM_JSON is native in Databricks
602            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
603
604            // PARSE_JSON is native in Databricks
605            "PARSE_JSON" => Ok(Expression::Function(Box::new(f))),
606
607            // COLLECT_LIST is native in Databricks
608            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
609
610            // COLLECT_SET is native in Databricks
611            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
612
613            // RLIKE is native in Databricks
614            "RLIKE" => Ok(Expression::Function(Box::new(f))),
615
616            // REGEXP -> RLIKE in Databricks
617            "REGEXP" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
618                "RLIKE".to_string(),
619                f.args,
620            )))),
621
622            // REGEXP_LIKE is native in Databricks
623            "REGEXP_LIKE" => Ok(Expression::Function(Box::new(f))),
624
625            // LEVENSHTEIN is native in Databricks
626            "LEVENSHTEIN" => Ok(Expression::Function(Box::new(f))),
627
628            // SEQUENCE is native (for GENERATE_SERIES)
629            "GENERATE_SERIES" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
630                Function::new("SEQUENCE".to_string(), f.args),
631            ))),
632
633            // SEQUENCE is native
634            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
635
636            // FLATTEN is native in Databricks
637            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
638
639            // ARRAY_SORT is native
640            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
641
642            // ARRAY_DISTINCT is native
643            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
644
645            // TRANSFORM is native (for array transformation)
646            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
647
648            // FILTER is native (for array filtering)
649            "FILTER" => Ok(Expression::Function(Box::new(f))),
650
651            // FROM_UTC_TIMESTAMP - wrap first argument in CAST(... AS TIMESTAMP) if not already
652            "FROM_UTC_TIMESTAMP" if f.args.len() >= 2 => {
653                let mut args = f.args;
654                let first_arg = args.remove(0);
655
656                // Check if first arg is already a Cast to TIMESTAMP
657                let wrapped_arg = if self.is_cast_to_timestamp(&first_arg) {
658                    first_arg
659                } else {
660                    // Wrap in CAST(... AS TIMESTAMP)
661                    Expression::Cast(Box::new(Cast {
662                        this: first_arg,
663                        to: DataType::Timestamp {
664                            precision: None,
665                            timezone: false,
666                        },
667                        trailing_comments: Vec::new(),
668                        double_colon_syntax: false,
669                        format: None,
670                        default: None,
671                        inferred_type: None,
672                    }))
673                };
674
675                let mut new_args = vec![wrapped_arg];
676                new_args.extend(args);
677
678                Ok(Expression::Function(Box::new(Function::new(
679                    "FROM_UTC_TIMESTAMP".to_string(),
680                    new_args,
681                ))))
682            }
683
684            // UNIFORM(low, high, RANDOM(seed)) -> UNIFORM(low, high, seed) or UNIFORM(low, high)
685            "UNIFORM" if f.args.len() == 3 => {
686                let mut args = f.args;
687                let low = args.remove(0);
688                let high = args.remove(0);
689                let gen = args.remove(0);
690                match gen {
691                    Expression::Function(func) if func.name.to_uppercase() == "RANDOM" => {
692                        if func.args.len() == 1 {
693                            // RANDOM(seed) -> extract seed
694                            let seed = func.args.into_iter().next().unwrap();
695                            Ok(Expression::Function(Box::new(Function::new(
696                                "UNIFORM".to_string(),
697                                vec![low, high, seed],
698                            ))))
699                        } else {
700                            // RANDOM() -> drop gen arg
701                            Ok(Expression::Function(Box::new(Function::new(
702                                "UNIFORM".to_string(),
703                                vec![low, high],
704                            ))))
705                        }
706                    }
707                    Expression::Rand(r) => {
708                        if let Some(seed) = r.seed {
709                            Ok(Expression::Function(Box::new(Function::new(
710                                "UNIFORM".to_string(),
711                                vec![low, high, *seed],
712                            ))))
713                        } else {
714                            Ok(Expression::Function(Box::new(Function::new(
715                                "UNIFORM".to_string(),
716                                vec![low, high],
717                            ))))
718                        }
719                    }
720                    _ => Ok(Expression::Function(Box::new(Function::new(
721                        "UNIFORM".to_string(),
722                        vec![low, high, gen],
723                    )))),
724                }
725            }
726
727            // REGEXP_SUBSTR(subject, pattern, ...) -> REGEXP_EXTRACT(subject, pattern)
728            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
729                let subject = f.args[0].clone();
730                let pattern = f.args[1].clone();
731                Ok(Expression::Function(Box::new(Function::new(
732                    "REGEXP_EXTRACT".to_string(),
733                    vec![subject, pattern],
734                ))))
735            }
736
737            // BIT_GET -> GETBIT
738            "BIT_GET" => Ok(Expression::Function(Box::new(Function::new(
739                "GETBIT".to_string(),
740                f.args,
741            )))),
742
743            // Pass through everything else
744            _ => Ok(Expression::Function(Box::new(f))),
745        }
746    }
747
748    fn transform_aggregate_function(
749        &self,
750        f: Box<crate::expressions::AggregateFunction>,
751    ) -> Result<Expression> {
752        let name_upper = f.name.to_uppercase();
753        match name_upper.as_str() {
754            // COUNT_IF is native in Databricks (Spark 3+)
755            "COUNT_IF" => Ok(Expression::AggregateFunction(f)),
756
757            // ANY_VALUE is native in Databricks (Spark 3+)
758            "ANY_VALUE" => Ok(Expression::AggregateFunction(f)),
759
760            // GROUP_CONCAT -> COLLECT_LIST + ARRAY_JOIN
761            "GROUP_CONCAT" if !f.args.is_empty() => {
762                let mut args = f.args;
763                let first = args.remove(0);
764                let separator = args.pop();
765                let collect_list = Expression::Function(Box::new(Function::new(
766                    "COLLECT_LIST".to_string(),
767                    vec![first],
768                )));
769                if let Some(sep) = separator {
770                    Ok(Expression::Function(Box::new(Function::new(
771                        "ARRAY_JOIN".to_string(),
772                        vec![collect_list, sep],
773                    ))))
774                } else {
775                    Ok(Expression::Function(Box::new(Function::new(
776                        "ARRAY_JOIN".to_string(),
777                        vec![collect_list],
778                    ))))
779                }
780            }
781
782            // STRING_AGG -> COLLECT_LIST + ARRAY_JOIN
783            "STRING_AGG" if !f.args.is_empty() => {
784                let mut args = f.args;
785                let first = args.remove(0);
786                let separator = args.pop();
787                let collect_list = Expression::Function(Box::new(Function::new(
788                    "COLLECT_LIST".to_string(),
789                    vec![first],
790                )));
791                if let Some(sep) = separator {
792                    Ok(Expression::Function(Box::new(Function::new(
793                        "ARRAY_JOIN".to_string(),
794                        vec![collect_list, sep],
795                    ))))
796                } else {
797                    Ok(Expression::Function(Box::new(Function::new(
798                        "ARRAY_JOIN".to_string(),
799                        vec![collect_list],
800                    ))))
801                }
802            }
803
804            // LISTAGG -> COLLECT_LIST + ARRAY_JOIN
805            "LISTAGG" if !f.args.is_empty() => {
806                let mut args = f.args;
807                let first = args.remove(0);
808                let separator = args.pop();
809                let collect_list = Expression::Function(Box::new(Function::new(
810                    "COLLECT_LIST".to_string(),
811                    vec![first],
812                )));
813                if let Some(sep) = separator {
814                    Ok(Expression::Function(Box::new(Function::new(
815                        "ARRAY_JOIN".to_string(),
816                        vec![collect_list, sep],
817                    ))))
818                } else {
819                    Ok(Expression::Function(Box::new(Function::new(
820                        "ARRAY_JOIN".to_string(),
821                        vec![collect_list],
822                    ))))
823                }
824            }
825
826            // ARRAY_AGG -> COLLECT_LIST
827            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
828                "COLLECT_LIST".to_string(),
829                f.args,
830            )))),
831
832            // STDDEV is native in Databricks
833            "STDDEV" => Ok(Expression::AggregateFunction(f)),
834
835            // VARIANCE is native in Databricks
836            "VARIANCE" => Ok(Expression::AggregateFunction(f)),
837
838            // APPROX_COUNT_DISTINCT is native in Databricks
839            "APPROX_COUNT_DISTINCT" => Ok(Expression::AggregateFunction(f)),
840
841            // APPROX_DISTINCT -> APPROX_COUNT_DISTINCT
842            "APPROX_DISTINCT" if !f.args.is_empty() => {
843                Ok(Expression::AggregateFunction(Box::new(AggregateFunction {
844                    name: "APPROX_COUNT_DISTINCT".to_string(),
845                    args: f.args,
846                    distinct: f.distinct,
847                    filter: f.filter,
848                    order_by: Vec::new(),
849                    limit: None,
850                    ignore_nulls: None,
851                    inferred_type: None,
852                })))
853            }
854
855            // Pass through everything else
856            _ => Ok(Expression::AggregateFunction(f)),
857        }
858    }
859
860    /// Transform Cast expressions - handles typed literals being cast
861    ///
862    /// When we have a typed literal (TIMESTAMP 'x', DATE 'x', TIME 'x') being cast to another type,
863    /// Databricks/Spark transforms it as follows:
864    ///
865    /// `TIMESTAMP 'x'::TYPE` -> `CAST(CAST('x' AS TYPE) AS TIMESTAMP)`
866    ///
867    /// This reverses the types - the inner cast is to the target type,
868    /// the outer cast is to the original literal type.
869    fn transform_cast(&self, c: Cast) -> Result<Expression> {
870        // Check if the inner expression is a typed literal
871        match &c.this {
872            // TIMESTAMP 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIMESTAMP)
873            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Timestamp(_)) => {
874                let Literal::Timestamp(value) = lit.as_ref() else {
875                    unreachable!()
876                };
877                // Create inner cast: CAST('value' AS target_type)
878                let inner_cast = Expression::Cast(Box::new(Cast {
879                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
880                    to: c.to,
881                    trailing_comments: Vec::new(),
882                    double_colon_syntax: false,
883                    format: None,
884                    default: None,
885                    inferred_type: None,
886                }));
887                // Create outer cast: CAST(inner_cast AS TIMESTAMP)
888                Ok(Expression::Cast(Box::new(Cast {
889                    this: inner_cast,
890                    to: DataType::Timestamp {
891                        precision: None,
892                        timezone: false,
893                    },
894                    trailing_comments: c.trailing_comments,
895                    double_colon_syntax: false,
896                    format: None,
897                    default: None,
898                    inferred_type: None,
899                })))
900            }
901            // DATE 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS DATE)
902            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Date(_)) => {
903                let Literal::Date(value) = lit.as_ref() else {
904                    unreachable!()
905                };
906                let inner_cast = Expression::Cast(Box::new(Cast {
907                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
908                    to: c.to,
909                    trailing_comments: Vec::new(),
910                    double_colon_syntax: false,
911                    format: None,
912                    default: None,
913                    inferred_type: None,
914                }));
915                Ok(Expression::Cast(Box::new(Cast {
916                    this: inner_cast,
917                    to: DataType::Date,
918                    trailing_comments: c.trailing_comments,
919                    double_colon_syntax: false,
920                    format: None,
921                    default: None,
922                    inferred_type: None,
923                })))
924            }
925            // TIME 'value'::TYPE -> CAST(CAST('value' AS TYPE) AS TIME)
926            Expression::Literal(lit) if matches!(lit.as_ref(), Literal::Time(_)) => {
927                let Literal::Time(value) = lit.as_ref() else {
928                    unreachable!()
929                };
930                let inner_cast = Expression::Cast(Box::new(Cast {
931                    this: Expression::Literal(Box::new(Literal::String(value.clone()))),
932                    to: c.to,
933                    trailing_comments: Vec::new(),
934                    double_colon_syntax: false,
935                    format: None,
936                    default: None,
937                    inferred_type: None,
938                }));
939                Ok(Expression::Cast(Box::new(Cast {
940                    this: inner_cast,
941                    to: DataType::Time {
942                        precision: None,
943                        timezone: false,
944                    },
945                    trailing_comments: c.trailing_comments,
946                    double_colon_syntax: false,
947                    format: None,
948                    default: None,
949                    inferred_type: None,
950                })))
951            }
952            // For all other cases, pass through the Cast unchanged
953            _ => Ok(Expression::Cast(Box::new(c))),
954        }
955    }
956
957    /// Check if an expression is a CAST to TIMESTAMP
958    fn is_cast_to_timestamp(&self, expr: &Expression) -> bool {
959        if let Expression::Cast(cast) = expr {
960            matches!(cast.to, DataType::Timestamp { .. })
961        } else {
962            false
963        }
964    }
965
966    /// Helper to uppercase the first argument if it's an identifier or column (for DATEDIFF, DATEADD units)
967    fn uppercase_first_arg_if_identifier(&self, mut args: Vec<Expression>) -> Vec<Expression> {
968        use crate::expressions::Identifier;
969        if !args.is_empty() {
970            match &args[0] {
971                Expression::Identifier(id) => {
972                    args[0] = Expression::Identifier(Identifier {
973                        name: id.name.to_uppercase(),
974                        quoted: id.quoted,
975                        trailing_comments: id.trailing_comments.clone(),
976                        span: None,
977                    });
978                }
979                Expression::Var(v) => {
980                    args[0] = Expression::Identifier(Identifier {
981                        name: v.this.to_uppercase(),
982                        quoted: false,
983                        trailing_comments: Vec::new(),
984                        span: None,
985                    });
986                }
987                Expression::Column(col) if col.table.is_none() => {
988                    // Unqualified column name like "day" should be treated as a unit
989                    args[0] = Expression::Identifier(Identifier {
990                        name: col.name.name.to_uppercase(),
991                        quoted: col.name.quoted,
992                        trailing_comments: col.name.trailing_comments.clone(),
993                        span: None,
994                    });
995                }
996                _ => {}
997            }
998        }
999        args
1000    }
1001}
1002
1003#[cfg(test)]
1004mod tests {
1005    use super::*;
1006    use crate::Dialect;
1007
1008    #[test]
1009    fn test_timestamp_literal_cast() {
1010        // TIMESTAMP 'value'::DATE -> CAST(CAST('value' AS DATE) AS TIMESTAMP)
1011        // This is test [47] in the Databricks dialect identity fixtures
1012        let sql = "SELECT TIMESTAMP '2025-04-29 18.47.18'::DATE";
1013        let expected = "SELECT CAST(CAST('2025-04-29 18.47.18' AS DATE) AS TIMESTAMP)";
1014
1015        let d = Dialect::get(DialectType::Databricks);
1016        let ast = d.parse(sql).expect("Parse failed");
1017        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1018        let output = d.generate(&transformed).expect("Generate failed");
1019
1020        assert_eq!(
1021            output, expected,
1022            "Timestamp literal cast transformation failed"
1023        );
1024    }
1025
1026    #[test]
1027    fn test_from_utc_timestamp_wraps_column() {
1028        // Test [48]: FROM_UTC_TIMESTAMP(foo, 'timezone') -> FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'timezone')
1029        let sql = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(foo, 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1030        let expected = "SELECT DATE_FORMAT(CAST(FROM_UTC_TIMESTAMP(CAST(foo AS TIMESTAMP), 'America/Los_Angeles') AS TIMESTAMP), 'yyyy-MM-dd HH:mm:ss') AS foo FROM t";
1031
1032        let d = Dialect::get(DialectType::Databricks);
1033        let ast = d.parse(sql).expect("Parse failed");
1034        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1035        let output = d.generate(&transformed).expect("Generate failed");
1036
1037        assert_eq!(output, expected, "FROM_UTC_TIMESTAMP transformation failed");
1038    }
1039
1040    #[test]
1041    fn test_from_utc_timestamp_keeps_existing_cast() {
1042        // Test [50]: FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz) -> FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)
1043        // When already cast to TIMESTAMP, keep it but convert :: syntax to CAST()
1044        let sql = "FROM_UTC_TIMESTAMP(x::TIMESTAMP, tz)";
1045        let expected = "FROM_UTC_TIMESTAMP(CAST(x AS TIMESTAMP), tz)";
1046
1047        let d = Dialect::get(DialectType::Databricks);
1048        let ast = d.parse(sql).expect("Parse failed");
1049        let transformed = d.transform(ast[0].clone()).expect("Transform failed");
1050        let output = d.generate(&transformed).expect("Generate failed");
1051
1052        assert_eq!(
1053            output, expected,
1054            "FROM_UTC_TIMESTAMP with existing CAST failed"
1055        );
1056    }
1057
1058    #[test]
1059    fn test_deep_clone_version_as_of() {
1060        let sql = "CREATE TABLE events_clone DEEP CLONE events VERSION AS OF 5";
1061        let d = Dialect::get(DialectType::Databricks);
1062        let ast = d.parse(sql).expect("Parse failed");
1063        let output = d.generate(&ast[0]).expect("Generate failed");
1064
1065        assert_eq!(output, sql);
1066    }
1067
1068    #[test]
1069    fn test_deep_clone_timestamp_as_of() {
1070        let sql = "CREATE TABLE events_clone DEEP CLONE events TIMESTAMP AS OF '2024-01-01'";
1071        let d = Dialect::get(DialectType::Databricks);
1072        let ast = d.parse(sql).expect("Parse failed");
1073        let output = d.generate(&ast[0]).expect("Generate failed");
1074
1075        assert_eq!(output, sql);
1076    }
1077
1078    #[test]
1079    fn test_shallow_clone_still_roundtrips() {
1080        let sql = "CREATE TABLE events_clone SHALLOW CLONE events";
1081        let d = Dialect::get(DialectType::Databricks);
1082        let ast = d.parse(sql).expect("Parse failed");
1083        let output = d.generate(&ast[0]).expect("Generate failed");
1084
1085        assert_eq!(output, sql);
1086    }
1087
1088    #[test]
1089    fn test_repair_table_commands_roundtrip() {
1090        let d = Dialect::get(DialectType::Databricks);
1091        let cases = [
1092            "REPAIR TABLE events",
1093            "MSCK REPAIR TABLE events",
1094            "REPAIR TABLE events ADD PARTITIONS",
1095            "REPAIR TABLE events DROP PARTITIONS",
1096            "REPAIR TABLE events SYNC PARTITIONS",
1097            "REPAIR TABLE events SYNC METADATA",
1098        ];
1099
1100        for sql in cases {
1101            let ast = d.parse(sql).expect("Parse failed");
1102            let output = d.generate(&ast[0]).expect("Generate failed");
1103            assert_eq!(output, sql);
1104        }
1105    }
1106
1107    #[test]
1108    fn test_apply_changes_commands_roundtrip() {
1109        let d = Dialect::get(DialectType::Databricks);
1110        let cases = [
1111            "APPLY CHANGES INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1112            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) IGNORE NULL UPDATES SEQUENCE BY ts",
1113            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) APPLY AS DELETE WHEN operation = 'DELETE' SEQUENCE BY ts COLUMNS * EXCEPT (operation) STORED AS SCD TYPE 1",
1114            "APPLY CHANGES INTO LIVE.silver_orders FROM STREAM(LIVE.bronze_orders) KEYS (id) SEQUENCE BY ts STORED AS SCD TYPE 2 TRACK HISTORY ON * EXCEPT (updated_at)",
1115            "AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1116            "CREATE FLOW apply_cdc AS AUTO CDC INTO silver.orders FROM STREAM(bronze.orders) KEYS (id) SEQUENCE BY ts",
1117        ];
1118
1119        for sql in cases {
1120            let ast = d.parse(sql).expect("Parse failed");
1121            let output = d.generate(&ast[0]).expect("Generate failed");
1122            assert_eq!(output, sql);
1123        }
1124    }
1125
1126    #[test]
1127    fn test_generate_symlink_format_manifest_roundtrip() {
1128        let d = Dialect::get(DialectType::Databricks);
1129        let cases = [
1130            "GENERATE symlink_format_manifest FOR TABLE events",
1131            "GENERATE symlink_format_manifest FOR TABLE catalog.schema.events",
1132        ];
1133
1134        for sql in cases {
1135            let ast = d.parse(sql).expect("Parse failed");
1136            let output = d.generate(&ast[0]).expect("Generate failed");
1137            assert_eq!(output, sql);
1138        }
1139    }
1140
1141    #[test]
1142    fn test_convert_to_delta_roundtrip() {
1143        let d = Dialect::get(DialectType::Databricks);
1144        let cases = [
1145            "CONVERT TO DELTA parquet.`/mnt/data/events`",
1146            "CONVERT TO DELTA database_name.table_name",
1147            "CONVERT TO DELTA parquet.`s3://my-bucket/path/to/table` PARTITIONED BY (date DATE)",
1148            "CONVERT TO DELTA database_name.table_name NO STATISTICS",
1149        ];
1150
1151        for sql in cases {
1152            let ast = d.parse(sql).expect("Parse failed");
1153            let output = d.generate(&ast[0]).expect("Generate failed");
1154            assert_eq!(output, sql);
1155        }
1156    }
1157}