Skip to main content

polyglot_sql/dialects/
mysql.rs

1//! MySQL Dialect
2//!
3//! MySQL-specific transformations based on sqlglot patterns.
4//! Key differences from standard SQL:
5//! - || is OR operator, not string concatenation (use CONCAT)
6//! - Uses backticks for identifiers
7//! - No TRY_CAST, no ILIKE
8//! - Different date/time function names
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{BinaryFunc, BinaryOp, Cast, DataType, Expression, Function, JsonExtractFunc, LikeOp, Literal, Paren, UnaryFunc};
13use crate::generator::GeneratorConfig;
14use crate::tokens::TokenizerConfig;
15
16/// Helper to wrap JSON arrow expressions in parentheses when they appear
17/// in contexts that require it (Binary, In, Not expressions)
18/// This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior
19fn wrap_if_json_arrow(expr: Expression) -> Expression {
20    match &expr {
21        Expression::JsonExtract(f) if f.arrow_syntax => {
22            Expression::Paren(Box::new(Paren {
23                this: expr,
24                trailing_comments: Vec::new(),
25            }))
26        }
27        Expression::JsonExtractScalar(f) if f.arrow_syntax => {
28            Expression::Paren(Box::new(Paren {
29                this: expr,
30                trailing_comments: Vec::new(),
31            }))
32        }
33        _ => expr,
34    }
35}
36
37/// Convert JSON arrow expression (-> or ->>) to JSON_EXTRACT function form
38/// This is needed for contexts like MEMBER OF where arrow syntax must become function form
39fn json_arrow_to_function(expr: Expression) -> Expression {
40    match expr {
41        Expression::JsonExtract(f) if f.arrow_syntax => {
42            Expression::Function(Box::new(Function::new(
43                "JSON_EXTRACT".to_string(),
44                vec![f.this, f.path],
45            )))
46        }
47        Expression::JsonExtractScalar(f) if f.arrow_syntax => {
48            // ->> becomes JSON_UNQUOTE(JSON_EXTRACT(...)) but can be simplified to JSON_EXTRACT_SCALAR
49            // For MySQL, use JSON_UNQUOTE(JSON_EXTRACT(...))
50            let json_extract = Expression::Function(Box::new(Function::new(
51                "JSON_EXTRACT".to_string(),
52                vec![f.this, f.path],
53            )));
54            Expression::Function(Box::new(Function::new(
55                "JSON_UNQUOTE".to_string(),
56                vec![json_extract],
57            )))
58        }
59        other => other,
60    }
61}
62
63/// MySQL dialect
64pub struct MySQLDialect;
65
66impl DialectImpl for MySQLDialect {
67    fn dialect_type(&self) -> DialectType {
68        DialectType::MySQL
69    }
70
71    fn tokenizer_config(&self) -> TokenizerConfig {
72        use crate::tokens::TokenType;
73        let mut config = TokenizerConfig::default();
74        // MySQL uses backticks for identifiers
75        config.identifiers.insert('`', '`');
76        // Remove double quotes from identifiers - in MySQL they are string delimiters
77        // (unless ANSI_QUOTES mode is set, but default mode uses them as strings)
78        config.identifiers.remove(&'"');
79        // MySQL supports double quotes as string literals by default
80        config.quotes.insert("\"".to_string(), "\"".to_string());
81        // MySQL supports backslash escapes in strings
82        config.string_escapes.push('\\');
83        // MySQL has XOR as a logical operator keyword
84        config.keywords.insert("XOR".to_string(), TokenType::Xor);
85        // MySQL: backslash followed by chars NOT in this list -> discard backslash
86        // See: https://dev.mysql.com/doc/refman/8.4/en/string-literals.html
87        config.escape_follow_chars = vec!['0', 'b', 'n', 'r', 't', 'Z', '%', '_'];
88        // MySQL allows identifiers to start with digits (e.g., 1a, 1_a)
89        config.identifiers_can_start_with_digit = true;
90        config
91    }
92
93    fn generator_config(&self) -> GeneratorConfig {
94        use crate::generator::IdentifierQuoteStyle;
95        GeneratorConfig {
96            identifier_quote: '`',
97            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
98            dialect: Some(DialectType::MySQL),
99            // MySQL doesn't support null ordering in most contexts
100            null_ordering_supported: false,
101            // MySQL LIMIT only
102            limit_only_literals: true,
103            // MySQL doesn't support semi/anti join
104            semi_anti_join_with_side: false,
105            // MySQL doesn't support table alias columns in some contexts
106            supports_table_alias_columns: false,
107            // MySQL VALUES not used as table
108            values_as_table: false,
109            // MySQL doesn't support TABLESAMPLE
110            tablesample_requires_parens: false,
111            tablesample_with_method: false,
112            // MySQL doesn't support aggregate FILTER
113            aggregate_filter_supported: false,
114            // MySQL doesn't support TRY
115            try_supported: false,
116            // MySQL doesn't support CONVERT_TIMEZONE
117            supports_convert_timezone: false,
118            // MySQL doesn't support UESCAPE
119            supports_uescape: false,
120            // MySQL doesn't support BETWEEN flags
121            supports_between_flags: false,
122            // MySQL supports EXPLAIN but not query hints in standard way
123            query_hints: false,
124            // MySQL parameter token
125            parameter_token: "?",
126            // MySQL doesn't support window EXCLUDE
127            supports_window_exclude: false,
128            // MySQL doesn't support exploding projections
129            supports_exploding_projections: false,
130            identifiers_can_start_with_digit: true,
131            // MySQL supports FOR UPDATE/SHARE
132            locking_reads_supported: true,
133            ..Default::default()
134        }
135    }
136
137    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
138        match expr {
139            // ===== Data Type Mappings =====
140            Expression::DataType(dt) => self.transform_data_type(dt),
141
142            // NVL -> IFNULL in MySQL
143            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
144
145            // Note: COALESCE is valid in MySQL and should be preserved.
146            // Unlike some other dialects, we do NOT convert COALESCE to IFNULL
147            // as this would break identity tests.
148
149            // TryCast -> CAST or TIMESTAMP() (MySQL doesn't support TRY_CAST)
150            Expression::TryCast(c) => self.transform_cast(*c),
151
152            // SafeCast -> CAST or TIMESTAMP() (MySQL doesn't support safe casts)
153            Expression::SafeCast(c) => self.transform_cast(*c),
154
155            // Cast -> Transform cast type according to MySQL restrictions
156            // CAST AS TIMESTAMP -> TIMESTAMP() function in MySQL
157            Expression::Cast(c) => self.transform_cast(*c),
158
159            // ILIKE -> LOWER() LIKE LOWER() in MySQL
160            Expression::ILike(op) => {
161                // Transform ILIKE to: LOWER(left) LIKE LOWER(right)
162                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left)));
163                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right)));
164                Ok(Expression::Like(Box::new(LikeOp {
165                    left: lower_left,
166                    right: lower_right,
167                    escape: op.escape,
168                    quantifier: op.quantifier,
169                })))
170            }
171
172            // || (Concat operator) -> OR in MySQL
173            // MySQL uses || as OR by default (unless PIPES_AS_CONCAT mode is set)
174            // For identity preservation, we transform to OR
175            Expression::Concat(op) => {
176                Ok(Expression::Or(op))
177            }
178
179            // RANDOM -> RAND in MySQL
180            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
181                seed: None, lower: None, upper: None,
182            }))),
183
184            // ArrayAgg -> GROUP_CONCAT in MySQL
185            Expression::ArrayAgg(f) => Ok(Expression::Function(Box::new(Function::new(
186                "GROUP_CONCAT".to_string(),
187                vec![f.this],
188            )))),
189
190            // StringAgg -> GROUP_CONCAT in MySQL
191            Expression::StringAgg(f) => {
192                let mut args = vec![f.this.clone()];
193                if let Some(separator) = &f.separator {
194                    args.push(separator.clone());
195                }
196                Ok(Expression::Function(Box::new(Function::new(
197                    "GROUP_CONCAT".to_string(),
198                    args,
199                ))))
200            }
201
202            // UNNEST -> Not directly supported in MySQL, use JSON_TABLE or inline
203            // For basic cases, pass through (may need manual handling)
204            Expression::Unnest(f) => {
205                // MySQL 8.0+ has JSON_TABLE which can be used for unnesting
206                // For now, pass through with a function call
207                Ok(Expression::Function(Box::new(Function::new(
208                    "JSON_TABLE".to_string(),
209                    vec![f.this],
210                ))))
211            }
212
213            // Substring: Use comma syntax (not FROM/FOR) in MySQL
214            Expression::Substring(mut f) => {
215                f.from_for_syntax = false;
216                Ok(Expression::Substring(f))
217            }
218
219            // ===== Bitwise operations =====
220            // BitwiseAndAgg -> BIT_AND
221            Expression::BitwiseAndAgg(f) => Ok(Expression::Function(Box::new(Function::new(
222                "BIT_AND".to_string(),
223                vec![f.this],
224            )))),
225
226            // BitwiseOrAgg -> BIT_OR
227            Expression::BitwiseOrAgg(f) => Ok(Expression::Function(Box::new(Function::new(
228                "BIT_OR".to_string(),
229                vec![f.this],
230            )))),
231
232            // BitwiseXorAgg -> BIT_XOR
233            Expression::BitwiseXorAgg(f) => Ok(Expression::Function(Box::new(Function::new(
234                "BIT_XOR".to_string(),
235                vec![f.this],
236            )))),
237
238            // BitwiseCount -> BIT_COUNT
239            Expression::BitwiseCount(f) => Ok(Expression::Function(Box::new(Function::new(
240                "BIT_COUNT".to_string(),
241                vec![f.this],
242            )))),
243
244            // TimeFromParts -> MAKETIME
245            Expression::TimeFromParts(f) => {
246                let mut args = Vec::new();
247                if let Some(h) = f.hour {
248                    args.push(*h);
249                }
250                if let Some(m) = f.min {
251                    args.push(*m);
252                }
253                if let Some(s) = f.sec {
254                    args.push(*s);
255                }
256                Ok(Expression::Function(Box::new(Function::new(
257                    "MAKETIME".to_string(),
258                    args,
259                ))))
260            }
261
262            // ===== Boolean aggregates =====
263            // In MySQL, there's no BOOL_AND/BOOL_OR, use MIN/MAX on boolean values
264            // LogicalAnd -> MIN (0 is false, non-0 is true)
265            Expression::LogicalAnd(f) => Ok(Expression::Function(Box::new(Function::new(
266                "MIN".to_string(),
267                vec![f.this],
268            )))),
269
270            // LogicalOr -> MAX
271            Expression::LogicalOr(f) => Ok(Expression::Function(Box::new(Function::new(
272                "MAX".to_string(),
273                vec![f.this],
274            )))),
275
276            // ===== Date/time functions =====
277            // DayOfMonth -> DAYOFMONTH
278            Expression::DayOfMonth(f) => Ok(Expression::Function(Box::new(Function::new(
279                "DAYOFMONTH".to_string(),
280                vec![f.this],
281            )))),
282
283            // DayOfWeek -> DAYOFWEEK
284            Expression::DayOfWeek(f) => Ok(Expression::Function(Box::new(Function::new(
285                "DAYOFWEEK".to_string(),
286                vec![f.this],
287            )))),
288
289            // DayOfYear -> DAYOFYEAR
290            Expression::DayOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
291                "DAYOFYEAR".to_string(),
292                vec![f.this],
293            )))),
294
295            // WeekOfYear -> WEEKOFYEAR
296            Expression::WeekOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
297                "WEEKOFYEAR".to_string(),
298                vec![f.this],
299            )))),
300
301            // DateDiff -> DATEDIFF
302            Expression::DateDiff(f) => Ok(Expression::Function(Box::new(Function::new(
303                "DATEDIFF".to_string(),
304                vec![f.this, f.expression],
305            )))),
306
307            // TimeStrToUnix -> UNIX_TIMESTAMP
308            Expression::TimeStrToUnix(f) => Ok(Expression::Function(Box::new(Function::new(
309                "UNIX_TIMESTAMP".to_string(),
310                vec![f.this],
311            )))),
312
313            // TimestampDiff -> TIMESTAMPDIFF
314            Expression::TimestampDiff(f) => Ok(Expression::Function(Box::new(Function::new(
315                "TIMESTAMPDIFF".to_string(),
316                vec![*f.this, *f.expression],
317            )))),
318
319            // ===== String functions =====
320            // StrPosition -> LOCATE in MySQL
321            // STRPOS(str, substr) -> LOCATE(substr, str) (args are swapped)
322            Expression::StrPosition(f) => {
323                let mut args = vec![];
324                if let Some(substr) = f.substr {
325                    args.push(*substr);
326                }
327                args.push(*f.this);
328                if let Some(pos) = f.position {
329                    args.push(*pos);
330                }
331                Ok(Expression::Function(Box::new(Function::new(
332                    "LOCATE".to_string(),
333                    args,
334                ))))
335            }
336
337            // Stuff -> INSERT in MySQL
338            Expression::Stuff(f) => {
339                let mut args = vec![*f.this];
340                if let Some(start) = f.start {
341                    args.push(*start);
342                }
343                if let Some(length) = f.length {
344                    args.push(Expression::number(length));
345                }
346                args.push(*f.expression);
347                Ok(Expression::Function(Box::new(Function::new(
348                    "INSERT".to_string(),
349                    args,
350                ))))
351            }
352
353            // ===== Session/User functions =====
354            // SessionUser -> SESSION_USER()
355            Expression::SessionUser(_) => Ok(Expression::Function(Box::new(Function::new(
356                "SESSION_USER".to_string(),
357                vec![],
358            )))),
359
360            // CurrentDate -> CURRENT_DATE (no parentheses in MySQL) - keep as CurrentDate
361            Expression::CurrentDate(_) => Ok(Expression::CurrentDate(crate::expressions::CurrentDate)),
362
363            // ===== Null-safe comparison =====
364            // NullSafeNeq -> NOT (a <=> b) in MySQL
365            Expression::NullSafeNeq(op) => {
366                // Create: NOT (left <=> right)
367                let null_safe_eq = Expression::NullSafeEq(Box::new(crate::expressions::BinaryOp {
368                    left: op.left,
369                    right: op.right,
370                    left_comments: Vec::new(),
371                    operator_comments: Vec::new(),
372                    trailing_comments: Vec::new(),
373                }));
374                Ok(Expression::Not(Box::new(crate::expressions::UnaryOp {
375                    this: null_safe_eq,
376                })))
377            }
378
379            // ParseJson: handled by generator (emits just the string literal for MySQL)
380
381            // JSONExtract with variant_extract (Snowflake colon syntax) -> JSON_EXTRACT
382            Expression::JSONExtract(e) if e.variant_extract.is_some() => {
383                let path = match *e.expression {
384                    Expression::Literal(Literal::String(s)) => {
385                        // Convert bracket notation ["key"] to quoted dot notation ."key"
386                        let s = Self::convert_bracket_to_quoted_path(&s);
387                        let normalized = if s.starts_with('$') {
388                            s
389                        } else if s.starts_with('[') {
390                            format!("${}", s)
391                        } else {
392                            format!("$.{}", s)
393                        };
394                        Expression::Literal(Literal::String(normalized))
395                    }
396                    other => other,
397                };
398                Ok(Expression::Function(Box::new(Function::new(
399                    "JSON_EXTRACT".to_string(),
400                    vec![*e.this, path],
401                ))))
402            }
403
404            // Generic function transformations
405            Expression::Function(f) => self.transform_function(*f),
406
407            // Generic aggregate function transformations
408            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
409
410            // ===== Context-aware JSON arrow wrapping =====
411            // When JSON arrow expressions appear in Binary/In/Not contexts,
412            // they need to be wrapped in parentheses for correct precedence.
413            // This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior.
414
415            // Binary operators that need JSON wrapping
416            Expression::Eq(op) => Ok(Expression::Eq(Box::new(BinaryOp {
417                left: wrap_if_json_arrow(op.left),
418                right: wrap_if_json_arrow(op.right),
419                ..*op
420            }))),
421            Expression::Neq(op) => Ok(Expression::Neq(Box::new(BinaryOp {
422                left: wrap_if_json_arrow(op.left),
423                right: wrap_if_json_arrow(op.right),
424                ..*op
425            }))),
426            Expression::Lt(op) => Ok(Expression::Lt(Box::new(BinaryOp {
427                left: wrap_if_json_arrow(op.left),
428                right: wrap_if_json_arrow(op.right),
429                ..*op
430            }))),
431            Expression::Lte(op) => Ok(Expression::Lte(Box::new(BinaryOp {
432                left: wrap_if_json_arrow(op.left),
433                right: wrap_if_json_arrow(op.right),
434                ..*op
435            }))),
436            Expression::Gt(op) => Ok(Expression::Gt(Box::new(BinaryOp {
437                left: wrap_if_json_arrow(op.left),
438                right: wrap_if_json_arrow(op.right),
439                ..*op
440            }))),
441            Expression::Gte(op) => Ok(Expression::Gte(Box::new(BinaryOp {
442                left: wrap_if_json_arrow(op.left),
443                right: wrap_if_json_arrow(op.right),
444                ..*op
445            }))),
446
447            // In expression - wrap the this part if it's JSON arrow
448            Expression::In(mut i) => {
449                i.this = wrap_if_json_arrow(i.this);
450                Ok(Expression::In(i))
451            }
452
453            // Not expression - wrap the this part if it's JSON arrow
454            Expression::Not(mut n) => {
455                n.this = wrap_if_json_arrow(n.this);
456                Ok(Expression::Not(n))
457            }
458
459            // && in MySQL is logical AND, not array overlaps
460            // Transform ArrayOverlaps -> And for MySQL identity
461            Expression::ArrayOverlaps(op) => Ok(Expression::And(op)),
462
463            // MOD(x, y) -> x % y in MySQL
464            Expression::ModFunc(f) => Ok(Expression::Mod(Box::new(BinaryOp {
465                left: f.this,
466                right: f.expression,
467                left_comments: Vec::new(),
468                operator_comments: Vec::new(),
469                trailing_comments: Vec::new(),
470            }))),
471
472            // SHOW SLAVE STATUS -> SHOW REPLICA STATUS
473            Expression::Show(mut s) => {
474                if s.this == "SLAVE STATUS" {
475                    s.this = "REPLICA STATUS".to_string();
476                }
477                if matches!(s.this.as_str(), "INDEX" | "COLUMNS") && s.db.is_none() {
478                    if let Some(Expression::Table(mut t)) = s.target.take() {
479                        if let Some(db_ident) = t.schema.take().or(t.catalog.take()) {
480                            s.db = Some(Expression::Identifier(db_ident));
481                            s.target = Some(Expression::Identifier(t.name));
482                        } else {
483                            s.target = Some(Expression::Table(t));
484                        }
485                    }
486                }
487                Ok(Expression::Show(s))
488            }
489
490            // AT TIME ZONE -> strip timezone (MySQL doesn't support AT TIME ZONE)
491            // But keep it for CURRENT_DATE/CURRENT_TIMESTAMP with timezone (transpiled from BigQuery)
492            Expression::AtTimeZone(atz) => {
493                let is_current = match &atz.this {
494                    Expression::CurrentDate(_) | Expression::CurrentTimestamp(_) => true,
495                    Expression::Function(f) => {
496                        let n = f.name.to_uppercase();
497                        (n == "CURRENT_DATE" || n == "CURRENT_TIMESTAMP") && f.no_parens
498                    }
499                    _ => false,
500                };
501                if is_current {
502                    Ok(Expression::AtTimeZone(atz)) // Keep AT TIME ZONE for CURRENT_DATE/CURRENT_TIMESTAMP
503                } else {
504                    Ok(atz.this) // Strip timezone for other expressions
505                }
506            }
507
508            // MEMBER OF with JSON arrow -> convert arrow to JSON_EXTRACT function
509            // MySQL's MEMBER OF requires JSON_EXTRACT function form, not arrow syntax
510            Expression::MemberOf(mut op) => {
511                op.right = json_arrow_to_function(op.right);
512                Ok(Expression::MemberOf(op))
513            }
514
515            // Pass through everything else
516            _ => Ok(expr),
517        }
518    }
519}
520
521impl MySQLDialect {
522    fn normalize_mysql_date_format(fmt: &str) -> String {
523        fmt
524            .replace("%H:%i:%s", "%T")
525            .replace("%H:%i:%S", "%T")
526    }
527
528    /// Convert bracket notation ["key with spaces"] to quoted dot notation ."key with spaces"
529    /// in JSON path strings.
530    fn convert_bracket_to_quoted_path(path: &str) -> String {
531        let mut result = String::new();
532        let mut chars = path.chars().peekable();
533        while let Some(c) = chars.next() {
534            if c == '[' && chars.peek() == Some(&'"') {
535                chars.next(); // consume "
536                let mut key = String::new();
537                while let Some(kc) = chars.next() {
538                    if kc == '"' && chars.peek() == Some(&']') {
539                        chars.next(); // consume ]
540                        break;
541                    }
542                    key.push(kc);
543                }
544                if !result.is_empty() && !result.ends_with('.') {
545                    result.push('.');
546                }
547                result.push('"');
548                result.push_str(&key);
549                result.push('"');
550            } else {
551                result.push(c);
552            }
553        }
554        result
555    }
556
557    /// Transform data types according to MySQL TYPE_MAPPING
558    /// Note: MySQL's TIMESTAMP is kept as TIMESTAMP (not converted to DATETIME)
559    /// because MySQL's TIMESTAMP has timezone awareness built-in
560    fn transform_data_type(&self, dt: crate::expressions::DataType) -> Result<Expression> {
561        use crate::expressions::DataType;
562        let transformed = match dt {
563            // All TIMESTAMP variants (with or without timezone) -> TIMESTAMP in MySQL
564            DataType::Timestamp { precision, timezone: _ } => DataType::Timestamp {
565                precision,
566                timezone: false,
567            },
568            // TIMESTAMPTZ / TIMESTAMPLTZ parsed as Custom -> normalize to TIMESTAMP
569            DataType::Custom { name } if name.to_uppercase() == "TIMESTAMPTZ"
570                || name.to_uppercase() == "TIMESTAMPLTZ" => {
571                DataType::Timestamp {
572                    precision: None,
573                    timezone: false,
574                }
575            },
576            // Keep native MySQL types as-is
577            // MySQL supports TEXT, MEDIUMTEXT, LONGTEXT, BLOB, etc. natively
578            other => other,
579        };
580        Ok(Expression::DataType(transformed))
581    }
582
583    /// Transform CAST expression
584    /// MySQL uses TIMESTAMP() function instead of CAST(x AS TIMESTAMP)
585    fn transform_cast(&self, cast: Cast) -> Result<Expression> {
586        // CAST AS TIMESTAMP/TIMESTAMPTZ/TIMESTAMPLTZ -> TIMESTAMP() function
587        match &cast.to {
588            DataType::Timestamp { .. } => {
589                Ok(Expression::Function(Box::new(Function::new(
590                    "TIMESTAMP".to_string(),
591                    vec![cast.this],
592                ))))
593            }
594            DataType::Custom { name } if name.to_uppercase() == "TIMESTAMPTZ"
595                || name.to_uppercase() == "TIMESTAMPLTZ" => {
596                Ok(Expression::Function(Box::new(Function::new(
597                    "TIMESTAMP".to_string(),
598                    vec![cast.this],
599                ))))
600            }
601            // All other casts go through normal type transformation
602            _ => Ok(Expression::Cast(Box::new(self.transform_cast_type(cast)))),
603        }
604    }
605
606    /// Transform CAST type according to MySQL restrictions
607    /// MySQL doesn't support many types in CAST - they get mapped to CHAR or SIGNED
608    /// Based on Python sqlglot's CHAR_CAST_MAPPING and SIGNED_CAST_MAPPING
609    fn transform_cast_type(&self, cast: Cast) -> Cast {
610        let new_type = match &cast.to {
611            // CHAR_CAST_MAPPING: These types become CHAR in MySQL CAST
612            DataType::VarChar { .. } => DataType::Char { length: None },
613            DataType::Text => DataType::Char { length: None },
614
615            // SIGNED_CAST_MAPPING: These integer types become SIGNED in MySQL CAST
616            DataType::BigInt { .. } => DataType::Custom { name: "SIGNED".to_string() },
617            DataType::Int { .. } => DataType::Custom { name: "SIGNED".to_string() },
618            DataType::SmallInt { .. } => DataType::Custom { name: "SIGNED".to_string() },
619            DataType::TinyInt { .. } => DataType::Custom { name: "SIGNED".to_string() },
620            DataType::Boolean => DataType::Custom { name: "SIGNED".to_string() },
621
622            // Custom types that need mapping
623            DataType::Custom { name } => {
624                let upper = name.to_uppercase();
625                match upper.as_str() {
626                    // Text/Blob types -> keep as Custom for cross-dialect mapping
627                    // MySQL generator will output CHAR for these in CAST context
628                    "LONGTEXT" | "MEDIUMTEXT" | "TINYTEXT" | "LONGBLOB" | "MEDIUMBLOB" | "TINYBLOB" => {
629                        DataType::Custom { name: upper }
630                    }
631                    // MEDIUMINT -> SIGNED in MySQL CAST
632                    "MEDIUMINT" => DataType::Custom { name: "SIGNED".to_string() },
633                    // Unsigned integer types -> UNSIGNED
634                    "UBIGINT" | "UINT" | "USMALLINT" | "UTINYINT" | "UMEDIUMINT" => {
635                        DataType::Custom { name: "UNSIGNED".to_string() }
636                    }
637                    // Keep other custom types
638                    _ => cast.to.clone(),
639                }
640            }
641
642            // Keep other types as-is
643            _ => cast.to.clone(),
644        };
645
646        Cast {
647            this: cast.this,
648            to: new_type,
649            trailing_comments: cast.trailing_comments,
650            double_colon_syntax: cast.double_colon_syntax,
651            format: cast.format,
652            default: cast.default,
653        }
654    }
655
656    fn transform_function(&self, f: Function) -> Result<Expression> {
657        let name_upper = f.name.to_uppercase();
658        match name_upper.as_str() {
659            // Normalize DATE_FORMAT short-hands to canonical MySQL forms.
660            "DATE_FORMAT" if f.args.len() >= 2 => {
661                let mut f = f;
662                if let Some(Expression::Literal(Literal::String(fmt))) = f.args.get(1) {
663                    let normalized = Self::normalize_mysql_date_format(fmt);
664                    if normalized != *fmt {
665                        f.args[1] = Expression::Literal(Literal::String(normalized));
666                    }
667                }
668                Ok(Expression::Function(Box::new(f)))
669            }
670
671            // NVL -> IFNULL
672            "NVL" if f.args.len() == 2 => {
673                let mut args = f.args;
674                let second = args.pop().unwrap();
675                let first = args.pop().unwrap();
676                Ok(Expression::IfNull(Box::new(BinaryFunc { original_name: None,
677                    this: first,
678                    expression: second,
679                })))
680            }
681
682            // Note: COALESCE is native to MySQL. We do NOT convert it to IFNULL
683            // because this would break identity tests (Python SQLGlot preserves COALESCE).
684
685            // ARRAY_AGG -> GROUP_CONCAT
686            "ARRAY_AGG" if f.args.len() == 1 => {
687                let mut args = f.args;
688                Ok(Expression::Function(Box::new(Function::new(
689                    "GROUP_CONCAT".to_string(),
690                    vec![args.pop().unwrap()],
691                ))))
692            }
693
694            // STRING_AGG -> GROUP_CONCAT
695            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
696                "GROUP_CONCAT".to_string(),
697                f.args,
698            )))),
699
700            // RANDOM -> RAND
701            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
702                seed: None, lower: None, upper: None,
703            }))),
704
705            // CURRENT_TIMESTAMP -> NOW() or CURRENT_TIMESTAMP (both work)
706            // Preserve precision if specified: CURRENT_TIMESTAMP(6)
707            "CURRENT_TIMESTAMP" => {
708                let precision = if let Some(Expression::Literal(crate::expressions::Literal::Number(n))) = f.args.first() {
709                    n.parse::<u32>().ok()
710                } else {
711                    None
712                };
713                Ok(Expression::CurrentTimestamp(
714                    crate::expressions::CurrentTimestamp { precision, sysdate: false },
715                ))
716            }
717
718            // POSITION -> LOCATE in MySQL (argument order is different)
719            // POSITION(substr IN str) -> LOCATE(substr, str)
720            "POSITION" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
721                "LOCATE".to_string(),
722                f.args,
723            )))),
724
725            // LENGTH is native to MySQL (returns bytes, not characters)
726            // CHAR_LENGTH for character count
727            "LENGTH" => Ok(Expression::Function(Box::new(f))),
728
729            // CEIL -> CEILING in MySQL (both work)
730            "CEIL" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
731                "CEILING".to_string(),
732                f.args,
733            )))),
734
735            // STDDEV -> STD or STDDEV_POP in MySQL
736            "STDDEV" => Ok(Expression::Function(Box::new(Function::new(
737                "STD".to_string(),
738                f.args,
739            )))),
740
741            // STDDEV_SAMP -> STDDEV in MySQL
742            "STDDEV_SAMP" => Ok(Expression::Function(Box::new(Function::new(
743                "STDDEV".to_string(),
744                f.args,
745            )))),
746
747            // TO_DATE -> STR_TO_DATE in MySQL
748            "TO_DATE" => Ok(Expression::Function(Box::new(Function::new(
749                "STR_TO_DATE".to_string(),
750                f.args,
751            )))),
752
753            // TO_TIMESTAMP -> STR_TO_DATE in MySQL
754            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(Function::new(
755                "STR_TO_DATE".to_string(),
756                f.args,
757            )))),
758
759            // DATE_TRUNC -> Complex transformation
760            // Typically uses DATE() or DATE_FORMAT() depending on unit
761            "DATE_TRUNC" if f.args.len() >= 2 => {
762                // Simplified: DATE_TRUNC('day', x) -> DATE(x)
763                // Full implementation would handle different units
764                let mut args = f.args;
765                let _unit = args.remove(0);
766                let date = args.remove(0);
767                Ok(Expression::Function(Box::new(Function::new(
768                    "DATE".to_string(),
769                    vec![date],
770                ))))
771            }
772
773            // EXTRACT is native but syntax varies
774
775            // COALESCE is native to MySQL (keep as-is for more than 2 args)
776            "COALESCE" if f.args.len() > 2 => Ok(Expression::Function(Box::new(f))),
777
778            // DAYOFMONTH -> DAY (both work)
779            "DAY" => Ok(Expression::Function(Box::new(Function::new(
780                "DAYOFMONTH".to_string(),
781                f.args,
782            )))),
783
784            // DAYOFWEEK is native to MySQL
785            "DAYOFWEEK" => Ok(Expression::Function(Box::new(f))),
786
787            // DAYOFYEAR is native to MySQL
788            "DAYOFYEAR" => Ok(Expression::Function(Box::new(f))),
789
790            // WEEKOFYEAR is native to MySQL
791            "WEEKOFYEAR" => Ok(Expression::Function(Box::new(f))),
792
793            // LAST_DAY is native to MySQL
794            "LAST_DAY" => Ok(Expression::Function(Box::new(f))),
795
796            // TIMESTAMPADD -> DATE_ADD
797            "TIMESTAMPADD" => Ok(Expression::Function(Box::new(Function::new(
798                "DATE_ADD".to_string(),
799                f.args,
800            )))),
801
802            // TIMESTAMPDIFF is native to MySQL
803            "TIMESTAMPDIFF" => Ok(Expression::Function(Box::new(f))),
804
805            // CONVERT_TIMEZONE(from_tz, to_tz, timestamp) -> CONVERT_TZ(timestamp, from_tz, to_tz) in MySQL
806            "CONVERT_TIMEZONE" if f.args.len() == 3 => {
807                let mut args = f.args;
808                let from_tz = args.remove(0);
809                let to_tz = args.remove(0);
810                let timestamp = args.remove(0);
811                Ok(Expression::Function(Box::new(Function::new(
812                    "CONVERT_TZ".to_string(),
813                    vec![timestamp, from_tz, to_tz],
814                ))))
815            }
816
817            // UTC_TIMESTAMP is native to MySQL
818            "UTC_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
819
820            // UTC_TIME is native to MySQL
821            "UTC_TIME" => Ok(Expression::Function(Box::new(f))),
822
823            // MAKETIME is native to MySQL (TimeFromParts)
824            "MAKETIME" => Ok(Expression::Function(Box::new(f))),
825
826            // TIME_FROM_PARTS -> MAKETIME
827            "TIME_FROM_PARTS" if f.args.len() == 3 => Ok(Expression::Function(Box::new(Function::new(
828                "MAKETIME".to_string(),
829                f.args,
830            )))),
831
832            // STUFF -> INSERT in MySQL
833            "STUFF" if f.args.len() == 4 => Ok(Expression::Function(Box::new(Function::new(
834                "INSERT".to_string(),
835                f.args,
836            )))),
837
838            // LOCATE is native to MySQL (reverse of POSITION args)
839            "LOCATE" => Ok(Expression::Function(Box::new(f))),
840
841            // FIND_IN_SET is native to MySQL
842            "FIND_IN_SET" => Ok(Expression::Function(Box::new(f))),
843
844            // FORMAT is native to MySQL (NumberToStr)
845            "FORMAT" => Ok(Expression::Function(Box::new(f))),
846
847            // JSON_EXTRACT is native to MySQL
848            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
849
850            // JSON_UNQUOTE is native to MySQL
851            "JSON_UNQUOTE" => Ok(Expression::Function(Box::new(f))),
852
853            // JSON_EXTRACT_PATH_TEXT -> JSON_UNQUOTE(JSON_EXTRACT(...))
854            "JSON_EXTRACT_PATH_TEXT" if f.args.len() >= 2 => {
855                let extract = Expression::Function(Box::new(Function::new(
856                    "JSON_EXTRACT".to_string(),
857                    f.args,
858                )));
859                Ok(Expression::Function(Box::new(Function::new(
860                    "JSON_UNQUOTE".to_string(),
861                    vec![extract],
862                ))))
863            }
864
865            // GEN_RANDOM_UUID / UUID -> UUID()
866            "GEN_RANDOM_UUID" | "GENERATE_UUID" => Ok(Expression::Function(Box::new(Function::new(
867                "UUID".to_string(),
868                vec![],
869            )))),
870
871            // DATABASE() -> SCHEMA() in MySQL (both return current database name)
872            "DATABASE" => Ok(Expression::Function(Box::new(Function::new(
873                "SCHEMA".to_string(),
874                f.args,
875            )))),
876
877            // INSTR -> LOCATE in MySQL (with swapped arguments)
878            // INSTR(str, substr) -> LOCATE(substr, str)
879            "INSTR" if f.args.len() == 2 => {
880                let mut args = f.args;
881                let str_arg = args.remove(0);
882                let substr_arg = args.remove(0);
883                Ok(Expression::Function(Box::new(Function::new(
884                    "LOCATE".to_string(),
885                    vec![substr_arg, str_arg],
886                ))))
887            }
888
889            // TIME_STR_TO_UNIX -> UNIX_TIMESTAMP in MySQL
890            "TIME_STR_TO_UNIX" => Ok(Expression::Function(Box::new(Function::new(
891                "UNIX_TIMESTAMP".to_string(),
892                f.args,
893            )))),
894
895            // TIME_STR_TO_TIME -> CAST AS DATETIME(N) or TIMESTAMP() in MySQL
896            "TIME_STR_TO_TIME" if f.args.len() >= 1 => {
897                let mut args = f.args.into_iter();
898                let arg = args.next().unwrap();
899
900                // If there's a timezone arg, use TIMESTAMP() function instead
901                if args.next().is_some() {
902                    return Ok(Expression::Function(Box::new(Function::new(
903                        "TIMESTAMP".to_string(),
904                        vec![arg],
905                    ))));
906                }
907
908                // Extract sub-second precision from the string literal
909                let precision = if let Expression::Literal(crate::expressions::Literal::String(ref s)) = arg {
910                    // Find fractional seconds: look for .NNN pattern after HH:MM:SS
911                    if let Some(dot_pos) = s.rfind('.') {
912                        let after_dot = &s[dot_pos + 1..];
913                        // Count digits until non-digit
914                        let frac_digits = after_dot.chars().take_while(|c| c.is_ascii_digit()).count();
915                        if frac_digits > 0 {
916                            // Round up: 1-3 digits → 3, 4-6 digits → 6
917                            if frac_digits <= 3 { Some(3) } else { Some(6) }
918                        } else {
919                            None
920                        }
921                    } else {
922                        None
923                    }
924                } else {
925                    None
926                };
927
928                let type_name = match precision {
929                    Some(p) => format!("DATETIME({})", p),
930                    None => "DATETIME".to_string(),
931                };
932
933                Ok(Expression::Cast(Box::new(Cast {
934                    this: arg,
935                    to: DataType::Custom { name: type_name },
936                    trailing_comments: Vec::new(),
937                    double_colon_syntax: false,
938                    format: None,
939                    default: None,
940                })))
941            }
942
943            // UCASE -> UPPER in MySQL
944            "UCASE" => Ok(Expression::Function(Box::new(Function::new(
945                "UPPER".to_string(),
946                f.args,
947            )))),
948
949            // LCASE -> LOWER in MySQL
950            "LCASE" => Ok(Expression::Function(Box::new(Function::new(
951                "LOWER".to_string(),
952                f.args,
953            )))),
954
955            // DAY_OF_MONTH -> DAYOFMONTH in MySQL
956            "DAY_OF_MONTH" => Ok(Expression::Function(Box::new(Function::new(
957                "DAYOFMONTH".to_string(),
958                f.args,
959            )))),
960
961            // DAY_OF_WEEK -> DAYOFWEEK in MySQL
962            "DAY_OF_WEEK" => Ok(Expression::Function(Box::new(Function::new(
963                "DAYOFWEEK".to_string(),
964                f.args,
965            )))),
966
967            // DAY_OF_YEAR -> DAYOFYEAR in MySQL
968            "DAY_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
969                "DAYOFYEAR".to_string(),
970                f.args,
971            )))),
972
973            // WEEK_OF_YEAR -> WEEKOFYEAR in MySQL
974            "WEEK_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
975                "WEEKOFYEAR".to_string(),
976                f.args,
977            )))),
978
979            // MOD(x, y) -> x % y in MySQL
980            "MOD" if f.args.len() == 2 => {
981                let mut args = f.args;
982                let left = args.remove(0);
983                let right = args.remove(0);
984                Ok(Expression::Mod(Box::new(BinaryOp {
985                    left,
986                    right,
987                    left_comments: Vec::new(),
988                    operator_comments: Vec::new(),
989                    trailing_comments: Vec::new(),
990                })))
991            }
992
993            // PARSE_JSON -> strip in MySQL (just keep the string argument)
994            "PARSE_JSON" if f.args.len() == 1 => Ok(f.args.into_iter().next().unwrap()),
995
996            // GET_PATH(obj, path) -> JSON_EXTRACT(obj, json_path) in MySQL
997            "GET_PATH" if f.args.len() == 2 => {
998                let mut args = f.args;
999                let this = args.remove(0);
1000                let path = args.remove(0);
1001                let json_path = match &path {
1002                    Expression::Literal(Literal::String(s)) => {
1003                        // Convert bracket notation ["key"] to quoted dot notation ."key"
1004                        let s = Self::convert_bracket_to_quoted_path(s);
1005                        let normalized = if s.starts_with('$') {
1006                            s
1007                        } else if s.starts_with('[') {
1008                            format!("${}", s)
1009                        } else {
1010                            format!("$.{}", s)
1011                        };
1012                        Expression::Literal(Literal::String(normalized))
1013                    }
1014                    _ => path,
1015                };
1016                Ok(Expression::JsonExtract(Box::new(JsonExtractFunc {
1017                    this,
1018                    path: json_path,
1019                    returning: None,
1020                    arrow_syntax: false,
1021                    hash_arrow_syntax: false,
1022                    wrapper_option: None,
1023                    quotes_option: None,
1024                    on_scalar_string: false,
1025                    on_error: None,
1026                })))
1027            }
1028
1029            // REGEXP -> REGEXP_LIKE (MySQL standard form)
1030            "REGEXP" if f.args.len() >= 2 => {
1031                Ok(Expression::Function(Box::new(Function::new("REGEXP_LIKE".to_string(), f.args))))
1032            }
1033
1034            // Pass through everything else
1035            _ => Ok(Expression::Function(Box::new(f))),
1036        }
1037    }
1038
1039    fn transform_aggregate_function(
1040        &self,
1041        f: Box<crate::expressions::AggregateFunction>,
1042    ) -> Result<Expression> {
1043        let name_upper = f.name.to_uppercase();
1044        match name_upper.as_str() {
1045            // STRING_AGG -> GROUP_CONCAT
1046            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
1047                "GROUP_CONCAT".to_string(),
1048                f.args,
1049            )))),
1050
1051            // ARRAY_AGG -> GROUP_CONCAT
1052            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
1053                "GROUP_CONCAT".to_string(),
1054                f.args,
1055            )))),
1056
1057            // Pass through everything else
1058            _ => Ok(Expression::AggregateFunction(f)),
1059        }
1060    }
1061}
1062
1063#[cfg(test)]
1064mod tests {
1065    use super::*;
1066    use crate::dialects::Dialect;
1067
1068    fn transpile_to_mysql(sql: &str) -> String {
1069        let dialect = Dialect::get(DialectType::Generic);
1070        let result = dialect
1071            .transpile_to(sql, DialectType::MySQL)
1072            .expect("Transpile failed");
1073        result[0].clone()
1074    }
1075
1076    #[test]
1077    fn test_nvl_to_ifnull() {
1078        let result = transpile_to_mysql("SELECT NVL(a, b)");
1079        assert!(
1080            result.contains("IFNULL"),
1081            "Expected IFNULL, got: {}",
1082            result
1083        );
1084    }
1085
1086    #[test]
1087    fn test_coalesce_preserved() {
1088        // COALESCE should be preserved in MySQL (it's a native function)
1089        let result = transpile_to_mysql("SELECT COALESCE(a, b)");
1090        assert!(
1091            result.contains("COALESCE"),
1092            "Expected COALESCE to be preserved, got: {}",
1093            result
1094        );
1095    }
1096
1097    #[test]
1098    fn test_random_to_rand() {
1099        let result = transpile_to_mysql("SELECT RANDOM()");
1100        assert!(result.contains("RAND"), "Expected RAND, got: {}", result);
1101    }
1102
1103    #[test]
1104    fn test_basic_select() {
1105        let result = transpile_to_mysql("SELECT a, b FROM users WHERE id = 1");
1106        assert!(result.contains("SELECT"));
1107        assert!(result.contains("FROM users"));
1108    }
1109
1110    #[test]
1111    fn test_string_agg_to_group_concat() {
1112        let result = transpile_to_mysql("SELECT STRING_AGG(name)");
1113        assert!(
1114            result.contains("GROUP_CONCAT"),
1115            "Expected GROUP_CONCAT, got: {}",
1116            result
1117        );
1118    }
1119
1120    #[test]
1121    fn test_array_agg_to_group_concat() {
1122        let result = transpile_to_mysql("SELECT ARRAY_AGG(name)");
1123        assert!(
1124            result.contains("GROUP_CONCAT"),
1125            "Expected GROUP_CONCAT, got: {}",
1126            result
1127        );
1128    }
1129
1130    #[test]
1131    fn test_to_date_to_str_to_date() {
1132        let result = transpile_to_mysql("SELECT TO_DATE('2023-01-01')");
1133        assert!(
1134            result.contains("STR_TO_DATE"),
1135            "Expected STR_TO_DATE, got: {}",
1136            result
1137        );
1138    }
1139
1140    #[test]
1141    fn test_backtick_identifiers() {
1142        // MySQL uses backticks for identifiers
1143        let dialect = MySQLDialect;
1144        let config = dialect.generator_config();
1145        assert_eq!(config.identifier_quote, '`');
1146    }
1147
1148    fn mysql_identity(sql: &str, expected: &str) {
1149        let dialect = Dialect::get(DialectType::MySQL);
1150        let ast = dialect.parse(sql).expect("Parse failed");
1151        let transformed = dialect.transform(ast[0].clone()).expect("Transform failed");
1152        let result = dialect.generate(&transformed).expect("Generate failed");
1153        assert_eq!(result, expected, "SQL: {}", sql);
1154    }
1155
1156    #[test]
1157    fn test_ucase_to_upper() {
1158        mysql_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')");
1159    }
1160
1161    #[test]
1162    fn test_lcase_to_lower() {
1163        mysql_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')");
1164    }
1165
1166    #[test]
1167    fn test_day_of_month() {
1168        mysql_identity("SELECT DAY_OF_MONTH('2023-01-01')", "SELECT DAYOFMONTH('2023-01-01')");
1169    }
1170
1171    #[test]
1172    fn test_day_of_week() {
1173        mysql_identity("SELECT DAY_OF_WEEK('2023-01-01')", "SELECT DAYOFWEEK('2023-01-01')");
1174    }
1175
1176    #[test]
1177    fn test_day_of_year() {
1178        mysql_identity("SELECT DAY_OF_YEAR('2023-01-01')", "SELECT DAYOFYEAR('2023-01-01')");
1179    }
1180
1181    #[test]
1182    fn test_week_of_year() {
1183        mysql_identity("SELECT WEEK_OF_YEAR('2023-01-01')", "SELECT WEEKOFYEAR('2023-01-01')");
1184    }
1185
1186    #[test]
1187    fn test_mod_func_to_percent() {
1188        // MOD(x, y) function is transformed to x % y in MySQL
1189        mysql_identity("MOD(x, y)", "x % y");
1190    }
1191
1192    #[test]
1193    fn test_database_to_schema() {
1194        mysql_identity("DATABASE()", "SCHEMA()");
1195    }
1196
1197    #[test]
1198    fn test_and_operator() {
1199        mysql_identity("SELECT 1 && 0", "SELECT 1 AND 0");
1200    }
1201
1202    #[test]
1203    fn test_or_operator() {
1204        mysql_identity("SELECT a || b", "SELECT a OR b");
1205    }
1206
1207}