Skip to main content

polyglot_sql/dialects/
spark.rs

1//! Spark SQL Dialect
2//!
3//! Spark SQL-specific transformations based on sqlglot patterns.
4//! Key features (extends Hive with modern SQL):
5//! - TRY_CAST is supported (Spark 3+)
6//! - ILIKE is supported (Spark 3+)
7//! - Uses backticks for identifiers
8//! - ARRAY_AGG, COLLECT_LIST for array aggregation
9//! - STRING_AGG / LISTAGG supported (Spark 4+)
10//! - DATE_ADD with unit parameter (Spark 3+)
11//! - TIMESTAMPADD, TIMESTAMPDIFF (Spark 3+)
12//! - More PostgreSQL-like syntax than Hive
13
14use super::{DialectImpl, DialectType};
15use crate::error::Result;
16use crate::expressions::{
17    CeilFunc, CurrentTimestamp, DataType, DateTimeField, Expression, ExtractFunc, Function,
18    Literal, StructField, UnaryFunc, VarArgFunc,
19};
20use crate::generator::GeneratorConfig;
21use crate::tokens::TokenizerConfig;
22
23/// Spark SQL dialect
24pub struct SparkDialect;
25
26impl DialectImpl for SparkDialect {
27    fn dialect_type(&self) -> DialectType {
28        DialectType::Spark
29    }
30
31    fn tokenizer_config(&self) -> TokenizerConfig {
32        let mut config = TokenizerConfig::default();
33        // Spark uses backticks for identifiers (NOT double quotes)
34        config.identifiers.clear();
35        config.identifiers.insert('`', '`');
36        // Spark (like Hive) uses double quotes as string delimiters (QUOTES = ["'", '"'])
37        config.quotes.insert("\"".to_string(), "\"".to_string());
38        // Spark (like Hive) uses backslash escapes in strings (STRING_ESCAPES = ["\\"])
39        config.string_escapes.push('\\');
40        // Spark supports DIV keyword for integer division (inherited from Hive)
41        config
42            .keywords
43            .insert("DIV".to_string(), crate::tokens::TokenType::Div);
44        // Spark numeric literal suffixes (same as Hive): 1L -> BIGINT, 1S -> SMALLINT, etc.
45        config
46            .numeric_literals
47            .insert("L".to_string(), "BIGINT".to_string());
48        config
49            .numeric_literals
50            .insert("S".to_string(), "SMALLINT".to_string());
51        config
52            .numeric_literals
53            .insert("Y".to_string(), "TINYINT".to_string());
54        config
55            .numeric_literals
56            .insert("D".to_string(), "DOUBLE".to_string());
57        config
58            .numeric_literals
59            .insert("F".to_string(), "FLOAT".to_string());
60        config
61            .numeric_literals
62            .insert("BD".to_string(), "DECIMAL".to_string());
63        // Spark allows identifiers to start with digits (e.g., 1a, 1_a)
64        config.identifiers_can_start_with_digit = true;
65        // Spark: STRING_ESCAPES_ALLOWED_IN_RAW_STRINGS = False
66        // Backslashes in raw strings are always literal (no escape processing)
67        config.string_escapes_allowed_in_raw_strings = false;
68        config
69    }
70
71    fn generator_config(&self) -> GeneratorConfig {
72        use crate::generator::IdentifierQuoteStyle;
73        GeneratorConfig {
74            identifier_quote: '`',
75            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
76            dialect: Some(DialectType::Spark),
77            // Spark uses colon separator in STRUCT field definitions: STRUCT<field_name: TYPE>
78            struct_field_sep: ": ",
79            // Spark doesn't use AS before RETURN in function definitions
80            create_function_return_as: false,
81            // Spark places alias after the TABLESAMPLE clause
82            alias_post_tablesample: true,
83            tablesample_seed_keyword: "REPEATABLE",
84            join_hints: false,
85            identifiers_can_start_with_digit: true,
86            // Spark uses COMMENT 'value' without = sign
87            schema_comment_with_eq: false,
88            ..Default::default()
89        }
90    }
91
92    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
93        match expr {
94            // IFNULL -> COALESCE in Spark
95            Expression::IfNull(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
96                original_name: None,
97                expressions: vec![f.this, f.expression],
98            }))),
99
100            // NVL is supported in Spark (from Hive), but COALESCE is standard
101            Expression::Nvl(f) => Ok(Expression::Coalesce(Box::new(VarArgFunc {
102                original_name: None,
103                expressions: vec![f.this, f.expression],
104            }))),
105
106            // Cast: normalize VARCHAR(n) -> STRING, CHAR(n) -> STRING for Spark
107            Expression::Cast(mut c) => {
108                c.to = Self::normalize_spark_type(c.to);
109                Ok(Expression::Cast(c))
110            }
111
112            // TryCast stays as TryCast in Spark (Spark supports TRY_CAST natively)
113            Expression::TryCast(mut c) => {
114                c.to = Self::normalize_spark_type(c.to);
115                Ok(Expression::TryCast(c))
116            }
117
118            // SafeCast -> TRY_CAST
119            Expression::SafeCast(mut c) => {
120                c.to = Self::normalize_spark_type(c.to);
121                Ok(Expression::TryCast(c))
122            }
123
124            // TRIM: non-standard comma syntax -> standard FROM syntax
125            // TRIM('SL', 'SSparkSQLS') -> TRIM('SL' FROM 'SSparkSQLS')
126            Expression::Trim(mut t) => {
127                if !t.sql_standard_syntax && t.characters.is_some() {
128                    // Convert comma syntax to standard SQL syntax
129                    // Fields already have correct semantics: this=string, characters=chars
130                    t.sql_standard_syntax = true;
131                }
132                Ok(Expression::Trim(t))
133            }
134
135            // ILIKE is supported in Spark 3+
136            Expression::ILike(op) => Ok(Expression::ILike(op)),
137
138            // UNNEST -> EXPLODE in Spark (Hive compatibility)
139            Expression::Unnest(f) => Ok(Expression::Explode(Box::new(UnaryFunc::new(f.this)))),
140
141            // EXPLODE is native to Spark
142            Expression::Explode(f) => Ok(Expression::Explode(f)),
143
144            // ExplodeOuter is supported in Spark
145            Expression::ExplodeOuter(f) => Ok(Expression::ExplodeOuter(f)),
146
147            // RANDOM -> RAND in Spark
148            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
149                seed: None,
150                lower: None,
151                upper: None,
152            }))),
153
154            // Rand is native to Spark
155            Expression::Rand(r) => Ok(Expression::Rand(r)),
156
157            // || (Concat) -> CONCAT in Spark
158            Expression::Concat(op) => Ok(Expression::Function(Box::new(Function::new(
159                "CONCAT".to_string(),
160                vec![op.left, op.right],
161            )))),
162
163            // ParseJson: handled by generator (emits just the string literal for Spark)
164
165            // Generic function transformations
166            Expression::Function(f) => self.transform_function(*f),
167
168            // Generic aggregate function transformations
169            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
170
171            // $N parameters -> ${N} in Spark (DollarBrace style)
172            Expression::Parameter(mut p)
173                if p.style == crate::expressions::ParameterStyle::Dollar =>
174            {
175                p.style = crate::expressions::ParameterStyle::DollarBrace;
176                // Convert index to name for DollarBrace format
177                if let Some(idx) = p.index {
178                    p.name = Some(idx.to_string());
179                }
180                Ok(Expression::Parameter(p))
181            }
182
183            // JSONExtract with variant_extract (Databricks colon syntax) -> GET_JSON_OBJECT
184            Expression::JSONExtract(je) if je.variant_extract.is_some() => {
185                // Convert path: 'item[1].price' -> '$.item[1].price'
186                let path = match *je.expression {
187                    Expression::Literal(Literal::String(s)) => {
188                        Expression::Literal(Literal::String(format!("$.{}", s)))
189                    }
190                    other => other,
191                };
192                Ok(Expression::Function(Box::new(Function::new(
193                    "GET_JSON_OBJECT".to_string(),
194                    vec![*je.this, path],
195                ))))
196            }
197
198            // Pass through everything else
199            _ => Ok(expr),
200        }
201    }
202}
203
204impl SparkDialect {
205    /// Normalize a data type for Spark:
206    /// - VARCHAR/CHAR without length -> STRING
207    /// - VARCHAR(n)/CHAR(n) with length -> keep as-is
208    /// - TEXT -> STRING
209    fn normalize_spark_type(dt: DataType) -> DataType {
210        match dt {
211            DataType::VarChar { length: None, .. }
212            | DataType::Char { length: None }
213            | DataType::Text => DataType::Custom {
214                name: "STRING".to_string(),
215            },
216            // VARCHAR(n) and CHAR(n) with length are kept as-is
217            DataType::VarChar { .. } | DataType::Char { .. } => dt,
218            // Also normalize struct fields recursively
219            DataType::Struct { fields, nested } => {
220                let normalized_fields: Vec<StructField> = fields
221                    .into_iter()
222                    .map(|mut f| {
223                        f.data_type = Self::normalize_spark_type(f.data_type);
224                        f
225                    })
226                    .collect();
227                DataType::Struct {
228                    fields: normalized_fields,
229                    nested,
230                }
231            }
232            _ => dt,
233        }
234    }
235
236    fn transform_function(&self, f: Function) -> Result<Expression> {
237        let name_upper = f.name.to_uppercase();
238        match name_upper.as_str() {
239            // IFNULL -> COALESCE
240            "IFNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
241                original_name: None,
242                expressions: f.args,
243            }))),
244
245            // NVL -> COALESCE
246            "NVL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
247                original_name: None,
248                expressions: f.args,
249            }))),
250
251            // ISNULL -> COALESCE
252            "ISNULL" if f.args.len() == 2 => Ok(Expression::Coalesce(Box::new(VarArgFunc {
253                original_name: None,
254                expressions: f.args,
255            }))),
256
257            // GROUP_CONCAT -> CONCAT_WS + COLLECT_LIST in older Spark
258            // In Spark 4+, STRING_AGG is available
259            "GROUP_CONCAT" if !f.args.is_empty() => {
260                // For simplicity, use COLLECT_LIST (array aggregation)
261                Ok(Expression::Function(Box::new(Function::new(
262                    "COLLECT_LIST".to_string(),
263                    f.args,
264                ))))
265            }
266
267            // STRING_AGG is supported in Spark 4+
268            // For older versions, fall back to CONCAT_WS + COLLECT_LIST
269            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
270                Function::new("COLLECT_LIST".to_string(), f.args),
271            ))),
272
273            // LISTAGG -> STRING_AGG in Spark 4+ (or COLLECT_LIST for older)
274            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
275                "COLLECT_LIST".to_string(),
276                f.args,
277            )))),
278
279            // SUBSTRING is native to Spark
280            "SUBSTRING" | "SUBSTR" => Ok(Expression::Function(Box::new(f))),
281
282            // LENGTH is native to Spark
283            "LENGTH" => Ok(Expression::Function(Box::new(f))),
284
285            // LEN -> LENGTH
286            "LEN" if f.args.len() == 1 => Ok(Expression::Length(Box::new(UnaryFunc::new(
287                f.args.into_iter().next().unwrap(),
288            )))),
289
290            // RANDOM -> RAND
291            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
292                seed: None,
293                lower: None,
294                upper: None,
295            }))),
296
297            // RAND is native to Spark
298            "RAND" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
299                seed: None,
300                lower: None,
301                upper: None,
302            }))),
303
304            // NOW -> CURRENT_TIMESTAMP
305            "NOW" => Ok(Expression::CurrentTimestamp(
306                crate::expressions::CurrentTimestamp {
307                    precision: None,
308                    sysdate: false,
309                },
310            )),
311
312            // GETDATE -> CURRENT_TIMESTAMP
313            "GETDATE" => Ok(Expression::CurrentTimestamp(
314                crate::expressions::CurrentTimestamp {
315                    precision: None,
316                    sysdate: false,
317                },
318            )),
319
320            // CURRENT_TIMESTAMP is native
321            "CURRENT_TIMESTAMP" => Ok(Expression::CurrentTimestamp(
322                crate::expressions::CurrentTimestamp {
323                    precision: None,
324                    sysdate: false,
325                },
326            )),
327
328            // CURRENT_DATE is native
329            "CURRENT_DATE" => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
330
331            // TO_DATE is native to Spark; strip default format 'yyyy-MM-dd'
332            "TO_DATE" if f.args.len() == 2 => {
333                let is_default_format = matches!(&f.args[1], Expression::Literal(crate::expressions::Literal::String(s)) if s == "yyyy-MM-dd");
334                if is_default_format {
335                    Ok(Expression::Function(Box::new(Function::new(
336                        "TO_DATE".to_string(),
337                        vec![f.args.into_iter().next().unwrap()],
338                    ))))
339                } else {
340                    Ok(Expression::Function(Box::new(f)))
341                }
342            }
343            "TO_DATE" => Ok(Expression::Function(Box::new(f))),
344
345            // TO_TIMESTAMP is native to Spark
346            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
347
348            // DATE_FORMAT is native to Spark
349            "DATE_FORMAT" => Ok(Expression::Function(Box::new(f))),
350
351            // strftime -> DATE_FORMAT
352            "STRFTIME" => Ok(Expression::Function(Box::new(Function::new(
353                "DATE_FORMAT".to_string(),
354                f.args,
355            )))),
356
357            // TO_CHAR -> DATE_FORMAT
358            "TO_CHAR" => Ok(Expression::Function(Box::new(Function::new(
359                "DATE_FORMAT".to_string(),
360                f.args,
361            )))),
362
363            // DATE_TRUNC is native to Spark
364            "DATE_TRUNC" => Ok(Expression::Function(Box::new(f))),
365
366            // TRUNC is native to Spark
367            "TRUNC" => Ok(Expression::Function(Box::new(f))),
368
369            // EXTRACT is native to Spark
370            "EXTRACT" => Ok(Expression::Function(Box::new(f))),
371
372            // DATEPART -> EXTRACT
373            "DATEPART" => Ok(Expression::Function(Box::new(Function::new(
374                "EXTRACT".to_string(),
375                f.args,
376            )))),
377
378            // UNIX_TIMESTAMP is native to Spark
379            // When called with no args, add CURRENT_TIMESTAMP() as default
380            "UNIX_TIMESTAMP" => {
381                if f.args.is_empty() {
382                    Ok(Expression::Function(Box::new(Function::new(
383                        "UNIX_TIMESTAMP".to_string(),
384                        vec![Expression::CurrentTimestamp(CurrentTimestamp {
385                            precision: None,
386                            sysdate: false,
387                        })],
388                    ))))
389                } else {
390                    Ok(Expression::Function(Box::new(f)))
391                }
392            }
393
394            // FROM_UNIXTIME is native to Spark
395            "FROM_UNIXTIME" => Ok(Expression::Function(Box::new(f))),
396
397            // STR_TO_MAP is native to Spark
398            // When called with only one arg, add default delimiters ',' and ':'
399            "STR_TO_MAP" => {
400                if f.args.len() == 1 {
401                    let mut args = f.args;
402                    args.push(Expression::Literal(crate::expressions::Literal::String(
403                        ",".to_string(),
404                    )));
405                    args.push(Expression::Literal(crate::expressions::Literal::String(
406                        ":".to_string(),
407                    )));
408                    Ok(Expression::Function(Box::new(Function::new(
409                        "STR_TO_MAP".to_string(),
410                        args,
411                    ))))
412                } else {
413                    Ok(Expression::Function(Box::new(f)))
414                }
415            }
416
417            // POSITION is native to Spark (POSITION(substr IN str))
418            "POSITION" => Ok(Expression::Function(Box::new(f))),
419
420            // LOCATE is native to Spark
421            "LOCATE" => Ok(Expression::Function(Box::new(f))),
422
423            // STRPOS -> Use expression form or LOCATE
424            "STRPOS" if f.args.len() == 2 => {
425                let mut args = f.args;
426                let first = args.remove(0);
427                let second = args.remove(0);
428                // LOCATE(substr, str) in Spark
429                Ok(Expression::Function(Box::new(Function::new(
430                    "LOCATE".to_string(),
431                    vec![second, first],
432                ))))
433            }
434
435            // CHARINDEX -> LOCATE
436            "CHARINDEX" if f.args.len() >= 2 => {
437                let mut args = f.args;
438                let substring = args.remove(0);
439                let string = args.remove(0);
440                let mut locate_args = vec![substring, string];
441                if !args.is_empty() {
442                    locate_args.push(args.remove(0));
443                }
444                Ok(Expression::Function(Box::new(Function::new(
445                    "LOCATE".to_string(),
446                    locate_args,
447                ))))
448            }
449
450            // INSTR is native to Spark
451            "INSTR" => Ok(Expression::Function(Box::new(f))),
452
453            // CEILING -> CEIL
454            "CEILING" if f.args.len() == 1 => Ok(Expression::Ceil(Box::new(CeilFunc {
455                this: f.args.into_iter().next().unwrap(),
456                decimals: None,
457                to: None,
458            }))),
459
460            // CEIL is native to Spark
461            "CEIL" if f.args.len() == 1 => Ok(Expression::Ceil(Box::new(CeilFunc {
462                this: f.args.into_iter().next().unwrap(),
463                decimals: None,
464                to: None,
465            }))),
466
467            // UNNEST -> EXPLODE
468            "UNNEST" => Ok(Expression::Function(Box::new(Function::new(
469                "EXPLODE".to_string(),
470                f.args,
471            )))),
472
473            // FLATTEN -> FLATTEN is native to Spark (for nested arrays)
474            "FLATTEN" => Ok(Expression::Function(Box::new(f))),
475
476            // ARRAY_AGG -> COLLECT_LIST
477            "ARRAY_AGG" => Ok(Expression::Function(Box::new(Function::new(
478                "COLLECT_LIST".to_string(),
479                f.args,
480            )))),
481
482            // COLLECT_LIST is native to Spark
483            "COLLECT_LIST" => Ok(Expression::Function(Box::new(f))),
484
485            // COLLECT_SET is native to Spark
486            "COLLECT_SET" => Ok(Expression::Function(Box::new(f))),
487
488            // ARRAY_LENGTH -> SIZE in Spark
489            "ARRAY_LENGTH" | "CARDINALITY" => Ok(Expression::Function(Box::new(Function::new(
490                "SIZE".to_string(),
491                f.args,
492            )))),
493
494            // SIZE is native to Spark
495            "SIZE" => Ok(Expression::Function(Box::new(f))),
496
497            // SPLIT is native to Spark
498            "SPLIT" => Ok(Expression::Function(Box::new(f))),
499
500            // REGEXP_REPLACE: Spark supports up to 4 args (subject, pattern, replacement, position)
501            // Strip extra Snowflake args (occurrence, params) if present
502            "REGEXP_REPLACE" if f.args.len() > 4 => {
503                let mut args = f.args;
504                args.truncate(4);
505                Ok(Expression::Function(Box::new(Function::new(
506                    "REGEXP_REPLACE".to_string(),
507                    args,
508                ))))
509            }
510            "REGEXP_REPLACE" => Ok(Expression::Function(Box::new(f))),
511
512            // REGEXP_EXTRACT is native to Spark
513            "REGEXP_EXTRACT" => Ok(Expression::Function(Box::new(f))),
514
515            // REGEXP_EXTRACT_ALL is native to Spark
516            "REGEXP_EXTRACT_ALL" => Ok(Expression::Function(Box::new(f))),
517
518            // RLIKE is native to Spark
519            "RLIKE" | "REGEXP_LIKE" => Ok(Expression::Function(Box::new(Function::new(
520                "RLIKE".to_string(),
521                f.args,
522            )))),
523
524            // JSON_EXTRACT -> GET_JSON_OBJECT (Hive style) or :: operator
525            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(Function::new(
526                "GET_JSON_OBJECT".to_string(),
527                f.args,
528            )))),
529
530            // JSON_EXTRACT_SCALAR -> GET_JSON_OBJECT
531            "JSON_EXTRACT_SCALAR" => Ok(Expression::Function(Box::new(Function::new(
532                "GET_JSON_OBJECT".to_string(),
533                f.args,
534            )))),
535
536            // GET_JSON_OBJECT is native to Spark
537            "GET_JSON_OBJECT" => Ok(Expression::Function(Box::new(f))),
538
539            // FROM_JSON is native to Spark
540            "FROM_JSON" => Ok(Expression::Function(Box::new(f))),
541
542            // TO_JSON is native to Spark
543            "TO_JSON" => Ok(Expression::Function(Box::new(f))),
544
545            // PARSE_JSON -> strip for Spark (just keep the string argument)
546            "PARSE_JSON" if f.args.len() == 1 => Ok(f.args.into_iter().next().unwrap()),
547            "PARSE_JSON" => Ok(Expression::Function(Box::new(Function::new(
548                "FROM_JSON".to_string(),
549                f.args,
550            )))),
551
552            // DATEDIFF is native to Spark (supports unit in Spark 3+)
553            "DATEDIFF" | "DATE_DIFF" => Ok(Expression::Function(Box::new(Function::new(
554                "DATEDIFF".to_string(),
555                f.args,
556            )))),
557
558            // DATE_ADD is native to Spark
559            "DATE_ADD" | "DATEADD" => Ok(Expression::Function(Box::new(Function::new(
560                "DATE_ADD".to_string(),
561                f.args,
562            )))),
563
564            // DATE_SUB is native to Spark
565            "DATE_SUB" => Ok(Expression::Function(Box::new(f))),
566
567            // TIMESTAMPADD is native to Spark 3+
568            "TIMESTAMPADD" => Ok(Expression::Function(Box::new(f))),
569
570            // TIMESTAMPDIFF is native to Spark 3+
571            "TIMESTAMPDIFF" => Ok(Expression::Function(Box::new(f))),
572
573            // ADD_MONTHS is native to Spark
574            "ADD_MONTHS" => Ok(Expression::Function(Box::new(f))),
575
576            // MONTHS_BETWEEN is native to Spark
577            "MONTHS_BETWEEN" => Ok(Expression::Function(Box::new(f))),
578
579            // NVL is native to Spark
580            "NVL" => Ok(Expression::Function(Box::new(f))),
581
582            // NVL2 is native to Spark
583            "NVL2" => Ok(Expression::Function(Box::new(f))),
584
585            // MAP is native to Spark
586            "MAP" => Ok(Expression::Function(Box::new(f))),
587
588            // ARRAY is native to Spark
589            "ARRAY" => Ok(Expression::Function(Box::new(f))),
590
591            // ROW -> STRUCT for Spark (cross-dialect, no auto-naming)
592            "ROW" => Ok(Expression::Function(Box::new(Function::new(
593                "STRUCT".to_string(),
594                f.args,
595            )))),
596
597            // STRUCT is native to Spark - auto-name unnamed args as col1, col2, etc.
598            "STRUCT" => {
599                let mut col_idx = 1usize;
600                let named_args: Vec<Expression> = f
601                    .args
602                    .into_iter()
603                    .map(|arg| {
604                        let current_idx = col_idx;
605                        col_idx += 1;
606                        // Check if arg already has an alias (AS name) or is Star
607                        match &arg {
608                            Expression::Alias(_) => arg, // already named
609                            Expression::Star(_) => arg,  // STRUCT(*) - keep as-is
610                            Expression::Column(c) if c.table.is_none() => {
611                                // Column reference: use column name as the struct field name
612                                let name = c.name.name.clone();
613                                Expression::Alias(Box::new(crate::expressions::Alias {
614                                    this: arg,
615                                    alias: crate::expressions::Identifier::new(&name),
616                                    column_aliases: Vec::new(),
617                                    pre_alias_comments: Vec::new(),
618                                    trailing_comments: Vec::new(),
619                                }))
620                            }
621                            _ => {
622                                // Unnamed literal/expression: auto-name as colN
623                                let name = format!("col{}", current_idx);
624                                Expression::Alias(Box::new(crate::expressions::Alias {
625                                    this: arg,
626                                    alias: crate::expressions::Identifier::new(&name),
627                                    column_aliases: Vec::new(),
628                                    pre_alias_comments: Vec::new(),
629                                    trailing_comments: Vec::new(),
630                                }))
631                            }
632                        }
633                    })
634                    .collect();
635                Ok(Expression::Function(Box::new(Function {
636                    name: "STRUCT".to_string(),
637                    args: named_args,
638                    distinct: false,
639                    trailing_comments: Vec::new(),
640                    use_bracket_syntax: false,
641                    no_parens: false,
642                    quoted: false,
643                })))
644            }
645
646            // NAMED_STRUCT is native to Spark
647            "NAMED_STRUCT" => Ok(Expression::Function(Box::new(f))),
648
649            // MAP_FROM_ARRAYS is native to Spark
650            "MAP_FROM_ARRAYS" => Ok(Expression::Function(Box::new(f))),
651
652            // ARRAY_SORT is native to Spark
653            "ARRAY_SORT" => Ok(Expression::Function(Box::new(f))),
654
655            // ARRAY_DISTINCT is native to Spark
656            "ARRAY_DISTINCT" => Ok(Expression::Function(Box::new(f))),
657
658            // ARRAY_UNION is native to Spark
659            "ARRAY_UNION" => Ok(Expression::Function(Box::new(f))),
660
661            // ARRAY_INTERSECT is native to Spark
662            "ARRAY_INTERSECT" => Ok(Expression::Function(Box::new(f))),
663
664            // ARRAY_EXCEPT is native to Spark
665            "ARRAY_EXCEPT" => Ok(Expression::Function(Box::new(f))),
666
667            // ARRAY_CONTAINS is native to Spark
668            "ARRAY_CONTAINS" => Ok(Expression::Function(Box::new(f))),
669
670            // ELEMENT_AT is native to Spark
671            "ELEMENT_AT" => Ok(Expression::Function(Box::new(f))),
672
673            // TRY_ELEMENT_AT is native to Spark 3+
674            "TRY_ELEMENT_AT" => Ok(Expression::Function(Box::new(f))),
675
676            // TRANSFORM is native to Spark (array transformation)
677            "TRANSFORM" => Ok(Expression::Function(Box::new(f))),
678
679            // FILTER is native to Spark (array filtering)
680            "FILTER" => Ok(Expression::Function(Box::new(f))),
681
682            // AGGREGATE is native to Spark (array reduction)
683            "AGGREGATE" => Ok(Expression::Function(Box::new(f))),
684
685            // SEQUENCE is native to Spark (generate array)
686            "SEQUENCE" => Ok(Expression::Function(Box::new(f))),
687
688            // GENERATE_SERIES -> SEQUENCE
689            "GENERATE_SERIES" => Ok(Expression::Function(Box::new(Function::new(
690                "SEQUENCE".to_string(),
691                f.args,
692            )))),
693
694            // STARTSWITH is native to Spark 3+
695            "STARTSWITH" | "STARTS_WITH" => Ok(Expression::Function(Box::new(Function::new(
696                "STARTSWITH".to_string(),
697                f.args,
698            )))),
699
700            // ENDSWITH is native to Spark 3+
701            "ENDSWITH" | "ENDS_WITH" => Ok(Expression::Function(Box::new(Function::new(
702                "ENDSWITH".to_string(),
703                f.args,
704            )))),
705
706            // ARRAY_CONSTRUCT_COMPACT(1, null, 2) -> ARRAY_COMPACT(ARRAY(1, NULL, 2))
707            "ARRAY_CONSTRUCT_COMPACT" => {
708                let inner =
709                    Expression::Function(Box::new(Function::new("ARRAY".to_string(), f.args)));
710                Ok(Expression::Function(Box::new(Function::new(
711                    "ARRAY_COMPACT".to_string(),
712                    vec![inner],
713                ))))
714            }
715
716            // ARRAY_TO_STRING -> ARRAY_JOIN
717            "ARRAY_TO_STRING" => Ok(Expression::Function(Box::new(Function::new(
718                "ARRAY_JOIN".to_string(),
719                f.args,
720            )))),
721
722            // TO_ARRAY(x) -> IF(x IS NULL, NULL, ARRAY(x))
723            "TO_ARRAY" if f.args.len() == 1 => {
724                let x = f.args[0].clone();
725                // Check if arg is already an array constructor (bracket notation)
726                // In that case: TO_ARRAY(['test']) -> ARRAY('test')
727                match &x {
728                    Expression::ArrayFunc(arr) => {
729                        // Just convert to ARRAY(...) function
730                        Ok(Expression::Function(Box::new(Function::new(
731                            "ARRAY".to_string(),
732                            arr.expressions.clone(),
733                        ))))
734                    }
735                    _ => Ok(Expression::IfFunc(Box::new(crate::expressions::IfFunc {
736                        condition: Expression::IsNull(Box::new(crate::expressions::IsNull {
737                            this: x.clone(),
738                            not: false,
739                            postfix_form: false,
740                        })),
741                        true_value: Expression::Null(crate::expressions::Null),
742                        false_value: Some(Expression::Function(Box::new(Function::new(
743                            "ARRAY".to_string(),
744                            vec![x],
745                        )))),
746                        original_name: Some("IF".to_string()),
747                    }))),
748                }
749            }
750
751            // REGEXP_SUBSTR -> REGEXP_EXTRACT (strip extra args)
752            "REGEXP_SUBSTR" if f.args.len() >= 2 => {
753                let subject = f.args[0].clone();
754                let pattern = f.args[1].clone();
755                // For Spark: REGEXP_EXTRACT(subject, pattern, group)
756                // group defaults to 0 for full match, but sqlglot uses last arg if present
757                let group = if f.args.len() >= 6 {
758                    let g = &f.args[5];
759                    // If group is literal 1 (default), omit it
760                    if matches!(g, Expression::Literal(Literal::Number(n)) if n == "1") {
761                        None
762                    } else {
763                        Some(g.clone())
764                    }
765                } else {
766                    None
767                };
768                let mut args = vec![subject, pattern];
769                if let Some(g) = group {
770                    args.push(g);
771                }
772                Ok(Expression::Function(Box::new(Function::new(
773                    "REGEXP_EXTRACT".to_string(),
774                    args,
775                ))))
776            }
777
778            // UUID_STRING -> UUID()
779            "UUID_STRING" => Ok(Expression::Function(Box::new(Function::new(
780                "UUID".to_string(),
781                vec![],
782            )))),
783
784            // OBJECT_CONSTRUCT -> STRUCT in Spark
785            "OBJECT_CONSTRUCT" if f.args.len() >= 2 && f.args.len() % 2 == 0 => {
786                // Convert key-value pairs to named struct fields
787                // OBJECT_CONSTRUCT('Manitoba', 'Winnipeg', 'foo', 'bar')
788                // -> STRUCT('Winnipeg' AS Manitoba, 'bar' AS foo)
789                let mut struct_args = Vec::new();
790                for pair in f.args.chunks(2) {
791                    if let Expression::Literal(Literal::String(key)) = &pair[0] {
792                        struct_args.push(Expression::Alias(Box::new(crate::expressions::Alias {
793                            this: pair[1].clone(),
794                            alias: crate::expressions::Identifier::new(key.clone()),
795                            column_aliases: vec![],
796                            pre_alias_comments: vec![],
797                            trailing_comments: vec![],
798                        })));
799                    } else {
800                        struct_args.push(pair[1].clone());
801                    }
802                }
803                Ok(Expression::Function(Box::new(Function::new(
804                    "STRUCT".to_string(),
805                    struct_args,
806                ))))
807            }
808
809            // DATE_PART(part, expr) -> EXTRACT(part FROM expr)
810            "DATE_PART" if f.args.len() == 2 => {
811                let mut args = f.args;
812                let part = args.remove(0);
813                let expr = args.remove(0);
814                if let Some(field) = expr_to_datetime_field(&part) {
815                    Ok(Expression::Extract(Box::new(ExtractFunc {
816                        this: expr,
817                        field,
818                    })))
819                } else {
820                    // Can't parse the field, keep as function
821                    Ok(Expression::Function(Box::new(Function::new(
822                        "DATE_PART".to_string(),
823                        vec![part, expr],
824                    ))))
825                }
826            }
827
828            // GET_PATH(obj, path) -> GET_JSON_OBJECT(obj, json_path) in Spark
829            "GET_PATH" if f.args.len() == 2 => {
830                let mut args = f.args;
831                let this = args.remove(0);
832                let path = args.remove(0);
833                let json_path = match &path {
834                    Expression::Literal(Literal::String(s)) => {
835                        let normalized = if s.starts_with('$') {
836                            s.clone()
837                        } else if s.starts_with('[') {
838                            format!("${}", s)
839                        } else {
840                            format!("$.{}", s)
841                        };
842                        Expression::Literal(Literal::String(normalized))
843                    }
844                    _ => path,
845                };
846                Ok(Expression::Function(Box::new(Function::new(
847                    "GET_JSON_OBJECT".to_string(),
848                    vec![this, json_path],
849                ))))
850            }
851
852            // BITWISE_LEFT_SHIFT → SHIFTLEFT
853            "BITWISE_LEFT_SHIFT" => Ok(Expression::Function(Box::new(Function::new(
854                "SHIFTLEFT".to_string(),
855                f.args,
856            )))),
857
858            // BITWISE_RIGHT_SHIFT → SHIFTRIGHT
859            "BITWISE_RIGHT_SHIFT" => Ok(Expression::Function(Box::new(Function::new(
860                "SHIFTRIGHT".to_string(),
861                f.args,
862            )))),
863
864            // APPROX_DISTINCT → APPROX_COUNT_DISTINCT
865            "APPROX_DISTINCT" => Ok(Expression::Function(Box::new(Function::new(
866                "APPROX_COUNT_DISTINCT".to_string(),
867                f.args,
868            )))),
869
870            // ARRAY_SLICE → SLICE
871            "ARRAY_SLICE" => Ok(Expression::Function(Box::new(Function::new(
872                "SLICE".to_string(),
873                f.args,
874            )))),
875
876            // DATE_FROM_PARTS → MAKE_DATE
877            "DATE_FROM_PARTS" => Ok(Expression::Function(Box::new(Function::new(
878                "MAKE_DATE".to_string(),
879                f.args,
880            )))),
881
882            // DAYOFWEEK_ISO → DAYOFWEEK
883            "DAYOFWEEK_ISO" => Ok(Expression::Function(Box::new(Function::new(
884                "DAYOFWEEK".to_string(),
885                f.args,
886            )))),
887
888            // FORMAT → FORMAT_STRING
889            "FORMAT" => Ok(Expression::Function(Box::new(Function::new(
890                "FORMAT_STRING".to_string(),
891                f.args,
892            )))),
893
894            // LOGICAL_AND → BOOL_AND
895            "LOGICAL_AND" => Ok(Expression::Function(Box::new(Function::new(
896                "BOOL_AND".to_string(),
897                f.args,
898            )))),
899
900            // VARIANCE_POP → VAR_POP
901            "VARIANCE_POP" => Ok(Expression::Function(Box::new(Function::new(
902                "VAR_POP".to_string(),
903                f.args,
904            )))),
905
906            // WEEK_OF_YEAR → WEEKOFYEAR
907            "WEEK_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
908                "WEEKOFYEAR".to_string(),
909                f.args,
910            )))),
911
912            // Pass through everything else
913            _ => Ok(Expression::Function(Box::new(f))),
914        }
915    }
916
917    fn transform_aggregate_function(
918        &self,
919        f: Box<crate::expressions::AggregateFunction>,
920    ) -> Result<Expression> {
921        let name_upper = f.name.to_uppercase();
922        match name_upper.as_str() {
923            // GROUP_CONCAT -> COLLECT_LIST (then CONCAT_WS for string)
924            "GROUP_CONCAT" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
925                Function::new("COLLECT_LIST".to_string(), f.args),
926            ))),
927
928            // STRING_AGG -> COLLECT_LIST (or STRING_AGG in Spark 4+)
929            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
930                Function::new("COLLECT_LIST".to_string(), f.args),
931            ))),
932
933            // LISTAGG -> COLLECT_LIST
934            "LISTAGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
935                "COLLECT_LIST".to_string(),
936                f.args,
937            )))),
938
939            // ARRAY_AGG -> COLLECT_LIST (preserve distinct and filter)
940            "ARRAY_AGG" if !f.args.is_empty() => {
941                let mut af = f;
942                af.name = "COLLECT_LIST".to_string();
943                Ok(Expression::AggregateFunction(af))
944            }
945
946            // LOGICAL_OR -> BOOL_OR in Spark
947            "LOGICAL_OR" if !f.args.is_empty() => {
948                let mut af = f;
949                af.name = "BOOL_OR".to_string();
950                Ok(Expression::AggregateFunction(af))
951            }
952
953            // Pass through everything else
954            _ => Ok(Expression::AggregateFunction(f)),
955        }
956    }
957}
958
959/// Convert an expression (string literal or identifier) to a DateTimeField
960fn expr_to_datetime_field(expr: &Expression) -> Option<DateTimeField> {
961    let name = match expr {
962        Expression::Literal(Literal::String(s)) => s.to_uppercase(),
963        Expression::Identifier(id) => id.name.to_uppercase(),
964        Expression::Column(col) if col.table.is_none() => col.name.name.to_uppercase(),
965        _ => return None,
966    };
967    match name.as_str() {
968        "YEAR" | "Y" | "YY" | "YYY" | "YYYY" | "YR" | "YEARS" | "YRS" => Some(DateTimeField::Year),
969        "MONTH" | "MM" | "MON" | "MONS" | "MONTHS" => Some(DateTimeField::Month),
970        "DAY" | "D" | "DD" | "DAYS" | "DAYOFMONTH" => Some(DateTimeField::Day),
971        "HOUR" | "H" | "HH" | "HR" | "HOURS" | "HRS" => Some(DateTimeField::Hour),
972        "MINUTE" | "MI" | "MIN" | "MINUTES" | "MINS" => Some(DateTimeField::Minute),
973        "SECOND" | "S" | "SEC" | "SECONDS" | "SECS" => Some(DateTimeField::Second),
974        "MILLISECOND" | "MS" | "MSEC" | "MILLISECONDS" => Some(DateTimeField::Millisecond),
975        "MICROSECOND" | "US" | "USEC" | "MICROSECONDS" => Some(DateTimeField::Microsecond),
976        "DOW" | "DAYOFWEEK" | "DAYOFWEEK_ISO" | "DW" => Some(DateTimeField::DayOfWeek),
977        "DOY" | "DAYOFYEAR" => Some(DateTimeField::DayOfYear),
978        "WEEK" | "W" | "WK" | "WEEKOFYEAR" | "WOY" => Some(DateTimeField::Week),
979        "QUARTER" | "Q" | "QTR" | "QTRS" | "QUARTERS" => Some(DateTimeField::Quarter),
980        "EPOCH" | "EPOCH_SECOND" | "EPOCH_SECONDS" => Some(DateTimeField::Epoch),
981        "TIMEZONE" | "TIMEZONE_HOUR" | "TZH" => Some(DateTimeField::TimezoneHour),
982        "TIMEZONE_MINUTE" | "TZM" => Some(DateTimeField::TimezoneMinute),
983        _ => Some(DateTimeField::Custom(name)),
984    }
985}