1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::types::{BorrowToSql, IsNull};
5use crate::{Error, Portal, Row, Statement};
6use bytes::{Bytes, BytesMut};
7use futures::{ready, Stream};
8use log::{debug, log_enabled, Level};
9use pin_project_lite::pin_project;
10use postgres_protocol::message::backend::Message;
11use postgres_protocol::message::frontend;
12use std::fmt;
13use std::marker::PhantomPinned;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16
17struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
18
19impl<'a, T> fmt::Debug for BorrowToSqlParamsDebug<'a, T>
20where
21 T: BorrowToSql,
22{
23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
24 f.debug_list()
25 .entries(self.0.iter().map(|x| x.borrow_to_sql()))
26 .finish()
27 }
28}
29
30pub async fn query<P, I>(
31 client: &InnerClient,
32 statement: Statement,
33 params: I,
34) -> Result<RowStream, Error>
35where
36 P: BorrowToSql,
37 I: IntoIterator<Item = P>,
38 I::IntoIter: ExactSizeIterator,
39{
40 let buf = if log_enabled!(Level::Debug) {
41 let params = params.into_iter().collect::<Vec<_>>();
42 debug!(
43 "executing statement {} with parameters: {:?}",
44 statement.name(),
45 BorrowToSqlParamsDebug(params.as_slice()),
46 );
47 encode(client, &statement, params)?
48 } else {
49 encode(client, &statement, params)?
50 };
51 let responses = start(client, buf).await?;
52 Ok(RowStream {
53 statement,
54 responses,
55 _p: PhantomPinned,
56 })
57}
58
59pub async fn query_portal(
60 client: &InnerClient,
61 portal: &Portal,
62 max_rows: i32,
63) -> Result<RowStream, Error> {
64 let buf = client.with_buf(|buf| {
65 frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
66 frontend::sync(buf);
67 Ok(buf.split().freeze())
68 })?;
69
70 let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
71
72 Ok(RowStream {
73 statement: portal.statement().clone(),
74 responses,
75 _p: PhantomPinned,
76 })
77}
78
79pub async fn execute<P, I>(
80 client: &InnerClient,
81 statement: Statement,
82 params: I,
83) -> Result<u64, Error>
84where
85 P: BorrowToSql,
86 I: IntoIterator<Item = P>,
87 I::IntoIter: ExactSizeIterator,
88{
89 let buf = if log_enabled!(Level::Debug) {
90 let params = params.into_iter().collect::<Vec<_>>();
91 debug!(
92 "executing statement {} with parameters: {:?}",
93 statement.name(),
94 BorrowToSqlParamsDebug(params.as_slice()),
95 );
96 encode(client, &statement, params)?
97 } else {
98 encode(client, &statement, params)?
99 };
100 let mut responses = start(client, buf).await?;
101
102 let mut rows = 0;
103 loop {
104 match responses.next().await? {
105 Message::DataRow(_) => {}
106 Message::CommandComplete(body) => {
107 rows = body
108 .tag()
109 .map_err(Error::parse)?
110 .rsplit(' ')
111 .next()
112 .unwrap()
113 .parse()
114 .unwrap_or(0);
115 }
116 Message::EmptyQueryResponse => rows = 0,
117 Message::ReadyForQuery(_) => return Ok(rows),
118 _ => return Err(Error::unexpected_message()),
119 }
120 }
121}
122
123async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
124 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
125
126 match responses.next().await? {
127 Message::BindComplete => {}
128 _ => return Err(Error::unexpected_message()),
129 }
130
131 Ok(responses)
132}
133
134pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
135where
136 P: BorrowToSql,
137 I: IntoIterator<Item = P>,
138 I::IntoIter: ExactSizeIterator,
139{
140 client.with_buf(|buf| {
141 encode_bind(statement, params, "", buf)?;
142 frontend::execute("", 0, buf).map_err(Error::encode)?;
143 frontend::sync(buf);
144 Ok(buf.split().freeze())
145 })
146}
147
148pub fn encode_bind<P, I>(
149 statement: &Statement,
150 params: I,
151 portal: &str,
152 buf: &mut BytesMut,
153) -> Result<(), Error>
154where
155 P: BorrowToSql,
156 I: IntoIterator<Item = P>,
157 I::IntoIter: ExactSizeIterator,
158{
159 let params = params.into_iter();
160
161 assert!(
162 statement.params().len() == params.len(),
163 "expected {} parameters but got {}",
164 statement.params().len(),
165 params.len()
166 );
167
168 let mut error_idx = 0;
169 let r = frontend::bind(
170 portal,
171 statement.name(),
172 Some(1),
173 params.zip(statement.params()).enumerate(),
174 |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(ty, buf) {
175 Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
176 Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
177 Err(e) => {
178 error_idx = idx;
179 Err(e)
180 }
181 },
182 Some(1),
183 buf,
184 );
185 match r {
186 Ok(()) => Ok(()),
187 Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
188 Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
189 }
190}
191
192pin_project! {
193 pub struct RowStream {
195 statement: Statement,
196 responses: Responses,
197 #[pin]
198 _p: PhantomPinned,
199 }
200}
201
202impl Stream for RowStream {
203 type Item = Result<Row, Error>;
204
205 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
206 let this = self.project();
207 loop {
208 match ready!(this.responses.poll_next(cx)?) {
209 Message::DataRow(body) => {
210 return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
211 }
212 Message::EmptyQueryResponse
213 | Message::CommandComplete(_)
214 | Message::PortalSuspended => {}
215 Message::ReadyForQuery(_) => return Poll::Ready(None),
216 _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
217 }
218 }
219 }
220}