1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
use toasty_core::{
driver::{Rows, operation},
stmt,
};
use crate::{
Result,
engine::exec::{Action, Exec, Output, VarId},
};
/// Information about a MySQL INSERT with RETURNING that needs special handling.
///
/// MySQL doesn't support RETURNING clauses, but we can work around this for
/// auto-increment columns by using LAST_INSERT_ID().
#[derive(Debug)]
struct MySQLInsertReturning {
/// Number of rows being inserted
num_rows: u64,
/// The original returning expression that was removed from the statement
returning_expr: stmt::Expr,
/// The type of the auto-increment column
auto_column_type: stmt::Type,
}
#[derive(Debug)]
pub(crate) struct ExecStatement {
/// Where to get arguments for this action.
pub input: Vec<VarId>,
/// How to handle output
pub output: ExecStatementOutput,
/// The query to execute. This may require input to generate the query.
pub stmt: stmt::Statement,
/// When true, the statement is a conditional update without any returning.
pub conditional_update_with_no_returning: bool,
}
#[derive(Debug)]
pub(crate) struct ExecStatementOutput {
/// Databases always return rows as a vec of values. This specifies the type
/// of each value.
pub ty: Option<Vec<stmt::Type>>,
pub output: Output,
}
impl Exec<'_> {
pub(super) async fn action_exec_statement(&mut self, action: &ExecStatement) -> Result<()> {
let mut stmt = action.stmt.clone();
// Collect input values and substitute into the statement
if !action.input.is_empty() {
let mut input_values = Vec::new();
for var_id in &action.input {
let values = self.vars.load(*var_id).await?.collect_as_value().await?;
input_values.push(values);
}
stmt.substitute(&input_values);
self.engine.simplify_stmt(&mut stmt);
}
debug_assert!(
stmt.returning()
.and_then(|returning| returning.as_expr())
.map(|expr| expr.is_record())
.unwrap_or(true),
"stmt={stmt:#?}"
);
// MySQL does not support returning clauses with insert statements,
// which adds a wrinkle when we want to get the IDs for autoincrement
// IDs.
let mysql_insert_returning = self.process_stmt_insert_with_returning_on_mysql(&mut stmt);
// Short circuit if we can statically determine there are no results
if let stmt::Statement::Query(query) = &stmt
&& let stmt::ExprSet::Values(values) = &query.body
&& values.is_empty()
{
assert!(!action.conditional_update_with_no_returning);
let rows = if action.output.ty.is_some() {
Rows::Stream(stmt::ValueStream::default())
} else {
Rows::Count(0)
};
self.vars.store(
action.output.output.var,
action.output.output.num_uses,
rows,
);
return Ok(());
}
let op = operation::QuerySql {
stmt,
ret: if action.conditional_update_with_no_returning {
Some(vec![stmt::Type::I64, stmt::Type::I64])
} else if mysql_insert_returning.is_some() {
// For MySQL INSERT with RETURNING, we don't send RETURNING to the database
// (it doesn't support it). The driver will fetch auto-increment IDs using LAST_INSERT_ID().
None
} else {
action.output.ty.clone()
},
last_insert_id_hack: mysql_insert_returning.as_ref().map(|info| info.num_rows),
};
let mut res = self.connection.exec(&self.engine.schema, op.into()).await?;
if action.conditional_update_with_no_returning {
let Rows::Stream(rows) = res.rows else {
return Err(toasty_core::Error::invalid_result(format!(
"conditional update expected Stream, got {:?}",
res.rows
)));
};
let rows = rows.collect().await?;
assert_eq!(rows.len(), 1);
let stmt::Value::Record(record) = &rows[0] else {
return Err(toasty_core::Error::invalid_result(format!(
"conditional update expected Record, got {:?}",
rows[0]
)));
};
assert_eq!(record.len(), 2);
if record[0] != record[1] {
return Err(toasty_core::Error::condition_failed(
"update condition did not match",
));
}
res.rows = Rows::Count(record[0].to_u64_unwrap());
} else if let Some(mysql_info) = mysql_insert_returning {
res.rows = mysql_info.reconstruct_returning(res.rows).await?;
}
self.vars.store(
action.output.output.var,
action.output.output.num_uses,
res.rows,
);
Ok(())
}
/// Processes INSERT statements with RETURNING on MySQL, which doesn't support RETURNING.
///
/// Returns information needed to reconstruct the RETURNING results using LAST_INSERT_ID()
/// if this is a MySQL INSERT with RETURNING. Returns None otherwise.
///
/// # Panics
///
/// Panics if the RETURNING clause includes non-auto-increment columns, as MySQL doesn't
/// support RETURNING and we can only work around it for auto-increment columns.
fn process_stmt_insert_with_returning_on_mysql(
&self,
stmt: &mut stmt::Statement,
) -> Option<MySQLInsertReturning> {
if self.engine.capability().returning_from_mutation {
return None;
}
let stmt::Statement::Insert(insert) = stmt else {
return None;
};
let returning = insert.returning.take()?;
// Verify that all columns in the RETURNING clause are auto-increment columns.
// This is required because MySQL doesn't support RETURNING, but we can work around
// this limitation for auto-increment columns by using LAST_INSERT_ID().
let cx = self.engine.expr_cx_for(&*insert);
let mut ref_count = 0;
let mut auto_column_type = None;
stmt::visit::for_each_expr(&returning, |expr| {
if let stmt::Expr::Reference(expr_ref) = expr {
let column = cx.resolve_expr_reference(expr_ref).as_column_unwrap();
assert!(
column.auto_increment,
"MySQL does not support RETURNING clause for non-auto-increment columns. \
Column '{}' in table '{}' is not auto-increment. \
Only auto-increment columns can be returned from INSERT statements on MySQL.",
column.name, self.engine.schema.db.tables[column.id.table.0].name
);
auto_column_type = Some(column.ty.clone());
ref_count += 1;
}
});
assert_eq!(
ref_count, 1,
"MySQL INSERT with RETURNING must have exactly one auto-increment column reference, found {ref_count}"
);
let auto_column_type = auto_column_type.expect("auto_column_type should be set");
// Extract the expression from the RETURNING clause and replace ExprReference with ExprArg
let mut returning_expr = match returning {
stmt::Returning::Expr(expr) => expr,
_ => panic!(
"MySQL INSERT with RETURNING must have an Expr, got: {:#?}",
returning
),
};
// Replace the ExprReference with ExprArg(position: 0) so we can pass the ID as a positional argument
stmt::visit_mut::for_each_expr_mut(&mut returning_expr, |expr| {
if matches!(expr, stmt::Expr::Reference(_)) {
*expr = stmt::Expr::Arg(stmt::ExprArg {
position: 0,
nesting: 0,
});
}
});
// Count the number of rows being inserted
let num_rows = match &insert.source.body {
stmt::ExprSet::Values(values) => values.rows.len() as u64,
_ => {
panic!(
"MySQL INSERT with RETURNING only supports VALUES, got: {:#?}",
insert.source.body
);
}
};
Some(MySQLInsertReturning {
num_rows,
returning_expr,
auto_column_type,
})
}
}
impl From<ExecStatement> for Action {
fn from(value: ExecStatement) -> Self {
Self::ExecStatement(value.into())
}
}
impl MySQLInsertReturning {
/// Reconstructs RETURNING results from the ID rows returned by the driver.
///
/// MySQL doesn't support RETURNING, but we fetch auto-increment IDs using LAST_INSERT_ID().
/// This method takes the ID rows returned by the driver and evaluates the original RETURNING
/// expression for each ID to produce the expected results.
async fn reconstruct_returning(self, rows: Rows) -> Result<Rows> {
// The driver executed SELECT LAST_INSERT_ID() and returned rows with IDs.
let Rows::Stream(id_rows) = rows else {
return Err(toasty_core::Error::invalid_result(format!(
"MySQL INSERT RETURNING expected Stream, got {:?}",
rows
)));
};
let id_values = id_rows.collect().await?;
assert_eq!(
id_values.len(),
self.num_rows as usize,
"Expected {} ID rows from driver, got {}",
self.num_rows,
id_values.len()
);
// Reconstruct the RETURNING results by evaluating the original returning expression
// for each ID row returned by the driver
let mut returning_rows = Vec::with_capacity(self.num_rows as usize);
for id_value_raw in id_values {
// The driver returns a record with one field containing the ID.
// Extract the ID value from the record wrapper.
let stmt::Value::Record(record) = id_value_raw else {
return Err(toasty_core::Error::invalid_result(format!(
"MySQL INSERT RETURNING expected Record from driver, got {:?}",
id_value_raw
)));
};
assert_eq!(
record.fields.len(),
1,
"Expected record with one field from driver"
);
// Cast the ID to the correct type for the auto-increment column
let id_value = self.auto_column_type.cast(record.fields[0].clone())?;
let input = vec![id_value];
// Evaluate the returning expression with the auto-increment ID
let row_value = self.returning_expr.eval(&input)?;
returning_rows.push(row_value);
}
Ok(Rows::Stream(stmt::ValueStream::from_iter(
returning_rows.into_iter().map(Ok),
)))
}
}