Skip to main content

sqlx_postgres/connection/
executor.rs

1use crate::error::Error;
2use crate::executor::{Execute, Executor};
3use crate::io::{PortalId, StatementId};
4use crate::logger::QueryLogger;
5use crate::message::{
6    self, BackendMessageFormat, Bind, Close, CommandComplete, DataRow, ParameterDescription, Parse,
7    ParseComplete, Query, RowDescription,
8};
9use crate::statement::PgStatementMetadata;
10use crate::{
11    statement::PgStatement, PgArguments, PgConnection, PgQueryResult, PgRow, PgTypeInfo,
12    PgValueFormat, Postgres,
13};
14use futures_core::future::BoxFuture;
15use futures_core::stream::BoxStream;
16use futures_core::Stream;
17use futures_util::TryStreamExt;
18use sqlx_core::arguments::Arguments;
19use sqlx_core::sql_str::SqlStr;
20use sqlx_core::Either;
21use std::{pin::pin, sync::Arc};
22
23async fn prepare(
24    conn: &mut PgConnection,
25    sql: &str,
26    parameters: &[PgTypeInfo],
27    metadata: Option<Arc<PgStatementMetadata>>,
28    persistent: bool,
29    fetch_column_origin: bool,
30) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
31    let id = if persistent {
32        let id = conn.inner.next_statement_id;
33        conn.inner.next_statement_id = id.next();
34        id
35    } else {
36        StatementId::UNNAMED
37    };
38
39    // build a list of type OIDs to send to the database in the PARSE command
40    // we have not yet started the query sequence, so we are *safe* to cleanly make
41    // additional queries here to get any missing OIDs
42
43    let mut param_types = Vec::with_capacity(parameters.len());
44
45    for ty in parameters {
46        param_types.push(conn.resolve_type_id(&ty.0).await?);
47    }
48
49    // flush and wait until we are re-ready
50    conn.wait_until_ready().await?;
51
52    // next we send the PARSE command to the server
53    conn.inner.stream.write_msg(Parse {
54        param_types: &param_types,
55        query: sql,
56        statement: id,
57    })?;
58
59    if metadata.is_none() {
60        // get the statement columns and parameters
61        conn.inner
62            .stream
63            .write_msg(message::Describe::Statement(id))?;
64    }
65
66    // we ask for the server to immediately send us the result of the PARSE command
67    conn.write_sync();
68    conn.inner.stream.flush().await?;
69
70    // indicates that the SQL query string is now successfully parsed and has semantic validity
71    conn.inner.stream.recv_expect::<ParseComplete>().await?;
72
73    let metadata = if let Some(metadata) = metadata {
74        // each SYNC produces one READY FOR QUERY
75        conn.recv_ready_for_query().await?;
76
77        // we already have metadata
78        metadata
79    } else {
80        let parameters = recv_desc_params(conn).await?;
81
82        let rows = recv_desc_rows(conn).await?;
83
84        // each SYNC produces one READY FOR QUERY
85        conn.recv_ready_for_query().await?;
86
87        let parameters = conn.handle_parameter_description(parameters).await?;
88
89        let (columns, column_names) = conn
90            .handle_row_description(rows, true, fetch_column_origin)
91            .await?;
92
93        // ensure that if we did fetch custom data, we wait until we are fully ready before
94        // continuing
95        conn.wait_until_ready().await?;
96
97        Arc::new(PgStatementMetadata {
98            parameters,
99            columns,
100            column_names: Arc::new(column_names),
101        })
102    };
103
104    Ok((id, metadata))
105}
106
107async fn recv_desc_params(conn: &mut PgConnection) -> Result<ParameterDescription, Error> {
108    conn.inner.stream.recv_expect().await
109}
110
111async fn recv_desc_rows(conn: &mut PgConnection) -> Result<Option<RowDescription>, Error> {
112    let rows: Option<RowDescription> = match conn.inner.stream.recv().await? {
113        // describes the rows that will be returned when the statement is eventually executed
114        message if message.format == BackendMessageFormat::RowDescription => {
115            Some(message.decode()?)
116        }
117
118        // no data would be returned if this statement was executed
119        message if message.format == BackendMessageFormat::NoData => None,
120
121        message => {
122            return Err(err_protocol!(
123                "expecting RowDescription or NoData but received {:?}",
124                message.format
125            ));
126        }
127    };
128
129    Ok(rows)
130}
131
132impl PgConnection {
133    // wait for CloseComplete to indicate a statement was closed
134    pub(super) async fn wait_for_close_complete(&mut self, mut count: usize) -> Result<(), Error> {
135        // we need to wait for the [CloseComplete] to be returned from the server
136        while count > 0 {
137            match self.inner.stream.recv().await? {
138                message if message.format == BackendMessageFormat::PortalSuspended => {
139                    // there was an open portal
140                    // this can happen if the last time a statement was used it was not fully executed
141                }
142
143                message if message.format == BackendMessageFormat::CloseComplete => {
144                    // successfully closed the statement (and freed up the server resources)
145                    count -= 1;
146                }
147
148                message => {
149                    return Err(err_protocol!(
150                        "expecting PortalSuspended or CloseComplete but received {:?}",
151                        message.format
152                    ));
153                }
154            }
155        }
156
157        Ok(())
158    }
159
160    #[inline(always)]
161    pub(crate) fn write_sync(&mut self) {
162        self.inner
163            .stream
164            .write_msg(message::Sync)
165            .expect("BUG: Sync should not be too big for protocol");
166
167        // all SYNC messages will return a ReadyForQuery
168        self.inner.pending_ready_for_query_count += 1;
169    }
170
171    async fn get_or_prepare(
172        &mut self,
173        sql: &str,
174        parameters: &[PgTypeInfo],
175        persistent: bool,
176        // optional metadata that was provided by the user, this means they are reusing
177        // a statement object
178        metadata: Option<Arc<PgStatementMetadata>>,
179        fetch_column_origin: bool,
180    ) -> Result<(StatementId, Arc<PgStatementMetadata>), Error> {
181        if let Some(statement) = self.inner.cache_statement.get_mut(sql) {
182            return Ok((*statement).clone());
183        }
184
185        let statement = prepare(
186            self,
187            sql,
188            parameters,
189            metadata,
190            persistent,
191            fetch_column_origin,
192        )
193        .await?;
194
195        if persistent && self.inner.cache_statement.is_enabled() {
196            if let Some((id, _)) = self.inner.cache_statement.insert(sql, statement.clone()) {
197                self.inner.stream.write_msg(Close::Statement(id))?;
198                self.write_sync();
199
200                self.inner.stream.flush().await?;
201
202                self.wait_for_close_complete(1).await?;
203                self.recv_ready_for_query().await?;
204            }
205        }
206
207        Ok(statement)
208    }
209
210    pub(crate) async fn run<'e, 'c: 'e, 'q: 'e>(
211        &'c mut self,
212        query: SqlStr,
213        arguments: Option<PgArguments>,
214        persistent: bool,
215        metadata_opt: Option<Arc<PgStatementMetadata>>,
216    ) -> Result<impl Stream<Item = Result<Either<PgQueryResult, PgRow>, Error>> + 'e, Error> {
217        let mut logger = QueryLogger::new(query, self.inner.log_settings.clone());
218        let sql = logger.sql().as_str();
219
220        // before we continue, wait until we are "ready" to accept more queries
221        self.wait_until_ready().await?;
222
223        let mut metadata: Arc<PgStatementMetadata>;
224
225        let format = if let Some(mut arguments) = arguments {
226            // Check this before we write anything to the stream.
227            //
228            // Note: Postgres actually interprets this value as unsigned,
229            // making the max number of parameters 65535, not 32767
230            // https://github.com/launchbadge/sqlx/issues/3464
231            // https://www.postgresql.org/docs/current/limits.html
232            let num_params = u16::try_from(arguments.len()).map_err(|_| {
233                err_protocol!(
234                    "PgConnection::run(): too many arguments for query: {}",
235                    arguments.len()
236                )
237            })?;
238
239            // prepare the statement if this our first time executing it
240            // always return the statement ID here
241            let (statement, metadata_) = self
242                .get_or_prepare(sql, &arguments.types, persistent, metadata_opt, false)
243                .await?;
244
245            metadata = metadata_;
246
247            // patch holes created during encoding
248            arguments.apply_patches(self, &metadata.parameters).await?;
249
250            // consume messages till `ReadyForQuery` before bind and execute
251            self.wait_until_ready().await?;
252
253            // bind to attach the arguments to the statement and create a portal
254            self.inner.stream.write_msg(Bind {
255                portal: PortalId::UNNAMED,
256                statement,
257                formats: &[PgValueFormat::Binary],
258                num_params,
259                params: &arguments.buffer,
260                result_formats: &[PgValueFormat::Binary],
261            })?;
262
263            // executes the portal up to the passed limit
264            // the protocol-level limit acts nearly identically to the `LIMIT` in SQL
265            self.inner.stream.write_msg(message::Execute {
266                portal: PortalId::UNNAMED,
267                // Non-zero limits cause query plan pessimization by disabling parallel workers:
268                // https://github.com/launchbadge/sqlx/issues/3673
269                limit: 0,
270            })?;
271            // From https://www.postgresql.org/docs/current/protocol-flow.html:
272            //
273            // "An unnamed portal is destroyed at the end of the transaction, or as
274            // soon as the next Bind statement specifying the unnamed portal as
275            // destination is issued. (Note that a simple Query message also
276            // destroys the unnamed portal."
277
278            // we ask the database server to close the unnamed portal and free the associated resources
279            // earlier - after the execution of the current query.
280            self.inner
281                .stream
282                .write_msg(Close::Portal(PortalId::UNNAMED))?;
283
284            // finally, [Sync] asks postgres to process the messages that we sent and respond with
285            // a [ReadyForQuery] message when it's completely done. Theoretically, we could send
286            // dozens of queries before a [Sync] and postgres can handle that. Execution on the server
287            // is still serial but it would reduce round-trips. Some kind of builder pattern that is
288            // termed batching might suit this.
289            self.write_sync();
290
291            // prepared statements are binary
292            PgValueFormat::Binary
293        } else {
294            // Query will trigger a ReadyForQuery
295            self.inner.stream.write_msg(Query(sql))?;
296            self.inner.pending_ready_for_query_count += 1;
297
298            // metadata starts out as "nothing"
299            metadata = Arc::new(PgStatementMetadata::default());
300
301            // and unprepared statements are text
302            PgValueFormat::Text
303        };
304
305        self.inner.stream.flush().await?;
306
307        Ok(try_stream! {
308            loop {
309                let message = self.inner.stream.recv().await?;
310
311                match message.format {
312                    BackendMessageFormat::BindComplete
313                    | BackendMessageFormat::ParseComplete
314                    | BackendMessageFormat::ParameterDescription
315                    | BackendMessageFormat::NoData
316                    // unnamed portal has been closed
317                    | BackendMessageFormat::CloseComplete
318                    => {
319                        // harmless messages to ignore
320                    }
321
322                    // "Execute phase is always terminated by the appearance of
323                    // exactly one of these messages: CommandComplete,
324                    // EmptyQueryResponse (if the portal was created from an
325                    // empty query string), ErrorResponse, or PortalSuspended"
326                    BackendMessageFormat::CommandComplete => {
327                        // a SQL command completed normally
328                        let cc: CommandComplete = message.decode()?;
329
330                        let rows_affected = cc.rows_affected();
331                        logger.increase_rows_affected(rows_affected);
332                        r#yield!(Either::Left(PgQueryResult {
333                            rows_affected,
334                        }));
335                    }
336
337                    BackendMessageFormat::EmptyQueryResponse => {
338                        // empty query string passed to an unprepared execute
339                    }
340
341                    // Message::ErrorResponse is handled in self.stream.recv()
342
343                    // incomplete query execution has finished
344                    BackendMessageFormat::PortalSuspended => {}
345
346                    BackendMessageFormat::RowDescription => {
347                        // indicates that a *new* set of rows are about to be returned
348                        let (columns, column_names) = self
349                            .handle_row_description(Some(message.decode()?), false, false)
350                            .await?;
351
352                        metadata = Arc::new(PgStatementMetadata {
353                            column_names: Arc::new(column_names),
354                            columns,
355                            parameters: Vec::default(),
356                        });
357                    }
358
359                    BackendMessageFormat::DataRow => {
360                        logger.increment_rows_returned();
361
362                        // one of the set of rows returned by a SELECT, FETCH, etc query
363                        let data: DataRow = message.decode()?;
364                        let row = PgRow {
365                            data,
366                            format,
367                            metadata: Arc::clone(&metadata),
368                        };
369
370                        r#yield!(Either::Right(row));
371                    }
372
373                    BackendMessageFormat::ReadyForQuery => {
374                        // processing of the query string is complete
375                        self.handle_ready_for_query(message)?;
376                        break;
377                    }
378
379                    _ => {
380                        return Err(err_protocol!(
381                            "execute: unexpected message: {:?}",
382                            message.format
383                        ));
384                    }
385                }
386            }
387
388            Ok(())
389        })
390    }
391}
392
393impl<'c> Executor<'c> for &'c mut PgConnection {
394    type Database = Postgres;
395
396    fn fetch_many<'e, 'q, E>(
397        self,
398        mut query: E,
399    ) -> BoxStream<'e, Result<Either<PgQueryResult, PgRow>, Error>>
400    where
401        'c: 'e,
402        E: Execute<'q, Self::Database>,
403        'q: 'e,
404        E: 'q,
405    {
406        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
407        #[allow(clippy::map_clone)]
408        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
409        let arguments = query.take_arguments().map_err(Error::Encode);
410        let persistent = query.persistent();
411        let sql = query.sql();
412
413        Box::pin(try_stream! {
414            let arguments = arguments?;
415            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
416
417            while let Some(v) = s.try_next().await? {
418                r#yield!(v);
419            }
420
421            Ok(())
422        })
423    }
424
425    fn fetch_optional<'e, 'q, E>(self, mut query: E) -> BoxFuture<'e, Result<Option<PgRow>, Error>>
426    where
427        'c: 'e,
428        E: Execute<'q, Self::Database>,
429        'q: 'e,
430        E: 'q,
431    {
432        // False positive: https://github.com/rust-lang/rust-clippy/issues/12560
433        #[allow(clippy::map_clone)]
434        let metadata = query.statement().map(|s| Arc::clone(&s.metadata));
435        let arguments = query.take_arguments().map_err(Error::Encode);
436        let persistent = query.persistent();
437
438        Box::pin(async move {
439            let sql = query.sql();
440            let arguments = arguments?;
441            let mut s = pin!(self.run(sql, arguments, persistent, metadata).await?);
442
443            // With deferred constraints we need to check all responses as we
444            // could get a OK response (with uncommitted data), only to get an
445            // error response after (when the deferred constraint is actually
446            // checked).
447            let mut ret = None;
448            while let Some(result) = s.try_next().await? {
449                match result {
450                    Either::Right(r) if ret.is_none() => ret = Some(r),
451                    _ => {}
452                }
453            }
454            Ok(ret)
455        })
456    }
457
458    fn prepare_with<'e>(
459        self,
460        sql: SqlStr,
461        parameters: &'e [PgTypeInfo],
462    ) -> BoxFuture<'e, Result<PgStatement, Error>>
463    where
464        'c: 'e,
465    {
466        Box::pin(async move {
467            self.wait_until_ready().await?;
468
469            let (_, metadata) = self
470                .get_or_prepare(sql.as_str(), parameters, true, None, true)
471                .await?;
472
473            Ok(PgStatement { sql, metadata })
474        })
475    }
476
477    #[cfg(feature = "offline")]
478    fn describe<'e>(
479        self,
480        sql: SqlStr,
481    ) -> BoxFuture<'e, Result<crate::describe::Describe<Self::Database>, Error>>
482    where
483        'c: 'e,
484    {
485        Box::pin(async move {
486            self.wait_until_ready().await?;
487
488            let (stmt_id, metadata) = self
489                .get_or_prepare(sql.as_str(), &[], true, None, true)
490                .await?;
491
492            let nullable = self.get_nullable_for_columns(stmt_id, &metadata).await?;
493
494            Ok(crate::describe::Describe {
495                columns: metadata.columns.clone(),
496                nullable,
497                parameters: Some(Either::Left(metadata.parameters.clone())),
498            })
499        })
500    }
501}