Skip to main content

polyglot_sql/dialects/
mysql.rs

1//! MySQL Dialect
2//!
3//! MySQL-specific transformations based on sqlglot patterns.
4//! Key differences from standard SQL:
5//! - || is OR operator, not string concatenation (use CONCAT)
6//! - Uses backticks for identifiers
7//! - No TRY_CAST, no ILIKE
8//! - Different date/time function names
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    BinaryFunc, BinaryOp, Cast, DataType, Expression, Function, JsonExtractFunc, LikeOp, Literal,
14    Paren, UnaryFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Helper to wrap JSON arrow expressions in parentheses when they appear
20/// in contexts that require it (Binary, In, Not expressions)
21/// This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior
22fn wrap_if_json_arrow(expr: Expression) -> Expression {
23    match &expr {
24        Expression::JsonExtract(f) if f.arrow_syntax => Expression::Paren(Box::new(Paren {
25            this: expr,
26            trailing_comments: Vec::new(),
27        })),
28        Expression::JsonExtractScalar(f) if f.arrow_syntax => Expression::Paren(Box::new(Paren {
29            this: expr,
30            trailing_comments: Vec::new(),
31        })),
32        _ => expr,
33    }
34}
35
36/// Convert JSON arrow expression (-> or ->>) to JSON_EXTRACT function form
37/// This is needed for contexts like MEMBER OF where arrow syntax must become function form
38fn json_arrow_to_function(expr: Expression) -> Expression {
39    match expr {
40        Expression::JsonExtract(f) if f.arrow_syntax => Expression::Function(Box::new(
41            Function::new("JSON_EXTRACT".to_string(), vec![f.this, f.path]),
42        )),
43        Expression::JsonExtractScalar(f) if f.arrow_syntax => {
44            // ->> becomes JSON_UNQUOTE(JSON_EXTRACT(...)) but can be simplified to JSON_EXTRACT_SCALAR
45            // For MySQL, use JSON_UNQUOTE(JSON_EXTRACT(...))
46            let json_extract = Expression::Function(Box::new(Function::new(
47                "JSON_EXTRACT".to_string(),
48                vec![f.this, f.path],
49            )));
50            Expression::Function(Box::new(Function::new(
51                "JSON_UNQUOTE".to_string(),
52                vec![json_extract],
53            )))
54        }
55        other => other,
56    }
57}
58
59/// MySQL dialect
60pub struct MySQLDialect;
61
62impl DialectImpl for MySQLDialect {
63    fn dialect_type(&self) -> DialectType {
64        DialectType::MySQL
65    }
66
67    fn tokenizer_config(&self) -> TokenizerConfig {
68        use crate::tokens::TokenType;
69        let mut config = TokenizerConfig::default();
70        // MySQL uses backticks for identifiers
71        config.identifiers.insert('`', '`');
72        // Remove double quotes from identifiers - in MySQL they are string delimiters
73        // (unless ANSI_QUOTES mode is set, but default mode uses them as strings)
74        config.identifiers.remove(&'"');
75        // MySQL supports double quotes as string literals by default
76        config.quotes.insert("\"".to_string(), "\"".to_string());
77        // MySQL supports backslash escapes in strings
78        config.string_escapes.push('\\');
79        // MySQL has XOR as a logical operator keyword
80        config.keywords.insert("XOR".to_string(), TokenType::Xor);
81        // MySQL: backslash followed by chars NOT in this list -> discard backslash
82        // See: https://dev.mysql.com/doc/refman/8.4/en/string-literals.html
83        config.escape_follow_chars = vec!['0', 'b', 'n', 'r', 't', 'Z', '%', '_'];
84        // MySQL allows identifiers to start with digits (e.g., 1a, 1_a)
85        config.identifiers_can_start_with_digit = true;
86        config
87    }
88
89    fn generator_config(&self) -> GeneratorConfig {
90        use crate::generator::IdentifierQuoteStyle;
91        GeneratorConfig {
92            identifier_quote: '`',
93            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
94            dialect: Some(DialectType::MySQL),
95            // MySQL doesn't support null ordering in most contexts
96            null_ordering_supported: false,
97            // MySQL LIMIT only
98            limit_only_literals: true,
99            // MySQL doesn't support semi/anti join
100            semi_anti_join_with_side: false,
101            // MySQL doesn't support table alias columns in some contexts
102            supports_table_alias_columns: false,
103            // MySQL VALUES not used as table
104            values_as_table: false,
105            // MySQL doesn't support TABLESAMPLE
106            tablesample_requires_parens: false,
107            tablesample_with_method: false,
108            // MySQL doesn't support aggregate FILTER
109            aggregate_filter_supported: false,
110            // MySQL doesn't support TRY
111            try_supported: false,
112            // MySQL doesn't support CONVERT_TIMEZONE
113            supports_convert_timezone: false,
114            // MySQL doesn't support UESCAPE
115            supports_uescape: false,
116            // MySQL doesn't support BETWEEN flags
117            supports_between_flags: false,
118            // MySQL supports EXPLAIN but not query hints in standard way
119            query_hints: false,
120            // MySQL parameter token
121            parameter_token: "?",
122            // MySQL doesn't support window EXCLUDE
123            supports_window_exclude: false,
124            // MySQL doesn't support exploding projections
125            supports_exploding_projections: false,
126            identifiers_can_start_with_digit: true,
127            // MySQL supports FOR UPDATE/SHARE
128            locking_reads_supported: true,
129            ..Default::default()
130        }
131    }
132
133    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
134        match expr {
135            // ===== Data Type Mappings =====
136            Expression::DataType(dt) => self.transform_data_type(dt),
137
138            // NVL -> IFNULL in MySQL
139            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
140
141            // Note: COALESCE is valid in MySQL and should be preserved.
142            // Unlike some other dialects, we do NOT convert COALESCE to IFNULL
143            // as this would break identity tests.
144
145            // TryCast -> CAST or TIMESTAMP() (MySQL doesn't support TRY_CAST)
146            Expression::TryCast(c) => self.transform_cast(*c),
147
148            // SafeCast -> CAST or TIMESTAMP() (MySQL doesn't support safe casts)
149            Expression::SafeCast(c) => self.transform_cast(*c),
150
151            // Cast -> Transform cast type according to MySQL restrictions
152            // CAST AS TIMESTAMP -> TIMESTAMP() function in MySQL
153            Expression::Cast(c) => self.transform_cast(*c),
154
155            // ILIKE -> LOWER() LIKE LOWER() in MySQL
156            Expression::ILike(op) => {
157                // Transform ILIKE to: LOWER(left) LIKE LOWER(right)
158                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left)));
159                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right)));
160                Ok(Expression::Like(Box::new(LikeOp {
161                    left: lower_left,
162                    right: lower_right,
163                    escape: op.escape,
164                    quantifier: op.quantifier,
165                    inferred_type: None,
166                })))
167            }
168
169            // Preserve semantic string concatenation expressions.
170            // MySQL generation renders these as CONCAT(...).
171            Expression::Concat(op) => Ok(Expression::Concat(op)),
172
173            // RANDOM -> RAND in MySQL
174            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
175                seed: None,
176                lower: None,
177                upper: None,
178            }))),
179
180            // ArrayAgg -> GROUP_CONCAT in MySQL
181            Expression::ArrayAgg(f) => Ok(Expression::Function(Box::new(Function::new(
182                "GROUP_CONCAT".to_string(),
183                vec![f.this],
184            )))),
185
186            // StringAgg -> GROUP_CONCAT in MySQL
187            Expression::StringAgg(f) => {
188                let mut args = vec![f.this.clone()];
189                if let Some(separator) = &f.separator {
190                    args.push(separator.clone());
191                }
192                Ok(Expression::Function(Box::new(Function::new(
193                    "GROUP_CONCAT".to_string(),
194                    args,
195                ))))
196            }
197
198            // UNNEST -> Not directly supported in MySQL, use JSON_TABLE or inline
199            // For basic cases, pass through (may need manual handling)
200            Expression::Unnest(f) => {
201                // MySQL 8.0+ has JSON_TABLE which can be used for unnesting
202                // For now, pass through with a function call
203                Ok(Expression::Function(Box::new(Function::new(
204                    "JSON_TABLE".to_string(),
205                    vec![f.this],
206                ))))
207            }
208
209            // Substring: Use comma syntax (not FROM/FOR) in MySQL
210            Expression::Substring(mut f) => {
211                f.from_for_syntax = false;
212                Ok(Expression::Substring(f))
213            }
214
215            // ===== Bitwise operations =====
216            // BitwiseAndAgg -> BIT_AND
217            Expression::BitwiseAndAgg(f) => Ok(Expression::Function(Box::new(Function::new(
218                "BIT_AND".to_string(),
219                vec![f.this],
220            )))),
221
222            // BitwiseOrAgg -> BIT_OR
223            Expression::BitwiseOrAgg(f) => Ok(Expression::Function(Box::new(Function::new(
224                "BIT_OR".to_string(),
225                vec![f.this],
226            )))),
227
228            // BitwiseXorAgg -> BIT_XOR
229            Expression::BitwiseXorAgg(f) => Ok(Expression::Function(Box::new(Function::new(
230                "BIT_XOR".to_string(),
231                vec![f.this],
232            )))),
233
234            // BitwiseCount -> BIT_COUNT
235            Expression::BitwiseCount(f) => Ok(Expression::Function(Box::new(Function::new(
236                "BIT_COUNT".to_string(),
237                vec![f.this],
238            )))),
239
240            // TimeFromParts -> MAKETIME
241            Expression::TimeFromParts(f) => {
242                let mut args = Vec::new();
243                if let Some(h) = f.hour {
244                    args.push(*h);
245                }
246                if let Some(m) = f.min {
247                    args.push(*m);
248                }
249                if let Some(s) = f.sec {
250                    args.push(*s);
251                }
252                Ok(Expression::Function(Box::new(Function::new(
253                    "MAKETIME".to_string(),
254                    args,
255                ))))
256            }
257
258            // ===== Boolean aggregates =====
259            // In MySQL, there's no BOOL_AND/BOOL_OR, use MIN/MAX on boolean values
260            // LogicalAnd -> MIN (0 is false, non-0 is true)
261            Expression::LogicalAnd(f) => Ok(Expression::Function(Box::new(Function::new(
262                "MIN".to_string(),
263                vec![f.this],
264            )))),
265
266            // LogicalOr -> MAX
267            Expression::LogicalOr(f) => Ok(Expression::Function(Box::new(Function::new(
268                "MAX".to_string(),
269                vec![f.this],
270            )))),
271
272            // ===== Date/time functions =====
273            // DayOfMonth -> DAYOFMONTH
274            Expression::DayOfMonth(f) => Ok(Expression::Function(Box::new(Function::new(
275                "DAYOFMONTH".to_string(),
276                vec![f.this],
277            )))),
278
279            // DayOfWeek -> DAYOFWEEK
280            Expression::DayOfWeek(f) => Ok(Expression::Function(Box::new(Function::new(
281                "DAYOFWEEK".to_string(),
282                vec![f.this],
283            )))),
284
285            // DayOfYear -> DAYOFYEAR
286            Expression::DayOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
287                "DAYOFYEAR".to_string(),
288                vec![f.this],
289            )))),
290
291            // WeekOfYear -> WEEKOFYEAR
292            Expression::WeekOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
293                "WEEKOFYEAR".to_string(),
294                vec![f.this],
295            )))),
296
297            // DateDiff -> DATEDIFF
298            Expression::DateDiff(f) => Ok(Expression::Function(Box::new(Function::new(
299                "DATEDIFF".to_string(),
300                vec![f.this, f.expression],
301            )))),
302
303            // TimeStrToUnix -> UNIX_TIMESTAMP
304            Expression::TimeStrToUnix(f) => Ok(Expression::Function(Box::new(Function::new(
305                "UNIX_TIMESTAMP".to_string(),
306                vec![f.this],
307            )))),
308
309            // TimestampDiff -> TIMESTAMPDIFF
310            Expression::TimestampDiff(f) => Ok(Expression::Function(Box::new(Function::new(
311                "TIMESTAMPDIFF".to_string(),
312                vec![*f.this, *f.expression],
313            )))),
314
315            // ===== String functions =====
316            // StrPosition -> LOCATE in MySQL
317            // STRPOS(str, substr) -> LOCATE(substr, str) (args are swapped)
318            Expression::StrPosition(f) => {
319                let mut args = vec![];
320                if let Some(substr) = f.substr {
321                    args.push(*substr);
322                }
323                args.push(*f.this);
324                if let Some(pos) = f.position {
325                    args.push(*pos);
326                }
327                Ok(Expression::Function(Box::new(Function::new(
328                    "LOCATE".to_string(),
329                    args,
330                ))))
331            }
332
333            // Stuff -> INSERT in MySQL
334            Expression::Stuff(f) => {
335                let mut args = vec![*f.this];
336                if let Some(start) = f.start {
337                    args.push(*start);
338                }
339                if let Some(length) = f.length {
340                    args.push(Expression::number(length));
341                }
342                args.push(*f.expression);
343                Ok(Expression::Function(Box::new(Function::new(
344                    "INSERT".to_string(),
345                    args,
346                ))))
347            }
348
349            // ===== Session/User functions =====
350            // SessionUser -> SESSION_USER()
351            Expression::SessionUser(_) => Ok(Expression::Function(Box::new(Function::new(
352                "SESSION_USER".to_string(),
353                vec![],
354            )))),
355
356            // CurrentDate -> CURRENT_DATE (no parentheses in MySQL) - keep as CurrentDate
357            Expression::CurrentDate(_) => {
358                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
359            }
360
361            // ===== Null-safe comparison =====
362            // NullSafeNeq -> NOT (a <=> b) in MySQL
363            Expression::NullSafeNeq(op) => {
364                // Create: NOT (left <=> right)
365                let null_safe_eq = Expression::NullSafeEq(Box::new(crate::expressions::BinaryOp {
366                    left: op.left,
367                    right: op.right,
368                    left_comments: Vec::new(),
369                    operator_comments: Vec::new(),
370                    trailing_comments: Vec::new(),
371                    inferred_type: None,
372                }));
373                Ok(Expression::Not(Box::new(crate::expressions::UnaryOp {
374                    this: null_safe_eq,
375                    inferred_type: None,
376                })))
377            }
378
379            // ParseJson: handled by generator (emits just the string literal for MySQL)
380
381            // JSONExtract with variant_extract (Snowflake colon syntax) -> JSON_EXTRACT
382            Expression::JSONExtract(e) if e.variant_extract.is_some() => {
383                let path = match *e.expression {
384                    Expression::Literal(Literal::String(s)) => {
385                        // Convert bracket notation ["key"] to quoted dot notation ."key"
386                        let s = Self::convert_bracket_to_quoted_path(&s);
387                        let normalized = if s.starts_with('$') {
388                            s
389                        } else if s.starts_with('[') {
390                            format!("${}", s)
391                        } else {
392                            format!("$.{}", s)
393                        };
394                        Expression::Literal(Literal::String(normalized))
395                    }
396                    other => other,
397                };
398                Ok(Expression::Function(Box::new(Function::new(
399                    "JSON_EXTRACT".to_string(),
400                    vec![*e.this, path],
401                ))))
402            }
403
404            // Generic function transformations
405            Expression::Function(f) => self.transform_function(*f),
406
407            // Generic aggregate function transformations
408            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
409
410            // ===== Context-aware JSON arrow wrapping =====
411            // When JSON arrow expressions appear in Binary/In/Not contexts,
412            // they need to be wrapped in parentheses for correct precedence.
413            // This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior.
414
415            // Binary operators that need JSON wrapping
416            Expression::Eq(op) => Ok(Expression::Eq(Box::new(BinaryOp {
417                left: wrap_if_json_arrow(op.left),
418                right: wrap_if_json_arrow(op.right),
419                ..*op
420            }))),
421            Expression::Neq(op) => Ok(Expression::Neq(Box::new(BinaryOp {
422                left: wrap_if_json_arrow(op.left),
423                right: wrap_if_json_arrow(op.right),
424                ..*op
425            }))),
426            Expression::Lt(op) => Ok(Expression::Lt(Box::new(BinaryOp {
427                left: wrap_if_json_arrow(op.left),
428                right: wrap_if_json_arrow(op.right),
429                ..*op
430            }))),
431            Expression::Lte(op) => Ok(Expression::Lte(Box::new(BinaryOp {
432                left: wrap_if_json_arrow(op.left),
433                right: wrap_if_json_arrow(op.right),
434                ..*op
435            }))),
436            Expression::Gt(op) => Ok(Expression::Gt(Box::new(BinaryOp {
437                left: wrap_if_json_arrow(op.left),
438                right: wrap_if_json_arrow(op.right),
439                ..*op
440            }))),
441            Expression::Gte(op) => Ok(Expression::Gte(Box::new(BinaryOp {
442                left: wrap_if_json_arrow(op.left),
443                right: wrap_if_json_arrow(op.right),
444                ..*op
445            }))),
446
447            // In expression - wrap the this part if it's JSON arrow
448            Expression::In(mut i) => {
449                i.this = wrap_if_json_arrow(i.this);
450                Ok(Expression::In(i))
451            }
452
453            // Not expression - wrap the this part if it's JSON arrow
454            Expression::Not(mut n) => {
455                n.this = wrap_if_json_arrow(n.this);
456                Ok(Expression::Not(n))
457            }
458
459            // && in MySQL is logical AND, not array overlaps
460            // Transform ArrayOverlaps -> And for MySQL identity
461            Expression::ArrayOverlaps(op) => Ok(Expression::And(op)),
462
463            // MOD(x, y) -> x % y in MySQL
464            Expression::ModFunc(f) => Ok(Expression::Mod(Box::new(BinaryOp {
465                left: f.this,
466                right: f.expression,
467                left_comments: Vec::new(),
468                operator_comments: Vec::new(),
469                trailing_comments: Vec::new(),
470                inferred_type: None,
471            }))),
472
473            // SHOW SLAVE STATUS -> SHOW REPLICA STATUS
474            Expression::Show(mut s) => {
475                if s.this == "SLAVE STATUS" {
476                    s.this = "REPLICA STATUS".to_string();
477                }
478                if matches!(s.this.as_str(), "INDEX" | "COLUMNS") && s.db.is_none() {
479                    if let Some(Expression::Table(mut t)) = s.target.take() {
480                        if let Some(db_ident) = t.schema.take().or(t.catalog.take()) {
481                            s.db = Some(Expression::Identifier(db_ident));
482                            s.target = Some(Expression::Identifier(t.name));
483                        } else {
484                            s.target = Some(Expression::Table(t));
485                        }
486                    }
487                }
488                Ok(Expression::Show(s))
489            }
490
491            // AT TIME ZONE -> strip timezone (MySQL doesn't support AT TIME ZONE)
492            // But keep it for CURRENT_DATE/CURRENT_TIMESTAMP with timezone (transpiled from BigQuery)
493            Expression::AtTimeZone(atz) => {
494                let is_current = match &atz.this {
495                    Expression::CurrentDate(_) | Expression::CurrentTimestamp(_) => true,
496                    Expression::Function(f) => {
497                        let n = f.name.to_uppercase();
498                        (n == "CURRENT_DATE" || n == "CURRENT_TIMESTAMP") && f.no_parens
499                    }
500                    _ => false,
501                };
502                if is_current {
503                    Ok(Expression::AtTimeZone(atz)) // Keep AT TIME ZONE for CURRENT_DATE/CURRENT_TIMESTAMP
504                } else {
505                    Ok(atz.this) // Strip timezone for other expressions
506                }
507            }
508
509            // MEMBER OF with JSON arrow -> convert arrow to JSON_EXTRACT function
510            // MySQL's MEMBER OF requires JSON_EXTRACT function form, not arrow syntax
511            Expression::MemberOf(mut op) => {
512                op.right = json_arrow_to_function(op.right);
513                Ok(Expression::MemberOf(op))
514            }
515
516            // Pass through everything else
517            _ => Ok(expr),
518        }
519    }
520}
521
522impl MySQLDialect {
523    fn normalize_mysql_date_format(fmt: &str) -> String {
524        fmt.replace("%H:%i:%s", "%T").replace("%H:%i:%S", "%T")
525    }
526
527    /// Convert bracket notation ["key with spaces"] to quoted dot notation ."key with spaces"
528    /// in JSON path strings.
529    fn convert_bracket_to_quoted_path(path: &str) -> String {
530        let mut result = String::new();
531        let mut chars = path.chars().peekable();
532        while let Some(c) = chars.next() {
533            if c == '[' && chars.peek() == Some(&'"') {
534                chars.next(); // consume "
535                let mut key = String::new();
536                while let Some(kc) = chars.next() {
537                    if kc == '"' && chars.peek() == Some(&']') {
538                        chars.next(); // consume ]
539                        break;
540                    }
541                    key.push(kc);
542                }
543                if !result.is_empty() && !result.ends_with('.') {
544                    result.push('.');
545                }
546                result.push('"');
547                result.push_str(&key);
548                result.push('"');
549            } else {
550                result.push(c);
551            }
552        }
553        result
554    }
555
556    /// Transform data types according to MySQL TYPE_MAPPING
557    /// Note: MySQL's TIMESTAMP is kept as TIMESTAMP (not converted to DATETIME)
558    /// because MySQL's TIMESTAMP has timezone awareness built-in
559    fn transform_data_type(&self, dt: crate::expressions::DataType) -> Result<Expression> {
560        use crate::expressions::DataType;
561        let transformed = match dt {
562            // All TIMESTAMP variants (with or without timezone) -> TIMESTAMP in MySQL
563            DataType::Timestamp {
564                precision,
565                timezone: _,
566            } => DataType::Timestamp {
567                precision,
568                timezone: false,
569            },
570            // TIMESTAMPTZ / TIMESTAMPLTZ parsed as Custom -> normalize to TIMESTAMP
571            DataType::Custom { name }
572                if name.to_uppercase() == "TIMESTAMPTZ"
573                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
574            {
575                DataType::Timestamp {
576                    precision: None,
577                    timezone: false,
578                }
579            }
580            // Keep native MySQL types as-is
581            // MySQL supports TEXT, MEDIUMTEXT, LONGTEXT, BLOB, etc. natively
582            other => other,
583        };
584        Ok(Expression::DataType(transformed))
585    }
586
587    /// Transform CAST expression
588    /// MySQL uses TIMESTAMP() function instead of CAST(x AS TIMESTAMP)
589    /// For Generic->MySQL, TIMESTAMP (no tz) is pre-converted to DATETIME in cross_dialect_normalize
590    fn transform_cast(&self, cast: Cast) -> Result<Expression> {
591        // CAST AS TIMESTAMP/TIMESTAMPTZ/TIMESTAMPLTZ -> TIMESTAMP() function
592        match &cast.to {
593            DataType::Timestamp { .. } => Ok(Expression::Function(Box::new(Function::new(
594                "TIMESTAMP".to_string(),
595                vec![cast.this],
596            )))),
597            DataType::Custom { name }
598                if name.to_uppercase() == "TIMESTAMPTZ"
599                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
600            {
601                Ok(Expression::Function(Box::new(Function::new(
602                    "TIMESTAMP".to_string(),
603                    vec![cast.this],
604                ))))
605            }
606            // All other casts go through normal type transformation
607            _ => Ok(Expression::Cast(Box::new(self.transform_cast_type(cast)))),
608        }
609    }
610
611    /// Transform CAST type according to MySQL restrictions
612    /// MySQL doesn't support many types in CAST - they get mapped to CHAR or SIGNED
613    /// Based on Python sqlglot's CHAR_CAST_MAPPING and SIGNED_CAST_MAPPING
614    fn transform_cast_type(&self, cast: Cast) -> Cast {
615        let new_type = match &cast.to {
616            // CHAR_CAST_MAPPING: These types become CHAR in MySQL CAST, preserving length
617            DataType::VarChar { length, .. } => DataType::Char { length: *length },
618            DataType::Text => DataType::Char { length: None },
619
620            // SIGNED_CAST_MAPPING: These integer types become SIGNED in MySQL CAST
621            DataType::BigInt { .. } => DataType::Custom {
622                name: "SIGNED".to_string(),
623            },
624            DataType::Int { .. } => DataType::Custom {
625                name: "SIGNED".to_string(),
626            },
627            DataType::SmallInt { .. } => DataType::Custom {
628                name: "SIGNED".to_string(),
629            },
630            DataType::TinyInt { .. } => DataType::Custom {
631                name: "SIGNED".to_string(),
632            },
633            DataType::Boolean => DataType::Custom {
634                name: "SIGNED".to_string(),
635            },
636
637            // Custom types that need mapping
638            DataType::Custom { name } => {
639                let upper = name.to_uppercase();
640                match upper.as_str() {
641                    // Text/Blob types -> keep as Custom for cross-dialect mapping
642                    // MySQL generator will output CHAR for these in CAST context
643                    "LONGTEXT" | "MEDIUMTEXT" | "TINYTEXT" | "LONGBLOB" | "MEDIUMBLOB"
644                    | "TINYBLOB" => DataType::Custom { name: upper },
645                    // MEDIUMINT -> SIGNED in MySQL CAST
646                    "MEDIUMINT" => DataType::Custom {
647                        name: "SIGNED".to_string(),
648                    },
649                    // Unsigned integer types -> UNSIGNED
650                    "UBIGINT" | "UINT" | "USMALLINT" | "UTINYINT" | "UMEDIUMINT" => {
651                        DataType::Custom {
652                            name: "UNSIGNED".to_string(),
653                        }
654                    }
655                    // Keep other custom types
656                    _ => cast.to.clone(),
657                }
658            }
659
660            // Types that are valid in MySQL CAST - pass through
661            DataType::Binary { .. } => cast.to.clone(),
662            DataType::VarBinary { .. } => cast.to.clone(),
663            DataType::Date => cast.to.clone(),
664            DataType::Time { .. } => cast.to.clone(),
665            DataType::Decimal { .. } => cast.to.clone(),
666            DataType::Json => cast.to.clone(),
667            DataType::Float { .. } => cast.to.clone(),
668            DataType::Double { .. } => cast.to.clone(),
669            DataType::Char { .. } => cast.to.clone(),
670            DataType::CharacterSet { .. } => cast.to.clone(),
671            DataType::Enum { .. } => cast.to.clone(),
672            DataType::Set { .. } => cast.to.clone(),
673            DataType::Timestamp { .. } => cast.to.clone(),
674
675            // All other unsupported types -> CHAR
676            _ => DataType::Char { length: None },
677        };
678
679        Cast {
680            this: cast.this,
681            to: new_type,
682            trailing_comments: cast.trailing_comments,
683            double_colon_syntax: cast.double_colon_syntax,
684            format: cast.format,
685            default: cast.default,
686            inferred_type: None,
687        }
688    }
689
690    fn transform_function(&self, f: Function) -> Result<Expression> {
691        let name_upper = f.name.to_uppercase();
692        match name_upper.as_str() {
693            // Normalize DATE_FORMAT short-hands to canonical MySQL forms.
694            "DATE_FORMAT" if f.args.len() >= 2 => {
695                let mut f = f;
696                if let Some(Expression::Literal(Literal::String(fmt))) = f.args.get(1) {
697                    let normalized = Self::normalize_mysql_date_format(fmt);
698                    if normalized != *fmt {
699                        f.args[1] = Expression::Literal(Literal::String(normalized));
700                    }
701                }
702                Ok(Expression::Function(Box::new(f)))
703            }
704
705            // NVL -> IFNULL
706            "NVL" if f.args.len() == 2 => {
707                let mut args = f.args;
708                let second = args.pop().unwrap();
709                let first = args.pop().unwrap();
710                Ok(Expression::IfNull(Box::new(BinaryFunc {
711                    original_name: None,
712                    this: first,
713                    expression: second,
714                    inferred_type: None,
715                })))
716            }
717
718            // Note: COALESCE is native to MySQL. We do NOT convert it to IFNULL
719            // because this would break identity tests (Python SQLGlot preserves COALESCE).
720
721            // ARRAY_AGG -> GROUP_CONCAT
722            "ARRAY_AGG" if f.args.len() == 1 => {
723                let mut args = f.args;
724                Ok(Expression::Function(Box::new(Function::new(
725                    "GROUP_CONCAT".to_string(),
726                    vec![args.pop().unwrap()],
727                ))))
728            }
729
730            // STRING_AGG -> GROUP_CONCAT
731            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
732                Function::new("GROUP_CONCAT".to_string(), f.args),
733            ))),
734
735            // RANDOM -> RAND
736            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
737                seed: None,
738                lower: None,
739                upper: None,
740            }))),
741
742            // CURRENT_TIMESTAMP -> NOW() or CURRENT_TIMESTAMP (both work)
743            // Preserve precision if specified: CURRENT_TIMESTAMP(6)
744            "CURRENT_TIMESTAMP" => {
745                let precision =
746                    if let Some(Expression::Literal(crate::expressions::Literal::Number(n))) =
747                        f.args.first()
748                    {
749                        n.parse::<u32>().ok()
750                    } else {
751                        None
752                    };
753                Ok(Expression::CurrentTimestamp(
754                    crate::expressions::CurrentTimestamp {
755                        precision,
756                        sysdate: false,
757                    },
758                ))
759            }
760
761            // POSITION -> LOCATE in MySQL (argument order is different)
762            // POSITION(substr IN str) -> LOCATE(substr, str)
763            "POSITION" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
764                "LOCATE".to_string(),
765                f.args,
766            )))),
767
768            // LENGTH is native to MySQL (returns bytes, not characters)
769            // CHAR_LENGTH for character count
770            "LENGTH" => Ok(Expression::Function(Box::new(f))),
771
772            // CEIL -> CEILING in MySQL (both work)
773            "CEIL" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
774                "CEILING".to_string(),
775                f.args,
776            )))),
777
778            // STDDEV -> STD or STDDEV_POP in MySQL
779            "STDDEV" => Ok(Expression::Function(Box::new(Function::new(
780                "STD".to_string(),
781                f.args,
782            )))),
783
784            // STDDEV_SAMP -> STDDEV in MySQL
785            "STDDEV_SAMP" => Ok(Expression::Function(Box::new(Function::new(
786                "STDDEV".to_string(),
787                f.args,
788            )))),
789
790            // TO_DATE -> STR_TO_DATE in MySQL
791            "TO_DATE" => Ok(Expression::Function(Box::new(Function::new(
792                "STR_TO_DATE".to_string(),
793                f.args,
794            )))),
795
796            // TO_TIMESTAMP -> STR_TO_DATE in MySQL
797            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(Function::new(
798                "STR_TO_DATE".to_string(),
799                f.args,
800            )))),
801
802            // DATE_TRUNC -> Complex transformation
803            // Typically uses DATE() or DATE_FORMAT() depending on unit
804            "DATE_TRUNC" if f.args.len() >= 2 => {
805                // Simplified: DATE_TRUNC('day', x) -> DATE(x)
806                // Full implementation would handle different units
807                let mut args = f.args;
808                let _unit = args.remove(0);
809                let date = args.remove(0);
810                Ok(Expression::Function(Box::new(Function::new(
811                    "DATE".to_string(),
812                    vec![date],
813                ))))
814            }
815
816            // EXTRACT is native but syntax varies
817
818            // COALESCE is native to MySQL (keep as-is for more than 2 args)
819            "COALESCE" if f.args.len() > 2 => Ok(Expression::Function(Box::new(f))),
820
821            // DAYOFMONTH -> DAY (both work)
822            "DAY" => Ok(Expression::Function(Box::new(Function::new(
823                "DAYOFMONTH".to_string(),
824                f.args,
825            )))),
826
827            // DAYOFWEEK is native to MySQL
828            "DAYOFWEEK" => Ok(Expression::Function(Box::new(f))),
829
830            // DAYOFYEAR is native to MySQL
831            "DAYOFYEAR" => Ok(Expression::Function(Box::new(f))),
832
833            // WEEKOFYEAR is native to MySQL
834            "WEEKOFYEAR" => Ok(Expression::Function(Box::new(f))),
835
836            // LAST_DAY is native to MySQL
837            "LAST_DAY" => Ok(Expression::Function(Box::new(f))),
838
839            // TIMESTAMPADD -> DATE_ADD
840            "TIMESTAMPADD" => Ok(Expression::Function(Box::new(Function::new(
841                "DATE_ADD".to_string(),
842                f.args,
843            )))),
844
845            // TIMESTAMPDIFF is native to MySQL
846            "TIMESTAMPDIFF" => Ok(Expression::Function(Box::new(f))),
847
848            // CONVERT_TIMEZONE(from_tz, to_tz, timestamp) -> CONVERT_TZ(timestamp, from_tz, to_tz) in MySQL
849            "CONVERT_TIMEZONE" if f.args.len() == 3 => {
850                let mut args = f.args;
851                let from_tz = args.remove(0);
852                let to_tz = args.remove(0);
853                let timestamp = args.remove(0);
854                Ok(Expression::Function(Box::new(Function::new(
855                    "CONVERT_TZ".to_string(),
856                    vec![timestamp, from_tz, to_tz],
857                ))))
858            }
859
860            // UTC_TIMESTAMP is native to MySQL
861            "UTC_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
862
863            // UTC_TIME is native to MySQL
864            "UTC_TIME" => Ok(Expression::Function(Box::new(f))),
865
866            // MAKETIME is native to MySQL (TimeFromParts)
867            "MAKETIME" => Ok(Expression::Function(Box::new(f))),
868
869            // TIME_FROM_PARTS -> MAKETIME
870            "TIME_FROM_PARTS" if f.args.len() == 3 => Ok(Expression::Function(Box::new(
871                Function::new("MAKETIME".to_string(), f.args),
872            ))),
873
874            // STUFF -> INSERT in MySQL
875            "STUFF" if f.args.len() == 4 => Ok(Expression::Function(Box::new(Function::new(
876                "INSERT".to_string(),
877                f.args,
878            )))),
879
880            // LOCATE is native to MySQL (reverse of POSITION args)
881            "LOCATE" => Ok(Expression::Function(Box::new(f))),
882
883            // FIND_IN_SET is native to MySQL
884            "FIND_IN_SET" => Ok(Expression::Function(Box::new(f))),
885
886            // FORMAT is native to MySQL (NumberToStr)
887            "FORMAT" => Ok(Expression::Function(Box::new(f))),
888
889            // JSON_EXTRACT is native to MySQL
890            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
891
892            // JSON_UNQUOTE is native to MySQL
893            "JSON_UNQUOTE" => Ok(Expression::Function(Box::new(f))),
894
895            // JSON_EXTRACT_PATH_TEXT -> JSON_UNQUOTE(JSON_EXTRACT(...))
896            "JSON_EXTRACT_PATH_TEXT" if f.args.len() >= 2 => {
897                let extract = Expression::Function(Box::new(Function::new(
898                    "JSON_EXTRACT".to_string(),
899                    f.args,
900                )));
901                Ok(Expression::Function(Box::new(Function::new(
902                    "JSON_UNQUOTE".to_string(),
903                    vec![extract],
904                ))))
905            }
906
907            // GEN_RANDOM_UUID / UUID -> UUID()
908            "GEN_RANDOM_UUID" | "GENERATE_UUID" => Ok(Expression::Function(Box::new(
909                Function::new("UUID".to_string(), vec![]),
910            ))),
911
912            // DATABASE() -> SCHEMA() in MySQL (both return current database name)
913            "DATABASE" => Ok(Expression::Function(Box::new(Function::new(
914                "SCHEMA".to_string(),
915                f.args,
916            )))),
917
918            // INSTR -> LOCATE in MySQL (with swapped arguments)
919            // INSTR(str, substr) -> LOCATE(substr, str)
920            "INSTR" if f.args.len() == 2 => {
921                let mut args = f.args;
922                let str_arg = args.remove(0);
923                let substr_arg = args.remove(0);
924                Ok(Expression::Function(Box::new(Function::new(
925                    "LOCATE".to_string(),
926                    vec![substr_arg, str_arg],
927                ))))
928            }
929
930            // TIME_STR_TO_UNIX -> UNIX_TIMESTAMP in MySQL
931            "TIME_STR_TO_UNIX" => Ok(Expression::Function(Box::new(Function::new(
932                "UNIX_TIMESTAMP".to_string(),
933                f.args,
934            )))),
935
936            // TIME_STR_TO_TIME -> CAST AS DATETIME(N) or TIMESTAMP() in MySQL
937            "TIME_STR_TO_TIME" if f.args.len() >= 1 => {
938                let mut args = f.args.into_iter();
939                let arg = args.next().unwrap();
940
941                // If there's a timezone arg, use TIMESTAMP() function instead
942                if args.next().is_some() {
943                    return Ok(Expression::Function(Box::new(Function::new(
944                        "TIMESTAMP".to_string(),
945                        vec![arg],
946                    ))));
947                }
948
949                // Extract sub-second precision from the string literal
950                let precision =
951                    if let Expression::Literal(crate::expressions::Literal::String(ref s)) = arg {
952                        // Find fractional seconds: look for .NNN pattern after HH:MM:SS
953                        if let Some(dot_pos) = s.rfind('.') {
954                            let after_dot = &s[dot_pos + 1..];
955                            // Count digits until non-digit
956                            let frac_digits =
957                                after_dot.chars().take_while(|c| c.is_ascii_digit()).count();
958                            if frac_digits > 0 {
959                                // Round up: 1-3 digits → 3, 4-6 digits → 6
960                                if frac_digits <= 3 {
961                                    Some(3)
962                                } else {
963                                    Some(6)
964                                }
965                            } else {
966                                None
967                            }
968                        } else {
969                            None
970                        }
971                    } else {
972                        None
973                    };
974
975                let type_name = match precision {
976                    Some(p) => format!("DATETIME({})", p),
977                    None => "DATETIME".to_string(),
978                };
979
980                Ok(Expression::Cast(Box::new(Cast {
981                    this: arg,
982                    to: DataType::Custom { name: type_name },
983                    trailing_comments: Vec::new(),
984                    double_colon_syntax: false,
985                    format: None,
986                    default: None,
987                    inferred_type: None,
988                })))
989            }
990
991            // UCASE -> UPPER in MySQL
992            "UCASE" => Ok(Expression::Function(Box::new(Function::new(
993                "UPPER".to_string(),
994                f.args,
995            )))),
996
997            // LCASE -> LOWER in MySQL
998            "LCASE" => Ok(Expression::Function(Box::new(Function::new(
999                "LOWER".to_string(),
1000                f.args,
1001            )))),
1002
1003            // DAY_OF_MONTH -> DAYOFMONTH in MySQL
1004            "DAY_OF_MONTH" => Ok(Expression::Function(Box::new(Function::new(
1005                "DAYOFMONTH".to_string(),
1006                f.args,
1007            )))),
1008
1009            // DAY_OF_WEEK -> DAYOFWEEK in MySQL
1010            "DAY_OF_WEEK" => Ok(Expression::Function(Box::new(Function::new(
1011                "DAYOFWEEK".to_string(),
1012                f.args,
1013            )))),
1014
1015            // DAY_OF_YEAR -> DAYOFYEAR in MySQL
1016            "DAY_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
1017                "DAYOFYEAR".to_string(),
1018                f.args,
1019            )))),
1020
1021            // WEEK_OF_YEAR -> WEEKOFYEAR in MySQL
1022            "WEEK_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
1023                "WEEKOFYEAR".to_string(),
1024                f.args,
1025            )))),
1026
1027            // MOD(x, y) -> x % y in MySQL
1028            "MOD" if f.args.len() == 2 => {
1029                let mut args = f.args;
1030                let left = args.remove(0);
1031                let right = args.remove(0);
1032                Ok(Expression::Mod(Box::new(BinaryOp {
1033                    left,
1034                    right,
1035                    left_comments: Vec::new(),
1036                    operator_comments: Vec::new(),
1037                    trailing_comments: Vec::new(),
1038                    inferred_type: None,
1039                })))
1040            }
1041
1042            // PARSE_JSON -> strip in MySQL (just keep the string argument)
1043            "PARSE_JSON" if f.args.len() == 1 => Ok(f.args.into_iter().next().unwrap()),
1044
1045            // GET_PATH(obj, path) -> JSON_EXTRACT(obj, json_path) in MySQL
1046            "GET_PATH" if f.args.len() == 2 => {
1047                let mut args = f.args;
1048                let this = args.remove(0);
1049                let path = args.remove(0);
1050                let json_path = match &path {
1051                    Expression::Literal(Literal::String(s)) => {
1052                        // Convert bracket notation ["key"] to quoted dot notation ."key"
1053                        let s = Self::convert_bracket_to_quoted_path(s);
1054                        let normalized = if s.starts_with('$') {
1055                            s
1056                        } else if s.starts_with('[') {
1057                            format!("${}", s)
1058                        } else {
1059                            format!("$.{}", s)
1060                        };
1061                        Expression::Literal(Literal::String(normalized))
1062                    }
1063                    _ => path,
1064                };
1065                Ok(Expression::JsonExtract(Box::new(JsonExtractFunc {
1066                    this,
1067                    path: json_path,
1068                    returning: None,
1069                    arrow_syntax: false,
1070                    hash_arrow_syntax: false,
1071                    wrapper_option: None,
1072                    quotes_option: None,
1073                    on_scalar_string: false,
1074                    on_error: None,
1075                })))
1076            }
1077
1078            // REGEXP -> REGEXP_LIKE (MySQL standard form)
1079            "REGEXP" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
1080                "REGEXP_LIKE".to_string(),
1081                f.args,
1082            )))),
1083
1084            // Pass through everything else
1085            _ => Ok(Expression::Function(Box::new(f))),
1086        }
1087    }
1088
1089    fn transform_aggregate_function(
1090        &self,
1091        f: Box<crate::expressions::AggregateFunction>,
1092    ) -> Result<Expression> {
1093        let name_upper = f.name.to_uppercase();
1094        match name_upper.as_str() {
1095            // STRING_AGG -> GROUP_CONCAT
1096            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
1097                Function::new("GROUP_CONCAT".to_string(), f.args),
1098            ))),
1099
1100            // ARRAY_AGG -> GROUP_CONCAT
1101            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
1102                "GROUP_CONCAT".to_string(),
1103                f.args,
1104            )))),
1105
1106            // Pass through everything else
1107            _ => Ok(Expression::AggregateFunction(f)),
1108        }
1109    }
1110}
1111
1112#[cfg(test)]
1113mod tests {
1114    use super::*;
1115    use crate::dialects::Dialect;
1116
1117    fn transpile_to_mysql(sql: &str) -> String {
1118        let dialect = Dialect::get(DialectType::Generic);
1119        let result = dialect
1120            .transpile_to(sql, DialectType::MySQL)
1121            .expect("Transpile failed");
1122        result[0].clone()
1123    }
1124
1125    #[test]
1126    fn test_nvl_to_ifnull() {
1127        let result = transpile_to_mysql("SELECT NVL(a, b)");
1128        assert!(
1129            result.contains("IFNULL"),
1130            "Expected IFNULL, got: {}",
1131            result
1132        );
1133    }
1134
1135    #[test]
1136    fn test_coalesce_preserved() {
1137        // COALESCE should be preserved in MySQL (it's a native function)
1138        let result = transpile_to_mysql("SELECT COALESCE(a, b)");
1139        assert!(
1140            result.contains("COALESCE"),
1141            "Expected COALESCE to be preserved, got: {}",
1142            result
1143        );
1144    }
1145
1146    #[test]
1147    fn test_random_to_rand() {
1148        let result = transpile_to_mysql("SELECT RANDOM()");
1149        assert!(result.contains("RAND"), "Expected RAND, got: {}", result);
1150    }
1151
1152    #[test]
1153    fn test_basic_select() {
1154        let result = transpile_to_mysql("SELECT a, b FROM users WHERE id = 1");
1155        assert!(result.contains("SELECT"));
1156        assert!(result.contains("FROM users"));
1157    }
1158
1159    #[test]
1160    fn test_string_agg_to_group_concat() {
1161        let result = transpile_to_mysql("SELECT STRING_AGG(name)");
1162        assert!(
1163            result.contains("GROUP_CONCAT"),
1164            "Expected GROUP_CONCAT, got: {}",
1165            result
1166        );
1167    }
1168
1169    #[test]
1170    fn test_array_agg_to_group_concat() {
1171        let result = transpile_to_mysql("SELECT ARRAY_AGG(name)");
1172        assert!(
1173            result.contains("GROUP_CONCAT"),
1174            "Expected GROUP_CONCAT, got: {}",
1175            result
1176        );
1177    }
1178
1179    #[test]
1180    fn test_to_date_to_str_to_date() {
1181        let result = transpile_to_mysql("SELECT TO_DATE('2023-01-01')");
1182        assert!(
1183            result.contains("STR_TO_DATE"),
1184            "Expected STR_TO_DATE, got: {}",
1185            result
1186        );
1187    }
1188
1189    #[test]
1190    fn test_backtick_identifiers() {
1191        // MySQL uses backticks for identifiers
1192        let dialect = MySQLDialect;
1193        let config = dialect.generator_config();
1194        assert_eq!(config.identifier_quote, '`');
1195    }
1196
1197    fn mysql_identity(sql: &str, expected: &str) {
1198        let dialect = Dialect::get(DialectType::MySQL);
1199        let ast = dialect.parse(sql).expect("Parse failed");
1200        let transformed = dialect.transform(ast[0].clone()).expect("Transform failed");
1201        let result = dialect.generate(&transformed).expect("Generate failed");
1202        assert_eq!(result, expected, "SQL: {}", sql);
1203    }
1204
1205    #[test]
1206    fn test_ucase_to_upper() {
1207        mysql_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')");
1208    }
1209
1210    #[test]
1211    fn test_lcase_to_lower() {
1212        mysql_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')");
1213    }
1214
1215    #[test]
1216    fn test_day_of_month() {
1217        mysql_identity(
1218            "SELECT DAY_OF_MONTH('2023-01-01')",
1219            "SELECT DAYOFMONTH('2023-01-01')",
1220        );
1221    }
1222
1223    #[test]
1224    fn test_day_of_week() {
1225        mysql_identity(
1226            "SELECT DAY_OF_WEEK('2023-01-01')",
1227            "SELECT DAYOFWEEK('2023-01-01')",
1228        );
1229    }
1230
1231    #[test]
1232    fn test_day_of_year() {
1233        mysql_identity(
1234            "SELECT DAY_OF_YEAR('2023-01-01')",
1235            "SELECT DAYOFYEAR('2023-01-01')",
1236        );
1237    }
1238
1239    #[test]
1240    fn test_week_of_year() {
1241        mysql_identity(
1242            "SELECT WEEK_OF_YEAR('2023-01-01')",
1243            "SELECT WEEKOFYEAR('2023-01-01')",
1244        );
1245    }
1246
1247    #[test]
1248    fn test_mod_func_to_percent() {
1249        // MOD(x, y) function is transformed to x % y in MySQL
1250        mysql_identity("MOD(x, y)", "x % y");
1251    }
1252
1253    #[test]
1254    fn test_database_to_schema() {
1255        mysql_identity("DATABASE()", "SCHEMA()");
1256    }
1257
1258    #[test]
1259    fn test_and_operator() {
1260        mysql_identity("SELECT 1 && 0", "SELECT 1 AND 0");
1261    }
1262
1263    #[test]
1264    fn test_or_operator() {
1265        mysql_identity("SELECT a || b", "SELECT a OR b");
1266    }
1267}