mysql_async/queryable/query_result/
result_set_stream.rs

1// Copyright (c) 2021 mysql_async developers.
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use std::any::type_name;
10use std::borrow::Cow;
11use std::sync::Arc;
12use std::task::Poll;
13use std::{fmt, marker::PhantomData};
14
15use futures_core::FusedStream;
16use futures_core::{future::BoxFuture, Stream};
17use futures_util::FutureExt;
18use mysql_common::packets::{Column, OkPacket};
19
20use crate::{
21    conn::PendingResult,
22    prelude::{FromRow, Protocol},
23    QueryResult, Row,
24};
25
26enum CowMut<'r, 'a: 'r, 't: 'a, P> {
27    Borrowed(&'r mut QueryResult<'a, 't, P>),
28    Owned(QueryResult<'a, 't, P>),
29}
30
31impl<'r, 'a: 'r, 't: 'a, P> fmt::Debug for CowMut<'r, 'a, 't, P> {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            Self::Borrowed(arg0) => f.debug_tuple("Borrowed").field(arg0).finish(),
35            Self::Owned(arg0) => f.debug_tuple("Owned").field(arg0).finish(),
36        }
37    }
38}
39
40impl<'r, 'a: 'r, 't: 'a, P> AsMut<QueryResult<'a, 't, P>> for CowMut<'r, 'a, 't, P> {
41    fn as_mut(&mut self) -> &mut QueryResult<'a, 't, P> {
42        match self {
43            CowMut::Borrowed(q) => q,
44            CowMut::Owned(q) => q,
45        }
46    }
47}
48
49enum ResultSetStreamState<'r, 'a: 'r, 't: 'a, P> {
50    Idle(CowMut<'r, 'a, 't, P>),
51    NextFut(BoxFuture<'r, (crate::Result<Option<Row>>, CowMut<'r, 'a, 't, P>)>),
52}
53
54impl<'r, 'a: 'r, 't: 'a, P> fmt::Debug for ResultSetStreamState<'r, 'a, 't, P> {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            Self::Idle(arg0) => f.debug_tuple("Idle").field(arg0).finish(),
58            Self::NextFut(_arg0) => f
59                .debug_tuple("NextFut")
60                .field(&type_name::<
61                    BoxFuture<'r, (crate::Result<Option<Row>>, CowMut<'r, 'a, 't, P>)>,
62                >())
63                .finish(),
64        }
65    }
66}
67
68#[derive(Debug)]
69/// Rows stream for a single result set.
70pub struct ResultSetStream<'r, 'a: 'r, 't: 'a, T, P> {
71    query_result: Option<ResultSetStreamState<'r, 'a, 't, P>>,
72    ok_packet: Option<OkPacket<'static>>,
73    columns: Arc<[Column]>,
74    __from_row_type: PhantomData<T>,
75}
76
77impl<'r, 'a: 'r, 't: 'a, T, P> FusedStream for ResultSetStream<'r, 'a, 't, T, P>
78where
79    P: Protocol + Unpin,
80    T: FromRow + Unpin + Send + 'static,
81{
82    fn is_terminated(&self) -> bool {
83        self.query_result.is_none()
84    }
85}
86
87impl<'r, 'a: 'r, 't: 'a, T, P> ResultSetStream<'r, 'a, 't, T, P> {
88    /// See [`Conn::last_insert_id`][1].
89    ///
90    /// [1]: crate::Conn::last_insert_id
91    pub fn last_insert_id(&self) -> Option<u64> {
92        self.ok_packet.as_ref().and_then(|ok| ok.last_insert_id())
93    }
94
95    /// See [`Conn::affected_rows`][1].
96    ///
97    /// [1]: crate::Conn::affected_rows
98    pub fn affected_rows(&self) -> u64 {
99        self.ok_packet
100            .as_ref()
101            .map(|ok| ok.affected_rows())
102            .unwrap_or_default()
103    }
104
105    /// See [`Conn::info`][1].
106    ///
107    /// [1]: crate::Conn::info
108    pub fn info(&self) -> Cow<'_, str> {
109        self.ok_packet
110            .as_ref()
111            .and_then(|ok| ok.info_str())
112            .unwrap_or_default()
113    }
114
115    /// See [`Conn::get_warnings`][1].
116    ///
117    /// [1]: crate::Conn::get_warnings
118    pub fn get_warnings(&self) -> u16 {
119        self.ok_packet
120            .as_ref()
121            .map(|ok| ok.warnings())
122            .unwrap_or_default()
123    }
124
125    /// Returns an `OkPacket` corresponding to this result set.
126    ///
127    /// Will be `None` if there is no OK packet (result set contains an error).
128    pub fn ok_packet(&self) -> Option<&OkPacket<'static>> {
129        self.ok_packet.as_ref()
130    }
131}
132
133impl<'r, 'a: 'r, 't: 'a, T, P> Stream for ResultSetStream<'r, 'a, 't, T, P>
134where
135    P: Protocol + Unpin,
136    T: FromRow + Unpin + Send + 'static,
137{
138    type Item = crate::Result<T>;
139
140    fn poll_next(
141        self: std::pin::Pin<&mut Self>,
142        cx: &mut std::task::Context<'_>,
143    ) -> Poll<Option<Self::Item>> {
144        let this = self.get_mut();
145        loop {
146            let columns = this.columns.clone();
147            match this.query_result.take() {
148                Some(ResultSetStreamState::Idle(mut query_result)) => {
149                    let fut = Box::pin(async move {
150                        let row = query_result.as_mut().next_row_or_next_set2(columns).await;
151                        (row, query_result)
152                    });
153                    this.query_result = Some(ResultSetStreamState::NextFut(fut));
154                }
155                Some(ResultSetStreamState::NextFut(mut fut)) => match fut.poll_unpin(cx) {
156                    Poll::Ready((row, query_result)) => match row {
157                        Ok(Some(row)) => {
158                            this.query_result = Some(ResultSetStreamState::Idle(query_result));
159                            return Poll::Ready(Some(Ok(crate::from_row(row))));
160                        }
161                        Ok(None) => return Poll::Ready(None),
162                        Err(err) => return Poll::Ready(Some(Err(err))),
163                    },
164                    Poll::Pending => {
165                        this.query_result = Some(ResultSetStreamState::NextFut(fut));
166                        return Poll::Pending;
167                    }
168                },
169                None => return Poll::Ready(None),
170            }
171        }
172    }
173}
174
175impl<'a, 't: 'a, P> QueryResult<'a, 't, P>
176where
177    P: Protocol + Unpin,
178{
179    async fn setup_stream(
180        &mut self,
181    ) -> crate::Result<Option<(Option<OkPacket<'static>>, Arc<[Column]>)>> {
182        match self.conn.use_pending_result()? {
183            Some(PendingResult::Taken(meta)) => {
184                let meta = (*meta).clone();
185                self.skip_taken(meta).await?;
186            }
187            Some(_) => (),
188            None => return Ok(None),
189        }
190
191        let ok_packet = self.conn.last_ok_packet().cloned();
192        let columns = match self.conn.take_pending_result()? {
193            Some(meta) => meta.columns().clone(),
194            None => return Ok(None),
195        };
196
197        Ok(Some((ok_packet, columns)))
198    }
199
200    /// Returns a [`Stream`] for the current result set.
201    ///
202    /// The returned stream satisfies [`futures_util::TryStream`],
203    /// so you can use [`futures_util::TryStreamExt`] functions on it.
204    ///
205    /// # Behavior
206    ///
207    /// ## Conversion
208    ///
209    /// This stream will convert each row into `T` using [`FromRow`] implementation.
210    /// If the row type is unknown please use the [`Row`] type for `T`
211    /// to make this conversion infallible.
212    ///
213    /// ## Consumption
214    ///
215    /// The call to [`QueryResult::stream`] entails the consumption of the current result set,
216    /// practically this means that the second call to [`QueryResult::stream`] will return
217    /// the next result set stream even if the stream returned from the first call wasn't
218    /// explicitly consumed:
219    ///
220    /// ```rust
221    /// # use mysql_async::test_misc::get_opts;
222    /// # #[tokio::main]
223    /// # async fn main() -> mysql_async::Result<()> {
224    /// # use mysql_async::*;
225    /// # use mysql_async::prelude::*;
226    /// # use futures_util::StreamExt;
227    /// let mut conn = Conn::new(get_opts()).await?;
228    ///
229    /// // This query result will contain two result sets.
230    /// let mut result = conn.query_iter("SELECT 1; SELECT 2;").await?;
231    ///
232    /// // The first result set stream is dropped here without being consumed,
233    /// let _ = result.stream::<u8>().await?;
234    /// // so it will be implicitly consumed here.
235    /// let mut stream = result.stream::<u8>().await?.expect("the second result set must be here");
236    /// assert_eq!(2_u8, stream.next().await.unwrap()?);
237    ///
238    /// # drop(stream); drop(result); conn.disconnect().await }
239    /// ```
240    ///
241    /// ## Errors
242    ///
243    /// Note, that [`QueryResult::stream`] may error if:
244    ///
245    /// - current result set contains an error,
246    /// - previously unconsumed result set stream contained an error.
247    ///
248    /// ```rust
249    /// # use mysql_async::test_misc::get_opts;
250    /// # #[tokio::main]
251    /// # async fn main() -> mysql_async::Result<()> {
252    /// # use mysql_async::*;
253    /// # use mysql_async::prelude::*;
254    /// # use futures_util::StreamExt;
255    /// let mut conn = Conn::new(get_opts()).await?;
256    ///
257    /// // The second result set of this query will contain an error.
258    /// let mut result = conn.query_iter("SELECT 1; SELECT FOO(); SELECT 2;").await?;
259    ///
260    /// // First result set stream is dropped here without being consumed,
261    /// let _ = result.stream::<Row>().await?;
262    /// // so it will be implicitly consumed on the second call to `QueryResult::stream`
263    /// // that will error complaining about unknown FOO
264    /// assert!(result.stream::<Row>().await.unwrap_err().to_string().contains("FOO"));
265    ///
266    /// # drop(result); conn.disconnect().await }
267    /// ```
268    pub fn stream<'r, T: Unpin + FromRow + Send + 'static>(
269        &'r mut self,
270    ) -> BoxFuture<'r, crate::Result<Option<ResultSetStream<'r, 'a, 't, T, P>>>> {
271        async move {
272            Ok(self
273                .setup_stream()
274                .await?
275                .map(
276                    move |(ok_packet, columns)| ResultSetStream::<'r, 'a, 't, T, P> {
277                        ok_packet,
278                        columns,
279                        query_result: Some(ResultSetStreamState::Idle(CowMut::Borrowed(self))),
280                        __from_row_type: PhantomData,
281                    },
282                ))
283        }
284        .boxed()
285    }
286
287    /// Owned version of the [`QueryResult::stream`].
288    ///
289    /// Returned stream will stop iteration on the first result set boundary.
290    ///
291    /// See also [`Query::stream`][query_stream], [`Queryable::query_stream`][queryable_query_stream],
292    /// [`Queryable::exec_stream`][queryable_exec_stream].
293    ///
294    /// The following example uses the [`Query::stream`][query_stream] function
295    /// that is based on the [`QueryResult::stream_and_drop`]:
296    ///
297    /// ```rust
298    /// # use mysql_async::test_misc::get_opts;
299    /// # #[tokio::main]
300    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
301    /// # use mysql_async::*;
302    /// # use mysql_async::prelude::*;
303    /// # use futures_util::TryStreamExt;
304    /// let pool = Pool::new(get_opts());
305    /// let mut conn = pool.get_conn().await?;
306    ///
307    /// // This example uses the `Query::stream` function that is based on `QueryResult::stream_and_drop`:
308    /// let mut stream = "SELECT 1 UNION ALL SELECT 2".stream::<u8, _>(&mut conn).await?;
309    /// let rows = stream.try_collect::<Vec<_>>().await.unwrap();
310    /// assert_eq!(vec![1, 2], rows);
311    ///
312    /// // Only the first result set will go into the stream:
313    /// let mut stream = r"
314    ///     SELECT 'foo' UNION ALL SELECT 'bar';
315    ///     SELECT 'baz' UNION ALL SELECT 'quux';".stream::<String, _>(&mut conn).await?;
316    /// let rows = stream.try_collect::<Vec<_>>().await.unwrap();
317    /// assert_eq!(vec!["foo".to_owned(), "bar".to_owned()], rows);
318    ///
319    /// // We can also build a `'static` stream by giving away the connection:
320    /// let stream = "SELECT 2 UNION ALL SELECT 3".stream::<u8, _>(conn).await?;
321    /// // `tokio::spawn` requires `'static`
322    /// let handle = tokio::spawn(async move {
323    ///     stream.try_collect::<Vec<_>>().await.unwrap()
324    /// });
325    /// assert_eq!(vec![2, 3], handle.await?);
326    ///
327    /// # Ok(()) }
328    /// ```
329    ///
330    /// [queryable_query_stream]: crate::prelude::Queryable::query_stream
331    /// [queryable_exec_stream]: crate::prelude::Queryable::exec_stream
332    /// [query_stream]: crate::prelude::Query::stream
333    pub fn stream_and_drop<T: Unpin + FromRow + Send + 'static>(
334        mut self,
335    ) -> BoxFuture<'a, crate::Result<Option<ResultSetStream<'a, 'a, 't, T, P>>>> {
336        async move {
337            Ok(self
338                .setup_stream()
339                .await?
340                .map(|(ok_packet, columns)| ResultSetStream::<'a, 'a, 't, T, P> {
341                    ok_packet,
342                    columns,
343                    query_result: Some(ResultSetStreamState::Idle(CowMut::Owned(self))),
344                    __from_row_type: PhantomData,
345                }))
346        }
347        .boxed()
348    }
349}