rorm_db/sqlx_impl/
executor.rs

1use std::future;
2use std::future::{ready, Ready};
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5
6use futures_core::stream;
7use rorm_sql::value::Value;
8use rorm_sql::DBImpl;
9use tracing::debug;
10
11use crate::executor::{
12    AffectedRows, All, DynamicExecutor, Executor, Nothing, One, Optional, QueryStrategy,
13    QueryStrategyResult, Stream,
14};
15use crate::internal::any::{AnyExecutor, AnyPool, AnyQueryResult, AnyRow, AnyTransaction};
16use crate::transaction::{Transaction, TransactionGuard};
17use crate::{Database, Error, Row};
18
19impl<'executor> Executor<'executor> for &'executor mut Transaction {
20    fn execute<'data, 'result, Q>(
21        self,
22        query: String,
23        values: Vec<Value<'data>>,
24    ) -> Q::Result<'result>
25    where
26        'executor: 'result,
27        'data: 'result,
28        Q: QueryStrategy,
29    {
30        debug!(
31            target: "rorm_db::executor",
32            sql = query,
33            values.len = values.len(),
34            "Executing statement"
35        );
36        Q::execute(&mut self.0, query, values)
37    }
38
39    fn into_dyn(self) -> DynamicExecutor<'executor> {
40        DynamicExecutor::Transaction(self)
41    }
42
43    fn dialect(&self) -> DBImpl {
44        match self.0 {
45            #[cfg(feature = "postgres")]
46            AnyTransaction::Postgres(_) => DBImpl::Postgres,
47            #[cfg(feature = "mysql")]
48            AnyTransaction::MySql(_) => DBImpl::MySQL,
49            #[cfg(feature = "sqlite")]
50            AnyTransaction::Sqlite(_) => DBImpl::SQLite,
51        }
52    }
53
54    type EnsureTransactionFuture = Ready<Result<TransactionGuard<'executor>, Error>>;
55
56    fn ensure_transaction(
57        self,
58    ) -> BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>> {
59        Box::pin(ready(Ok(TransactionGuard::Borrowed(self))))
60    }
61}
62
63impl<'executor> Executor<'executor> for &'executor Database {
64    fn execute<'data, 'result, Q>(
65        self,
66        query: String,
67        values: Vec<Value<'data>>,
68    ) -> Q::Result<'result>
69    where
70        'executor: 'result,
71        'data: 'result,
72        Q: QueryStrategy,
73    {
74        debug!(
75            target: "rorm_db::executor",
76            sql = query,
77            values.len = values.len(),
78            "Executing statement"
79        );
80        Q::execute(&self.0, query, values)
81    }
82
83    fn into_dyn(self) -> DynamicExecutor<'executor> {
84        DynamicExecutor::Database(self)
85    }
86
87    fn dialect(&self) -> DBImpl {
88        match self.0 {
89            #[cfg(feature = "postgres")]
90            AnyPool::Postgres(_) => DBImpl::Postgres,
91            #[cfg(feature = "mysql")]
92            AnyPool::MySql(_) => DBImpl::MySQL,
93            #[cfg(feature = "sqlite")]
94            AnyPool::Sqlite(_) => DBImpl::SQLite,
95        }
96    }
97
98    type EnsureTransactionFuture = BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>>;
99
100    fn ensure_transaction(
101        self,
102    ) -> BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>> {
103        Box::pin(async move { self.start_transaction().await.map(TransactionGuard::Owned) })
104    }
105}
106
107pub trait QueryStrategyImpl: QueryStrategyResult {
108    fn execute<'query, E>(
109        executor: E,
110        query: String,
111        values: Vec<Value<'query>>,
112    ) -> Self::Result<'query>
113    where
114        E: AnyExecutor<'query>;
115}
116
117pub type QueryFuture<T> = QueryWrapper<T>;
118pub type QueryStream<T> = QueryWrapper<T>;
119
120pub use query_wrapper::QueryWrapper;
121
122use crate::futures_util::{BoxFuture, BoxStream};
123
124/// Private module to contain the internals behind a sound api
125mod query_wrapper {
126    use std::pin::Pin;
127
128    use rorm_sql::value::Value;
129
130    use crate::internal::any::{AnyExecutor, AnyQuery};
131
132    #[doc(hidden)]
133    #[pin_project::pin_project]
134    pub struct QueryWrapper<T> {
135        #[pin]
136        wrapped: T,
137        #[allow(dead_code)] // is used via a reference inside T
138        query_string: String,
139    }
140
141    impl<'query, T: 'query> QueryWrapper<T> {
142        /// Basic constructor which only performs the unsafe lifetime extension to be tested by miri
143        pub(crate) fn new_basic(string: String, wrapped: impl FnOnce(&'query str) -> T) -> Self {
144            let slice: &str = string.as_str();
145
146            // SAFETY: The heap allocation won't be dropped or moved
147            //         until `wrapped` which contains this reference is dropped.
148            let slice: &'query str = unsafe { std::mem::transmute(slice) };
149
150            Self {
151                query_string: string,
152                wrapped: wrapped(slice),
153            }
154        }
155
156        pub fn new<'data: 'query>(
157            executor: impl AnyExecutor<'query>,
158            query_string: String,
159            values: Vec<Value<'data>>,
160            execute: impl FnOnce(AnyQuery<'query>) -> T,
161        ) -> Self {
162            Self::new_basic(query_string, move |query_string| {
163                let mut query = executor.query(query_string);
164                for value in values {
165                    crate::internal::utils::bind_param(&mut query, value);
166                }
167                execute(query)
168            })
169        }
170    }
171
172    impl<T> QueryWrapper<T> {
173        /// Project a [`Pin`] onto the `wrapped` field
174        pub fn project_wrapped(self: Pin<&mut Self>) -> Pin<&mut T> {
175            self.project().wrapped
176        }
177    }
178}
179
180impl<F> future::Future for QueryFuture<F>
181where
182    F: future::Future,
183{
184    type Output = F::Output;
185
186    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
187        self.project_wrapped().poll(cx)
188    }
189}
190
191impl<S> stream::Stream for QueryStream<S>
192where
193    S: stream::Stream,
194{
195    type Item = S::Item;
196
197    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
198        self.project_wrapped().poll_next(cx)
199    }
200}
201
202impl QueryStrategyResult for Nothing {
203    type Result<'query> = QueryFuture<NothingFuture<'query>>;
204}
205
206impl QueryStrategyImpl for Nothing {
207    fn execute<'query, E>(
208        executor: E,
209        query: String,
210        values: Vec<Value<'query>>,
211    ) -> Self::Result<'query>
212    where
213        E: AnyExecutor<'query>,
214    {
215        QueryFuture::new(executor, query, values, |query| NothingFuture {
216            stream: query.fetch_many(),
217        })
218    }
219}
220
221/// [`QueryStrategyResult::Result`] of [`Nothing`]
222pub struct NothingFuture<'stream> {
223    stream: BoxStream<'stream, sqlx::Result<sqlx::Either<AnyQueryResult, AnyRow>>>,
224}
225
226impl future::Future for NothingFuture<'_> {
227    type Output = Result<(), Error>;
228
229    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
230        loop {
231            return Poll::Ready(match ready!(self.stream.as_mut().poll_next(cx)) {
232                None => Ok(()),
233                Some(Err(error)) => Err(error.into()),
234                Some(_either) => continue,
235            });
236        }
237    }
238}
239
240impl QueryStrategyResult for AffectedRows {
241    type Result<'query> = QueryFuture<BoxFuture<'query, Result<u64, Error>>>;
242}
243
244impl QueryStrategyImpl for AffectedRows {
245    fn execute<'query, E>(
246        executor: E,
247        query: String,
248        values: Vec<Value<'query>>,
249    ) -> Self::Result<'query>
250    where
251        E: AnyExecutor<'query>,
252    {
253        QueryFuture::new(executor, query, values, |query| {
254            Box::pin(async move { Ok(query.fetch_affected_rows().await?) }) as BoxFuture<_>
255        })
256    }
257}
258
259impl QueryStrategyResult for One {
260    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Row, Error>>>;
261}
262
263impl QueryStrategyImpl for One {
264    fn execute<'query, E>(
265        executor: E,
266        query: String,
267        values: Vec<Value<'query>>,
268    ) -> Self::Result<'query>
269    where
270        E: AnyExecutor<'query>,
271    {
272        QueryFuture::new(executor, query, values, |query| {
273            Box::pin(async move {
274                Ok(Row(query
275                    .fetch_optional()
276                    .await?
277                    .ok_or(sqlx::Error::RowNotFound)?))
278            }) as BoxFuture<_>
279        })
280    }
281}
282
283impl QueryStrategyResult for Optional {
284    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Option<Row>, Error>>>;
285}
286
287impl QueryStrategyImpl for Optional {
288    fn execute<'query, E>(
289        executor: E,
290        query: String,
291        values: Vec<Value<'query>>,
292    ) -> Self::Result<'query>
293    where
294        E: AnyExecutor<'query>,
295    {
296        QueryFuture::new(executor, query, values, |query| {
297            Box::pin(async move { Ok(query.fetch_optional().await?.map(Row)) }) as BoxFuture<_>
298        })
299    }
300}
301
302impl QueryStrategyResult for All {
303    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Vec<Row>, Error>>>;
304}
305
306impl QueryStrategyImpl for All {
307    fn execute<'query, E>(
308        executor: E,
309        query: String,
310        values: Vec<Value<'query>>,
311    ) -> Self::Result<'query>
312    where
313        E: AnyExecutor<'query>,
314    {
315        QueryFuture::new(executor, query, values, |query| {
316            Box::pin(async move { Ok(query.fetch_all().await?.into_iter().map(Row).collect()) })
317                as BoxFuture<_>
318        })
319    }
320}
321
322impl QueryStrategyResult for Stream {
323    type Result<'query> = QueryStream<StreamResult<'query>>;
324}
325
326impl QueryStrategyImpl for Stream {
327    fn execute<'query, E>(
328        executor: E,
329        query: String,
330        values: Vec<Value<'query>>,
331    ) -> Self::Result<'query>
332    where
333        E: AnyExecutor<'query>,
334    {
335        QueryStream::new(executor, query, values, |query| StreamResult {
336            stream: query.fetch_many(),
337        })
338    }
339}
340
341/// [`QueryStrategyResult::Result`] of [`Stream`]
342pub struct StreamResult<'stream> {
343    stream: BoxStream<'stream, sqlx::Result<sqlx::Either<AnyQueryResult, AnyRow>>>,
344}
345
346impl Unpin for StreamResult<'_> {}
347impl stream::Stream for StreamResult<'_> {
348    type Item = Result<Row, Error>;
349
350    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
351        loop {
352            return Poll::Ready(match ready!(self.stream.as_mut().poll_next(cx)) {
353                None => None,
354                Some(Err(error)) => Some(Err(error.into())),
355                Some(Ok(sqlx::Either::Right(row))) => Some(Ok(Row(row))),
356                Some(Ok(sqlx::Either::Left(_result))) => continue,
357            });
358        }
359    }
360}
361
362#[cfg(test)]
363mod test {
364    use crate::internal::executor::QueryWrapper;
365
366    /// Run this test with miri
367    ///
368    /// If the drop order of [`QueryWrapper`]'s fields is incorrect,
369    /// miri will complain about a use-after-free.
370    #[test]
371    fn test_drop_order() {
372        struct BorrowStr<'a>(&'a str);
373        impl<'a> Drop for BorrowStr<'a> {
374            fn drop(&mut self) {
375                // Use the borrowed string.
376                // If it were already dropped, miri would detect it.
377                println!("{}", self.0);
378            }
379        }
380        let _w = QueryWrapper::new_basic(format!("Hello World"), BorrowStr);
381    }
382}