Skip to main content

polyglot_sql/dialects/
mysql.rs

1//! MySQL Dialect
2//!
3//! MySQL-specific transformations based on sqlglot patterns.
4//! Key differences from standard SQL:
5//! - || is OR operator, not string concatenation (use CONCAT)
6//! - Uses backticks for identifiers
7//! - No TRY_CAST, no ILIKE
8//! - Different date/time function names
9
10use super::{DialectImpl, DialectType};
11use crate::error::Result;
12use crate::expressions::{
13    BinaryFunc, BinaryOp, Cast, DataType, Expression, Function, JsonExtractFunc, LikeOp, Literal,
14    Paren, UnaryFunc,
15};
16use crate::generator::GeneratorConfig;
17use crate::tokens::TokenizerConfig;
18
19/// Helper to wrap JSON arrow expressions in parentheses when they appear
20/// in contexts that require it (Binary, In, Not expressions)
21/// This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior
22fn wrap_if_json_arrow(expr: Expression) -> Expression {
23    match &expr {
24        Expression::JsonExtract(f) if f.arrow_syntax => Expression::Paren(Box::new(Paren {
25            this: expr,
26            trailing_comments: Vec::new(),
27        })),
28        Expression::JsonExtractScalar(f) if f.arrow_syntax => Expression::Paren(Box::new(Paren {
29            this: expr,
30            trailing_comments: Vec::new(),
31        })),
32        _ => expr,
33    }
34}
35
36/// Convert JSON arrow expression (-> or ->>) to JSON_EXTRACT function form
37/// This is needed for contexts like MEMBER OF where arrow syntax must become function form
38fn json_arrow_to_function(expr: Expression) -> Expression {
39    match expr {
40        Expression::JsonExtract(f) if f.arrow_syntax => Expression::Function(Box::new(
41            Function::new("JSON_EXTRACT".to_string(), vec![f.this, f.path]),
42        )),
43        Expression::JsonExtractScalar(f) if f.arrow_syntax => {
44            // ->> becomes JSON_UNQUOTE(JSON_EXTRACT(...)) but can be simplified to JSON_EXTRACT_SCALAR
45            // For MySQL, use JSON_UNQUOTE(JSON_EXTRACT(...))
46            let json_extract = Expression::Function(Box::new(Function::new(
47                "JSON_EXTRACT".to_string(),
48                vec![f.this, f.path],
49            )));
50            Expression::Function(Box::new(Function::new(
51                "JSON_UNQUOTE".to_string(),
52                vec![json_extract],
53            )))
54        }
55        other => other,
56    }
57}
58
59/// MySQL dialect
60pub struct MySQLDialect;
61
62impl DialectImpl for MySQLDialect {
63    fn dialect_type(&self) -> DialectType {
64        DialectType::MySQL
65    }
66
67    fn tokenizer_config(&self) -> TokenizerConfig {
68        use crate::tokens::TokenType;
69        let mut config = TokenizerConfig::default();
70        // MySQL uses backticks for identifiers
71        config.identifiers.insert('`', '`');
72        // Remove double quotes from identifiers - in MySQL they are string delimiters
73        // (unless ANSI_QUOTES mode is set, but default mode uses them as strings)
74        config.identifiers.remove(&'"');
75        // MySQL supports double quotes as string literals by default
76        config.quotes.insert("\"".to_string(), "\"".to_string());
77        // MySQL supports backslash escapes in strings
78        config.string_escapes.push('\\');
79        // MySQL has XOR as a logical operator keyword
80        config.keywords.insert("XOR".to_string(), TokenType::Xor);
81        // MySQL: backslash followed by chars NOT in this list -> discard backslash
82        // See: https://dev.mysql.com/doc/refman/8.4/en/string-literals.html
83        config.escape_follow_chars = vec!['0', 'b', 'n', 'r', 't', 'Z', '%', '_'];
84        // MySQL allows identifiers to start with digits (e.g., 1a, 1_a)
85        config.identifiers_can_start_with_digit = true;
86        config
87    }
88
89    fn generator_config(&self) -> GeneratorConfig {
90        use crate::generator::IdentifierQuoteStyle;
91        GeneratorConfig {
92            identifier_quote: '`',
93            identifier_quote_style: IdentifierQuoteStyle::BACKTICK,
94            dialect: Some(DialectType::MySQL),
95            // MySQL doesn't support null ordering in most contexts
96            null_ordering_supported: false,
97            // MySQL LIMIT only
98            limit_only_literals: true,
99            // MySQL doesn't support semi/anti join
100            semi_anti_join_with_side: false,
101            // MySQL doesn't support table alias columns in some contexts
102            supports_table_alias_columns: false,
103            // MySQL VALUES not used as table
104            values_as_table: false,
105            // MySQL doesn't support TABLESAMPLE
106            tablesample_requires_parens: false,
107            tablesample_with_method: false,
108            // MySQL doesn't support aggregate FILTER
109            aggregate_filter_supported: false,
110            // MySQL doesn't support TRY
111            try_supported: false,
112            // MySQL doesn't support CONVERT_TIMEZONE
113            supports_convert_timezone: false,
114            // MySQL doesn't support UESCAPE
115            supports_uescape: false,
116            // MySQL doesn't support BETWEEN flags
117            supports_between_flags: false,
118            // MySQL supports EXPLAIN but not query hints in standard way
119            query_hints: false,
120            // MySQL parameter token
121            parameter_token: "?",
122            // MySQL doesn't support window EXCLUDE
123            supports_window_exclude: false,
124            // MySQL doesn't support exploding projections
125            supports_exploding_projections: false,
126            identifiers_can_start_with_digit: true,
127            // MySQL supports FOR UPDATE/SHARE
128            locking_reads_supported: true,
129            ..Default::default()
130        }
131    }
132
133    fn transform_expr(&self, expr: Expression) -> Result<Expression> {
134        match expr {
135            // ===== Data Type Mappings =====
136            Expression::DataType(dt) => self.transform_data_type(dt),
137
138            // NVL -> IFNULL in MySQL
139            Expression::Nvl(f) => Ok(Expression::IfNull(f)),
140
141            // Note: COALESCE is valid in MySQL and should be preserved.
142            // Unlike some other dialects, we do NOT convert COALESCE to IFNULL
143            // as this would break identity tests.
144
145            // TryCast -> CAST or TIMESTAMP() (MySQL doesn't support TRY_CAST)
146            Expression::TryCast(c) => self.transform_cast(*c),
147
148            // SafeCast -> CAST or TIMESTAMP() (MySQL doesn't support safe casts)
149            Expression::SafeCast(c) => self.transform_cast(*c),
150
151            // Cast -> Transform cast type according to MySQL restrictions
152            // CAST AS TIMESTAMP -> TIMESTAMP() function in MySQL
153            Expression::Cast(c) => self.transform_cast(*c),
154
155            // ILIKE -> LOWER() LIKE LOWER() in MySQL
156            Expression::ILike(op) => {
157                // Transform ILIKE to: LOWER(left) LIKE LOWER(right)
158                let lower_left = Expression::Lower(Box::new(UnaryFunc::new(op.left)));
159                let lower_right = Expression::Lower(Box::new(UnaryFunc::new(op.right)));
160                Ok(Expression::Like(Box::new(LikeOp {
161                    left: lower_left,
162                    right: lower_right,
163                    escape: op.escape,
164                    quantifier: op.quantifier,
165                })))
166            }
167
168            // || (Concat operator) -> OR in MySQL
169            // MySQL uses || as OR by default (unless PIPES_AS_CONCAT mode is set)
170            // For identity preservation, we transform to OR
171            Expression::Concat(op) => Ok(Expression::Or(op)),
172
173            // RANDOM -> RAND in MySQL
174            Expression::Random(_) => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
175                seed: None,
176                lower: None,
177                upper: None,
178            }))),
179
180            // ArrayAgg -> GROUP_CONCAT in MySQL
181            Expression::ArrayAgg(f) => Ok(Expression::Function(Box::new(Function::new(
182                "GROUP_CONCAT".to_string(),
183                vec![f.this],
184            )))),
185
186            // StringAgg -> GROUP_CONCAT in MySQL
187            Expression::StringAgg(f) => {
188                let mut args = vec![f.this.clone()];
189                if let Some(separator) = &f.separator {
190                    args.push(separator.clone());
191                }
192                Ok(Expression::Function(Box::new(Function::new(
193                    "GROUP_CONCAT".to_string(),
194                    args,
195                ))))
196            }
197
198            // UNNEST -> Not directly supported in MySQL, use JSON_TABLE or inline
199            // For basic cases, pass through (may need manual handling)
200            Expression::Unnest(f) => {
201                // MySQL 8.0+ has JSON_TABLE which can be used for unnesting
202                // For now, pass through with a function call
203                Ok(Expression::Function(Box::new(Function::new(
204                    "JSON_TABLE".to_string(),
205                    vec![f.this],
206                ))))
207            }
208
209            // Substring: Use comma syntax (not FROM/FOR) in MySQL
210            Expression::Substring(mut f) => {
211                f.from_for_syntax = false;
212                Ok(Expression::Substring(f))
213            }
214
215            // ===== Bitwise operations =====
216            // BitwiseAndAgg -> BIT_AND
217            Expression::BitwiseAndAgg(f) => Ok(Expression::Function(Box::new(Function::new(
218                "BIT_AND".to_string(),
219                vec![f.this],
220            )))),
221
222            // BitwiseOrAgg -> BIT_OR
223            Expression::BitwiseOrAgg(f) => Ok(Expression::Function(Box::new(Function::new(
224                "BIT_OR".to_string(),
225                vec![f.this],
226            )))),
227
228            // BitwiseXorAgg -> BIT_XOR
229            Expression::BitwiseXorAgg(f) => Ok(Expression::Function(Box::new(Function::new(
230                "BIT_XOR".to_string(),
231                vec![f.this],
232            )))),
233
234            // BitwiseCount -> BIT_COUNT
235            Expression::BitwiseCount(f) => Ok(Expression::Function(Box::new(Function::new(
236                "BIT_COUNT".to_string(),
237                vec![f.this],
238            )))),
239
240            // TimeFromParts -> MAKETIME
241            Expression::TimeFromParts(f) => {
242                let mut args = Vec::new();
243                if let Some(h) = f.hour {
244                    args.push(*h);
245                }
246                if let Some(m) = f.min {
247                    args.push(*m);
248                }
249                if let Some(s) = f.sec {
250                    args.push(*s);
251                }
252                Ok(Expression::Function(Box::new(Function::new(
253                    "MAKETIME".to_string(),
254                    args,
255                ))))
256            }
257
258            // ===== Boolean aggregates =====
259            // In MySQL, there's no BOOL_AND/BOOL_OR, use MIN/MAX on boolean values
260            // LogicalAnd -> MIN (0 is false, non-0 is true)
261            Expression::LogicalAnd(f) => Ok(Expression::Function(Box::new(Function::new(
262                "MIN".to_string(),
263                vec![f.this],
264            )))),
265
266            // LogicalOr -> MAX
267            Expression::LogicalOr(f) => Ok(Expression::Function(Box::new(Function::new(
268                "MAX".to_string(),
269                vec![f.this],
270            )))),
271
272            // ===== Date/time functions =====
273            // DayOfMonth -> DAYOFMONTH
274            Expression::DayOfMonth(f) => Ok(Expression::Function(Box::new(Function::new(
275                "DAYOFMONTH".to_string(),
276                vec![f.this],
277            )))),
278
279            // DayOfWeek -> DAYOFWEEK
280            Expression::DayOfWeek(f) => Ok(Expression::Function(Box::new(Function::new(
281                "DAYOFWEEK".to_string(),
282                vec![f.this],
283            )))),
284
285            // DayOfYear -> DAYOFYEAR
286            Expression::DayOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
287                "DAYOFYEAR".to_string(),
288                vec![f.this],
289            )))),
290
291            // WeekOfYear -> WEEKOFYEAR
292            Expression::WeekOfYear(f) => Ok(Expression::Function(Box::new(Function::new(
293                "WEEKOFYEAR".to_string(),
294                vec![f.this],
295            )))),
296
297            // DateDiff -> DATEDIFF
298            Expression::DateDiff(f) => Ok(Expression::Function(Box::new(Function::new(
299                "DATEDIFF".to_string(),
300                vec![f.this, f.expression],
301            )))),
302
303            // TimeStrToUnix -> UNIX_TIMESTAMP
304            Expression::TimeStrToUnix(f) => Ok(Expression::Function(Box::new(Function::new(
305                "UNIX_TIMESTAMP".to_string(),
306                vec![f.this],
307            )))),
308
309            // TimestampDiff -> TIMESTAMPDIFF
310            Expression::TimestampDiff(f) => Ok(Expression::Function(Box::new(Function::new(
311                "TIMESTAMPDIFF".to_string(),
312                vec![*f.this, *f.expression],
313            )))),
314
315            // ===== String functions =====
316            // StrPosition -> LOCATE in MySQL
317            // STRPOS(str, substr) -> LOCATE(substr, str) (args are swapped)
318            Expression::StrPosition(f) => {
319                let mut args = vec![];
320                if let Some(substr) = f.substr {
321                    args.push(*substr);
322                }
323                args.push(*f.this);
324                if let Some(pos) = f.position {
325                    args.push(*pos);
326                }
327                Ok(Expression::Function(Box::new(Function::new(
328                    "LOCATE".to_string(),
329                    args,
330                ))))
331            }
332
333            // Stuff -> INSERT in MySQL
334            Expression::Stuff(f) => {
335                let mut args = vec![*f.this];
336                if let Some(start) = f.start {
337                    args.push(*start);
338                }
339                if let Some(length) = f.length {
340                    args.push(Expression::number(length));
341                }
342                args.push(*f.expression);
343                Ok(Expression::Function(Box::new(Function::new(
344                    "INSERT".to_string(),
345                    args,
346                ))))
347            }
348
349            // ===== Session/User functions =====
350            // SessionUser -> SESSION_USER()
351            Expression::SessionUser(_) => Ok(Expression::Function(Box::new(Function::new(
352                "SESSION_USER".to_string(),
353                vec![],
354            )))),
355
356            // CurrentDate -> CURRENT_DATE (no parentheses in MySQL) - keep as CurrentDate
357            Expression::CurrentDate(_) => {
358                Ok(Expression::CurrentDate(crate::expressions::CurrentDate))
359            }
360
361            // ===== Null-safe comparison =====
362            // NullSafeNeq -> NOT (a <=> b) in MySQL
363            Expression::NullSafeNeq(op) => {
364                // Create: NOT (left <=> right)
365                let null_safe_eq = Expression::NullSafeEq(Box::new(crate::expressions::BinaryOp {
366                    left: op.left,
367                    right: op.right,
368                    left_comments: Vec::new(),
369                    operator_comments: Vec::new(),
370                    trailing_comments: Vec::new(),
371                }));
372                Ok(Expression::Not(Box::new(crate::expressions::UnaryOp {
373                    this: null_safe_eq,
374                })))
375            }
376
377            // ParseJson: handled by generator (emits just the string literal for MySQL)
378
379            // JSONExtract with variant_extract (Snowflake colon syntax) -> JSON_EXTRACT
380            Expression::JSONExtract(e) if e.variant_extract.is_some() => {
381                let path = match *e.expression {
382                    Expression::Literal(Literal::String(s)) => {
383                        // Convert bracket notation ["key"] to quoted dot notation ."key"
384                        let s = Self::convert_bracket_to_quoted_path(&s);
385                        let normalized = if s.starts_with('$') {
386                            s
387                        } else if s.starts_with('[') {
388                            format!("${}", s)
389                        } else {
390                            format!("$.{}", s)
391                        };
392                        Expression::Literal(Literal::String(normalized))
393                    }
394                    other => other,
395                };
396                Ok(Expression::Function(Box::new(Function::new(
397                    "JSON_EXTRACT".to_string(),
398                    vec![*e.this, path],
399                ))))
400            }
401
402            // Generic function transformations
403            Expression::Function(f) => self.transform_function(*f),
404
405            // Generic aggregate function transformations
406            Expression::AggregateFunction(f) => self.transform_aggregate_function(f),
407
408            // ===== Context-aware JSON arrow wrapping =====
409            // When JSON arrow expressions appear in Binary/In/Not contexts,
410            // they need to be wrapped in parentheses for correct precedence.
411            // This matches Python sqlglot's WRAPPED_JSON_EXTRACT_EXPRESSIONS behavior.
412
413            // Binary operators that need JSON wrapping
414            Expression::Eq(op) => Ok(Expression::Eq(Box::new(BinaryOp {
415                left: wrap_if_json_arrow(op.left),
416                right: wrap_if_json_arrow(op.right),
417                ..*op
418            }))),
419            Expression::Neq(op) => Ok(Expression::Neq(Box::new(BinaryOp {
420                left: wrap_if_json_arrow(op.left),
421                right: wrap_if_json_arrow(op.right),
422                ..*op
423            }))),
424            Expression::Lt(op) => Ok(Expression::Lt(Box::new(BinaryOp {
425                left: wrap_if_json_arrow(op.left),
426                right: wrap_if_json_arrow(op.right),
427                ..*op
428            }))),
429            Expression::Lte(op) => Ok(Expression::Lte(Box::new(BinaryOp {
430                left: wrap_if_json_arrow(op.left),
431                right: wrap_if_json_arrow(op.right),
432                ..*op
433            }))),
434            Expression::Gt(op) => Ok(Expression::Gt(Box::new(BinaryOp {
435                left: wrap_if_json_arrow(op.left),
436                right: wrap_if_json_arrow(op.right),
437                ..*op
438            }))),
439            Expression::Gte(op) => Ok(Expression::Gte(Box::new(BinaryOp {
440                left: wrap_if_json_arrow(op.left),
441                right: wrap_if_json_arrow(op.right),
442                ..*op
443            }))),
444
445            // In expression - wrap the this part if it's JSON arrow
446            Expression::In(mut i) => {
447                i.this = wrap_if_json_arrow(i.this);
448                Ok(Expression::In(i))
449            }
450
451            // Not expression - wrap the this part if it's JSON arrow
452            Expression::Not(mut n) => {
453                n.this = wrap_if_json_arrow(n.this);
454                Ok(Expression::Not(n))
455            }
456
457            // && in MySQL is logical AND, not array overlaps
458            // Transform ArrayOverlaps -> And for MySQL identity
459            Expression::ArrayOverlaps(op) => Ok(Expression::And(op)),
460
461            // MOD(x, y) -> x % y in MySQL
462            Expression::ModFunc(f) => Ok(Expression::Mod(Box::new(BinaryOp {
463                left: f.this,
464                right: f.expression,
465                left_comments: Vec::new(),
466                operator_comments: Vec::new(),
467                trailing_comments: Vec::new(),
468            }))),
469
470            // SHOW SLAVE STATUS -> SHOW REPLICA STATUS
471            Expression::Show(mut s) => {
472                if s.this == "SLAVE STATUS" {
473                    s.this = "REPLICA STATUS".to_string();
474                }
475                if matches!(s.this.as_str(), "INDEX" | "COLUMNS") && s.db.is_none() {
476                    if let Some(Expression::Table(mut t)) = s.target.take() {
477                        if let Some(db_ident) = t.schema.take().or(t.catalog.take()) {
478                            s.db = Some(Expression::Identifier(db_ident));
479                            s.target = Some(Expression::Identifier(t.name));
480                        } else {
481                            s.target = Some(Expression::Table(t));
482                        }
483                    }
484                }
485                Ok(Expression::Show(s))
486            }
487
488            // AT TIME ZONE -> strip timezone (MySQL doesn't support AT TIME ZONE)
489            // But keep it for CURRENT_DATE/CURRENT_TIMESTAMP with timezone (transpiled from BigQuery)
490            Expression::AtTimeZone(atz) => {
491                let is_current = match &atz.this {
492                    Expression::CurrentDate(_) | Expression::CurrentTimestamp(_) => true,
493                    Expression::Function(f) => {
494                        let n = f.name.to_uppercase();
495                        (n == "CURRENT_DATE" || n == "CURRENT_TIMESTAMP") && f.no_parens
496                    }
497                    _ => false,
498                };
499                if is_current {
500                    Ok(Expression::AtTimeZone(atz)) // Keep AT TIME ZONE for CURRENT_DATE/CURRENT_TIMESTAMP
501                } else {
502                    Ok(atz.this) // Strip timezone for other expressions
503                }
504            }
505
506            // MEMBER OF with JSON arrow -> convert arrow to JSON_EXTRACT function
507            // MySQL's MEMBER OF requires JSON_EXTRACT function form, not arrow syntax
508            Expression::MemberOf(mut op) => {
509                op.right = json_arrow_to_function(op.right);
510                Ok(Expression::MemberOf(op))
511            }
512
513            // Pass through everything else
514            _ => Ok(expr),
515        }
516    }
517}
518
519impl MySQLDialect {
520    fn normalize_mysql_date_format(fmt: &str) -> String {
521        fmt.replace("%H:%i:%s", "%T").replace("%H:%i:%S", "%T")
522    }
523
524    /// Convert bracket notation ["key with spaces"] to quoted dot notation ."key with spaces"
525    /// in JSON path strings.
526    fn convert_bracket_to_quoted_path(path: &str) -> String {
527        let mut result = String::new();
528        let mut chars = path.chars().peekable();
529        while let Some(c) = chars.next() {
530            if c == '[' && chars.peek() == Some(&'"') {
531                chars.next(); // consume "
532                let mut key = String::new();
533                while let Some(kc) = chars.next() {
534                    if kc == '"' && chars.peek() == Some(&']') {
535                        chars.next(); // consume ]
536                        break;
537                    }
538                    key.push(kc);
539                }
540                if !result.is_empty() && !result.ends_with('.') {
541                    result.push('.');
542                }
543                result.push('"');
544                result.push_str(&key);
545                result.push('"');
546            } else {
547                result.push(c);
548            }
549        }
550        result
551    }
552
553    /// Transform data types according to MySQL TYPE_MAPPING
554    /// Note: MySQL's TIMESTAMP is kept as TIMESTAMP (not converted to DATETIME)
555    /// because MySQL's TIMESTAMP has timezone awareness built-in
556    fn transform_data_type(&self, dt: crate::expressions::DataType) -> Result<Expression> {
557        use crate::expressions::DataType;
558        let transformed = match dt {
559            // All TIMESTAMP variants (with or without timezone) -> TIMESTAMP in MySQL
560            DataType::Timestamp {
561                precision,
562                timezone: _,
563            } => DataType::Timestamp {
564                precision,
565                timezone: false,
566            },
567            // TIMESTAMPTZ / TIMESTAMPLTZ parsed as Custom -> normalize to TIMESTAMP
568            DataType::Custom { name }
569                if name.to_uppercase() == "TIMESTAMPTZ"
570                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
571            {
572                DataType::Timestamp {
573                    precision: None,
574                    timezone: false,
575                }
576            }
577            // Keep native MySQL types as-is
578            // MySQL supports TEXT, MEDIUMTEXT, LONGTEXT, BLOB, etc. natively
579            other => other,
580        };
581        Ok(Expression::DataType(transformed))
582    }
583
584    /// Transform CAST expression
585    /// MySQL uses TIMESTAMP() function instead of CAST(x AS TIMESTAMP)
586    /// For Generic->MySQL, TIMESTAMP (no tz) is pre-converted to DATETIME in cross_dialect_normalize
587    fn transform_cast(&self, cast: Cast) -> Result<Expression> {
588        // CAST AS TIMESTAMP/TIMESTAMPTZ/TIMESTAMPLTZ -> TIMESTAMP() function
589        match &cast.to {
590            DataType::Timestamp { .. } => Ok(Expression::Function(Box::new(Function::new(
591                "TIMESTAMP".to_string(),
592                vec![cast.this],
593            )))),
594            DataType::Custom { name }
595                if name.to_uppercase() == "TIMESTAMPTZ"
596                    || name.to_uppercase() == "TIMESTAMPLTZ" =>
597            {
598                Ok(Expression::Function(Box::new(Function::new(
599                    "TIMESTAMP".to_string(),
600                    vec![cast.this],
601                ))))
602            }
603            // All other casts go through normal type transformation
604            _ => Ok(Expression::Cast(Box::new(self.transform_cast_type(cast)))),
605        }
606    }
607
608    /// Transform CAST type according to MySQL restrictions
609    /// MySQL doesn't support many types in CAST - they get mapped to CHAR or SIGNED
610    /// Based on Python sqlglot's CHAR_CAST_MAPPING and SIGNED_CAST_MAPPING
611    fn transform_cast_type(&self, cast: Cast) -> Cast {
612        let new_type = match &cast.to {
613            // CHAR_CAST_MAPPING: These types become CHAR in MySQL CAST, preserving length
614            DataType::VarChar { length, .. } => DataType::Char { length: *length },
615            DataType::Text => DataType::Char { length: None },
616
617            // SIGNED_CAST_MAPPING: These integer types become SIGNED in MySQL CAST
618            DataType::BigInt { .. } => DataType::Custom {
619                name: "SIGNED".to_string(),
620            },
621            DataType::Int { .. } => DataType::Custom {
622                name: "SIGNED".to_string(),
623            },
624            DataType::SmallInt { .. } => DataType::Custom {
625                name: "SIGNED".to_string(),
626            },
627            DataType::TinyInt { .. } => DataType::Custom {
628                name: "SIGNED".to_string(),
629            },
630            DataType::Boolean => DataType::Custom {
631                name: "SIGNED".to_string(),
632            },
633
634            // Custom types that need mapping
635            DataType::Custom { name } => {
636                let upper = name.to_uppercase();
637                match upper.as_str() {
638                    // Text/Blob types -> keep as Custom for cross-dialect mapping
639                    // MySQL generator will output CHAR for these in CAST context
640                    "LONGTEXT" | "MEDIUMTEXT" | "TINYTEXT" | "LONGBLOB" | "MEDIUMBLOB"
641                    | "TINYBLOB" => DataType::Custom { name: upper },
642                    // MEDIUMINT -> SIGNED in MySQL CAST
643                    "MEDIUMINT" => DataType::Custom {
644                        name: "SIGNED".to_string(),
645                    },
646                    // Unsigned integer types -> UNSIGNED
647                    "UBIGINT" | "UINT" | "USMALLINT" | "UTINYINT" | "UMEDIUMINT" => {
648                        DataType::Custom {
649                            name: "UNSIGNED".to_string(),
650                        }
651                    }
652                    // Keep other custom types
653                    _ => cast.to.clone(),
654                }
655            }
656
657            // Types that are valid in MySQL CAST - pass through
658            DataType::Binary { .. } => cast.to.clone(),
659            DataType::VarBinary { .. } => cast.to.clone(),
660            DataType::Date => cast.to.clone(),
661            DataType::Time { .. } => cast.to.clone(),
662            DataType::Decimal { .. } => cast.to.clone(),
663            DataType::Json => cast.to.clone(),
664            DataType::Float { .. } => cast.to.clone(),
665            DataType::Double { .. } => cast.to.clone(),
666            DataType::Char { .. } => cast.to.clone(),
667            DataType::CharacterSet { .. } => cast.to.clone(),
668            DataType::Enum { .. } => cast.to.clone(),
669            DataType::Set { .. } => cast.to.clone(),
670            DataType::Timestamp { .. } => cast.to.clone(),
671
672            // All other unsupported types -> CHAR
673            _ => DataType::Char { length: None },
674        };
675
676        Cast {
677            this: cast.this,
678            to: new_type,
679            trailing_comments: cast.trailing_comments,
680            double_colon_syntax: cast.double_colon_syntax,
681            format: cast.format,
682            default: cast.default,
683        }
684    }
685
686    fn transform_function(&self, f: Function) -> Result<Expression> {
687        let name_upper = f.name.to_uppercase();
688        match name_upper.as_str() {
689            // Normalize DATE_FORMAT short-hands to canonical MySQL forms.
690            "DATE_FORMAT" if f.args.len() >= 2 => {
691                let mut f = f;
692                if let Some(Expression::Literal(Literal::String(fmt))) = f.args.get(1) {
693                    let normalized = Self::normalize_mysql_date_format(fmt);
694                    if normalized != *fmt {
695                        f.args[1] = Expression::Literal(Literal::String(normalized));
696                    }
697                }
698                Ok(Expression::Function(Box::new(f)))
699            }
700
701            // NVL -> IFNULL
702            "NVL" if f.args.len() == 2 => {
703                let mut args = f.args;
704                let second = args.pop().unwrap();
705                let first = args.pop().unwrap();
706                Ok(Expression::IfNull(Box::new(BinaryFunc {
707                    original_name: None,
708                    this: first,
709                    expression: second,
710                })))
711            }
712
713            // Note: COALESCE is native to MySQL. We do NOT convert it to IFNULL
714            // because this would break identity tests (Python SQLGlot preserves COALESCE).
715
716            // ARRAY_AGG -> GROUP_CONCAT
717            "ARRAY_AGG" if f.args.len() == 1 => {
718                let mut args = f.args;
719                Ok(Expression::Function(Box::new(Function::new(
720                    "GROUP_CONCAT".to_string(),
721                    vec![args.pop().unwrap()],
722                ))))
723            }
724
725            // STRING_AGG -> GROUP_CONCAT
726            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
727                Function::new("GROUP_CONCAT".to_string(), f.args),
728            ))),
729
730            // RANDOM -> RAND
731            "RANDOM" => Ok(Expression::Rand(Box::new(crate::expressions::Rand {
732                seed: None,
733                lower: None,
734                upper: None,
735            }))),
736
737            // CURRENT_TIMESTAMP -> NOW() or CURRENT_TIMESTAMP (both work)
738            // Preserve precision if specified: CURRENT_TIMESTAMP(6)
739            "CURRENT_TIMESTAMP" => {
740                let precision =
741                    if let Some(Expression::Literal(crate::expressions::Literal::Number(n))) =
742                        f.args.first()
743                    {
744                        n.parse::<u32>().ok()
745                    } else {
746                        None
747                    };
748                Ok(Expression::CurrentTimestamp(
749                    crate::expressions::CurrentTimestamp {
750                        precision,
751                        sysdate: false,
752                    },
753                ))
754            }
755
756            // POSITION -> LOCATE in MySQL (argument order is different)
757            // POSITION(substr IN str) -> LOCATE(substr, str)
758            "POSITION" if f.args.len() == 2 => Ok(Expression::Function(Box::new(Function::new(
759                "LOCATE".to_string(),
760                f.args,
761            )))),
762
763            // LENGTH is native to MySQL (returns bytes, not characters)
764            // CHAR_LENGTH for character count
765            "LENGTH" => Ok(Expression::Function(Box::new(f))),
766
767            // CEIL -> CEILING in MySQL (both work)
768            "CEIL" if f.args.len() == 1 => Ok(Expression::Function(Box::new(Function::new(
769                "CEILING".to_string(),
770                f.args,
771            )))),
772
773            // STDDEV -> STD or STDDEV_POP in MySQL
774            "STDDEV" => Ok(Expression::Function(Box::new(Function::new(
775                "STD".to_string(),
776                f.args,
777            )))),
778
779            // STDDEV_SAMP -> STDDEV in MySQL
780            "STDDEV_SAMP" => Ok(Expression::Function(Box::new(Function::new(
781                "STDDEV".to_string(),
782                f.args,
783            )))),
784
785            // TO_DATE -> STR_TO_DATE in MySQL
786            "TO_DATE" => Ok(Expression::Function(Box::new(Function::new(
787                "STR_TO_DATE".to_string(),
788                f.args,
789            )))),
790
791            // TO_TIMESTAMP -> STR_TO_DATE in MySQL
792            "TO_TIMESTAMP" => Ok(Expression::Function(Box::new(Function::new(
793                "STR_TO_DATE".to_string(),
794                f.args,
795            )))),
796
797            // DATE_TRUNC -> Complex transformation
798            // Typically uses DATE() or DATE_FORMAT() depending on unit
799            "DATE_TRUNC" if f.args.len() >= 2 => {
800                // Simplified: DATE_TRUNC('day', x) -> DATE(x)
801                // Full implementation would handle different units
802                let mut args = f.args;
803                let _unit = args.remove(0);
804                let date = args.remove(0);
805                Ok(Expression::Function(Box::new(Function::new(
806                    "DATE".to_string(),
807                    vec![date],
808                ))))
809            }
810
811            // EXTRACT is native but syntax varies
812
813            // COALESCE is native to MySQL (keep as-is for more than 2 args)
814            "COALESCE" if f.args.len() > 2 => Ok(Expression::Function(Box::new(f))),
815
816            // DAYOFMONTH -> DAY (both work)
817            "DAY" => Ok(Expression::Function(Box::new(Function::new(
818                "DAYOFMONTH".to_string(),
819                f.args,
820            )))),
821
822            // DAYOFWEEK is native to MySQL
823            "DAYOFWEEK" => Ok(Expression::Function(Box::new(f))),
824
825            // DAYOFYEAR is native to MySQL
826            "DAYOFYEAR" => Ok(Expression::Function(Box::new(f))),
827
828            // WEEKOFYEAR is native to MySQL
829            "WEEKOFYEAR" => Ok(Expression::Function(Box::new(f))),
830
831            // LAST_DAY is native to MySQL
832            "LAST_DAY" => Ok(Expression::Function(Box::new(f))),
833
834            // TIMESTAMPADD -> DATE_ADD
835            "TIMESTAMPADD" => Ok(Expression::Function(Box::new(Function::new(
836                "DATE_ADD".to_string(),
837                f.args,
838            )))),
839
840            // TIMESTAMPDIFF is native to MySQL
841            "TIMESTAMPDIFF" => Ok(Expression::Function(Box::new(f))),
842
843            // CONVERT_TIMEZONE(from_tz, to_tz, timestamp) -> CONVERT_TZ(timestamp, from_tz, to_tz) in MySQL
844            "CONVERT_TIMEZONE" if f.args.len() == 3 => {
845                let mut args = f.args;
846                let from_tz = args.remove(0);
847                let to_tz = args.remove(0);
848                let timestamp = args.remove(0);
849                Ok(Expression::Function(Box::new(Function::new(
850                    "CONVERT_TZ".to_string(),
851                    vec![timestamp, from_tz, to_tz],
852                ))))
853            }
854
855            // UTC_TIMESTAMP is native to MySQL
856            "UTC_TIMESTAMP" => Ok(Expression::Function(Box::new(f))),
857
858            // UTC_TIME is native to MySQL
859            "UTC_TIME" => Ok(Expression::Function(Box::new(f))),
860
861            // MAKETIME is native to MySQL (TimeFromParts)
862            "MAKETIME" => Ok(Expression::Function(Box::new(f))),
863
864            // TIME_FROM_PARTS -> MAKETIME
865            "TIME_FROM_PARTS" if f.args.len() == 3 => Ok(Expression::Function(Box::new(
866                Function::new("MAKETIME".to_string(), f.args),
867            ))),
868
869            // STUFF -> INSERT in MySQL
870            "STUFF" if f.args.len() == 4 => Ok(Expression::Function(Box::new(Function::new(
871                "INSERT".to_string(),
872                f.args,
873            )))),
874
875            // LOCATE is native to MySQL (reverse of POSITION args)
876            "LOCATE" => Ok(Expression::Function(Box::new(f))),
877
878            // FIND_IN_SET is native to MySQL
879            "FIND_IN_SET" => Ok(Expression::Function(Box::new(f))),
880
881            // FORMAT is native to MySQL (NumberToStr)
882            "FORMAT" => Ok(Expression::Function(Box::new(f))),
883
884            // JSON_EXTRACT is native to MySQL
885            "JSON_EXTRACT" => Ok(Expression::Function(Box::new(f))),
886
887            // JSON_UNQUOTE is native to MySQL
888            "JSON_UNQUOTE" => Ok(Expression::Function(Box::new(f))),
889
890            // JSON_EXTRACT_PATH_TEXT -> JSON_UNQUOTE(JSON_EXTRACT(...))
891            "JSON_EXTRACT_PATH_TEXT" if f.args.len() >= 2 => {
892                let extract = Expression::Function(Box::new(Function::new(
893                    "JSON_EXTRACT".to_string(),
894                    f.args,
895                )));
896                Ok(Expression::Function(Box::new(Function::new(
897                    "JSON_UNQUOTE".to_string(),
898                    vec![extract],
899                ))))
900            }
901
902            // GEN_RANDOM_UUID / UUID -> UUID()
903            "GEN_RANDOM_UUID" | "GENERATE_UUID" => Ok(Expression::Function(Box::new(
904                Function::new("UUID".to_string(), vec![]),
905            ))),
906
907            // DATABASE() -> SCHEMA() in MySQL (both return current database name)
908            "DATABASE" => Ok(Expression::Function(Box::new(Function::new(
909                "SCHEMA".to_string(),
910                f.args,
911            )))),
912
913            // INSTR -> LOCATE in MySQL (with swapped arguments)
914            // INSTR(str, substr) -> LOCATE(substr, str)
915            "INSTR" if f.args.len() == 2 => {
916                let mut args = f.args;
917                let str_arg = args.remove(0);
918                let substr_arg = args.remove(0);
919                Ok(Expression::Function(Box::new(Function::new(
920                    "LOCATE".to_string(),
921                    vec![substr_arg, str_arg],
922                ))))
923            }
924
925            // TIME_STR_TO_UNIX -> UNIX_TIMESTAMP in MySQL
926            "TIME_STR_TO_UNIX" => Ok(Expression::Function(Box::new(Function::new(
927                "UNIX_TIMESTAMP".to_string(),
928                f.args,
929            )))),
930
931            // TIME_STR_TO_TIME -> CAST AS DATETIME(N) or TIMESTAMP() in MySQL
932            "TIME_STR_TO_TIME" if f.args.len() >= 1 => {
933                let mut args = f.args.into_iter();
934                let arg = args.next().unwrap();
935
936                // If there's a timezone arg, use TIMESTAMP() function instead
937                if args.next().is_some() {
938                    return Ok(Expression::Function(Box::new(Function::new(
939                        "TIMESTAMP".to_string(),
940                        vec![arg],
941                    ))));
942                }
943
944                // Extract sub-second precision from the string literal
945                let precision =
946                    if let Expression::Literal(crate::expressions::Literal::String(ref s)) = arg {
947                        // Find fractional seconds: look for .NNN pattern after HH:MM:SS
948                        if let Some(dot_pos) = s.rfind('.') {
949                            let after_dot = &s[dot_pos + 1..];
950                            // Count digits until non-digit
951                            let frac_digits =
952                                after_dot.chars().take_while(|c| c.is_ascii_digit()).count();
953                            if frac_digits > 0 {
954                                // Round up: 1-3 digits → 3, 4-6 digits → 6
955                                if frac_digits <= 3 {
956                                    Some(3)
957                                } else {
958                                    Some(6)
959                                }
960                            } else {
961                                None
962                            }
963                        } else {
964                            None
965                        }
966                    } else {
967                        None
968                    };
969
970                let type_name = match precision {
971                    Some(p) => format!("DATETIME({})", p),
972                    None => "DATETIME".to_string(),
973                };
974
975                Ok(Expression::Cast(Box::new(Cast {
976                    this: arg,
977                    to: DataType::Custom { name: type_name },
978                    trailing_comments: Vec::new(),
979                    double_colon_syntax: false,
980                    format: None,
981                    default: None,
982                })))
983            }
984
985            // UCASE -> UPPER in MySQL
986            "UCASE" => Ok(Expression::Function(Box::new(Function::new(
987                "UPPER".to_string(),
988                f.args,
989            )))),
990
991            // LCASE -> LOWER in MySQL
992            "LCASE" => Ok(Expression::Function(Box::new(Function::new(
993                "LOWER".to_string(),
994                f.args,
995            )))),
996
997            // DAY_OF_MONTH -> DAYOFMONTH in MySQL
998            "DAY_OF_MONTH" => Ok(Expression::Function(Box::new(Function::new(
999                "DAYOFMONTH".to_string(),
1000                f.args,
1001            )))),
1002
1003            // DAY_OF_WEEK -> DAYOFWEEK in MySQL
1004            "DAY_OF_WEEK" => Ok(Expression::Function(Box::new(Function::new(
1005                "DAYOFWEEK".to_string(),
1006                f.args,
1007            )))),
1008
1009            // DAY_OF_YEAR -> DAYOFYEAR in MySQL
1010            "DAY_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
1011                "DAYOFYEAR".to_string(),
1012                f.args,
1013            )))),
1014
1015            // WEEK_OF_YEAR -> WEEKOFYEAR in MySQL
1016            "WEEK_OF_YEAR" => Ok(Expression::Function(Box::new(Function::new(
1017                "WEEKOFYEAR".to_string(),
1018                f.args,
1019            )))),
1020
1021            // MOD(x, y) -> x % y in MySQL
1022            "MOD" if f.args.len() == 2 => {
1023                let mut args = f.args;
1024                let left = args.remove(0);
1025                let right = args.remove(0);
1026                Ok(Expression::Mod(Box::new(BinaryOp {
1027                    left,
1028                    right,
1029                    left_comments: Vec::new(),
1030                    operator_comments: Vec::new(),
1031                    trailing_comments: Vec::new(),
1032                })))
1033            }
1034
1035            // PARSE_JSON -> strip in MySQL (just keep the string argument)
1036            "PARSE_JSON" if f.args.len() == 1 => Ok(f.args.into_iter().next().unwrap()),
1037
1038            // GET_PATH(obj, path) -> JSON_EXTRACT(obj, json_path) in MySQL
1039            "GET_PATH" if f.args.len() == 2 => {
1040                let mut args = f.args;
1041                let this = args.remove(0);
1042                let path = args.remove(0);
1043                let json_path = match &path {
1044                    Expression::Literal(Literal::String(s)) => {
1045                        // Convert bracket notation ["key"] to quoted dot notation ."key"
1046                        let s = Self::convert_bracket_to_quoted_path(s);
1047                        let normalized = if s.starts_with('$') {
1048                            s
1049                        } else if s.starts_with('[') {
1050                            format!("${}", s)
1051                        } else {
1052                            format!("$.{}", s)
1053                        };
1054                        Expression::Literal(Literal::String(normalized))
1055                    }
1056                    _ => path,
1057                };
1058                Ok(Expression::JsonExtract(Box::new(JsonExtractFunc {
1059                    this,
1060                    path: json_path,
1061                    returning: None,
1062                    arrow_syntax: false,
1063                    hash_arrow_syntax: false,
1064                    wrapper_option: None,
1065                    quotes_option: None,
1066                    on_scalar_string: false,
1067                    on_error: None,
1068                })))
1069            }
1070
1071            // REGEXP -> REGEXP_LIKE (MySQL standard form)
1072            "REGEXP" if f.args.len() >= 2 => Ok(Expression::Function(Box::new(Function::new(
1073                "REGEXP_LIKE".to_string(),
1074                f.args,
1075            )))),
1076
1077            // Pass through everything else
1078            _ => Ok(Expression::Function(Box::new(f))),
1079        }
1080    }
1081
1082    fn transform_aggregate_function(
1083        &self,
1084        f: Box<crate::expressions::AggregateFunction>,
1085    ) -> Result<Expression> {
1086        let name_upper = f.name.to_uppercase();
1087        match name_upper.as_str() {
1088            // STRING_AGG -> GROUP_CONCAT
1089            "STRING_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(
1090                Function::new("GROUP_CONCAT".to_string(), f.args),
1091            ))),
1092
1093            // ARRAY_AGG -> GROUP_CONCAT
1094            "ARRAY_AGG" if !f.args.is_empty() => Ok(Expression::Function(Box::new(Function::new(
1095                "GROUP_CONCAT".to_string(),
1096                f.args,
1097            )))),
1098
1099            // Pass through everything else
1100            _ => Ok(Expression::AggregateFunction(f)),
1101        }
1102    }
1103}
1104
1105#[cfg(test)]
1106mod tests {
1107    use super::*;
1108    use crate::dialects::Dialect;
1109
1110    fn transpile_to_mysql(sql: &str) -> String {
1111        let dialect = Dialect::get(DialectType::Generic);
1112        let result = dialect
1113            .transpile_to(sql, DialectType::MySQL)
1114            .expect("Transpile failed");
1115        result[0].clone()
1116    }
1117
1118    #[test]
1119    fn test_nvl_to_ifnull() {
1120        let result = transpile_to_mysql("SELECT NVL(a, b)");
1121        assert!(
1122            result.contains("IFNULL"),
1123            "Expected IFNULL, got: {}",
1124            result
1125        );
1126    }
1127
1128    #[test]
1129    fn test_coalesce_preserved() {
1130        // COALESCE should be preserved in MySQL (it's a native function)
1131        let result = transpile_to_mysql("SELECT COALESCE(a, b)");
1132        assert!(
1133            result.contains("COALESCE"),
1134            "Expected COALESCE to be preserved, got: {}",
1135            result
1136        );
1137    }
1138
1139    #[test]
1140    fn test_random_to_rand() {
1141        let result = transpile_to_mysql("SELECT RANDOM()");
1142        assert!(result.contains("RAND"), "Expected RAND, got: {}", result);
1143    }
1144
1145    #[test]
1146    fn test_basic_select() {
1147        let result = transpile_to_mysql("SELECT a, b FROM users WHERE id = 1");
1148        assert!(result.contains("SELECT"));
1149        assert!(result.contains("FROM users"));
1150    }
1151
1152    #[test]
1153    fn test_string_agg_to_group_concat() {
1154        let result = transpile_to_mysql("SELECT STRING_AGG(name)");
1155        assert!(
1156            result.contains("GROUP_CONCAT"),
1157            "Expected GROUP_CONCAT, got: {}",
1158            result
1159        );
1160    }
1161
1162    #[test]
1163    fn test_array_agg_to_group_concat() {
1164        let result = transpile_to_mysql("SELECT ARRAY_AGG(name)");
1165        assert!(
1166            result.contains("GROUP_CONCAT"),
1167            "Expected GROUP_CONCAT, got: {}",
1168            result
1169        );
1170    }
1171
1172    #[test]
1173    fn test_to_date_to_str_to_date() {
1174        let result = transpile_to_mysql("SELECT TO_DATE('2023-01-01')");
1175        assert!(
1176            result.contains("STR_TO_DATE"),
1177            "Expected STR_TO_DATE, got: {}",
1178            result
1179        );
1180    }
1181
1182    #[test]
1183    fn test_backtick_identifiers() {
1184        // MySQL uses backticks for identifiers
1185        let dialect = MySQLDialect;
1186        let config = dialect.generator_config();
1187        assert_eq!(config.identifier_quote, '`');
1188    }
1189
1190    fn mysql_identity(sql: &str, expected: &str) {
1191        let dialect = Dialect::get(DialectType::MySQL);
1192        let ast = dialect.parse(sql).expect("Parse failed");
1193        let transformed = dialect.transform(ast[0].clone()).expect("Transform failed");
1194        let result = dialect.generate(&transformed).expect("Generate failed");
1195        assert_eq!(result, expected, "SQL: {}", sql);
1196    }
1197
1198    #[test]
1199    fn test_ucase_to_upper() {
1200        mysql_identity("SELECT UCASE('foo')", "SELECT UPPER('foo')");
1201    }
1202
1203    #[test]
1204    fn test_lcase_to_lower() {
1205        mysql_identity("SELECT LCASE('foo')", "SELECT LOWER('foo')");
1206    }
1207
1208    #[test]
1209    fn test_day_of_month() {
1210        mysql_identity(
1211            "SELECT DAY_OF_MONTH('2023-01-01')",
1212            "SELECT DAYOFMONTH('2023-01-01')",
1213        );
1214    }
1215
1216    #[test]
1217    fn test_day_of_week() {
1218        mysql_identity(
1219            "SELECT DAY_OF_WEEK('2023-01-01')",
1220            "SELECT DAYOFWEEK('2023-01-01')",
1221        );
1222    }
1223
1224    #[test]
1225    fn test_day_of_year() {
1226        mysql_identity(
1227            "SELECT DAY_OF_YEAR('2023-01-01')",
1228            "SELECT DAYOFYEAR('2023-01-01')",
1229        );
1230    }
1231
1232    #[test]
1233    fn test_week_of_year() {
1234        mysql_identity(
1235            "SELECT WEEK_OF_YEAR('2023-01-01')",
1236            "SELECT WEEKOFYEAR('2023-01-01')",
1237        );
1238    }
1239
1240    #[test]
1241    fn test_mod_func_to_percent() {
1242        // MOD(x, y) function is transformed to x % y in MySQL
1243        mysql_identity("MOD(x, y)", "x % y");
1244    }
1245
1246    #[test]
1247    fn test_database_to_schema() {
1248        mysql_identity("DATABASE()", "SCHEMA()");
1249    }
1250
1251    #[test]
1252    fn test_and_operator() {
1253        mysql_identity("SELECT 1 && 0", "SELECT 1 AND 0");
1254    }
1255
1256    #[test]
1257    fn test_or_operator() {
1258        mysql_identity("SELECT a || b", "SELECT a OR b");
1259    }
1260}