Skip to main content

nautilus_dialect/
mysql.rs

1//! MySQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Error, Expr, Insert, Result, Select, Update, Value};
5
6/// MySQL SQL dialect renderer.
7#[derive(Debug, Clone, Copy)]
8pub struct MysqlDialect;
9
10/// Renders query ASTs into MySQL-compatible SQL with `?` placeholders
11/// and backtick-quoted identifiers.
12impl Dialect for MysqlDialect {
13    fn supports_returning(&self) -> bool {
14        false
15    }
16
17    fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
18        let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
19        render_select_body_core_mut!(&mut ctx, &mut select, '`', render_expr_owned, false, true);
20        ctx.finish()
21    }
22
23    fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
24        let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
25        render_insert_body_mut!(&mut ctx, &mut insert, '`', false, false);
26        ctx.finish()
27    }
28
29    fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
30        let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
31        render_update_body_mut!(&mut ctx, &mut update, '`', render_expr_owned, false, false);
32        ctx.finish()
33    }
34
35    fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
36        let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
37        render_delete_body_mut!(&mut ctx, &mut delete, '`', render_expr_owned, false);
38        ctx.finish()
39    }
40}
41
42struct RenderContext {
43    sql: String,
44    params: Vec<Value>,
45    error: Option<Error>,
46}
47
48impl RenderContext {
49    fn with_estimate(estimate: crate::RenderEstimate) -> Self {
50        Self {
51            sql: String::with_capacity(estimate.sql_capacity),
52            params: Vec::with_capacity(estimate.params_capacity),
53            error: None,
54        }
55    }
56
57    fn push_param(&mut self, value: Value) {
58        self.params.push(value);
59        self.sql.push('?');
60    }
61
62    fn take_param(&mut self, value: &mut Value) {
63        self.push_param(std::mem::replace(value, Value::Null));
64    }
65
66    fn fail(&mut self, message: impl Into<String>) {
67        if self.error.is_none() {
68            self.error = Some(Error::InvalidQuery(message.into()));
69        }
70    }
71
72    fn finish(self) -> Result<Sql> {
73        if let Some(err) = self.error {
74            return Err(err);
75        }
76
77        Ok(Sql {
78            text: self.sql,
79            params: self.params,
80        })
81    }
82}
83
84fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
85    render_select_body_core_mut!(ctx, select, '`', render_expr_owned, false, true);
86}
87
88fn mysql_function_name(name: &str) -> &str {
89    match name {
90        "json_agg" => "JSON_ARRAYAGG",
91        "json_build_object" => "JSON_OBJECT",
92        _ => name,
93    }
94}
95
96fn render_case_filtered_aggregate_owned(
97    ctx: &mut RenderContext,
98    fn_name: &str,
99    arg: &mut Expr,
100    predicate: &mut Expr,
101) {
102    ctx.sql.push_str(fn_name);
103    ctx.sql.push_str("(CASE WHEN ");
104    render_expr_owned(ctx, predicate);
105    ctx.sql.push_str(" THEN ");
106    render_expr_owned(ctx, arg);
107    ctx.sql.push_str(" ELSE NULL END)");
108}
109
110fn render_filter_owned(ctx: &mut RenderContext, expr: &mut Expr, predicate: &mut Expr) {
111    let Expr::FunctionCall { name, args } = expr else {
112        ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
113        return;
114    };
115
116    let upper = name.to_ascii_uppercase();
117    match (upper.as_str(), args.as_mut_slice()) {
118        ("COUNT", [Expr::Star]) => {
119            ctx.sql.push_str("COUNT(CASE WHEN ");
120            render_expr_owned(ctx, predicate);
121            ctx.sql.push_str(" THEN 1 ELSE NULL END)");
122        }
123        ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
124            render_case_filtered_aggregate_owned(ctx, upper.as_str(), arg, predicate);
125        }
126        ("JSON_AGG", [_]) => {
127            ctx.fail(
128                "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
129            );
130        }
131        (_, []) => {
132            ctx.fail(format!(
133                "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
134                name
135            ));
136        }
137        _ => {
138            ctx.fail(format!(
139                "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
140                name,
141                args.len()
142            ));
143        }
144    }
145}
146
147fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
148    if ctx.error.is_some() {
149        return;
150    }
151
152    render_expr_common_mut!(ctx, expr, '`', render_expr_owned, render_select_body_owned, {
153        Expr::Param(value) => {
154            if matches!(value, Value::Null) {
155                ctx.sql.push_str("NULL");
156            } else {
157                ctx.take_param(value);
158            }
159        }
160        Expr::Binary { left, op, right } => {
161            if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
162                ctx.sql.push('(');
163                render_expr_owned(ctx, left.as_mut());
164                ctx.sql.push(' ');
165                ctx.sql
166                    .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
167                ctx.sql.push_str(" (");
168                if let Expr::List(exprs) = right.as_mut() {
169                    for (i, e) in exprs.iter_mut().enumerate() {
170                        if i > 0 {
171                            ctx.sql.push_str(", ");
172                        }
173                        render_expr_owned(ctx, e);
174                    }
175                } else {
176                    render_expr_owned(ctx, right.as_mut());
177                }
178                ctx.sql.push(')');
179                ctx.sql.push(')');
180            } else if matches!(
181                *op,
182                BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps
183            ) {
184                match *op {
185                    BinaryOp::ArrayContains => {
186                        ctx.sql.push_str("JSON_CONTAINS(");
187                        render_expr_owned(ctx, left.as_mut());
188                        ctx.sql.push_str(", ");
189                        render_expr_owned(ctx, right.as_mut());
190                        ctx.sql.push(')');
191                    }
192                    BinaryOp::ArrayContainedBy => {
193                        ctx.sql.push_str("JSON_CONTAINS(");
194                        render_expr_owned(ctx, right.as_mut());
195                        ctx.sql.push_str(", ");
196                        render_expr_owned(ctx, left.as_mut());
197                        ctx.sql.push(')');
198                    }
199                    BinaryOp::ArrayOverlaps => {
200                        ctx.fail(
201                            "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
202                        );
203                    }
204                    _ => unreachable!(),
205                }
206            } else {
207                ctx.sql.push('(');
208                render_expr_owned(ctx, left.as_mut());
209                ctx.sql.push(' ');
210                ctx.sql.push_str(crate::binary_op_sql(op));
211                ctx.sql.push(' ');
212                render_expr_owned(ctx, right.as_mut());
213                ctx.sql.push(')');
214            }
215        }
216        Expr::FunctionCall { name, args } => {
217            let mysql_name = mysql_function_name(name);
218            ctx.sql.push_str(mysql_name);
219            ctx.sql.push('(');
220            for (i, arg) in args.iter_mut().enumerate() {
221                if i > 0 {
222                    ctx.sql.push_str(", ");
223                }
224                render_expr_owned(ctx, arg);
225            }
226            ctx.sql.push(')');
227        }
228        Expr::Filter { expr, predicate } => {
229            render_filter_owned(ctx, expr.as_mut(), predicate.as_mut());
230        }
231    });
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    fn quote_identifier(name: &str) -> String {
239        let mut sql = String::new();
240        crate::push_quoted_identifier(&mut sql, name, '`');
241        sql
242    }
243
244    #[test]
245    fn test_quote_identifier() {
246        assert_eq!(quote_identifier("users"), "`users`");
247        assert_eq!(quote_identifier("email"), "`email`");
248        assert_eq!(quote_identifier("foo`bar"), "`foo``bar`");
249        assert_eq!(quote_identifier("a`b`c"), "`a``b``c`");
250    }
251
252    #[test]
253    fn test_skip_without_take() {
254        let dialect = MysqlDialect;
255        let select = Select::from_table("users").skip(20).build().unwrap();
256        let sql = dialect.render_select(&select).unwrap();
257
258        assert_eq!(
259            sql.text,
260            "SELECT * FROM `users` LIMIT 18446744073709551615 OFFSET 20"
261        );
262        assert!(sql.params.is_empty());
263    }
264
265    #[test]
266    fn test_insert_returning_is_omitted() {
267        let dialect = MysqlDialect;
268        let insert = Insert::into_table("users")
269            .column(nautilus_core::ColumnMarker::new("users", "email"))
270            .values(vec![Value::String("alice@example.com".to_string())])
271            .returning(vec![
272                nautilus_core::ColumnMarker::new("users", "id"),
273                nautilus_core::ColumnMarker::new("users", "email"),
274            ])
275            .build()
276            .unwrap();
277        let sql = dialect.render_insert(&insert).unwrap();
278
279        assert_eq!(sql.text, "INSERT INTO `users` (`email`) VALUES (?)");
280        assert!(!sql.text.contains("RETURNING"));
281    }
282
283    #[test]
284    fn test_update_returning_is_omitted() {
285        let dialect = MysqlDialect;
286        let update = Update::table("users")
287            .set(
288                nautilus_core::ColumnMarker::new("users", "email"),
289                Value::String("new@example.com".to_string()),
290            )
291            .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
292            .returning(vec![
293                nautilus_core::ColumnMarker::new("users", "id"),
294                nautilus_core::ColumnMarker::new("users", "email"),
295            ])
296            .build()
297            .unwrap();
298        let sql = dialect.render_update(&update).unwrap();
299
300        assert_eq!(sql.text, "UPDATE `users` SET `email` = ? WHERE (`id` = ?)");
301        assert!(!sql.text.contains("RETURNING"));
302    }
303
304    #[test]
305    fn test_delete_returning_is_omitted() {
306        let dialect = MysqlDialect;
307        let delete = Delete::from_table("users")
308            .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
309            .returning(vec![
310                nautilus_core::ColumnMarker::new("users", "id"),
311                nautilus_core::ColumnMarker::new("users", "email"),
312            ])
313            .build()
314            .unwrap();
315        let sql = dialect.render_delete(&delete).unwrap();
316
317        assert_eq!(sql.text, "DELETE FROM `users` WHERE (`id` = ?)");
318        assert!(!sql.text.contains("RETURNING"));
319    }
320
321    #[test]
322    fn test_filter_count_star_is_emulated() {
323        let dialect = MysqlDialect;
324        let select = Select::from_table("users")
325            .computed(
326                Expr::function_call("COUNT", vec![Expr::star()])
327                    .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
328                "active_count",
329            )
330            .build()
331            .unwrap();
332
333        let sql = dialect.render_select(&select).unwrap();
334
335        assert_eq!(
336            sql.text,
337            "SELECT (COUNT(CASE WHEN (`active` = ?) THEN 1 ELSE NULL END)) AS `active_count` FROM `users`"
338        );
339        assert_eq!(sql.params, vec![Value::Bool(true)]);
340    }
341
342    #[test]
343    fn test_filter_single_arg_aggregate_is_emulated() {
344        let dialect = MysqlDialect;
345        let select = Select::from_table("users")
346            .computed(
347                Expr::function_call("SUM", vec![Expr::column("score")])
348                    .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
349                "active_score",
350            )
351            .build()
352            .unwrap();
353
354        let sql = dialect.render_select(&select).unwrap();
355
356        assert_eq!(
357            sql.text,
358            "SELECT (SUM(CASE WHEN (`active` = ?) THEN `score` ELSE NULL END)) AS `active_score` FROM `users`"
359        );
360        assert_eq!(sql.params, vec![Value::Bool(true)]);
361    }
362
363    #[test]
364    fn test_filter_multi_arg_function_is_rejected() {
365        let dialect = MysqlDialect;
366        let select = Select::from_table("users")
367            .computed(
368                Expr::function_call(
369                    "json_build_object",
370                    vec![
371                        Expr::Literal(nautilus_core::LiteralSql::from_static("score")),
372                        Expr::column("score"),
373                    ],
374                )
375                .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
376                "payload",
377            )
378            .build()
379            .unwrap();
380
381        let err = dialect.render_select(&select).unwrap_err();
382        assert!(err
383            .to_string()
384            .contains("cannot emulate FILTER for function 'json_build_object'"));
385    }
386
387    #[test]
388    fn test_array_overlaps_is_rejected() {
389        let dialect = MysqlDialect;
390        let expr = Expr::Binary {
391            left: Box::new(Expr::column("posts__tags")),
392            op: BinaryOp::ArrayOverlaps,
393            right: Box::new(Expr::param(Value::Array(vec![Value::String(
394                "rust".to_string(),
395            )]))),
396        };
397        let select = Select::from_table("posts").filter(expr).build().unwrap();
398
399        let err = dialect.render_select(&select).unwrap_err();
400        assert!(err.to_string().contains("ArrayOverlaps generically"));
401    }
402}