1use crate::{Dialect, Sql};
4use nautilus_core::{
5 BinaryOp, Delete, Error, Expr, Insert, JsonPathCast, Result, Select, Update, Value,
6};
7
8#[derive(Debug, Clone, Copy)]
10pub struct MysqlDialect;
11
12impl Dialect for MysqlDialect {
15 fn supports_returning(&self) -> bool {
16 false
17 }
18
19 fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
20 let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
21 render_select_body_core_mut!(&mut ctx, &mut select, '`', render_expr_owned, false, true);
22 ctx.finish()
23 }
24
25 fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
26 let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
27 render_insert_body_mut!(&mut ctx, &mut insert, '`', false, false);
28 ctx.finish()
29 }
30
31 fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
32 let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
33 render_update_body_mut!(&mut ctx, &mut update, '`', render_expr_owned, false, false);
34 ctx.finish()
35 }
36
37 fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
38 let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
39 render_delete_body_mut!(&mut ctx, &mut delete, '`', render_expr_owned, false);
40 ctx.finish()
41 }
42}
43
44struct RenderContext {
45 sql: String,
46 params: Vec<Value>,
47 error: Option<Error>,
48}
49
50impl RenderContext {
51 fn with_estimate(estimate: crate::RenderEstimate) -> Self {
52 Self {
53 sql: String::with_capacity(estimate.sql_capacity),
54 params: Vec::with_capacity(estimate.params_capacity),
55 error: None,
56 }
57 }
58
59 fn push_param(&mut self, value: Value) {
60 self.params.push(value);
61 self.sql.push('?');
62 }
63
64 fn take_param(&mut self, value: &mut Value) {
65 self.push_param(std::mem::replace(value, Value::Null));
66 }
67
68 fn fail(&mut self, message: impl Into<String>) {
69 if self.error.is_none() {
70 self.error = Some(Error::InvalidQuery(message.into()));
71 }
72 }
73
74 fn finish(self) -> Result<Sql> {
75 if let Some(err) = self.error {
76 return Err(err);
77 }
78
79 Ok(Sql {
80 text: self.sql,
81 params: self.params,
82 })
83 }
84}
85
86fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
87 render_select_body_core_mut!(ctx, select, '`', render_expr_owned, false, true);
88}
89
90fn mysql_function_name(name: &str) -> &str {
91 match name {
92 "json_agg" => "JSON_ARRAYAGG",
93 "json_build_object" => "JSON_OBJECT",
94 _ => name,
95 }
96}
97
98fn render_case_filtered_aggregate_owned(
99 ctx: &mut RenderContext,
100 fn_name: &str,
101 arg: &mut Expr,
102 predicate: &mut Expr,
103) {
104 ctx.sql.push_str(fn_name);
105 ctx.sql.push_str("(CASE WHEN ");
106 render_expr_owned(ctx, predicate);
107 ctx.sql.push_str(" THEN ");
108 render_expr_owned(ctx, arg);
109 ctx.sql.push_str(" ELSE NULL END)");
110}
111
112fn render_filter_owned(ctx: &mut RenderContext, expr: &mut Expr, predicate: &mut Expr) {
113 let Expr::FunctionCall { name, args } = expr else {
114 ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
115 return;
116 };
117
118 let upper = name.to_ascii_uppercase();
119 match (upper.as_str(), args.as_mut_slice()) {
120 ("COUNT", [Expr::Star]) => {
121 ctx.sql.push_str("COUNT(CASE WHEN ");
122 render_expr_owned(ctx, predicate);
123 ctx.sql.push_str(" THEN 1 ELSE NULL END)");
124 }
125 ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
126 render_case_filtered_aggregate_owned(ctx, upper.as_str(), arg, predicate);
127 }
128 ("JSON_AGG", [_]) => {
129 ctx.fail(
130 "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
131 );
132 }
133 (_, []) => {
134 ctx.fail(format!(
135 "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
136 name
137 ));
138 }
139 _ => {
140 ctx.fail(format!(
141 "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
142 name,
143 args.len()
144 ));
145 }
146 }
147}
148
149fn render_json_extract_unquoted(ctx: &mut RenderContext, table: &str, column: &str, key: &str) {
150 ctx.sql.push_str("JSON_UNQUOTE(JSON_EXTRACT(");
151 crate::push_qualified_identifier(&mut ctx.sql, table, column, '`');
152 ctx.sql.push_str(", ");
153 crate::push_json_object_path_literal(&mut ctx.sql, key);
154 ctx.sql.push_str("))");
155}
156
157fn render_composite_field_owned(
158 ctx: &mut RenderContext,
159 table: &str,
160 column: &str,
161 key: &str,
162 cast: JsonPathCast,
163) {
164 match cast {
165 JsonPathCast::None => render_json_extract_unquoted(ctx, table, column, key),
166 JsonPathCast::Signed => {
167 ctx.sql.push_str("CAST(");
168 render_json_extract_unquoted(ctx, table, column, key);
169 ctx.sql.push_str(" AS SIGNED)");
170 }
171 JsonPathCast::Double => {
172 ctx.sql.push_str("CAST(");
173 render_json_extract_unquoted(ctx, table, column, key);
174 ctx.sql.push_str(" AS DOUBLE)");
175 }
176 JsonPathCast::Decimal => {
177 ctx.sql.push_str("CAST(");
178 render_json_extract_unquoted(ctx, table, column, key);
179 ctx.sql.push_str(" AS DECIMAL(65, 30))");
180 }
181 }
182}
183
184fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
185 if ctx.error.is_some() {
186 return;
187 }
188
189 render_expr_common_mut!(ctx, expr, '`', render_expr_owned, render_select_body_owned, {
190 Expr::CompositeField {
191 table,
192 column,
193 json_key,
194 json_cast,
195 ..
196 } => {
197 render_composite_field_owned(ctx, table, column, json_key, *json_cast);
198 }
199 Expr::Param(value) => {
200 if matches!(value, Value::Null) {
201 ctx.sql.push_str("NULL");
202 } else {
203 ctx.take_param(value);
204 }
205 }
206 Expr::Binary { left, op, right } => {
207 if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
208 ctx.sql.push('(');
209 render_expr_owned(ctx, left.as_mut());
210 ctx.sql.push(' ');
211 ctx.sql
212 .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
213 ctx.sql.push_str(" (");
214 if let Expr::List(exprs) = right.as_mut() {
215 for (i, e) in exprs.iter_mut().enumerate() {
216 if i > 0 {
217 ctx.sql.push_str(", ");
218 }
219 render_expr_owned(ctx, e);
220 }
221 } else {
222 render_expr_owned(ctx, right.as_mut());
223 }
224 ctx.sql.push(')');
225 ctx.sql.push(')');
226 } else if matches!(
227 *op,
228 BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps
229 ) {
230 match *op {
231 BinaryOp::ArrayContains => {
232 ctx.sql.push_str("JSON_CONTAINS(");
233 render_expr_owned(ctx, left.as_mut());
234 ctx.sql.push_str(", ");
235 render_expr_owned(ctx, right.as_mut());
236 ctx.sql.push(')');
237 }
238 BinaryOp::ArrayContainedBy => {
239 ctx.sql.push_str("JSON_CONTAINS(");
240 render_expr_owned(ctx, right.as_mut());
241 ctx.sql.push_str(", ");
242 render_expr_owned(ctx, left.as_mut());
243 ctx.sql.push(')');
244 }
245 BinaryOp::ArrayOverlaps => {
246 ctx.fail(
247 "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
248 );
249 }
250 _ => unreachable!(),
251 }
252 } else {
253 ctx.sql.push('(');
254 render_expr_owned(ctx, left.as_mut());
255 ctx.sql.push(' ');
256 ctx.sql.push_str(crate::binary_op_sql(op));
257 ctx.sql.push(' ');
258 render_expr_owned(ctx, right.as_mut());
259 ctx.sql.push(')');
260 }
261 }
262 Expr::FunctionCall { name, args } => {
263 let mysql_name = mysql_function_name(name);
264 ctx.sql.push_str(mysql_name);
265 ctx.sql.push('(');
266 for (i, arg) in args.iter_mut().enumerate() {
267 if i > 0 {
268 ctx.sql.push_str(", ");
269 }
270 render_expr_owned(ctx, arg);
271 }
272 ctx.sql.push(')');
273 }
274 Expr::Filter { expr, predicate } => {
275 render_filter_owned(ctx, expr.as_mut(), predicate.as_mut());
276 }
277 });
278}
279
280#[cfg(test)]
281mod tests {
282 use super::*;
283
284 fn quote_identifier(name: &str) -> String {
285 let mut sql = String::new();
286 crate::push_quoted_identifier(&mut sql, name, '`');
287 sql
288 }
289
290 #[test]
291 fn test_quote_identifier() {
292 assert_eq!(quote_identifier("users"), "`users`");
293 assert_eq!(quote_identifier("email"), "`email`");
294 assert_eq!(quote_identifier("foo`bar"), "`foo``bar`");
295 assert_eq!(quote_identifier("a`b`c"), "`a``b``c`");
296 }
297
298 #[test]
299 fn test_skip_without_take() {
300 let dialect = MysqlDialect;
301 let select = Select::from_table("users").skip(20).build().unwrap();
302 let sql = dialect.render_select(&select).unwrap();
303
304 assert_eq!(
305 sql.text,
306 "SELECT * FROM `users` LIMIT 18446744073709551615 OFFSET 20"
307 );
308 assert!(sql.params.is_empty());
309 }
310
311 #[test]
312 fn test_insert_returning_is_omitted() {
313 let dialect = MysqlDialect;
314 let insert = Insert::into_table("users")
315 .column(nautilus_core::ColumnMarker::new("users", "email"))
316 .values(vec![Value::String("alice@example.com".to_string())])
317 .returning(vec![
318 nautilus_core::ColumnMarker::new("users", "id"),
319 nautilus_core::ColumnMarker::new("users", "email"),
320 ])
321 .build()
322 .unwrap();
323 let sql = dialect.render_insert(&insert).unwrap();
324
325 assert_eq!(sql.text, "INSERT INTO `users` (`email`) VALUES (?)");
326 assert!(!sql.text.contains("RETURNING"));
327 }
328
329 #[test]
330 fn test_update_returning_is_omitted() {
331 let dialect = MysqlDialect;
332 let update = Update::table("users")
333 .set(
334 nautilus_core::ColumnMarker::new("users", "email"),
335 Value::String("new@example.com".to_string()),
336 )
337 .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
338 .returning(vec![
339 nautilus_core::ColumnMarker::new("users", "id"),
340 nautilus_core::ColumnMarker::new("users", "email"),
341 ])
342 .build()
343 .unwrap();
344 let sql = dialect.render_update(&update).unwrap();
345
346 assert_eq!(sql.text, "UPDATE `users` SET `email` = ? WHERE (`id` = ?)");
347 assert!(!sql.text.contains("RETURNING"));
348 }
349
350 #[test]
351 fn test_delete_returning_is_omitted() {
352 let dialect = MysqlDialect;
353 let delete = Delete::from_table("users")
354 .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
355 .returning(vec![
356 nautilus_core::ColumnMarker::new("users", "id"),
357 nautilus_core::ColumnMarker::new("users", "email"),
358 ])
359 .build()
360 .unwrap();
361 let sql = dialect.render_delete(&delete).unwrap();
362
363 assert_eq!(sql.text, "DELETE FROM `users` WHERE (`id` = ?)");
364 assert!(!sql.text.contains("RETURNING"));
365 }
366
367 #[test]
368 fn test_filter_count_star_is_emulated() {
369 let dialect = MysqlDialect;
370 let select = Select::from_table("users")
371 .computed(
372 Expr::function_call("COUNT", vec![Expr::star()])
373 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
374 "active_count",
375 )
376 .build()
377 .unwrap();
378
379 let sql = dialect.render_select(&select).unwrap();
380
381 assert_eq!(
382 sql.text,
383 "SELECT (COUNT(CASE WHEN (`active` = ?) THEN 1 ELSE NULL END)) AS `active_count` FROM `users`"
384 );
385 assert_eq!(sql.params, vec![Value::Bool(true)]);
386 }
387
388 #[test]
389 fn test_filter_single_arg_aggregate_is_emulated() {
390 let dialect = MysqlDialect;
391 let select = Select::from_table("users")
392 .computed(
393 Expr::function_call("SUM", vec![Expr::column("score")])
394 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
395 "active_score",
396 )
397 .build()
398 .unwrap();
399
400 let sql = dialect.render_select(&select).unwrap();
401
402 assert_eq!(
403 sql.text,
404 "SELECT (SUM(CASE WHEN (`active` = ?) THEN `score` ELSE NULL END)) AS `active_score` FROM `users`"
405 );
406 assert_eq!(sql.params, vec![Value::Bool(true)]);
407 }
408
409 #[test]
410 fn test_filter_multi_arg_function_is_rejected() {
411 let dialect = MysqlDialect;
412 let select = Select::from_table("users")
413 .computed(
414 Expr::function_call(
415 "json_build_object",
416 vec![
417 Expr::Literal(nautilus_core::LiteralSql::from_static("score")),
418 Expr::column("score"),
419 ],
420 )
421 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
422 "payload",
423 )
424 .build()
425 .unwrap();
426
427 let err = dialect.render_select(&select).unwrap_err();
428 assert!(err
429 .to_string()
430 .contains("cannot emulate FILTER for function 'json_build_object'"));
431 }
432
433 #[test]
434 fn test_array_overlaps_is_rejected() {
435 let dialect = MysqlDialect;
436 let expr = Expr::Binary {
437 left: Box::new(Expr::column("posts__tags")),
438 op: BinaryOp::ArrayOverlaps,
439 right: Box::new(Expr::param(Value::Array(vec![Value::String(
440 "rust".to_string(),
441 )]))),
442 };
443 let select = Select::from_table("posts").filter(expr).build().unwrap();
444
445 let err = dialect.render_select(&select).unwrap_err();
446 assert!(err.to_string().contains("ArrayOverlaps generically"));
447 }
448
449 #[test]
450 fn composite_field_ordering_uses_json_extract_with_numeric_cast() {
451 let dialect = MysqlDialect;
452 let select = Select::from_table("shipments")
453 .order_by_expr(
454 Expr::composite_field(
455 "shipments",
456 "delivery_snapshot",
457 "eta_minutes",
458 "etaMinutes",
459 JsonPathCast::Signed,
460 ),
461 nautilus_core::OrderDir::Asc,
462 )
463 .build()
464 .unwrap();
465 let sql = dialect.render_select(&select).unwrap();
466
467 assert_eq!(
468 sql.text,
469 "SELECT * FROM `shipments` ORDER BY CAST(JSON_UNQUOTE(JSON_EXTRACT(`shipments`.`delivery_snapshot`, '$.\"etaMinutes\"')) AS SIGNED) ASC"
470 );
471 }
472}