Skip to main content

tokio_postgres/
query.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::prepare::get_type;
5use crate::types::{BorrowToSql, IsNull};
6use crate::{Column, Error, Portal, Row, Statement};
7use bytes::{Bytes, BytesMut};
8use fallible_iterator::FallibleIterator;
9use futures_util::Stream;
10use log::{Level, debug, log_enabled};
11use pin_project_lite::pin_project;
12use postgres_protocol::message::backend::{CommandCompleteBody, Message};
13use postgres_protocol::message::frontend;
14use postgres_types::Type;
15use std::fmt;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{Context, Poll, ready};
19
20struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
21
22impl<T> fmt::Debug for BorrowToSqlParamsDebug<'_, T>
23where
24    T: BorrowToSql,
25{
26    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27        f.debug_list()
28            .entries(self.0.iter().map(|x| x.borrow_to_sql()))
29            .finish()
30    }
31}
32
33pub async fn query<P, I>(
34    client: &InnerClient,
35    statement: Statement,
36    params: I,
37) -> Result<RowStream, Error>
38where
39    P: BorrowToSql,
40    I: IntoIterator<Item = P>,
41    I::IntoIter: ExactSizeIterator,
42{
43    let buf = if log_enabled!(Level::Debug) {
44        let params = params.into_iter().collect::<Vec<_>>();
45        debug!(
46            "executing statement {} with parameters: {:?}",
47            statement.name(),
48            BorrowToSqlParamsDebug(params.as_slice()),
49        );
50        encode(client, &statement, params)?
51    } else {
52        encode(client, &statement, params)?
53    };
54    let responses = start(client, buf).await?;
55    Ok(RowStream {
56        statement,
57        responses,
58        rows_affected: None,
59    })
60}
61
62pub async fn query_typed<P, I>(
63    client: &Arc<InnerClient>,
64    query: &str,
65    params: I,
66) -> Result<RowStream, Error>
67where
68    P: BorrowToSql,
69    I: IntoIterator<Item = (P, Type)>,
70{
71    let buf = {
72        let params = params.into_iter().collect::<Vec<_>>();
73        let param_oids = params.iter().map(|(_, t)| t.oid()).collect::<Vec<_>>();
74
75        client.with_buf(|buf| {
76            frontend::parse("", query, param_oids, buf).map_err(Error::parse)?;
77            encode_bind_raw("", params, "", buf)?;
78            frontend::describe(b'S', "", buf).map_err(Error::encode)?;
79            frontend::execute("", 0, buf).map_err(Error::encode)?;
80            frontend::sync(buf);
81
82            Ok(buf.split().freeze())
83        })?
84    };
85
86    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
87
88    loop {
89        match responses.next().await? {
90            Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {}
91            Message::NoData => {
92                return Ok(RowStream {
93                    statement: Statement::unnamed(vec![], vec![]),
94                    responses,
95                    rows_affected: None,
96                });
97            }
98            Message::RowDescription(row_description) => {
99                let mut columns: Vec<Column> = vec![];
100                let mut it = row_description.fields();
101                while let Some(field) = it.next().map_err(Error::parse)? {
102                    let type_ = get_type(client, field.type_oid()).await?;
103                    let column = Column {
104                        name: field.name().to_string(),
105                        table_oid: Some(field.table_oid()).filter(|n| *n != 0),
106                        column_id: Some(field.column_id()).filter(|n| *n != 0),
107                        type_modifier: field.type_modifier(),
108                        r#type: type_,
109                    };
110                    columns.push(column);
111                }
112                return Ok(RowStream {
113                    statement: Statement::unnamed(vec![], columns),
114                    responses,
115                    rows_affected: None,
116                });
117            }
118            _ => return Err(Error::unexpected_message()),
119        }
120    }
121}
122
123pub async fn execute_typed<P, I>(
124    client: &Arc<InnerClient>,
125    query: &str,
126    params: I,
127) -> Result<u64, Error>
128where
129    P: BorrowToSql,
130    I: IntoIterator<Item = (P, Type)>,
131{
132    let buf = {
133        let params = params.into_iter().collect::<Vec<_>>();
134        let param_oids = params.iter().map(|(_, t)| t.oid()).collect::<Vec<_>>();
135
136        client.with_buf(|buf| {
137            frontend::parse("", query, param_oids, buf).map_err(Error::parse)?;
138            encode_bind_raw("", params, "", buf)?;
139            frontend::describe(b'S', "", buf).map_err(Error::encode)?;
140            frontend::execute("", 0, buf).map_err(Error::encode)?;
141            frontend::sync(buf);
142
143            Ok(buf.split().freeze())
144        })?
145    };
146
147    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
148
149    let mut rows = 0;
150
151    loop {
152        match responses.next().await? {
153            Message::ParseComplete
154            | Message::BindComplete
155            | Message::ParameterDescription(_)
156            | Message::RowDescription(_) => {}
157            Message::NoData => {
158                rows = 0;
159            }
160
161            Message::DataRow(_) => {}
162            Message::CommandComplete(body) => {
163                rows = extract_row_affected(&body)?;
164            }
165
166            Message::EmptyQueryResponse => rows = 0,
167            Message::ReadyForQuery(_) => return Ok(rows),
168            _ => {
169                return Err(Error::unexpected_message());
170            }
171        }
172    }
173}
174
175pub async fn query_portal(
176    client: &InnerClient,
177    portal: &Portal,
178    max_rows: i32,
179) -> Result<RowStream, Error> {
180    let buf = client.with_buf(|buf| {
181        frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
182        frontend::sync(buf);
183        Ok(buf.split().freeze())
184    })?;
185
186    let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
187
188    Ok(RowStream {
189        statement: portal.statement().clone(),
190        responses,
191        rows_affected: None,
192    })
193}
194
195/// Extract the number of rows affected from [`CommandCompleteBody`].
196pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
197    let rows = body
198        .tag()
199        .map_err(Error::parse)?
200        .rsplit(' ')
201        .next()
202        .unwrap()
203        .parse()
204        .unwrap_or(0);
205    Ok(rows)
206}
207
208pub async fn execute<P, I>(
209    client: &InnerClient,
210    statement: Statement,
211    params: I,
212) -> Result<u64, Error>
213where
214    P: BorrowToSql,
215    I: IntoIterator<Item = P>,
216    I::IntoIter: ExactSizeIterator,
217{
218    let buf = if log_enabled!(Level::Debug) {
219        let params = params.into_iter().collect::<Vec<_>>();
220        debug!(
221            "executing statement {} with parameters: {:?}",
222            statement.name(),
223            BorrowToSqlParamsDebug(params.as_slice()),
224        );
225        encode(client, &statement, params)?
226    } else {
227        encode(client, &statement, params)?
228    };
229    let mut responses = start(client, buf).await?;
230
231    let mut rows = 0;
232    loop {
233        match responses.next().await? {
234            Message::DataRow(_) => {}
235            Message::CommandComplete(body) => {
236                rows = extract_row_affected(&body)?;
237            }
238            Message::EmptyQueryResponse => rows = 0,
239            Message::ReadyForQuery(_) => return Ok(rows),
240            _ => return Err(Error::unexpected_message()),
241        }
242    }
243}
244
245async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
246    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
247
248    match responses.next().await? {
249        Message::BindComplete => {}
250        _ => return Err(Error::unexpected_message()),
251    }
252
253    Ok(responses)
254}
255
256pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
257where
258    P: BorrowToSql,
259    I: IntoIterator<Item = P>,
260    I::IntoIter: ExactSizeIterator,
261{
262    client.with_buf(|buf| {
263        encode_bind(statement, params, "", buf)?;
264        frontend::execute("", 0, buf).map_err(Error::encode)?;
265        frontend::sync(buf);
266        Ok(buf.split().freeze())
267    })
268}
269
270pub fn encode_bind<P, I>(
271    statement: &Statement,
272    params: I,
273    portal: &str,
274    buf: &mut BytesMut,
275) -> Result<(), Error>
276where
277    P: BorrowToSql,
278    I: IntoIterator<Item = P>,
279    I::IntoIter: ExactSizeIterator,
280{
281    let params = params.into_iter();
282    if params.len() != statement.params().len() {
283        return Err(Error::parameters(params.len(), statement.params().len()));
284    }
285
286    encode_bind_raw(
287        statement.name(),
288        params.zip(statement.params().iter().cloned()),
289        portal,
290        buf,
291    )
292}
293
294fn encode_bind_raw<P, I>(
295    statement_name: &str,
296    params: I,
297    portal: &str,
298    buf: &mut BytesMut,
299) -> Result<(), Error>
300where
301    P: BorrowToSql,
302    I: IntoIterator<Item = (P, Type)>,
303    I::IntoIter: ExactSizeIterator,
304{
305    let (param_formats, params): (Vec<_>, Vec<_>) = params
306        .into_iter()
307        .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty)))
308        .unzip();
309
310    let mut error_idx = 0;
311    let r = frontend::bind(
312        portal,
313        statement_name,
314        param_formats,
315        params.into_iter().enumerate(),
316        |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) {
317            Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
318            Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
319            Err(e) => {
320                error_idx = idx;
321                Err(e)
322            }
323        },
324        Some(1),
325        buf,
326    );
327    match r {
328        Ok(()) => Ok(()),
329        Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
330        Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
331    }
332}
333
334pin_project! {
335    /// A stream of table rows.
336    #[project(!Unpin)]
337    pub struct RowStream {
338        statement: Statement,
339        responses: Responses,
340        rows_affected: Option<u64>,
341    }
342}
343
344impl Stream for RowStream {
345    type Item = Result<Row, Error>;
346
347    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
348        let this = self.project();
349        loop {
350            match ready!(this.responses.poll_next(cx)?) {
351                Message::DataRow(body) => {
352                    return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)));
353                }
354                Message::CommandComplete(body) => {
355                    *this.rows_affected = Some(extract_row_affected(&body)?);
356                }
357                Message::EmptyQueryResponse | Message::PortalSuspended => {}
358                Message::ReadyForQuery(_) => return Poll::Ready(None),
359                _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
360            }
361        }
362    }
363}
364
365impl RowStream {
366    /// Returns the number of rows affected by the query.
367    ///
368    /// This function will return `None` until the stream has been exhausted.
369    pub fn rows_affected(&self) -> Option<u64> {
370        self.rows_affected
371    }
372}
373
374pub async fn sync(client: &InnerClient) -> Result<(), Error> {
375    let buf = Bytes::from_static(b"S\0\0\0\x04");
376    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
377
378    match responses.next().await? {
379        Message::ReadyForQuery(_) => Ok(()),
380        _ => Err(Error::unexpected_message()),
381    }
382}