Skip to main content

rorm_db/internal/
executor.rs

1use std::future;
2use std::future::{ready, Ready};
3use std::pin::Pin;
4use std::task::{ready, Context, Poll};
5
6use futures_core::stream;
7use rorm_sql::value::Value;
8use rorm_sql::DBImpl;
9use tracing::debug;
10
11use crate::executor::{
12    AffectedRows, All, DynamicExecutor, Executor, Nothing, One, Optional, QueryStrategy,
13    QueryStrategyResult, Stream,
14};
15use crate::internal::any::{AnyExecutor, AnyPool, AnyQueryResult, AnyRow, AnyTransaction};
16use crate::transaction::{Transaction, TransactionGuard};
17use crate::{Database, Error, Row};
18
19impl<'executor> Executor<'executor> for &'executor mut Transaction {
20    fn execute<'data, 'result, Q>(
21        self,
22        query: String,
23        values: Vec<Value<'data>>,
24    ) -> Q::Result<'result>
25    where
26        'executor: 'result,
27        'data: 'result,
28        Q: QueryStrategy,
29    {
30        debug!(
31            target: "rorm_db::executor",
32            sql = query,
33            values.len = values.len(),
34            "Executing statement"
35        );
36        Q::execute(&mut self.sqlx, query, values)
37    }
38
39    fn into_dyn(self) -> DynamicExecutor<'executor> {
40        DynamicExecutor::Transaction(self)
41    }
42
43    fn dialect(&self) -> DBImpl {
44        match self.sqlx {
45            #[cfg(feature = "postgres")]
46            AnyTransaction::Postgres(_) => DBImpl::Postgres,
47            #[cfg(feature = "sqlite")]
48            AnyTransaction::Sqlite(_) => DBImpl::SQLite,
49        }
50    }
51
52    type EnsureTransactionFuture = Ready<Result<TransactionGuard<'executor>, Error>>;
53
54    fn ensure_transaction(
55        self,
56    ) -> BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>> {
57        Box::pin(ready(Ok(TransactionGuard::Borrowed(self))))
58    }
59}
60
61impl<'executor> Executor<'executor> for &'executor Database {
62    fn execute<'data, 'result, Q>(
63        self,
64        query: String,
65        values: Vec<Value<'data>>,
66    ) -> Q::Result<'result>
67    where
68        'executor: 'result,
69        'data: 'result,
70        Q: QueryStrategy,
71    {
72        debug!(
73            target: "rorm_db::executor",
74            sql = query,
75            values.len = values.len(),
76            "Executing statement"
77        );
78        Q::execute(&self.0, query, values)
79    }
80
81    fn into_dyn(self) -> DynamicExecutor<'executor> {
82        DynamicExecutor::Database(self)
83    }
84
85    fn dialect(&self) -> DBImpl {
86        match self.0 {
87            #[cfg(feature = "postgres")]
88            AnyPool::Postgres(_) => DBImpl::Postgres,
89            #[cfg(feature = "sqlite")]
90            AnyPool::Sqlite(_) => DBImpl::SQLite,
91        }
92    }
93
94    type EnsureTransactionFuture = BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>>;
95
96    fn ensure_transaction(
97        self,
98    ) -> BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>> {
99        Box::pin(async move { self.start_transaction().await.map(TransactionGuard::Owned) })
100    }
101}
102
103pub trait QueryStrategyImpl: QueryStrategyResult {
104    fn execute<'query, E>(
105        executor: E,
106        query: String,
107        values: Vec<Value<'query>>,
108    ) -> Self::Result<'query>
109    where
110        E: AnyExecutor<'query>;
111}
112
113pub type QueryFuture<T> = QueryWrapper<T>;
114pub type QueryStream<T> = QueryWrapper<T>;
115
116pub use query_wrapper::QueryWrapper;
117
118use crate::futures_util::{BoxFuture, BoxStream};
119
120/// Private module to contain the internals behind a sound api
121mod query_wrapper {
122    use std::pin::Pin;
123
124    use rorm_sql::value::Value;
125
126    use crate::internal::any::{AnyExecutor, AnyQuery};
127    use crate::internal::bind_params::bind_param;
128
129    #[doc(hidden)]
130    #[pin_project::pin_project]
131    pub struct QueryWrapper<T> {
132        #[pin]
133        wrapped: T,
134        #[allow(dead_code)] // is used via a reference inside T
135        query_string: String,
136    }
137
138    impl<'query, T: 'query> QueryWrapper<T> {
139        /// Basic constructor which only performs the unsafe lifetime extension to be tested by miri
140        pub(crate) fn new_basic(string: String, wrapped: impl FnOnce(&'query str) -> T) -> Self {
141            let slice: &str = string.as_str();
142
143            // SAFETY: The heap allocation won't be dropped or moved
144            //         until `wrapped` which contains this reference is dropped.
145            let slice: &'query str = unsafe { std::mem::transmute(slice) };
146
147            Self {
148                query_string: string,
149                wrapped: wrapped(slice),
150            }
151        }
152
153        pub fn new<'data: 'query>(
154            executor: impl AnyExecutor<'query>,
155            query_string: String,
156            values: Vec<Value<'data>>,
157            execute: impl FnOnce(AnyQuery<'query>) -> T,
158        ) -> Self {
159            Self::new_basic(query_string, move |query_string| {
160                let mut query = executor.query(query_string);
161                for value in values {
162                    bind_param(&mut query, value);
163                }
164                execute(query)
165            })
166        }
167    }
168
169    impl<T> QueryWrapper<T> {
170        /// Project a [`Pin`] onto the `wrapped` field
171        pub fn project_wrapped(self: Pin<&mut Self>) -> Pin<&mut T> {
172            self.project().wrapped
173        }
174    }
175}
176
177impl<F> future::Future for QueryFuture<F>
178where
179    F: future::Future,
180{
181    type Output = F::Output;
182
183    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
184        self.project_wrapped().poll(cx)
185    }
186}
187
188impl<S> stream::Stream for QueryStream<S>
189where
190    S: stream::Stream,
191{
192    type Item = S::Item;
193
194    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
195        self.project_wrapped().poll_next(cx)
196    }
197}
198
199impl QueryStrategyResult for Nothing {
200    type Result<'query> = QueryFuture<NothingFuture<'query>>;
201}
202
203impl QueryStrategyImpl for Nothing {
204    fn execute<'query, E>(
205        executor: E,
206        query: String,
207        values: Vec<Value<'query>>,
208    ) -> Self::Result<'query>
209    where
210        E: AnyExecutor<'query>,
211    {
212        QueryFuture::new(executor, query, values, |query| NothingFuture {
213            stream: query.fetch_many(),
214        })
215    }
216}
217
218/// [`QueryStrategyResult::Result`] of [`Nothing`]
219pub struct NothingFuture<'stream> {
220    stream: BoxStream<'stream, sqlx::Result<sqlx::Either<AnyQueryResult, AnyRow>>>,
221}
222
223impl future::Future for NothingFuture<'_> {
224    type Output = Result<(), Error>;
225
226    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
227        loop {
228            return Poll::Ready(match ready!(self.stream.as_mut().poll_next(cx)) {
229                None => Ok(()),
230                Some(Err(error)) => Err(error.into()),
231                Some(_either) => continue,
232            });
233        }
234    }
235}
236
237impl QueryStrategyResult for AffectedRows {
238    type Result<'query> = QueryFuture<BoxFuture<'query, Result<u64, Error>>>;
239}
240
241impl QueryStrategyImpl for AffectedRows {
242    fn execute<'query, E>(
243        executor: E,
244        query: String,
245        values: Vec<Value<'query>>,
246    ) -> Self::Result<'query>
247    where
248        E: AnyExecutor<'query>,
249    {
250        QueryFuture::new(executor, query, values, |query| {
251            Box::pin(async move { Ok(query.fetch_affected_rows().await?) }) as BoxFuture<_>
252        })
253    }
254}
255
256impl QueryStrategyResult for One {
257    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Row, Error>>>;
258}
259
260impl QueryStrategyImpl for One {
261    fn execute<'query, E>(
262        executor: E,
263        query: String,
264        values: Vec<Value<'query>>,
265    ) -> Self::Result<'query>
266    where
267        E: AnyExecutor<'query>,
268    {
269        QueryFuture::new(executor, query, values, |query| {
270            Box::pin(async move {
271                Ok(Row(query
272                    .fetch_optional()
273                    .await?
274                    .ok_or(sqlx::Error::RowNotFound)?))
275            }) as BoxFuture<_>
276        })
277    }
278}
279
280impl QueryStrategyResult for Optional {
281    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Option<Row>, Error>>>;
282}
283
284impl QueryStrategyImpl for Optional {
285    fn execute<'query, E>(
286        executor: E,
287        query: String,
288        values: Vec<Value<'query>>,
289    ) -> Self::Result<'query>
290    where
291        E: AnyExecutor<'query>,
292    {
293        QueryFuture::new(executor, query, values, |query| {
294            Box::pin(async move { Ok(query.fetch_optional().await?.map(Row)) }) as BoxFuture<_>
295        })
296    }
297}
298
299impl QueryStrategyResult for All {
300    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Vec<Row>, Error>>>;
301}
302
303impl QueryStrategyImpl for All {
304    fn execute<'query, E>(
305        executor: E,
306        query: String,
307        values: Vec<Value<'query>>,
308    ) -> Self::Result<'query>
309    where
310        E: AnyExecutor<'query>,
311    {
312        QueryFuture::new(executor, query, values, |query| {
313            Box::pin(async move { Ok(query.fetch_all().await?.into_iter().map(Row).collect()) })
314                as BoxFuture<_>
315        })
316    }
317}
318
319impl QueryStrategyResult for Stream {
320    type Result<'query> = QueryStream<StreamResult<'query>>;
321}
322
323impl QueryStrategyImpl for Stream {
324    fn execute<'query, E>(
325        executor: E,
326        query: String,
327        values: Vec<Value<'query>>,
328    ) -> Self::Result<'query>
329    where
330        E: AnyExecutor<'query>,
331    {
332        QueryStream::new(executor, query, values, |query| StreamResult {
333            stream: query.fetch_many(),
334        })
335    }
336}
337
338/// [`QueryStrategyResult::Result`] of [`Stream`]
339pub struct StreamResult<'stream> {
340    stream: BoxStream<'stream, sqlx::Result<sqlx::Either<AnyQueryResult, AnyRow>>>,
341}
342
343impl Unpin for StreamResult<'_> {}
344impl stream::Stream for StreamResult<'_> {
345    type Item = Result<Row, Error>;
346
347    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
348        loop {
349            return Poll::Ready(match ready!(self.stream.as_mut().poll_next(cx)) {
350                None => None,
351                Some(Err(error)) => Some(Err(error.into())),
352                Some(Ok(sqlx::Either::Right(row))) => Some(Ok(Row(row))),
353                Some(Ok(sqlx::Either::Left(_result))) => continue,
354            });
355        }
356    }
357}
358
359#[cfg(test)]
360mod test {
361    use crate::internal::executor::QueryWrapper;
362
363    /// Run this test with miri
364    ///
365    /// If the drop order of [`QueryWrapper`]'s fields is incorrect,
366    /// miri will complain about a use-after-free.
367    #[test]
368    fn test_drop_order() {
369        struct BorrowStr<'a>(&'a str);
370        impl<'a> Drop for BorrowStr<'a> {
371            fn drop(&mut self) {
372                // Use the borrowed string.
373                // If it were already dropped, miri would detect it.
374                println!("{}", self.0);
375            }
376        }
377        let _w = QueryWrapper::new_basic(format!("Hello World"), BorrowStr);
378    }
379}