1use crate::{Dialect, Sql};
4use nautilus_core::{BinaryOp, Delete, Error, Expr, Insert, Result, Select, Update, Value};
5
6#[derive(Debug, Clone, Copy)]
8pub struct MysqlDialect;
9
10impl Dialect for MysqlDialect {
13 fn supports_returning(&self) -> bool {
14 false
15 }
16
17 fn render_select_owned(&self, mut select: Select) -> Result<Sql> {
18 let mut ctx = RenderContext::with_estimate(crate::estimate_select_render(&select));
19 render_select_body_core_mut!(&mut ctx, &mut select, '`', render_expr_owned, false, true);
20 ctx.finish()
21 }
22
23 fn render_insert_owned(&self, mut insert: Insert) -> Result<Sql> {
24 let mut ctx = RenderContext::with_estimate(crate::estimate_insert_render(&insert));
25 render_insert_body_mut!(&mut ctx, &mut insert, '`', false, false);
26 ctx.finish()
27 }
28
29 fn render_update_owned(&self, mut update: Update) -> Result<Sql> {
30 let mut ctx = RenderContext::with_estimate(crate::estimate_update_render(&update));
31 render_update_body_mut!(&mut ctx, &mut update, '`', render_expr_owned, false, false);
32 ctx.finish()
33 }
34
35 fn render_delete_owned(&self, mut delete: Delete) -> Result<Sql> {
36 let mut ctx = RenderContext::with_estimate(crate::estimate_delete_render(&delete));
37 render_delete_body_mut!(&mut ctx, &mut delete, '`', render_expr_owned, false);
38 ctx.finish()
39 }
40}
41
42struct RenderContext {
43 sql: String,
44 params: Vec<Value>,
45 error: Option<Error>,
46}
47
48impl RenderContext {
49 fn with_estimate(estimate: crate::RenderEstimate) -> Self {
50 Self {
51 sql: String::with_capacity(estimate.sql_capacity),
52 params: Vec::with_capacity(estimate.params_capacity),
53 error: None,
54 }
55 }
56
57 fn push_param(&mut self, value: Value) {
58 self.params.push(value);
59 self.sql.push('?');
60 }
61
62 fn take_param(&mut self, value: &mut Value) {
63 self.push_param(std::mem::replace(value, Value::Null));
64 }
65
66 fn fail(&mut self, message: impl Into<String>) {
67 if self.error.is_none() {
68 self.error = Some(Error::InvalidQuery(message.into()));
69 }
70 }
71
72 fn finish(self) -> Result<Sql> {
73 if let Some(err) = self.error {
74 return Err(err);
75 }
76
77 Ok(Sql {
78 text: self.sql,
79 params: self.params,
80 })
81 }
82}
83
84fn render_select_body_owned(ctx: &mut RenderContext, select: &mut crate::Select) {
85 render_select_body_core_mut!(ctx, select, '`', render_expr_owned, false, true);
86}
87
88fn mysql_function_name(name: &str) -> &str {
89 match name {
90 "json_agg" => "JSON_ARRAYAGG",
91 "json_build_object" => "JSON_OBJECT",
92 _ => name,
93 }
94}
95
96fn render_case_filtered_aggregate_owned(
97 ctx: &mut RenderContext,
98 fn_name: &str,
99 arg: &mut Expr,
100 predicate: &mut Expr,
101) {
102 ctx.sql.push_str(fn_name);
103 ctx.sql.push_str("(CASE WHEN ");
104 render_expr_owned(ctx, predicate);
105 ctx.sql.push_str(" THEN ");
106 render_expr_owned(ctx, arg);
107 ctx.sql.push_str(" ELSE NULL END)");
108}
109
110fn render_filter_owned(ctx: &mut RenderContext, expr: &mut Expr, predicate: &mut Expr) {
111 let Expr::FunctionCall { name, args } = expr else {
112 ctx.fail("MysqlDialect can only emulate FILTER for aggregate function calls");
113 return;
114 };
115
116 let upper = name.to_ascii_uppercase();
117 match (upper.as_str(), args.as_mut_slice()) {
118 ("COUNT", [Expr::Star]) => {
119 ctx.sql.push_str("COUNT(CASE WHEN ");
120 render_expr_owned(ctx, predicate);
121 ctx.sql.push_str(" THEN 1 ELSE NULL END)");
122 }
123 ("COUNT", [arg]) | ("SUM", [arg]) | ("AVG", [arg]) | ("MIN", [arg]) | ("MAX", [arg]) => {
124 render_case_filtered_aggregate_owned(ctx, upper.as_str(), arg, predicate);
125 }
126 ("JSON_AGG", [_]) => {
127 ctx.fail(
128 "MysqlDialect cannot emulate FILTER for json_agg without changing JSON null semantics",
129 );
130 }
131 (_, []) => {
132 ctx.fail(format!(
133 "MysqlDialect cannot emulate FILTER for function '{}' with zero arguments",
134 name
135 ));
136 }
137 _ => {
138 ctx.fail(format!(
139 "MysqlDialect cannot emulate FILTER for function '{}' with {} arguments",
140 name,
141 args.len()
142 ));
143 }
144 }
145}
146
147fn render_expr_owned(ctx: &mut RenderContext, expr: &mut Expr) {
148 if ctx.error.is_some() {
149 return;
150 }
151
152 render_expr_common_mut!(ctx, expr, '`', render_expr_owned, render_select_body_owned, {
153 Expr::Param(value) => {
154 if matches!(value, Value::Null) {
155 ctx.sql.push_str("NULL");
156 } else {
157 ctx.take_param(value);
158 }
159 }
160 Expr::Binary { left, op, right } => {
161 if matches!(*op, BinaryOp::In | BinaryOp::NotIn) {
162 ctx.sql.push('(');
163 render_expr_owned(ctx, left.as_mut());
164 ctx.sql.push(' ');
165 ctx.sql
166 .push_str(if matches!(*op, BinaryOp::In) { "IN" } else { "NOT IN" });
167 ctx.sql.push_str(" (");
168 if let Expr::List(exprs) = right.as_mut() {
169 for (i, e) in exprs.iter_mut().enumerate() {
170 if i > 0 {
171 ctx.sql.push_str(", ");
172 }
173 render_expr_owned(ctx, e);
174 }
175 } else {
176 render_expr_owned(ctx, right.as_mut());
177 }
178 ctx.sql.push(')');
179 ctx.sql.push(')');
180 } else if matches!(
181 *op,
182 BinaryOp::ArrayContains | BinaryOp::ArrayContainedBy | BinaryOp::ArrayOverlaps
183 ) {
184 match *op {
185 BinaryOp::ArrayContains => {
186 ctx.sql.push_str("JSON_CONTAINS(");
187 render_expr_owned(ctx, left.as_mut());
188 ctx.sql.push_str(", ");
189 render_expr_owned(ctx, right.as_mut());
190 ctx.sql.push(')');
191 }
192 BinaryOp::ArrayContainedBy => {
193 ctx.sql.push_str("JSON_CONTAINS(");
194 render_expr_owned(ctx, right.as_mut());
195 ctx.sql.push_str(", ");
196 render_expr_owned(ctx, left.as_mut());
197 ctx.sql.push(')');
198 }
199 BinaryOp::ArrayOverlaps => {
200 ctx.fail(
201 "MysqlDialect does not render ArrayOverlaps generically because JSON_OVERLAPS is unavailable on some supported MySQL-family backends",
202 );
203 }
204 _ => unreachable!(),
205 }
206 } else {
207 ctx.sql.push('(');
208 render_expr_owned(ctx, left.as_mut());
209 ctx.sql.push(' ');
210 ctx.sql.push_str(crate::binary_op_sql(op));
211 ctx.sql.push(' ');
212 render_expr_owned(ctx, right.as_mut());
213 ctx.sql.push(')');
214 }
215 }
216 Expr::FunctionCall { name, args } => {
217 let mysql_name = mysql_function_name(name);
218 ctx.sql.push_str(mysql_name);
219 ctx.sql.push('(');
220 for (i, arg) in args.iter_mut().enumerate() {
221 if i > 0 {
222 ctx.sql.push_str(", ");
223 }
224 render_expr_owned(ctx, arg);
225 }
226 ctx.sql.push(')');
227 }
228 Expr::Filter { expr, predicate } => {
229 render_filter_owned(ctx, expr.as_mut(), predicate.as_mut());
230 }
231 });
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 fn quote_identifier(name: &str) -> String {
239 let mut sql = String::new();
240 crate::push_quoted_identifier(&mut sql, name, '`');
241 sql
242 }
243
244 #[test]
245 fn test_quote_identifier() {
246 assert_eq!(quote_identifier("users"), "`users`");
247 assert_eq!(quote_identifier("email"), "`email`");
248 assert_eq!(quote_identifier("foo`bar"), "`foo``bar`");
249 assert_eq!(quote_identifier("a`b`c"), "`a``b``c`");
250 }
251
252 #[test]
253 fn test_skip_without_take() {
254 let dialect = MysqlDialect;
255 let select = Select::from_table("users").skip(20).build().unwrap();
256 let sql = dialect.render_select(&select).unwrap();
257
258 assert_eq!(
259 sql.text,
260 "SELECT * FROM `users` LIMIT 18446744073709551615 OFFSET 20"
261 );
262 assert!(sql.params.is_empty());
263 }
264
265 #[test]
266 fn test_insert_returning_is_omitted() {
267 let dialect = MysqlDialect;
268 let insert = Insert::into_table("users")
269 .column(nautilus_core::ColumnMarker::new("users", "email"))
270 .values(vec![Value::String("alice@example.com".to_string())])
271 .returning(vec![
272 nautilus_core::ColumnMarker::new("users", "id"),
273 nautilus_core::ColumnMarker::new("users", "email"),
274 ])
275 .build()
276 .unwrap();
277 let sql = dialect.render_insert(&insert).unwrap();
278
279 assert_eq!(sql.text, "INSERT INTO `users` (`email`) VALUES (?)");
280 assert!(!sql.text.contains("RETURNING"));
281 }
282
283 #[test]
284 fn test_update_returning_is_omitted() {
285 let dialect = MysqlDialect;
286 let update = Update::table("users")
287 .set(
288 nautilus_core::ColumnMarker::new("users", "email"),
289 Value::String("new@example.com".to_string()),
290 )
291 .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
292 .returning(vec![
293 nautilus_core::ColumnMarker::new("users", "id"),
294 nautilus_core::ColumnMarker::new("users", "email"),
295 ])
296 .build()
297 .unwrap();
298 let sql = dialect.render_update(&update).unwrap();
299
300 assert_eq!(sql.text, "UPDATE `users` SET `email` = ? WHERE (`id` = ?)");
301 assert!(!sql.text.contains("RETURNING"));
302 }
303
304 #[test]
305 fn test_delete_returning_is_omitted() {
306 let dialect = MysqlDialect;
307 let delete = Delete::from_table("users")
308 .filter(Expr::column("id").eq(Expr::param(Value::I64(1))))
309 .returning(vec![
310 nautilus_core::ColumnMarker::new("users", "id"),
311 nautilus_core::ColumnMarker::new("users", "email"),
312 ])
313 .build()
314 .unwrap();
315 let sql = dialect.render_delete(&delete).unwrap();
316
317 assert_eq!(sql.text, "DELETE FROM `users` WHERE (`id` = ?)");
318 assert!(!sql.text.contains("RETURNING"));
319 }
320
321 #[test]
322 fn test_filter_count_star_is_emulated() {
323 let dialect = MysqlDialect;
324 let select = Select::from_table("users")
325 .computed(
326 Expr::function_call("COUNT", vec![Expr::star()])
327 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
328 "active_count",
329 )
330 .build()
331 .unwrap();
332
333 let sql = dialect.render_select(&select).unwrap();
334
335 assert_eq!(
336 sql.text,
337 "SELECT (COUNT(CASE WHEN (`active` = ?) THEN 1 ELSE NULL END)) AS `active_count` FROM `users`"
338 );
339 assert_eq!(sql.params, vec![Value::Bool(true)]);
340 }
341
342 #[test]
343 fn test_filter_single_arg_aggregate_is_emulated() {
344 let dialect = MysqlDialect;
345 let select = Select::from_table("users")
346 .computed(
347 Expr::function_call("SUM", vec![Expr::column("score")])
348 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
349 "active_score",
350 )
351 .build()
352 .unwrap();
353
354 let sql = dialect.render_select(&select).unwrap();
355
356 assert_eq!(
357 sql.text,
358 "SELECT (SUM(CASE WHEN (`active` = ?) THEN `score` ELSE NULL END)) AS `active_score` FROM `users`"
359 );
360 assert_eq!(sql.params, vec![Value::Bool(true)]);
361 }
362
363 #[test]
364 fn test_filter_multi_arg_function_is_rejected() {
365 let dialect = MysqlDialect;
366 let select = Select::from_table("users")
367 .computed(
368 Expr::function_call(
369 "json_build_object",
370 vec![
371 Expr::Literal(nautilus_core::LiteralSql::from_static("score")),
372 Expr::column("score"),
373 ],
374 )
375 .filter(Expr::column("active").eq(Expr::param(Value::Bool(true)))),
376 "payload",
377 )
378 .build()
379 .unwrap();
380
381 let err = dialect.render_select(&select).unwrap_err();
382 assert!(err
383 .to_string()
384 .contains("cannot emulate FILTER for function 'json_build_object'"));
385 }
386
387 #[test]
388 fn test_array_overlaps_is_rejected() {
389 let dialect = MysqlDialect;
390 let expr = Expr::Binary {
391 left: Box::new(Expr::column("posts__tags")),
392 op: BinaryOp::ArrayOverlaps,
393 right: Box::new(Expr::param(Value::Array(vec![Value::String(
394 "rust".to_string(),
395 )]))),
396 };
397 let select = Select::from_table("posts").filter(expr).build().unwrap();
398
399 let err = dialect.render_select(&select).unwrap_err();
400 assert!(err.to_string().contains("ArrayOverlaps generically"));
401 }
402}