mysql_async/queryable/query_result/
result_set_stream.rs

1// Copyright (c) 2021 mysql_async developers.
2//
3// Licensed under the Apache License, Version 2.0
4// <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0> or the MIT
5// license <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. All files in the project carrying such notice may not be copied,
7// modified, or distributed except according to those terms.
8
9use std::any::type_name;
10use std::borrow::Cow;
11use std::sync::Arc;
12use std::task::Poll;
13use std::{fmt, marker::PhantomData};
14
15use futures_core::FusedStream;
16use futures_core::{future::BoxFuture, Stream};
17use futures_util::FutureExt;
18use mysql_common::packets::{Column, OkPacket};
19
20use crate::{
21    conn::PendingResult,
22    prelude::{FromRow, Protocol},
23    QueryResult, Row,
24};
25
26enum CowMut<'r, 'a: 'r, 't: 'a, P> {
27    Borrowed(&'r mut QueryResult<'a, 't, P>),
28    Owned(QueryResult<'a, 't, P>),
29}
30
31impl<'r, 'a: 'r, 't: 'a, P> fmt::Debug for CowMut<'r, 'a, 't, P> {
32    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
33        match self {
34            Self::Borrowed(arg0) => f.debug_tuple("Borrowed").field(arg0).finish(),
35            Self::Owned(arg0) => f.debug_tuple("Owned").field(arg0).finish(),
36        }
37    }
38}
39
40impl<'r, 'a: 'r, 't: 'a, P> AsMut<QueryResult<'a, 't, P>> for CowMut<'r, 'a, 't, P> {
41    fn as_mut(&mut self) -> &mut QueryResult<'a, 't, P> {
42        match self {
43            CowMut::Borrowed(q) => q,
44            CowMut::Owned(q) => q,
45        }
46    }
47}
48
49enum ResultSetStreamState<'r, 'a: 'r, 't: 'a, P> {
50    Idle(CowMut<'r, 'a, 't, P>),
51    NextFut(BoxFuture<'r, (crate::Result<Option<Row>>, CowMut<'r, 'a, 't, P>)>),
52}
53
54impl<'r, 'a: 'r, 't: 'a, P> fmt::Debug for ResultSetStreamState<'r, 'a, 't, P> {
55    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
56        match self {
57            Self::Idle(arg0) => f.debug_tuple("Idle").field(arg0).finish(),
58            Self::NextFut(_arg0) => f
59                .debug_tuple("NextFut")
60                .field(&type_name::<
61                    BoxFuture<'r, (crate::Result<Option<Row>>, CowMut<'r, 'a, 't, P>)>,
62                >())
63                .finish(),
64        }
65    }
66}
67
68#[derive(Debug)]
69/// Rows stream for a single result set.
70pub struct ResultSetStream<'r, 'a: 'r, 't: 'a, T, P> {
71    query_result: Option<ResultSetStreamState<'r, 'a, 't, P>>,
72    ok_packet: Option<OkPacket<'static>>,
73    columns: Arc<[Column]>,
74    __from_row_type: PhantomData<T>,
75}
76
77impl<'r, 'a: 'r, 't: 'a, T, P> FusedStream for ResultSetStream<'r, 'a, 't, T, P>
78where
79    P: Protocol + Unpin,
80    T: FromRow + Unpin + Send + 'static,
81{
82    fn is_terminated(&self) -> bool {
83        self.query_result.is_none()
84    }
85}
86
87impl<'r, 'a: 'r, 't: 'a, T, P> ResultSetStream<'r, 'a, 't, T, P> {
88    /// See [`Conn::last_insert_id`][1].
89    ///
90    /// [1]: crate::Conn::last_insert_id
91    pub fn last_insert_id(&self) -> Option<u64> {
92        self.ok_packet.as_ref().and_then(|ok| ok.last_insert_id())
93    }
94
95    /// See [`Conn::affected_rows`][1].
96    ///
97    /// [1]: crate::Conn::affected_rows
98    pub fn affected_rows(&self) -> u64 {
99        self.ok_packet
100            .as_ref()
101            .map(|ok| ok.affected_rows())
102            .unwrap_or_default()
103    }
104
105    /// See [`QueryResult::columns_ref`].
106    pub fn columns_ref(&self) -> &[Column] {
107        &self.columns[..]
108    }
109
110    /// See [`QueryResult::columns`].
111    pub fn columns(&self) -> Arc<[Column]> {
112        self.columns.clone()
113    }
114
115    /// See [`Conn::info`][1].
116    ///
117    /// [1]: crate::Conn::info
118    pub fn info(&self) -> Cow<'_, str> {
119        self.ok_packet
120            .as_ref()
121            .and_then(|ok| ok.info_str())
122            .unwrap_or_default()
123    }
124
125    /// See [`Conn::get_warnings`][1].
126    ///
127    /// [1]: crate::Conn::get_warnings
128    pub fn get_warnings(&self) -> u16 {
129        self.ok_packet
130            .as_ref()
131            .map(|ok| ok.warnings())
132            .unwrap_or_default()
133    }
134
135    /// Returns an `OkPacket` corresponding to this result set.
136    ///
137    /// Will be `None` if there is no OK packet (result set contains an error).
138    pub fn ok_packet(&self) -> Option<&OkPacket<'static>> {
139        self.ok_packet.as_ref()
140    }
141}
142
143impl<'r, 'a: 'r, 't: 'a, T, P> Stream for ResultSetStream<'r, 'a, 't, T, P>
144where
145    P: Protocol + Unpin,
146    T: FromRow + Unpin + Send + 'static,
147{
148    type Item = crate::Result<T>;
149
150    fn poll_next(
151        self: std::pin::Pin<&mut Self>,
152        cx: &mut std::task::Context<'_>,
153    ) -> Poll<Option<Self::Item>> {
154        let this = self.get_mut();
155        loop {
156            let columns = this.columns.clone();
157            match this.query_result.take() {
158                Some(ResultSetStreamState::Idle(mut query_result)) => {
159                    let fut = Box::pin(async move {
160                        let row = query_result.as_mut().next_row_or_next_set2(columns).await;
161                        (row, query_result)
162                    });
163                    this.query_result = Some(ResultSetStreamState::NextFut(fut));
164                }
165                Some(ResultSetStreamState::NextFut(mut fut)) => match fut.poll_unpin(cx) {
166                    Poll::Ready((row, query_result)) => match row {
167                        Ok(Some(row)) => {
168                            this.query_result = Some(ResultSetStreamState::Idle(query_result));
169                            return Poll::Ready(Some(Ok(crate::from_row(row))));
170                        }
171                        Ok(None) => return Poll::Ready(None),
172                        Err(err) => return Poll::Ready(Some(Err(err))),
173                    },
174                    Poll::Pending => {
175                        this.query_result = Some(ResultSetStreamState::NextFut(fut));
176                        return Poll::Pending;
177                    }
178                },
179                None => return Poll::Ready(None),
180            }
181        }
182    }
183}
184
185impl<'a, 't: 'a, P> QueryResult<'a, 't, P>
186where
187    P: Protocol + Unpin,
188{
189    async fn setup_stream(
190        &mut self,
191    ) -> crate::Result<Option<(Option<OkPacket<'static>>, Arc<[Column]>)>> {
192        match self.conn.use_pending_result()? {
193            Some(PendingResult::Taken(meta)) => {
194                let meta = (*meta).clone();
195                self.skip_taken(meta).await?;
196            }
197            Some(_) => (),
198            None => return Ok(None),
199        }
200
201        let ok_packet = self.conn.last_ok_packet().cloned();
202        let columns = match self.conn.take_pending_result()? {
203            Some(meta) => meta.columns().clone(),
204            None => return Ok(None),
205        };
206
207        Ok(Some((ok_packet, columns)))
208    }
209
210    /// Returns a [`Stream`] for the current result set.
211    ///
212    /// The returned stream satisfies [`futures_util::TryStream`],
213    /// so you can use [`futures_util::TryStreamExt`] functions on it.
214    ///
215    /// # Behavior
216    ///
217    /// ## Conversion
218    ///
219    /// This stream will convert each row into `T` using [`FromRow`] implementation.
220    /// If the row type is unknown please use the [`Row`] type for `T`
221    /// to make this conversion infallible.
222    ///
223    /// ## Consumption
224    ///
225    /// The call to [`QueryResult::stream`] entails the consumption of the current result set,
226    /// practically this means that the second call to [`QueryResult::stream`] will return
227    /// the next result set stream even if the stream returned from the first call wasn't
228    /// explicitly consumed:
229    ///
230    /// ```rust
231    /// # use mysql_async::test_misc::get_opts;
232    /// # #[tokio::main]
233    /// # async fn main() -> mysql_async::Result<()> {
234    /// # use mysql_async::*;
235    /// # use mysql_async::prelude::*;
236    /// # use futures_util::StreamExt;
237    /// let mut conn = Conn::new(get_opts()).await?;
238    ///
239    /// // This query result will contain two result sets.
240    /// let mut result = conn.query_iter("SELECT 1; SELECT 2;").await?;
241    ///
242    /// // The first result set stream is dropped here without being consumed,
243    /// let _ = result.stream::<u8>().await?;
244    /// // so it will be implicitly consumed here.
245    /// let mut stream = result.stream::<u8>().await?.expect("the second result set must be here");
246    /// assert_eq!(2_u8, stream.next().await.unwrap()?);
247    ///
248    /// # drop(stream); drop(result); conn.disconnect().await }
249    /// ```
250    ///
251    /// ## Errors
252    ///
253    /// Note, that [`QueryResult::stream`] may error if:
254    ///
255    /// - current result set contains an error,
256    /// - previously unconsumed result set stream contained an error.
257    ///
258    /// ```rust
259    /// # use mysql_async::test_misc::get_opts;
260    /// # #[tokio::main]
261    /// # async fn main() -> mysql_async::Result<()> {
262    /// # use mysql_async::*;
263    /// # use mysql_async::prelude::*;
264    /// # use futures_util::StreamExt;
265    /// let mut conn = Conn::new(get_opts()).await?;
266    ///
267    /// // The second result set of this query will contain an error.
268    /// let mut result = conn.query_iter("SELECT 1; SELECT FOO(); SELECT 2;").await?;
269    ///
270    /// // First result set stream is dropped here without being consumed,
271    /// let _ = result.stream::<Row>().await?;
272    /// // so it will be implicitly consumed on the second call to `QueryResult::stream`
273    /// // that will error complaining about unknown FOO
274    /// assert!(result.stream::<Row>().await.unwrap_err().to_string().contains("FOO"));
275    ///
276    /// # drop(result); conn.disconnect().await }
277    /// ```
278    pub fn stream<'r, T: Unpin + FromRow + Send + 'static>(
279        &'r mut self,
280    ) -> BoxFuture<'r, crate::Result<Option<ResultSetStream<'r, 'a, 't, T, P>>>> {
281        async move {
282            Ok(self
283                .setup_stream()
284                .await?
285                .map(
286                    move |(ok_packet, columns)| ResultSetStream::<'r, 'a, 't, T, P> {
287                        ok_packet,
288                        columns,
289                        query_result: Some(ResultSetStreamState::Idle(CowMut::Borrowed(self))),
290                        __from_row_type: PhantomData,
291                    },
292                ))
293        }
294        .boxed()
295    }
296
297    /// Owned version of the [`QueryResult::stream`].
298    ///
299    /// Returned stream will stop iteration on the first result set boundary.
300    ///
301    /// See also [`Query::stream`][query_stream], [`Queryable::query_stream`][queryable_query_stream],
302    /// [`Queryable::exec_stream`][queryable_exec_stream].
303    ///
304    /// The following example uses the [`Query::stream`][query_stream] function
305    /// that is based on the [`QueryResult::stream_and_drop`]:
306    ///
307    /// ```rust
308    /// # use mysql_async::test_misc::get_opts;
309    /// # #[tokio::main]
310    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
311    /// # use mysql_async::*;
312    /// # use mysql_async::prelude::*;
313    /// # use futures_util::TryStreamExt;
314    /// let pool = Pool::new(get_opts());
315    /// let mut conn = pool.get_conn().await?;
316    ///
317    /// // This example uses the `Query::stream` function that is based on `QueryResult::stream_and_drop`:
318    /// let mut stream = "SELECT 1 UNION ALL SELECT 2".stream::<u8, _>(&mut conn).await?;
319    /// let rows = stream.try_collect::<Vec<_>>().await.unwrap();
320    /// assert_eq!(vec![1, 2], rows);
321    ///
322    /// // Only the first result set will go into the stream:
323    /// let mut stream = r"
324    ///     SELECT 'foo' UNION ALL SELECT 'bar';
325    ///     SELECT 'baz' UNION ALL SELECT 'quux';".stream::<String, _>(&mut conn).await?;
326    /// let rows = stream.try_collect::<Vec<_>>().await.unwrap();
327    /// assert_eq!(vec!["foo".to_owned(), "bar".to_owned()], rows);
328    ///
329    /// // We can also build a `'static` stream by giving away the connection:
330    /// let stream = "SELECT 2 UNION ALL SELECT 3".stream::<u8, _>(conn).await?;
331    /// // `tokio::spawn` requires `'static`
332    /// let handle = tokio::spawn(async move {
333    ///     stream.try_collect::<Vec<_>>().await.unwrap()
334    /// });
335    /// assert_eq!(vec![2, 3], handle.await?);
336    ///
337    /// # Ok(()) }
338    /// ```
339    ///
340    /// [queryable_query_stream]: crate::prelude::Queryable::query_stream
341    /// [queryable_exec_stream]: crate::prelude::Queryable::exec_stream
342    /// [query_stream]: crate::prelude::Query::stream
343    pub fn stream_and_drop<T: Unpin + FromRow + Send + 'static>(
344        mut self,
345    ) -> BoxFuture<'a, crate::Result<Option<ResultSetStream<'a, 'a, 't, T, P>>>> {
346        async move {
347            Ok(self
348                .setup_stream()
349                .await?
350                .map(|(ok_packet, columns)| ResultSetStream::<'a, 'a, 't, T, P> {
351                    ok_packet,
352                    columns,
353                    query_result: Some(ResultSetStreamState::Idle(CowMut::Owned(self))),
354                    __from_row_type: PhantomData,
355                }))
356        }
357        .boxed()
358    }
359}