1use crate::client::{InnerClient, Responses};
2use crate::codec::FrontendMessage;
3use crate::connection::RequestMessages;
4use crate::prepare::get_type;
5use crate::types::{BorrowToSql, IsNull};
6use crate::{Column, Error, Portal, Row, Statement};
7use bytes::{Bytes, BytesMut};
8use fallible_iterator::FallibleIterator;
9use futures_util::Stream;
10use log::{debug, log_enabled, Level};
11use pin_project_lite::pin_project;
12use postgres_protocol::message::backend::{CommandCompleteBody, Message};
13use postgres_protocol::message::frontend;
14use postgres_types::Type;
15use std::fmt;
16use std::pin::Pin;
17use std::sync::Arc;
18use std::task::{ready, Context, Poll};
19
20struct BorrowToSqlParamsDebug<'a, T>(&'a [T]);
21
22impl<T> fmt::Debug for BorrowToSqlParamsDebug<'_, T>
23where
24 T: BorrowToSql,
25{
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 f.debug_list()
28 .entries(self.0.iter().map(|x| x.borrow_to_sql()))
29 .finish()
30 }
31}
32
33pub async fn query<P, I>(
34 client: &InnerClient,
35 statement: Statement,
36 params: I,
37) -> Result<RowStream, Error>
38where
39 P: BorrowToSql,
40 I: IntoIterator<Item = P>,
41 I::IntoIter: ExactSizeIterator,
42{
43 let buf = if log_enabled!(Level::Debug) {
44 let params = params.into_iter().collect::<Vec<_>>();
45 debug!(
46 "executing statement {} with parameters: {:?}",
47 statement.name(),
48 BorrowToSqlParamsDebug(params.as_slice()),
49 );
50 encode(client, &statement, params)?
51 } else {
52 encode(client, &statement, params)?
53 };
54 let responses = start(client, buf).await?;
55 Ok(RowStream {
56 statement,
57 responses,
58 rows_affected: None,
59 })
60}
61
62pub async fn query_typed<P, I>(
63 client: &Arc<InnerClient>,
64 query: &str,
65 params: I,
66) -> Result<RowStream, Error>
67where
68 P: BorrowToSql,
69 I: IntoIterator<Item = (P, Type)>,
70{
71 let buf = {
72 let params = params.into_iter().collect::<Vec<_>>();
73 let param_oids = params.iter().map(|(_, t)| t.oid()).collect::<Vec<_>>();
74
75 client.with_buf(|buf| {
76 frontend::parse("", query, param_oids.into_iter(), buf).map_err(Error::parse)?;
77 encode_bind_raw("", params, "", buf)?;
78 frontend::describe(b'S', "", buf).map_err(Error::encode)?;
79 frontend::execute("", 0, buf).map_err(Error::encode)?;
80 frontend::sync(buf);
81
82 Ok(buf.split().freeze())
83 })?
84 };
85
86 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
87
88 loop {
89 match responses.next().await? {
90 Message::ParseComplete | Message::BindComplete | Message::ParameterDescription(_) => {}
91 Message::NoData => {
92 return Ok(RowStream {
93 statement: Statement::unnamed(vec![], vec![]),
94 responses,
95 rows_affected: None,
96 });
97 }
98 Message::RowDescription(row_description) => {
99 let mut columns: Vec<Column> = vec![];
100 let mut it = row_description.fields();
101 while let Some(field) = it.next().map_err(Error::parse)? {
102 let type_ = get_type(client, field.type_oid()).await?;
103 let column = Column {
104 name: field.name().to_string(),
105 table_oid: Some(field.table_oid()).filter(|n| *n != 0),
106 column_id: Some(field.column_id()).filter(|n| *n != 0),
107 type_modifier: field.type_modifier(),
108 r#type: type_,
109 };
110 columns.push(column);
111 }
112 return Ok(RowStream {
113 statement: Statement::unnamed(vec![], columns),
114 responses,
115 rows_affected: None,
116 });
117 }
118 _ => return Err(Error::unexpected_message()),
119 }
120 }
121}
122
123pub async fn query_portal(
124 client: &InnerClient,
125 portal: &Portal,
126 max_rows: i32,
127) -> Result<RowStream, Error> {
128 let buf = client.with_buf(|buf| {
129 frontend::execute(portal.name(), max_rows, buf).map_err(Error::encode)?;
130 frontend::sync(buf);
131 Ok(buf.split().freeze())
132 })?;
133
134 let responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
135
136 Ok(RowStream {
137 statement: portal.statement().clone(),
138 responses,
139 rows_affected: None,
140 })
141}
142
143pub fn extract_row_affected(body: &CommandCompleteBody) -> Result<u64, Error> {
145 let rows = body
146 .tag()
147 .map_err(Error::parse)?
148 .rsplit(' ')
149 .next()
150 .unwrap()
151 .parse()
152 .unwrap_or(0);
153 Ok(rows)
154}
155
156pub async fn execute<P, I>(
157 client: &InnerClient,
158 statement: Statement,
159 params: I,
160) -> Result<u64, Error>
161where
162 P: BorrowToSql,
163 I: IntoIterator<Item = P>,
164 I::IntoIter: ExactSizeIterator,
165{
166 let buf = if log_enabled!(Level::Debug) {
167 let params = params.into_iter().collect::<Vec<_>>();
168 debug!(
169 "executing statement {} with parameters: {:?}",
170 statement.name(),
171 BorrowToSqlParamsDebug(params.as_slice()),
172 );
173 encode(client, &statement, params)?
174 } else {
175 encode(client, &statement, params)?
176 };
177 let mut responses = start(client, buf).await?;
178
179 let mut rows = 0;
180 loop {
181 match responses.next().await? {
182 Message::DataRow(_) => {}
183 Message::CommandComplete(body) => {
184 rows = extract_row_affected(&body)?;
185 }
186 Message::EmptyQueryResponse => rows = 0,
187 Message::ReadyForQuery(_) => return Ok(rows),
188 _ => return Err(Error::unexpected_message()),
189 }
190 }
191}
192
193async fn start(client: &InnerClient, buf: Bytes) -> Result<Responses, Error> {
194 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
195
196 match responses.next().await? {
197 Message::BindComplete => {}
198 _ => return Err(Error::unexpected_message()),
199 }
200
201 Ok(responses)
202}
203
204pub fn encode<P, I>(client: &InnerClient, statement: &Statement, params: I) -> Result<Bytes, Error>
205where
206 P: BorrowToSql,
207 I: IntoIterator<Item = P>,
208 I::IntoIter: ExactSizeIterator,
209{
210 client.with_buf(|buf| {
211 encode_bind(statement, params, "", buf)?;
212 frontend::execute("", 0, buf).map_err(Error::encode)?;
213 frontend::sync(buf);
214 Ok(buf.split().freeze())
215 })
216}
217
218pub fn encode_bind<P, I>(
219 statement: &Statement,
220 params: I,
221 portal: &str,
222 buf: &mut BytesMut,
223) -> Result<(), Error>
224where
225 P: BorrowToSql,
226 I: IntoIterator<Item = P>,
227 I::IntoIter: ExactSizeIterator,
228{
229 let params = params.into_iter();
230 if params.len() != statement.params().len() {
231 return Err(Error::parameters(params.len(), statement.params().len()));
232 }
233
234 encode_bind_raw(
235 statement.name(),
236 params.zip(statement.params().iter().cloned()),
237 portal,
238 buf,
239 )
240}
241
242fn encode_bind_raw<P, I>(
243 statement_name: &str,
244 params: I,
245 portal: &str,
246 buf: &mut BytesMut,
247) -> Result<(), Error>
248where
249 P: BorrowToSql,
250 I: IntoIterator<Item = (P, Type)>,
251 I::IntoIter: ExactSizeIterator,
252{
253 let (param_formats, params): (Vec<_>, Vec<_>) = params
254 .into_iter()
255 .map(|(p, ty)| (p.borrow_to_sql().encode_format(&ty) as i16, (p, ty)))
256 .unzip();
257
258 let mut error_idx = 0;
259 let r = frontend::bind(
260 portal,
261 statement_name,
262 param_formats,
263 params.into_iter().enumerate(),
264 |(idx, (param, ty)), buf| match param.borrow_to_sql().to_sql_checked(&ty, buf) {
265 Ok(IsNull::No) => Ok(postgres_protocol::IsNull::No),
266 Ok(IsNull::Yes) => Ok(postgres_protocol::IsNull::Yes),
267 Err(e) => {
268 error_idx = idx;
269 Err(e)
270 }
271 },
272 Some(1),
273 buf,
274 );
275 match r {
276 Ok(()) => Ok(()),
277 Err(frontend::BindError::Conversion(e)) => Err(Error::to_sql(e, error_idx)),
278 Err(frontend::BindError::Serialization(e)) => Err(Error::encode(e)),
279 }
280}
281
282pin_project! {
283 #[project(!Unpin)]
285 pub struct RowStream {
286 statement: Statement,
287 responses: Responses,
288 rows_affected: Option<u64>,
289 }
290}
291
292impl Stream for RowStream {
293 type Item = Result<Row, Error>;
294
295 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
296 let this = self.project();
297 loop {
298 match ready!(this.responses.poll_next(cx)?) {
299 Message::DataRow(body) => {
300 return Poll::Ready(Some(Ok(Row::new(this.statement.clone(), body)?)))
301 }
302 Message::CommandComplete(body) => {
303 *this.rows_affected = Some(extract_row_affected(&body)?);
304 }
305 Message::EmptyQueryResponse | Message::PortalSuspended => {}
306 Message::ReadyForQuery(_) => return Poll::Ready(None),
307 _ => return Poll::Ready(Some(Err(Error::unexpected_message()))),
308 }
309 }
310 }
311}
312
313impl RowStream {
314 pub fn rows_affected(&self) -> Option<u64> {
318 self.rows_affected
319 }
320}
321
322pub async fn sync(client: &InnerClient) -> Result<(), Error> {
323 let buf = Bytes::from_static(b"S\0\0\0\x04");
324 let mut responses = client.send(RequestMessages::Single(FrontendMessage::Raw(buf)))?;
325
326 match responses.next().await? {
327 Message::ReadyForQuery(_) => Ok(()),
328 _ => Err(Error::unexpected_message()),
329 }
330}