rorm_db/sqlx_impl/
executor.rs

1use std::future::{ready, Ready};
2use std::pin::Pin;
3use std::task::{Context, Poll};
4
5use futures::future::{self, BoxFuture, FutureExt, TryFutureExt};
6use futures::stream::{self, BoxStream, TryCollect, TryFilterMap, TryStreamExt};
7use rorm_sql::value::Value;
8use rorm_sql::DBImpl;
9
10use crate::executor::{
11    AffectedRows, All, Executor, Nothing, One, Optional, QueryStrategy, QueryStrategyResult, Stream,
12};
13use crate::internal::any::{AnyExecutor, AnyPool, AnyQueryResult, AnyRow, AnyTransaction};
14use crate::transaction::{Transaction, TransactionGuard};
15use crate::{Database, Error, Row};
16
17impl<'executor> Executor<'executor> for &'executor mut Transaction {
18    fn execute<'data, 'result, Q>(
19        self,
20        query: String,
21        values: Vec<Value<'data>>,
22    ) -> Q::Result<'result>
23    where
24        'executor: 'result,
25        'data: 'result,
26        Q: QueryStrategy,
27    {
28        Q::execute(&mut self.0, query, values)
29    }
30
31    fn dialect(&self) -> DBImpl {
32        match self.0 {
33            #[cfg(feature = "postgres")]
34            AnyTransaction::Postgres(_) => DBImpl::Postgres,
35            #[cfg(feature = "mysql")]
36            AnyTransaction::MySql(_) => DBImpl::MySQL,
37            #[cfg(feature = "sqlite")]
38            AnyTransaction::Sqlite(_) => DBImpl::SQLite,
39        }
40    }
41
42    type EnsureTransactionFuture = Ready<Result<TransactionGuard<'executor>, Error>>;
43
44    fn ensure_transaction(
45        self,
46    ) -> BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>> {
47        Box::pin(ready(Ok(TransactionGuard::Borrowed(self))))
48    }
49}
50
51impl<'executor> Executor<'executor> for &'executor Database {
52    fn execute<'data, 'result, Q>(
53        self,
54        query: String,
55        values: Vec<Value<'data>>,
56    ) -> Q::Result<'result>
57    where
58        'executor: 'result,
59        'data: 'result,
60        Q: QueryStrategy,
61    {
62        Q::execute(&self.0, query, values)
63    }
64
65    fn dialect(&self) -> DBImpl {
66        match self.0 {
67            #[cfg(feature = "postgres")]
68            AnyPool::Postgres(_) => DBImpl::Postgres,
69            #[cfg(feature = "mysql")]
70            AnyPool::MySql(_) => DBImpl::MySQL,
71            #[cfg(feature = "sqlite")]
72            AnyPool::Sqlite(_) => DBImpl::SQLite,
73        }
74    }
75
76    type EnsureTransactionFuture = BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>>;
77
78    fn ensure_transaction(
79        self,
80    ) -> BoxFuture<'executor, Result<TransactionGuard<'executor>, Error>> {
81        Box::pin(async move { self.start_transaction().await.map(TransactionGuard::Owned) })
82    }
83}
84
85pub trait QueryStrategyImpl: QueryStrategyResult {
86    fn execute<'query, E>(
87        executor: E,
88        query: String,
89        values: Vec<Value<'query>>,
90    ) -> Self::Result<'query>
91    where
92        E: AnyExecutor<'query>;
93}
94
95type AnyEither = sqlx::Either<AnyQueryResult, AnyRow>;
96type FetchMany<'a> = BoxStream<'a, Result<AnyEither, sqlx::Error>>;
97
98pub type QueryFuture<T> = QueryWrapper<T>;
99pub type QueryStream<T> = QueryWrapper<T>;
100pub use query_wrapper::QueryWrapper;
101
102/// Private module to contain the internals behind a sound api
103mod query_wrapper {
104    use std::pin::Pin;
105
106    use rorm_sql::value::Value;
107
108    use crate::internal::any::{AnyExecutor, AnyQuery};
109
110    #[doc(hidden)]
111    #[pin_project::pin_project]
112    pub struct QueryWrapper<T> {
113        #[pin]
114        wrapped: T,
115        #[allow(dead_code)] // is used via a reference inside T
116        query_string: String,
117    }
118
119    impl<'query, T: 'query> QueryWrapper<T> {
120        /// Basic constructor which only performs the unsafe lifetime extension to be tested by miri
121        pub(crate) fn new_basic(string: String, wrapped: impl FnOnce(&'query str) -> T) -> Self {
122            let slice: &str = string.as_str();
123
124            // SAFETY: The heap allocation won't be dropped or moved
125            //         until `wrapped` which contains this reference is dropped.
126            let slice: &'query str = unsafe { std::mem::transmute(slice) };
127
128            Self {
129                query_string: string,
130                wrapped: wrapped(slice),
131            }
132        }
133
134        pub fn new<'data: 'query>(
135            executor: impl AnyExecutor<'query>,
136            query_string: String,
137            values: Vec<Value<'data>>,
138            execute: impl FnOnce(AnyQuery<'query>) -> T,
139        ) -> Self {
140            Self::new_basic(query_string, move |query_string| {
141                let mut query = executor.query(query_string);
142                for value in values {
143                    crate::internal::utils::bind_param(&mut query, value);
144                }
145                execute(query)
146            })
147        }
148    }
149
150    impl<T> QueryWrapper<T> {
151        /// Project a [`Pin`] onto the `wrapped` field
152        pub fn project_wrapped(self: Pin<&mut Self>) -> Pin<&mut T> {
153            self.project().wrapped
154        }
155    }
156}
157
158impl<F> future::Future for QueryFuture<F>
159where
160    F: future::Future,
161{
162    type Output = F::Output;
163
164    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
165        self.project_wrapped().poll(cx)
166    }
167}
168impl<S> stream::Stream for QueryStream<S>
169where
170    S: stream::Stream,
171{
172    type Item = S::Item;
173
174    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
175        self.project_wrapped().poll_next(cx)
176    }
177}
178
179impl QueryStrategyResult for Nothing {
180    type Result<'query> = QueryFuture<
181        future::MapOk<
182            TryCollect<
183                stream::ErrInto<stream::MapOk<FetchMany<'query>, fn(AnyEither) -> ()>, Error>,
184                Vec<()>,
185            >,
186            fn(Vec<()>) -> (),
187        >,
188    >;
189}
190
191impl QueryStrategyImpl for Nothing {
192    fn execute<'query, E>(
193        executor: E,
194        query: String,
195        values: Vec<Value<'query>>,
196    ) -> Self::Result<'query>
197    where
198        E: AnyExecutor<'query>,
199    {
200        fn dump<T>(_: T) {}
201        let dump_either: fn(AnyEither) -> () = dump;
202        let dump_vec: fn(Vec<()>) -> () = dump;
203        QueryFuture::new(executor, query, values, |query| {
204            query
205                .fetch_many()
206                .map_ok(dump_either)
207                .err_into()
208                .try_collect()
209                .map_ok(dump_vec)
210        })
211    }
212}
213
214impl QueryStrategyResult for AffectedRows {
215    type Result<'query> = QueryFuture<BoxFuture<'query, Result<u64, Error>>>;
216}
217impl QueryStrategyImpl for AffectedRows {
218    fn execute<'query, E>(
219        executor: E,
220        query: String,
221        values: Vec<Value<'query>>,
222    ) -> Self::Result<'query>
223    where
224        E: AnyExecutor<'query>,
225    {
226        QueryFuture::new(executor, query, values, |query| {
227            (async move { Ok(query.fetch_affected_rows().await?) }).boxed()
228        })
229    }
230}
231
232impl QueryStrategyResult for One {
233    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Row, Error>>>;
234}
235impl QueryStrategyImpl for One {
236    fn execute<'query, E>(
237        executor: E,
238        query: String,
239        values: Vec<Value<'query>>,
240    ) -> Self::Result<'query>
241    where
242        E: AnyExecutor<'query>,
243    {
244        QueryFuture::new(executor, query, values, |query| {
245            (async move {
246                Ok(Row(query
247                    .fetch_optional()
248                    .await?
249                    .ok_or(sqlx::Error::RowNotFound)?))
250            })
251            .boxed()
252        })
253    }
254}
255
256impl QueryStrategyResult for Optional {
257    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Option<Row>, Error>>>;
258}
259impl QueryStrategyImpl for Optional {
260    fn execute<'query, E>(
261        executor: E,
262        query: String,
263        values: Vec<Value<'query>>,
264    ) -> Self::Result<'query>
265    where
266        E: AnyExecutor<'query>,
267    {
268        QueryFuture::new(executor, query, values, |query| {
269            (async move { Ok(query.fetch_optional().await?.map(Row)) }).boxed()
270        })
271    }
272}
273
274/// Function used by [All] and [Stream] in [try_filter_map](TryStreamExt::try_filter_map).
275static TRY_FILTER_MAP: fn(AnyEither) -> Ready<Result<Option<Row>, sqlx::Error>> = {
276    fn convert(either: AnyEither) -> Ready<Result<Option<Row>, sqlx::Error>> {
277        std::future::ready(Ok(match either {
278            AnyEither::Left(_) => None,
279            AnyEither::Right(row) => Some(Row(row)),
280        }))
281    }
282    convert
283};
284
285impl QueryStrategyResult for All {
286    type Result<'query> = QueryFuture<BoxFuture<'query, Result<Vec<Row>, Error>>>;
287}
288impl QueryStrategyImpl for All {
289    fn execute<'query, E>(
290        executor: E,
291        query: String,
292        values: Vec<Value<'query>>,
293    ) -> Self::Result<'query>
294    where
295        E: AnyExecutor<'query>,
296    {
297        QueryFuture::new(executor, query, values, |query| {
298            (async move { Ok(query.fetch_all().await?.into_iter().map(Row).collect()) }).boxed()
299        })
300    }
301}
302
303impl QueryStrategyResult for Stream {
304    type Result<'query> = QueryStream<
305        stream::ErrInto<
306            TryFilterMap<
307                FetchMany<'query>,
308                Ready<Result<Option<Row>, sqlx::Error>>,
309                fn(AnyEither) -> Ready<Result<Option<Row>, sqlx::Error>>,
310            >,
311            Error,
312        >,
313    >;
314}
315impl QueryStrategyImpl for Stream {
316    fn execute<'query, E>(
317        executor: E,
318        query: String,
319        values: Vec<Value<'query>>,
320    ) -> Self::Result<'query>
321    where
322        E: AnyExecutor<'query>,
323    {
324        QueryStream::new(executor, query, values, |query| {
325            query.fetch_many().try_filter_map(TRY_FILTER_MAP).err_into()
326        })
327    }
328}
329
330#[cfg(test)]
331mod test {
332    use crate::internal::executor::QueryWrapper;
333
334    /// Run this test with miri
335    ///
336    /// If the drop order of [`QueryWrapper`]'s fields is incorrect,
337    /// miri will complain about a use-after-free.
338    #[test]
339    fn test_drop_order() {
340        struct BorrowStr<'a>(&'a str);
341        impl<'a> Drop for BorrowStr<'a> {
342            fn drop(&mut self) {
343                // Use the borrowed string.
344                // If it were already dropped, miri would detect it.
345                println!("{}", self.0);
346            }
347        }
348        let _w = QueryWrapper::new_basic(format!("Hello World"), BorrowStr);
349    }
350}