diesel_async/pg/
mod.rs

1//! Provides types and functions related to working with PostgreSQL
2//!
3//! Much of this module is re-exported from database agnostic locations.
4//! However, if you are writing code specifically to extend Diesel on
5//! PostgreSQL, you may need to work with this module directly.
6
7use self::error_helper::ErrorHelper;
8use self::row::PgRow;
9use self::serialize::ToSqlHelper;
10use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
11use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection};
12use diesel::connection::statement_cache::{
13    PrepareForCache, QueryFragmentForCachedStatement, StatementCache,
14};
15use diesel::connection::StrQueryHelper;
16use diesel::connection::{CacheSize, Instrumentation};
17use diesel::connection::{DynInstrumentation, InstrumentationEvent};
18use diesel::pg::{
19    Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata,
20};
21use diesel::query_builder::bind_collector::RawBytesBindCollector;
22use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
23use diesel::result::{DatabaseErrorKind, Error};
24use diesel::{ConnectionError, ConnectionResult, QueryResult};
25use futures_core::future::BoxFuture;
26use futures_core::stream::BoxStream;
27use futures_util::future::Either;
28use futures_util::stream::TryStreamExt;
29use futures_util::TryFutureExt;
30use futures_util::{FutureExt, StreamExt};
31use std::collections::{HashMap, HashSet};
32use std::future::Future;
33use std::sync::Arc;
34use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
35use tokio_postgres::types::ToSql;
36use tokio_postgres::types::Type;
37use tokio_postgres::Statement;
38
39pub use self::transaction_builder::TransactionBuilder;
40
41mod error_helper;
42mod row;
43mod serialize;
44mod transaction_builder;
45
46const FAKE_OID: u32 = 0;
47
48/// A connection to a PostgreSQL database.
49///
50/// Connection URLs should be in the form
51/// `postgres://[user[:password]@]host/database_name`
52///
53/// Checkout the documentation of the [tokio_postgres]
54/// crate for details about the format
55///
56/// [tokio_postgres]: https://docs.rs/tokio-postgres/0.7.6/tokio_postgres/config/struct.Config.html#url
57///
58/// ## Pipelining
59///
60/// This connection supports *pipelined* requests. Pipelining can improve performance in use cases in which multiple,
61/// independent queries need to be executed. In a traditional workflow, each query is sent to the server after the
62/// previous query completes. In contrast, pipelining allows the client to send all of the queries to the server up
63/// front, minimizing time spent by one side waiting for the other to finish sending data:
64///
65/// ```not_rust
66///             Sequential                              Pipelined
67/// | Client         | Server          |    | Client         | Server          |
68/// |----------------|-----------------|    |----------------|-----------------|
69/// | send query 1   |                 |    | send query 1   |                 |
70/// |                | process query 1 |    | send query 2   | process query 1 |
71/// | receive rows 1 |                 |    | send query 3   | process query 2 |
72/// | send query 2   |                 |    | receive rows 1 | process query 3 |
73/// |                | process query 2 |    | receive rows 2 |                 |
74/// | receive rows 2 |                 |    | receive rows 3 |                 |
75/// | send query 3   |                 |
76/// |                | process query 3 |
77/// | receive rows 3 |                 |
78/// ```
79///
80/// In both cases, the PostgreSQL server is executing the queries **sequentially** - pipelining just allows both sides of
81/// the connection to work concurrently when possible.
82///
83/// Pipelining happens automatically when futures are polled concurrently (for example, by using the futures `join`
84/// combinator):
85///
86/// ```rust
87/// # include!("../doctest_setup.rs");
88/// use diesel_async::RunQueryDsl;
89///
90/// #
91/// # #[tokio::main(flavor = "current_thread")]
92/// # async fn main() {
93/// #     run_test().await.unwrap();
94/// # }
95/// #
96/// # async fn run_test() -> QueryResult<()> {
97/// #     use diesel::sql_types::{Text, Integer};
98/// #     let conn = &mut establish_connection().await;
99///       let q1 = diesel::select(1_i32.into_sql::<Integer>());
100///       let q2 = diesel::select(2_i32.into_sql::<Integer>());
101///
102///       // construct multiple futures for different queries
103///       let f1 = q1.get_result::<i32>(conn);
104///       let f2 = q2.get_result::<i32>(conn);
105///
106///       // wait on both results
107///       let res = futures_util::try_join!(f1, f2)?;
108///
109///       assert_eq!(res.0, 1);
110///       assert_eq!(res.1, 2);
111///       # Ok(())
112/// # }
113/// ```
114///
115/// For more complex cases, an immutable reference to the connection need to be used:
116/// ```rust
117/// # include!("../doctest_setup.rs");
118/// use diesel_async::RunQueryDsl;
119///
120/// #
121/// # #[tokio::main(flavor = "current_thread")]
122/// # async fn main() {
123/// #     run_test().await.unwrap();
124/// # }
125/// #
126/// # async fn run_test() -> QueryResult<()> {
127/// #     use diesel::sql_types::{Text, Integer};
128/// #     let conn = &mut establish_connection().await;
129/// #
130///       async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
131///           let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
132///           let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
133///
134///           futures_util::try_join!(f1, f2)
135///       }
136///
137///       async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
138///           let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
139///           let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
140///
141///           futures_util::try_join!(f3, f4)
142///       }
143///
144///       let f12 = fn12(&conn);
145///       let f34 = fn34(&conn);
146///
147///       let ((r1, r2), (r3, r4)) = futures_util::try_join!(f12, f34).unwrap();
148///
149///       assert_eq!(r1, 1);
150///       assert_eq!(r2, 2);
151///       assert_eq!(r3, 3);
152///       assert_eq!(r4, 4);
153///       # Ok(())
154/// # }
155/// ```
156///
157/// ## TLS
158///
159/// Connections created by [`AsyncPgConnection::establish`] do not support TLS.
160///
161/// TLS support for tokio_postgres connections is implemented by external crates, e.g. [tokio_postgres_rustls].
162///
163/// [`AsyncPgConnection::try_from_client_and_connection`] can be used to construct a connection from an existing
164/// [`tokio_postgres::Connection`] with TLS enabled.
165///
166/// [tokio_postgres_rustls]: https://docs.rs/tokio-postgres-rustls/0.12.0/tokio_postgres_rustls/
167pub struct AsyncPgConnection {
168    conn: Arc<tokio_postgres::Client>,
169    stmt_cache: Arc<Mutex<StatementCache<diesel::pg::Pg, Statement>>>,
170    transaction_state: Arc<Mutex<AnsiTransactionManager>>,
171    metadata_cache: Arc<Mutex<PgMetadataCache>>,
172    connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
173    notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
174    shutdown_channel: Option<oneshot::Sender<()>>,
175    // a sync mutex is fine here as we only hold it for a really short time
176    instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
177}
178
179impl SimpleAsyncConnection for AsyncPgConnection {
180    async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
181        SimpleAsyncConnection::batch_execute(&mut &*self, query).await
182    }
183}
184
185impl SimpleAsyncConnection for &AsyncPgConnection {
186    async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
187        self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
188            query,
189        )));
190        let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
191        let batch_execute = self
192            .conn
193            .batch_execute(query)
194            .map_err(ErrorHelper)
195            .map_err(Into::into);
196
197        let r = drive_future(connection_future, batch_execute).await;
198        let r = {
199            let mut transaction_manager = self.transaction_state.lock().await;
200            update_transaction_manager_status(r, &mut transaction_manager)
201        };
202        self.record_instrumentation(InstrumentationEvent::finish_query(
203            &StrQueryHelper::new(query),
204            r.as_ref().err(),
205        ));
206        r
207    }
208}
209
210impl AsyncConnectionCore for AsyncPgConnection {
211    type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
212    type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
213    type Stream<'conn, 'query> = BoxStream<'static, QueryResult<PgRow>>;
214    type Row<'conn, 'query> = PgRow;
215    type Backend = diesel::pg::Pg;
216
217    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
218    where
219        T: AsQuery + 'query,
220        T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
221    {
222        AsyncConnectionCore::load(&mut &*self, source)
223    }
224
225    fn execute_returning_count<'conn, 'query, T>(
226        &'conn mut self,
227        source: T,
228    ) -> Self::ExecuteFuture<'conn, 'query>
229    where
230        T: QueryFragment<Self::Backend> + QueryId + 'query,
231    {
232        AsyncConnectionCore::execute_returning_count(&mut &*self, source)
233    }
234}
235
236impl AsyncConnectionCore for &AsyncPgConnection {
237    type LoadFuture<'conn, 'query> =
238        <AsyncPgConnection as AsyncConnectionCore>::LoadFuture<'conn, 'query>;
239
240    type ExecuteFuture<'conn, 'query> =
241        <AsyncPgConnection as AsyncConnectionCore>::ExecuteFuture<'conn, 'query>;
242
243    type Stream<'conn, 'query> = <AsyncPgConnection as AsyncConnectionCore>::Stream<'conn, 'query>;
244
245    type Row<'conn, 'query> = <AsyncPgConnection as AsyncConnectionCore>::Row<'conn, 'query>;
246
247    type Backend = <AsyncPgConnection as AsyncConnectionCore>::Backend;
248
249    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
250    where
251        T: AsQuery + 'query,
252        T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
253    {
254        let query = source.as_query();
255        let load_future = self.with_prepared_statement(query, load_prepared);
256
257        self.run_with_connection_future(load_future)
258    }
259
260    fn execute_returning_count<'conn, 'query, T>(
261        &'conn mut self,
262        source: T,
263    ) -> Self::ExecuteFuture<'conn, 'query>
264    where
265        T: QueryFragment<Self::Backend> + QueryId + 'query,
266    {
267        let execute = self.with_prepared_statement(source, execute_prepared);
268        self.run_with_connection_future(execute)
269    }
270}
271
272impl AsyncConnection for AsyncPgConnection {
273    type TransactionManager = AnsiTransactionManager;
274
275    async fn establish(database_url: &str) -> ConnectionResult<Self> {
276        let mut instrumentation = DynInstrumentation::default_instrumentation();
277        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
278            database_url,
279        ));
280        let instrumentation = Arc::new(std::sync::Mutex::new(instrumentation));
281        let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
282            .await
283            .map_err(ErrorHelper)?;
284
285        let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection);
286
287        let r = Self::setup(
288            client,
289            Some(error_rx),
290            Some(notification_rx),
291            Some(shutdown_tx),
292            Arc::clone(&instrumentation),
293        )
294        .await;
295
296        instrumentation
297            .lock()
298            .unwrap_or_else(|e| e.into_inner())
299            .on_connection_event(InstrumentationEvent::finish_establish_connection(
300                database_url,
301                r.as_ref().err(),
302            ));
303        r
304    }
305
306    fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
307        // there should be no other pending future when this is called
308        // that means there is only one instance of this arc and
309        // we can simply access the inner data
310        if let Some(tm) = Arc::get_mut(&mut self.transaction_state) {
311            tm.get_mut()
312        } else {
313            panic!("Cannot access shared transaction state")
314        }
315    }
316
317    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
318        // there should be no other pending future when this is called
319        // that means there is only one instance of this arc and
320        // we can simply access the inner data
321        if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) {
322            &mut **(instrumentation.get_mut().unwrap_or_else(|p| p.into_inner()))
323        } else {
324            panic!("Cannot access shared instrumentation")
325        }
326    }
327
328    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
329        self.instrumentation = Arc::new(std::sync::Mutex::new(instrumentation.into()));
330    }
331
332    fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
333        // there should be no other pending future when this is called
334        // that means there is only one instance of this arc and
335        // we can simply access the inner data
336        if let Some(cache) = Arc::get_mut(&mut self.stmt_cache) {
337            cache.get_mut().set_cache_size(size)
338        } else {
339            panic!("Cannot access shared statement cache")
340        }
341    }
342}
343
344impl Drop for AsyncPgConnection {
345    fn drop(&mut self) {
346        if let Some(tx) = self.shutdown_channel.take() {
347            let _ = tx.send(());
348        }
349    }
350}
351
352async fn load_prepared(
353    conn: Arc<tokio_postgres::Client>,
354    stmt: Statement,
355    binds: Vec<ToSqlHelper>,
356) -> QueryResult<BoxStream<'static, QueryResult<PgRow>>> {
357    let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
358
359    Ok(res
360        .map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
361        .map_ok(PgRow::new)
362        .boxed())
363}
364
365async fn execute_prepared(
366    conn: Arc<tokio_postgres::Client>,
367    stmt: Statement,
368    binds: Vec<ToSqlHelper>,
369) -> QueryResult<usize> {
370    let binds = binds
371        .iter()
372        .map(|b| b as &(dyn ToSql + Sync))
373        .collect::<Vec<_>>();
374
375    let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_])
376        .await
377        .map_err(ErrorHelper)?;
378    res.try_into()
379        .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
380}
381
382#[inline(always)]
383fn update_transaction_manager_status<T>(
384    query_result: QueryResult<T>,
385    transaction_manager: &mut AnsiTransactionManager,
386) -> QueryResult<T> {
387    if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
388        query_result
389    {
390        if !transaction_manager.is_commit {
391            transaction_manager
392                .status
393                .set_requires_rollback_maybe_up_to_top_level(true);
394        }
395    }
396    query_result
397}
398
399fn prepare_statement_helper(
400    conn: Arc<tokio_postgres::Client>,
401    sql: &str,
402    _is_for_cache: PrepareForCache,
403    metadata: &[PgTypeMetadata],
404) -> CallbackHelper<
405    impl Future<Output = QueryResult<(Statement, Arc<tokio_postgres::Client>)>> + Send,
406> {
407    let bind_types = metadata
408        .iter()
409        .map(type_from_oid)
410        .collect::<QueryResult<Vec<_>>>();
411    // ideally we wouldn't clone the SQL string here
412    // but as we usually cache statements anyway
413    // this is a fixed one time const
414    //
415    // The probleme with not cloning it is that we then cannot express
416    // the right result lifetime anymore (at least not easily)
417    let sql = sql.to_string();
418    CallbackHelper(async move {
419        let bind_types = bind_types?;
420        let stmt = conn
421            .prepare_typed(&sql, &bind_types)
422            .await
423            .map_err(ErrorHelper);
424        Ok((stmt?, conn))
425    })
426}
427
428fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
429    let oid = t
430        .oid()
431        .map_err(|e| diesel::result::Error::SerializationError(Box::new(e) as _))?;
432
433    if let Some(tpe) = Type::from_oid(oid) {
434        return Ok(tpe);
435    }
436
437    Ok(Type::new(
438        format!("diesel_custom_type_{oid}"),
439        oid,
440        tokio_postgres::types::Kind::Simple,
441        "public".into(),
442    ))
443}
444
445impl AsyncPgConnection {
446    /// Build a transaction, specifying additional details such as isolation level
447    ///
448    /// See [`TransactionBuilder`] for more examples.
449    ///
450    /// [`TransactionBuilder`]: crate::pg::TransactionBuilder
451    ///
452    /// ```rust
453    /// # include!("../doctest_setup.rs");
454    /// # use scoped_futures::ScopedFutureExt;
455    /// #
456    /// # #[tokio::main(flavor = "current_thread")]
457    /// # async fn main() {
458    /// #     run_test().await.unwrap();
459    /// # }
460    /// #
461    /// # async fn run_test() -> QueryResult<()> {
462    /// #     use schema::users::dsl::*;
463    /// #     let conn = &mut connection_no_transaction().await;
464    /// conn.build_transaction()
465    ///     .read_only()
466    ///     .serializable()
467    ///     .deferrable()
468    ///     .run(|conn| async move { Ok(()) }.scope_boxed())
469    ///     .await
470    /// # }
471    /// ```
472    pub fn build_transaction(&mut self) -> TransactionBuilder<'_, Self> {
473        TransactionBuilder::new(self)
474    }
475
476    /// Construct a new `AsyncPgConnection` instance from an existing [`tokio_postgres::Client`]
477    pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
478        Self::setup(
479            conn,
480            None,
481            None,
482            None,
483            Arc::new(std::sync::Mutex::new(
484                DynInstrumentation::default_instrumentation(),
485            )),
486        )
487        .await
488    }
489
490    /// Constructs a new `AsyncPgConnection` from an existing [`tokio_postgres::Client`] and
491    /// [`tokio_postgres::Connection`]
492    pub async fn try_from_client_and_connection<S>(
493        client: tokio_postgres::Client,
494        conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
495    ) -> ConnectionResult<Self>
496    where
497        S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
498    {
499        let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn);
500
501        Self::setup(
502            client,
503            Some(error_rx),
504            Some(notification_rx),
505            Some(shutdown_tx),
506            Arc::new(std::sync::Mutex::new(
507                DynInstrumentation::default_instrumentation(),
508            )),
509        )
510        .await
511    }
512
513    async fn setup(
514        conn: tokio_postgres::Client,
515        connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
516        notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
517        shutdown_channel: Option<oneshot::Sender<()>>,
518        instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
519    ) -> ConnectionResult<Self> {
520        let mut conn = Self {
521            conn: Arc::new(conn),
522            stmt_cache: Arc::new(Mutex::new(StatementCache::new())),
523            transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
524            metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
525            connection_future,
526            notification_rx,
527            shutdown_channel,
528            instrumentation,
529        };
530        conn.set_config_options()
531            .await
532            .map_err(ConnectionError::CouldntSetupConfiguration)?;
533        Ok(conn)
534    }
535
536    /// Constructs a cancellation token that can later be used to request cancellation of a query running on the connection associated with this client.
537    pub fn cancel_token(&self) -> tokio_postgres::CancelToken {
538        self.conn.cancel_token()
539    }
540
541    async fn set_config_options(&mut self) -> QueryResult<()> {
542        use crate::run_query_dsl::RunQueryDsl;
543
544        futures_util::future::try_join(
545            diesel::sql_query("SET TIME ZONE 'UTC'").execute(self),
546            diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'").execute(self),
547        )
548        .await?;
549        Ok(())
550    }
551
552    fn run_with_connection_future<'a, R: 'a>(
553        &self,
554        future: impl Future<Output = QueryResult<R>> + Send + 'a,
555    ) -> BoxFuture<'a, QueryResult<R>> {
556        let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
557        drive_future(connection_future, future).boxed()
558    }
559
560    fn with_prepared_statement<'a, T, F, R>(
561        &self,
562        query: T,
563        callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
564    ) -> BoxFuture<'a, QueryResult<R>>
565    where
566        T: QueryFragment<diesel::pg::Pg> + QueryId,
567        F: Future<Output = QueryResult<R>> + Send + 'a,
568        R: Send,
569    {
570        self.record_instrumentation(InstrumentationEvent::start_query(&diesel::debug_query(
571            &query,
572        )));
573        // we explicilty descruct the query here before going into the async block
574        //
575        // That's required to remove the send bound from `T` as we have translated
576        // the query type to just a string (for the SQL) and a bunch of bytes (for the binds)
577        // which both are `Send`.
578        // We also collect the query id (essentially an integer) and the safe_to_cache flag here
579        // so there is no need to even access the query in the async block below
580        let mut query_builder = PgQueryBuilder::default();
581
582        let bind_data = construct_bind_data(&query);
583
584        // The code that doesn't need the `T` generic parameter is in a separate function to reduce LLVM IR lines
585        self.with_prepared_statement_after_sql_built(
586            callback,
587            query.is_safe_to_cache_prepared(&Pg),
588            T::query_id(),
589            query.to_sql(&mut query_builder, &Pg),
590            query_builder,
591            bind_data,
592        )
593    }
594
595    fn with_prepared_statement_after_sql_built<'a, F, R>(
596        &self,
597        callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
598        is_safe_to_cache_prepared: QueryResult<bool>,
599        query_id: Option<std::any::TypeId>,
600        to_sql_result: QueryResult<()>,
601        query_builder: PgQueryBuilder,
602        bind_data: BindData,
603    ) -> BoxFuture<'a, QueryResult<R>>
604    where
605        F: Future<Output = QueryResult<R>> + Send + 'a,
606        R: Send,
607    {
608        let raw_connection = self.conn.clone();
609        let stmt_cache = self.stmt_cache.clone();
610        let metadata_cache = self.metadata_cache.clone();
611        let tm = self.transaction_state.clone();
612        let instrumentation = self.instrumentation.clone();
613        let BindData {
614            collect_bind_result,
615            fake_oid_locations,
616            generated_oids,
617            mut bind_collector,
618        } = bind_data;
619
620        async move {
621            let sql = to_sql_result.map(|_| query_builder.finish())?;
622            let res = async {
623            let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
624            collect_bind_result?;
625            // Check whether we need to resolve some types at all
626            //
627            // If the user doesn't use custom types there is no need
628            // to borther with that at all
629            if let Some(ref unresolved_types) = generated_oids {
630                let metadata_cache = &mut *metadata_cache.lock().await;
631                let mut real_oids = HashMap::new();
632
633                for ((schema, lookup_type_name), (fake_oid, fake_array_oid)) in
634                    unresolved_types
635                {
636                    // for each unresolved item
637                    // we check whether it's arleady in the cache
638                    // or perform a lookup and insert it into the cache
639                    let cache_key = PgMetadataCacheKey::new(
640                        schema.as_deref().map(Into::into),
641                        lookup_type_name.into(),
642                    );
643                    let real_metadata = if let Some(type_metadata) =
644                        metadata_cache.lookup_type(&cache_key)
645                    {
646                        type_metadata
647                    } else {
648                        let type_metadata =
649                            lookup_type(schema.clone(), lookup_type_name.clone(), &raw_connection)
650                                .await?;
651                        metadata_cache.store_type(cache_key, type_metadata);
652
653                        PgTypeMetadata::from_result(Ok(type_metadata))
654                    };
655                    // let (fake_oid, fake_array_oid) = metadata_lookup.fake_oids(index);
656                    let (real_oid, real_array_oid) = unwrap_oids(&real_metadata);
657                    real_oids.extend([(*fake_oid, real_oid), (*fake_array_oid, real_array_oid)]);
658                }
659
660                // Replace fake OIDs with real OIDs in `bind_collector.metadata`
661                for m in &mut bind_collector.metadata {
662                    let (oid, array_oid) = unwrap_oids(m);
663                    *m = PgTypeMetadata::new(
664                        real_oids.get(&oid).copied().unwrap_or(oid),
665                        real_oids.get(&array_oid).copied().unwrap_or(array_oid)
666                    );
667                }
668                // Replace fake OIDs with real OIDs in `bind_collector.binds`
669                for (bind_index, byte_index) in fake_oid_locations {
670                    replace_fake_oid(&mut bind_collector.binds, &real_oids, bind_index, byte_index)
671                        .ok_or_else(|| {
672                            Error::SerializationError(
673                                format!("diesel_async failed to replace a type OID serialized in bind value {bind_index}").into(),
674                            )
675                        })?;
676                }
677            }
678            let stmt = {
679                let mut stmt_cache = stmt_cache.lock().await;
680                let helper = QueryFragmentHelper {
681                    sql: sql.clone(),
682                    safe_to_cache: is_safe_to_cache_prepared,
683                };
684                let instrumentation = Arc::clone(&instrumentation);
685                stmt_cache
686                    .cached_statement_non_generic(
687                        query_id,
688                        &helper,
689                        &Pg,
690                        &bind_collector.metadata,
691                        raw_connection.clone(),
692                        prepare_statement_helper,
693                        &mut move |event: InstrumentationEvent<'_>| {
694                            // we wrap this lock into another callback to prevent locking
695                            // the instrumentation longer than necessary
696                            instrumentation.lock().unwrap_or_else(|e| e.into_inner())
697                                .on_connection_event(event);
698                        },
699                    )
700                    .await?
701                    .0
702                    .clone()
703            };
704
705            let binds = bind_collector
706                .metadata
707                .into_iter()
708                .zip(bind_collector.binds)
709                .map(|(meta, bind)| ToSqlHelper(meta, bind))
710                .collect::<Vec<_>>();
711                callback(raw_connection, stmt.clone(), binds).await
712            };
713            let res = res.await;
714            let mut tm = tm.lock().await;
715            let r = update_transaction_manager_status(res, &mut tm);
716            instrumentation
717                .lock()
718                .unwrap_or_else(|p| p.into_inner())
719                .on_connection_event(InstrumentationEvent::finish_query(
720                    &StrQueryHelper::new(&sql),
721                    r.as_ref().err(),
722                ));
723
724            r
725        }
726        .boxed()
727    }
728
729    fn record_instrumentation(&self, event: InstrumentationEvent<'_>) {
730        self.instrumentation
731            .lock()
732            .unwrap_or_else(|p| p.into_inner())
733            .on_connection_event(event);
734    }
735
736    /// See Postgres documentation for SQL commands [NOTIFY][] and [LISTEN][]
737    ///
738    /// The returned stream yields all notifications received by the connection, not only notifications received
739    /// after calling the function. The returned stream will never close, so no notifications will just result
740    /// in a pending state.
741    ///
742    /// If there's no connection available to poll, the stream will yield no notifications and be pending forever.
743    /// This can happen if you created the [`AsyncPgConnection`] by the [`try_from`] constructor.
744    ///
745    /// [NOTIFY]: https://www.postgresql.org/docs/current/sql-notify.html
746    /// [LISTEN]: https://www.postgresql.org/docs/current/sql-listen.html
747    /// [`AsyncPgConnection`]: crate::pg::AsyncPgConnection
748    /// [`try_from`]: crate::pg::AsyncPgConnection::try_from
749    ///
750    /// ```rust
751    /// # include!("../doctest_setup.rs");
752    /// # use scoped_futures::ScopedFutureExt;
753    /// #
754    /// # #[tokio::main(flavor = "current_thread")]
755    /// # async fn main() {
756    /// #     run_test().await.unwrap();
757    /// # }
758    /// #
759    /// # async fn run_test() -> QueryResult<()> {
760    /// #     use diesel_async::RunQueryDsl;
761    /// #     use futures_util::StreamExt;
762    /// #     let conn = &mut connection_no_transaction().await;
763    /// // register the notifications channel we want to receive notifications for
764    /// diesel::sql_query("LISTEN example_channel").execute(conn).await?;
765    /// // send some notification (usually done from a different connection/thread/application)
766    /// diesel::sql_query("NOTIFY example_channel, 'additional data'").execute(conn).await?;
767    ///
768    /// let mut notifications = std::pin::pin!(conn.notifications_stream());
769    /// let mut notification = notifications.next().await.unwrap().unwrap();
770    ///
771    /// assert_eq!(notification.channel, "example_channel");
772    /// assert_eq!(notification.payload, "additional data");
773    /// println!("Notification received from process with id {}", notification.process_id);
774    /// # Ok(())
775    /// # }
776    /// ```
777    pub fn notifications_stream(
778        &mut self,
779    ) -> impl futures_core::Stream<Item = QueryResult<diesel::pg::PgNotification>> + '_ {
780        match &mut self.notification_rx {
781            None => Either::Left(futures_util::stream::pending()),
782            Some(rx) => Either::Right(futures_util::stream::unfold(rx, |rx| async {
783                rx.recv().await.map(move |item| (item, rx))
784            })),
785        }
786    }
787}
788
789struct BindData {
790    collect_bind_result: Result<(), Error>,
791    fake_oid_locations: Vec<(usize, usize)>,
792    generated_oids: GeneratedOidTypeMap,
793    bind_collector: RawBytesBindCollector<Pg>,
794}
795
796fn construct_bind_data(query: &dyn QueryFragment<diesel::pg::Pg>) -> BindData {
797    // we don't resolve custom types here yet, we do that later
798    // in the async block below as we might need to perform lookup
799    // queries for that.
800    //
801    // We apply this workaround to prevent requiring all the diesel
802    // serialization code to beeing async
803    //
804    // We give out constant fake oids here to optimize for the "happy" path
805    // without custom type lookup
806    let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
807    let mut metadata_lookup_0 = PgAsyncMetadataLookup {
808        custom_oid: false,
809        generated_oids: None,
810        oid_generator: |_, _| (FAKE_OID, FAKE_OID),
811    };
812    let collect_bind_result_0 =
813        query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
814    // we have encountered a custom type oid, so we need to perform more work here.
815    // These oids can occure in two locations:
816    //
817    // * In the collected metadata -> relativly easy to resolve, just need to replace them below
818    // * As part of the seralized bind blob -> hard to replace
819    //
820    // To address the second case, we perform a second run of the bind collector
821    // with a different set of fake oids. Then we compare the output of the two runs
822    // and use that information to infer where to replace bytes in the serialized output
823    if metadata_lookup_0.custom_oid {
824        // we try to get the maxium oid we encountered here
825        // to be sure that we don't accidently give out a fake oid below that collides with
826        // something
827        let mut max_oid = bind_collector_0
828            .metadata
829            .iter()
830            .flat_map(|t| {
831                [
832                    t.oid().unwrap_or_default(),
833                    t.array_oid().unwrap_or_default(),
834                ]
835            })
836            .max()
837            .unwrap_or_default();
838        let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
839        let mut metadata_lookup_1 = PgAsyncMetadataLookup {
840            custom_oid: false,
841            generated_oids: Some(HashMap::new()),
842            oid_generator: move |_, _| {
843                max_oid += 2;
844                (max_oid, max_oid + 1)
845            },
846        };
847        let collect_bind_result_1 =
848            query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
849
850        assert_eq!(
851            bind_collector_0.binds.len(),
852            bind_collector_0.metadata.len()
853        );
854        let fake_oid_locations = std::iter::zip(
855            bind_collector_0
856                .binds
857                .iter()
858                .zip(&bind_collector_0.metadata),
859            &bind_collector_1.binds,
860        )
861        .enumerate()
862        .flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
863            // custom oids might appear in the serialized bind arguments for arrays or composite (record) types
864            // in both cases the relevant buffer is a custom type on it's own
865            // so we only need to check the cases that contain a fake OID on their own
866            let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
867                (
868                    bytes_0.as_deref().unwrap_or_default(),
869                    bytes_1.as_deref().unwrap_or_default(),
870                )
871            } else {
872                // for all other cases, just return an empty
873                // list to make the iteration below a no-op
874                // and prevent the need of boxing
875                (&[] as &[_], &[] as &[_])
876            };
877            let lookup_map = metadata_lookup_1
878                .generated_oids
879                .as_ref()
880                .map(|map| {
881                    map.values()
882                        .flat_map(|(oid, array_oid)| [*oid, *array_oid])
883                        .collect::<HashSet<_>>()
884                })
885                .unwrap_or_default();
886            std::iter::zip(
887                bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
888                bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
889            )
890            .enumerate()
891            .filter_map(move |(byte_index, (l, r))| {
892                // here we infer if some byte sequence is a fake oid
893                // We use the following conditions for that:
894                //
895                // * The first byte sequence matches the constant FAKE_OID
896                // * The second sequence does not match the constant FAKE_OID
897                // * The second sequence is contained in the set of generated oid,
898                //   otherwise we get false positives around the boundary
899                //   of a to be replaced byte sequence
900                let r_val = u32::from_be_bytes(r.try_into().expect("That's the right size"));
901                (l == FAKE_OID.to_be_bytes()
902                    && r != FAKE_OID.to_be_bytes()
903                    && lookup_map.contains(&r_val))
904                .then_some((bind_index, byte_index))
905            })
906        })
907        // Avoid storing the bind collectors in the returned Future
908        .collect::<Vec<_>>();
909        BindData {
910            collect_bind_result: collect_bind_result_0.and(collect_bind_result_1),
911            fake_oid_locations,
912            generated_oids: metadata_lookup_1.generated_oids,
913            bind_collector: bind_collector_1,
914        }
915    } else {
916        BindData {
917            collect_bind_result: collect_bind_result_0,
918            fake_oid_locations: Vec::new(),
919            generated_oids: None,
920            bind_collector: bind_collector_0,
921        }
922    }
923}
924
925type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
926
927/// Collects types that need to be looked up, and causes fake OIDs to be written into the bind collector
928/// so they can be replaced with asynchronously fetched OIDs after the original query is dropped
929struct PgAsyncMetadataLookup<F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static> {
930    custom_oid: bool,
931    generated_oids: GeneratedOidTypeMap,
932    oid_generator: F,
933}
934
935impl<F> PgMetadataLookup for PgAsyncMetadataLookup<F>
936where
937    F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static,
938{
939    fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> PgTypeMetadata {
940        self.custom_oid = true;
941
942        let oid = if let Some(map) = &mut self.generated_oids {
943            *map.entry((schema.map(ToOwned::to_owned), type_name.to_owned()))
944                .or_insert_with(|| (self.oid_generator)(type_name, schema))
945        } else {
946            (self.oid_generator)(type_name, schema)
947        };
948
949        PgTypeMetadata::from_result(Ok(oid))
950    }
951}
952
953async fn lookup_type(
954    schema: Option<String>,
955    type_name: String,
956    raw_connection: &tokio_postgres::Client,
957) -> QueryResult<(u32, u32)> {
958    let r = if let Some(schema) = schema.as_ref() {
959        raw_connection
960            .query_one(
961                "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
962             INNER JOIN pg_namespace ON pg_type.typnamespace = pg_namespace.oid \
963             WHERE pg_type.typname = $1 AND pg_namespace.nspname = $2 \
964             LIMIT 1",
965                &[&type_name, schema],
966            )
967            .await
968            .map_err(ErrorHelper)?
969    } else {
970        raw_connection
971            .query_one(
972                "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
973             WHERE pg_type.oid = quote_ident($1)::regtype::oid \
974             LIMIT 1",
975                &[&type_name],
976            )
977            .await
978            .map_err(ErrorHelper)?
979    };
980    Ok((r.get(0), r.get(1)))
981}
982
983fn unwrap_oids(metadata: &PgTypeMetadata) -> (u32, u32) {
984    let err_msg = "PgTypeMetadata is supposed to always be Ok here";
985    (
986        metadata.oid().expect(err_msg),
987        metadata.array_oid().expect(err_msg),
988    )
989}
990
991fn replace_fake_oid(
992    binds: &mut [Option<Vec<u8>>],
993    real_oids: &HashMap<u32, u32>,
994    bind_index: usize,
995    byte_index: usize,
996) -> Option<()> {
997    let serialized_oid = binds
998        .get_mut(bind_index)?
999        .as_mut()?
1000        .get_mut(byte_index..)?
1001        .first_chunk_mut::<4>()?;
1002    *serialized_oid = real_oids
1003        .get(&u32::from_be_bytes(*serialized_oid))?
1004        .to_be_bytes();
1005    Some(())
1006}
1007
1008async fn drive_future<R>(
1009    connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
1010    client_future: impl Future<Output = Result<R, diesel::result::Error>>,
1011) -> Result<R, diesel::result::Error> {
1012    if let Some(mut connection_future) = connection_future {
1013        let client_future = std::pin::pin!(client_future);
1014        let connection_future = std::pin::pin!(connection_future.recv());
1015        match futures_util::future::select(client_future, connection_future).await {
1016            Either::Left((res, _)) => res,
1017            // we got an error from the background task
1018            // return it to the user
1019            Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
1020            // seems like the background thread died for whatever reason
1021            Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
1022                DatabaseErrorKind::UnableToSendCommand,
1023                Box::new(e.to_string()),
1024            )),
1025        }
1026    } else {
1027        client_future.await
1028    }
1029}
1030
1031fn drive_connection<S>(
1032    mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
1033) -> (
1034    broadcast::Receiver<Arc<tokio_postgres::Error>>,
1035    mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>,
1036    oneshot::Sender<()>,
1037)
1038where
1039    S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
1040{
1041    let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
1042    let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel();
1043    let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
1044    let mut conn = futures_util::stream::poll_fn(move |cx| conn.poll_message(cx));
1045
1046    tokio::spawn(async move {
1047        loop {
1048            match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
1049                Either::Left(_) | Either::Right((None, _)) => break,
1050                Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
1051                    let _: Result<_, _> = notification_tx.send(Ok(diesel::pg::PgNotification {
1052                        process_id: notif.process_id(),
1053                        channel: notif.channel().to_owned(),
1054                        payload: notif.payload().to_owned(),
1055                    }));
1056                }
1057                Either::Right((Some(Ok(_)), _)) => {}
1058                Either::Right((Some(Err(e)), _)) => {
1059                    let e = Arc::new(e);
1060                    let _: Result<_, _> = error_tx.send(e.clone());
1061                    let _: Result<_, _> =
1062                        notification_tx.send(Err(error_helper::from_tokio_postgres_error(e)));
1063                    break;
1064                }
1065            }
1066        }
1067    });
1068
1069    (error_rx, notification_rx, shutdown_tx)
1070}
1071
1072#[cfg(any(
1073    feature = "deadpool",
1074    feature = "bb8",
1075    feature = "mobc",
1076    feature = "r2d2"
1077))]
1078impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
1079    fn is_broken(&mut self) -> bool {
1080        use crate::TransactionManager;
1081
1082        Self::TransactionManager::is_broken_transaction_manager(self) || self.conn.is_closed()
1083    }
1084}
1085
1086impl QueryFragmentForCachedStatement<Pg> for QueryFragmentHelper {
1087    fn construct_sql(&self, _backend: &Pg) -> QueryResult<String> {
1088        Ok(self.sql.clone())
1089    }
1090
1091    fn is_safe_to_cache_prepared(&self, _backend: &Pg) -> QueryResult<bool> {
1092        Ok(self.safe_to_cache)
1093    }
1094}
1095
1096#[cfg(test)]
1097mod tests {
1098    use super::*;
1099    use crate::run_query_dsl::RunQueryDsl;
1100    use diesel::sql_types::Integer;
1101    use diesel::IntoSql;
1102    use futures_util::future::try_join;
1103    use scoped_futures::ScopedFutureExt;
1104
1105    #[tokio::test]
1106    async fn pipelining() {
1107        let database_url =
1108            std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1109
1110        let mut conn = crate::AsyncPgConnection::establish(&database_url)
1111            .await
1112            .unwrap();
1113
1114        let q1 = diesel::select(1_i32.into_sql::<Integer>());
1115        let q2 = diesel::select(2_i32.into_sql::<Integer>());
1116
1117        let f1 = q1.get_result::<i32>(&mut conn);
1118        let f2 = q2.get_result::<i32>(&mut conn);
1119
1120        let (r1, r2) = try_join(f1, f2).await.unwrap();
1121
1122        assert_eq!(r1, 1);
1123        assert_eq!(r2, 2);
1124    }
1125
1126    #[tokio::test]
1127    async fn pipelining_with_composed_futures() {
1128        let database_url =
1129            std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1130
1131        let conn = crate::AsyncPgConnection::establish(&database_url)
1132            .await
1133            .unwrap();
1134
1135        async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1136            let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1137            let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1138
1139            try_join(f1, f2).await
1140        }
1141
1142        async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1143            let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1144            let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1145
1146            try_join(f3, f4).await
1147        }
1148
1149        let f12 = fn12(&conn);
1150        let f34 = fn34(&conn);
1151
1152        let ((r1, r2), (r3, r4)) = try_join(f12, f34).await.unwrap();
1153
1154        assert_eq!(r1, 1);
1155        assert_eq!(r2, 2);
1156        assert_eq!(r3, 3);
1157        assert_eq!(r4, 4);
1158    }
1159
1160    #[tokio::test]
1161    async fn pipelining_with_composed_futures_and_transaction() {
1162        let database_url =
1163            std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1164
1165        let mut conn = crate::AsyncPgConnection::establish(&database_url)
1166            .await
1167            .unwrap();
1168
1169        async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1170            let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1171            let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1172
1173            try_join(f1, f2).await
1174        }
1175
1176        async fn fn37(
1177            mut conn: &AsyncPgConnection,
1178        ) -> QueryResult<(usize, (Vec<i32>, (i32, (Vec<i32>, i32))))> {
1179            let f3 = diesel::select(0_i32.into_sql::<Integer>()).execute(&mut conn);
1180            let f4 = diesel::select(4_i32.into_sql::<Integer>()).load::<i32>(&mut conn);
1181            let f5 = diesel::select(5_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1182            let f6 = diesel::select(6_i32.into_sql::<Integer>()).get_results::<i32>(&mut conn);
1183            let f7 = diesel::select(7_i32.into_sql::<Integer>()).first::<i32>(&mut conn);
1184
1185            try_join(f3, try_join(f4, try_join(f5, try_join(f6, f7)))).await
1186        }
1187
1188        conn.transaction(|conn| {
1189            async move {
1190                let f12 = fn12(conn);
1191                let f37 = fn37(conn);
1192
1193                let ((r1, r2), (r3, (r4, (r5, (r6, r7))))) = try_join(f12, f37).await.unwrap();
1194
1195                assert_eq!(r1, 1);
1196                assert_eq!(r2, 2);
1197                assert_eq!(r3, 1);
1198                assert_eq!(r4, vec![4]);
1199                assert_eq!(r5, 5);
1200                assert_eq!(r6, vec![6]);
1201                assert_eq!(r7, 7);
1202
1203                fn12(conn).await?;
1204                fn37(conn).await?;
1205
1206                QueryResult::<_>::Ok(())
1207            }
1208            .scope_boxed()
1209        })
1210        .await
1211        .unwrap();
1212    }
1213}