Skip to main content

polyglot_sql/dialects/
mysql.rs

1//! MySQL Dialect
2//!
3//! MySQL-specific transformations based on sqlglot patterns.
4//! Key differences from standard SQL:
5//! - || is OR operator, not string concatenation (use CONCAT)
6//! - Uses backticks for identifiers
7//! - No TRY_CAST, no ILIKE
8//! - Different date/time function names
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    BinaryFunc, BinaryOp, Cast, DataType, Expression, Function, JsonExtractFunc, LikeOp, Literal,
14    Paren, UnaryFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Helper to wrap JSON arrow expressions in parentheses when they appear
20/// in contexts that require it (Binary, In, Not expressions)
21/// This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior
22fn wrap_if_json_arrow(expr: Expression) -> Expression {
23    match &expr {
24        Expression::JsonExtract(f) if f.arrow_syntax => Expression::Paren(Box::new(Paren {
25            this: expr,
26            trailing_comments: Vec::new(),
27        })),
28        Expression::JsonExtractScalar(f) if f.arrow_syntax => Expression::Paren(Box::new(Paren {
29            this: expr,
30            trailing_comments: Vec::new(),
31        })),
32        _ => expr,
33    }
34}
35
36/// Convert JSON arrow expression (-> or ->>) to JSON_EXTRACT function form
37/// This is needed for contexts like MEMBER OF where arrow syntax must become function form
38fn json_arrow_to_function(expr: Expression) -> Expression {
39    match expr {
40        Expression::JsonExtract(f) if f.arrow_syntax => Expression::Function(Box::new(
41            Function::new("JSON_EXTRACT".to_string(), vec![f.this, f.path]),
42        )),
43        Expression::JsonExtractScalar(f) if f.arrow_syntax => {
44            // ->> becomes JSON_UNQUOTE(JSON_EXTRACT(...)) but can be simplified to JSON_EXTRACT_SCALAR
45            // For MySQL, use JSON_UNQUOTE(JSON_EXTRACT(...))
46            let json_extract = Expression::Function(Box::new(Function::new(
47                "JSON_EXTRACT".to_string(),
48                vec![f.this, f.path],
49            )));
50            Expression::Function(Box::new(Function::new(
51                "JSON_UNQUOTE".to_string(),
52                vec![json_extract],
53            )))
54        }
55        other => other,
56    }
57}
58
59/// MySQL dialect
60pub struct MySQLDialect;
61
62impl DialectImpl for MySQLDialect {
63    fn dialect_type(&self) -> DialectType {
64        DialectType::MySQL
65    }
66
67    fn tokenizer_config(&self) -> TokenizerConfig {
68        use crate::tokens::TokenType;
69        let mut config = TokenizerConfig::default();
70        // MySQL uses backticks for identifiers
71        config.identifiers.insert('`', '`');
72        // Remove double quotes from identifiers - in MySQL they are string delimiters
73        // (unless ANSI_QUOTES mode is set, but default mode uses them as strings)
74        config.identifiers.remove(&'"');
75        // MySQL supports double quotes as string literals by default
76        config.quotes.insert("\"".to_string(), "\"".to_string());
77        // MySQL supports backslash escapes in strings
78        config.string_escapes.push('\\');
79        // MySQL has XOR as a logical operator keyword
80        config.keywords.insert("XOR".to_string(), TokenType::Xor);
81        // MySQL: backslash followed by chars NOT in this list -> discard backslash
82        // See: https://dev.mysql.com/doc/refman/8.4/en/string-literals.html
83        config.escape_follow_chars = vec!['0', 'b', 'n', 'r', 't', 'Z', '%', '_'];
84        // MySQL allows identifiers to start with digits (e.g., 1a, 1_a)
85        config.identifiers_can_start_with_digit = true;
86        config
87    }
88
89    fn generator_config(&self) -> GeneratorConfig {
90        use crate::generator::IdentifierQuoteStyle;
91        GeneratorConfig {
92            identifier_quote: '`',
93            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
94            dialect: Some(DialectType::MySQL),
95            // MySQL doesn't support null ordering in most contexts
96            null_ordering_supported: false,
97            // MySQL LIMIT only
98            limit_only_literals: true,
99            // MySQL doesn't support semi/anti join
100            semi_anti_join_with_side: false,
101            // MySQL doesn't support table alias columns in some contexts
102            supports_table_alias_columns: false,
103            // MySQL VALUES not used as table
104            values_as_table: false,
105            // MySQL doesn't support TABLESAMPLE
106            tablesample_requires_parens: false,
107            tablesample_with_method: false,
108            // MySQL doesn't support aggregate FILTER
109            aggregate_filter_supported: false,
110            // MySQL doesn't support TRY
111            try_supported: false,
112            // MySQL doesn't support CONVERT_TIMEZONE
113            supports_convert_timezone: false,
114            // MySQL doesn't support UESCAPE
115            supports_uescape: false,
116            // MySQL doesn't support BETWEEN flags
117            supports_between_flags: false,
118            // MySQL supports EXPLAIN but not query hints in standard way
119            query_hints: false,
120            // MySQL parameter token
121            parameter_token: "?",
122            // MySQL doesn't support window EXCLUDE
123            supports_window_exclude: false,
124            // MySQL doesn't support exploding projections
125            supports_exploding_projections: false,
126            identifiers_can_start_with_digit: true,
127            // MySQL supports FOR UPDATE/SHARE
128            locking_reads_supported: true,
129            ..Default::default()
130        }
131    }
132
133    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
134        match expr {
135            // ===== Data Type Mappings =====
136            Expression::DataType(dt) => self.transform_data_type(dt),
137
138            // NVL -> IFNULL in MySQL
139            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
140
141            // Note: COALESCE is valid in MySQL and should be preserved.
142            // Unlike some other dialects, we do NOT convert COALESCE to IFNULL
143            // as this would break identity tests.
144
145            // TryCast -> CAST or TIMESTAMP() (MySQL doesn't support TRY_CAST)
146            Expression::TryCast(c) => self.transform_cast(*c),
147
148            // SafeCast -> CAST or TIMESTAMP() (MySQL doesn't support safe casts)
149            Expression::SafeCast(c) => self.transform_cast(*c),
150
151            // Cast -> Transform cast type according to MySQL restrictions
152            // CAST AS TIMESTAMP -> TIMESTAMP() function in MySQL
153            Expression::Cast(c) => self.transform_cast(*c),
154
155            // ILIKE -> LOWER() LIKE LOWER() in MySQL
156            Expression::ILike(op) => {
157                // Transform ILIKE to: LOWER(left) LIKE LOWER(right)
158                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left)));
159                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right)));
160                Ok(Expression::Like(Box::new(LikeOp {
161                    left: lower_left,
162                    right: lower_right,
163                    escape: op.escape,
164                    quantifier: op.quantifier,
165                })))
166            }
167
168            // Preserve semantic string concatenation expressions.
169            // MySQL generation renders these as CONCAT(...).
170            Expression::Concat(op) => Ok(Expression::Concat(op)),
171
172            // RANDOM -> RAND in MySQL
173            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
174                seed: None,
175                lower: None,
176                upper: None,
177            }))),
178
179            // ArrayAgg -> GROUP_CONCAT in MySQL
180            Expression::ArrayAgg(f) => Ok(Expression::Function(Box::new(Function::new(
181                "GROUP_CONCAT".to_string(),
182                vec![f.this],
183            )))),
184
185            // StringAgg -> GROUP_CONCAT in MySQL
186            Expression::StringAgg(f) => {
187                let mut args = vec![f.this.clone()];
188                if let Some(separator) = &f.separator {
189                    args.push(separator.clone());
190                }
191                Ok(Expression::Function(Box::new(Function::new(
192                    "GROUP_CONCAT".to_string(),
193                    args,
194                ))))
195            }
196
197            // UNNEST -> Not directly supported in MySQL, use JSON_TABLE or inline
198            // For basic cases, pass through (may need manual handling)
199            Expression::Unnest(f) => {
200                // MySQL 8.0+ has JSON_TABLE which can be used for unnesting
201                // For now, pass through with a function call
202                Ok(Expression::Function(Box::new(Function::new(
203                    "JSON_TABLE".to_string(),
204                    vec![f.this],
205                ))))
206            }
207
208            // Substring: Use comma syntax (not FROM/FOR) in MySQL
209            Expression::Substring(mut f) => {
210                f.from_for_syntax = false;
211                Ok(Expression::Substring(f))
212            }
213
214            // ===== Bitwise operations =====
215            // BitwiseAndAgg -> BIT_AND
216            Expression::BitwiseAndAgg(f) => Ok(Expression::Function(Box::new(Function::new(
217                "BIT_AND".to_string(),
218                vec![f.this],
219            )))),
220
221            // BitwiseOrAgg -> BIT_OR
222            Expression::BitwiseOrAgg(f) => Ok(Expression::Function(Box::new(Function::new(
223                "BIT_OR".to_string(),
224                vec![f.this],
225            )))),
226
227            // BitwiseXorAgg -> BIT_XOR
228            Expression::BitwiseXorAgg(f) => Ok(Expression::Function(Box::new(Function::new(
229                "BIT_XOR".to_string(),
230                vec![f.this],
231            )))),
232
233            // BitwiseCount -> BIT_COUNT
234            Expression::BitwiseCount(f) => Ok(Expression::Function(Box::new(Function::new(
235                "BIT_COUNT".to_string(),
236                vec![f.this],
237            )))),
238
239            // TimeFromParts -> MAKETIME
240            Expression::TimeFromParts(f) => {
241                let mut args = Vec::new();
242                if let Some(h) = f.hour {
243                    args.push(*h);
244                }
245                if let Some(m) = f.min {
246                    args.push(*m);
247                }
248                if let Some(s) = f.sec {
249                    args.push(*s);
250                }
251                Ok(Expression::Function(Box::new(Function::new(
252                    "MAKETIME".to_string(),
253                    args,
254                ))))
255            }
256
257            // ===== Boolean aggregates =====
258            // In MySQL, there's no BOOL_AND/BOOL_OR, use MIN/MAX on boolean values
259            // LogicalAnd -> MIN (0 is false, non-0 is true)
260            Expression::LogicalAnd(f) => Ok(Expression::Function(Box::new(Function::new(
261                "MIN".to_string(),
262                vec![f.this],
263            )))),
264
265            // LogicalOr -> MAX
266            Expression::LogicalOr(f) => Ok(Expression::Function(Box::new(Function::new(
267                "MAX".to_string(),
268                vec![f.this],
269            )))),
270
271            // ===== Date/time functions =====
272            // DayOfMonth -> DAYOFMONTH
273            Expression::DayOfMonth(f) => Ok(Expression::Function(Box::new(Function::new(
274                "DAYOFMONTH".to_string(),
275                vec![f.this],
276            )))),
277
278            // DayOfWeek -> DAYOFWEEK
279            Expression::DayOfWeek(f) => Ok(Expression::Function(Box::new(Function::new(
280                "DAYOFWEEK".to_string(),
281                vec![f.this],
282            )))),
283
284            // DayOfYear -> DAYOFYEAR
285            Expression::DayOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
286                "DAYOFYEAR".to_string(),
287                vec![f.this],
288            )))),
289
290            // WeekOfYear -> WEEKOFYEAR
291            Expression::WeekOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
292                "WEEKOFYEAR".to_string(),
293                vec![f.this],
294            )))),
295
296            // DateDiff -> DATEDIFF
297            Expression::DateDiff(f) => Ok(Expression::Function(Box::new(Function::new(
298                "DATEDIFF".to_string(),
299                vec![f.this, f.expression],
300            )))),
301
302            // TimeStrToUnix -> UNIX_TIMESTAMP
303            Expression::TimeStrToUnix(f) => Ok(Expression::Function(Box::new(Function::new(
304                "UNIX_TIMESTAMP".to_string(),
305                vec![f.this],
306            )))),
307
308            // TimestampDiff -> TIMESTAMPDIFF
309            Expression::TimestampDiff(f) => Ok(Expression::Function(Box::new(Function::new(
310                "TIMESTAMPDIFF".to_string(),
311                vec![*f.this, *f.expression],
312            )))),
313
314            // ===== String functions =====
315            // StrPosition -> LOCATE in MySQL
316            // STRPOS(str, substr) -> LOCATE(substr, str) (args are swapped)
317            Expression::StrPosition(f) => {
318                let mut args = vec![];
319                if let Some(substr) = f.substr {
320                    args.push(*substr);
321                }
322                args.push(*f.this);
323                if let Some(pos) = f.position {
324                    args.push(*pos);
325                }
326                Ok(Expression::Function(Box::new(Function::new(
327                    "LOCATE".to_string(),
328                    args,
329                ))))
330            }
331
332            // Stuff -> INSERT in MySQL
333            Expression::Stuff(f) => {
334                let mut args = vec![*f.this];
335                if let Some(start) = f.start {
336                    args.push(*start);
337                }
338                if let Some(length) = f.length {
339                    args.push(Expression::number(length));
340                }
341                args.push(*f.expression);
342                Ok(Expression::Function(Box::new(Function::new(
343                    "INSERT".to_string(),
344                    args,
345                ))))
346            }
347
348            // ===== Session/User functions =====
349            // SessionUser -> SESSION_USER()
350            Expression::SessionUser(_) => Ok(Expression::Function(Box::new(Function::new(
351                "SESSION_USER".to_string(),
352                vec![],
353            )))),
354
355            // CurrentDate -> CURRENT_DATE (no parentheses in MySQL) - keep as CurrentDate
356            Expression::CurrentDate(_) => {
357                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
358            }
359
360            // ===== Null-safe comparison =====
361            // NullSafeNeq -> NOT (a <=> b) in MySQL
362            Expression::NullSafeNeq(op) => {
363                // Create: NOT (left <=> right)
364                let null_safe_eq = Expression::NullSafeEq(Box::new(crate::expressions::BinaryOp {
365                    left: op.left,
366                    right: op.right,
367                    left_comments: Vec::new(),
368                    operator_comments: Vec::new(),
369                    trailing_comments: Vec::new(),
370                }));
371                Ok(Expression::Not(Box::new(crate::expressions::UnaryOp {
372                    this: null_safe_eq,
373                })))
374            }
375
376            // ParseJson: handled by generator (emits just the string literal for MySQL)
377
378            // JSONExtract with variant_extract (Snowflake colon syntax) -> JSON_EXTRACT
379            Expression::JSONExtract(e) if e.variant_extract.is_some() => {
380                let path = match *e.expression {
381                    Expression::Literal(Literal::String(s)) => {
382                        // Convert bracket notation ["key"] to quoted dot notation ."key"
383                        let s = Self::convert_bracket_to_quoted_path(&s);
384                        let normalized = if s.starts_with('$') {
385                            s
386                        } else if s.starts_with('[') {
387                            format!("${}", s)
388                        } else {
389                            format!("$.{}", s)
390                        };
391                        Expression::Literal(Literal::String(normalized))
392                    }
393                    other => other,
394                };
395                Ok(Expression::Function(Box::new(Function::new(
396                    "JSON_EXTRACT".to_string(),
397                    vec![*e.this, path],
398                ))))
399            }
400
401            // Generic function transformations
402            Expression::Function(f) => self.transform_function(*f),
403
404            // Generic aggregate function transformations
405            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
406
407            // ===== Context-aware JSON arrow wrapping =====
408            // When JSON arrow expressions appear in Binary/In/Not contexts,
409            // they need to be wrapped in parentheses for correct precedence.
410            // This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior.
411
412            // Binary operators that need JSON wrapping
413            Expression::Eq(op) => Ok(Expression::Eq(Box::new(BinaryOp {
414                left: wrap_if_json_arrow(op.left),
415                right: wrap_if_json_arrow(op.right),
416                ..*op
417            }))),
418            Expression::Neq(op) => Ok(Expression::Neq(Box::new(BinaryOp {
419                left: wrap_if_json_arrow(op.left),
420                right: wrap_if_json_arrow(op.right),
421                ..*op
422            }))),
423            Expression::Lt(op) => Ok(Expression::Lt(Box::new(BinaryOp {
424                left: wrap_if_json_arrow(op.left),
425                right: wrap_if_json_arrow(op.right),
426                ..*op
427            }))),
428            Expression::Lte(op) => Ok(Expression::Lte(Box::new(BinaryOp {
429                left: wrap_if_json_arrow(op.left),
430                right: wrap_if_json_arrow(op.right),
431                ..*op
432            }))),
433            Expression::Gt(op) => Ok(Expression::Gt(Box::new(BinaryOp {
434                left: wrap_if_json_arrow(op.left),
435                right: wrap_if_json_arrow(op.right),
436                ..*op
437            }))),
438            Expression::Gte(op) => Ok(Expression::Gte(Box::new(BinaryOp {
439                left: wrap_if_json_arrow(op.left),
440                right: wrap_if_json_arrow(op.right),
441                ..*op
442            }))),
443
444            // In expression - wrap the this part if it's JSON arrow
445            Expression::In(mut i) => {
446                i.this = wrap_if_json_arrow(i.this);
447                Ok(Expression::In(i))
448            }
449
450            // Not expression - wrap the this part if it's JSON arrow
451            Expression::Not(mut n) => {
452                n.this = wrap_if_json_arrow(n.this);
453                Ok(Expression::Not(n))
454            }
455
456            // && in MySQL is logical AND, not array overlaps
457            // Transform ArrayOverlaps -> And for MySQL identity
458            Expression::ArrayOverlaps(op) => Ok(Expression::And(op)),
459
460            // MOD(x, y) -> x % y in MySQL
461            Expression::ModFunc(f) => Ok(Expression::Mod(Box::new(BinaryOp {
462                left: f.this,
463                right: f.expression,
464                left_comments: Vec::new(),
465                operator_comments: Vec::new(),
466                trailing_comments: Vec::new(),
467            }))),
468
469            // SHOW SLAVE STATUS -> SHOW REPLICA STATUS
470            Expression::Show(mut s) => {
471                if s.this == "SLAVE STATUS" {
472                    s.this = "REPLICA STATUS".to_string();
473                }
474                if matches!(s.this.as_str(), "INDEX" | "COLUMNS") && s.db.is_none() {
475                    if let Some(Expression::Table(mut t)) = s.target.take() {
476                        if let Some(db_ident) = t.schema.take().or(t.catalog.take()) {
477                            s.db = Some(Expression::Identifier(db_ident));
478                            s.target = Some(Expression::Identifier(t.name));
479                        } else {
480                            s.target = Some(Expression::Table(t));
481                        }
482                    }
483                }
484                Ok(Expression::Show(s))
485            }
486
487            // AT TIME ZONE -> strip timezone (MySQL doesn't support AT TIME ZONE)
488            // But keep it for CURRENT_DATE/CURRENT_TIMESTAMP with timezone (transpiled from BigQuery)
489            Expression::AtTimeZone(atz) => {
490                let is_current = match &atz.this {
491                    Expression::CurrentDate(_) | Expression::CurrentTimestamp(_) => true,
492                    Expression::Function(f) => {
493                        let n = f.name.to_uppercase();
494                        (n == "CURRENT_DATE" || n == "CURRENT_TIMESTAMP") && f.no_parens
495                    }
496                    _ => false,
497                };
498                if is_current {
499                    Ok(Expression::AtTimeZone(atz)) // Keep AT TIME ZONE for CURRENT_DATE/CURRENT_TIMESTAMP
500                } else {
501                    Ok(atz.this) // Strip timezone for other expressions
502                }
503            }
504
505            // MEMBER OF with JSON arrow -> convert arrow to JSON_EXTRACT function
506            // MySQL's MEMBER OF requires JSON_EXTRACT function form, not arrow syntax
507            Expression::MemberOf(mut op) => {
508                op.right = json_arrow_to_function(op.right);
509                Ok(Expression::MemberOf(op))
510            }
511
512            // Pass through everything else
513            _ => Ok(expr),
514        }
515    }
516}
517
518impl MySQLDialect {
519    fn normalize_mysql_date_format(fmt: &str) -> String {
520        fmt.replace("%H:%i:%s", "%T").replace("%H:%i:%S", "%T")
521    }
522
523    /// Convert bracket notation ["key with spaces"] to quoted dot notation ."key with spaces"
524    /// in JSON path strings.
525    fn convert_bracket_to_quoted_path(path: &str) -> String {
526        let mut result = String::new();
527        let mut chars = path.chars().peekable();
528        while let Some(c) = chars.next() {
529            if c == '[' && chars.peek() == Some(&'"') {
530                chars.next(); // consume "
531                let mut key = String::new();
532                while let Some(kc) = chars.next() {
533                    if kc == '"' && chars.peek() == Some(&']') {
534                        chars.next(); // consume ]
535                        break;
536                    }
537                    key.push(kc);
538                }
539                if !result.is_empty() && !result.ends_with('.') {
540                    result.push('.');
541                }
542                result.push('"');
543                result.push_str(&key);
544                result.push('"');
545            } else {
546                result.push(c);
547            }
548        }
549        result
550    }
551
552    /// Transform data types according to MySQL TYPE_MAPPING
553    /// Note: MySQL's TIMESTAMP is kept as TIMESTAMP (not converted to DATETIME)
554    /// because MySQL's TIMESTAMP has timezone awareness built-in
555    fn transform_data_type(&self, dt: crate::expressions::DataType) -> Result<Expression> {
556        use crate::expressions::DataType;
557        let transformed = match dt {
558            // All TIMESTAMP variants (with or without timezone) -> TIMESTAMP in MySQL
559            DataType::Timestamp {
560                precision,
561                timezone: _,
562            } => DataType::Timestamp {
563                precision,
564                timezone: false,
565            },
566            // TIMESTAMPTZ / TIMESTAMPLTZ parsed as Custom -> normalize to TIMESTAMP
567            DataType::Custom { name }
568                if name.to_uppercase() == "TIMESTAMPTZ"
569                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
570            {
571                DataType::Timestamp {
572                    precision: None,
573                    timezone: false,
574                }
575            }
576            // Keep native MySQL types as-is
577            // MySQL supports TEXT, MEDIUMTEXT, LONGTEXT, BLOB, etc. natively
578            other => other,
579        };
580        Ok(Expression::DataType(transformed))
581    }
582
583    /// Transform CAST expression
584    /// MySQL uses TIMESTAMP() function instead of CAST(x AS TIMESTAMP)
585    /// For Generic->MySQL, TIMESTAMP (no tz) is pre-converted to DATETIME in cross_dialect_normalize
586    fn transform_cast(&self, cast: Cast) -> Result<Expression> {
587        // CAST AS TIMESTAMP/TIMESTAMPTZ/TIMESTAMPLTZ -> TIMESTAMP() function
588        match &cast.to {
589            DataType::Timestamp { .. } => Ok(Expression::Function(Box::new(Function::new(
590                "TIMESTAMP".to_string(),
591                vec![cast.this],
592            )))),
593            DataType::Custom { name }
594                if name.to_uppercase() == "TIMESTAMPTZ"
595                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
596            {
597                Ok(Expression::Function(Box::new(Function::new(
598                    "TIMESTAMP".to_string(),
599                    vec![cast.this],
600                ))))
601            }
602            // All other casts go through normal type transformation
603            _ => Ok(Expression::Cast(Box::new(self.transform_cast_type(cast)))),
604        }
605    }
606
607    /// Transform CAST type according to MySQL restrictions
608    /// MySQL doesn't support many types in CAST - they get mapped to CHAR or SIGNED
609    /// Based on Python sqlglot's CHAR_CAST_MAPPING and SIGNED_CAST_MAPPING
610    fn transform_cast_type(&self, cast: Cast) -> Cast {
611        let new_type = match &cast.to {
612            // CHAR_CAST_MAPPING: These types become CHAR in MySQL CAST, preserving length
613            DataType::VarChar { length, .. } => DataType::Char { length: *length },
614            DataType::Text => DataType::Char { length: None },
615
616            // SIGNED_CAST_MAPPING: These integer types become SIGNED in MySQL CAST
617            DataType::BigInt { .. } => DataType::Custom {
618                name: "SIGNED".to_string(),
619            },
620            DataType::Int { .. } => DataType::Custom {
621                name: "SIGNED".to_string(),
622            },
623            DataType::SmallInt { .. } => DataType::Custom {
624                name: "SIGNED".to_string(),
625            },
626            DataType::TinyInt { .. } => DataType::Custom {
627                name: "SIGNED".to_string(),
628            },
629            DataType::Boolean => DataType::Custom {
630                name: "SIGNED".to_string(),
631            },
632
633            // Custom types that need mapping
634            DataType::Custom { name } => {
635                let upper = name.to_uppercase();
636                match upper.as_str() {
637                    // Text/Blob types -> keep as Custom for cross-dialect mapping
638                    // MySQL generator will output CHAR for these in CAST context
639                    "LONGTEXT" | "MEDIUMTEXT" | "TINYTEXT" | "LONGBLOB" | "MEDIUMBLOB"
640                    | "TINYBLOB" => DataType::Custom { name: upper },
641                    // MEDIUMINT -> SIGNED in MySQL CAST
642                    "MEDIUMINT" => DataType::Custom {
643                        name: "SIGNED".to_string(),
644                    },
645                    // Unsigned integer types -> UNSIGNED
646                    "UBIGINT" | "UINT" | "USMALLINT" | "UTINYINT" | "UMEDIUMINT" => {
647                        DataType::Custom {
648                            name: "UNSIGNED".to_string(),
649                        }
650                    }
651                    // Keep other custom types
652                    _ => cast.to.clone(),
653                }
654            }
655
656            // Types that are valid in MySQL CAST - pass through
657            DataType::Binary { .. } => cast.to.clone(),
658            DataType::VarBinary { .. } => cast.to.clone(),
659            DataType::Date => cast.to.clone(),
660            DataType::Time { .. } => cast.to.clone(),
661            DataType::Decimal { .. } => cast.to.clone(),
662            DataType::Json => cast.to.clone(),
663            DataType::Float { .. } => cast.to.clone(),
664            DataType::Double { .. } => cast.to.clone(),
665            DataType::Char { .. } => cast.to.clone(),
666            DataType::CharacterSet { .. } => cast.to.clone(),
667            DataType::Enum { .. } => cast.to.clone(),
668            DataType::Set { .. } => cast.to.clone(),
669            DataType::Timestamp { .. } => cast.to.clone(),
670
671            // All other unsupported types -> CHAR
672            _ => DataType::Char { length: None },
673        };
674
675        Cast {
676            this: cast.this,
677            to: new_type,
678            trailing_comments: cast.trailing_comments,
679            double_colon_syntax: cast.double_colon_syntax,
680            format: cast.format,
681            default: cast.default,
682        }
683    }
684
685    fn transform_function(&self, f: Function) -> Result<Expression> {
686        let name_upper = f.name.to_uppercase();
687        match name_upper.as_str() {
688            // Normalize DATE_FORMAT short-hands to canonical MySQL forms.
689            "DATE_FORMAT" if f.args.len() >= 2 => {
690                let mut f = f;
691                if let Some(Expression::Literal(Literal::String(fmt))) = f.args.get(1) {
692                    let normalized = Self::normalize_mysql_date_format(fmt);
693                    if normalized != *fmt {
694                        f.args[1] = Expression::Literal(Literal::String(normalized));
695                    }
696                }
697                Ok(Expression::Function(Box::new(f)))
698            }
699
700            // NVL -> IFNULL
701            "NVL" if f.args.len() == 2 => {
702                let mut args = f.args;
703                let second = args.pop().unwrap();
704                let first = args.pop().unwrap();
705                Ok(Expression::IfNull(Box::new(BinaryFunc {
706                    original_name: None,
707                    this: first,
708                    expression: second,
709                })))
710            }
711
712            // Note: COALESCE is native to MySQL. We do NOT convert it to IFNULL
713            // because this would break identity tests (Python SQLGlot preserves COALESCE).
714
715            // ARRAY_AGG -> GROUP_CONCAT
716            "ARRAY_AGG" if f.args.len() == 1 => {
717                let mut args = f.args;
718                Ok(Expression::Function(Box::new(Function::new(
719                    "GROUP_CONCAT".to_string(),
720                    vec![args.pop().unwrap()],
721                ))))
722            }
723
724            // STRING_AGG -> GROUP_CONCAT
725            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
726                Function::new("GROUP_CONCAT".to_string(), f.args),
727            ))),
728
729            // RANDOM -> RAND
730            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
731                seed: None,
732                lower: None,
733                upper: None,
734            }))),
735
736            // CURRENT_TIMESTAMP -> NOW() or CURRENT_TIMESTAMP (both work)
737            // Preserve precision if specified: CURRENT_TIMESTAMP(6)
738            "CURRENT_TIMESTAMP" => {
739                let precision =
740                    if let Some(Expression::Literal(crate::expressions::Literal::Number(n))) =
741                        f.args.first()
742                    {
743                        n.parse::<u32>().ok()
744                    } else {
745                        None
746                    };
747                Ok(Expression::CurrentTimestamp(
748                    crate::expressions::CurrentTimestamp {
749                        precision,
750                        sysdate: false,
751                    },
752                ))
753            }
754
755            // POSITION -> LOCATE in MySQL (argument order is different)
756            // POSITION(substr IN str) -> LOCATE(substr, str)
757            "POSITION" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
758                "LOCATE".to_string(),
759                f.args,
760            )))),
761
762            // LENGTH is native to MySQL (returns bytes, not characters)
763            // CHAR_LENGTH for character count
764            "LENGTH" => Ok(Expression::Function(Box::new(f))),
765
766            // CEIL -> CEILING in MySQL (both work)
767            "CEIL" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
768                "CEILING".to_string(),
769                f.args,
770            )))),
771
772            // STDDEV -> STD or STDDEV_POP in MySQL
773            "STDDEV" => Ok(Expression::Function(Box::new(Function::new(
774                "STD".to_string(),
775                f.args,
776            )))),
777
778            // STDDEV_SAMP -> STDDEV in MySQL
779            "STDDEV_SAMP" => Ok(Expression::Function(Box::new(Function::new(
780                "STDDEV".to_string(),
781                f.args,
782            )))),
783
784            // TO_DATE -> STR_TO_DATE in MySQL
785            "TO_DATE" => Ok(Expression::Function(Box::new(Function::new(
786                "STR_TO_DATE".to_string(),
787                f.args,
788            )))),
789
790            // TO_TIMESTAMP -> STR_TO_DATE in MySQL
791            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(Function::new(
792                "STR_TO_DATE".to_string(),
793                f.args,
794            )))),
795
796            // DATE_TRUNC -> Complex transformation
797            // Typically uses DATE() or DATE_FORMAT() depending on unit
798            "DATE_TRUNC" if f.args.len() >= 2 => {
799                // Simplified: DATE_TRUNC('day', x) -> DATE(x)
800                // Full implementation would handle different units
801                let mut args = f.args;
802                let _unit = args.remove(0);
803                let date = args.remove(0);
804                Ok(Expression::Function(Box::new(Function::new(
805                    "DATE".to_string(),
806                    vec![date],
807                ))))
808            }
809
810            // EXTRACT is native but syntax varies
811
812            // COALESCE is native to MySQL (keep as-is for more than 2 args)
813            "COALESCE" if f.args.len() > 2 => Ok(Expression::Function(Box::new(f))),
814
815            // DAYOFMONTH -> DAY (both work)
816            "DAY" => Ok(Expression::Function(Box::new(Function::new(
817                "DAYOFMONTH".to_string(),
818                f.args,
819            )))),
820
821            // DAYOFWEEK is native to MySQL
822            "DAYOFWEEK" => Ok(Expression::Function(Box::new(f))),
823
824            // DAYOFYEAR is native to MySQL
825            "DAYOFYEAR" => Ok(Expression::Function(Box::new(f))),
826
827            // WEEKOFYEAR is native to MySQL
828            "WEEKOFYEAR" => Ok(Expression::Function(Box::new(f))),
829
830            // LAST_DAY is native to MySQL
831            "LAST_DAY" => Ok(Expression::Function(Box::new(f))),
832
833            // TIMESTAMPADD -> DATE_ADD
834            "TIMESTAMPADD" => Ok(Expression::Function(Box::new(Function::new(
835                "DATE_ADD".to_string(),
836                f.args,
837            )))),
838
839            // TIMESTAMPDIFF is native to MySQL
840            "TIMESTAMPDIFF" => Ok(Expression::Function(Box::new(f))),
841
842            // CONVERT_TIMEZONE(from_tz, to_tz, timestamp) -> CONVERT_TZ(timestamp, from_tz, to_tz) in MySQL
843            "CONVERT_TIMEZONE" if f.args.len() == 3 => {
844                let mut args = f.args;
845                let from_tz = args.remove(0);
846                let to_tz = args.remove(0);
847                let timestamp = args.remove(0);
848                Ok(Expression::Function(Box::new(Function::new(
849                    "CONVERT_TZ".to_string(),
850                    vec![timestamp, from_tz, to_tz],
851                ))))
852            }
853
854            // UTC_TIMESTAMP is native to MySQL
855            "UTC_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
856
857            // UTC_TIME is native to MySQL
858            "UTC_TIME" => Ok(Expression::Function(Box::new(f))),
859
860            // MAKETIME is native to MySQL (TimeFromParts)
861            "MAKETIME" => Ok(Expression::Function(Box::new(f))),
862
863            // TIME_FROM_PARTS -> MAKETIME
864            "TIME_FROM_PARTS" if f.args.len() == 3 => Ok(Expression::Function(Box::new(
865                Function::new("MAKETIME".to_string(), f.args),
866            ))),
867
868            // STUFF -> INSERT in MySQL
869            "STUFF" if f.args.len() == 4 => Ok(Expression::Function(Box::new(Function::new(
870                "INSERT".to_string(),
871                f.args,
872            )))),
873
874            // LOCATE is native to MySQL (reverse of POSITION args)
875            "LOCATE" => Ok(Expression::Function(Box::new(f))),
876
877            // FIND_IN_SET is native to MySQL
878            "FIND_IN_SET" => Ok(Expression::Function(Box::new(f))),
879
880            // FORMAT is native to MySQL (NumberToStr)
881            "FORMAT" => Ok(Expression::Function(Box::new(f))),
882
883            // JSON_EXTRACT is native to MySQL
884            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
885
886            // JSON_UNQUOTE is native to MySQL
887            "JSON_UNQUOTE" => Ok(Expression::Function(Box::new(f))),
888
889            // JSON_EXTRACT_PATH_TEXT -> JSON_UNQUOTE(JSON_EXTRACT(...))
890            "JSON_EXTRACT_PATH_TEXT" if f.args.len() >= 2 => {
891                let extract = Expression::Function(Box::new(Function::new(
892                    "JSON_EXTRACT".to_string(),
893                    f.args,
894                )));
895                Ok(Expression::Function(Box::new(Function::new(
896                    "JSON_UNQUOTE".to_string(),
897                    vec![extract],
898                ))))
899            }
900
901            // GEN_RANDOM_UUID / UUID -> UUID()
902            "GEN_RANDOM_UUID" | "GENERATE_UUID" => Ok(Expression::Function(Box::new(
903                Function::new("UUID".to_string(), vec![]),
904            ))),
905
906            // DATABASE() -> SCHEMA() in MySQL (both return current database name)
907            "DATABASE" => Ok(Expression::Function(Box::new(Function::new(
908                "SCHEMA".to_string(),
909                f.args,
910            )))),
911
912            // INSTR -> LOCATE in MySQL (with swapped arguments)
913            // INSTR(str, substr) -> LOCATE(substr, str)
914            "INSTR" if f.args.len() == 2 => {
915                let mut args = f.args;
916                let str_arg = args.remove(0);
917                let substr_arg = args.remove(0);
918                Ok(Expression::Function(Box::new(Function::new(
919                    "LOCATE".to_string(),
920                    vec![substr_arg, str_arg],
921                ))))
922            }
923
924            // TIME_STR_TO_UNIX -> UNIX_TIMESTAMP in MySQL
925            "TIME_STR_TO_UNIX" => Ok(Expression::Function(Box::new(Function::new(
926                "UNIX_TIMESTAMP".to_string(),
927                f.args,
928            )))),
929
930            // TIME_STR_TO_TIME -> CAST AS DATETIME(N) or TIMESTAMP() in MySQL
931            "TIME_STR_TO_TIME" if f.args.len() >= 1 => {
932                let mut args = f.args.into_iter();
933                let arg = args.next().unwrap();
934
935                // If there's a timezone arg, use TIMESTAMP() function instead
936                if args.next().is_some() {
937                    return Ok(Expression::Function(Box::new(Function::new(
938                        "TIMESTAMP".to_string(),
939                        vec![arg],
940                    ))));
941                }
942
943                // Extract sub-second precision from the string literal
944                let precision =
945                    if let Expression::Literal(crate::expressions::Literal::String(ref s)) = arg {
946                        // Find fractional seconds: look for .NNN pattern after HH:MM:SS
947                        if let Some(dot_pos) = s.rfind('.') {
948                            let after_dot = &s[dot_pos + 1..];
949                            // Count digits until non-digit
950                            let frac_digits =
951                                after_dot.chars().take_while(|c| c.is_ascii_digit()).count();
952                            if frac_digits > 0 {
953                                // Round up: 1-3 digits → 3, 4-6 digits → 6
954                                if frac_digits <= 3 {
955                                    Some(3)
956                                } else {
957                                    Some(6)
958                                }
959                            } else {
960                                None
961                            }
962                        } else {
963                            None
964                        }
965                    } else {
966                        None
967                    };
968
969                let type_name = match precision {
970                    Some(p) => format!("DATETIME({})", p),
971                    None => "DATETIME".to_string(),
972                };
973
974                Ok(Expression::Cast(Box::new(Cast {
975                    this: arg,
976                    to: DataType::Custom { name: type_name },
977                    trailing_comments: Vec::new(),
978                    double_colon_syntax: false,
979                    format: None,
980                    default: None,
981                })))
982            }
983
984            // UCASE -> UPPER in MySQL
985            "UCASE" => Ok(Expression::Function(Box::new(Function::new(
986                "UPPER".to_string(),
987                f.args,
988            )))),
989
990            // LCASE -> LOWER in MySQL
991            "LCASE" => Ok(Expression::Function(Box::new(Function::new(
992                "LOWER".to_string(),
993                f.args,
994            )))),
995
996            // DAY_OF_MONTH -> DAYOFMONTH in MySQL
997            "DAY_OF_MONTH" => Ok(Expression::Function(Box::new(Function::new(
998                "DAYOFMONTH".to_string(),
999                f.args,
1000            )))),
1001
1002            // DAY_OF_WEEK -> DAYOFWEEK in MySQL
1003            "DAY_OF_WEEK" => Ok(Expression::Function(Box::new(Function::new(
1004                "DAYOFWEEK".to_string(),
1005                f.args,
1006            )))),
1007
1008            // DAY_OF_YEAR -> DAYOFYEAR in MySQL
1009            "DAY_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
1010                "DAYOFYEAR".to_string(),
1011                f.args,
1012            )))),
1013
1014            // WEEK_OF_YEAR -> WEEKOFYEAR in MySQL
1015            "WEEK_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
1016                "WEEKOFYEAR".to_string(),
1017                f.args,
1018            )))),
1019
1020            // MOD(x, y) -> x % y in MySQL
1021            "MOD" if f.args.len() == 2 => {
1022                let mut args = f.args;
1023                let left = args.remove(0);
1024                let right = args.remove(0);
1025                Ok(Expression::Mod(Box::new(BinaryOp {
1026                    left,
1027                    right,
1028                    left_comments: Vec::new(),
1029                    operator_comments: Vec::new(),
1030                    trailing_comments: Vec::new(),
1031                })))
1032            }
1033
1034            // PARSE_JSON -> strip in MySQL (just keep the string argument)
1035            "PARSE_JSON" if f.args.len() == 1 => Ok(f.args.into_iter().next().unwrap()),
1036
1037            // GET_PATH(obj, path) -> JSON_EXTRACT(obj, json_path) in MySQL
1038            "GET_PATH" if f.args.len() == 2 => {
1039                let mut args = f.args;
1040                let this = args.remove(0);
1041                let path = args.remove(0);
1042                let json_path = match &path {
1043                    Expression::Literal(Literal::String(s)) => {
1044                        // Convert bracket notation ["key"] to quoted dot notation ."key"
1045                        let s = Self::convert_bracket_to_quoted_path(s);
1046                        let normalized = if s.starts_with('$') {
1047                            s
1048                        } else if s.starts_with('[') {
1049                            format!("${}", s)
1050                        } else {
1051                            format!("$.{}", s)
1052                        };
1053                        Expression::Literal(Literal::String(normalized))
1054                    }
1055                    _ => path,
1056                };
1057                Ok(Expression::JsonExtract(Box::new(JsonExtractFunc {
1058                    this,
1059                    path: json_path,
1060                    returning: None,
1061                    arrow_syntax: false,
1062                    hash_arrow_syntax: false,
1063                    wrapper_option: None,
1064                    quotes_option: None,
1065                    on_scalar_string: false,
1066                    on_error: None,
1067                })))
1068            }
1069
1070            // REGEXP -> REGEXP_LIKE (MySQL standard form)
1071            "REGEXP" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
1072                "REGEXP_LIKE".to_string(),
1073                f.args,
1074            )))),
1075
1076            // Pass through everything else
1077            _ => Ok(Expression::Function(Box::new(f))),
1078        }
1079    }
1080
1081    fn transform_aggregate_function(
1082        &self,
1083        f: Box<crate::expressions::AggregateFunction>,
1084    ) -> Result<Expression> {
1085        let name_upper = f.name.to_uppercase();
1086        match name_upper.as_str() {
1087            // STRING_AGG -> GROUP_CONCAT
1088            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
1089                Function::new("GROUP_CONCAT".to_string(), f.args),
1090            ))),
1091
1092            // ARRAY_AGG -> GROUP_CONCAT
1093            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
1094                "GROUP_CONCAT".to_string(),
1095                f.args,
1096            )))),
1097
1098            // Pass through everything else
1099            _ => Ok(Expression::AggregateFunction(f)),
1100        }
1101    }
1102}
1103
1104#[cfg(test)]
1105mod tests {
1106    use super::*;
1107    use crate::dialects::Dialect;
1108
1109    fn transpile_to_mysql(sql: &str) -> String {
1110        let dialect = Dialect::get(DialectType::Generic);
1111        let result = dialect
1112            .transpile_to(sql, DialectType::MySQL)
1113            .expect("Transpile failed");
1114        result[0].clone()
1115    }
1116
1117    #[test]
1118    fn test_nvl_to_ifnull() {
1119        let result = transpile_to_mysql("SELECT NVL(a, b)");
1120        assert!(
1121            result.contains("IFNULL"),
1122            "Expected IFNULL, got: {}",
1123            result
1124        );
1125    }
1126
1127    #[test]
1128    fn test_coalesce_preserved() {
1129        // COALESCE should be preserved in MySQL (it's a native function)
1130        let result = transpile_to_mysql("SELECT COALESCE(a, b)");
1131        assert!(
1132            result.contains("COALESCE"),
1133            "Expected COALESCE to be preserved, got: {}",
1134            result
1135        );
1136    }
1137
1138    #[test]
1139    fn test_random_to_rand() {
1140        let result = transpile_to_mysql("SELECT RANDOM()");
1141        assert!(result.contains("RAND"), "Expected RAND, got: {}", result);
1142    }
1143
1144    #[test]
1145    fn test_basic_select() {
1146        let result = transpile_to_mysql("SELECT a, b FROM users WHERE id = 1");
1147        assert!(result.contains("SELECT"));
1148        assert!(result.contains("FROM users"));
1149    }
1150
1151    #[test]
1152    fn test_string_agg_to_group_concat() {
1153        let result = transpile_to_mysql("SELECT STRING_AGG(name)");
1154        assert!(
1155            result.contains("GROUP_CONCAT"),
1156            "Expected GROUP_CONCAT, got: {}",
1157            result
1158        );
1159    }
1160
1161    #[test]
1162    fn test_array_agg_to_group_concat() {
1163        let result = transpile_to_mysql("SELECT ARRAY_AGG(name)");
1164        assert!(
1165            result.contains("GROUP_CONCAT"),
1166            "Expected GROUP_CONCAT, got: {}",
1167            result
1168        );
1169    }
1170
1171    #[test]
1172    fn test_to_date_to_str_to_date() {
1173        let result = transpile_to_mysql("SELECT TO_DATE('2023-01-01')");
1174        assert!(
1175            result.contains("STR_TO_DATE"),
1176            "Expected STR_TO_DATE, got: {}",
1177            result
1178        );
1179    }
1180
1181    #[test]
1182    fn test_backtick_identifiers() {
1183        // MySQL uses backticks for identifiers
1184        let dialect = MySQLDialect;
1185        let config = dialect.generator_config();
1186        assert_eq!(config.identifier_quote, '`');
1187    }
1188
1189    fn mysql_identity(sql: &str, expected: &str) {
1190        let dialect = Dialect::get(DialectType::MySQL);
1191        let ast = dialect.parse(sql).expect("Parse failed");
1192        let transformed = dialect.transform(ast[0].clone()).expect("Transform failed");
1193        let result = dialect.generate(&transformed).expect("Generate failed");
1194        assert_eq!(result, expected, "SQL: {}", sql);
1195    }
1196
1197    #[test]
1198    fn test_ucase_to_upper() {
1199        mysql_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')");
1200    }
1201
1202    #[test]
1203    fn test_lcase_to_lower() {
1204        mysql_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')");
1205    }
1206
1207    #[test]
1208    fn test_day_of_month() {
1209        mysql_identity(
1210            "SELECT DAY_OF_MONTH('2023-01-01')",
1211            "SELECT DAYOFMONTH('2023-01-01')",
1212        );
1213    }
1214
1215    #[test]
1216    fn test_day_of_week() {
1217        mysql_identity(
1218            "SELECT DAY_OF_WEEK('2023-01-01')",
1219            "SELECT DAYOFWEEK('2023-01-01')",
1220        );
1221    }
1222
1223    #[test]
1224    fn test_day_of_year() {
1225        mysql_identity(
1226            "SELECT DAY_OF_YEAR('2023-01-01')",
1227            "SELECT DAYOFYEAR('2023-01-01')",
1228        );
1229    }
1230
1231    #[test]
1232    fn test_week_of_year() {
1233        mysql_identity(
1234            "SELECT WEEK_OF_YEAR('2023-01-01')",
1235            "SELECT WEEKOFYEAR('2023-01-01')",
1236        );
1237    }
1238
1239    #[test]
1240    fn test_mod_func_to_percent() {
1241        // MOD(x, y) function is transformed to x % y in MySQL
1242        mysql_identity("MOD(x, y)", "x % y");
1243    }
1244
1245    #[test]
1246    fn test_database_to_schema() {
1247        mysql_identity("DATABASE()", "SCHEMA()");
1248    }
1249
1250    #[test]
1251    fn test_and_operator() {
1252        mysql_identity("SELECT 1 && 0", "SELECT 1 AND 0");
1253    }
1254
1255    #[test]
1256    fn test_or_operator() {
1257        mysql_identity("SELECT a || b", "SELECT a OR b");
1258    }
1259}