1use self::error_helper::ErrorHelper;
8use self::row::PgRow;
9use self::serialize::ToSqlHelper;
10use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
11use crate::{AnsiTransactionManager, AsyncConnection, AsyncConnectionCore, SimpleAsyncConnection};
12use diesel::connection::statement_cache::{
13 PrepareForCache, QueryFragmentForCachedStatement, StatementCache,
14};
15use diesel::connection::StrQueryHelper;
16use diesel::connection::{CacheSize, Instrumentation};
17use diesel::connection::{DynInstrumentation, InstrumentationEvent};
18use diesel::pg::{
19 Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata,
20};
21use diesel::query_builder::bind_collector::RawBytesBindCollector;
22use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
23use diesel::result::{DatabaseErrorKind, Error};
24use diesel::{ConnectionError, ConnectionResult, QueryResult};
25use futures_core::future::BoxFuture;
26use futures_core::stream::BoxStream;
27use futures_util::future::Either;
28use futures_util::stream::TryStreamExt;
29use futures_util::TryFutureExt;
30use futures_util::{FutureExt, StreamExt};
31use std::collections::{HashMap, HashSet};
32use std::future::Future;
33use std::sync::Arc;
34use tokio::sync::{broadcast, mpsc, oneshot, Mutex};
35use tokio_postgres::types::ToSql;
36use tokio_postgres::types::Type;
37use tokio_postgres::Statement;
38
39pub use self::transaction_builder::TransactionBuilder;
40
41mod error_helper;
42mod row;
43mod serialize;
44mod transaction_builder;
45
46const FAKE_OID: u32 = 0;
47
48pub struct AsyncPgConnection {
168 conn: tokio_postgres::Client,
169 stmt_cache: Mutex<StatementCache<diesel::pg::Pg, Statement>>,
170 transaction_state: Mutex<AnsiTransactionManager>,
171 metadata_cache: Mutex<PgMetadataCache>,
172 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
173 notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
174 shutdown_channel: Option<oneshot::Sender<()>>,
175 instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
177}
178
179impl SimpleAsyncConnection for AsyncPgConnection {
180 async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
181 SimpleAsyncConnection::batch_execute(&mut &*self, query).await
182 }
183}
184
185impl SimpleAsyncConnection for &AsyncPgConnection {
186 async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
187 self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
188 query,
189 )));
190 let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
191 let batch_execute = self
192 .conn
193 .batch_execute(query)
194 .map_err(ErrorHelper)
195 .map_err(Into::into);
196
197 let r = drive_future(connection_future, batch_execute).await;
198 let r = {
199 let mut transaction_manager = self.transaction_state.lock().await;
200 update_transaction_manager_status(r, &mut transaction_manager)
201 };
202 self.record_instrumentation(InstrumentationEvent::finish_query(
203 &StrQueryHelper::new(query),
204 r.as_ref().err(),
205 ));
206 r
207 }
208}
209
210impl AsyncConnectionCore for AsyncPgConnection {
211 type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
213 type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
214 type Stream<'conn, 'query> = BoxStream<'static, QueryResult<PgRow>>;
215 type Row<'conn, 'query> = PgRow;
216 type Backend = diesel::pg::Pg;
217
218 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
219 where
220 T: AsQuery + 'query,
221 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
222 {
223 let query = source.as_query();
224 let load_future = self.with_prepared_statement(query, load_prepared);
225
226 self.run_with_connection_future(load_future)
227 }
228
229 fn execute_returning_count<'conn, 'query, T>(
230 &'conn mut self,
231 source: T,
232 ) -> Self::ExecuteFuture<'conn, 'query>
233 where
234 T: QueryFragment<Self::Backend> + QueryId + 'query,
235 {
236 let execute = self.with_prepared_statement(source, execute_prepared);
237 self.run_with_connection_future(execute)
238 }
239}
240
241impl<'a> AsyncConnectionCore for &'a AsyncPgConnection {
244 type LoadFuture<'conn, 'query> = BoxFuture<'a, QueryResult<Self::Stream<'conn, 'query>>>;
245 type ExecuteFuture<'conn, 'query> = BoxFuture<'a, QueryResult<usize>>;
246 type Stream<'conn, 'query> = BoxStream<'static, QueryResult<PgRow>>;
247 type Row<'conn, 'query> = PgRow;
248 type Backend = diesel::pg::Pg;
249
250 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
251 where
252 T: AsQuery + 'query,
253 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
254 {
255 let query = source.as_query();
256 let load_future = self.with_prepared_statement(query, load_prepared);
257
258 self.run_with_connection_future(load_future)
259 }
260
261 fn execute_returning_count<'conn, 'query, T>(
262 &'conn mut self,
263 source: T,
264 ) -> Self::ExecuteFuture<'conn, 'query>
265 where
266 T: QueryFragment<Self::Backend> + QueryId + 'query,
267 {
268 let execute = self.with_prepared_statement(source, execute_prepared);
269 self.run_with_connection_future(execute)
270 }
271}
272
273impl AsyncConnection for AsyncPgConnection {
274 type TransactionManager = AnsiTransactionManager;
275
276 async fn establish(database_url: &str) -> ConnectionResult<Self> {
277 let mut instrumentation = DynInstrumentation::default_instrumentation();
278 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
279 database_url,
280 ));
281 let instrumentation = Arc::new(std::sync::Mutex::new(instrumentation));
282 let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
283 .await
284 .map_err(ErrorHelper)?;
285
286 let (error_rx, notification_rx, shutdown_tx) = drive_connection(connection);
287
288 let r = Self::setup(
289 client,
290 Some(error_rx),
291 Some(notification_rx),
292 Some(shutdown_tx),
293 Arc::clone(&instrumentation),
294 )
295 .await;
296
297 instrumentation
298 .lock()
299 .unwrap_or_else(|e| e.into_inner())
300 .on_connection_event(InstrumentationEvent::finish_establish_connection(
301 database_url,
302 r.as_ref().err(),
303 ));
304 r
305 }
306
307 fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
308 self.transaction_state.get_mut()
309 }
310
311 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
312 if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) {
316 &mut **(instrumentation.get_mut().unwrap_or_else(|p| p.into_inner()))
317 } else {
318 panic!("Cannot access shared instrumentation")
319 }
320 }
321
322 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
323 self.instrumentation = Arc::new(std::sync::Mutex::new(instrumentation.into()));
324 }
325
326 fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
327 self.stmt_cache.get_mut().set_cache_size(size)
328 }
329}
330
331impl Drop for AsyncPgConnection {
332 fn drop(&mut self) {
333 if let Some(tx) = self.shutdown_channel.take() {
334 let _ = tx.send(());
335 }
336 }
337}
338
339async fn load_prepared(
340 conn: &tokio_postgres::Client,
341 stmt: Statement,
342 binds: Vec<ToSqlHelper>,
343) -> QueryResult<BoxStream<'static, QueryResult<PgRow>>> {
344 let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
345
346 Ok(res
347 .map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
348 .map_ok(PgRow::new)
349 .boxed())
350}
351
352async fn execute_prepared(
353 conn: &tokio_postgres::Client,
354 stmt: Statement,
355 binds: Vec<ToSqlHelper>,
356) -> QueryResult<usize> {
357 let binds = binds
358 .iter()
359 .map(|b| b as &(dyn ToSql + Sync))
360 .collect::<Vec<_>>();
361
362 let res = tokio_postgres::Client::execute(conn, &stmt, &binds as &[_])
363 .await
364 .map_err(ErrorHelper)?;
365 res.try_into()
366 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
367}
368
369#[inline(always)]
370fn update_transaction_manager_status<T>(
371 query_result: QueryResult<T>,
372 transaction_manager: &mut AnsiTransactionManager,
373) -> QueryResult<T> {
374 if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
375 query_result
376 {
377 if !transaction_manager.is_commit {
378 transaction_manager
379 .status
380 .set_requires_rollback_maybe_up_to_top_level(true);
381 }
382 }
383 query_result
384}
385
386fn prepare_statement_helper<'conn>(
387 conn: &'conn tokio_postgres::Client,
388 sql: &str,
389 _is_for_cache: PrepareForCache,
390 metadata: &[PgTypeMetadata],
391) -> CallbackHelper<
392 impl Future<Output = QueryResult<(Statement, &'conn tokio_postgres::Client)>> + Send,
393> {
394 let bind_types = metadata
395 .iter()
396 .map(type_from_oid)
397 .collect::<QueryResult<Vec<_>>>();
398 let sql = sql.to_string();
405 CallbackHelper(async move {
406 let bind_types = bind_types?;
407 let stmt = conn
408 .prepare_typed(&sql, &bind_types)
409 .await
410 .map_err(ErrorHelper);
411 Ok((stmt?, conn))
412 })
413}
414
415fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
416 let oid = t
417 .oid()
418 .map_err(|e| diesel::result::Error::SerializationError(Box::new(e) as _))?;
419
420 if let Some(tpe) = Type::from_oid(oid) {
421 return Ok(tpe);
422 }
423
424 Ok(Type::new(
425 format!("diesel_custom_type_{oid}"),
426 oid,
427 tokio_postgres::types::Kind::Simple,
428 "public".into(),
429 ))
430}
431
432impl AsyncPgConnection {
433 pub fn build_transaction(&mut self) -> TransactionBuilder<'_, Self> {
460 TransactionBuilder::new(self)
461 }
462
463 pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
465 Self::setup(
466 conn,
467 None,
468 None,
469 None,
470 Arc::new(std::sync::Mutex::new(
471 DynInstrumentation::default_instrumentation(),
472 )),
473 )
474 .await
475 }
476
477 pub async fn try_from_client_and_connection<S>(
480 client: tokio_postgres::Client,
481 conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
482 ) -> ConnectionResult<Self>
483 where
484 S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
485 {
486 let (error_rx, notification_rx, shutdown_tx) = drive_connection(conn);
487
488 Self::setup(
489 client,
490 Some(error_rx),
491 Some(notification_rx),
492 Some(shutdown_tx),
493 Arc::new(std::sync::Mutex::new(
494 DynInstrumentation::default_instrumentation(),
495 )),
496 )
497 .await
498 }
499
500 async fn setup(
501 conn: tokio_postgres::Client,
502 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
503 notification_rx: Option<mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>>,
504 shutdown_channel: Option<oneshot::Sender<()>>,
505 instrumentation: Arc<std::sync::Mutex<DynInstrumentation>>,
506 ) -> ConnectionResult<Self> {
507 let mut conn = Self {
508 conn,
509 stmt_cache: Mutex::new(StatementCache::new()),
510 transaction_state: Mutex::new(AnsiTransactionManager::default()),
511 metadata_cache: Mutex::new(PgMetadataCache::new()),
512 connection_future,
513 notification_rx,
514 shutdown_channel,
515 instrumentation,
516 };
517 conn.set_config_options()
518 .await
519 .map_err(ConnectionError::CouldntSetupConfiguration)?;
520 Ok(conn)
521 }
522
523 pub fn cancel_token(&self) -> tokio_postgres::CancelToken {
525 self.conn.cancel_token()
526 }
527
528 async fn set_config_options(&mut self) -> QueryResult<()> {
529 use crate::run_query_dsl::RunQueryDsl;
530
531 futures_util::future::try_join(
532 diesel::sql_query("SET TIME ZONE 'UTC'").execute(&mut &*self),
533 diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'").execute(&mut &*self),
534 )
535 .await?;
536 Ok(())
537 }
538
539 fn run_with_connection_future<'a, R: 'a>(
540 &self,
541 future: impl Future<Output = QueryResult<R>> + Send + 'a,
542 ) -> BoxFuture<'a, QueryResult<R>> {
543 let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
544 drive_future(connection_future, future).boxed()
545 }
546
547 fn with_prepared_statement<'a, T, F, R>(
548 &'a self,
549 query: T,
550 callback: fn(&'a tokio_postgres::Client, Statement, Vec<ToSqlHelper>) -> F,
551 ) -> BoxFuture<'a, QueryResult<R>>
552 where
553 T: QueryFragment<diesel::pg::Pg> + QueryId,
554 F: Future<Output = QueryResult<R>> + Send + 'a,
555 R: Send,
556 {
557 self.record_instrumentation(InstrumentationEvent::start_query(&diesel::debug_query(
558 &query,
559 )));
560 let mut query_builder = PgQueryBuilder::default();
568
569 let bind_data = construct_bind_data(&query);
570
571 self.with_prepared_statement_after_sql_built(
573 callback,
574 query.is_safe_to_cache_prepared(&Pg),
575 T::query_id(),
576 query.to_sql(&mut query_builder, &Pg),
577 query_builder,
578 bind_data,
579 )
580 }
581
582 fn with_prepared_statement_after_sql_built<'a, F, R>(
583 &'a self,
584 callback: fn(&'a tokio_postgres::Client, Statement, Vec<ToSqlHelper>) -> F,
585 is_safe_to_cache_prepared: QueryResult<bool>,
586 query_id: Option<std::any::TypeId>,
587 to_sql_result: QueryResult<()>,
588 query_builder: PgQueryBuilder,
589 bind_data: BindData,
590 ) -> BoxFuture<'a, QueryResult<R>>
591 where
592 F: Future<Output = QueryResult<R>> + Send + 'a,
593 R: Send,
594 {
595 let raw_connection = &self.conn;
596 let stmt_cache = &self.stmt_cache;
597 let metadata_cache = &self.metadata_cache;
598 let tm = &self.transaction_state;
599 let instrumentation = self.instrumentation.clone();
600 let BindData {
601 collect_bind_result,
602 fake_oid_locations,
603 generated_oids,
604 mut bind_collector,
605 } = bind_data;
606
607 async move {
608 let sql = to_sql_result.map(|_| query_builder.finish())?;
609 let res = async {
610 let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
611 collect_bind_result?;
612 if let Some(ref unresolved_types) = generated_oids {
617 let metadata_cache = &mut *metadata_cache.lock().await;
618 let mut real_oids = HashMap::new();
619
620 for ((schema, lookup_type_name), (fake_oid, fake_array_oid)) in
621 unresolved_types
622 {
623 let cache_key = PgMetadataCacheKey::new(
627 schema.as_deref().map(Into::into),
628 lookup_type_name.into(),
629 );
630 let real_metadata = if let Some(type_metadata) =
631 metadata_cache.lookup_type(&cache_key)
632 {
633 type_metadata
634 } else {
635 let type_metadata =
636 lookup_type(schema.clone(), lookup_type_name.clone(), raw_connection)
637 .await?;
638 metadata_cache.store_type(cache_key, type_metadata);
639
640 PgTypeMetadata::from_result(Ok(type_metadata))
641 };
642 let (real_oid, real_array_oid) = unwrap_oids(&real_metadata);
644 real_oids.extend([(*fake_oid, real_oid), (*fake_array_oid, real_array_oid)]);
645 }
646
647 for m in &mut bind_collector.metadata {
649 let (oid, array_oid) = unwrap_oids(m);
650 *m = PgTypeMetadata::new(
651 real_oids.get(&oid).copied().unwrap_or(oid),
652 real_oids.get(&array_oid).copied().unwrap_or(array_oid)
653 );
654 }
655 for (bind_index, byte_index) in fake_oid_locations {
657 replace_fake_oid(&mut bind_collector.binds, &real_oids, bind_index, byte_index)
658 .ok_or_else(|| {
659 Error::SerializationError(
660 format!("diesel_async failed to replace a type OID serialized in bind value {bind_index}").into(),
661 )
662 })?;
663 }
664 }
665 let stmt = {
666 let mut stmt_cache = stmt_cache.lock().await;
667 let helper = QueryFragmentHelper {
668 sql: sql.clone(),
669 safe_to_cache: is_safe_to_cache_prepared,
670 };
671 let instrumentation = Arc::clone(&instrumentation);
672 stmt_cache
673 .cached_statement_non_generic(
674 query_id,
675 &helper,
676 &Pg,
677 &bind_collector.metadata,
678 raw_connection,
679 prepare_statement_helper,
680 &mut move |event: InstrumentationEvent<'_>| {
681 instrumentation.lock().unwrap_or_else(|e| e.into_inner())
684 .on_connection_event(event);
685 },
686 )
687 .await?
688 .0
689 .clone()
690 };
691
692 let binds = bind_collector
693 .metadata
694 .into_iter()
695 .zip(bind_collector.binds)
696 .map(|(meta, bind)| ToSqlHelper(meta, bind))
697 .collect::<Vec<_>>();
698 callback(raw_connection, stmt.clone(), binds).await
699 };
700 let res = res.await;
701 let mut tm = tm.lock().await;
702 let r = update_transaction_manager_status(res, &mut tm);
703 instrumentation
704 .lock()
705 .unwrap_or_else(|p| p.into_inner())
706 .on_connection_event(InstrumentationEvent::finish_query(
707 &StrQueryHelper::new(&sql),
708 r.as_ref().err(),
709 ));
710
711 r
712 }
713 .boxed()
714 }
715
716 fn record_instrumentation(&self, event: InstrumentationEvent<'_>) {
717 self.instrumentation
718 .lock()
719 .unwrap_or_else(|p| p.into_inner())
720 .on_connection_event(event);
721 }
722
723 pub fn notifications_stream(
765 &mut self,
766 ) -> impl futures_core::Stream<Item = QueryResult<diesel::pg::PgNotification>> + '_ {
767 match &mut self.notification_rx {
768 None => Either::Left(futures_util::stream::pending()),
769 Some(rx) => Either::Right(futures_util::stream::unfold(rx, |rx| async {
770 rx.recv().await.map(move |item| (item, rx))
771 })),
772 }
773 }
774}
775
776struct BindData {
777 collect_bind_result: Result<(), Error>,
778 fake_oid_locations: Vec<(usize, usize)>,
779 generated_oids: GeneratedOidTypeMap,
780 bind_collector: RawBytesBindCollector<Pg>,
781}
782
783fn construct_bind_data(query: &dyn QueryFragment<diesel::pg::Pg>) -> BindData {
784 let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
794 let mut metadata_lookup_0 = PgAsyncMetadataLookup {
795 custom_oid: false,
796 generated_oids: None,
797 oid_generator: |_, _| (FAKE_OID, FAKE_OID),
798 };
799 let collect_bind_result_0 =
800 query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
801 if metadata_lookup_0.custom_oid {
811 let mut max_oid = bind_collector_0
815 .metadata
816 .iter()
817 .flat_map(|t| {
818 [
819 t.oid().unwrap_or_default(),
820 t.array_oid().unwrap_or_default(),
821 ]
822 })
823 .max()
824 .unwrap_or_default();
825 let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
826 let mut metadata_lookup_1 = PgAsyncMetadataLookup {
827 custom_oid: false,
828 generated_oids: Some(HashMap::new()),
829 oid_generator: move |_, _| {
830 max_oid += 2;
831 (max_oid, max_oid + 1)
832 },
833 };
834 let collect_bind_result_1 =
835 query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
836
837 assert_eq!(
838 bind_collector_0.binds.len(),
839 bind_collector_0.metadata.len()
840 );
841 let fake_oid_locations = std::iter::zip(
842 bind_collector_0
843 .binds
844 .iter()
845 .zip(&bind_collector_0.metadata),
846 &bind_collector_1.binds,
847 )
848 .enumerate()
849 .flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
850 let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
854 (
855 bytes_0.as_deref().unwrap_or_default(),
856 bytes_1.as_deref().unwrap_or_default(),
857 )
858 } else {
859 (&[] as &[_], &[] as &[_])
863 };
864 let lookup_map = metadata_lookup_1
865 .generated_oids
866 .as_ref()
867 .map(|map| {
868 map.values()
869 .flat_map(|(oid, array_oid)| [*oid, *array_oid])
870 .collect::<HashSet<_>>()
871 })
872 .unwrap_or_default();
873 std::iter::zip(
874 bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
875 bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
876 )
877 .enumerate()
878 .filter_map(move |(byte_index, (l, r))| {
879 let r_val = u32::from_be_bytes(r.try_into().expect("That's the right size"));
888 (l == FAKE_OID.to_be_bytes()
889 && r != FAKE_OID.to_be_bytes()
890 && lookup_map.contains(&r_val))
891 .then_some((bind_index, byte_index))
892 })
893 })
894 .collect::<Vec<_>>();
896 BindData {
897 collect_bind_result: collect_bind_result_0.and(collect_bind_result_1),
898 fake_oid_locations,
899 generated_oids: metadata_lookup_1.generated_oids,
900 bind_collector: bind_collector_1,
901 }
902 } else {
903 BindData {
904 collect_bind_result: collect_bind_result_0,
905 fake_oid_locations: Vec::new(),
906 generated_oids: None,
907 bind_collector: bind_collector_0,
908 }
909 }
910}
911
912type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
913
914struct PgAsyncMetadataLookup<F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static> {
917 custom_oid: bool,
918 generated_oids: GeneratedOidTypeMap,
919 oid_generator: F,
920}
921
922impl<F> PgMetadataLookup for PgAsyncMetadataLookup<F>
923where
924 F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static,
925{
926 fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> PgTypeMetadata {
927 self.custom_oid = true;
928
929 let oid = if let Some(map) = &mut self.generated_oids {
930 *map.entry((schema.map(ToOwned::to_owned), type_name.to_owned()))
931 .or_insert_with(|| (self.oid_generator)(type_name, schema))
932 } else {
933 (self.oid_generator)(type_name, schema)
934 };
935
936 PgTypeMetadata::from_result(Ok(oid))
937 }
938}
939
940async fn lookup_type(
941 schema: Option<String>,
942 type_name: String,
943 raw_connection: &tokio_postgres::Client,
944) -> QueryResult<(u32, u32)> {
945 let r = if let Some(schema) = schema.as_ref() {
946 raw_connection
947 .query_one(
948 "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
949 INNER JOIN pg_namespace ON pg_type.typnamespace = pg_namespace.oid \
950 WHERE pg_type.typname = $1 AND pg_namespace.nspname = $2 \
951 LIMIT 1",
952 &[&type_name, schema],
953 )
954 .await
955 .map_err(ErrorHelper)?
956 } else {
957 raw_connection
958 .query_one(
959 "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
960 WHERE pg_type.oid = quote_ident($1)::regtype::oid \
961 LIMIT 1",
962 &[&type_name],
963 )
964 .await
965 .map_err(ErrorHelper)?
966 };
967 Ok((r.get(0), r.get(1)))
968}
969
970fn unwrap_oids(metadata: &PgTypeMetadata) -> (u32, u32) {
971 let err_msg = "PgTypeMetadata is supposed to always be Ok here";
972 (
973 metadata.oid().expect(err_msg),
974 metadata.array_oid().expect(err_msg),
975 )
976}
977
978fn replace_fake_oid(
979 binds: &mut [Option<Vec<u8>>],
980 real_oids: &HashMap<u32, u32>,
981 bind_index: usize,
982 byte_index: usize,
983) -> Option<()> {
984 let serialized_oid = binds
985 .get_mut(bind_index)?
986 .as_mut()?
987 .get_mut(byte_index..)?
988 .first_chunk_mut::<4>()?;
989 *serialized_oid = real_oids
990 .get(&u32::from_be_bytes(*serialized_oid))?
991 .to_be_bytes();
992 Some(())
993}
994
995async fn drive_future<R>(
996 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
997 client_future: impl Future<Output = Result<R, diesel::result::Error>>,
998) -> Result<R, diesel::result::Error> {
999 if let Some(mut connection_future) = connection_future {
1000 let client_future = std::pin::pin!(client_future);
1001 let connection_future = std::pin::pin!(connection_future.recv());
1002 match futures_util::future::select(client_future, connection_future).await {
1003 Either::Left((res, _)) => res,
1004 Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
1007 Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
1009 DatabaseErrorKind::UnableToSendCommand,
1010 Box::new(e.to_string()),
1011 )),
1012 }
1013 } else {
1014 client_future.await
1015 }
1016}
1017
1018fn drive_connection<S>(
1019 mut conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
1020) -> (
1021 broadcast::Receiver<Arc<tokio_postgres::Error>>,
1022 mpsc::UnboundedReceiver<QueryResult<diesel::pg::PgNotification>>,
1023 oneshot::Sender<()>,
1024)
1025where
1026 S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
1027{
1028 let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
1029 let (notification_tx, notification_rx) = tokio::sync::mpsc::unbounded_channel();
1030 let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel();
1031 let mut conn = futures_util::stream::poll_fn(move |cx| conn.poll_message(cx));
1032
1033 tokio::spawn(async move {
1034 loop {
1035 match futures_util::future::select(&mut shutdown_rx, conn.next()).await {
1036 Either::Left(_) | Either::Right((None, _)) => break,
1037 Either::Right((Some(Ok(tokio_postgres::AsyncMessage::Notification(notif))), _)) => {
1038 let _: Result<_, _> = notification_tx.send(Ok(diesel::pg::PgNotification {
1039 process_id: notif.process_id(),
1040 channel: notif.channel().to_owned(),
1041 payload: notif.payload().to_owned(),
1042 }));
1043 }
1044 Either::Right((Some(Ok(_)), _)) => {}
1045 Either::Right((Some(Err(e)), _)) => {
1046 let e = Arc::new(e);
1047 let _: Result<_, _> = error_tx.send(e.clone());
1048 let _: Result<_, _> =
1049 notification_tx.send(Err(error_helper::from_tokio_postgres_error(e)));
1050 break;
1051 }
1052 }
1053 }
1054 });
1055
1056 (error_rx, notification_rx, shutdown_tx)
1057}
1058
1059#[cfg(any(
1060 feature = "deadpool",
1061 feature = "bb8",
1062 feature = "mobc",
1063 feature = "r2d2"
1064))]
1065impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
1066 fn is_broken(&mut self) -> bool {
1067 use crate::TransactionManager;
1068
1069 Self::TransactionManager::is_broken_transaction_manager(self) || self.conn.is_closed()
1070 }
1071}
1072
1073impl QueryFragmentForCachedStatement<Pg> for QueryFragmentHelper {
1074 fn construct_sql(&self, _backend: &Pg) -> QueryResult<String> {
1075 Ok(self.sql.clone())
1076 }
1077
1078 fn is_safe_to_cache_prepared(&self, _backend: &Pg) -> QueryResult<bool> {
1079 Ok(self.safe_to_cache)
1080 }
1081}
1082
1083#[cfg(test)]
1084mod tests {
1085 use super::*;
1086 use crate::run_query_dsl::RunQueryDsl;
1087 use diesel::sql_types::Integer;
1088 use diesel::IntoSql;
1089 use futures_util::future::try_join;
1090 use scoped_futures::ScopedFutureExt;
1091
1092 #[tokio::test]
1093 async fn pipelining() {
1094 let database_url =
1095 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1096
1097 let conn = crate::AsyncPgConnection::establish(&database_url)
1098 .await
1099 .unwrap();
1100
1101 let q1 = diesel::select(1_i32.into_sql::<Integer>());
1102 let q2 = diesel::select(2_i32.into_sql::<Integer>());
1103
1104 let f1 = q1.get_result::<i32>(&mut &conn);
1105 let f2 = q2.get_result::<i32>(&mut &conn);
1106
1107 let (r1, r2) = try_join(f1, f2).await.unwrap();
1108
1109 assert_eq!(r1, 1);
1110 assert_eq!(r2, 2);
1111 }
1112
1113 #[tokio::test]
1114 async fn pipelining_with_composed_futures() {
1115 let database_url =
1116 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1117
1118 let conn = crate::AsyncPgConnection::establish(&database_url)
1119 .await
1120 .unwrap();
1121
1122 async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1123 let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1124 let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1125
1126 try_join(f1, f2).await
1127 }
1128
1129 async fn fn34(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1130 let f3 = diesel::select(3_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1131 let f4 = diesel::select(4_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1132
1133 try_join(f3, f4).await
1134 }
1135
1136 let f12 = fn12(&conn);
1137 let f34 = fn34(&conn);
1138
1139 let ((r1, r2), (r3, r4)) = try_join(f12, f34).await.unwrap();
1140
1141 assert_eq!(r1, 1);
1142 assert_eq!(r2, 2);
1143 assert_eq!(r3, 3);
1144 assert_eq!(r4, 4);
1145 }
1146
1147 #[tokio::test]
1148 async fn pipelining_with_composed_futures_and_transaction() {
1149 let database_url =
1150 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
1151
1152 let mut conn = crate::AsyncPgConnection::establish(&database_url)
1153 .await
1154 .unwrap();
1155
1156 async fn fn12(mut conn: &AsyncPgConnection) -> QueryResult<(i32, i32)> {
1157 let f1 = diesel::select(1_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1158 let f2 = diesel::select(2_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1159
1160 try_join(f1, f2).await
1161 }
1162
1163 async fn fn37(
1164 mut conn: &AsyncPgConnection,
1165 ) -> QueryResult<(usize, (Vec<i32>, (i32, (Vec<i32>, i32))))> {
1166 let f3 = diesel::select(0_i32.into_sql::<Integer>()).execute(&mut conn);
1167 let f4 = diesel::select(4_i32.into_sql::<Integer>()).load::<i32>(&mut conn);
1168 let f5 = diesel::select(5_i32.into_sql::<Integer>()).get_result::<i32>(&mut conn);
1169 let f6 = diesel::select(6_i32.into_sql::<Integer>()).get_results::<i32>(&mut conn);
1170 let f7 = diesel::select(7_i32.into_sql::<Integer>()).first::<i32>(&mut conn);
1171
1172 try_join(f3, try_join(f4, try_join(f5, try_join(f6, f7)))).await
1173 }
1174
1175 conn.transaction(|conn| {
1176 async move {
1177 let f12 = fn12(conn);
1178 let f37 = fn37(conn);
1179
1180 let ((r1, r2), (r3, (r4, (r5, (r6, r7))))) = try_join(f12, f37).await.unwrap();
1181
1182 assert_eq!(r1, 1);
1183 assert_eq!(r2, 2);
1184 assert_eq!(r3, 1);
1185 assert_eq!(r4, vec![4]);
1186 assert_eq!(r5, 5);
1187 assert_eq!(r6, vec![6]);
1188 assert_eq!(r7, 7);
1189
1190 fn12(conn).await?;
1191 fn37(conn).await?;
1192
1193 QueryResult::<_>::Ok(())
1194 }
1195 .scope_boxed()
1196 })
1197 .await
1198 .unwrap();
1199 }
1200}