Skip to main content

nautilus_dialect/
mysql.rs

1//! MySQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{
5    BinaryOp, Delete, Error, Expr, Insert, JsonPathCast, Result, Select, Update, Value,
6};
7
8/// MySQL SQL dialect renderer.
9#[derive(Debug, Clone, Copy)]
10pub struct MysqlDialect;
11
12/// Renders query ASTs into MySQL-compatible SQL with `?` placeholders
13/// and backtick-quoted identifiers.
14impl Dialect for MysqlDialect {
15    fn supports_returning(&self) -> bool {
16        false
17    }
18
19    fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
20        let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
21        render_select_body_core_mut!(&mut ctx, &mut select, '`', render_expr_owned, false, true);
22        ctx.finish()
23    }
24
25    fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
26        let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
27        render_insert_body_mut!(&mut ctx, &mut insert, '`', false, false);
28        ctx.finish()
29    }
30
31    fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
32        let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
33        render_update_body_mut!(&mut ctx, &mut update, '`', render_expr_owned, false, false);
34        ctx.finish()
35    }
36
37    fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
38        let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
39        render_delete_body_mut!(&mut ctx, &mut delete, '`', render_expr_owned, false);
40        ctx.finish()
41    }
42}
43
44struct RenderContext {
45    sql: String,
46    params: Vec<Value>,
47    error: Option<Error>,
48}
49
50impl RenderContext {
51    fn with_estimate(estimate: crate::RenderEstimate) -> Self {
52        Self {
53            sql: String::with_capacity(estimate.sql_capacity),
54            params: Vec::with_capacity(estimate.params_capacity),
55            error: None,
56        }
57    }
58
59    fn push_param(&mut self, value: Value) {
60        self.params.push(value);
61        self.sql.push('?');
62    }
63
64    fn take_param(&mut self, value: &mut Value) {
65        self.push_param(std::mem::replace(value, Value::Null));
66    }
67
68    fn fail(&mut self, message: impl Into<String>) {
69        if self.error.is_none() {
70            self.error = Some(Error::InvalidQuery(message.into()));
71        }
72    }
73
74    fn finish(self) -> Result<Sql> {
75        if let Some(err) = self.error {
76            return Err(err);
77        }
78
79        Ok(Sql {
80            text: self.sql,
81            params: self.params,
82        })
83    }
84}
85
86fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
87    render_select_body_core_mut!(ctx, select, '`', render_expr_owned, false, true);
88}
89
90fn mysql_function_name(name: &str) -> &str {
91    match name {
92        "json_agg" => "JSON_ARRAYAGG",
93        "json_build_object" => "JSON_OBJECT",
94        _ => name,
95    }
96}
97
98fn render_case_filtered_aggregate_owned(
99    ctx: &mut RenderContext,
100    fn_name: &str,
101    arg: &mut Expr,
102    predicate: &mut Expr,
103) {
104    ctx.sql.push_str(fn_name);
105    ctx.sql.push_str("(CASE WHEN ");
106    render_expr_owned(ctx, predicate);
107    ctx.sql.push_str(" THEN ");
108    render_expr_owned(ctx, arg);
109    ctx.sql.push_str(" ELSE NULL END)");
110}
111
112fn render_filter_owned(ctx: &mut RenderContext, expr: &mut Expr, predicate: &mut Expr) {
113    let Expr::FunctionCall { name, args } = expr else {
114        ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
115        return;
116    };
117
118    let upper = name.to_ascii_uppercase();
119    match (upper.as_str(), args.as_mut_slice()) {
120        ("COUNT", [Expr::Star]) => {
121            ctx.sql.push_str("COUNT(CASE WHEN ");
122            render_expr_owned(ctx, predicate);
123            ctx.sql.push_str(" THEN 1 ELSE NULL END)");
124        }
125        ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
126            render_case_filtered_aggregate_owned(ctx, upper.as_str(), arg, predicate);
127        }
128        ("JSON_AGG", [_]) => {
129            ctx.fail(
130                "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
131            );
132        }
133        (_, []) => {
134            ctx.fail(format!(
135                "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
136                name
137            ));
138        }
139        _ => {
140            ctx.fail(format!(
141                "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
142                name,
143                args.len()
144            ));
145        }
146    }
147}
148
149fn render_json_extract_unquoted(ctx: &mut RenderContext, table: &str, column: &str, key: &str) {
150    ctx.sql.push_str("JSON_UNQUOTE(JSON_EXTRACT(");
151    crate::push_qualified_identifier(&mut ctx.sql, table, column, '`');
152    ctx.sql.push_str(", ");
153    crate::push_json_object_path_literal(&mut ctx.sql, key);
154    ctx.sql.push_str("))");
155}
156
157fn render_composite_field_owned(
158    ctx: &mut RenderContext,
159    table: &str,
160    column: &str,
161    key: &str,
162    cast: JsonPathCast,
163) {
164    match cast {
165        JsonPathCast::None => render_json_extract_unquoted(ctx, table, column, key),
166        JsonPathCast::Signed => {
167            ctx.sql.push_str("CAST(");
168            render_json_extract_unquoted(ctx, table, column, key);
169            ctx.sql.push_str(" AS SIGNED)");
170        }
171        JsonPathCast::Double => {
172            ctx.sql.push_str("CAST(");
173            render_json_extract_unquoted(ctx, table, column, key);
174            ctx.sql.push_str(" AS DOUBLE)");
175        }
176        JsonPathCast::Decimal => {
177            ctx.sql.push_str("CAST(");
178            render_json_extract_unquoted(ctx, table, column, key);
179            ctx.sql.push_str(" AS DECIMAL(65, 30))");
180        }
181    }
182}
183
184fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
185    if ctx.error.is_some() {
186        return;
187    }
188
189    render_expr_common_mut!(ctx, expr, '`', render_expr_owned, render_select_body_owned, {
190        Expr::CompositeField {
191            table,
192            column,
193            json_key,
194            json_cast,
195            ..
196        } => {
197            render_composite_field_owned(ctx, table, column, json_key, *json_cast);
198        }
199        Expr::Param(value) => {
200            if matches!(value, Value::Null) {
201                ctx.sql.push_str("NULL");
202            } else {
203                ctx.take_param(value);
204            }
205        }
206        Expr::Binary { left, op, right } => {
207            if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
208                ctx.sql.push('(');
209                render_expr_owned(ctx, left.as_mut());
210                ctx.sql.push(' ');
211                ctx.sql
212                    .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
213                ctx.sql.push_str(" (");
214                if let Expr::List(exprs) = right.as_mut() {
215                    for (i, e) in exprs.iter_mut().enumerate() {
216                        if i > 0 {
217                            ctx.sql.push_str(", ");
218                        }
219                        render_expr_owned(ctx, e);
220                    }
221                } else {
222                    render_expr_owned(ctx, right.as_mut());
223                }
224                ctx.sql.push(')');
225                ctx.sql.push(')');
226            } else if matches!(
227                *op,
228                BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps
229            ) {
230                match *op {
231                    BinaryOp::ArrayContains => {
232                        ctx.sql.push_str("JSON_CONTAINS(");
233                        render_expr_owned(ctx, left.as_mut());
234                        ctx.sql.push_str(", ");
235                        render_expr_owned(ctx, right.as_mut());
236                        ctx.sql.push(')');
237                    }
238                    BinaryOp::ArrayContainedBy => {
239                        ctx.sql.push_str("JSON_CONTAINS(");
240                        render_expr_owned(ctx, right.as_mut());
241                        ctx.sql.push_str(", ");
242                        render_expr_owned(ctx, left.as_mut());
243                        ctx.sql.push(')');
244                    }
245                    BinaryOp::ArrayOverlaps => {
246                        ctx.fail(
247                            "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
248                        );
249                    }
250                    _ => unreachable!(),
251                }
252            } else {
253                ctx.sql.push('(');
254                render_expr_owned(ctx, left.as_mut());
255                ctx.sql.push(' ');
256                ctx.sql.push_str(crate::binary_op_sql(op));
257                ctx.sql.push(' ');
258                render_expr_owned(ctx, right.as_mut());
259                ctx.sql.push(')');
260            }
261        }
262        Expr::FunctionCall { name, args } => {
263            let mysql_name = mysql_function_name(name);
264            ctx.sql.push_str(mysql_name);
265            ctx.sql.push('(');
266            for (i, arg) in args.iter_mut().enumerate() {
267                if i > 0 {
268                    ctx.sql.push_str(", ");
269                }
270                render_expr_owned(ctx, arg);
271            }
272            ctx.sql.push(')');
273        }
274        Expr::Filter { expr, predicate } => {
275            render_filter_owned(ctx, expr.as_mut(), predicate.as_mut());
276        }
277    });
278}
279
280#[cfg(test)]
281mod tests {
282    use super::*;
283
284    fn quote_identifier(name: &str) -> String {
285        let mut sql = String::new();
286        crate::push_quoted_identifier(&mut sql, name, '`');
287        sql
288    }
289
290    #[test]
291    fn test_quote_identifier() {
292        assert_eq!(quote_identifier("users"), "`users`");
293        assert_eq!(quote_identifier("email"), "`email`");
294        assert_eq!(quote_identifier("foo`bar"), "`foo``bar`");
295        assert_eq!(quote_identifier("a`b`c"), "`a``b``c`");
296    }
297
298    #[test]
299    fn test_skip_without_take() {
300        let dialect = MysqlDialect;
301        let select = Select::from_table("users").skip(20).build().unwrap();
302        let sql = dialect.render_select(&select).unwrap();
303
304        assert_eq!(
305            sql.text,
306            "SELECT * FROM `users` LIMIT 18446744073709551615 OFFSET 20"
307        );
308        assert!(sql.params.is_empty());
309    }
310
311    #[test]
312    fn test_insert_returning_is_omitted() {
313        let dialect = MysqlDialect;
314        let insert = Insert::into_table("users")
315            .column(nautilus_core::ColumnMarker::new("users", "email"))
316            .values(vec![Value::String("alice@example.com".to_string())])
317            .returning(vec![
318                nautilus_core::ColumnMarker::new("users", "id"),
319                nautilus_core::ColumnMarker::new("users", "email"),
320            ])
321            .build()
322            .unwrap();
323        let sql = dialect.render_insert(&insert).unwrap();
324
325        assert_eq!(sql.text, "INSERT INTO `users` (`email`) VALUES (?)");
326        assert!(!sql.text.contains("RETURNING"));
327    }
328
329    #[test]
330    fn test_update_returning_is_omitted() {
331        let dialect = MysqlDialect;
332        let update = Update::table("users")
333            .set(
334                nautilus_core::ColumnMarker::new("users", "email"),
335                Value::String("new@example.com".to_string()),
336            )
337            .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
338            .returning(vec![
339                nautilus_core::ColumnMarker::new("users", "id"),
340                nautilus_core::ColumnMarker::new("users", "email"),
341            ])
342            .build()
343            .unwrap();
344        let sql = dialect.render_update(&update).unwrap();
345
346        assert_eq!(sql.text, "UPDATE `users` SET `email` = ? WHERE (`id` = ?)");
347        assert!(!sql.text.contains("RETURNING"));
348    }
349
350    #[test]
351    fn test_delete_returning_is_omitted() {
352        let dialect = MysqlDialect;
353        let delete = Delete::from_table("users")
354            .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
355            .returning(vec![
356                nautilus_core::ColumnMarker::new("users", "id"),
357                nautilus_core::ColumnMarker::new("users", "email"),
358            ])
359            .build()
360            .unwrap();
361        let sql = dialect.render_delete(&delete).unwrap();
362
363        assert_eq!(sql.text, "DELETE FROM `users` WHERE (`id` = ?)");
364        assert!(!sql.text.contains("RETURNING"));
365    }
366
367    #[test]
368    fn test_filter_count_star_is_emulated() {
369        let dialect = MysqlDialect;
370        let select = Select::from_table("users")
371            .computed(
372                Expr::function_call("COUNT", vec![Expr::star()])
373                    .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
374                "active_count",
375            )
376            .build()
377            .unwrap();
378
379        let sql = dialect.render_select(&select).unwrap();
380
381        assert_eq!(
382            sql.text,
383            "SELECT (COUNT(CASE WHEN (`active` = ?) THEN 1 ELSE NULL END)) AS `active_count` FROM `users`"
384        );
385        assert_eq!(sql.params, vec![Value::Bool(true)]);
386    }
387
388    #[test]
389    fn test_filter_single_arg_aggregate_is_emulated() {
390        let dialect = MysqlDialect;
391        let select = Select::from_table("users")
392            .computed(
393                Expr::function_call("SUM", vec![Expr::column("score")])
394                    .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
395                "active_score",
396            )
397            .build()
398            .unwrap();
399
400        let sql = dialect.render_select(&select).unwrap();
401
402        assert_eq!(
403            sql.text,
404            "SELECT (SUM(CASE WHEN (`active` = ?) THEN `score` ELSE NULL END)) AS `active_score` FROM `users`"
405        );
406        assert_eq!(sql.params, vec![Value::Bool(true)]);
407    }
408
409    #[test]
410    fn test_filter_multi_arg_function_is_rejected() {
411        let dialect = MysqlDialect;
412        let select = Select::from_table("users")
413            .computed(
414                Expr::function_call(
415                    "json_build_object",
416                    vec![
417                        Expr::Literal(nautilus_core::LiteralSql::from_static("score")),
418                        Expr::column("score"),
419                    ],
420                )
421                .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
422                "payload",
423            )
424            .build()
425            .unwrap();
426
427        let err = dialect.render_select(&select).unwrap_err();
428        assert!(err
429            .to_string()
430            .contains("cannot emulate FILTER for function 'json_build_object'"));
431    }
432
433    #[test]
434    fn test_array_overlaps_is_rejected() {
435        let dialect = MysqlDialect;
436        let expr = Expr::Binary {
437            left: Box::new(Expr::column("posts__tags")),
438            op: BinaryOp::ArrayOverlaps,
439            right: Box::new(Expr::param(Value::Array(vec![Value::String(
440                "rust".to_string(),
441            )]))),
442        };
443        let select = Select::from_table("posts").filter(expr).build().unwrap();
444
445        let err = dialect.render_select(&select).unwrap_err();
446        assert!(err.to_string().contains("ArrayOverlaps generically"));
447    }
448
449    #[test]
450    fn composite_field_ordering_uses_json_extract_with_numeric_cast() {
451        let dialect = MysqlDialect;
452        let select = Select::from_table("shipments")
453            .order_by_expr(
454                Expr::composite_field(
455                    "shipments",
456                    "delivery_snapshot",
457                    "eta_minutes",
458                    "etaMinutes",
459                    JsonPathCast::Signed,
460                ),
461                nautilus_core::OrderDir::Asc,
462            )
463            .build()
464            .unwrap();
465        let sql = dialect.render_select(&select).unwrap();
466
467        assert_eq!(
468            sql.text,
469            "SELECT * FROM `shipments` ORDER BY CAST(JSON_UNQUOTE(JSON_EXTRACT(`shipments`.`delivery_snapshot`, '$.\"etaMinutes\"')) AS SIGNED) ASC"
470        );
471    }
472}