1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Error, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
8pub struct MysqlDialect;
9
10impl Dialect for MysqlDialect {
13 fn supports_returning(&self) -> bool {
14 false
15 }
16
17 fn render_select(&self, select: &Select) -> Result<Sql> {
18 let mut ctx = RenderContext::new();
19 render_select_body_core!(&mut ctx, select, quote_identifier, render_expr, false, true);
20 ctx.finish()
21 }
22
23 fn render_insert(&self, insert: &Insert) -> Result<Sql> {
24 let mut ctx = RenderContext::new();
25 render_insert_body!(&mut ctx, insert, quote_identifier, false, false);
26 ctx.finish()
27 }
28
29 fn render_update(&self, update: &Update) -> Result<Sql> {
30 let mut ctx = RenderContext::new();
31 render_update_body!(
32 &mut ctx,
33 update,
34 quote_identifier,
35 render_expr,
36 false,
37 false
38 );
39 ctx.finish()
40 }
41
42 fn render_delete(&self, delete: &Delete) -> Result<Sql> {
43 let mut ctx = RenderContext::new();
44 render_delete_body!(&mut ctx, delete, quote_identifier, render_expr, false);
45 ctx.finish()
46 }
47}
48
49fn quote_identifier(name: &str) -> String {
50 crate::backtick_quote_identifier(name)
51}
52
53struct RenderContext {
54 sql: String,
55 params: Vec<Value>,
56 error: Option<Error>,
57}
58
59impl RenderContext {
60 fn new() -> Self {
61 Self {
62 sql: String::new(),
63 params: Vec::new(),
64 error: None,
65 }
66 }
67
68 fn push_param(&mut self, value: Value) -> String {
69 self.params.push(value);
70 "?".to_string()
71 }
72
73 fn fail(&mut self, message: impl Into<String>) {
74 if self.error.is_none() {
75 self.error = Some(Error::InvalidQuery(message.into()));
76 }
77 }
78
79 fn finish(self) -> Result<Sql> {
80 if let Some(err) = self.error {
81 return Err(err);
82 }
83
84 Ok(Sql {
85 text: self.sql,
86 params: self.params,
87 })
88 }
89}
90
91fn render_select_body(ctx: &mut RenderContext, select: &crate::Select) {
92 render_select_body_core!(ctx, select, quote_identifier, render_expr, false, true);
93}
94
95fn mysql_function_name(name: &str) -> &str {
96 match name {
97 "json_agg" => "JSON_ARRAYAGG",
98 "json_build_object" => "JSON_OBJECT",
99 _ => name,
100 }
101}
102
103fn render_case_filtered_aggregate(
104 ctx: &mut RenderContext,
105 fn_name: &str,
106 arg: &Expr,
107 predicate: &Expr,
108) {
109 ctx.sql.push_str(fn_name);
110 ctx.sql.push_str("(CASE WHEN ");
111 render_expr(ctx, predicate);
112 ctx.sql.push_str(" THEN ");
113 render_expr(ctx, arg);
114 ctx.sql.push_str(" ELSE NULL END)");
115}
116
117fn render_filter(ctx: &mut RenderContext, expr: &Expr, predicate: &Expr) {
118 let Expr::FunctionCall { name, args } = expr else {
119 ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
120 return;
121 };
122
123 let upper = name.to_ascii_uppercase();
124 match (upper.as_str(), args.as_slice()) {
125 ("COUNT", [Expr::Star]) => {
126 ctx.sql.push_str("COUNT(CASE WHEN ");
127 render_expr(ctx, predicate);
128 ctx.sql.push_str(" THEN 1 ELSE NULL END)");
129 }
130 ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
131 render_case_filtered_aggregate(ctx, upper.as_str(), arg, predicate);
132 }
133 ("JSON_AGG", [_]) => {
134 ctx.fail(
135 "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
136 );
137 }
138 (_, []) => {
139 ctx.fail(format!(
140 "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
141 name
142 ));
143 }
144 _ => {
145 ctx.fail(format!(
146 "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
147 name,
148 args.len()
149 ));
150 }
151 }
152}
153
154fn render_expr(ctx: &mut RenderContext, expr: &Expr) {
155 if ctx.error.is_some() {
156 return;
157 }
158
159 render_expr_common!(ctx, expr, quote_identifier, render_expr, render_select_body, {
160 Expr::Param(value) => {
161 if matches!(value, Value::Null) {
162 ctx.sql.push_str("NULL");
163 } else {
164 let placeholder = ctx.push_param(value.clone());
165 ctx.sql.push_str(&placeholder);
166 }
167 }
168 Expr::Binary { left, op, right } => {
169 if matches!(op, BinaryOp::In | BinaryOp::NotIn) {
170 ctx.sql.push('(');
171 render_expr(ctx, left);
172 ctx.sql.push(' ');
173 ctx.sql.push_str(if matches!(op, BinaryOp::In) { "IN" } else { "NOT IN" });
174 ctx.sql.push_str(" (");
175 if let Expr::List(exprs) = right.as_ref() {
176 for (i, e) in exprs.iter().enumerate() {
177 if i > 0 { ctx.sql.push_str(", "); }
178 render_expr(ctx, e);
179 }
180 } else {
181 render_expr(ctx, right);
182 }
183 ctx.sql.push(')');
184 ctx.sql.push(')');
185 } else if matches!(op, BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps) {
186 match op {
189 BinaryOp::ArrayContains => {
190 ctx.sql.push_str("JSON_CONTAINS(");
193 render_expr(ctx, left);
194 ctx.sql.push_str(", ");
195 render_expr(ctx, right);
196 ctx.sql.push(')');
197 }
198 BinaryOp::ArrayContainedBy => {
199 ctx.sql.push_str("JSON_CONTAINS(");
201 render_expr(ctx, right);
202 ctx.sql.push_str(", ");
203 render_expr(ctx, left);
204 ctx.sql.push(')');
205 }
206 BinaryOp::ArrayOverlaps => {
207 ctx.fail(
208 "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
209 );
210 }
211 _ => unreachable!(),
212 }
213 } else {
214 ctx.sql.push('(');
215 render_expr(ctx, left);
216 ctx.sql.push(' ');
217 ctx.sql.push_str(crate::binary_op_sql(op));
218 ctx.sql.push(' ');
219 render_expr(ctx, right);
220 ctx.sql.push(')');
221 }
222 }
223 Expr::FunctionCall { name, args } => {
224 let mysql_name = mysql_function_name(name);
225 ctx.sql.push_str(mysql_name);
226 ctx.sql.push('(');
227 for (i, arg) in args.iter().enumerate() {
228 if i > 0 { ctx.sql.push_str(", "); }
229 render_expr(ctx, arg);
230 }
231 ctx.sql.push(')');
232 }
233 Expr::Filter { expr, predicate } => {
234 render_filter(ctx, expr, predicate);
235 }
236 });
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
242
243 #[test]
244 fn test_quote_identifier() {
245 assert_eq!(quote_identifier("users"), "`users`");
246 assert_eq!(quote_identifier("email"), "`email`");
247 assert_eq!(quote_identifier("foo`bar"), "`foo``bar`");
248 assert_eq!(quote_identifier("a`b`c"), "`a``b``c`");
249 }
250
251 #[test]
252 fn test_skip_without_take() {
253 let dialect = MysqlDialect;
254 let select = Select::from_table("users").skip(20).build().unwrap();
255 let sql = dialect.render_select(&select).unwrap();
256
257 assert_eq!(
258 sql.text,
259 "SELECT * FROM `users` LIMIT 18446744073709551615 OFFSET 20"
260 );
261 assert!(sql.params.is_empty());
262 }
263
264 #[test]
265 fn test_insert_returning_is_omitted() {
266 let dialect = MysqlDialect;
267 let insert = Insert::into_table("users")
268 .column(nautilus_core::ColumnMarker::new("users", "email"))
269 .values(vec![Value::String("alice@example.com".to_string())])
270 .returning(vec![
271 nautilus_core::ColumnMarker::new("users", "id"),
272 nautilus_core::ColumnMarker::new("users", "email"),
273 ])
274 .build()
275 .unwrap();
276 let sql = dialect.render_insert(&insert).unwrap();
277
278 assert_eq!(sql.text, "INSERT INTO `users` (`email`) VALUES (?)");
279 assert!(!sql.text.contains("RETURNING"));
280 }
281
282 #[test]
283 fn test_update_returning_is_omitted() {
284 let dialect = MysqlDialect;
285 let update = Update::table("users")
286 .set(
287 nautilus_core::ColumnMarker::new("users", "email"),
288 Value::String("new@example.com".to_string()),
289 )
290 .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
291 .returning(vec![
292 nautilus_core::ColumnMarker::new("users", "id"),
293 nautilus_core::ColumnMarker::new("users", "email"),
294 ])
295 .build()
296 .unwrap();
297 let sql = dialect.render_update(&update).unwrap();
298
299 assert_eq!(sql.text, "UPDATE `users` SET `email` = ? WHERE (`id` = ?)");
300 assert!(!sql.text.contains("RETURNING"));
301 }
302
303 #[test]
304 fn test_delete_returning_is_omitted() {
305 let dialect = MysqlDialect;
306 let delete = Delete::from_table("users")
307 .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
308 .returning(vec![
309 nautilus_core::ColumnMarker::new("users", "id"),
310 nautilus_core::ColumnMarker::new("users", "email"),
311 ])
312 .build()
313 .unwrap();
314 let sql = dialect.render_delete(&delete).unwrap();
315
316 assert_eq!(sql.text, "DELETE FROM `users` WHERE (`id` = ?)");
317 assert!(!sql.text.contains("RETURNING"));
318 }
319
320 #[test]
321 fn test_filter_count_star_is_emulated() {
322 let dialect = MysqlDialect;
323 let select = Select::from_table("users")
324 .computed(
325 Expr::function_call("COUNT", vec![Expr::star()])
326 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
327 "active_count",
328 )
329 .build()
330 .unwrap();
331
332 let sql = dialect.render_select(&select).unwrap();
333
334 assert_eq!(
335 sql.text,
336 "SELECT (COUNT(CASE WHEN (`active` = ?) THEN 1 ELSE NULL END)) AS `active_count` FROM `users`"
337 );
338 assert_eq!(sql.params, vec![Value::Bool(true)]);
339 }
340
341 #[test]
342 fn test_filter_single_arg_aggregate_is_emulated() {
343 let dialect = MysqlDialect;
344 let select = Select::from_table("users")
345 .computed(
346 Expr::function_call("SUM", vec![Expr::column("score")])
347 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
348 "active_score",
349 )
350 .build()
351 .unwrap();
352
353 let sql = dialect.render_select(&select).unwrap();
354
355 assert_eq!(
356 sql.text,
357 "SELECT (SUM(CASE WHEN (`active` = ?) THEN `score` ELSE NULL END)) AS `active_score` FROM `users`"
358 );
359 assert_eq!(sql.params, vec![Value::Bool(true)]);
360 }
361
362 #[test]
363 fn test_filter_multi_arg_function_is_rejected() {
364 let dialect = MysqlDialect;
365 let select = Select::from_table("users")
366 .computed(
367 Expr::function_call(
368 "json_build_object",
369 vec![Expr::Literal("score".to_string()), Expr::column("score")],
370 )
371 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
372 "payload",
373 )
374 .build()
375 .unwrap();
376
377 let err = dialect.render_select(&select).unwrap_err();
378 assert!(err
379 .to_string()
380 .contains("cannot emulate FILTER for function 'json_build_object'"));
381 }
382
383 #[test]
384 fn test_array_overlaps_is_rejected() {
385 let dialect = MysqlDialect;
386 let expr = Expr::Binary {
387 left: Box::new(Expr::column("posts__tags")),
388 op: BinaryOp::ArrayOverlaps,
389 right: Box::new(Expr::param(Value::Array(vec![Value::String(
390 "rust".to_string(),
391 )]))),
392 };
393 let select = Select::from_table("posts").filter(expr).build().unwrap();
394
395 let err = dialect.render_select(&select).unwrap_err();
396 assert!(err.to_string().contains("ArrayOverlaps generically"));
397 }
398}