1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Error, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
8pub struct MysqlDialect;
9
10impl Dialect for MysqlDialect {
13 fn supports_returning(&self) -> bool {
14 false
15 }
16
17 fn render_select(&self, select: &Select) -> Result<Sql> {
18 let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(select));
19 render_select_body_core!(&mut ctx, select, '`', render_expr, false, true);
20 ctx.finish()
21 }
22
23 fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
24 let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
25 render_select_body_core_mut!(&mut ctx, &mut select, '`', render_expr_owned, false, true);
26 ctx.finish()
27 }
28
29 fn render_insert(&self, insert: &Insert) -> Result<Sql> {
30 let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(insert));
31 render_insert_body!(&mut ctx, insert, '`', false, false);
32 ctx.finish()
33 }
34
35 fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
36 let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
37 render_insert_body_mut!(&mut ctx, &mut insert, '`', false, false);
38 ctx.finish()
39 }
40
41 fn render_update(&self, update: &Update) -> Result<Sql> {
42 let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(update));
43 render_update_body!(&mut ctx, update, '`', render_expr, false, false);
44 ctx.finish()
45 }
46
47 fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
48 let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
49 render_update_body_mut!(&mut ctx, &mut update, '`', render_expr_owned, false, false);
50 ctx.finish()
51 }
52
53 fn render_delete(&self, delete: &Delete) -> Result<Sql> {
54 let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(delete));
55 render_delete_body!(&mut ctx, delete, '`', render_expr, false);
56 ctx.finish()
57 }
58
59 fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
60 let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
61 render_delete_body_mut!(&mut ctx, &mut delete, '`', render_expr_owned, false);
62 ctx.finish()
63 }
64}
65
66struct RenderContext {
67 sql: String,
68 params: Vec<Value>,
69 error: Option<Error>,
70}
71
72impl RenderContext {
73 fn with_estimate(estimate: crate::RenderEstimate) -> Self {
74 Self {
75 sql: String::with_capacity(estimate.sql_capacity),
76 params: Vec::with_capacity(estimate.params_capacity),
77 error: None,
78 }
79 }
80
81 fn push_param(&mut self, value: Value) {
82 self.params.push(value);
83 self.sql.push('?');
84 }
85
86 fn take_param(&mut self, value: &mut Value) {
87 self.push_param(std::mem::replace(value, Value::Null));
88 }
89
90 fn fail(&mut self, message: impl Into<String>) {
91 if self.error.is_none() {
92 self.error = Some(Error::InvalidQuery(message.into()));
93 }
94 }
95
96 fn finish(self) -> Result<Sql> {
97 if let Some(err) = self.error {
98 return Err(err);
99 }
100
101 Ok(Sql {
102 text: self.sql,
103 params: self.params,
104 })
105 }
106}
107
108fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
109 render_select_body_core!(ctx, select, '`', render_expr, false, true);
110}
111
112fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
113 render_select_body_core_mut!(ctx, select, '`', render_expr_owned, false, true);
114}
115
116fn mysql_function_name(name: &str) -> &str {
117 match name {
118 "json_agg" => "JSON_ARRAYAGG",
119 "json_build_object" => "JSON_OBJECT",
120 _ => name,
121 }
122}
123
124fn render_case_filtered_aggregate(
125 ctx: &mut RenderContext,
126 fn_name: &str,
127 arg: &Expr,
128 predicate: &Expr,
129) {
130 ctx.sql.push_str(fn_name);
131 ctx.sql.push_str("(CASE WHEN ");
132 render_expr(ctx, predicate);
133 ctx.sql.push_str(" THEN ");
134 render_expr(ctx, arg);
135 ctx.sql.push_str(" ELSE NULL END)");
136}
137
138fn render_case_filtered_aggregate_owned(
139 ctx: &mut RenderContext,
140 fn_name: &str,
141 arg: &mut Expr,
142 predicate: &mut Expr,
143) {
144 ctx.sql.push_str(fn_name);
145 ctx.sql.push_str("(CASE WHEN ");
146 render_expr_owned(ctx, predicate);
147 ctx.sql.push_str(" THEN ");
148 render_expr_owned(ctx, arg);
149 ctx.sql.push_str(" ELSE NULL END)");
150}
151
152fn render_filter(ctx: &mut RenderContext, expr: &Expr, predicate: &Expr) {
153 let Expr::FunctionCall { name, args } = expr else {
154 ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
155 return;
156 };
157
158 let upper = name.to_ascii_uppercase();
159 match (upper.as_str(), args.as_slice()) {
160 ("COUNT", [Expr::Star]) => {
161 ctx.sql.push_str("COUNT(CASE WHEN ");
162 render_expr(ctx, predicate);
163 ctx.sql.push_str(" THEN 1 ELSE NULL END)");
164 }
165 ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
166 render_case_filtered_aggregate(ctx, upper.as_str(), arg, predicate);
167 }
168 ("JSON_AGG", [_]) => {
169 ctx.fail(
170 "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
171 );
172 }
173 (_, []) => {
174 ctx.fail(format!(
175 "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
176 name
177 ));
178 }
179 _ => {
180 ctx.fail(format!(
181 "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
182 name,
183 args.len()
184 ));
185 }
186 }
187}
188
189fn render_filter_owned(ctx: &mut RenderContext, expr: &mut Expr, predicate: &mut Expr) {
190 let Expr::FunctionCall { name, args } = expr else {
191 ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
192 return;
193 };
194
195 let upper = name.to_ascii_uppercase();
196 match (upper.as_str(), args.as_mut_slice()) {
197 ("COUNT", [Expr::Star]) => {
198 ctx.sql.push_str("COUNT(CASE WHEN ");
199 render_expr_owned(ctx, predicate);
200 ctx.sql.push_str(" THEN 1 ELSE NULL END)");
201 }
202 ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
203 render_case_filtered_aggregate_owned(ctx, upper.as_str(), arg, predicate);
204 }
205 ("JSON_AGG", [_]) => {
206 ctx.fail(
207 "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
208 );
209 }
210 (_, []) => {
211 ctx.fail(format!(
212 "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
213 name
214 ));
215 }
216 _ => {
217 ctx.fail(format!(
218 "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
219 name,
220 args.len()
221 ));
222 }
223 }
224}
225
226fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
227 if ctx.error.is_some() {
228 return;
229 }
230
231 render_expr_common!(ctx, expr, '`', render_expr, render_select_body, {
232 Expr::Param(value) => {
233 if matches!(value, Value::Null) {
234 ctx.sql.push_str("NULL");
235 } else {
236 ctx.push_param(value.clone());
237 }
238 }
239 Expr::Binary { left, op, right } => {
240 if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
241 ctx.sql.push('(');
242 render_expr(ctx, left);
243 ctx.sql.push(' ');
244 ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
245 ctx.sql.push_str(" (");
246 if let Expr::List(exprs) = right.as_ref() {
247 for (i, e) in exprs.iter().enumerate() {
248 if i > 0 { ctx.sql.push_str(", "); }
249 render_expr(ctx, e);
250 }
251 } else {
252 render_expr(ctx, right);
253 }
254 ctx.sql.push(')');
255 ctx.sql.push(')');
256 } else if matches!(op, BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps) {
257 match op {
260 BinaryOp::ArrayContains => {
261 ctx.sql.push_str("JSON_CONTAINS(");
264 render_expr(ctx, left);
265 ctx.sql.push_str(", ");
266 render_expr(ctx, right);
267 ctx.sql.push(')');
268 }
269 BinaryOp::ArrayContainedBy => {
270 ctx.sql.push_str("JSON_CONTAINS(");
272 render_expr(ctx, right);
273 ctx.sql.push_str(", ");
274 render_expr(ctx, left);
275 ctx.sql.push(')');
276 }
277 BinaryOp::ArrayOverlaps => {
278 ctx.fail(
279 "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
280 );
281 }
282 _ => unreachable!(),
283 }
284 } else {
285 ctx.sql.push('(');
286 render_expr(ctx, left);
287 ctx.sql.push(' ');
288 ctx.sql.push_str(crate::binary_op_sql(op));
289 ctx.sql.push(' ');
290 render_expr(ctx, right);
291 ctx.sql.push(')');
292 }
293 }
294 Expr::FunctionCall { name, args } => {
295 let mysql_name = mysql_function_name(name);
296 ctx.sql.push_str(mysql_name);
297 ctx.sql.push('(');
298 for (i, arg) in args.iter().enumerate() {
299 if i > 0 { ctx.sql.push_str(", "); }
300 render_expr(ctx, arg);
301 }
302 ctx.sql.push(')');
303 }
304 Expr::Filter { expr, predicate } => {
305 render_filter(ctx, expr, predicate);
306 }
307 });
308}
309
310fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
311 if ctx.error.is_some() {
312 return;
313 }
314
315 render_expr_common_mut!(ctx, expr, '`', render_expr_owned, render_select_body_owned, {
316 Expr::Param(value) => {
317 if matches!(value, Value::Null) {
318 ctx.sql.push_str("NULL");
319 } else {
320 ctx.take_param(value);
321 }
322 }
323 Expr::Binary { left, op, right } => {
324 if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
325 ctx.sql.push('(');
326 render_expr_owned(ctx, left.as_mut());
327 ctx.sql.push(' ');
328 ctx.sql
329 .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
330 ctx.sql.push_str(" (");
331 if let Expr::List(exprs) = right.as_mut() {
332 for (i, e) in exprs.iter_mut().enumerate() {
333 if i > 0 {
334 ctx.sql.push_str(", ");
335 }
336 render_expr_owned(ctx, e);
337 }
338 } else {
339 render_expr_owned(ctx, right.as_mut());
340 }
341 ctx.sql.push(')');
342 ctx.sql.push(')');
343 } else if matches!(
344 *op,
345 BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps
346 ) {
347 match *op {
348 BinaryOp::ArrayContains => {
349 ctx.sql.push_str("JSON_CONTAINS(");
350 render_expr_owned(ctx, left.as_mut());
351 ctx.sql.push_str(", ");
352 render_expr_owned(ctx, right.as_mut());
353 ctx.sql.push(')');
354 }
355 BinaryOp::ArrayContainedBy => {
356 ctx.sql.push_str("JSON_CONTAINS(");
357 render_expr_owned(ctx, right.as_mut());
358 ctx.sql.push_str(", ");
359 render_expr_owned(ctx, left.as_mut());
360 ctx.sql.push(')');
361 }
362 BinaryOp::ArrayOverlaps => {
363 ctx.fail(
364 "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
365 );
366 }
367 _ => unreachable!(),
368 }
369 } else {
370 ctx.sql.push('(');
371 render_expr_owned(ctx, left.as_mut());
372 ctx.sql.push(' ');
373 ctx.sql.push_str(crate::binary_op_sql(op));
374 ctx.sql.push(' ');
375 render_expr_owned(ctx, right.as_mut());
376 ctx.sql.push(')');
377 }
378 }
379 Expr::FunctionCall { name, args } => {
380 let mysql_name = mysql_function_name(name);
381 ctx.sql.push_str(mysql_name);
382 ctx.sql.push('(');
383 for (i, arg) in args.iter_mut().enumerate() {
384 if i > 0 {
385 ctx.sql.push_str(", ");
386 }
387 render_expr_owned(ctx, arg);
388 }
389 ctx.sql.push(')');
390 }
391 Expr::Filter { expr, predicate } => {
392 render_filter_owned(ctx, expr.as_mut(), predicate.as_mut());
393 }
394 });
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400
401 fn quote_identifier(name: &str) -> String {
402 let mut sql = String::new();
403 crate::push_quoted_identifier(&mut sql, name, '`');
404 sql
405 }
406
407 #[test]
408 fn test_quote_identifier() {
409 assert_eq!(quote_identifier("users"), "`users`");
410 assert_eq!(quote_identifier("email"), "`email`");
411 assert_eq!(quote_identifier("foo`bar"), "`foo``bar`");
412 assert_eq!(quote_identifier("a`b`c"), "`a``b``c`");
413 }
414
415 #[test]
416 fn test_skip_without_take() {
417 let dialect = MysqlDialect;
418 let select = Select::from_table("users").skip(20).build().unwrap();
419 let sql = dialect.render_select(&select).unwrap();
420
421 assert_eq!(
422 sql.text,
423 "SELECT * FROM `users` LIMIT 18446744073709551615 OFFSET 20"
424 );
425 assert!(sql.params.is_empty());
426 }
427
428 #[test]
429 fn test_insert_returning_is_omitted() {
430 let dialect = MysqlDialect;
431 let insert = Insert::into_table("users")
432 .column(nautilus_core::ColumnMarker::new("users", "email"))
433 .values(vec![Value::String("alice@example.com".to_string())])
434 .returning(vec![
435 nautilus_core::ColumnMarker::new("users", "id"),
436 nautilus_core::ColumnMarker::new("users", "email"),
437 ])
438 .build()
439 .unwrap();
440 let sql = dialect.render_insert(&insert).unwrap();
441
442 assert_eq!(sql.text, "INSERT INTO `users` (`email`) VALUES (?)");
443 assert!(!sql.text.contains("RETURNING"));
444 }
445
446 #[test]
447 fn test_update_returning_is_omitted() {
448 let dialect = MysqlDialect;
449 let update = Update::table("users")
450 .set(
451 nautilus_core::ColumnMarker::new("users", "email"),
452 Value::String("new@example.com".to_string()),
453 )
454 .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
455 .returning(vec![
456 nautilus_core::ColumnMarker::new("users", "id"),
457 nautilus_core::ColumnMarker::new("users", "email"),
458 ])
459 .build()
460 .unwrap();
461 let sql = dialect.render_update(&update).unwrap();
462
463 assert_eq!(sql.text, "UPDATE `users` SET `email` = ? WHERE (`id` = ?)");
464 assert!(!sql.text.contains("RETURNING"));
465 }
466
467 #[test]
468 fn test_delete_returning_is_omitted() {
469 let dialect = MysqlDialect;
470 let delete = Delete::from_table("users")
471 .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
472 .returning(vec![
473 nautilus_core::ColumnMarker::new("users", "id"),
474 nautilus_core::ColumnMarker::new("users", "email"),
475 ])
476 .build()
477 .unwrap();
478 let sql = dialect.render_delete(&delete).unwrap();
479
480 assert_eq!(sql.text, "DELETE FROM `users` WHERE (`id` = ?)");
481 assert!(!sql.text.contains("RETURNING"));
482 }
483
484 #[test]
485 fn test_filter_count_star_is_emulated() {
486 let dialect = MysqlDialect;
487 let select = Select::from_table("users")
488 .computed(
489 Expr::function_call("COUNT", vec![Expr::star()])
490 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
491 "active_count",
492 )
493 .build()
494 .unwrap();
495
496 let sql = dialect.render_select(&select).unwrap();
497
498 assert_eq!(
499 sql.text,
500 "SELECT (COUNT(CASE WHEN (`active` = ?) THEN 1 ELSE NULL END)) AS `active_count` FROM `users`"
501 );
502 assert_eq!(sql.params, vec![Value::Bool(true)]);
503 }
504
505 #[test]
506 fn test_filter_single_arg_aggregate_is_emulated() {
507 let dialect = MysqlDialect;
508 let select = Select::from_table("users")
509 .computed(
510 Expr::function_call("SUM", vec![Expr::column("score")])
511 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
512 "active_score",
513 )
514 .build()
515 .unwrap();
516
517 let sql = dialect.render_select(&select).unwrap();
518
519 assert_eq!(
520 sql.text,
521 "SELECT (SUM(CASE WHEN (`active` = ?) THEN `score` ELSE NULL END)) AS `active_score` FROM `users`"
522 );
523 assert_eq!(sql.params, vec![Value::Bool(true)]);
524 }
525
526 #[test]
527 fn test_filter_multi_arg_function_is_rejected() {
528 let dialect = MysqlDialect;
529 let select = Select::from_table("users")
530 .computed(
531 Expr::function_call(
532 "json_build_object",
533 vec![Expr::Literal("score".to_string()), Expr::column("score")],
534 )
535 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
536 "payload",
537 )
538 .build()
539 .unwrap();
540
541 let err = dialect.render_select(&select).unwrap_err();
542 assert!(err
543 .to_string()
544 .contains("cannot emulate FILTER for function 'json_build_object'"));
545 }
546
547 #[test]
548 fn test_array_overlaps_is_rejected() {
549 let dialect = MysqlDialect;
550 let expr = Expr::Binary {
551 left: Box::new(Expr::column("posts__tags")),
552 op: BinaryOp::ArrayOverlaps,
553 right: Box::new(Expr::param(Value::Array(vec![Value::String(
554 "rust".to_string(),
555 )]))),
556 };
557 let select = Select::from_table("posts").filter(expr).build().unwrap();
558
559 let err = dialect.render_select(&select).unwrap_err();
560 assert!(err.to_string().contains("ArrayOverlaps generically"));
561 }
562}