1use crate::connection::{SqlExecutor, SqlExecutorAsync, SqlExecutorMut};
2use std::future::Future;
3use tracing::{Instrument, Span, error};
4
5pub(super) trait Sealed {}
6
7#[allow(private_bounds)]
13pub trait Statement: Send + Sealed {
14 type Output: Send;
16
17 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error>;
23
24 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error>;
30
31 fn execute_async<S: SqlExecutorAsync>(
37 self,
38 connection: &mut S,
39 ) -> impl Future<Output = Result<Self::Output, S::Error>> + Send;
40
41 fn then<Q: Statement>(self, statement: Q) -> Then<Self, Q>
43 where
44 Self: Sized,
45 {
46 Then {
47 a: self,
48 b: statement,
49 }
50 }
51
52 fn pipe<Q: StatementWithInput<Input = Self::Output> + Send>(self, statement: Q) -> Pipe<Self, Q>
55 where
56 Self: Sized,
57 {
58 Pipe {
59 a: self,
60 b: statement,
61 }
62 }
63
64 fn spanned(self, span: Span) -> TracedStatement<Self>
66 where
67 Self: Sized,
68 {
69 TracedStatement::new(self, span)
70 }
71
72 fn spanned_in_current(self) -> TracedStatement<Self>
74 where
75 Self: Sized,
76 {
77 TracedStatement::current(self)
78 }
79}
80
81pub trait StatementWithInput: Send {
85 type Input: Send;
87 type Output: Send;
89
90 fn execute<S: SqlExecutor>(
96 self,
97 connection: &S,
98 input: Self::Input,
99 ) -> Result<Self::Output, S::Error>;
100
101 fn execute_mut<S: SqlExecutorMut>(
107 self,
108 connection: &mut S,
109 input: Self::Input,
110 ) -> Result<Self::Output, S::Error>;
111
112 fn execute_async<S: SqlExecutorAsync>(
118 self,
119 connection: &mut S,
120 input: Self::Input,
121 ) -> impl Future<Output = Result<Self::Output, S::Error>> + Send;
122}
123
124pub struct Then<A: Statement, B: Statement> {
128 a: A,
129 b: B,
130}
131impl<A: Statement, B: Statement> Sealed for Then<A, B> {}
132
133impl<A: Statement + Send, B: Statement + Send> Statement for Then<A, B> {
134 type Output = B::Output;
135
136 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
137 self.a.execute(connection)?;
138 self.b.execute(connection)
139 }
140
141 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
142 self.a.execute_mut(connection)?;
143 self.b.execute_mut(connection)
144 }
145
146 async fn execute_async<S: SqlExecutorAsync>(
147 self,
148 connection: &mut S,
149 ) -> Result<Self::Output, S::Error> {
150 self.a.execute_async(connection).await?;
151 self.b.execute_async(connection).await
152 }
153}
154
155pub struct Pipe<A: Statement + Send, B: StatementWithInput<Input = A::Output> + Send> {
157 a: A,
158 b: B,
159}
160
161impl<A: Statement, B: StatementWithInput<Input = A::Output>> Sealed for Pipe<A, B> {}
162impl<A: Statement + Send, B: StatementWithInput<Input = A::Output> + Send> Statement
163 for Pipe<A, B>
164{
165 type Output = B::Output;
166
167 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
168 let output = self.a.execute(connection)?;
169 self.b.execute(connection, output)
170 }
171
172 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
173 let output = self.a.execute_mut(connection)?;
174 self.b.execute_mut(connection, output)
175 }
176
177 async fn execute_async<S: SqlExecutorAsync>(
178 self,
179 connection: &mut S,
180 ) -> Result<Self::Output, S::Error> {
181 let output = self.a.execute_async(connection).await?;
182 self.b.execute_async(connection, output).await
183 }
184}
185
186pub(super) struct SqlExecuteStatement<T: AsRef<str>> {
188 query: T,
189}
190
191impl<T: AsRef<str> + Send> SqlExecuteStatement<T> {
192 pub fn new(query: T) -> Self {
193 Self { query }
194 }
195}
196
197impl<T: AsRef<str> + Send> Sealed for SqlExecuteStatement<T> {}
198
199impl<T: AsRef<str> + Send> Statement for SqlExecuteStatement<T> {
200 type Output = ();
201
202 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
203 connection.sql_execute(self.query.as_ref())
204 }
205
206 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
207 connection.sql_execute(self.query.as_ref())
208 }
209
210 async fn execute_async<S: SqlExecutorAsync>(
211 self,
212 connection: &mut S,
213 ) -> Result<Self::Output, S::Error> {
214 connection.sql_execute(self.query.as_ref()).await
215 }
216}
217
218enum TransactionMode {
220 Temporary,
223 Full,
225}
226
227pub(super) struct SqlTransactionStatement<Q: Statement> {
229 statement: Q,
230 mode: TransactionMode,
231}
232
233impl<Q: Statement> SqlTransactionStatement<Q> {
234 pub fn temporary(statement: Q) -> Self {
236 Self {
237 statement,
238 mode: TransactionMode::Temporary,
239 }
240 }
241 #[allow(dead_code)]
243 pub fn full(statement: Q) -> Self {
244 Self {
245 statement,
246 mode: TransactionMode::Full,
247 }
248 }
249
250 fn begin_statement(&self) -> &'static str {
251 match self.mode {
252 TransactionMode::Temporary => BEGIN_TRANSACTION_STATEMENT,
253 TransactionMode::Full => BEGIN_TRANSACTION_IMMEDIATE_STATEMENT,
254 }
255 }
256}
257
258impl<Q: Statement<Output = ()>> Sealed for SqlTransactionStatement<Q> {}
259
260impl<Q: Statement<Output = ()>> Statement for SqlTransactionStatement<Q> {
261 type Output = ();
262
263 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
264 connection
265 .sql_execute(self.begin_statement())
266 .inspect_err(|e| error!("Failed to start transaction: {e}"))?;
267 if let Err(e) = self.statement.execute(connection) {
268 error!("Statement failed to execute: {e}");
269 if let Err(e) = connection.sql_execute(ROLLBACK_TRANSACTION_STATEMENT) {
270 error!("Failed to rollback transaction: {e}");
271 }
272 return Err(e);
273 }
274 connection
275 .sql_execute(COMMIT_TRANSACTION_STATEMENT)
276 .inspect_err(|e| error!("Failed to commit transaction: {e}"))?;
277 Ok(())
278 }
279
280 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
281 connection
282 .sql_execute(self.begin_statement())
283 .inspect_err(|e| error!("Failed to start transaction: {e}"))?;
284 if let Err(e) = self.statement.execute_mut(connection) {
285 error!("Statement failed to execute: {e}");
286 if let Err(e) = connection.sql_execute(ROLLBACK_TRANSACTION_STATEMENT) {
287 error!("Failed to rollback transaction: {e}");
288 }
289 return Err(e);
290 }
291 connection
292 .sql_execute(COMMIT_TRANSACTION_STATEMENT)
293 .inspect_err(|e| error!("Failed to commit transaction: {e}"))?;
294 Ok(())
295 }
296 async fn execute_async<S: SqlExecutorAsync>(
297 self,
298 connection: &mut S,
299 ) -> Result<Self::Output, S::Error> {
300 connection
301 .sql_execute(self.begin_statement())
302 .await
303 .inspect_err(|e| error!("Failed to start transaction: {e}"))?;
304 if let Err(e) = self.statement.execute_async(connection).await {
305 error!("Statement failed to execute: {e}");
306 if let Err(e) = connection.sql_execute(ROLLBACK_TRANSACTION_STATEMENT).await {
307 error!("Failed to rollback transaction: {e}");
308 }
309 return Err(e);
310 }
311 connection
312 .sql_execute(COMMIT_TRANSACTION_STATEMENT)
313 .await
314 .inspect_err(|e| error!("Failed to commit transaction: {e}"))?;
315 Ok(())
316 }
317}
318
319pub(super) struct BatchQuery<Q: Statement>(Vec<Q>);
323
324impl<Q: Statement> BatchQuery<Q> {
325 pub fn new(v: impl IntoIterator<Item = Q>) -> Self {
326 Self(Vec::from_iter(v))
327 }
328
329 pub fn push(&mut self, q: Q) {
330 self.0.push(q);
331 }
332
333 pub fn extend<I: IntoIterator<Item = Q>>(&mut self, iter: I) {
334 self.0.extend(iter);
335 }
336}
337
338impl<Q: Statement<Output = ()>> Sealed for BatchQuery<Q> {}
339
340impl<Q: Statement<Output = ()>> Statement for BatchQuery<Q> {
341 type Output = ();
342
343 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
344 for q in self.0 {
345 q.execute(connection)?;
346 }
347 Ok(())
348 }
349
350 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
351 for q in self.0 {
352 q.execute_mut(connection)?;
353 }
354 Ok(())
355 }
356 async fn execute_async<S: SqlExecutorAsync>(
357 self,
358 connection: &mut S,
359 ) -> Result<Self::Output, S::Error> {
360 for q in self.0 {
361 q.execute_async(connection).await?;
362 }
363 Ok(())
364 }
365}
366
367impl<Q: Statement> Sealed for Option<Q> {}
368
369impl<Q: Statement> Statement for Option<Q> {
370 type Output = Option<Q::Output>;
371
372 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
373 Ok(match self {
374 Some(q) => Some(q.execute(connection)?),
375 None => None,
376 })
377 }
378
379 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
380 Ok(match self {
381 Some(q) => Some(q.execute_mut(connection)?),
382 None => None,
383 })
384 }
385
386 async fn execute_async<S: SqlExecutorAsync>(
387 self,
388 connection: &mut S,
389 ) -> Result<Self::Output, S::Error> {
390 Ok(match self {
391 Some(q) => Some(q.execute_async(connection).await?),
392 None => None,
393 })
394 }
395}
396
397pub struct TracedStatement<Q: Statement> {
398 statement: Q,
399 span: Span,
400}
401
402impl<Q: Statement> TracedStatement<Q> {
403 pub fn new(statement: Q, span: Span) -> Self {
405 Self { statement, span }
406 }
407
408 pub fn current(statement: Q) -> Self {
410 Self::new(statement, Span::current())
411 }
412}
413
414impl<Q: Statement> Sealed for TracedStatement<Q> {}
415
416impl<Q: Statement> Statement for TracedStatement<Q> {
417 type Output = Q::Output;
418 fn execute<S: SqlExecutor>(self, connection: &S) -> Result<Self::Output, S::Error> {
419 let _span = self.span.entered();
420 self.statement.execute(connection)
421 }
422
423 fn execute_mut<S: SqlExecutorMut>(self, connection: &mut S) -> Result<Self::Output, S::Error> {
424 let _span = self.span.entered();
425 self.statement.execute_mut(connection)
426 }
427
428 async fn execute_async<S: SqlExecutorAsync>(
429 self,
430 connection: &mut S,
431 ) -> Result<Self::Output, S::Error> {
432 self.statement
433 .execute_async(connection)
434 .instrument(self.span)
435 .await
436 }
437}
438
439const BEGIN_TRANSACTION_STATEMENT: &str = "BEGIN";
440const BEGIN_TRANSACTION_IMMEDIATE_STATEMENT: &str = "BEGIN IMMEDIATE";
441const COMMIT_TRANSACTION_STATEMENT: &str = "COMMIT";
442const ROLLBACK_TRANSACTION_STATEMENT: &str = "ROLLBACK";