sqlx_postgres/connection/
executor.rs1use crate::error::Error;
2use crate::executor::{Execute, Executor};
3use crate::io::{PortalId, StatementId};
4use crate::logger::QueryLogger;
5use crate::message::{
6 self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
7 ParseComplete, Query, RowDescription,
8};
9use crate::statement::PgStatementMetadata;
10use crate::{
11 statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
12 PgValueFormat, Postgres,
13};
14use futures_core::future::BoxFuture;
15use futures_core::stream::BoxStream;
16use futures_core::Stream;
17use futures_util::TryStreamExt;
18use sqlx_core::arguments::Arguments;
19use sqlx_core::sql_str::SqlStr;
20use sqlx_core::Either;
21use std::{pin::pin, sync::Arc};
22
23async fn prepare(
24 conn: &mut PgConnection,
25 sql: &str,
26 parameters: &[PgTypeInfo],
27 metadata: Option<Arc<PgStatementMetadata>>,
28 persistent: bool,
29 fetch_column_origin: bool,
30) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
31 let id = if persistent {
32 let id = conn.inner.next_statement_id;
33 conn.inner.next_statement_id = id.next();
34 id
35 } else {
36 StatementId::UNNAMED
37 };
38
39 let mut param_types = Vec::with_capacity(parameters.len());
44
45 for ty in parameters {
46 param_types.push(conn.resolve_type_id(&ty.0).await?);
47 }
48
49 conn.wait_until_ready().await?;
51
52 conn.inner.stream.write_msg(Parse {
54 param_types: ¶m_types,
55 query: sql,
56 statement: id,
57 })?;
58
59 if metadata.is_none() {
60 conn.inner
62 .stream
63 .write_msg(message::Describe::Statement(id))?;
64 }
65
66 conn.write_sync();
68 conn.inner.stream.flush().await?;
69
70 conn.inner.stream.recv_expect::<ParseComplete>().await?;
72
73 let metadata = if let Some(metadata) = metadata {
74 conn.recv_ready_for_query().await?;
76
77 metadata
79 } else {
80 let parameters = recv_desc_params(conn).await?;
81
82 let rows = recv_desc_rows(conn).await?;
83
84 conn.recv_ready_for_query().await?;
86
87 let parameters = conn.handle_parameter_description(parameters).await?;
88
89 let (columns, column_names) = conn
90 .handle_row_description(rows, true, fetch_column_origin)
91 .await?;
92
93 conn.wait_until_ready().await?;
96
97 Arc::new(PgStatementMetadata {
98 parameters,
99 columns,
100 column_names: Arc::new(column_names),
101 })
102 };
103
104 Ok((id, metadata))
105}
106
107async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
108 conn.inner.stream.recv_expect().await
109}
110
111async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
112 let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
113 message if message.format == BackendMessageFormat::RowDescription => {
115 Some(message.decode()?)
116 }
117
118 message if message.format == BackendMessageFormat::NoData => None,
120
121 message => {
122 return Err(err_protocol!(
123 "expecting RowDescription or NoData but received {:?}",
124 message.format
125 ));
126 }
127 };
128
129 Ok(rows)
130}
131
132impl PgConnection {
133 pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
135 while count > 0 {
137 match self.inner.stream.recv().await? {
138 message if message.format == BackendMessageFormat::PortalSuspended => {
139 }
142
143 message if message.format == BackendMessageFormat::CloseComplete => {
144 count -= 1;
146 }
147
148 message => {
149 return Err(err_protocol!(
150 "expecting PortalSuspended or CloseComplete but received {:?}",
151 message.format
152 ));
153 }
154 }
155 }
156
157 Ok(())
158 }
159
160 #[inline(always)]
161 pub(crate) fn write_sync(&mut self) {
162 self.inner
163 .stream
164 .write_msg(message::Sync)
165 .expect("BUG: Sync should not be too big for protocol");
166
167 self.inner.pending_ready_for_query_count += 1;
169 }
170
171 async fn get_or_prepare(
172 &mut self,
173 sql: &str,
174 parameters: &[PgTypeInfo],
175 persistent: bool,
176 metadata: Option<Arc<PgStatementMetadata>>,
179 fetch_column_origin: bool,
180 ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
181 if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
182 return Ok((*statement).clone());
183 }
184
185 let statement = prepare(
186 self,
187 sql,
188 parameters,
189 metadata,
190 persistent,
191 fetch_column_origin,
192 )
193 .await?;
194
195 if persistent && self.inner.cache_statement.is_enabled() {
196 if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
197 self.inner.stream.write_msg(Close::Statement(id))?;
198 self.write_sync();
199
200 self.inner.stream.flush().await?;
201
202 self.wait_for_close_complete(1).await?;
203 self.recv_ready_for_query().await?;
204 }
205 }
206
207 Ok(statement)
208 }
209
210 pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
211 &'c mut self,
212 query: SqlStr,
213 arguments: Option<PgArguments>,
214 persistent: bool,
215 metadata_opt: Option<Arc<PgStatementMetadata>>,
216 ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
217 let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
218 let sql = logger.sql().as_str();
219
220 self.wait_until_ready().await?;
222
223 let mut metadata: Arc<PgStatementMetadata>;
224
225 let format = if let Some(mut arguments) = arguments {
226 let num_params = u16::try_from(arguments.len()).map_err(|_| {
233 err_protocol!(
234 "PgConnection::run(): too many arguments for query: {}",
235 arguments.len()
236 )
237 })?;
238
239 let (statement, metadata_) = self
242 .get_or_prepare(sql, &arguments.types, persistent, metadata_opt, false)
243 .await?;
244
245 metadata = metadata_;
246
247 arguments.apply_patches(self, &metadata.parameters).await?;
249
250 self.wait_until_ready().await?;
252
253 self.inner.stream.write_msg(Bind {
255 portal: PortalId::UNNAMED,
256 statement,
257 formats: &[PgValueFormat::Binary],
258 num_params,
259 params: &arguments.buffer,
260 result_formats: &[PgValueFormat::Binary],
261 })?;
262
263 self.inner.stream.write_msg(message::Execute {
266 portal: PortalId::UNNAMED,
267 limit: 0,
270 })?;
271 self.inner
281 .stream
282 .write_msg(Close::Portal(PortalId::UNNAMED))?;
283
284 self.write_sync();
290
291 PgValueFormat::Binary
293 } else {
294 self.inner.stream.write_msg(Query(sql))?;
296 self.inner.pending_ready_for_query_count += 1;
297
298 metadata = Arc::new(PgStatementMetadata::default());
300
301 PgValueFormat::Text
303 };
304
305 self.inner.stream.flush().await?;
306
307 Ok(try_stream! {
308 loop {
309 let message = self.inner.stream.recv().await?;
310
311 match message.format {
312 BackendMessageFormat::BindComplete
313 | BackendMessageFormat::ParseComplete
314 | BackendMessageFormat::ParameterDescription
315 | BackendMessageFormat::NoData
316 | BackendMessageFormat::CloseComplete
318 => {
319 }
321
322 BackendMessageFormat::CommandComplete => {
327 let cc: CommandComplete = message.decode()?;
329
330 let rows_affected = cc.rows_affected();
331 logger.increase_rows_affected(rows_affected);
332 r#yield!(Either::Left(PgQueryResult {
333 rows_affected,
334 }));
335 }
336
337 BackendMessageFormat::EmptyQueryResponse => {
338 }
340
341 BackendMessageFormat::PortalSuspended => {}
345
346 BackendMessageFormat::RowDescription => {
347 let (columns, column_names) = self
349 .handle_row_description(Some(message.decode()?), false, false)
350 .await?;
351
352 metadata = Arc::new(PgStatementMetadata {
353 column_names: Arc::new(column_names),
354 columns,
355 parameters: Vec::default(),
356 });
357 }
358
359 BackendMessageFormat::DataRow => {
360 logger.increment_rows_returned();
361
362 let data: DataRow = message.decode()?;
364 let row = PgRow {
365 data,
366 format,
367 metadata: Arc::clone(&metadata),
368 };
369
370 r#yield!(Either::Right(row));
371 }
372
373 BackendMessageFormat::ReadyForQuery => {
374 self.handle_ready_for_query(message)?;
376 break;
377 }
378
379 _ => {
380 return Err(err_protocol!(
381 "execute: unexpected message: {:?}",
382 message.format
383 ));
384 }
385 }
386 }
387
388 Ok(())
389 })
390 }
391}
392
393impl<'c> Executor<'c> for &'c mut PgConnection {
394 type Database = Postgres;
395
396 fn fetch_many<'e, 'q, E>(
397 self,
398 mut query: E,
399 ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
400 where
401 'c: 'e,
402 E: Execute<'q, Self::Database>,
403 'q: 'e,
404 E: 'q,
405 {
406 #[allow(clippy::map_clone)]
408 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
409 let arguments = query.take_arguments().map_err(Error::Encode);
410 let persistent = query.persistent();
411 let sql = query.sql();
412
413 Box::pin(try_stream! {
414 let arguments = arguments?;
415 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
416
417 while let Some(v) = s.try_next().await? {
418 r#yield!(v);
419 }
420
421 Ok(())
422 })
423 }
424
425 fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
426 where
427 'c: 'e,
428 E: Execute<'q, Self::Database>,
429 'q: 'e,
430 E: 'q,
431 {
432 #[allow(clippy::map_clone)]
434 let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
435 let arguments = query.take_arguments().map_err(Error::Encode);
436 let persistent = query.persistent();
437
438 Box::pin(async move {
439 let sql = query.sql();
440 let arguments = arguments?;
441 let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
442
443 let mut ret = None;
448 while let Some(result) = s.try_next().await? {
449 match result {
450 Either::Right(r) if ret.is_none() => ret = Some(r),
451 _ => {}
452 }
453 }
454 Ok(ret)
455 })
456 }
457
458 fn prepare_with<'e>(
459 self,
460 sql: SqlStr,
461 parameters: &'e [PgTypeInfo],
462 ) -> BoxFuture<'e, Result<PgStatement, Error>>
463 where
464 'c: 'e,
465 {
466 Box::pin(async move {
467 self.wait_until_ready().await?;
468
469 let (_, metadata) = self
470 .get_or_prepare(sql.as_str(), parameters, true, None, true)
471 .await?;
472
473 Ok(PgStatement { sql, metadata })
474 })
475 }
476
477 #[cfg(feature = "offline")]
478 fn describe<'e>(
479 self,
480 sql: SqlStr,
481 ) -> BoxFuture<'e, Result<crate::describe::Describe<Self::Database>, Error>>
482 where
483 'c: 'e,
484 {
485 Box::pin(async move {
486 self.wait_until_ready().await?;
487
488 let (stmt_id, metadata) = self
489 .get_or_prepare(sql.as_str(), &[], true, None, true)
490 .await?;
491
492 let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
493
494 Ok(crate::describe::Describe {
495 columns: metadata.columns.clone(),
496 nullable,
497 parameters: Some(Either::Left(metadata.parameters.clone())),
498 })
499 })
500 }
501}