madsim_tokio_postgres/
query.rs

1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::types::{BorrowToSql, IsNull};
5use crate::{Error, Portal, Row, Statement};
6use bytes::{Bytes, BytesMut};
7use futures::{ready, Stream};
8use log::{debug, log_enabled, Level};
9use pin_project_lite::pin_project;
10use postgres_protocol::message::backend::Message;
11use postgres_protocol::message::frontend;
12use std::fmt;
13use std::marker::PhantomPinned;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
18
19impl<'a, T> fmt::Debug for BorrowToSqlParamsDebug<'a, T>
20where
21    T: BorrowToSql,
22{
23    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24        f.debug_list()
25            .entries(self.0.iter().map(|x| x.borrow_to_sql()))
26            .finish()
27    }
28}
29
30pub async fn query<P, I>(
31    client: &InnerClient,
32    statement: Statement,
33    params: I,
34) -> Result<RowStream, Error>
35where
36    P: BorrowToSql,
37    I: IntoIterator<Item = P>,
38    I::IntoIter: ExactSizeIterator,
39{
40    let buf = if log_enabled!(Level::Debug) {
41        let params = params.into_iter().collect::<Vec<_>>();
42        debug!(
43            "executing statement {} with parameters: {:?}",
44            statement.name(),
45            BorrowToSqlParamsDebug(params.as_slice()),
46        );
47        encode(client, &statement, params)?
48    } else {
49        encode(client, &statement, params)?
50    };
51    let responses = start(client, buf).await?;
52    Ok(RowStream {
53        statement,
54        responses,
55        _p: PhantomPinned,
56    })
57}
58
59pub async fn query_portal(
60    client: &InnerClient,
61    portal: &Portal,
62    max_rows: i32,
63) -> Result<RowStream, Error> {
64    let buf = client.with_buf(|buf| {
65        frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
66        frontend::sync(buf);
67        Ok(buf.split().freeze())
68    })?;
69
70    let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
71
72    Ok(RowStream {
73        statement: portal.statement().clone(),
74        responses,
75        _p: PhantomPinned,
76    })
77}
78
79pub async fn execute<P, I>(
80    client: &InnerClient,
81    statement: Statement,
82    params: I,
83) -> Result<u64, Error>
84where
85    P: BorrowToSql,
86    I: IntoIterator<Item = P>,
87    I::IntoIter: ExactSizeIterator,
88{
89    let buf = if log_enabled!(Level::Debug) {
90        let params = params.into_iter().collect::<Vec<_>>();
91        debug!(
92            "executing statement {} with parameters: {:?}",
93            statement.name(),
94            BorrowToSqlParamsDebug(params.as_slice()),
95        );
96        encode(client, &statement, params)?
97    } else {
98        encode(client, &statement, params)?
99    };
100    let mut responses = start(client, buf).await?;
101
102    let mut rows = 0;
103    loop {
104        match responses.next().await? {
105            Message::DataRow(_) => {}
106            Message::CommandComplete(body) => {
107                rows = body
108                    .tag()
109                    .map_err(Error::parse)?
110                    .rsplit(' ')
111                    .next()
112                    .unwrap()
113                    .parse()
114                    .unwrap_or(0);
115            }
116            Message::EmptyQueryResponse => rows = 0,
117            Message::ReadyForQuery(_) => return Ok(rows),
118            _ => return Err(Error::unexpected_message()),
119        }
120    }
121}
122
123async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
124    let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
125
126    match responses.next().await? {
127        Message::BindComplete => {}
128        _ => return Err(Error::unexpected_message()),
129    }
130
131    Ok(responses)
132}
133
134pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
135where
136    P: BorrowToSql,
137    I: IntoIterator<Item = P>,
138    I::IntoIter: ExactSizeIterator,
139{
140    client.with_buf(|buf| {
141        encode_bind(statement, params, "", buf)?;
142        frontend::execute("", 0, buf).map_err(Error::encode)?;
143        frontend::sync(buf);
144        Ok(buf.split().freeze())
145    })
146}
147
148pub fn encode_bind<P, I>(
149    statement: &Statement,
150    params: I,
151    portal: &str,
152    buf: &mut BytesMut,
153) -> Result<(), Error>
154where
155    P: BorrowToSql,
156    I: IntoIterator<Item = P>,
157    I::IntoIter: ExactSizeIterator,
158{
159    let params = params.into_iter();
160
161    assert!(
162        statement.params().len() == params.len(),
163        "expected {} parameters but got {}",
164        statement.params().len(),
165        params.len()
166    );
167
168    let mut error_idx = 0;
169    let r = frontend::bind(
170        portal,
171        statement.name(),
172        Some(1),
173        params.zip(statement.params()).enumerate(),
174        |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
175            Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
176            Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
177            Err(e) => {
178                error_idx = idx;
179                Err(e)
180            }
181        },
182        Some(1),
183        buf,
184    );
185    match r {
186        Ok(()) => Ok(()),
187        Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
188        Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
189    }
190}
191
192pin_project! {
193    /// A stream of table rows.
194    pub struct RowStream {
195        statement: Statement,
196        responses: Responses,
197        #[pin]
198        _p: PhantomPinned,
199    }
200}
201
202impl Stream for RowStream {
203    type Item = Result<Row, Error>;
204
205    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
206        let this = self.project();
207        loop {
208            match ready!(this.responses.poll_next(cx)?) {
209                Message::DataRow(body) => {
210                    return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
211                }
212                Message::EmptyQueryResponse
213                | Message::CommandComplete(_)
214                | Message::PortalSuspended => {}
215                Message::ReadyForQuery(_) => return Poll::Ready(None),
216                _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
217            }
218        }
219    }
220}