diesel_async/pg/
mod.rs

1//! Provides types and functions related to working with PostgreSQL
2//!
3//! Much of this module is re-exported from database agnostic locations.
4//! However, if you are writing code specifically to extend Diesel on
5//! PostgreSQL, you may need to work with this module directly.
6
7use self::error_helper::ErrorHelper;
8use self::row::PgRow;
9use self::serialize::ToSqlHelper;
10use crate::statement_cache::CacheSize;
11use crate::statement_cache::{PrepareForCache, QueryFragmentForCachedStatement, StatementCache};
12use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
13use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
14use diesel::connection::Instrumentation;
15use diesel::connection::InstrumentationEvent;
16use diesel::connection::StrQueryHelper;
17use diesel::pg::{
18    Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata,
19};
20use diesel::query_builder::bind_collector::RawBytesBindCollector;
21use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
22use diesel::result::{DatabaseErrorKind, Error};
23use diesel::{ConnectionError, ConnectionResult, QueryResult};
24use futures_core::future::BoxFuture;
25use futures_core::stream::BoxStream;
26use futures_util::future::Either;
27use futures_util::stream::TryStreamExt;
28use futures_util::TryFutureExt;
29use futures_util::{FutureExt, StreamExt};
30use std::collections::{HashMap, HashSet};
31use std::future::Future;
32use std::sync::Arc;
33use tokio::sync::broadcast;
34use tokio::sync::oneshot;
35use tokio::sync::Mutex;
36use tokio_postgres::types::ToSql;
37use tokio_postgres::types::Type;
38use tokio_postgres::Statement;
39
40pub use self::transaction_builder::TransactionBuilder;
41
42mod error_helper;
43mod row;
44mod serialize;
45mod transaction_builder;
46
47const FAKE_OID: u32 = 0;
48
49/// A connection to a PostgreSQL database.
50///
51/// Connection URLs should be in the form
52/// `postgres://[user[:password]@]host/database_name`
53///
54/// Checkout the documentation of the [tokio_postgres]
55/// crate for details about the format
56///
57/// [tokio_postgres]: https://docs.rs/tokio-postgres/0.7.6/tokio_postgres/config/struct.Config.html#url
58///
59/// ## Pipelining
60///
61/// This connection supports *pipelined* requests. Pipelining can improve performance in use cases in which multiple,
62/// independent queries need to be executed. In a traditional workflow, each query is sent to the server after the
63/// previous query completes. In contrast, pipelining allows the client to send all of the queries to the server up
64/// front, minimizing time spent by one side waiting for the other to finish sending data:
65///
66/// ```not_rust
67///             Sequential                              Pipelined
68/// | Client         | Server          |    | Client         | Server          |
69/// |----------------|-----------------|    |----------------|-----------------|
70/// | send query 1   |                 |    | send query 1   |                 |
71/// |                | process query 1 |    | send query 2   | process query 1 |
72/// | receive rows 1 |                 |    | send query 3   | process query 2 |
73/// | send query 2   |                 |    | receive rows 1 | process query 3 |
74/// |                | process query 2 |    | receive rows 2 |                 |
75/// | receive rows 2 |                 |    | receive rows 3 |                 |
76/// | send query 3   |                 |
77/// |                | process query 3 |
78/// | receive rows 3 |                 |
79/// ```
80///
81/// In both cases, the PostgreSQL server is executing the queries **sequentially** - pipelining just allows both sides of
82/// the connection to work concurrently when possible.
83///
84/// Pipelining happens automatically when futures are polled concurrently (for example, by using the futures `join`
85/// combinator):
86///
87/// ```rust
88/// # include!("../doctest_setup.rs");
89/// use diesel_async::RunQueryDsl;
90///
91/// #
92/// # #[tokio::main(flavor = "current_thread")]
93/// # async fn main() {
94/// #     run_test().await.unwrap();
95/// # }
96/// #
97/// # async fn run_test() -> QueryResult<()> {
98/// #     use diesel::sql_types::{Text, Integer};
99/// #     let conn = &mut establish_connection().await;
100///       let q1 = diesel::select(1_i32.into_sql::<Integer>());
101///       let q2 = diesel::select(2_i32.into_sql::<Integer>());
102///
103///       // construct multiple futures for different queries
104///       let f1 = q1.get_result::<i32>(conn);
105///       let f2 = q2.get_result::<i32>(conn);
106///
107///       // wait on both results
108///       let res = futures_util::try_join!(f1, f2)?;
109///
110///       assert_eq!(res.0, 1);
111///       assert_eq!(res.1, 2);
112///       # Ok(())
113/// # }
114/// ```
115///
116/// ## TLS
117///
118/// Connections created by [`AsyncPgConnection::establish`] do not support TLS.
119///
120/// TLS support for tokio_postgres connections is implemented by external crates, e.g. [tokio_postgres_rustls].
121///
122/// [`AsyncPgConnection::try_from_client_and_connection`] can be used to construct a connection from an existing
123/// [`tokio_postgres::Connection`] with TLS enabled.
124///
125/// [tokio_postgres_rustls]: https://docs.rs/tokio-postgres-rustls/0.12.0/tokio_postgres_rustls/
126pub struct AsyncPgConnection {
127    conn: Arc<tokio_postgres::Client>,
128    stmt_cache: Arc<Mutex<StatementCache<diesel::pg::Pg, Statement>>>,
129    transaction_state: Arc<Mutex<AnsiTransactionManager>>,
130    metadata_cache: Arc<Mutex<PgMetadataCache>>,
131    connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
132    shutdown_channel: Option<oneshot::Sender<()>>,
133    // a sync mutex is fine here as we only hold it for a really short time
134    instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
135}
136
137impl SimpleAsyncConnection for AsyncPgConnection {
138    async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
139        self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
140            query,
141        )));
142        let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
143        let batch_execute = self
144            .conn
145            .batch_execute(query)
146            .map_err(ErrorHelper)
147            .map_err(Into::into);
148
149        let r = drive_future(connection_future, batch_execute).await;
150        let r = {
151            let mut transaction_manager = self.transaction_state.lock().await;
152            update_transaction_manager_status(r, &mut transaction_manager)
153        };
154        self.record_instrumentation(InstrumentationEvent::finish_query(
155            &StrQueryHelper::new(query),
156            r.as_ref().err(),
157        ));
158        r
159    }
160}
161
162impl AsyncConnection for AsyncPgConnection {
163    type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
164    type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
165    type Stream<'conn, 'query> = BoxStream<'static, QueryResult<PgRow>>;
166    type Row<'conn, 'query> = PgRow;
167    type Backend = diesel::pg::Pg;
168    type TransactionManager = AnsiTransactionManager;
169
170    async fn establish(database_url: &str) -> ConnectionResult<Self> {
171        let mut instrumentation = diesel::connection::get_default_instrumentation();
172        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
173            database_url,
174        ));
175        let instrumentation = Arc::new(std::sync::Mutex::new(instrumentation));
176        let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
177            .await
178            .map_err(ErrorHelper)?;
179
180        let (error_rx, shutdown_tx) = drive_connection(connection);
181
182        let r = Self::setup(
183            client,
184            Some(error_rx),
185            Some(shutdown_tx),
186            Arc::clone(&instrumentation),
187        )
188        .await;
189
190        instrumentation
191            .lock()
192            .unwrap_or_else(|e| e.into_inner())
193            .on_connection_event(InstrumentationEvent::finish_establish_connection(
194                database_url,
195                r.as_ref().err(),
196            ));
197        r
198    }
199
200    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
201    where
202        T: AsQuery + 'query,
203        T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
204    {
205        let query = source.as_query();
206        let load_future = self.with_prepared_statement(query, load_prepared);
207
208        self.run_with_connection_future(load_future)
209    }
210
211    fn execute_returning_count<'conn, 'query, T>(
212        &'conn mut self,
213        source: T,
214    ) -> Self::ExecuteFuture<'conn, 'query>
215    where
216        T: QueryFragment<Self::Backend> + QueryId + 'query,
217    {
218        let execute = self.with_prepared_statement(source, execute_prepared);
219        self.run_with_connection_future(execute)
220    }
221
222    fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
223        // there should be no other pending future when this is called
224        // that means there is only one instance of this arc and
225        // we can simply access the inner data
226        if let Some(tm) = Arc::get_mut(&mut self.transaction_state) {
227            tm.get_mut()
228        } else {
229            panic!("Cannot access shared transaction state")
230        }
231    }
232
233    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
234        // there should be no other pending future when this is called
235        // that means there is only one instance of this arc and
236        // we can simply access the inner data
237        if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) {
238            &mut *(instrumentation.get_mut().unwrap_or_else(|p| p.into_inner()))
239        } else {
240            panic!("Cannot access shared instrumentation")
241        }
242    }
243
244    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
245        self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation))));
246    }
247
248    fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
249        // there should be no other pending future when this is called
250        // that means there is only one instance of this arc and
251        // we can simply access the inner data
252        if let Some(cache) = Arc::get_mut(&mut self.stmt_cache) {
253            cache.get_mut().set_cache_size(size)
254        } else {
255            panic!("Cannot access shared statement cache")
256        }
257    }
258}
259
260impl Drop for AsyncPgConnection {
261    fn drop(&mut self) {
262        if let Some(tx) = self.shutdown_channel.take() {
263            let _ = tx.send(());
264        }
265    }
266}
267
268async fn load_prepared(
269    conn: Arc<tokio_postgres::Client>,
270    stmt: Statement,
271    binds: Vec<ToSqlHelper>,
272) -> QueryResult<BoxStream<'static, QueryResult<PgRow>>> {
273    let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
274
275    Ok(res
276        .map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
277        .map_ok(PgRow::new)
278        .boxed())
279}
280
281async fn execute_prepared(
282    conn: Arc<tokio_postgres::Client>,
283    stmt: Statement,
284    binds: Vec<ToSqlHelper>,
285) -> QueryResult<usize> {
286    let binds = binds
287        .iter()
288        .map(|b| b as &(dyn ToSql + Sync))
289        .collect::<Vec<_>>();
290
291    let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_])
292        .await
293        .map_err(ErrorHelper)?;
294    res.try_into()
295        .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
296}
297
298#[inline(always)]
299fn update_transaction_manager_status<T>(
300    query_result: QueryResult<T>,
301    transaction_manager: &mut AnsiTransactionManager,
302) -> QueryResult<T> {
303    if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
304        query_result
305    {
306        transaction_manager
307            .status
308            .set_requires_rollback_maybe_up_to_top_level(true)
309    }
310    query_result
311}
312
313fn prepare_statement_helper(
314    conn: Arc<tokio_postgres::Client>,
315    sql: &str,
316    _is_for_cache: PrepareForCache,
317    metadata: &[PgTypeMetadata],
318) -> CallbackHelper<
319    impl Future<Output = QueryResult<(Statement, Arc<tokio_postgres::Client>)>> + Send,
320> {
321    let bind_types = metadata
322        .iter()
323        .map(type_from_oid)
324        .collect::<QueryResult<Vec<_>>>();
325    // ideally we wouldn't clone the SQL string here
326    // but as we usually cache statements anyway
327    // this is a fixed one time const
328    //
329    // The probleme with not cloning it is that we then cannot express
330    // the right result lifetime anymore (at least not easily)
331    let sql = sql.to_string();
332    CallbackHelper(async move {
333        let bind_types = bind_types?;
334        let stmt = conn
335            .prepare_typed(&sql, &bind_types)
336            .await
337            .map_err(ErrorHelper);
338        Ok((stmt?, conn))
339    })
340}
341
342fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
343    let oid = t
344        .oid()
345        .map_err(|e| diesel::result::Error::SerializationError(Box::new(e) as _))?;
346
347    if let Some(tpe) = Type::from_oid(oid) {
348        return Ok(tpe);
349    }
350
351    Ok(Type::new(
352        format!("diesel_custom_type_{oid}"),
353        oid,
354        tokio_postgres::types::Kind::Simple,
355        "public".into(),
356    ))
357}
358
359impl AsyncPgConnection {
360    /// Build a transaction, specifying additional details such as isolation level
361    ///
362    /// See [`TransactionBuilder`] for more examples.
363    ///
364    /// [`TransactionBuilder`]: crate::pg::TransactionBuilder
365    ///
366    /// ```rust
367    /// # include!("../doctest_setup.rs");
368    /// # use scoped_futures::ScopedFutureExt;
369    /// #
370    /// # #[tokio::main(flavor = "current_thread")]
371    /// # async fn main() {
372    /// #     run_test().await.unwrap();
373    /// # }
374    /// #
375    /// # async fn run_test() -> QueryResult<()> {
376    /// #     use schema::users::dsl::*;
377    /// #     let conn = &mut connection_no_transaction().await;
378    /// conn.build_transaction()
379    ///     .read_only()
380    ///     .serializable()
381    ///     .deferrable()
382    ///     .run(|conn| async move { Ok(()) }.scope_boxed())
383    ///     .await
384    /// # }
385    /// ```
386    pub fn build_transaction(&mut self) -> TransactionBuilder<'_, Self> {
387        TransactionBuilder::new(self)
388    }
389
390    /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
391    pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
392        Self::setup(
393            conn,
394            None,
395            None,
396            Arc::new(std::sync::Mutex::new(
397                diesel::connection::get_default_instrumentation(),
398            )),
399        )
400        .await
401    }
402
403    /// Constructs a new `AsyncPgConnection` from an existing [`tokio_postgres::Client`] and
404    /// [`tokio_postgres::Connection`]
405    pub async fn try_from_client_and_connection<S>(
406        client: tokio_postgres::Client,
407        conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
408    ) -> ConnectionResult<Self>
409    where
410        S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
411    {
412        let (error_rx, shutdown_tx) = drive_connection(conn);
413
414        Self::setup(
415            client,
416            Some(error_rx),
417            Some(shutdown_tx),
418            Arc::new(std::sync::Mutex::new(
419                diesel::connection::get_default_instrumentation(),
420            )),
421        )
422        .await
423    }
424
425    async fn setup(
426        conn: tokio_postgres::Client,
427        connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
428        shutdown_channel: Option<oneshot::Sender<()>>,
429        instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
430    ) -> ConnectionResult<Self> {
431        let mut conn = Self {
432            conn: Arc::new(conn),
433            stmt_cache: Arc::new(Mutex::new(StatementCache::new())),
434            transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
435            metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
436            connection_future,
437            shutdown_channel,
438            instrumentation,
439        };
440        conn.set_config_options()
441            .await
442            .map_err(ConnectionError::CouldntSetupConfiguration)?;
443        Ok(conn)
444    }
445
446    /// Constructs a cancellation token that can later be used to request cancellation of a query running on the connection associated with this client.
447    pub fn cancel_token(&self) -> tokio_postgres::CancelToken {
448        self.conn.cancel_token()
449    }
450
451    async fn set_config_options(&mut self) -> QueryResult<()> {
452        use crate::run_query_dsl::RunQueryDsl;
453
454        futures_util::future::try_join(
455            diesel::sql_query("SET TIME ZONE 'UTC'").execute(self),
456            diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'").execute(self),
457        )
458        .await?;
459        Ok(())
460    }
461
462    fn run_with_connection_future<'a, R: 'a>(
463        &self,
464        future: impl Future<Output = QueryResult<R>> + Send + 'a,
465    ) -> BoxFuture<'a, QueryResult<R>> {
466        let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
467        drive_future(connection_future, future).boxed()
468    }
469
470    fn with_prepared_statement<'a, T, F, R>(
471        &mut self,
472        query: T,
473        callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
474    ) -> BoxFuture<'a, QueryResult<R>>
475    where
476        T: QueryFragment<diesel::pg::Pg> + QueryId,
477        F: Future<Output = QueryResult<R>> + Send + 'a,
478        R: Send,
479    {
480        self.record_instrumentation(InstrumentationEvent::start_query(&diesel::debug_query(
481            &query,
482        )));
483        // we explicilty descruct the query here before going into the async block
484        //
485        // That's required to remove the send bound from `T` as we have translated
486        // the query type to just a string (for the SQL) and a bunch of bytes (for the binds)
487        // which both are `Send`.
488        // We also collect the query id (essentially an integer) and the safe_to_cache flag here
489        // so there is no need to even access the query in the async block below
490        let mut query_builder = PgQueryBuilder::default();
491
492        let bind_data = construct_bind_data(&query);
493
494        // The code that doesn't need the `T` generic parameter is in a separate function to reduce LLVM IR lines
495        self.with_prepared_statement_after_sql_built(
496            callback,
497            query.is_safe_to_cache_prepared(&Pg),
498            T::query_id(),
499            query.to_sql(&mut query_builder, &Pg),
500            query_builder,
501            bind_data,
502        )
503    }
504
505    fn with_prepared_statement_after_sql_built<'a, F, R>(
506        &mut self,
507        callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
508        is_safe_to_cache_prepared: QueryResult<bool>,
509        query_id: Option<std::any::TypeId>,
510        to_sql_result: QueryResult<()>,
511        query_builder: PgQueryBuilder,
512        bind_data: BindData,
513    ) -> BoxFuture<'a, QueryResult<R>>
514    where
515        F: Future<Output = QueryResult<R>> + Send + 'a,
516        R: Send,
517    {
518        let raw_connection = self.conn.clone();
519        let stmt_cache = self.stmt_cache.clone();
520        let metadata_cache = self.metadata_cache.clone();
521        let tm = self.transaction_state.clone();
522        let instrumentation = self.instrumentation.clone();
523        let BindData {
524            collect_bind_result,
525            fake_oid_locations,
526            generated_oids,
527            mut bind_collector,
528        } = bind_data;
529
530        async move {
531            let sql = to_sql_result.map(|_| query_builder.finish())?;
532            let res = async {
533            let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
534            collect_bind_result?;
535            // Check whether we need to resolve some types at all
536            //
537            // If the user doesn't use custom types there is no need
538            // to borther with that at all
539            if let Some(ref unresolved_types) = generated_oids {
540                let metadata_cache = &mut *metadata_cache.lock().await;
541                let mut real_oids = HashMap::new();
542
543                for ((schema, lookup_type_name), (fake_oid, fake_array_oid)) in
544                    unresolved_types
545                {
546                    // for each unresolved item
547                    // we check whether it's arleady in the cache
548                    // or perform a lookup and insert it into the cache
549                    let cache_key = PgMetadataCacheKey::new(
550                        schema.as_deref().map(Into::into),
551                        lookup_type_name.into(),
552                    );
553                    let real_metadata = if let Some(type_metadata) =
554                        metadata_cache.lookup_type(&cache_key)
555                    {
556                        type_metadata
557                    } else {
558                        let type_metadata =
559                            lookup_type(schema.clone(), lookup_type_name.clone(), &raw_connection)
560                                .await?;
561                        metadata_cache.store_type(cache_key, type_metadata);
562
563                        PgTypeMetadata::from_result(Ok(type_metadata))
564                    };
565                    // let (fake_oid, fake_array_oid) = metadata_lookup.fake_oids(index);
566                    let (real_oid, real_array_oid) = unwrap_oids(&real_metadata);
567                    real_oids.extend([(*fake_oid, real_oid), (*fake_array_oid, real_array_oid)]);
568                }
569
570                // Replace fake OIDs with real OIDs in `bind_collector.metadata`
571                for m in &mut bind_collector.metadata {
572                    let (oid, array_oid) = unwrap_oids(m);
573                    *m = PgTypeMetadata::new(
574                        real_oids.get(&oid).copied().unwrap_or(oid),
575                        real_oids.get(&array_oid).copied().unwrap_or(array_oid)
576                    );
577                }
578                // Replace fake OIDs with real OIDs in `bind_collector.binds`
579                for (bind_index, byte_index) in fake_oid_locations {
580                    replace_fake_oid(&mut bind_collector.binds, &real_oids, bind_index, byte_index)
581                        .ok_or_else(|| {
582                            Error::SerializationError(
583                                format!("diesel_async failed to replace a type OID serialized in bind value {bind_index}").into(),
584                            )
585                        })?;
586                }
587            }
588            let stmt = {
589                let mut stmt_cache = stmt_cache.lock().await;
590                let helper = QueryFragmentHelper {
591                    sql: sql.clone(),
592                    safe_to_cache: is_safe_to_cache_prepared,
593                };
594                let instrumentation = Arc::clone(&instrumentation);
595                stmt_cache
596                    .cached_statement_non_generic(
597                        query_id,
598                        &helper,
599                        &Pg,
600                        &bind_collector.metadata,
601                        raw_connection.clone(),
602                        prepare_statement_helper,
603                        &mut move |event: InstrumentationEvent<'_>| {
604                            // we wrap this lock into another callback to prevent locking
605                            // the instrumentation longer than necessary
606                            instrumentation.lock().unwrap_or_else(|e| e.into_inner())
607                                .on_connection_event(event);
608                        },
609                    )
610                    .await?
611                    .0
612                    .clone()
613            };
614
615            let binds = bind_collector
616                .metadata
617                .into_iter()
618                .zip(bind_collector.binds)
619                .map(|(meta, bind)| ToSqlHelper(meta, bind))
620                .collect::<Vec<_>>();
621                callback(raw_connection, stmt.clone(), binds).await
622            };
623            let res = res.await;
624            let mut tm = tm.lock().await;
625            let r = update_transaction_manager_status(res, &mut tm);
626            instrumentation
627                .lock()
628                .unwrap_or_else(|p| p.into_inner())
629                .on_connection_event(InstrumentationEvent::finish_query(
630                    &StrQueryHelper::new(&sql),
631                    r.as_ref().err(),
632                ));
633
634            r
635        }
636        .boxed()
637    }
638
639    fn record_instrumentation(&self, event: InstrumentationEvent<'_>) {
640        self.instrumentation
641            .lock()
642            .unwrap_or_else(|p| p.into_inner())
643            .on_connection_event(event);
644    }
645}
646
647struct BindData {
648    collect_bind_result: Result<(), Error>,
649    fake_oid_locations: Vec<(usize, usize)>,
650    generated_oids: GeneratedOidTypeMap,
651    bind_collector: RawBytesBindCollector<Pg>,
652}
653
654fn construct_bind_data(query: &dyn QueryFragment<diesel::pg::Pg>) -> BindData {
655    // we don't resolve custom types here yet, we do that later
656    // in the async block below as we might need to perform lookup
657    // queries for that.
658    //
659    // We apply this workaround to prevent requiring all the diesel
660    // serialization code to beeing async
661    //
662    // We give out constant fake oids here to optimize for the "happy" path
663    // without custom type lookup
664    let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
665    let mut metadata_lookup_0 = PgAsyncMetadataLookup {
666        custom_oid: false,
667        generated_oids: None,
668        oid_generator: |_, _| (FAKE_OID, FAKE_OID),
669    };
670    let collect_bind_result_0 =
671        query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
672    // we have encountered a custom type oid, so we need to perform more work here.
673    // These oids can occure in two locations:
674    //
675    // * In the collected metadata -> relativly easy to resolve, just need to replace them below
676    // * As part of the seralized bind blob -> hard to replace
677    //
678    // To address the second case, we perform a second run of the bind collector
679    // with a different set of fake oids. Then we compare the output of the two runs
680    // and use that information to infer where to replace bytes in the serialized output
681    if metadata_lookup_0.custom_oid {
682        // we try to get the maxium oid we encountered here
683        // to be sure that we don't accidently give out a fake oid below that collides with
684        // something
685        let mut max_oid = bind_collector_0
686            .metadata
687            .iter()
688            .flat_map(|t| {
689                [
690                    t.oid().unwrap_or_default(),
691                    t.array_oid().unwrap_or_default(),
692                ]
693            })
694            .max()
695            .unwrap_or_default();
696        let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
697        let mut metadata_lookup_1 = PgAsyncMetadataLookup {
698            custom_oid: false,
699            generated_oids: Some(HashMap::new()),
700            oid_generator: move |_, _| {
701                max_oid += 2;
702                (max_oid, max_oid + 1)
703            },
704        };
705        let collect_bind_result_1 =
706            query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
707
708        assert_eq!(
709            bind_collector_0.binds.len(),
710            bind_collector_0.metadata.len()
711        );
712        let fake_oid_locations = std::iter::zip(
713            bind_collector_0
714                .binds
715                .iter()
716                .zip(&bind_collector_0.metadata),
717            &bind_collector_1.binds,
718        )
719        .enumerate()
720        .flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
721            // custom oids might appear in the serialized bind arguments for arrays or composite (record) types
722            // in both cases the relevant buffer is a custom type on it's own
723            // so we only need to check the cases that contain a fake OID on their own
724            let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
725                (
726                    bytes_0.as_deref().unwrap_or_default(),
727                    bytes_1.as_deref().unwrap_or_default(),
728                )
729            } else {
730                // for all other cases, just return an empty
731                // list to make the iteration below a no-op
732                // and prevent the need of boxing
733                (&[] as &[_], &[] as &[_])
734            };
735            let lookup_map = metadata_lookup_1
736                .generated_oids
737                .as_ref()
738                .map(|map| {
739                    map.values()
740                        .flat_map(|(oid, array_oid)| [*oid, *array_oid])
741                        .collect::<HashSet<_>>()
742                })
743                .unwrap_or_default();
744            std::iter::zip(
745                bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
746                bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
747            )
748            .enumerate()
749            .filter_map(move |(byte_index, (l, r))| {
750                // here we infer if some byte sequence is a fake oid
751                // We use the following conditions for that:
752                //
753                // * The first byte sequence matches the constant FAKE_OID
754                // * The second sequence does not match the constant FAKE_OID
755                // * The second sequence is contained in the set of generated oid,
756                //   otherwise we get false positives around the boundary
757                //   of a to be replaced byte sequence
758                let r_val = u32::from_be_bytes(r.try_into().expect("That's the right size"));
759                (l == FAKE_OID.to_be_bytes()
760                    && r != FAKE_OID.to_be_bytes()
761                    && lookup_map.contains(&r_val))
762                .then_some((bind_index, byte_index))
763            })
764        })
765        // Avoid storing the bind collectors in the returned Future
766        .collect::<Vec<_>>();
767        BindData {
768            collect_bind_result: collect_bind_result_0.and(collect_bind_result_1),
769            fake_oid_locations,
770            generated_oids: metadata_lookup_1.generated_oids,
771            bind_collector: bind_collector_1,
772        }
773    } else {
774        BindData {
775            collect_bind_result: collect_bind_result_0,
776            fake_oid_locations: Vec::new(),
777            generated_oids: None,
778            bind_collector: bind_collector_0,
779        }
780    }
781}
782
783type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
784
785/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
786/// so they can be replaced with asynchronously fetched OIDs after the original query is dropped
787struct PgAsyncMetadataLookup<F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static> {
788    custom_oid: bool,
789    generated_oids: GeneratedOidTypeMap,
790    oid_generator: F,
791}
792
793impl<F> PgMetadataLookup for PgAsyncMetadataLookup<F>
794where
795    F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static,
796{
797    fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> PgTypeMetadata {
798        self.custom_oid = true;
799
800        let oid = if let Some(map) = &mut self.generated_oids {
801            *map.entry((schema.map(ToOwned::to_owned), type_name.to_owned()))
802                .or_insert_with(|| (self.oid_generator)(type_name, schema))
803        } else {
804            (self.oid_generator)(type_name, schema)
805        };
806
807        PgTypeMetadata::from_result(Ok(oid))
808    }
809}
810
811async fn lookup_type(
812    schema: Option<String>,
813    type_name: String,
814    raw_connection: &tokio_postgres::Client,
815) -> QueryResult<(u32, u32)> {
816    let r = if let Some(schema) = schema.as_ref() {
817        raw_connection
818            .query_one(
819                "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
820             INNER JOIN pg_namespace ON pg_type.typnamespace = pg_namespace.oid \
821             WHERE pg_type.typname = $1 AND pg_namespace.nspname = $2 \
822             LIMIT 1",
823                &[&type_name, schema],
824            )
825            .await
826            .map_err(ErrorHelper)?
827    } else {
828        raw_connection
829            .query_one(
830                "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
831             WHERE pg_type.oid = quote_ident($1)::regtype::oid \
832             LIMIT 1",
833                &[&type_name],
834            )
835            .await
836            .map_err(ErrorHelper)?
837    };
838    Ok((r.get(0), r.get(1)))
839}
840
841fn unwrap_oids(metadata: &PgTypeMetadata) -> (u32, u32) {
842    let err_msg = "PgTypeMetadata is supposed to always be Ok here";
843    (
844        metadata.oid().expect(err_msg),
845        metadata.array_oid().expect(err_msg),
846    )
847}
848
849fn replace_fake_oid(
850    binds: &mut [Option<Vec<u8>>],
851    real_oids: &HashMap<u32, u32>,
852    bind_index: usize,
853    byte_index: usize,
854) -> Option<()> {
855    let serialized_oid = binds
856        .get_mut(bind_index)?
857        .as_mut()?
858        .get_mut(byte_index..)?
859        .first_chunk_mut::<4>()?;
860    *serialized_oid = real_oids
861        .get(&u32::from_be_bytes(*serialized_oid))?
862        .to_be_bytes();
863    Some(())
864}
865
866async fn drive_future<R>(
867    connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
868    client_future: impl Future<Output = Result<R, diesel::result::Error>>,
869) -> Result<R, diesel::result::Error> {
870    if let Some(mut connection_future) = connection_future {
871        let client_future = std::pin::pin!(client_future);
872        let connection_future = std::pin::pin!(connection_future.recv());
873        match futures_util::future::select(client_future, connection_future).await {
874            Either::Left((res, _)) => res,
875            // we got an error from the background task
876            // return it to the user
877            Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
878            // seems like the background thread died for whatever reason
879            Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
880                DatabaseErrorKind::UnableToSendCommand,
881                Box::new(e.to_string()),
882            )),
883        }
884    } else {
885        client_future.await
886    }
887}
888
889fn drive_connection<S>(
890    conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
891) -> (
892    broadcast::Receiver<Arc<tokio_postgres::Error>>,
893    oneshot::Sender<()>,
894)
895where
896    S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
897{
898    let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
899    let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
900
901    tokio::spawn(async move {
902        match futures_util::future::select(shutdown_rx, conn).await {
903            Either::Left(_) | Either::Right((Ok(_), _)) => {}
904            Either::Right((Err(e), _)) => {
905                let _ = error_tx.send(Arc::new(e));
906            }
907        }
908    });
909
910    (error_rx, shutdown_tx)
911}
912
913#[cfg(any(
914    feature = "deadpool",
915    feature = "bb8",
916    feature = "mobc",
917    feature = "r2d2"
918))]
919impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
920    fn is_broken(&mut self) -> bool {
921        use crate::TransactionManager;
922
923        Self::TransactionManager::is_broken_transaction_manager(self) || self.conn.is_closed()
924    }
925}
926
927impl QueryFragmentForCachedStatement<Pg> for QueryFragmentHelper {
928    fn construct_sql(&self, _backend: &Pg) -> QueryResult<String> {
929        Ok(self.sql.clone())
930    }
931
932    fn is_safe_to_cache_prepared(&self, _backend: &Pg) -> QueryResult<bool> {
933        Ok(self.safe_to_cache)
934    }
935}
936
937#[cfg(test)]
938mod tests {
939    use super::*;
940    use crate::run_query_dsl::RunQueryDsl;
941    use diesel::sql_types::Integer;
942    use diesel::IntoSql;
943
944    #[tokio::test]
945    async fn pipelining() {
946        let database_url =
947            std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
948        let mut conn = crate::AsyncPgConnection::establish(&database_url)
949            .await
950            .unwrap();
951
952        let q1 = diesel::select(1_i32.into_sql::<Integer>());
953        let q2 = diesel::select(2_i32.into_sql::<Integer>());
954
955        let f1 = q1.get_result::<i32>(&mut conn);
956        let f2 = q2.get_result::<i32>(&mut conn);
957
958        let (r1, r2) = futures_util::try_join!(f1, f2).unwrap();
959
960        assert_eq!(r1, 1);
961        assert_eq!(r2, 2);
962    }
963}