Skip to main content

polyglot_sql/dialects/
spark.rs

1//! Spark SQL Dialect
2//!
3//! Spark SQL-specific transformations based on sqlglot patterns.
4//! Key features (extends Hive with modern SQL):
5//! - TRY_CAST is supported (Spark 3+)
6//! - ILIKE is supported (Spark 3+)
7//! - Uses backticks for identifiers
8//! - ARRAY_AGG, COLLECT_LIST for array aggregation
9//! - STRING_AGG / LISTAGG supported (Spark 4+)
10//! - DATE_ADD with unit parameter (Spark 3+)
11//! - TIMESTAMPADD, TIMESTAMPDIFF (Spark 3+)
12//! - More PostgreSQL-like syntax than Hive
13
14use super::{DialectImpl, DialectType};
15use crate::error::Result;
16use crate::expressions::{
17    CeilFunc, CurrentTimestamp, DataType, DateTimeField, Expression, ExtractFunc, Function,
18    Literal, StructField, UnaryFunc, VarArgFunc,
19};
20use crate::generator::GeneratorConfig;
21use crate::tokens::TokenizerConfig;
22
23/// Spark SQL dialect
24pub struct SparkDialect;
25
26impl DialectImpl for SparkDialect {
27    fn dialect_type(&self) -> DialectType {
28        DialectType::Spark
29    }
30
31    fn tokenizer_config(&self) -> TokenizerConfig {
32        let mut config = TokenizerConfig::default();
33        // Spark uses backticks for identifiers (NOT double quotes)
34        config.identifiers.clear();
35        config.identifiers.insert('`', '`');
36        // Spark (like Hive) uses double quotes as string delimiters (QUOTES = ["'", '"'])
37        config.quotes.insert("\"".to_string(), "\"".to_string());
38        // Spark (like Hive) uses backslash escapes in strings (STRING_ESCAPES = ["\\"])
39        config.string_escapes.push('\\');
40        // Spark supports DIV keyword for integer division (inherited from Hive)
41        config
42            .keywords
43            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
44        // Spark numeric literal suffixes (same as Hive): 1L -> BIGINT, 1S -> SMALLINT, etc.
45        config
46            .numeric_literals
47            .insert("L".to_string(), "BIGINT".to_string());
48        config
49            .numeric_literals
50            .insert("S".to_string(), "SMALLINT".to_string());
51        config
52            .numeric_literals
53            .insert("Y".to_string(), "TINYINT".to_string());
54        config
55            .numeric_literals
56            .insert("D".to_string(), "DOUBLE".to_string());
57        config
58            .numeric_literals
59            .insert("F".to_string(), "FLOAT".to_string());
60        config
61            .numeric_literals
62            .insert("BD".to_string(), "DECIMAL".to_string());
63        // Spark allows identifiers to start with digits (e.g., 1a, 1_a)
64        config.identifiers_can_start_with_digit = true;
65        // Spark: STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
66        // Backslashes in raw strings are always literal (no escape processing)
67        config.string_escapes_allowed_in_raw_strings = false;
68        config
69    }
70
71    fn generator_config(&self) -> GeneratorConfig {
72        use crate::generator::IdentifierQuoteStyle;
73        GeneratorConfig {
74            identifier_quote: '`',
75            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
76            dialect: Some(DialectType::Spark),
77            // Spark uses colon separator in STRUCT field definitions: STRUCT<field_name: TYPE>
78            struct_field_sep: ": ",
79            // Spark doesn't use AS before RETURN in function definitions
80            create_function_return_as: false,
81            // Spark places alias after the TABLESAMPLE clause
82            alias_post_tablesample: true,
83            tablesample_seed_keyword: "REPEATABLE",
84            join_hints: false,
85            identifiers_can_start_with_digit: true,
86            // Spark uses COMMENT 'value' without = sign
87            schema_comment_with_eq: false,
88            ..Default::default()
89        }
90    }
91
92    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
93        match expr {
94            // IFNULL -> COALESCE in Spark
95            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
96                original_name: None,
97                expressions: vec![f.this, f.expression],
98            }))),
99
100            // NVL is supported in Spark (from Hive), but COALESCE is standard
101            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
102                original_name: None,
103                expressions: vec![f.this, f.expression],
104            }))),
105
106            // Cast: normalize VARCHAR(n) -> STRING, CHAR(n) -> STRING for Spark
107            Expression::Cast(mut c) => {
108                c.to = Self::normalize_spark_type(c.to);
109                Ok(Expression::Cast(c))
110            }
111
112            // TryCast stays as TryCast in Spark (Spark supports TRY_CAST natively)
113            Expression::TryCast(mut c) => {
114                c.to = Self::normalize_spark_type(c.to);
115                Ok(Expression::TryCast(c))
116            }
117
118            // SafeCast -> TRY_CAST
119            Expression::SafeCast(mut c) => {
120                c.to = Self::normalize_spark_type(c.to);
121                Ok(Expression::TryCast(c))
122            }
123
124            // TRIM: non-standard comma syntax -> standard FROM syntax
125            // TRIM('SL', 'SSparkSQLS') -> TRIM('SL' FROM 'SSparkSQLS')
126            Expression::Trim(mut t) => {
127                if !t.sql_standard_syntax && t.characters.is_some() {
128                    // Convert comma syntax to standard SQL syntax
129                    // Fields already have correct semantics: this=string, characters=chars
130                    t.sql_standard_syntax = true;
131                }
132                Ok(Expression::Trim(t))
133            }
134
135            // ILIKE is supported in Spark 3+
136            Expression::ILike(op) => Ok(Expression::ILike(op)),
137
138            // UNNEST -> EXPLODE in Spark (Hive compatibility)
139            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
140
141            // EXPLODE is native to Spark
142            Expression::Explode(f) => Ok(Expression::Explode(f)),
143
144            // ExplodeOuter is supported in Spark
145            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
146
147            // RANDOM -> RAND in Spark
148            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
149                seed: None,
150                lower: None,
151                upper: None,
152            }))),
153
154            // Rand is native to Spark
155            Expression::Rand(r) => Ok(Expression::Rand(r)),
156
157            // || (Concat) -> CONCAT in Spark
158            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
159                "CONCAT".to_string(),
160                vec![op.left, op.right],
161            )))),
162
163            // ParseJson: handled by generator (emits just the string literal for Spark)
164
165            // Generic function transformations
166            Expression::Function(f) => self.transform_function(*f),
167
168            // Generic aggregate function transformations
169            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
170
171            // $N parameters -> ${N} in Spark (DollarBrace style)
172            Expression::Parameter(mut p)
173                if p.style == crate::expressions::ParameterStyle::Dollar =>
174            {
175                p.style = crate::expressions::ParameterStyle::DollarBrace;
176                // Convert index to name for DollarBrace format
177                if let Some(idx) = p.index {
178                    p.name = Some(idx.to_string());
179                }
180                Ok(Expression::Parameter(p))
181            }
182
183            // JSONExtract with variant_extract (Databricks colon syntax) -> GET_JSON_OBJECT
184            Expression::JSONExtract(je) if je.variant_extract.is_some() => {
185                // Convert path: 'item[1].price' -> '$.item[1].price'
186                let path = match *je.expression {
187                    Expression::Literal(Literal::String(s)) => {
188                        Expression::Literal(Literal::String(format!("$.{}", s)))
189                    }
190                    other => other,
191                };
192                Ok(Expression::Function(Box::new(Function::new(
193                    "GET_JSON_OBJECT".to_string(),
194                    vec![*je.this, path],
195                ))))
196            }
197
198            // Pass through everything else
199            _ => Ok(expr),
200        }
201    }
202}
203
204impl SparkDialect {
205    /// Normalize a data type for Spark:
206    /// - VARCHAR/CHAR without length -> STRING
207    /// - VARCHAR(n)/CHAR(n) with length -> keep as-is
208    /// - TEXT -> STRING
209    fn normalize_spark_type(dt: DataType) -> DataType {
210        match dt {
211            DataType::VarChar { length: None, .. }
212            | DataType::Char { length: None }
213            | DataType::Text => DataType::Custom {
214                name: "STRING".to_string(),
215            },
216            // VARCHAR(n) and CHAR(n) with length are kept as-is
217            DataType::VarChar { .. } | DataType::Char { .. } => dt,
218            // Also normalize struct fields recursively
219            DataType::Struct { fields, nested } => {
220                let normalized_fields: Vec<StructField> = fields
221                    .into_iter()
222                    .map(|mut f| {
223                        f.data_type = Self::normalize_spark_type(f.data_type);
224                        f
225                    })
226                    .collect();
227                DataType::Struct {
228                    fields: normalized_fields,
229                    nested,
230                }
231            }
232            _ => dt,
233        }
234    }
235
236    fn transform_function(&self, f: Function) -> Result<Expression> {
237        let name_upper = f.name.to_uppercase();
238        match name_upper.as_str() {
239            // IFNULL -> COALESCE
240            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
241                original_name: None,
242                expressions: f.args,
243            }))),
244
245            // NVL -> COALESCE
246            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
247                original_name: None,
248                expressions: f.args,
249            }))),
250
251            // ISNULL -> COALESCE
252            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
253                original_name: None,
254                expressions: f.args,
255            }))),
256
257            // GROUP_CONCAT -> CONCAT_WS + COLLECT_LIST in older Spark
258            // In Spark 4+, STRING_AGG is available
259            "GROUP_CONCAT" if !f.args.is_empty() => {
260                // For simplicity, use COLLECT_LIST (array aggregation)
261                Ok(Expression::Function(Box::new(Function::new(
262                    "COLLECT_LIST".to_string(),
263                    f.args,
264                ))))
265            }
266
267            // STRING_AGG is supported in Spark 4+
268            // For older versions, fall back to CONCAT_WS + COLLECT_LIST
269            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
270                Function::new("COLLECT_LIST".to_string(), f.args),
271            ))),
272
273            // LISTAGG -> STRING_AGG in Spark 4+ (or COLLECT_LIST for older)
274            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
275                "COLLECT_LIST".to_string(),
276                f.args,
277            )))),
278
279            // SUBSTRING is native to Spark
280            "SUBSTRING" | "SUBSTR" => Ok(Expression::Function(Box::new(f))),
281
282            // LENGTH is native to Spark
283            "LENGTH" => Ok(Expression::Function(Box::new(f))),
284
285            // LEN -> LENGTH
286            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
287                f.args.into_iter().next().unwrap(),
288            )))),
289
290            // RANDOM -> RAND
291            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
292                seed: None,
293                lower: None,
294                upper: None,
295            }))),
296
297            // RAND is native to Spark
298            "RAND" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
299                seed: None,
300                lower: None,
301                upper: None,
302            }))),
303
304            // NOW -> CURRENT_TIMESTAMP
305            "NOW" => Ok(Expression::CurrentTimestamp(
306                crate::expressions::CurrentTimestamp {
307                    precision: None,
308                    sysdate: false,
309                },
310            )),
311
312            // GETDATE -> CURRENT_TIMESTAMP
313            "GETDATE" => Ok(Expression::CurrentTimestamp(
314                crate::expressions::CurrentTimestamp {
315                    precision: None,
316                    sysdate: false,
317                },
318            )),
319
320            // CURRENT_TIMESTAMP is native
321            "CURRENT_TIMESTAMP" => Ok(Expression::CurrentTimestamp(
322                crate::expressions::CurrentTimestamp {
323                    precision: None,
324                    sysdate: false,
325                },
326            )),
327
328            // CURRENT_DATE is native
329            "CURRENT_DATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
330
331            // TO_DATE is native to Spark; strip default format 'yyyy-MM-dd'
332            "TO_DATE" if f.args.len() == 2 => {
333                let is_default_format = matches!(&f.args[1], Expression::Literal(crate::expressions::Literal::String(s)) if s == "yyyy-MM-dd");
334                if is_default_format {
335                    Ok(Expression::Function(Box::new(Function::new(
336                        "TO_DATE".to_string(),
337                        vec![f.args.into_iter().next().unwrap()],
338                    ))))
339                } else {
340                    Ok(Expression::Function(Box::new(f)))
341                }
342            }
343            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
344
345            // TO_TIMESTAMP is native to Spark
346            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
347
348            // DATE_FORMAT is native to Spark
349            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
350
351            // strftime -> DATE_FORMAT
352            "STRFTIME" => Ok(Expression::Function(Box::new(Function::new(
353                "DATE_FORMAT".to_string(),
354                f.args,
355            )))),
356
357            // TO_CHAR -> DATE_FORMAT
358            "TO_CHAR" => Ok(Expression::Function(Box::new(Function::new(
359                "DATE_FORMAT".to_string(),
360                f.args,
361            )))),
362
363            // DATE_TRUNC is native to Spark
364            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
365
366            // TRUNC is native to Spark
367            "TRUNC" => Ok(Expression::Function(Box::new(f))),
368
369            // EXTRACT is native to Spark
370            "EXTRACT" => Ok(Expression::Function(Box::new(f))),
371
372            // DATEPART -> EXTRACT
373            "DATEPART" => Ok(Expression::Function(Box::new(Function::new(
374                "EXTRACT".to_string(),
375                f.args,
376            )))),
377
378            // UNIX_TIMESTAMP is native to Spark
379            // When called with no args, add CURRENT_TIMESTAMP() as default
380            "UNIX_TIMESTAMP" => {
381                if f.args.is_empty() {
382                    Ok(Expression::Function(Box::new(Function::new(
383                        "UNIX_TIMESTAMP".to_string(),
384                        vec![Expression::CurrentTimestamp(CurrentTimestamp {
385                            precision: None,
386                            sysdate: false,
387                        })],
388                    ))))
389                } else {
390                    Ok(Expression::Function(Box::new(f)))
391                }
392            }
393
394            // FROM_UNIXTIME is native to Spark
395            "FROM_UNIXTIME" => Ok(Expression::Function(Box::new(f))),
396
397            // STR_TO_MAP is native to Spark
398            // When called with only one arg, add default delimiters ',' and ':'
399            "STR_TO_MAP" => {
400                if f.args.len() == 1 {
401                    let mut args = f.args;
402                    args.push(Expression::Literal(crate::expressions::Literal::String(
403                        ",".to_string(),
404                    )));
405                    args.push(Expression::Literal(crate::expressions::Literal::String(
406                        ":".to_string(),
407                    )));
408                    Ok(Expression::Function(Box::new(Function::new(
409                        "STR_TO_MAP".to_string(),
410                        args,
411                    ))))
412                } else {
413                    Ok(Expression::Function(Box::new(f)))
414                }
415            }
416
417            // POSITION is native to Spark (POSITION(substr IN str))
418            "POSITION" => Ok(Expression::Function(Box::new(f))),
419
420            // LOCATE is native to Spark
421            "LOCATE" => Ok(Expression::Function(Box::new(f))),
422
423            // STRPOS -> Use expression form or LOCATE
424            "STRPOS" if f.args.len() == 2 => {
425                let mut args = f.args;
426                let first = args.remove(0);
427                let second = args.remove(0);
428                // LOCATE(substr, str) in Spark
429                Ok(Expression::Function(Box::new(Function::new(
430                    "LOCATE".to_string(),
431                    vec![second, first],
432                ))))
433            }
434
435            // CHARINDEX -> LOCATE
436            "CHARINDEX" if f.args.len() >= 2 => {
437                let mut args = f.args;
438                let substring = args.remove(0);
439                let string = args.remove(0);
440                let mut locate_args = vec![substring, string];
441                if !args.is_empty() {
442                    locate_args.push(args.remove(0));
443                }
444                Ok(Expression::Function(Box::new(Function::new(
445                    "LOCATE".to_string(),
446                    locate_args,
447                ))))
448            }
449
450            // INSTR is native to Spark
451            "INSTR" => Ok(Expression::Function(Box::new(f))),
452
453            // CEILING -> CEIL
454            "CEILING" if f.args.len() == 1 => Ok(Expression::Ceil(Box::new(CeilFunc {
455                this: f.args.into_iter().next().unwrap(),
456                decimals: None,
457                to: None,
458            }))),
459
460            // CEIL is native to Spark
461            "CEIL" if f.args.len() == 1 => Ok(Expression::Ceil(Box::new(CeilFunc {
462                this: f.args.into_iter().next().unwrap(),
463                decimals: None,
464                to: None,
465            }))),
466
467            // UNNEST -> EXPLODE
468            "UNNEST" => Ok(Expression::Function(Box::new(Function::new(
469                "EXPLODE".to_string(),
470                f.args,
471            )))),
472
473            // FLATTEN -> FLATTEN is native to Spark (for nested arrays)
474            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
475
476            // ARRAY_AGG -> COLLECT_LIST
477            "ARRAY_AGG" => Ok(Expression::Function(Box::new(Function::new(
478                "COLLECT_LIST".to_string(),
479                f.args,
480            )))),
481
482            // COLLECT_LIST is native to Spark
483            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
484
485            // COLLECT_SET is native to Spark
486            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
487
488            // ARRAY_LENGTH -> SIZE in Spark
489            "ARRAY_LENGTH" | "CARDINALITY" => Ok(Expression::Function(Box::new(Function::new(
490                "SIZE".to_string(),
491                f.args,
492            )))),
493
494            // SIZE is native to Spark
495            "SIZE" => Ok(Expression::Function(Box::new(f))),
496
497            // SPLIT is native to Spark
498            "SPLIT" => Ok(Expression::Function(Box::new(f))),
499
500            // REGEXP_REPLACE: Spark supports up to 4 args (subject, pattern, replacement, position)
501            // Strip extra Snowflake args (occurrence, params) if present
502            "REGEXP_REPLACE" if f.args.len() > 4 => {
503                let mut args = f.args;
504                args.truncate(4);
505                Ok(Expression::Function(Box::new(Function::new(
506                    "REGEXP_REPLACE".to_string(),
507                    args,
508                ))))
509            }
510            "REGEXP_REPLACE" => Ok(Expression::Function(Box::new(f))),
511
512            // REGEXP_EXTRACT is native to Spark
513            "REGEXP_EXTRACT" => Ok(Expression::Function(Box::new(f))),
514
515            // REGEXP_EXTRACT_ALL is native to Spark
516            "REGEXP_EXTRACT_ALL" => Ok(Expression::Function(Box::new(f))),
517
518            // RLIKE is native to Spark
519            "RLIKE" | "REGEXP_LIKE" => Ok(Expression::Function(Box::new(Function::new(
520                "RLIKE".to_string(),
521                f.args,
522            )))),
523
524            // JSON_EXTRACT -> GET_JSON_OBJECT (Hive style) or :: operator
525            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(Function::new(
526                "GET_JSON_OBJECT".to_string(),
527                f.args,
528            )))),
529
530            // JSON_EXTRACT_SCALAR -> GET_JSON_OBJECT
531            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(Function::new(
532                "GET_JSON_OBJECT".to_string(),
533                f.args,
534            )))),
535
536            // GET_JSON_OBJECT is native to Spark
537            "GET_JSON_OBJECT" => Ok(Expression::Function(Box::new(f))),
538
539            // FROM_JSON is native to Spark
540            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
541
542            // TO_JSON is native to Spark
543            "TO_JSON" => Ok(Expression::Function(Box::new(f))),
544
545            // PARSE_JSON -> strip for Spark (just keep the string argument)
546            "PARSE_JSON" if f.args.len() == 1 => Ok(f.args.into_iter().next().unwrap()),
547            "PARSE_JSON" => Ok(Expression::Function(Box::new(Function::new(
548                "FROM_JSON".to_string(),
549                f.args,
550            )))),
551
552            // DATEDIFF is native to Spark (supports unit in Spark 3+)
553            "DATEDIFF" | "DATE_DIFF" => Ok(Expression::Function(Box::new(Function::new(
554                "DATEDIFF".to_string(),
555                f.args,
556            )))),
557
558            // DATE_ADD is native to Spark
559            "DATE_ADD" | "DATEADD" => Ok(Expression::Function(Box::new(Function::new(
560                "DATE_ADD".to_string(),
561                f.args,
562            )))),
563
564            // DATE_SUB is native to Spark
565            "DATE_SUB" => Ok(Expression::Function(Box::new(f))),
566
567            // TIMESTAMPADD is native to Spark 3+
568            "TIMESTAMPADD" => Ok(Expression::Function(Box::new(f))),
569
570            // TIMESTAMPDIFF is native to Spark 3+
571            "TIMESTAMPDIFF" => Ok(Expression::Function(Box::new(f))),
572
573            // ADD_MONTHS is native to Spark
574            "ADD_MONTHS" => Ok(Expression::Function(Box::new(f))),
575
576            // MONTHS_BETWEEN is native to Spark
577            "MONTHS_BETWEEN" => Ok(Expression::Function(Box::new(f))),
578
579            // NVL is native to Spark
580            "NVL" => Ok(Expression::Function(Box::new(f))),
581
582            // NVL2 is native to Spark
583            "NVL2" => Ok(Expression::Function(Box::new(f))),
584
585            // MAP is native to Spark
586            "MAP" => Ok(Expression::Function(Box::new(f))),
587
588            // ARRAY is native to Spark
589            "ARRAY" => Ok(Expression::Function(Box::new(f))),
590
591            // ROW -> STRUCT for Spark (cross-dialect, no auto-naming)
592            "ROW" => Ok(Expression::Function(Box::new(Function::new(
593                "STRUCT".to_string(),
594                f.args,
595            )))),
596
597            // STRUCT is native to Spark - auto-name unnamed args as col1, col2, etc.
598            "STRUCT" => {
599                let mut col_idx = 1usize;
600                let named_args: Vec<Expression> = f
601                    .args
602                    .into_iter()
603                    .map(|arg| {
604                        let current_idx = col_idx;
605                        col_idx += 1;
606                        // Check if arg already has an alias (AS name) or is Star
607                        match &arg {
608                            Expression::Alias(_) => arg, // already named
609                            Expression::Star(_) => arg,  // STRUCT(*) - keep as-is
610                            Expression::Column(c) if c.table.is_none() => {
611                                // Column reference: use column name as the struct field name
612                                let name = c.name.name.clone();
613                                Expression::Alias(Box::new(crate::expressions::Alias {
614                                    this: arg,
615                                    alias: crate::expressions::Identifier::new(&name),
616                                    column_aliases: Vec::new(),
617                                    pre_alias_comments: Vec::new(),
618                                    trailing_comments: Vec::new(),
619                                }))
620                            }
621                            _ => {
622                                // Unnamed literal/expression: auto-name as colN
623                                let name = format!("col{}", current_idx);
624                                Expression::Alias(Box::new(crate::expressions::Alias {
625                                    this: arg,
626                                    alias: crate::expressions::Identifier::new(&name),
627                                    column_aliases: Vec::new(),
628                                    pre_alias_comments: Vec::new(),
629                                    trailing_comments: Vec::new(),
630                                }))
631                            }
632                        }
633                    })
634                    .collect();
635                Ok(Expression::Function(Box::new(Function {
636                    name: "STRUCT".to_string(),
637                    args: named_args,
638                    distinct: false,
639                    trailing_comments: Vec::new(),
640                    use_bracket_syntax: false,
641                    no_parens: false,
642                    quoted: false,
643                    span: None,
644                })))
645            }
646
647            // NAMED_STRUCT is native to Spark
648            "NAMED_STRUCT" => Ok(Expression::Function(Box::new(f))),
649
650            // MAP_FROM_ARRAYS is native to Spark
651            "MAP_FROM_ARRAYS" => Ok(Expression::Function(Box::new(f))),
652
653            // ARRAY_SORT is native to Spark
654            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
655
656            // ARRAY_DISTINCT is native to Spark
657            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
658
659            // ARRAY_UNION is native to Spark
660            "ARRAY_UNION" => Ok(Expression::Function(Box::new(f))),
661
662            // ARRAY_INTERSECT is native to Spark
663            "ARRAY_INTERSECT" => Ok(Expression::Function(Box::new(f))),
664
665            // ARRAY_EXCEPT is native to Spark
666            "ARRAY_EXCEPT" => Ok(Expression::Function(Box::new(f))),
667
668            // ARRAY_CONTAINS is native to Spark
669            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
670
671            // ELEMENT_AT is native to Spark
672            "ELEMENT_AT" => Ok(Expression::Function(Box::new(f))),
673
674            // TRY_ELEMENT_AT is native to Spark 3+
675            "TRY_ELEMENT_AT" => Ok(Expression::Function(Box::new(f))),
676
677            // TRANSFORM is native to Spark (array transformation)
678            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
679
680            // FILTER is native to Spark (array filtering)
681            "FILTER" => Ok(Expression::Function(Box::new(f))),
682
683            // AGGREGATE is native to Spark (array reduction)
684            "AGGREGATE" => Ok(Expression::Function(Box::new(f))),
685
686            // SEQUENCE is native to Spark (generate array)
687            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
688
689            // GENERATE_SERIES -> SEQUENCE
690            "GENERATE_SERIES" => Ok(Expression::Function(Box::new(Function::new(
691                "SEQUENCE".to_string(),
692                f.args,
693            )))),
694
695            // STARTSWITH is native to Spark 3+
696            "STARTSWITH" | "STARTS_WITH" => Ok(Expression::Function(Box::new(Function::new(
697                "STARTSWITH".to_string(),
698                f.args,
699            )))),
700
701            // ENDSWITH is native to Spark 3+
702            "ENDSWITH" | "ENDS_WITH" => Ok(Expression::Function(Box::new(Function::new(
703                "ENDSWITH".to_string(),
704                f.args,
705            )))),
706
707            // ARRAY_CONSTRUCT_COMPACT(1, null, 2) -> ARRAY_COMPACT(ARRAY(1, NULL, 2))
708            "ARRAY_CONSTRUCT_COMPACT" => {
709                let inner =
710                    Expression::Function(Box::new(Function::new("ARRAY".to_string(), f.args)));
711                Ok(Expression::Function(Box::new(Function::new(
712                    "ARRAY_COMPACT".to_string(),
713                    vec![inner],
714                ))))
715            }
716
717            // ARRAY_TO_STRING -> ARRAY_JOIN
718            "ARRAY_TO_STRING" => Ok(Expression::Function(Box::new(Function::new(
719                "ARRAY_JOIN".to_string(),
720                f.args,
721            )))),
722
723            // TO_ARRAY(x) -> IF(x IS NULL, NULL, ARRAY(x))
724            "TO_ARRAY" if f.args.len() == 1 => {
725                let x = f.args[0].clone();
726                // Check if arg is already an array constructor (bracket notation)
727                // In that case: TO_ARRAY(['test']) -> ARRAY('test')
728                match &x {
729                    Expression::ArrayFunc(arr) => {
730                        // Just convert to ARRAY(...) function
731                        Ok(Expression::Function(Box::new(Function::new(
732                            "ARRAY".to_string(),
733                            arr.expressions.clone(),
734                        ))))
735                    }
736                    _ => Ok(Expression::IfFunc(Box::new(crate::expressions::IfFunc {
737                        condition: Expression::IsNull(Box::new(crate::expressions::IsNull {
738                            this: x.clone(),
739                            not: false,
740                            postfix_form: false,
741                        })),
742                        true_value: Expression::Null(crate::expressions::Null),
743                        false_value: Some(Expression::Function(Box::new(Function::new(
744                            "ARRAY".to_string(),
745                            vec![x],
746                        )))),
747                        original_name: Some("IF".to_string()),
748                    }))),
749                }
750            }
751
752            // REGEXP_SUBSTR -> REGEXP_EXTRACT (strip extra args)
753            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
754                let subject = f.args[0].clone();
755                let pattern = f.args[1].clone();
756                // For Spark: REGEXP_EXTRACT(subject, pattern, group)
757                // group defaults to 0 for full match, but sqlglot uses last arg if present
758                let group = if f.args.len() >= 6 {
759                    let g = &f.args[5];
760                    // If group is literal 1 (default), omit it
761                    if matches!(g, Expression::Literal(Literal::Number(n)) if n == "1") {
762                        None
763                    } else {
764                        Some(g.clone())
765                    }
766                } else {
767                    None
768                };
769                let mut args = vec![subject, pattern];
770                if let Some(g) = group {
771                    args.push(g);
772                }
773                Ok(Expression::Function(Box::new(Function::new(
774                    "REGEXP_EXTRACT".to_string(),
775                    args,
776                ))))
777            }
778
779            // UUID_STRING -> UUID()
780            "UUID_STRING" => Ok(Expression::Function(Box::new(Function::new(
781                "UUID".to_string(),
782                vec![],
783            )))),
784
785            // OBJECT_CONSTRUCT -> STRUCT in Spark
786            "OBJECT_CONSTRUCT" if f.args.len() >= 2 && f.args.len() % 2 == 0 => {
787                // Convert key-value pairs to named struct fields
788                // OBJECT_CONSTRUCT('Manitoba', 'Winnipeg', 'foo', 'bar')
789                // -> STRUCT('Winnipeg' AS Manitoba, 'bar' AS foo)
790                let mut struct_args = Vec::new();
791                for pair in f.args.chunks(2) {
792                    if let Expression::Literal(Literal::String(key)) = &pair[0] {
793                        struct_args.push(Expression::Alias(Box::new(crate::expressions::Alias {
794                            this: pair[1].clone(),
795                            alias: crate::expressions::Identifier::new(key.clone()),
796                            column_aliases: vec![],
797                            pre_alias_comments: vec![],
798                            trailing_comments: vec![],
799                        })));
800                    } else {
801                        struct_args.push(pair[1].clone());
802                    }
803                }
804                Ok(Expression::Function(Box::new(Function::new(
805                    "STRUCT".to_string(),
806                    struct_args,
807                ))))
808            }
809
810            // DATE_PART(part, expr) -> EXTRACT(part FROM expr)
811            "DATE_PART" if f.args.len() == 2 => {
812                let mut args = f.args;
813                let part = args.remove(0);
814                let expr = args.remove(0);
815                if let Some(field) = expr_to_datetime_field(&part) {
816                    Ok(Expression::Extract(Box::new(ExtractFunc {
817                        this: expr,
818                        field,
819                    })))
820                } else {
821                    // Can't parse the field, keep as function
822                    Ok(Expression::Function(Box::new(Function::new(
823                        "DATE_PART".to_string(),
824                        vec![part, expr],
825                    ))))
826                }
827            }
828
829            // GET_PATH(obj, path) -> GET_JSON_OBJECT(obj, json_path) in Spark
830            "GET_PATH" if f.args.len() == 2 => {
831                let mut args = f.args;
832                let this = args.remove(0);
833                let path = args.remove(0);
834                let json_path = match &path {
835                    Expression::Literal(Literal::String(s)) => {
836                        let normalized = if s.starts_with('$') {
837                            s.clone()
838                        } else if s.starts_with('[') {
839                            format!("${}", s)
840                        } else {
841                            format!("$.{}", s)
842                        };
843                        Expression::Literal(Literal::String(normalized))
844                    }
845                    _ => path,
846                };
847                Ok(Expression::Function(Box::new(Function::new(
848                    "GET_JSON_OBJECT".to_string(),
849                    vec![this, json_path],
850                ))))
851            }
852
853            // BITWISE_LEFT_SHIFT → SHIFTLEFT
854            "BITWISE_LEFT_SHIFT" => Ok(Expression::Function(Box::new(Function::new(
855                "SHIFTLEFT".to_string(),
856                f.args,
857            )))),
858
859            // BITWISE_RIGHT_SHIFT → SHIFTRIGHT
860            "BITWISE_RIGHT_SHIFT" => Ok(Expression::Function(Box::new(Function::new(
861                "SHIFTRIGHT".to_string(),
862                f.args,
863            )))),
864
865            // APPROX_DISTINCT → APPROX_COUNT_DISTINCT
866            "APPROX_DISTINCT" => Ok(Expression::Function(Box::new(Function::new(
867                "APPROX_COUNT_DISTINCT".to_string(),
868                f.args,
869            )))),
870
871            // ARRAY_SLICE → SLICE
872            "ARRAY_SLICE" => Ok(Expression::Function(Box::new(Function::new(
873                "SLICE".to_string(),
874                f.args,
875            )))),
876
877            // DATE_FROM_PARTS → MAKE_DATE
878            "DATE_FROM_PARTS" => Ok(Expression::Function(Box::new(Function::new(
879                "MAKE_DATE".to_string(),
880                f.args,
881            )))),
882
883            // DAYOFWEEK_ISO → DAYOFWEEK
884            "DAYOFWEEK_ISO" => Ok(Expression::Function(Box::new(Function::new(
885                "DAYOFWEEK".to_string(),
886                f.args,
887            )))),
888
889            // FORMAT → FORMAT_STRING
890            "FORMAT" => Ok(Expression::Function(Box::new(Function::new(
891                "FORMAT_STRING".to_string(),
892                f.args,
893            )))),
894
895            // LOGICAL_AND → BOOL_AND
896            "LOGICAL_AND" => Ok(Expression::Function(Box::new(Function::new(
897                "BOOL_AND".to_string(),
898                f.args,
899            )))),
900
901            // VARIANCE_POP → VAR_POP
902            "VARIANCE_POP" => Ok(Expression::Function(Box::new(Function::new(
903                "VAR_POP".to_string(),
904                f.args,
905            )))),
906
907            // WEEK_OF_YEAR → WEEKOFYEAR
908            "WEEK_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
909                "WEEKOFYEAR".to_string(),
910                f.args,
911            )))),
912
913            // Pass through everything else
914            _ => Ok(Expression::Function(Box::new(f))),
915        }
916    }
917
918    fn transform_aggregate_function(
919        &self,
920        f: Box<crate::expressions::AggregateFunction>,
921    ) -> Result<Expression> {
922        let name_upper = f.name.to_uppercase();
923        match name_upper.as_str() {
924            // GROUP_CONCAT -> COLLECT_LIST (then CONCAT_WS for string)
925            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
926                Function::new("COLLECT_LIST".to_string(), f.args),
927            ))),
928
929            // STRING_AGG -> COLLECT_LIST (or STRING_AGG in Spark 4+)
930            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
931                Function::new("COLLECT_LIST".to_string(), f.args),
932            ))),
933
934            // LISTAGG -> COLLECT_LIST
935            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
936                "COLLECT_LIST".to_string(),
937                f.args,
938            )))),
939
940            // ARRAY_AGG -> COLLECT_LIST (preserve distinct and filter)
941            "ARRAY_AGG" if !f.args.is_empty() => {
942                let mut af = f;
943                af.name = "COLLECT_LIST".to_string();
944                Ok(Expression::AggregateFunction(af))
945            }
946
947            // LOGICAL_OR -> BOOL_OR in Spark
948            "LOGICAL_OR" if !f.args.is_empty() => {
949                let mut af = f;
950                af.name = "BOOL_OR".to_string();
951                Ok(Expression::AggregateFunction(af))
952            }
953
954            // Pass through everything else
955            _ => Ok(Expression::AggregateFunction(f)),
956        }
957    }
958}
959
960/// Convert an expression (string literal or identifier) to a DateTimeField
961fn expr_to_datetime_field(expr: &Expression) -> Option<DateTimeField> {
962    let name = match expr {
963        Expression::Literal(Literal::String(s)) => s.to_uppercase(),
964        Expression::Identifier(id) => id.name.to_uppercase(),
965        Expression::Column(col) if col.table.is_none() => col.name.name.to_uppercase(),
966        _ => return None,
967    };
968    match name.as_str() {
969        "YEAR" | "Y" | "YY" | "YYY" | "YYYY" | "YR" | "YEARS" | "YRS" => Some(DateTimeField::Year),
970        "MONTH" | "MM" | "MON" | "MONS" | "MONTHS" => Some(DateTimeField::Month),
971        "DAY" | "D" | "DD" | "DAYS" | "DAYOFMONTH" => Some(DateTimeField::Day),
972        "HOUR" | "H" | "HH" | "HR" | "HOURS" | "HRS" => Some(DateTimeField::Hour),
973        "MINUTE" | "MI" | "MIN" | "MINUTES" | "MINS" => Some(DateTimeField::Minute),
974        "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => Some(DateTimeField::Second),
975        "MILLISECOND" | "MS" | "MSEC" | "MILLISECONDS" => Some(DateTimeField::Millisecond),
976        "MICROSECOND" | "US" | "USEC" | "MICROSECONDS" => Some(DateTimeField::Microsecond),
977        "DOW" | "DAYOFWEEK" | "DAYOFWEEK_ISO" | "DW" => Some(DateTimeField::DayOfWeek),
978        "DOY" | "DAYOFYEAR" => Some(DateTimeField::DayOfYear),
979        "WEEK" | "W" | "WK" | "WEEKOFYEAR" | "WOY" => Some(DateTimeField::Week),
980        "QUARTER" | "Q" | "QTR" | "QTRS" | "QUARTERS" => Some(DateTimeField::Quarter),
981        "EPOCH" | "EPOCH_SECOND" | "EPOCH_SECONDS" => Some(DateTimeField::Epoch),
982        "TIMEZONE" | "TIMEZONE_HOUR" | "TZH" => Some(DateTimeField::TimezoneHour),
983        "TIMEZONE_MINUTE" | "TZM" => Some(DateTimeField::TimezoneMinute),
984        _ => Some(DateTimeField::Custom(name)),
985    }
986}