1use self::error_helper::ErrorHelper;
8use self::row::PgRow;
9use self::serialize::ToSqlHelper;
10use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
11use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection};
12use diesel::connection::statement_cache::{
13 PrepareForCache, QueryFragmentForCachedStatement, StatementCache,
14};
15use diesel::connection::StrQueryHelper;
16use diesel::connection::{CacheSize, Instrumentation};
17use diesel::connection::{DynInstrumentation, InstrumentationEvent};
18use diesel::pg::{
19 Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata,
20};
21use diesel::query_builder::bind_collector::RawBytesBindCollector;
22use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
23use diesel::result::{DatabaseErrorKind, Error};
24use diesel::{ConnectionError, ConnectionResult, QueryResult};
25use futures_core::future::BoxFuture;
26use futures_core::stream::BoxStream;
27use futures_util::future::Either;
28use futures_util::stream::TryStreamExt;
29use futures_util::TryFutureExt;
30use futures_util::{FutureExt, StreamExt};
31use std::collections::{HashMap, HashSet};
32use std::future::Future;
33use std::sync::Arc;
34use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
35use tokio_postgres::types::ToSql;
36use tokio_postgres::types::Type;
37use tokio_postgres::Statement;
38
39pub use self::transaction_builder::TransactionBuilder;
40
41mod error_helper;
42mod row;
43mod serialize;
44mod transaction_builder;
45
46const FAKE_OID: u32 = 0;
47
48pub struct AsyncPgConnection {
168 conn: Arc<tokio_postgres::Client>,
169 stmt_cache: Arc<Mutex<StatementCache<diesel::pg::Pg, Statement>>>,
170 transaction_state: Arc<Mutex<AnsiTransactionManager>>,
171 metadata_cache: Arc<Mutex<PgMetadataCache>>,
172 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
173 notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
174 shutdown_channel: Option<oneshot::Sender<()>>,
175 instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
177}
178
179impl SimpleAsyncConnection for AsyncPgConnection {
180 async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
181 SimpleAsyncConnection::batch_execute(&mut &*self, query).await
182 }
183}
184
185impl SimpleAsyncConnection for &AsyncPgConnection {
186 async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
187 self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
188 query,
189 )));
190 let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
191 let batch_execute = self
192 .conn
193 .batch_execute(query)
194 .map_err(ErrorHelper)
195 .map_err(Into::into);
196
197 let r = drive_future(connection_future, batch_execute).await;
198 let r = {
199 let mut transaction_manager = self.transaction_state.lock().await;
200 update_transaction_manager_status(r, &mut transaction_manager)
201 };
202 self.record_instrumentation(InstrumentationEvent::finish_query(
203 &StrQueryHelper::new(query),
204 r.as_ref().err(),
205 ));
206 r
207 }
208}
209
210impl AsyncConnectionCore for AsyncPgConnection {
211 type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
212 type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
213 type Stream<'conn, 'query> = BoxStream<'static, QueryResult<PgRow>>;
214 type Row<'conn, 'query> = PgRow;
215 type Backend = diesel::pg::Pg;
216
217 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
218 where
219 T: AsQuery + 'query,
220 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
221 {
222 AsyncConnectionCore::load(&mut &*self, source)
223 }
224
225 fn execute_returning_count<'conn, 'query, T>(
226 &'conn mut self,
227 source: T,
228 ) -> Self::ExecuteFuture<'conn, 'query>
229 where
230 T: QueryFragment<Self::Backend> + QueryId + 'query,
231 {
232 AsyncConnectionCore::execute_returning_count(&mut &*self, source)
233 }
234}
235
236impl AsyncConnectionCore for &AsyncPgConnection {
237 type LoadFuture<'conn, 'query> =
238 <AsyncPgConnection as AsyncConnectionCore>::LoadFuture<'conn, 'query>;
239
240 type ExecuteFuture<'conn, 'query> =
241 <AsyncPgConnection as AsyncConnectionCore>::ExecuteFuture<'conn, 'query>;
242
243 type Stream<'conn, 'query> = <AsyncPgConnection as AsyncConnectionCore>::Stream<'conn, 'query>;
244
245 type Row<'conn, 'query> = <AsyncPgConnection as AsyncConnectionCore>::Row<'conn, 'query>;
246
247 type Backend = <AsyncPgConnection as AsyncConnectionCore>::Backend;
248
249 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
250 where
251 T: AsQuery + 'query,
252 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
253 {
254 let query = source.as_query();
255 let load_future = self.with_prepared_statement(query, load_prepared);
256
257 self.run_with_connection_future(load_future)
258 }
259
260 fn execute_returning_count<'conn, 'query, T>(
261 &'conn mut self,
262 source: T,
263 ) -> Self::ExecuteFuture<'conn, 'query>
264 where
265 T: QueryFragment<Self::Backend> + QueryId + 'query,
266 {
267 let execute = self.with_prepared_statement(source, execute_prepared);
268 self.run_with_connection_future(execute)
269 }
270}
271
272impl AsyncConnection for AsyncPgConnection {
273 type TransactionManager = AnsiTransactionManager;
274
275 async fn establish(database_url: &str) -> ConnectionResult<Self> {
276 let mut instrumentation = DynInstrumentation::default_instrumentation();
277 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
278 database_url,
279 ));
280 let instrumentation = Arc::new(std::sync::Mutex::new(instrumentation));
281 let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
282 .await
283 .map_err(ErrorHelper)?;
284
285 let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection);
286
287 let r = Self::setup(
288 client,
289 Some(error_rx),
290 Some(notification_rx),
291 Some(shutdown_tx),
292 Arc::clone(&instrumentation),
293 )
294 .await;
295
296 instrumentation
297 .lock()
298 .unwrap_or_else(|e| e.into_inner())
299 .on_connection_event(InstrumentationEvent::finish_establish_connection(
300 database_url,
301 r.as_ref().err(),
302 ));
303 r
304 }
305
306 fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
307 if let Some(tm) = Arc::get_mut(&mut self.transaction_state) {
311 tm.get_mut()
312 } else {
313 panic!("Cannot access shared transaction state")
314 }
315 }
316
317 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
318 if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) {
322 &mut **(instrumentation.get_mut().unwrap_or_else(|p| p.into_inner()))
323 } else {
324 panic!("Cannot access shared instrumentation")
325 }
326 }
327
328 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
329 self.instrumentation = Arc::new(std::sync::Mutex::new(instrumentation.into()));
330 }
331
332 fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
333 if let Some(cache) = Arc::get_mut(&mut self.stmt_cache) {
337 cache.get_mut().set_cache_size(size)
338 } else {
339 panic!("Cannot access shared statement cache")
340 }
341 }
342}
343
344impl Drop for AsyncPgConnection {
345 fn drop(&mut self) {
346 if let Some(tx) = self.shutdown_channel.take() {
347 let _ = tx.send(());
348 }
349 }
350}
351
352async fn load_prepared(
353 conn: Arc<tokio_postgres::Client>,
354 stmt: Statement,
355 binds: Vec<ToSqlHelper>,
356) -> QueryResult<BoxStream<'static, QueryResult<PgRow>>> {
357 let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
358
359 Ok(res
360 .map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
361 .map_ok(PgRow::new)
362 .boxed())
363}
364
365async fn execute_prepared(
366 conn: Arc<tokio_postgres::Client>,
367 stmt: Statement,
368 binds: Vec<ToSqlHelper>,
369) -> QueryResult<usize> {
370 let binds = binds
371 .iter()
372 .map(|b| b as &(dyn ToSql + Sync))
373 .collect::<Vec<_>>();
374
375 let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_])
376 .await
377 .map_err(ErrorHelper)?;
378 res.try_into()
379 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
380}
381
382#[inline(always)]
383fn update_transaction_manager_status<T>(
384 query_result: QueryResult<T>,
385 transaction_manager: &mut AnsiTransactionManager,
386) -> QueryResult<T> {
387 if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
388 query_result
389 {
390 if !transaction_manager.is_commit {
391 transaction_manager
392 .status
393 .set_requires_rollback_maybe_up_to_top_level(true);
394 }
395 }
396 query_result
397}
398
399fn prepare_statement_helper(
400 conn: Arc<tokio_postgres::Client>,
401 sql: &str,
402 _is_for_cache: PrepareForCache,
403 metadata: &[PgTypeMetadata],
404) -> CallbackHelper<
405 impl Future<Output = QueryResult<(Statement, Arc<tokio_postgres::Client>)>> + Send,
406> {
407 let bind_types = metadata
408 .iter()
409 .map(type_from_oid)
410 .collect::<QueryResult<Vec<_>>>();
411 let sql = sql.to_string();
418 CallbackHelper(async move {
419 let bind_types = bind_types?;
420 let stmt = conn
421 .prepare_typed(&sql, &bind_types)
422 .await
423 .map_err(ErrorHelper);
424 Ok((stmt?, conn))
425 })
426}
427
428fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
429 let oid = t
430 .oid()
431 .map_err(|e| diesel::result::Error::SerializationError(Box::new(e) as _))?;
432
433 if let Some(tpe) = Type::from_oid(oid) {
434 return Ok(tpe);
435 }
436
437 Ok(Type::new(
438 format!("diesel_custom_type_{oid}"),
439 oid,
440 tokio_postgres::types::Kind::Simple,
441 "public".into(),
442 ))
443}
444
445impl AsyncPgConnection {
446 pub fn build_transaction(&mut self) -> TransactionBuilder<'_, Self> {
473 TransactionBuilder::new(self)
474 }
475
476 pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
478 Self::setup(
479 conn,
480 None,
481 None,
482 None,
483 Arc::new(std::sync::Mutex::new(
484 DynInstrumentation::default_instrumentation(),
485 )),
486 )
487 .await
488 }
489
490 pub async fn try_from_client_and_connection<S>(
493 client: tokio_postgres::Client,
494 conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
495 ) -> ConnectionResult<Self>
496 where
497 S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
498 {
499 let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn);
500
501 Self::setup(
502 client,
503 Some(error_rx),
504 Some(notification_rx),
505 Some(shutdown_tx),
506 Arc::new(std::sync::Mutex::new(
507 DynInstrumentation::default_instrumentation(),
508 )),
509 )
510 .await
511 }
512
513 async fn setup(
514 conn: tokio_postgres::Client,
515 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
516 notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
517 shutdown_channel: Option<oneshot::Sender<()>>,
518 instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
519 ) -> ConnectionResult<Self> {
520 let mut conn = Self {
521 conn: Arc::new(conn),
522 stmt_cache: Arc::new(Mutex::new(StatementCache::new())),
523 transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
524 metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
525 connection_future,
526 notification_rx,
527 shutdown_channel,
528 instrumentation,
529 };
530 conn.set_config_options()
531 .await
532 .map_err(ConnectionError::CouldntSetupConfiguration)?;
533 Ok(conn)
534 }
535
536 pub fn cancel_token(&self) -> tokio_postgres::CancelToken {
538 self.conn.cancel_token()
539 }
540
541 async fn set_config_options(&mut self) -> QueryResult<()> {
542 use crate::run_query_dsl::RunQueryDsl;
543
544 futures_util::future::try_join(
545 diesel::sql_query("SET TIME ZONE 'UTC'").execute(self),
546 diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'").execute(self),
547 )
548 .await?;
549 Ok(())
550 }
551
552 fn run_with_connection_future<'a, R: 'a>(
553 &self,
554 future: impl Future<Output = QueryResult<R>> + Send + 'a,
555 ) -> BoxFuture<'a, QueryResult<R>> {
556 let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
557 drive_future(connection_future, future).boxed()
558 }
559
560 fn with_prepared_statement<'a, T, F, R>(
561 &self,
562 query: T,
563 callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
564 ) -> BoxFuture<'a, QueryResult<R>>
565 where
566 T: QueryFragment<diesel::pg::Pg> + QueryId,
567 F: Future<Output = QueryResult<R>> + Send + 'a,
568 R: Send,
569 {
570 self.record_instrumentation(InstrumentationEvent::start_query(&diesel::debug_query(
571 &query,
572 )));
573 let mut query_builder = PgQueryBuilder::default();
581
582 let bind_data = construct_bind_data(&query);
583
584 self.with_prepared_statement_after_sql_built(
586 callback,
587 query.is_safe_to_cache_prepared(&Pg),
588 T::query_id(),
589 query.to_sql(&mut query_builder, &Pg),
590 query_builder,
591 bind_data,
592 )
593 }
594
595 fn with_prepared_statement_after_sql_built<'a, F, R>(
596 &self,
597 callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
598 is_safe_to_cache_prepared: QueryResult<bool>,
599 query_id: Option<std::any::TypeId>,
600 to_sql_result: QueryResult<()>,
601 query_builder: PgQueryBuilder,
602 bind_data: BindData,
603 ) -> BoxFuture<'a, QueryResult<R>>
604 where
605 F: Future<Output = QueryResult<R>> + Send + 'a,
606 R: Send,
607 {
608 let raw_connection = self.conn.clone();
609 let stmt_cache = self.stmt_cache.clone();
610 let metadata_cache = self.metadata_cache.clone();
611 let tm = self.transaction_state.clone();
612 let instrumentation = self.instrumentation.clone();
613 let BindData {
614 collect_bind_result,
615 fake_oid_locations,
616 generated_oids,
617 mut bind_collector,
618 } = bind_data;
619
620 async move {
621 let sql = to_sql_result.map(|_| query_builder.finish())?;
622 let res = async {
623 let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
624 collect_bind_result?;
625 if let Some(ref unresolved_types) = generated_oids {
630 let metadata_cache = &mut *metadata_cache.lock().await;
631 let mut real_oids = HashMap::new();
632
633 for ((schema, lookup_type_name), (fake_oid, fake_array_oid)) in
634 unresolved_types
635 {
636 let cache_key = PgMetadataCacheKey::new(
640 schema.as_deref().map(Into::into),
641 lookup_type_name.into(),
642 );
643 let real_metadata = if let Some(type_metadata) =
644 metadata_cache.lookup_type(&cache_key)
645 {
646 type_metadata
647 } else {
648 let type_metadata =
649 lookup_type(schema.clone(), lookup_type_name.clone(), &raw_connection)
650 .await?;
651 metadata_cache.store_type(cache_key, type_metadata);
652
653 PgTypeMetadata::from_result(Ok(type_metadata))
654 };
655 let (real_oid, real_array_oid) = unwrap_oids(&real_metadata);
657 real_oids.extend([(*fake_oid, real_oid), (*fake_array_oid, real_array_oid)]);
658 }
659
660 for m in &mut bind_collector.metadata {
662 let (oid, array_oid) = unwrap_oids(m);
663 *m = PgTypeMetadata::new(
664 real_oids.get(&oid).copied().unwrap_or(oid),
665 real_oids.get(&array_oid).copied().unwrap_or(array_oid)
666 );
667 }
668 for (bind_index, byte_index) in fake_oid_locations {
670 replace_fake_oid(&mut bind_collector.binds, &real_oids, bind_index, byte_index)
671 .ok_or_else(|| {
672 Error::SerializationError(
673 format!("diesel_async failed to replace a type OID serialized in bind value {bind_index}").into(),
674 )
675 })?;
676 }
677 }
678 let stmt = {
679 let mut stmt_cache = stmt_cache.lock().await;
680 let helper = QueryFragmentHelper {
681 sql: sql.clone(),
682 safe_to_cache: is_safe_to_cache_prepared,
683 };
684 let instrumentation = Arc::clone(&instrumentation);
685 stmt_cache
686 .cached_statement_non_generic(
687 query_id,
688 &helper,
689 &Pg,
690 &bind_collector.metadata,
691 raw_connection.clone(),
692 prepare_statement_helper,
693 &mut move |event: InstrumentationEvent<'_>| {
694 instrumentation.lock().unwrap_or_else(|e| e.into_inner())
697 .on_connection_event(event);
698 },
699 )
700 .await?
701 .0
702 .clone()
703 };
704
705 let binds = bind_collector
706 .metadata
707 .into_iter()
708 .zip(bind_collector.binds)
709 .map(|(meta, bind)| ToSqlHelper(meta, bind))
710 .collect::<Vec<_>>();
711 callback(raw_connection, stmt.clone(), binds).await
712 };
713 let res = res.await;
714 let mut tm = tm.lock().await;
715 let r = update_transaction_manager_status(res, &mut tm);
716 instrumentation
717 .lock()
718 .unwrap_or_else(|p| p.into_inner())
719 .on_connection_event(InstrumentationEvent::finish_query(
720 &StrQueryHelper::new(&sql),
721 r.as_ref().err(),
722 ));
723
724 r
725 }
726 .boxed()
727 }
728
729 fn record_instrumentation(&self, event: InstrumentationEvent<'_>) {
730 self.instrumentation
731 .lock()
732 .unwrap_or_else(|p| p.into_inner())
733 .on_connection_event(event);
734 }
735
736 pub fn notifications_stream(
778 &mut self,
779 ) -> impl futures_core::Stream<Item = QueryResult<diesel::pg::PgNotification>> + '_ {
780 match &mut self.notification_rx {
781 None => Either::Left(futures_util::stream::pending()),
782 Some(rx) => Either::Right(futures_util::stream::unfold(rx, |rx| async {
783 rx.recv().await.map(move |item| (item, rx))
784 })),
785 }
786 }
787}
788
789struct BindData {
790 collect_bind_result: Result<(), Error>,
791 fake_oid_locations: Vec<(usize, usize)>,
792 generated_oids: GeneratedOidTypeMap,
793 bind_collector: RawBytesBindCollector<Pg>,
794}
795
796fn construct_bind_data(query: &dyn QueryFragment<diesel::pg::Pg>) -> BindData {
797 let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
807 let mut metadata_lookup_0 = PgAsyncMetadataLookup {
808 custom_oid: false,
809 generated_oids: None,
810 oid_generator: |_, _| (FAKE_OID, FAKE_OID),
811 };
812 let collect_bind_result_0 =
813 query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
814 if metadata_lookup_0.custom_oid {
824 let mut max_oid = bind_collector_0
828 .metadata
829 .iter()
830 .flat_map(|t| {
831 [
832 t.oid().unwrap_or_default(),
833 t.array_oid().unwrap_or_default(),
834 ]
835 })
836 .max()
837 .unwrap_or_default();
838 let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
839 let mut metadata_lookup_1 = PgAsyncMetadataLookup {
840 custom_oid: false,
841 generated_oids: Some(HashMap::new()),
842 oid_generator: move |_, _| {
843 max_oid += 2;
844 (max_oid, max_oid + 1)
845 },
846 };
847 let collect_bind_result_1 =
848 query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
849
850 assert_eq!(
851 bind_collector_0.binds.len(),
852 bind_collector_0.metadata.len()
853 );
854 let fake_oid_locations = std::iter::zip(
855 bind_collector_0
856 .binds
857 .iter()
858 .zip(&bind_collector_0.metadata),
859 &bind_collector_1.binds,
860 )
861 .enumerate()
862 .flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
863 let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
867 (
868 bytes_0.as_deref().unwrap_or_default(),
869 bytes_1.as_deref().unwrap_or_default(),
870 )
871 } else {
872 (&[] as &[_], &[] as &[_])
876 };
877 let lookup_map = metadata_lookup_1
878 .generated_oids
879 .as_ref()
880 .map(|map| {
881 map.values()
882 .flat_map(|(oid, array_oid)| [*oid, *array_oid])
883 .collect::<HashSet<_>>()
884 })
885 .unwrap_or_default();
886 std::iter::zip(
887 bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
888 bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
889 )
890 .enumerate()
891 .filter_map(move |(byte_index, (l, r))| {
892 let r_val = u32::from_be_bytes(r.try_into().expect("That's the right size"));
901 (l == FAKE_OID.to_be_bytes()
902 && r != FAKE_OID.to_be_bytes()
903 && lookup_map.contains(&r_val))
904 .then_some((bind_index, byte_index))
905 })
906 })
907 .collect::<Vec<_>>();
909 BindData {
910 collect_bind_result: collect_bind_result_0.and(collect_bind_result_1),
911 fake_oid_locations,
912 generated_oids: metadata_lookup_1.generated_oids,
913 bind_collector: bind_collector_1,
914 }
915 } else {
916 BindData {
917 collect_bind_result: collect_bind_result_0,
918 fake_oid_locations: Vec::new(),
919 generated_oids: None,
920 bind_collector: bind_collector_0,
921 }
922 }
923}
924
925type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
926
927struct PgAsyncMetadataLookup<F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static> {
930 custom_oid: bool,
931 generated_oids: GeneratedOidTypeMap,
932 oid_generator: F,
933}
934
935impl<F> PgMetadataLookup for PgAsyncMetadataLookup<F>
936where
937 F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static,
938{
939 fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> PgTypeMetadata {
940 self.custom_oid = true;
941
942 let oid = if let Some(map) = &mut self.generated_oids {
943 *map.entry((schema.map(ToOwned::to_owned), type_name.to_owned()))
944 .or_insert_with(|| (self.oid_generator)(type_name, schema))
945 } else {
946 (self.oid_generator)(type_name, schema)
947 };
948
949 PgTypeMetadata::from_result(Ok(oid))
950 }
951}
952
953async fn lookup_type(
954 schema: Option<String>,
955 type_name: String,
956 raw_connection: &tokio_postgres::Client,
957) -> QueryResult<(u32, u32)> {
958 let r = if let Some(schema) = schema.as_ref() {
959 raw_connection
960 .query_one(
961 "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
962 INNER JOIN pg_namespace ON pg_type.typnamespace = pg_namespace.oid \
963 WHERE pg_type.typname = $1 AND pg_namespace.nspname = $2 \
964 LIMIT 1",
965 &[&type_name, schema],
966 )
967 .await
968 .map_err(ErrorHelper)?
969 } else {
970 raw_connection
971 .query_one(
972 "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
973 WHERE pg_type.oid = quote_ident($1)::regtype::oid \
974 LIMIT 1",
975 &[&type_name],
976 )
977 .await
978 .map_err(ErrorHelper)?
979 };
980 Ok((r.get(0), r.get(1)))
981}
982
983fn unwrap_oids(metadata: &PgTypeMetadata) -> (u32, u32) {
984 let err_msg = "PgTypeMetadata is supposed to always be Ok here";
985 (
986 metadata.oid().expect(err_msg),
987 metadata.array_oid().expect(err_msg),
988 )
989}
990
991fn replace_fake_oid(
992 binds: &mut [Option<Vec<u8>>],
993 real_oids: &HashMap<u32, u32>,
994 bind_index: usize,
995 byte_index: usize,
996) -> Option<()> {
997 let serialized_oid = binds
998 .get_mut(bind_index)?
999 .as_mut()?
1000 .get_mut(byte_index..)?
1001 .first_chunk_mut::<4>()?;
1002 *serialized_oid = real_oids
1003 .get(&u32::from_be_bytes(*serialized_oid))?
1004 .to_be_bytes();
1005 Some(())
1006}
1007
1008async fn drive_future<R>(
1009 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
1010 client_future: impl Future<Output = Result<R, diesel::result::Error>>,
1011) -> Result<R, diesel::result::Error> {
1012 if let Some(mut connection_future) = connection_future {
1013 let client_future = std::pin::pin!(client_future);
1014 let connection_future = std::pin::pin!(connection_future.recv());
1015 match futures_util::future::select(client_future, connection_future).await {
1016 Either::Left((res, _)) => res,
1017 Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
1020 Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
1022 DatabaseErrorKind::UnableToSendCommand,
1023 Box::new(e.to_string()),
1024 )),
1025 }
1026 } else {
1027 client_future.await
1028 }
1029}
1030
1031fn drive_connection<S>(
1032 mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
1033) -> (
1034 broadcast::Receiver<Arc<tokio_postgres::Error>>,
1035 mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>,
1036 oneshot::Sender<()>,
1037)
1038where
1039 S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
1040{
1041 let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
1042 let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel();
1043 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
1044 let mut conn = futures_util::stream::poll_fn(move |cx| conn.poll_message(cx));
1045
1046 tokio::spawn(async move {
1047 loop {
1048 match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
1049 Either::Left(_) | Either::Right((None, _)) => break,
1050 Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
1051 let _: Result<_, _> = notification_tx.send(Ok(diesel::pg::PgNotification {
1052 process_id: notif.process_id(),
1053 channel: notif.channel().to_owned(),
1054 payload: notif.payload().to_owned(),
1055 }));
1056 }
1057 Either::Right((Some(Ok(_)), _)) => {}
1058 Either::Right((Some(Err(e)), _)) => {
1059 let e = Arc::new(e);
1060 let _: Result<_, _> = error_tx.send(e.clone());
1061 let _: Result<_, _> =
1062 notification_tx.send(Err(error_helper::from_tokio_postgres_error(e)));
1063 break;
1064 }
1065 }
1066 }
1067 });
1068
1069 (error_rx, notification_rx, shutdown_tx)
1070}
1071
1072#[cfg(any(
1073 feature = "deadpool",
1074 feature = "bb8",
1075 feature = "mobc",
1076 feature = "r2d2"
1077))]
1078impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
1079 fn is_broken(&mut self) -> bool {
1080 use crate::TransactionManager;
1081
1082 Self::TransactionManager::is_broken_transaction_manager(self) || self.conn.is_closed()
1083 }
1084}
1085
1086impl QueryFragmentForCachedStatement<Pg> for QueryFragmentHelper {
1087 fn construct_sql(&self, _backend: &Pg) -> QueryResult<String> {
1088 Ok(self.sql.clone())
1089 }
1090
1091 fn is_safe_to_cache_prepared(&self, _backend: &Pg) -> QueryResult<bool> {
1092 Ok(self.safe_to_cache)
1093 }
1094}
1095
1096#[cfg(test)]
1097mod tests {
1098 use super::*;
1099 use crate::run_query_dsl::RunQueryDsl;
1100 use diesel::sql_types::Integer;
1101 use diesel::IntoSql;
1102 use futures_util::future::try_join;
1103 use scoped_futures::ScopedFutureExt;
1104
1105 #[tokio::test]
1106 async fn pipelining() {
1107 let database_url =
1108 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1109
1110 let mut conn = crate::AsyncPgConnection::establish(&database_url)
1111 .await
1112 .unwrap();
1113
1114 let q1 = diesel::select(1_i32.into_sql::<Integer>());
1115 let q2 = diesel::select(2_i32.into_sql::<Integer>());
1116
1117 let f1 = q1.get_result::<i32>(&mut conn);
1118 let f2 = q2.get_result::<i32>(&mut conn);
1119
1120 let (r1, r2) = try_join(f1, f2).await.unwrap();
1121
1122 assert_eq!(r1, 1);
1123 assert_eq!(r2, 2);
1124 }
1125
1126 #[tokio::test]
1127 async fn pipelining_with_composed_futures() {
1128 let database_url =
1129 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1130
1131 let conn = crate::AsyncPgConnection::establish(&database_url)
1132 .await
1133 .unwrap();
1134
1135 async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1136 let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1137 let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1138
1139 try_join(f1, f2).await
1140 }
1141
1142 async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1143 let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1144 let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1145
1146 try_join(f3, f4).await
1147 }
1148
1149 let f12 = fn12(&conn);
1150 let f34 = fn34(&conn);
1151
1152 let ((r1, r2), (r3, r4)) = try_join(f12, f34).await.unwrap();
1153
1154 assert_eq!(r1, 1);
1155 assert_eq!(r2, 2);
1156 assert_eq!(r3, 3);
1157 assert_eq!(r4, 4);
1158 }
1159
1160 #[tokio::test]
1161 async fn pipelining_with_composed_futures_and_transaction() {
1162 let database_url =
1163 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1164
1165 let mut conn = crate::AsyncPgConnection::establish(&database_url)
1166 .await
1167 .unwrap();
1168
1169 async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1170 let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1171 let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1172
1173 try_join(f1, f2).await
1174 }
1175
1176 async fn fn37(
1177 mut conn: &AsyncPgConnection,
1178 ) -> QueryResult<(usize, (Vec<i32>, (i32, (Vec<i32>, i32))))> {
1179 let f3 = diesel::select(0_i32.into_sql::<Integer>()).execute(&mut conn);
1180 let f4 = diesel::select(4_i32.into_sql::<Integer>()).load::<i32>(&mut conn);
1181 let f5 = diesel::select(5_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1182 let f6 = diesel::select(6_i32.into_sql::<Integer>()).get_results::<i32>(&mut conn);
1183 let f7 = diesel::select(7_i32.into_sql::<Integer>()).first::<i32>(&mut conn);
1184
1185 try_join(f3, try_join(f4, try_join(f5, try_join(f6, f7)))).await
1186 }
1187
1188 conn.transaction(|conn| {
1189 async move {
1190 let f12 = fn12(conn);
1191 let f37 = fn37(conn);
1192
1193 let ((r1, r2), (r3, (r4, (r5, (r6, r7))))) = try_join(f12, f37).await.unwrap();
1194
1195 assert_eq!(r1, 1);
1196 assert_eq!(r2, 2);
1197 assert_eq!(r3, 1);
1198 assert_eq!(r4, vec![4]);
1199 assert_eq!(r5, 5);
1200 assert_eq!(r6, vec![6]);
1201 assert_eq!(r7, 7);
1202
1203 fn12(conn).await?;
1204 fn37(conn).await?;
1205
1206 QueryResult::<_>::Ok(())
1207 }
1208 .scope_boxed()
1209 })
1210 .await
1211 .unwrap();
1212 }
1213}