Skip to main content

nautilus_dialect/
mysql.rs

1//! MySQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Error, Expr, Insert, Result, Select, Update, Value};
5
6/// MySQL SQL dialect renderer.
7#[derive(Debug, Clone, Copy)]
8pub struct MysqlDialect;
9
10/// Renders query ASTs into MySQL-compatible SQL with `?` placeholders
11/// and backtick-quoted identifiers.
12impl Dialect for MysqlDialect {
13    fn supports_returning(&self) -> bool {
14        false
15    }
16
17    fn render_select(&self, select: &Select) -> Result<Sql> {
18        let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(select));
19        render_select_body_core!(&mut ctx, select, '`', render_expr, false, true);
20        ctx.finish()
21    }
22
23    fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
24        let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
25        render_select_body_core_mut!(&mut ctx, &mut select, '`', render_expr_owned, false, true);
26        ctx.finish()
27    }
28
29    fn render_insert(&self, insert: &Insert) -> Result<Sql> {
30        let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(insert));
31        render_insert_body!(&mut ctx, insert, '`', false, false);
32        ctx.finish()
33    }
34
35    fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
36        let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
37        render_insert_body_mut!(&mut ctx, &mut insert, '`', false, false);
38        ctx.finish()
39    }
40
41    fn render_update(&self, update: &Update) -> Result<Sql> {
42        let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(update));
43        render_update_body!(&mut ctx, update, '`', render_expr, false, false);
44        ctx.finish()
45    }
46
47    fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
48        let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
49        render_update_body_mut!(&mut ctx, &mut update, '`', render_expr_owned, false, false);
50        ctx.finish()
51    }
52
53    fn render_delete(&self, delete: &Delete) -> Result<Sql> {
54        let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(delete));
55        render_delete_body!(&mut ctx, delete, '`', render_expr, false);
56        ctx.finish()
57    }
58
59    fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
60        let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
61        render_delete_body_mut!(&mut ctx, &mut delete, '`', render_expr_owned, false);
62        ctx.finish()
63    }
64}
65
66struct RenderContext {
67    sql: String,
68    params: Vec<Value>,
69    error: Option<Error>,
70}
71
72impl RenderContext {
73    fn with_estimate(estimate: crate::RenderEstimate) -> Self {
74        Self {
75            sql: String::with_capacity(estimate.sql_capacity),
76            params: Vec::with_capacity(estimate.params_capacity),
77            error: None,
78        }
79    }
80
81    fn push_param(&mut self, value: Value) {
82        self.params.push(value);
83        self.sql.push('?');
84    }
85
86    fn take_param(&mut self, value: &mut Value) {
87        self.push_param(std::mem::replace(value, Value::Null));
88    }
89
90    fn fail(&mut self, message: impl Into<String>) {
91        if self.error.is_none() {
92            self.error = Some(Error::InvalidQuery(message.into()));
93        }
94    }
95
96    fn finish(self) -> Result<Sql> {
97        if let Some(err) = self.error {
98            return Err(err);
99        }
100
101        Ok(Sql {
102            text: self.sql,
103            params: self.params,
104        })
105    }
106}
107
108fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
109    render_select_body_core!(ctx, select, '`', render_expr, false, true);
110}
111
112fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
113    render_select_body_core_mut!(ctx, select, '`', render_expr_owned, false, true);
114}
115
116fn mysql_function_name(name: &str) -> &str {
117    match name {
118        "json_agg" => "JSON_ARRAYAGG",
119        "json_build_object" => "JSON_OBJECT",
120        _ => name,
121    }
122}
123
124fn render_case_filtered_aggregate(
125    ctx: &mut RenderContext,
126    fn_name: &str,
127    arg: &Expr,
128    predicate: &Expr,
129) {
130    ctx.sql.push_str(fn_name);
131    ctx.sql.push_str("(CASE WHEN ");
132    render_expr(ctx, predicate);
133    ctx.sql.push_str(" THEN ");
134    render_expr(ctx, arg);
135    ctx.sql.push_str(" ELSE NULL END)");
136}
137
138fn render_case_filtered_aggregate_owned(
139    ctx: &mut RenderContext,
140    fn_name: &str,
141    arg: &mut Expr,
142    predicate: &mut Expr,
143) {
144    ctx.sql.push_str(fn_name);
145    ctx.sql.push_str("(CASE WHEN ");
146    render_expr_owned(ctx, predicate);
147    ctx.sql.push_str(" THEN ");
148    render_expr_owned(ctx, arg);
149    ctx.sql.push_str(" ELSE NULL END)");
150}
151
152fn render_filter(ctx: &mut RenderContext, expr: &Expr, predicate: &Expr) {
153    let Expr::FunctionCall { name, args } = expr else {
154        ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
155        return;
156    };
157
158    let upper = name.to_ascii_uppercase();
159    match (upper.as_str(), args.as_slice()) {
160        ("COUNT", [Expr::Star]) => {
161            ctx.sql.push_str("COUNT(CASE WHEN ");
162            render_expr(ctx, predicate);
163            ctx.sql.push_str(" THEN 1 ELSE NULL END)");
164        }
165        ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
166            render_case_filtered_aggregate(ctx, upper.as_str(), arg, predicate);
167        }
168        ("JSON_AGG", [_]) => {
169            ctx.fail(
170                "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
171            );
172        }
173        (_, []) => {
174            ctx.fail(format!(
175                "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
176                name
177            ));
178        }
179        _ => {
180            ctx.fail(format!(
181                "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
182                name,
183                args.len()
184            ));
185        }
186    }
187}
188
189fn render_filter_owned(ctx: &mut RenderContext, expr: &mut Expr, predicate: &mut Expr) {
190    let Expr::FunctionCall { name, args } = expr else {
191        ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
192        return;
193    };
194
195    let upper = name.to_ascii_uppercase();
196    match (upper.as_str(), args.as_mut_slice()) {
197        ("COUNT", [Expr::Star]) => {
198            ctx.sql.push_str("COUNT(CASE WHEN ");
199            render_expr_owned(ctx, predicate);
200            ctx.sql.push_str(" THEN 1 ELSE NULL END)");
201        }
202        ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
203            render_case_filtered_aggregate_owned(ctx, upper.as_str(), arg, predicate);
204        }
205        ("JSON_AGG", [_]) => {
206            ctx.fail(
207                "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
208            );
209        }
210        (_, []) => {
211            ctx.fail(format!(
212                "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
213                name
214            ));
215        }
216        _ => {
217            ctx.fail(format!(
218                "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
219                name,
220                args.len()
221            ));
222        }
223    }
224}
225
226fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
227    if ctx.error.is_some() {
228        return;
229    }
230
231    render_expr_common!(ctx, expr, '`', render_expr, render_select_body, {
232        Expr::Param(value) => {
233            if matches!(value, Value::Null) {
234                ctx.sql.push_str("NULL");
235            } else {
236                ctx.push_param(value.clone());
237            }
238        }
239        Expr::Binary { left, op, right } => {
240            if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
241                ctx.sql.push('(');
242                render_expr(ctx, left);
243                ctx.sql.push(' ');
244                ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
245                ctx.sql.push_str(" (");
246                if let Expr::List(exprs) = right.as_ref() {
247                    for (i, e) in exprs.iter().enumerate() {
248                        if i > 0 { ctx.sql.push_str(", "); }
249                        render_expr(ctx, e);
250                    }
251                } else {
252                    render_expr(ctx, right);
253                }
254                ctx.sql.push(')');
255                ctx.sql.push(')');
256            } else if matches!(op, BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps) {
257                // Array operators emulated via MySQL JSON functions.
258                // Arrays are bound as JSON strings by the connector layer.
259                match op {
260                    BinaryOp::ArrayContains => {
261                        // col @> rhs: col contains every element of rhs.
262                        // JSON_CONTAINS(target, candidate) returns 1 when the candidate is a subset of the target.
263                        ctx.sql.push_str("JSON_CONTAINS(");
264                        render_expr(ctx, left);
265                        ctx.sql.push_str(", ");
266                        render_expr(ctx, right);
267                        ctx.sql.push(')');
268                    }
269                    BinaryOp::ArrayContainedBy => {
270                        // col <@ rhs: rhs contains every element of col.
271                        ctx.sql.push_str("JSON_CONTAINS(");
272                        render_expr(ctx, right);
273                        ctx.sql.push_str(", ");
274                        render_expr(ctx, left);
275                        ctx.sql.push(')');
276                    }
277                    BinaryOp::ArrayOverlaps => {
278                        ctx.fail(
279                            "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
280                        );
281                    }
282                    _ => unreachable!(),
283                }
284            } else {
285                ctx.sql.push('(');
286                render_expr(ctx, left);
287                ctx.sql.push(' ');
288                ctx.sql.push_str(crate::binary_op_sql(op));
289                ctx.sql.push(' ');
290                render_expr(ctx, right);
291                ctx.sql.push(')');
292            }
293        }
294        Expr::FunctionCall { name, args } => {
295            let mysql_name = mysql_function_name(name);
296            ctx.sql.push_str(mysql_name);
297            ctx.sql.push('(');
298            for (i, arg) in args.iter().enumerate() {
299                if i > 0 { ctx.sql.push_str(", "); }
300                render_expr(ctx, arg);
301            }
302            ctx.sql.push(')');
303        }
304        Expr::Filter { expr, predicate } => {
305            render_filter(ctx, expr, predicate);
306        }
307    });
308}
309
310fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
311    if ctx.error.is_some() {
312        return;
313    }
314
315    render_expr_common_mut!(ctx, expr, '`', render_expr_owned, render_select_body_owned, {
316        Expr::Param(value) => {
317            if matches!(value, Value::Null) {
318                ctx.sql.push_str("NULL");
319            } else {
320                ctx.take_param(value);
321            }
322        }
323        Expr::Binary { left, op, right } => {
324            if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
325                ctx.sql.push('(');
326                render_expr_owned(ctx, left.as_mut());
327                ctx.sql.push(' ');
328                ctx.sql
329                    .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
330                ctx.sql.push_str(" (");
331                if let Expr::List(exprs) = right.as_mut() {
332                    for (i, e) in exprs.iter_mut().enumerate() {
333                        if i > 0 {
334                            ctx.sql.push_str(", ");
335                        }
336                        render_expr_owned(ctx, e);
337                    }
338                } else {
339                    render_expr_owned(ctx, right.as_mut());
340                }
341                ctx.sql.push(')');
342                ctx.sql.push(')');
343            } else if matches!(
344                *op,
345                BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps
346            ) {
347                match *op {
348                    BinaryOp::ArrayContains => {
349                        ctx.sql.push_str("JSON_CONTAINS(");
350                        render_expr_owned(ctx, left.as_mut());
351                        ctx.sql.push_str(", ");
352                        render_expr_owned(ctx, right.as_mut());
353                        ctx.sql.push(')');
354                    }
355                    BinaryOp::ArrayContainedBy => {
356                        ctx.sql.push_str("JSON_CONTAINS(");
357                        render_expr_owned(ctx, right.as_mut());
358                        ctx.sql.push_str(", ");
359                        render_expr_owned(ctx, left.as_mut());
360                        ctx.sql.push(')');
361                    }
362                    BinaryOp::ArrayOverlaps => {
363                        ctx.fail(
364                            "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
365                        );
366                    }
367                    _ => unreachable!(),
368                }
369            } else {
370                ctx.sql.push('(');
371                render_expr_owned(ctx, left.as_mut());
372                ctx.sql.push(' ');
373                ctx.sql.push_str(crate::binary_op_sql(op));
374                ctx.sql.push(' ');
375                render_expr_owned(ctx, right.as_mut());
376                ctx.sql.push(')');
377            }
378        }
379        Expr::FunctionCall { name, args } => {
380            let mysql_name = mysql_function_name(name);
381            ctx.sql.push_str(mysql_name);
382            ctx.sql.push('(');
383            for (i, arg) in args.iter_mut().enumerate() {
384                if i > 0 {
385                    ctx.sql.push_str(", ");
386                }
387                render_expr_owned(ctx, arg);
388            }
389            ctx.sql.push(')');
390        }
391        Expr::Filter { expr, predicate } => {
392            render_filter_owned(ctx, expr.as_mut(), predicate.as_mut());
393        }
394    });
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400
401    fn quote_identifier(name: &str) -> String {
402        let mut sql = String::new();
403        crate::push_quoted_identifier(&mut sql, name, '`');
404        sql
405    }
406
407    #[test]
408    fn test_quote_identifier() {
409        assert_eq!(quote_identifier("users"), "`users`");
410        assert_eq!(quote_identifier("email"), "`email`");
411        assert_eq!(quote_identifier("foo`bar"), "`foo``bar`");
412        assert_eq!(quote_identifier("a`b`c"), "`a``b``c`");
413    }
414
415    #[test]
416    fn test_skip_without_take() {
417        let dialect = MysqlDialect;
418        let select = Select::from_table("users").skip(20).build().unwrap();
419        let sql = dialect.render_select(&select).unwrap();
420
421        assert_eq!(
422            sql.text,
423            "SELECT * FROM `users` LIMIT 18446744073709551615 OFFSET 20"
424        );
425        assert!(sql.params.is_empty());
426    }
427
428    #[test]
429    fn test_insert_returning_is_omitted() {
430        let dialect = MysqlDialect;
431        let insert = Insert::into_table("users")
432            .column(nautilus_core::ColumnMarker::new("users", "email"))
433            .values(vec![Value::String("alice@example.com".to_string())])
434            .returning(vec![
435                nautilus_core::ColumnMarker::new("users", "id"),
436                nautilus_core::ColumnMarker::new("users", "email"),
437            ])
438            .build()
439            .unwrap();
440        let sql = dialect.render_insert(&insert).unwrap();
441
442        assert_eq!(sql.text, "INSERT INTO `users` (`email`) VALUES (?)");
443        assert!(!sql.text.contains("RETURNING"));
444    }
445
446    #[test]
447    fn test_update_returning_is_omitted() {
448        let dialect = MysqlDialect;
449        let update = Update::table("users")
450            .set(
451                nautilus_core::ColumnMarker::new("users", "email"),
452                Value::String("new@example.com".to_string()),
453            )
454            .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
455            .returning(vec![
456                nautilus_core::ColumnMarker::new("users", "id"),
457                nautilus_core::ColumnMarker::new("users", "email"),
458            ])
459            .build()
460            .unwrap();
461        let sql = dialect.render_update(&update).unwrap();
462
463        assert_eq!(sql.text, "UPDATE `users` SET `email` = ? WHERE (`id` = ?)");
464        assert!(!sql.text.contains("RETURNING"));
465    }
466
467    #[test]
468    fn test_delete_returning_is_omitted() {
469        let dialect = MysqlDialect;
470        let delete = Delete::from_table("users")
471            .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
472            .returning(vec![
473                nautilus_core::ColumnMarker::new("users", "id"),
474                nautilus_core::ColumnMarker::new("users", "email"),
475            ])
476            .build()
477            .unwrap();
478        let sql = dialect.render_delete(&delete).unwrap();
479
480        assert_eq!(sql.text, "DELETE FROM `users` WHERE (`id` = ?)");
481        assert!(!sql.text.contains("RETURNING"));
482    }
483
484    #[test]
485    fn test_filter_count_star_is_emulated() {
486        let dialect = MysqlDialect;
487        let select = Select::from_table("users")
488            .computed(
489                Expr::function_call("COUNT", vec![Expr::star()])
490                    .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
491                "active_count",
492            )
493            .build()
494            .unwrap();
495
496        let sql = dialect.render_select(&select).unwrap();
497
498        assert_eq!(
499            sql.text,
500            "SELECT (COUNT(CASE WHEN (`active` = ?) THEN 1 ELSE NULL END)) AS `active_count` FROM `users`"
501        );
502        assert_eq!(sql.params, vec![Value::Bool(true)]);
503    }
504
505    #[test]
506    fn test_filter_single_arg_aggregate_is_emulated() {
507        let dialect = MysqlDialect;
508        let select = Select::from_table("users")
509            .computed(
510                Expr::function_call("SUM", vec![Expr::column("score")])
511                    .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
512                "active_score",
513            )
514            .build()
515            .unwrap();
516
517        let sql = dialect.render_select(&select).unwrap();
518
519        assert_eq!(
520            sql.text,
521            "SELECT (SUM(CASE WHEN (`active` = ?) THEN `score` ELSE NULL END)) AS `active_score` FROM `users`"
522        );
523        assert_eq!(sql.params, vec![Value::Bool(true)]);
524    }
525
526    #[test]
527    fn test_filter_multi_arg_function_is_rejected() {
528        let dialect = MysqlDialect;
529        let select = Select::from_table("users")
530            .computed(
531                Expr::function_call(
532                    "json_build_object",
533                    vec![Expr::Literal("score".to_string()), Expr::column("score")],
534                )
535                .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
536                "payload",
537            )
538            .build()
539            .unwrap();
540
541        let err = dialect.render_select(&select).unwrap_err();
542        assert!(err
543            .to_string()
544            .contains("cannot emulate FILTER for function 'json_build_object'"));
545    }
546
547    #[test]
548    fn test_array_overlaps_is_rejected() {
549        let dialect = MysqlDialect;
550        let expr = Expr::Binary {
551            left: Box::new(Expr::column("posts__tags")),
552            op: BinaryOp::ArrayOverlaps,
553            right: Box::new(Expr::param(Value::Array(vec![Value::String(
554                "rust".to_string(),
555            )]))),
556        };
557        let select = Select::from_table("posts").filter(expr).build().unwrap();
558
559        let err = dialect.render_select(&select).unwrap_err();
560        assert!(err.to_string().contains("ArrayOverlaps generically"));
561    }
562}