Skip to main content

nautilus_dialect/
mysql.rs

1//! MySQL SQL dialect renderer.
2
3use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Error, Expr, Insert, Result, Select, Update, Value};
5
6/// MySQL SQL dialect renderer.
7#[derive(Debug, Clone, Copy)]
8pub struct MysqlDialect;
9
10/// Renders query ASTs into MySQL-compatible SQL with `?` placeholders
11/// and backtick-quoted identifiers.
12impl Dialect for MysqlDialect {
13    fn supports_returning(&self) -> bool {
14        false
15    }
16
17    fn render_select(&self, select: &Select) -> Result<Sql> {
18        let mut ctx = RenderContext::new();
19        render_select_body_core!(&mut ctx, select, quote_identifier, render_expr, false, true);
20        ctx.finish()
21    }
22
23    fn render_insert(&self, insert: &Insert) -> Result<Sql> {
24        let mut ctx = RenderContext::new();
25        render_insert_body!(&mut ctx, insert, quote_identifier, false, false);
26        ctx.finish()
27    }
28
29    fn render_update(&self, update: &Update) -> Result<Sql> {
30        let mut ctx = RenderContext::new();
31        render_update_body!(
32            &mut ctx,
33            update,
34            quote_identifier,
35            render_expr,
36            false,
37            false
38        );
39        ctx.finish()
40    }
41
42    fn render_delete(&self, delete: &Delete) -> Result<Sql> {
43        let mut ctx = RenderContext::new();
44        render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, false);
45        ctx.finish()
46    }
47}
48
49fn quote_identifier(name: &str) -> String {
50    crate::backtick_quote_identifier(name)
51}
52
53struct RenderContext {
54    sql: String,
55    params: Vec<Value>,
56    error: Option<Error>,
57}
58
59impl RenderContext {
60    fn new() -> Self {
61        Self {
62            sql: String::new(),
63            params: Vec::new(),
64            error: None,
65        }
66    }
67
68    fn push_param(&mut self, value: Value) -> String {
69        self.params.push(value);
70        "?".to_string()
71    }
72
73    fn fail(&mut self, message: impl Into<String>) {
74        if self.error.is_none() {
75            self.error = Some(Error::InvalidQuery(message.into()));
76        }
77    }
78
79    fn finish(self) -> Result<Sql> {
80        if let Some(err) = self.error {
81            return Err(err);
82        }
83
84        Ok(Sql {
85            text: self.sql,
86            params: self.params,
87        })
88    }
89}
90
91fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
92    render_select_body_core!(ctx, select, quote_identifier, render_expr, false, true);
93}
94
95fn mysql_function_name(name: &str) -> &str {
96    match name {
97        "json_agg" => "JSON_ARRAYAGG",
98        "json_build_object" => "JSON_OBJECT",
99        _ => name,
100    }
101}
102
103fn render_case_filtered_aggregate(
104    ctx: &mut RenderContext,
105    fn_name: &str,
106    arg: &Expr,
107    predicate: &Expr,
108) {
109    ctx.sql.push_str(fn_name);
110    ctx.sql.push_str("(CASE WHEN ");
111    render_expr(ctx, predicate);
112    ctx.sql.push_str(" THEN ");
113    render_expr(ctx, arg);
114    ctx.sql.push_str(" ELSE NULL END)");
115}
116
117fn render_filter(ctx: &mut RenderContext, expr: &Expr, predicate: &Expr) {
118    let Expr::FunctionCall { name, args } = expr else {
119        ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
120        return;
121    };
122
123    let upper = name.to_ascii_uppercase();
124    match (upper.as_str(), args.as_slice()) {
125        ("COUNT", [Expr::Star]) => {
126            ctx.sql.push_str("COUNT(CASE WHEN ");
127            render_expr(ctx, predicate);
128            ctx.sql.push_str(" THEN 1 ELSE NULL END)");
129        }
130        ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
131            render_case_filtered_aggregate(ctx, upper.as_str(), arg, predicate);
132        }
133        ("JSON_AGG", [_]) => {
134            ctx.fail(
135                "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
136            );
137        }
138        (_, []) => {
139            ctx.fail(format!(
140                "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
141                name
142            ));
143        }
144        _ => {
145            ctx.fail(format!(
146                "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
147                name,
148                args.len()
149            ));
150        }
151    }
152}
153
154fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
155    if ctx.error.is_some() {
156        return;
157    }
158
159    render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
160        Expr::Param(value) => {
161            if matches!(value, Value::Null) {
162                ctx.sql.push_str("NULL");
163            } else {
164                let placeholder = ctx.push_param(value.clone());
165                ctx.sql.push_str(&placeholder);
166            }
167        }
168        Expr::Binary { left, op, right } => {
169            if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
170                ctx.sql.push('(');
171                render_expr(ctx, left);
172                ctx.sql.push(' ');
173                ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
174                ctx.sql.push_str(" (");
175                if let Expr::List(exprs) = right.as_ref() {
176                    for (i, e) in exprs.iter().enumerate() {
177                        if i > 0 { ctx.sql.push_str(", "); }
178                        render_expr(ctx, e);
179                    }
180                } else {
181                    render_expr(ctx, right);
182                }
183                ctx.sql.push(')');
184                ctx.sql.push(')');
185            } else if matches!(op, BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps) {
186                // Array operators emulated via MySQL JSON functions.
187                // Arrays are bound as JSON strings by the connector layer.
188                match op {
189                    BinaryOp::ArrayContains => {
190                        // col @> rhs: col contains every element of rhs.
191                        // JSON_CONTAINS(target, candidate) returns 1 when the candidate is a subset of the target.
192                        ctx.sql.push_str("JSON_CONTAINS(");
193                        render_expr(ctx, left);
194                        ctx.sql.push_str(", ");
195                        render_expr(ctx, right);
196                        ctx.sql.push(')');
197                    }
198                    BinaryOp::ArrayContainedBy => {
199                        // col <@ rhs: rhs contains every element of col.
200                        ctx.sql.push_str("JSON_CONTAINS(");
201                        render_expr(ctx, right);
202                        ctx.sql.push_str(", ");
203                        render_expr(ctx, left);
204                        ctx.sql.push(')');
205                    }
206                    BinaryOp::ArrayOverlaps => {
207                        ctx.fail(
208                            "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
209                        );
210                    }
211                    _ => unreachable!(),
212                }
213            } else {
214                ctx.sql.push('(');
215                render_expr(ctx, left);
216                ctx.sql.push(' ');
217                ctx.sql.push_str(crate::binary_op_sql(op));
218                ctx.sql.push(' ');
219                render_expr(ctx, right);
220                ctx.sql.push(')');
221            }
222        }
223        Expr::FunctionCall { name, args } => {
224            let mysql_name = mysql_function_name(name);
225            ctx.sql.push_str(mysql_name);
226            ctx.sql.push('(');
227            for (i, arg) in args.iter().enumerate() {
228                if i > 0 { ctx.sql.push_str(", "); }
229                render_expr(ctx, arg);
230            }
231            ctx.sql.push(')');
232        }
233        Expr::Filter { expr, predicate } => {
234            render_filter(ctx, expr, predicate);
235        }
236    });
237}
238
239#[cfg(test)]
240mod tests {
241    use super::*;
242
243    #[test]
244    fn test_quote_identifier() {
245        assert_eq!(quote_identifier("users"), "`users`");
246        assert_eq!(quote_identifier("email"), "`email`");
247        assert_eq!(quote_identifier("foo`bar"), "`foo``bar`");
248        assert_eq!(quote_identifier("a`b`c"), "`a``b``c`");
249    }
250
251    #[test]
252    fn test_skip_without_take() {
253        let dialect = MysqlDialect;
254        let select = Select::from_table("users").skip(20).build().unwrap();
255        let sql = dialect.render_select(&select).unwrap();
256
257        assert_eq!(
258            sql.text,
259            "SELECT * FROM `users` LIMIT 18446744073709551615 OFFSET 20"
260        );
261        assert!(sql.params.is_empty());
262    }
263
264    #[test]
265    fn test_insert_returning_is_omitted() {
266        let dialect = MysqlDialect;
267        let insert = Insert::into_table("users")
268            .column(nautilus_core::ColumnMarker::new("users", "email"))
269            .values(vec![Value::String("alice@example.com".to_string())])
270            .returning(vec![
271                nautilus_core::ColumnMarker::new("users", "id"),
272                nautilus_core::ColumnMarker::new("users", "email"),
273            ])
274            .build()
275            .unwrap();
276        let sql = dialect.render_insert(&insert).unwrap();
277
278        assert_eq!(sql.text, "INSERT INTO `users` (`email`) VALUES (?)");
279        assert!(!sql.text.contains("RETURNING"));
280    }
281
282    #[test]
283    fn test_update_returning_is_omitted() {
284        let dialect = MysqlDialect;
285        let update = Update::table("users")
286            .set(
287                nautilus_core::ColumnMarker::new("users", "email"),
288                Value::String("new@example.com".to_string()),
289            )
290            .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
291            .returning(vec![
292                nautilus_core::ColumnMarker::new("users", "id"),
293                nautilus_core::ColumnMarker::new("users", "email"),
294            ])
295            .build()
296            .unwrap();
297        let sql = dialect.render_update(&update).unwrap();
298
299        assert_eq!(sql.text, "UPDATE `users` SET `email` = ? WHERE (`id` = ?)");
300        assert!(!sql.text.contains("RETURNING"));
301    }
302
303    #[test]
304    fn test_delete_returning_is_omitted() {
305        let dialect = MysqlDialect;
306        let delete = Delete::from_table("users")
307            .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
308            .returning(vec![
309                nautilus_core::ColumnMarker::new("users", "id"),
310                nautilus_core::ColumnMarker::new("users", "email"),
311            ])
312            .build()
313            .unwrap();
314        let sql = dialect.render_delete(&delete).unwrap();
315
316        assert_eq!(sql.text, "DELETE FROM `users` WHERE (`id` = ?)");
317        assert!(!sql.text.contains("RETURNING"));
318    }
319
320    #[test]
321    fn test_filter_count_star_is_emulated() {
322        let dialect = MysqlDialect;
323        let select = Select::from_table("users")
324            .computed(
325                Expr::function_call("COUNT", vec![Expr::star()])
326                    .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
327                "active_count",
328            )
329            .build()
330            .unwrap();
331
332        let sql = dialect.render_select(&select).unwrap();
333
334        assert_eq!(
335            sql.text,
336            "SELECT (COUNT(CASE WHEN (`active` = ?) THEN 1 ELSE NULL END)) AS `active_count` FROM `users`"
337        );
338        assert_eq!(sql.params, vec![Value::Bool(true)]);
339    }
340
341    #[test]
342    fn test_filter_single_arg_aggregate_is_emulated() {
343        let dialect = MysqlDialect;
344        let select = Select::from_table("users")
345            .computed(
346                Expr::function_call("SUM", vec![Expr::column("score")])
347                    .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
348                "active_score",
349            )
350            .build()
351            .unwrap();
352
353        let sql = dialect.render_select(&select).unwrap();
354
355        assert_eq!(
356            sql.text,
357            "SELECT (SUM(CASE WHEN (`active` = ?) THEN `score` ELSE NULL END)) AS `active_score` FROM `users`"
358        );
359        assert_eq!(sql.params, vec![Value::Bool(true)]);
360    }
361
362    #[test]
363    fn test_filter_multi_arg_function_is_rejected() {
364        let dialect = MysqlDialect;
365        let select = Select::from_table("users")
366            .computed(
367                Expr::function_call(
368                    "json_build_object",
369                    vec![Expr::Literal("score".to_string()), Expr::column("score")],
370                )
371                .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
372                "payload",
373            )
374            .build()
375            .unwrap();
376
377        let err = dialect.render_select(&select).unwrap_err();
378        assert!(err
379            .to_string()
380            .contains("cannot emulate FILTER for function 'json_build_object'"));
381    }
382
383    #[test]
384    fn test_array_overlaps_is_rejected() {
385        let dialect = MysqlDialect;
386        let expr = Expr::Binary {
387            left: Box::new(Expr::column("posts__tags")),
388            op: BinaryOp::ArrayOverlaps,
389            right: Box::new(Expr::param(Value::Array(vec![Value::String(
390                "rust".to_string(),
391            )]))),
392        };
393        let select = Select::from_table("posts").filter(expr).build().unwrap();
394
395        let err = dialect.render_select(&select).unwrap_err();
396        assert!(err.to_string().contains("ArrayOverlaps generically"));
397    }
398}