Skip to main content

polyglot_sql/dialects/
mysql.rs

1//! MySQL Dialect
2//!
3//! MySQL-specific transformations based on sqlglot patterns.
4//! Key differences from standard SQL:
5//! - || is OR operator, not string concatenation (use CONCAT)
6//! - Uses backticks for identifiers
7//! - No TRY_CAST, no ILIKE
8//! - Different date/time function names
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    BinaryFunc, BinaryOp, Cast, DataType, Expression, Function, Interval, IntervalUnit,
14    IntervalUnitSpec, JsonExtractFunc, LikeOp, Literal, MakeInterval, Paren, UnaryFunc,
15};
16#[cfg(feature = "generate")]
17use crate::generator::GeneratorConfig;
18use crate::tokens::TokenizerConfig;
19
20/// Helper to wrap JSON arrow expressions in parentheses when they appear
21/// in contexts that require it (Binary, In, Not expressions)
22/// This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior
23fn wrap_if_json_arrow(expr: Expression) -> Expression {
24    match &expr {
25        Expression::JsonExtract(f) if f.arrow_syntax => Expression::Paren(Box::new(Paren {
26            this: expr,
27            trailing_comments: Vec::new(),
28        })),
29        Expression::JsonExtractScalar(f) if f.arrow_syntax => Expression::Paren(Box::new(Paren {
30            this: expr,
31            trailing_comments: Vec::new(),
32        })),
33        _ => expr,
34    }
35}
36
37/// Convert JSON arrow expression (-> or ->>) to JSON_EXTRACT function form
38/// This is needed for contexts like MEMBER OF where arrow syntax must become function form
39fn json_arrow_to_function(expr: Expression) -> Expression {
40    match expr {
41        Expression::JsonExtract(f) if f.arrow_syntax => Expression::Function(Box::new(
42            Function::new("JSON_EXTRACT".to_string(), vec![f.this, f.path]),
43        )),
44        Expression::JsonExtractScalar(f) if f.arrow_syntax => {
45            // ->> becomes JSON_UNQUOTE(JSON_EXTRACT(...)) but can be simplified to JSON_EXTRACT_SCALAR
46            // For MySQL, use JSON_UNQUOTE(JSON_EXTRACT(...))
47            let json_extract = Expression::Function(Box::new(Function::new(
48                "JSON_EXTRACT".to_string(),
49                vec![f.this, f.path],
50            )));
51            Expression::Function(Box::new(Function::new(
52                "JSON_UNQUOTE".to_string(),
53                vec![json_extract],
54            )))
55        }
56        other => other,
57    }
58}
59
60/// MySQL dialect
61pub struct MySQLDialect;
62
63impl DialectImpl for MySQLDialect {
64    fn dialect_type(&self) -> DialectType {
65        DialectType::MySQL
66    }
67
68    fn tokenizer_config(&self) -> TokenizerConfig {
69        use crate::tokens::TokenType;
70        let mut config = TokenizerConfig::default();
71        // MySQL uses backticks for identifiers
72        config.identifiers.insert('`', '`');
73        // Remove double quotes from identifiers - in MySQL they are string delimiters
74        // (unless ANSI_QUOTES mode is set, but default mode uses them as strings)
75        config.identifiers.remove(&'"');
76        // MySQL supports double quotes as string literals by default
77        config.quotes.insert("\"".to_string(), "\"".to_string());
78        // MySQL supports backslash escapes in strings
79        config.string_escapes.push('\\');
80        // MySQL has XOR as a logical operator keyword
81        config.keywords.insert("XOR".to_string(), TokenType::Xor);
82        // MySQL: backslash followed by chars NOT in this list -> discard backslash
83        // See: https://dev.mysql.com/doc/refman/8.4/en/string-literals.html
84        config.escape_follow_chars = vec!['0', 'b', 'n', 'r', 't', 'Z', '%', '_'];
85        // MySQL allows identifiers to start with digits (e.g., 1a, 1_a)
86        config.identifiers_can_start_with_digit = true;
87        config.hex_number_strings = true;
88        config
89    }
90
91    #[cfg(feature = "generate")]
92
93    fn generator_config(&self) -> GeneratorConfig {
94        use crate::generator::IdentifierQuoteStyle;
95        GeneratorConfig {
96            identifier_quote: '`',
97            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
98            dialect: Some(DialectType::MySQL),
99            // MySQL doesn't support null ordering in most contexts
100            null_ordering_supported: false,
101            // MySQL LIMIT only
102            limit_only_literals: true,
103            // MySQL doesn't support semi/anti join
104            semi_anti_join_with_side: false,
105            // MySQL doesn't support table alias columns in some contexts
106            supports_table_alias_columns: false,
107            // MySQL VALUES not used as table
108            values_as_table: false,
109            // MySQL doesn't support TABLESAMPLE
110            tablesample_requires_parens: false,
111            tablesample_with_method: false,
112            // MySQL doesn't support aggregate FILTER
113            aggregate_filter_supported: false,
114            // MySQL doesn't support TRY
115            try_supported: false,
116            // MySQL doesn't support CONVERT_TIMEZONE
117            supports_convert_timezone: false,
118            // MySQL doesn't support UESCAPE
119            supports_uescape: false,
120            // MySQL doesn't support BETWEEN flags
121            supports_between_flags: false,
122            // MySQL supports EXPLAIN but not query hints in standard way
123            query_hints: false,
124            // MySQL parameter token
125            parameter_token: "?",
126            // MySQL doesn't support window EXCLUDE
127            supports_window_exclude: false,
128            // MySQL doesn't support exploding projections
129            supports_exploding_projections: false,
130            identifiers_can_start_with_digit: true,
131            // MySQL supports FOR UPDATE/SHARE
132            locking_reads_supported: true,
133            ..Default::default()
134        }
135    }
136
137    #[cfg(feature = "transpile")]
138
139    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
140        match expr {
141            // ===== Data Type Mappings =====
142            Expression::DataType(dt) => self.transform_data_type(dt),
143
144            // NVL -> IFNULL in MySQL
145            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
146
147            // Note: COALESCE is valid in MySQL and should be preserved.
148            // Unlike some other dialects, we do NOT convert COALESCE to IFNULL
149            // as this would break identity tests.
150
151            // TryCast -> CAST or TIMESTAMP() (MySQL doesn't support TRY_CAST)
152            Expression::TryCast(c) => self.transform_cast(*c),
153
154            // SafeCast -> CAST or TIMESTAMP() (MySQL doesn't support safe casts)
155            Expression::SafeCast(c) => self.transform_cast(*c),
156
157            // Cast -> Transform cast type according to MySQL restrictions
158            // CAST AS TIMESTAMP -> TIMESTAMP() function in MySQL
159            Expression::Cast(c) => self.transform_cast(*c),
160
161            // ILIKE -> LOWER() LIKE LOWER() in MySQL
162            Expression::ILike(op) => {
163                // Transform ILIKE to: LOWER(left) LIKE LOWER(right)
164                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left)));
165                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right)));
166                Ok(Expression::Like(Box::new(LikeOp {
167                    left: lower_left,
168                    right: lower_right,
169                    escape: op.escape,
170                    quantifier: op.quantifier,
171                    inferred_type: None,
172                })))
173            }
174
175            // Preserve semantic string concatenation expressions.
176            // MySQL generation renders these as CONCAT(...).
177            Expression::Concat(op) => Ok(Expression::Concat(op)),
178
179            // PostgreSQL MAKE_INTERVAL -> MySQL interval arithmetic.
180            Expression::Add(op) => self.transform_make_interval_binary(op, false),
181            Expression::Sub(op) => self.transform_make_interval_binary(op, true),
182
183            // RANDOM -> RAND in MySQL
184            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
185                seed: None,
186                lower: None,
187                upper: None,
188            }))),
189
190            // ArrayAgg -> GROUP_CONCAT in MySQL
191            Expression::ArrayAgg(f) => Ok(Expression::Function(Box::new(Function::new(
192                "GROUP_CONCAT".to_string(),
193                vec![f.this],
194            )))),
195
196            // StringAgg -> GROUP_CONCAT in MySQL
197            Expression::StringAgg(f) => {
198                let mut args = vec![f.this.clone()];
199                if let Some(separator) = &f.separator {
200                    args.push(separator.clone());
201                }
202                Ok(Expression::Function(Box::new(Function::new(
203                    "GROUP_CONCAT".to_string(),
204                    args,
205                ))))
206            }
207
208            // UNNEST -> Not directly supported in MySQL, use JSON_TABLE or inline
209            // For basic cases, pass through (may need manual handling)
210            Expression::Unnest(f) => {
211                // MySQL 8.0+ has JSON_TABLE which can be used for unnesting
212                // For now, pass through with a function call
213                Ok(Expression::Function(Box::new(Function::new(
214                    "JSON_TABLE".to_string(),
215                    vec![f.this],
216                ))))
217            }
218
219            // Substring: Use comma syntax (not FROM/FOR) in MySQL
220            Expression::Substring(mut f) => {
221                f.from_for_syntax = false;
222                Ok(Expression::Substring(f))
223            }
224
225            // ===== Bitwise operations =====
226            // BitwiseAndAgg -> BIT_AND
227            Expression::BitwiseAndAgg(f) => Ok(Expression::Function(Box::new(Function::new(
228                "BIT_AND".to_string(),
229                vec![f.this],
230            )))),
231
232            // BitwiseOrAgg -> BIT_OR
233            Expression::BitwiseOrAgg(f) => Ok(Expression::Function(Box::new(Function::new(
234                "BIT_OR".to_string(),
235                vec![f.this],
236            )))),
237
238            // BitwiseXorAgg -> BIT_XOR
239            Expression::BitwiseXorAgg(f) => Ok(Expression::Function(Box::new(Function::new(
240                "BIT_XOR".to_string(),
241                vec![f.this],
242            )))),
243
244            // BitwiseCount -> BIT_COUNT
245            Expression::BitwiseCount(f) => Ok(Expression::Function(Box::new(Function::new(
246                "BIT_COUNT".to_string(),
247                vec![f.this],
248            )))),
249
250            // TimeFromParts -> MAKETIME
251            Expression::TimeFromParts(f) => {
252                let mut args = Vec::new();
253                if let Some(h) = f.hour {
254                    args.push(*h);
255                }
256                if let Some(m) = f.min {
257                    args.push(*m);
258                }
259                if let Some(s) = f.sec {
260                    args.push(*s);
261                }
262                Ok(Expression::Function(Box::new(Function::new(
263                    "MAKETIME".to_string(),
264                    args,
265                ))))
266            }
267
268            // ===== Boolean aggregates =====
269            // In MySQL, there's no BOOL_AND/BOOL_OR, use MIN/MAX on boolean values
270            // LogicalAnd -> MIN (0 is false, non-0 is true)
271            Expression::LogicalAnd(f) => Ok(Expression::Function(Box::new(Function::new(
272                "MIN".to_string(),
273                vec![f.this],
274            )))),
275
276            // LogicalOr -> MAX
277            Expression::LogicalOr(f) => Ok(Expression::Function(Box::new(Function::new(
278                "MAX".to_string(),
279                vec![f.this],
280            )))),
281
282            // ===== Date/time functions =====
283            // DayOfMonth -> DAYOFMONTH
284            Expression::DayOfMonth(f) => Ok(Expression::Function(Box::new(Function::new(
285                "DAYOFMONTH".to_string(),
286                vec![f.this],
287            )))),
288
289            // DayOfWeek -> DAYOFWEEK
290            Expression::DayOfWeek(f) => Ok(Expression::Function(Box::new(Function::new(
291                "DAYOFWEEK".to_string(),
292                vec![f.this],
293            )))),
294
295            // DayOfYear -> DAYOFYEAR
296            Expression::DayOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
297                "DAYOFYEAR".to_string(),
298                vec![f.this],
299            )))),
300
301            // WeekOfYear -> WEEKOFYEAR
302            Expression::WeekOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
303                "WEEKOFYEAR".to_string(),
304                vec![f.this],
305            )))),
306
307            // DateDiff -> DATEDIFF
308            Expression::DateDiff(f) => Ok(Expression::Function(Box::new(Function::new(
309                "DATEDIFF".to_string(),
310                vec![f.this, f.expression],
311            )))),
312
313            // TimeStrToUnix -> UNIX_TIMESTAMP
314            Expression::TimeStrToUnix(f) => Ok(Expression::Function(Box::new(Function::new(
315                "UNIX_TIMESTAMP".to_string(),
316                vec![f.this],
317            )))),
318
319            // TimestampDiff -> TIMESTAMPDIFF
320            Expression::TimestampDiff(f) => Ok(Expression::Function(Box::new(Function::new(
321                "TIMESTAMPDIFF".to_string(),
322                vec![*f.this, *f.expression],
323            )))),
324
325            // ===== String functions =====
326            // StrPosition -> LOCATE in MySQL
327            // STRPOS(str, substr) -> LOCATE(substr, str) (args are swapped)
328            Expression::StrPosition(f) => {
329                let mut args = vec![];
330                if let Some(substr) = f.substr {
331                    args.push(*substr);
332                }
333                args.push(*f.this);
334                if let Some(pos) = f.position {
335                    args.push(*pos);
336                }
337                Ok(Expression::Function(Box::new(Function::new(
338                    "LOCATE".to_string(),
339                    args,
340                ))))
341            }
342
343            // Stuff -> INSERT in MySQL
344            Expression::Stuff(f) => {
345                let mut args = vec![*f.this];
346                if let Some(start) = f.start {
347                    args.push(*start);
348                }
349                if let Some(length) = f.length {
350                    args.push(Expression::number(length));
351                }
352                args.push(*f.expression);
353                Ok(Expression::Function(Box::new(Function::new(
354                    "INSERT".to_string(),
355                    args,
356                ))))
357            }
358
359            // ===== Session/User functions =====
360            // SessionUser -> SESSION_USER()
361            Expression::SessionUser(_) => Ok(Expression::Function(Box::new(Function::new(
362                "SESSION_USER".to_string(),
363                vec![],
364            )))),
365
366            // CurrentDate -> CURRENT_DATE (no parentheses in MySQL) - keep as CurrentDate
367            Expression::CurrentDate(_) => {
368                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
369            }
370
371            // ===== Null-safe comparison =====
372            // NullSafeNeq -> NOT (a <=> b) in MySQL
373            Expression::NullSafeNeq(op) => {
374                // Create: NOT (left <=> right)
375                let null_safe_eq = Expression::NullSafeEq(Box::new(crate::expressions::BinaryOp {
376                    left: op.left,
377                    right: op.right,
378                    left_comments: Vec::new(),
379                    operator_comments: Vec::new(),
380                    trailing_comments: Vec::new(),
381                    inferred_type: None,
382                }));
383                Ok(Expression::Not(Box::new(crate::expressions::UnaryOp {
384                    this: null_safe_eq,
385                    inferred_type: None,
386                })))
387            }
388
389            // ParseJson: handled by generator (emits just the string literal for MySQL)
390
391            // JSONExtract with variant_extract (Snowflake colon syntax) -> JSON_EXTRACT
392            Expression::JSONExtract(e) if e.variant_extract.is_some() => {
393                let path = match *e.expression {
394                    Expression::Literal(lit) if matches!(lit.as_ref(), Literal::String(_)) => {
395                        let Literal::String(s) = lit.as_ref() else {
396                            unreachable!()
397                        };
398                        // Convert bracket notation ["key"] to quoted dot notation ."key"
399                        let s = Self::convert_bracket_to_quoted_path(&s);
400                        let normalized = if s.starts_with('$') {
401                            s
402                        } else if s.starts_with('[') {
403                            format!("${}", s)
404                        } else {
405                            format!("$.{}", s)
406                        };
407                        Expression::Literal(Box::new(Literal::String(normalized)))
408                    }
409                    other => other,
410                };
411                Ok(Expression::Function(Box::new(Function::new(
412                    "JSON_EXTRACT".to_string(),
413                    vec![*e.this, path],
414                ))))
415            }
416
417            // Generic function transformations
418            Expression::Function(f) => self.transform_function(*f),
419
420            // Generic aggregate function transformations
421            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
422
423            // ===== Context-aware JSON arrow wrapping =====
424            // When JSON arrow expressions appear in Binary/In/Not contexts,
425            // they need to be wrapped in parentheses for correct precedence.
426            // This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior.
427
428            // Binary operators that need JSON wrapping
429            Expression::Eq(op) => Ok(Expression::Eq(Box::new(BinaryOp {
430                left: wrap_if_json_arrow(op.left),
431                right: wrap_if_json_arrow(op.right),
432                ..*op
433            }))),
434            Expression::Neq(op) => Ok(Expression::Neq(Box::new(BinaryOp {
435                left: wrap_if_json_arrow(op.left),
436                right: wrap_if_json_arrow(op.right),
437                ..*op
438            }))),
439            Expression::Lt(op) => Ok(Expression::Lt(Box::new(BinaryOp {
440                left: wrap_if_json_arrow(op.left),
441                right: wrap_if_json_arrow(op.right),
442                ..*op
443            }))),
444            Expression::Lte(op) => Ok(Expression::Lte(Box::new(BinaryOp {
445                left: wrap_if_json_arrow(op.left),
446                right: wrap_if_json_arrow(op.right),
447                ..*op
448            }))),
449            Expression::Gt(op) => Ok(Expression::Gt(Box::new(BinaryOp {
450                left: wrap_if_json_arrow(op.left),
451                right: wrap_if_json_arrow(op.right),
452                ..*op
453            }))),
454            Expression::Gte(op) => Ok(Expression::Gte(Box::new(BinaryOp {
455                left: wrap_if_json_arrow(op.left),
456                right: wrap_if_json_arrow(op.right),
457                ..*op
458            }))),
459
460            // In expression - wrap the this part if it's JSON arrow
461            Expression::In(mut i) => {
462                i.this = wrap_if_json_arrow(i.this);
463                Ok(Expression::In(i))
464            }
465
466            // Not expression - wrap the this part if it's JSON arrow
467            Expression::Not(mut n) => {
468                n.this = wrap_if_json_arrow(n.this);
469                Ok(Expression::Not(n))
470            }
471
472            // && in MySQL is logical AND, not array overlaps
473            // Transform ArrayOverlaps -> And for MySQL identity
474            Expression::ArrayOverlaps(op) => Ok(Expression::And(op)),
475
476            // MOD(x, y) -> x % y in MySQL
477            Expression::ModFunc(f) => Ok(Expression::Mod(Box::new(BinaryOp {
478                left: f.this,
479                right: f.expression,
480                left_comments: Vec::new(),
481                operator_comments: Vec::new(),
482                trailing_comments: Vec::new(),
483                inferred_type: None,
484            }))),
485
486            // SHOW SLAVE STATUS -> SHOW REPLICA STATUS
487            Expression::Show(mut s) => {
488                if s.this == "SLAVE STATUS" {
489                    s.this = "REPLICA STATUS".to_string();
490                }
491                if matches!(s.this.as_str(), "INDEX" | "COLUMNS") && s.db.is_none() {
492                    if let Some(Expression::Table(mut t)) = s.target.take() {
493                        if let Some(db_ident) = t.schema.take().or(t.catalog.take()) {
494                            s.db = Some(Expression::Identifier(db_ident));
495                            s.target = Some(Expression::Identifier(t.name));
496                        } else {
497                            s.target = Some(Expression::Table(t));
498                        }
499                    }
500                }
501                Ok(Expression::Show(s))
502            }
503
504            // AT TIME ZONE -> strip timezone (MySQL doesn't support AT TIME ZONE)
505            // But keep it for CURRENT_DATE/CURRENT_TIMESTAMP with timezone (transpiled from BigQuery)
506            Expression::AtTimeZone(atz) => {
507                let is_current = match &atz.this {
508                    Expression::CurrentDate(_) | Expression::CurrentTimestamp(_) => true,
509                    Expression::Function(f) => {
510                        let n = f.name.to_uppercase();
511                        (n == "CURRENT_DATE" || n == "CURRENT_TIMESTAMP") && f.no_parens
512                    }
513                    _ => false,
514                };
515                if is_current {
516                    Ok(Expression::AtTimeZone(atz)) // Keep AT TIME ZONE for CURRENT_DATE/CURRENT_TIMESTAMP
517                } else {
518                    Ok(atz.this) // Strip timezone for other expressions
519                }
520            }
521
522            // MEMBER OF with JSON arrow -> convert arrow to JSON_EXTRACT function
523            // MySQL's MEMBER OF requires JSON_EXTRACT function form, not arrow syntax
524            Expression::MemberOf(mut op) => {
525                op.right = json_arrow_to_function(op.right);
526                Ok(Expression::MemberOf(op))
527            }
528
529            // Pass through everything else
530            _ => Ok(expr),
531        }
532    }
533}
534
535#[cfg(feature = "transpile")]
536impl MySQLDialect {
537    fn normalize_mysql_date_format(fmt: &str) -> String {
538        fmt.replace("%H:%i:%s", "%T").replace("%H:%i:%S", "%T")
539    }
540
541    /// Convert bracket notation ["key with spaces"] to quoted dot notation ."key with spaces"
542    /// in JSON path strings.
543    fn convert_bracket_to_quoted_path(path: &str) -> String {
544        let mut result = String::new();
545        let mut chars = path.chars().peekable();
546        while let Some(c) = chars.next() {
547            if c == '[' && chars.peek() == Some(&'"') {
548                chars.next(); // consume "
549                let mut key = String::new();
550                while let Some(kc) = chars.next() {
551                    if kc == '"' && chars.peek() == Some(&']') {
552                        chars.next(); // consume ]
553                        break;
554                    }
555                    key.push(kc);
556                }
557                if !result.is_empty() && !result.ends_with('.') {
558                    result.push('.');
559                }
560                result.push('"');
561                result.push_str(&key);
562                result.push('"');
563            } else {
564                result.push(c);
565            }
566        }
567        result
568    }
569
570    /// Transform data types according to MySQL TYPE_MAPPING
571    /// Note: MySQL's TIMESTAMP is kept as TIMESTAMP (not converted to DATETIME)
572    /// because MySQL's TIMESTAMP has timezone awareness built-in
573    fn transform_data_type(&self, dt: crate::expressions::DataType) -> Result<Expression> {
574        use crate::expressions::DataType;
575        let transformed = match dt {
576            // All TIMESTAMP variants (with or without timezone) -> TIMESTAMP in MySQL
577            DataType::Timestamp {
578                precision,
579                timezone: _,
580            } => DataType::Timestamp {
581                precision,
582                timezone: false,
583            },
584            // TIMESTAMPTZ / TIMESTAMPLTZ parsed as Custom -> normalize to TIMESTAMP
585            DataType::Custom { name }
586                if name.to_uppercase() == "TIMESTAMPTZ"
587                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
588            {
589                DataType::Timestamp {
590                    precision: None,
591                    timezone: false,
592                }
593            }
594            // Keep native MySQL types as-is
595            // MySQL supports TEXT, MEDIUMTEXT, LONGTEXT, BLOB, etc. natively
596            other => other,
597        };
598        Ok(Expression::DataType(transformed))
599    }
600
601    /// Transform CAST expression
602    /// MySQL uses TIMESTAMP() function instead of CAST(x AS TIMESTAMP)
603    /// For Generic->MySQL, TIMESTAMP (no tz) is pre-converted to DATETIME in cross_dialect_normalize
604    fn transform_cast(&self, cast: Cast) -> Result<Expression> {
605        // CAST AS TIMESTAMP/TIMESTAMPTZ/TIMESTAMPLTZ -> TIMESTAMP() function
606        match &cast.to {
607            DataType::Timestamp { .. } => Ok(Expression::Function(Box::new(Function::new(
608                "TIMESTAMP".to_string(),
609                vec![cast.this],
610            )))),
611            DataType::Custom { name }
612                if name.to_uppercase() == "TIMESTAMPTZ"
613                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
614            {
615                Ok(Expression::Function(Box::new(Function::new(
616                    "TIMESTAMP".to_string(),
617                    vec![cast.this],
618                ))))
619            }
620            // All other casts go through normal type transformation
621            _ => Ok(Expression::Cast(Box::new(self.transform_cast_type(cast)))),
622        }
623    }
624
625    fn transform_make_interval_binary(
626        &self,
627        op: Box<BinaryOp>,
628        subtract: bool,
629    ) -> Result<Expression> {
630        if let Expression::MakeInterval(make_interval) = &op.right {
631            if let Some(expr) =
632                Self::expand_make_interval_binary(op.left.clone(), make_interval, subtract)
633            {
634                return Ok(expr);
635            }
636        }
637        if let Expression::Function(function) = &op.right {
638            if function.name.eq_ignore_ascii_case("MAKE_INTERVAL") {
639                if let Some(expr) =
640                    Self::expand_make_interval_function_binary(op.left.clone(), function, subtract)
641                {
642                    return Ok(expr);
643                }
644            }
645        }
646
647        if subtract {
648            Ok(Expression::Sub(op))
649        } else {
650            Ok(Expression::Add(op))
651        }
652    }
653
654    fn expand_make_interval_binary(
655        left: Expression,
656        make_interval: &MakeInterval,
657        subtract: bool,
658    ) -> Option<Expression> {
659        let parts = Self::make_interval_parts(make_interval);
660        if parts.is_empty() {
661            return None;
662        }
663
664        let mut expr = left;
665        for (value, unit) in parts {
666            let interval = Expression::Interval(Box::new(Interval {
667                this: Some(value),
668                unit: Some(IntervalUnitSpec::Simple {
669                    unit,
670                    use_plural: false,
671                }),
672            }));
673            expr = if subtract {
674                Expression::Sub(Box::new(BinaryOp::new(expr, interval)))
675            } else {
676                Expression::Add(Box::new(BinaryOp::new(expr, interval)))
677            };
678        }
679
680        Some(expr)
681    }
682
683    fn expand_make_interval_function_binary(
684        left: Expression,
685        function: &Function,
686        subtract: bool,
687    ) -> Option<Expression> {
688        let parts = Self::make_interval_function_parts(function);
689        if parts.is_empty() {
690            return None;
691        }
692
693        let mut expr = left;
694        for (value, unit) in parts {
695            let interval = Expression::Interval(Box::new(Interval {
696                this: Some(value),
697                unit: Some(IntervalUnitSpec::Simple {
698                    unit,
699                    use_plural: false,
700                }),
701            }));
702            expr = if subtract {
703                Expression::Sub(Box::new(BinaryOp::new(expr, interval)))
704            } else {
705                Expression::Add(Box::new(BinaryOp::new(expr, interval)))
706            };
707        }
708
709        Some(expr)
710    }
711
712    fn make_interval_parts(make_interval: &MakeInterval) -> Vec<(Expression, IntervalUnit)> {
713        let mut parts = Vec::new();
714        if let Some(year) = &make_interval.year {
715            parts.push(((**year).clone(), IntervalUnit::Year));
716        }
717        if let Some(month) = &make_interval.month {
718            parts.push(((**month).clone(), IntervalUnit::Month));
719        }
720        if let Some(week) = &make_interval.week {
721            parts.push(((**week).clone(), IntervalUnit::Week));
722        }
723        if let Some(day) = &make_interval.day {
724            parts.push(((**day).clone(), IntervalUnit::Day));
725        }
726        if let Some(hour) = &make_interval.hour {
727            parts.push(((**hour).clone(), IntervalUnit::Hour));
728        }
729        if let Some(minute) = &make_interval.minute {
730            parts.push(((**minute).clone(), IntervalUnit::Minute));
731        }
732        if let Some(second) = &make_interval.second {
733            parts.push(((**second).clone(), IntervalUnit::Second));
734        }
735        parts
736    }
737
738    fn make_interval_function_parts(function: &Function) -> Vec<(Expression, IntervalUnit)> {
739        let positional_units = [
740            IntervalUnit::Year,
741            IntervalUnit::Month,
742            IntervalUnit::Week,
743            IntervalUnit::Day,
744            IntervalUnit::Hour,
745            IntervalUnit::Minute,
746            IntervalUnit::Second,
747        ];
748        let mut positional_index = 0;
749        let mut parts = Vec::new();
750
751        for arg in &function.args {
752            if let Expression::NamedArgument(named) = arg {
753                if let Some(unit) = Self::make_interval_unit_from_name(&named.name.name) {
754                    parts.push((named.value.clone(), unit));
755                }
756            } else if let Some(unit) = positional_units.get(positional_index) {
757                parts.push((arg.clone(), *unit));
758                positional_index += 1;
759            }
760        }
761
762        parts
763    }
764
765    fn make_interval_unit_from_name(name: &str) -> Option<IntervalUnit> {
766        match name.to_ascii_lowercase().as_str() {
767            "year" | "years" => Some(IntervalUnit::Year),
768            "month" | "months" => Some(IntervalUnit::Month),
769            "week" | "weeks" => Some(IntervalUnit::Week),
770            "day" | "days" => Some(IntervalUnit::Day),
771            "hour" | "hours" => Some(IntervalUnit::Hour),
772            "min" | "mins" | "minute" | "minutes" => Some(IntervalUnit::Minute),
773            "sec" | "secs" | "second" | "seconds" => Some(IntervalUnit::Second),
774            _ => None,
775        }
776    }
777
778    /// Transform CAST type according to MySQL restrictions
779    /// MySQL doesn't support many types in CAST - they get mapped to CHAR or SIGNED
780    /// Based on Python sqlglot's CHAR_CAST_MAPPING and SIGNED_CAST_MAPPING
781    fn transform_cast_type(&self, cast: Cast) -> Cast {
782        let new_type = match &cast.to {
783            // CHAR_CAST_MAPPING: These types become CHAR in MySQL CAST, preserving length
784            DataType::VarChar { length, .. } => DataType::Char { length: *length },
785            DataType::Text => DataType::Char { length: None },
786
787            // SIGNED_CAST_MAPPING: These integer types become SIGNED in MySQL CAST
788            DataType::BigInt { .. } => DataType::Custom {
789                name: "SIGNED".to_string(),
790            },
791            DataType::Int { .. } => DataType::Custom {
792                name: "SIGNED".to_string(),
793            },
794            DataType::SmallInt { .. } => DataType::Custom {
795                name: "SIGNED".to_string(),
796            },
797            DataType::TinyInt { .. } => DataType::Custom {
798                name: "SIGNED".to_string(),
799            },
800            DataType::Boolean => DataType::Custom {
801                name: "SIGNED".to_string(),
802            },
803
804            // Custom types that need mapping
805            DataType::Custom { name } => {
806                let upper = name.to_uppercase();
807                match upper.as_str() {
808                    // Text/Blob types -> keep as Custom for cross-dialect mapping
809                    // MySQL generator will output CHAR for these in CAST context
810                    "LONGTEXT" | "MEDIUMTEXT" | "TINYTEXT" | "LONGBLOB" | "MEDIUMBLOB"
811                    | "TINYBLOB" => DataType::Custom { name: upper },
812                    // MEDIUMINT -> SIGNED in MySQL CAST
813                    "MEDIUMINT" => DataType::Custom {
814                        name: "SIGNED".to_string(),
815                    },
816                    // Unsigned integer types -> UNSIGNED
817                    "UBIGINT" | "UINT" | "USMALLINT" | "UTINYINT" | "UMEDIUMINT" => {
818                        DataType::Custom {
819                            name: "UNSIGNED".to_string(),
820                        }
821                    }
822                    // Keep other custom types
823                    _ => cast.to.clone(),
824                }
825            }
826
827            // Types that are valid in MySQL CAST - pass through
828            DataType::Binary { .. } => cast.to.clone(),
829            DataType::VarBinary { .. } => cast.to.clone(),
830            DataType::Date => cast.to.clone(),
831            DataType::Time { .. } => cast.to.clone(),
832            DataType::Decimal { .. } => cast.to.clone(),
833            DataType::Json => cast.to.clone(),
834            DataType::Float { .. } => cast.to.clone(),
835            DataType::Double { .. } => cast.to.clone(),
836            DataType::Char { .. } => cast.to.clone(),
837            DataType::CharacterSet { .. } => cast.to.clone(),
838            DataType::Enum { .. } => cast.to.clone(),
839            DataType::Set { .. } => cast.to.clone(),
840            DataType::Timestamp { .. } => cast.to.clone(),
841
842            // All other unsupported types -> CHAR
843            _ => DataType::Char { length: None },
844        };
845
846        Cast {
847            this: cast.this,
848            to: new_type,
849            trailing_comments: cast.trailing_comments,
850            double_colon_syntax: cast.double_colon_syntax,
851            format: cast.format,
852            default: cast.default,
853            inferred_type: None,
854        }
855    }
856
857    fn transform_function(&self, f: Function) -> Result<Expression> {
858        let name_upper = f.name.to_uppercase();
859        match name_upper.as_str() {
860            // Normalize DATE_FORMAT short-hands to canonical MySQL forms.
861            "DATE_FORMAT" if f.args.len() >= 2 => {
862                let mut f = f;
863                if let Some(Expression::Literal(lit)) = f.args.get(1) {
864                    if let Literal::String(fmt) = lit.as_ref() {
865                        let normalized = Self::normalize_mysql_date_format(fmt);
866                        if normalized != *fmt {
867                            f.args[1] = Expression::Literal(Box::new(Literal::String(normalized)));
868                        }
869                    }
870                }
871                Ok(Expression::Function(Box::new(f)))
872            }
873
874            // NVL -> IFNULL
875            "NVL" if f.args.len() == 2 => {
876                let mut args = f.args;
877                let second = args.pop().unwrap();
878                let first = args.pop().unwrap();
879                Ok(Expression::IfNull(Box::new(BinaryFunc {
880                    original_name: None,
881                    this: first,
882                    expression: second,
883                    inferred_type: None,
884                })))
885            }
886
887            // Note: COALESCE is native to MySQL. We do NOT convert it to IFNULL
888            // because this would break identity tests (Python SQLGlot preserves COALESCE).
889
890            // ARRAY_AGG -> GROUP_CONCAT
891            "ARRAY_AGG" if f.args.len() == 1 => {
892                let mut args = f.args;
893                Ok(Expression::Function(Box::new(Function::new(
894                    "GROUP_CONCAT".to_string(),
895                    vec![args.pop().unwrap()],
896                ))))
897            }
898
899            // STRING_AGG -> GROUP_CONCAT
900            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
901                Function::new("GROUP_CONCAT".to_string(), f.args),
902            ))),
903
904            // RANDOM -> RAND
905            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
906                seed: None,
907                lower: None,
908                upper: None,
909            }))),
910
911            // CURRENT_TIMESTAMP -> NOW() or CURRENT_TIMESTAMP (both work)
912            // Preserve precision if specified: CURRENT_TIMESTAMP(6)
913            "CURRENT_TIMESTAMP" => {
914                let precision = if let Some(Expression::Literal(lit)) = f.args.first() {
915                    if let crate::expressions::Literal::Number(n) = lit.as_ref() {
916                        n.parse::<u32>().ok()
917                    } else {
918                        None
919                    }
920                } else {
921                    None
922                };
923                Ok(Expression::CurrentTimestamp(
924                    crate::expressions::CurrentTimestamp {
925                        precision,
926                        sysdate: false,
927                    },
928                ))
929            }
930
931            // POSITION -> LOCATE in MySQL (argument order is different)
932            // POSITION(substr IN str) -> LOCATE(substr, str)
933            "POSITION" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
934                "LOCATE".to_string(),
935                f.args,
936            )))),
937
938            // LENGTH is native to MySQL (returns bytes, not characters)
939            // CHAR_LENGTH for character count
940            "LENGTH" => Ok(Expression::Function(Box::new(f))),
941
942            // CEIL -> CEILING in MySQL (both work)
943            "CEIL" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
944                "CEILING".to_string(),
945                f.args,
946            )))),
947
948            // STDDEV -> STD or STDDEV_POP in MySQL
949            "STDDEV" => Ok(Expression::Function(Box::new(Function::new(
950                "STD".to_string(),
951                f.args,
952            )))),
953
954            // STDDEV_SAMP -> STDDEV in MySQL
955            "STDDEV_SAMP" => Ok(Expression::Function(Box::new(Function::new(
956                "STDDEV".to_string(),
957                f.args,
958            )))),
959
960            // TO_DATE -> STR_TO_DATE in MySQL
961            "TO_DATE" => Ok(Expression::Function(Box::new(Function::new(
962                "STR_TO_DATE".to_string(),
963                f.args,
964            )))),
965
966            // TO_TIMESTAMP -> STR_TO_DATE in MySQL
967            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(Function::new(
968                "STR_TO_DATE".to_string(),
969                f.args,
970            )))),
971
972            // DATE_TRUNC -> Complex transformation
973            // Typically uses DATE() or DATE_FORMAT() depending on unit
974            "DATE_TRUNC" if f.args.len() >= 2 => {
975                // Simplified: DATE_TRUNC('day', x) -> DATE(x)
976                // Full implementation would handle different units
977                let mut args = f.args;
978                let _unit = args.remove(0);
979                let date = args.remove(0);
980                Ok(Expression::Function(Box::new(Function::new(
981                    "DATE".to_string(),
982                    vec![date],
983                ))))
984            }
985
986            // EXTRACT is native but syntax varies
987
988            // COALESCE is native to MySQL (keep as-is for more than 2 args)
989            "COALESCE" if f.args.len() > 2 => Ok(Expression::Function(Box::new(f))),
990
991            // DAYOFMONTH -> DAY (both work)
992            "DAY" => Ok(Expression::Function(Box::new(Function::new(
993                "DAYOFMONTH".to_string(),
994                f.args,
995            )))),
996
997            // DAYOFWEEK is native to MySQL
998            "DAYOFWEEK" => Ok(Expression::Function(Box::new(f))),
999
1000            // DAYOFYEAR is native to MySQL
1001            "DAYOFYEAR" => Ok(Expression::Function(Box::new(f))),
1002
1003            // WEEKOFYEAR is native to MySQL
1004            "WEEKOFYEAR" => Ok(Expression::Function(Box::new(f))),
1005
1006            // LAST_DAY is native to MySQL
1007            "LAST_DAY" => Ok(Expression::Function(Box::new(f))),
1008
1009            // TIMESTAMPADD -> DATE_ADD
1010            "TIMESTAMPADD" => Ok(Expression::Function(Box::new(Function::new(
1011                "DATE_ADD".to_string(),
1012                f.args,
1013            )))),
1014
1015            // TIMESTAMPDIFF is native to MySQL
1016            "TIMESTAMPDIFF" => Ok(Expression::Function(Box::new(f))),
1017
1018            // CONVERT_TIMEZONE(from_tz, to_tz, timestamp) -> CONVERT_TZ(timestamp, from_tz, to_tz) in MySQL
1019            "CONVERT_TIMEZONE" if f.args.len() == 3 => {
1020                let mut args = f.args;
1021                let from_tz = args.remove(0);
1022                let to_tz = args.remove(0);
1023                let timestamp = args.remove(0);
1024                Ok(Expression::Function(Box::new(Function::new(
1025                    "CONVERT_TZ".to_string(),
1026                    vec![timestamp, from_tz, to_tz],
1027                ))))
1028            }
1029
1030            // UTC_TIMESTAMP is native to MySQL
1031            "UTC_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
1032
1033            // UTC_TIME is native to MySQL
1034            "UTC_TIME" => Ok(Expression::Function(Box::new(f))),
1035
1036            // MAKETIME is native to MySQL (TimeFromParts)
1037            "MAKETIME" => Ok(Expression::Function(Box::new(f))),
1038
1039            // TIME_FROM_PARTS -> MAKETIME
1040            "TIME_FROM_PARTS" if f.args.len() == 3 => Ok(Expression::Function(Box::new(
1041                Function::new("MAKETIME".to_string(), f.args),
1042            ))),
1043
1044            // STUFF -> INSERT in MySQL
1045            "STUFF" if f.args.len() == 4 => Ok(Expression::Function(Box::new(Function::new(
1046                "INSERT".to_string(),
1047                f.args,
1048            )))),
1049
1050            // LOCATE is native to MySQL (reverse of POSITION args)
1051            "LOCATE" => Ok(Expression::Function(Box::new(f))),
1052
1053            // FIND_IN_SET is native to MySQL
1054            "FIND_IN_SET" => Ok(Expression::Function(Box::new(f))),
1055
1056            // FORMAT is native to MySQL (NumberToStr)
1057            "FORMAT" => Ok(Expression::Function(Box::new(f))),
1058
1059            // JSON_EXTRACT is native to MySQL
1060            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
1061
1062            // JSON_UNQUOTE is native to MySQL
1063            "JSON_UNQUOTE" => Ok(Expression::Function(Box::new(f))),
1064
1065            // JSON_EXTRACT_PATH_TEXT -> JSON_UNQUOTE(JSON_EXTRACT(...))
1066            "JSON_EXTRACT_PATH_TEXT" if f.args.len() >= 2 => {
1067                let extract = Expression::Function(Box::new(Function::new(
1068                    "JSON_EXTRACT".to_string(),
1069                    f.args,
1070                )));
1071                Ok(Expression::Function(Box::new(Function::new(
1072                    "JSON_UNQUOTE".to_string(),
1073                    vec![extract],
1074                ))))
1075            }
1076
1077            // GEN_RANDOM_UUID / UUID -> UUID()
1078            "GEN_RANDOM_UUID" | "GENERATE_UUID" => Ok(Expression::Function(Box::new(
1079                Function::new("UUID".to_string(), vec![]),
1080            ))),
1081
1082            // DATABASE() -> SCHEMA() in MySQL (both return current database name)
1083            "DATABASE" => Ok(Expression::Function(Box::new(Function::new(
1084                "SCHEMA".to_string(),
1085                f.args,
1086            )))),
1087
1088            // INSTR -> LOCATE in MySQL (with swapped arguments)
1089            // INSTR(str, substr) -> LOCATE(substr, str)
1090            "INSTR" if f.args.len() == 2 => {
1091                let mut args = f.args;
1092                let str_arg = args.remove(0);
1093                let substr_arg = args.remove(0);
1094                Ok(Expression::Function(Box::new(Function::new(
1095                    "LOCATE".to_string(),
1096                    vec![substr_arg, str_arg],
1097                ))))
1098            }
1099
1100            // TIME_STR_TO_UNIX -> UNIX_TIMESTAMP in MySQL
1101            "TIME_STR_TO_UNIX" => Ok(Expression::Function(Box::new(Function::new(
1102                "UNIX_TIMESTAMP".to_string(),
1103                f.args,
1104            )))),
1105
1106            // TIME_STR_TO_TIME -> CAST AS DATETIME(N) or TIMESTAMP() in MySQL
1107            "TIME_STR_TO_TIME" if f.args.len() >= 1 => {
1108                let mut args = f.args.into_iter();
1109                let arg = args.next().unwrap();
1110
1111                // If there's a timezone arg, use TIMESTAMP() function instead
1112                if args.next().is_some() {
1113                    return Ok(Expression::Function(Box::new(Function::new(
1114                        "TIMESTAMP".to_string(),
1115                        vec![arg],
1116                    ))));
1117                }
1118
1119                // Extract sub-second precision from the string literal
1120                let precision = if let Expression::Literal(ref lit) = arg {
1121                    if let crate::expressions::Literal::String(ref s) = lit.as_ref() {
1122                        // Find fractional seconds: look for .NNN pattern after HH:MM:SS
1123                        if let Some(dot_pos) = s.rfind('.') {
1124                            let after_dot = &s[dot_pos + 1..];
1125                            // Count digits until non-digit
1126                            let frac_digits =
1127                                after_dot.chars().take_while(|c| c.is_ascii_digit()).count();
1128                            if frac_digits > 0 {
1129                                // Round up: 1-3 digits → 3, 4-6 digits → 6
1130                                if frac_digits <= 3 {
1131                                    Some(3)
1132                                } else {
1133                                    Some(6)
1134                                }
1135                            } else {
1136                                None
1137                            }
1138                        } else {
1139                            None
1140                        }
1141                    } else {
1142                        None
1143                    }
1144                } else {
1145                    None
1146                };
1147
1148                let type_name = match precision {
1149                    Some(p) => format!("DATETIME({})", p),
1150                    None => "DATETIME".to_string(),
1151                };
1152
1153                Ok(Expression::Cast(Box::new(Cast {
1154                    this: arg,
1155                    to: DataType::Custom { name: type_name },
1156                    trailing_comments: Vec::new(),
1157                    double_colon_syntax: false,
1158                    format: None,
1159                    default: None,
1160                    inferred_type: None,
1161                })))
1162            }
1163
1164            // UCASE -> UPPER in MySQL
1165            "UCASE" => Ok(Expression::Function(Box::new(Function::new(
1166                "UPPER".to_string(),
1167                f.args,
1168            )))),
1169
1170            // LCASE -> LOWER in MySQL
1171            "LCASE" => Ok(Expression::Function(Box::new(Function::new(
1172                "LOWER".to_string(),
1173                f.args,
1174            )))),
1175
1176            // DAY_OF_MONTH -> DAYOFMONTH in MySQL
1177            "DAY_OF_MONTH" => Ok(Expression::Function(Box::new(Function::new(
1178                "DAYOFMONTH".to_string(),
1179                f.args,
1180            )))),
1181
1182            // DAY_OF_WEEK -> DAYOFWEEK in MySQL
1183            "DAY_OF_WEEK" => Ok(Expression::Function(Box::new(Function::new(
1184                "DAYOFWEEK".to_string(),
1185                f.args,
1186            )))),
1187
1188            // DAY_OF_YEAR -> DAYOFYEAR in MySQL
1189            "DAY_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
1190                "DAYOFYEAR".to_string(),
1191                f.args,
1192            )))),
1193
1194            // WEEK_OF_YEAR -> WEEKOFYEAR in MySQL
1195            "WEEK_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
1196                "WEEKOFYEAR".to_string(),
1197                f.args,
1198            )))),
1199
1200            // MOD(x, y) -> x % y in MySQL
1201            "MOD" if f.args.len() == 2 => {
1202                let mut args = f.args;
1203                let left = args.remove(0);
1204                let right = args.remove(0);
1205                Ok(Expression::Mod(Box::new(BinaryOp {
1206                    left,
1207                    right,
1208                    left_comments: Vec::new(),
1209                    operator_comments: Vec::new(),
1210                    trailing_comments: Vec::new(),
1211                    inferred_type: None,
1212                })))
1213            }
1214
1215            // PARSE_JSON -> strip in MySQL (just keep the string argument)
1216            "PARSE_JSON" if f.args.len() == 1 => Ok(f.args.into_iter().next().unwrap()),
1217
1218            // GET_PATH(obj, path) -> JSON_EXTRACT(obj, json_path) in MySQL
1219            "GET_PATH" if f.args.len() == 2 => {
1220                let mut args = f.args;
1221                let this = args.remove(0);
1222                let path = args.remove(0);
1223                let json_path = match &path {
1224                    Expression::Literal(lit) if matches!(lit.as_ref(), Literal::String(_)) => {
1225                        let Literal::String(s) = lit.as_ref() else {
1226                            unreachable!()
1227                        };
1228                        // Convert bracket notation ["key"] to quoted dot notation ."key"
1229                        let s = Self::convert_bracket_to_quoted_path(s);
1230                        let normalized = if s.starts_with('$') {
1231                            s
1232                        } else if s.starts_with('[') {
1233                            format!("${}", s)
1234                        } else {
1235                            format!("$.{}", s)
1236                        };
1237                        Expression::Literal(Box::new(Literal::String(normalized)))
1238                    }
1239                    _ => path,
1240                };
1241                Ok(Expression::JsonExtract(Box::new(JsonExtractFunc {
1242                    this,
1243                    path: json_path,
1244                    returning: None,
1245                    arrow_syntax: false,
1246                    hash_arrow_syntax: false,
1247                    wrapper_option: None,
1248                    quotes_option: None,
1249                    on_scalar_string: false,
1250                    on_error: None,
1251                })))
1252            }
1253
1254            // REGEXP -> REGEXP_LIKE (MySQL standard form)
1255            "REGEXP" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
1256                "REGEXP_LIKE".to_string(),
1257                f.args,
1258            )))),
1259
1260            // CURTIME -> CURRENT_TIME
1261            "CURTIME" => Ok(Expression::CurrentTime(crate::expressions::CurrentTime {
1262                precision: None,
1263            })),
1264
1265            // TRUNC -> TRUNCATE in MySQL
1266            "TRUNC" => Ok(Expression::Function(Box::new(Function::new(
1267                "TRUNCATE".to_string(),
1268                f.args,
1269            )))),
1270
1271            // Pass through everything else
1272            _ => Ok(Expression::Function(Box::new(f))),
1273        }
1274    }
1275
1276    fn transform_aggregate_function(
1277        &self,
1278        f: Box<crate::expressions::AggregateFunction>,
1279    ) -> Result<Expression> {
1280        let name_upper = f.name.to_uppercase();
1281        match name_upper.as_str() {
1282            // STRING_AGG -> GROUP_CONCAT
1283            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
1284                Function::new("GROUP_CONCAT".to_string(), f.args),
1285            ))),
1286
1287            // ARRAY_AGG -> GROUP_CONCAT
1288            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
1289                "GROUP_CONCAT".to_string(),
1290                f.args,
1291            )))),
1292
1293            // Pass through everything else
1294            _ => Ok(Expression::AggregateFunction(f)),
1295        }
1296    }
1297}
1298
1299#[cfg(test)]
1300mod tests {
1301    use super::*;
1302    use crate::dialects::Dialect;
1303
1304    fn transpile_to_mysql(sql: &str) -> String {
1305        let dialect = Dialect::get(DialectType::Generic);
1306        let result = dialect
1307            .transpile(sql, DialectType::MySQL)
1308            .expect("Transpile failed");
1309        result[0].clone()
1310    }
1311
1312    #[test]
1313    fn test_nvl_to_ifnull() {
1314        let result = transpile_to_mysql("SELECT NVL(a, b)");
1315        assert!(
1316            result.contains("IFNULL"),
1317            "Expected IFNULL, got: {}",
1318            result
1319        );
1320    }
1321
1322    #[test]
1323    fn test_coalesce_preserved() {
1324        // COALESCE should be preserved in MySQL (it's a native function)
1325        let result = transpile_to_mysql("SELECT COALESCE(a, b)");
1326        assert!(
1327            result.contains("COALESCE"),
1328            "Expected COALESCE to be preserved, got: {}",
1329            result
1330        );
1331    }
1332
1333    #[test]
1334    fn test_random_to_rand() {
1335        let result = transpile_to_mysql("SELECT RANDOM()");
1336        assert!(result.contains("RAND"), "Expected RAND, got: {}", result);
1337    }
1338
1339    #[test]
1340    fn test_basic_select() {
1341        let result = transpile_to_mysql("SELECT a, b FROM users WHERE id = 1");
1342        assert!(result.contains("SELECT"));
1343        assert!(result.contains("FROM users"));
1344    }
1345
1346    #[test]
1347    fn test_string_agg_to_group_concat() {
1348        let result = transpile_to_mysql("SELECT STRING_AGG(name)");
1349        assert!(
1350            result.contains("GROUP_CONCAT"),
1351            "Expected GROUP_CONCAT, got: {}",
1352            result
1353        );
1354    }
1355
1356    #[test]
1357    fn test_array_agg_to_group_concat() {
1358        let result = transpile_to_mysql("SELECT ARRAY_AGG(name)");
1359        assert!(
1360            result.contains("GROUP_CONCAT"),
1361            "Expected GROUP_CONCAT, got: {}",
1362            result
1363        );
1364    }
1365
1366    #[test]
1367    fn test_to_date_to_str_to_date() {
1368        let result = transpile_to_mysql("SELECT TO_DATE('2023-01-01')");
1369        assert!(
1370            result.contains("STR_TO_DATE"),
1371            "Expected STR_TO_DATE, got: {}",
1372            result
1373        );
1374    }
1375
1376    #[test]
1377    fn test_backtick_identifiers() {
1378        // MySQL uses backticks for identifiers
1379        let dialect = MySQLDialect;
1380        let config = dialect.generator_config();
1381        assert_eq!(config.identifier_quote, '`');
1382    }
1383
1384    #[test]
1385    fn test_create_procedure_signal_sqlstate_regression() {
1386        let sql = "CREATE PROCEDURE CheckSalary(IN p_salary DECIMAL(10,2))
1387BEGIN
1388  IF p_salary <= 0 THEN
1389    SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Salary must be positive';
1390  END IF;
1391  INSERT INTO employees (salary) VALUES (p_salary);
1392END;";
1393
1394        let dialect = Dialect::get(DialectType::MySQL);
1395        let ast = dialect.parse(sql).expect("Parse failed");
1396        let output = dialect.generate(&ast[0]).expect("Generate failed");
1397
1398        assert!(output.contains("CREATE PROCEDURE CheckSalary"));
1399        assert!(output.contains("IF p_salary <= 0 THEN"));
1400        assert!(
1401            output.contains("SIGNAL SQLSTATE '45000' SET MESSAGE_TEXT = 'Salary must be positive'")
1402        );
1403        assert!(output.contains("END IF"));
1404        assert!(output.contains("INSERT INTO employees (salary) VALUES (p_salary)"));
1405    }
1406
1407    fn mysql_identity(sql: &str, expected: &str) {
1408        let dialect = Dialect::get(DialectType::MySQL);
1409        let ast = dialect.parse(sql).expect("Parse failed");
1410        let transformed = dialect.transform(ast[0].clone()).expect("Transform failed");
1411        let result = dialect.generate(&transformed).expect("Generate failed");
1412        assert_eq!(result, expected, "SQL: {}", sql);
1413    }
1414
1415    #[test]
1416    fn test_ucase_to_upper() {
1417        mysql_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')");
1418    }
1419
1420    #[test]
1421    fn test_lcase_to_lower() {
1422        mysql_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')");
1423    }
1424
1425    #[test]
1426    fn test_day_of_month() {
1427        mysql_identity(
1428            "SELECT DAY_OF_MONTH('2023-01-01')",
1429            "SELECT DAYOFMONTH('2023-01-01')",
1430        );
1431    }
1432
1433    #[test]
1434    fn test_day_of_week() {
1435        mysql_identity(
1436            "SELECT DAY_OF_WEEK('2023-01-01')",
1437            "SELECT DAYOFWEEK('2023-01-01')",
1438        );
1439    }
1440
1441    #[test]
1442    fn test_day_of_year() {
1443        mysql_identity(
1444            "SELECT DAY_OF_YEAR('2023-01-01')",
1445            "SELECT DAYOFYEAR('2023-01-01')",
1446        );
1447    }
1448
1449    #[test]
1450    fn test_week_of_year() {
1451        mysql_identity(
1452            "SELECT WEEK_OF_YEAR('2023-01-01')",
1453            "SELECT WEEKOFYEAR('2023-01-01')",
1454        );
1455    }
1456
1457    #[test]
1458    fn test_mod_func_to_percent() {
1459        // MOD(x, y) function is transformed to x % y in MySQL
1460        mysql_identity("MOD(x, y)", "x % y");
1461    }
1462
1463    #[test]
1464    fn test_database_to_schema() {
1465        mysql_identity("DATABASE()", "SCHEMA()");
1466    }
1467
1468    #[test]
1469    fn test_and_operator() {
1470        mysql_identity("SELECT 1 && 0", "SELECT 1 AND 0");
1471    }
1472
1473    #[test]
1474    fn test_or_operator() {
1475        mysql_identity("SELECT a || b", "SELECT a OR b");
1476    }
1477}