1use self::error_helper::ErrorHelper;
8use self::row::PgRow;
9use self::serialize::ToSqlHelper;
10use crate::statement_cache::CacheSize;
11use crate::statement_cache::{PrepareForCache, QueryFragmentForCachedStatement, StatementCache};
12use crate::stmt_cache::{CallbackHelper, QueryFragmentHelper};
13use crate::{AnsiTransactionManager, AsyncConnection, SimpleAsyncConnection};
14use diesel::connection::Instrumentation;
15use diesel::connection::InstrumentationEvent;
16use diesel::connection::StrQueryHelper;
17use diesel::pg::{
18 Pg, PgMetadataCache, PgMetadataCacheKey, PgMetadataLookup, PgQueryBuilder, PgTypeMetadata,
19};
20use diesel::query_builder::bind_collector::RawBytesBindCollector;
21use diesel::query_builder::{AsQuery, QueryBuilder, QueryFragment, QueryId};
22use diesel::result::{DatabaseErrorKind, Error};
23use diesel::{ConnectionError, ConnectionResult, QueryResult};
24use futures_core::future::BoxFuture;
25use futures_core::stream::BoxStream;
26use futures_util::future::Either;
27use futures_util::stream::TryStreamExt;
28use futures_util::TryFutureExt;
29use futures_util::{FutureExt, StreamExt};
30use std::collections::{HashMap, HashSet};
31use std::future::Future;
32use std::sync::Arc;
33use tokio::sync::broadcast;
34use tokio::sync::oneshot;
35use tokio::sync::Mutex;
36use tokio_postgres::types::ToSql;
37use tokio_postgres::types::Type;
38use tokio_postgres::Statement;
39
40pub use self::transaction_builder::TransactionBuilder;
41
42mod error_helper;
43mod row;
44mod serialize;
45mod transaction_builder;
46
47const FAKE_OID: u32 = 0;
48
49pub struct AsyncPgConnection {
127 conn: Arc<tokio_postgres::Client>,
128 stmt_cache: Arc<Mutex<StatementCache<diesel::pg::Pg, Statement>>>,
129 transaction_state: Arc<Mutex<AnsiTransactionManager>>,
130 metadata_cache: Arc<Mutex<PgMetadataCache>>,
131 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
132 shutdown_channel: Option<oneshot::Sender<()>>,
133 instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
135}
136
137impl SimpleAsyncConnection for AsyncPgConnection {
138 async fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
139 self.record_instrumentation(InstrumentationEvent::start_query(&StrQueryHelper::new(
140 query,
141 )));
142 let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
143 let batch_execute = self
144 .conn
145 .batch_execute(query)
146 .map_err(ErrorHelper)
147 .map_err(Into::into);
148
149 let r = drive_future(connection_future, batch_execute).await;
150 let r = {
151 let mut transaction_manager = self.transaction_state.lock().await;
152 update_transaction_manager_status(r, &mut transaction_manager)
153 };
154 self.record_instrumentation(InstrumentationEvent::finish_query(
155 &StrQueryHelper::new(query),
156 r.as_ref().err(),
157 ));
158 r
159 }
160}
161
162impl AsyncConnection for AsyncPgConnection {
163 type LoadFuture<'conn, 'query> = BoxFuture<'query, QueryResult<Self::Stream<'conn, 'query>>>;
164 type ExecuteFuture<'conn, 'query> = BoxFuture<'query, QueryResult<usize>>;
165 type Stream<'conn, 'query> = BoxStream<'static, QueryResult<PgRow>>;
166 type Row<'conn, 'query> = PgRow;
167 type Backend = diesel::pg::Pg;
168 type TransactionManager = AnsiTransactionManager;
169
170 async fn establish(database_url: &str) -> ConnectionResult<Self> {
171 let mut instrumentation = diesel::connection::get_default_instrumentation();
172 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
173 database_url,
174 ));
175 let instrumentation = Arc::new(std::sync::Mutex::new(instrumentation));
176 let (client, connection) = tokio_postgres::connect(database_url, tokio_postgres::NoTls)
177 .await
178 .map_err(ErrorHelper)?;
179
180 let (error_rx, shutdown_tx) = drive_connection(connection);
181
182 let r = Self::setup(
183 client,
184 Some(error_rx),
185 Some(shutdown_tx),
186 Arc::clone(&instrumentation),
187 )
188 .await;
189
190 instrumentation
191 .lock()
192 .unwrap_or_else(|e| e.into_inner())
193 .on_connection_event(InstrumentationEvent::finish_establish_connection(
194 database_url,
195 r.as_ref().err(),
196 ));
197 r
198 }
199
200 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
201 where
202 T: AsQuery + 'query,
203 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
204 {
205 let query = source.as_query();
206 let load_future = self.with_prepared_statement(query, load_prepared);
207
208 self.run_with_connection_future(load_future)
209 }
210
211 fn execute_returning_count<'conn, 'query, T>(
212 &'conn mut self,
213 source: T,
214 ) -> Self::ExecuteFuture<'conn, 'query>
215 where
216 T: QueryFragment<Self::Backend> + QueryId + 'query,
217 {
218 let execute = self.with_prepared_statement(source, execute_prepared);
219 self.run_with_connection_future(execute)
220 }
221
222 fn transaction_state(&mut self) -> &mut AnsiTransactionManager {
223 if let Some(tm) = Arc::get_mut(&mut self.transaction_state) {
227 tm.get_mut()
228 } else {
229 panic!("Cannot access shared transaction state")
230 }
231 }
232
233 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
234 if let Some(instrumentation) = Arc::get_mut(&mut self.instrumentation) {
238 &mut *(instrumentation.get_mut().unwrap_or_else(|p| p.into_inner()))
239 } else {
240 panic!("Cannot access shared instrumentation")
241 }
242 }
243
244 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
245 self.instrumentation = Arc::new(std::sync::Mutex::new(Some(Box::new(instrumentation))));
246 }
247
248 fn set_prepared_statement_cache_size(&mut self, size: CacheSize) {
249 if let Some(cache) = Arc::get_mut(&mut self.stmt_cache) {
253 cache.get_mut().set_cache_size(size)
254 } else {
255 panic!("Cannot access shared statement cache")
256 }
257 }
258}
259
260impl Drop for AsyncPgConnection {
261 fn drop(&mut self) {
262 if let Some(tx) = self.shutdown_channel.take() {
263 let _ = tx.send(());
264 }
265 }
266}
267
268async fn load_prepared(
269 conn: Arc<tokio_postgres::Client>,
270 stmt: Statement,
271 binds: Vec<ToSqlHelper>,
272) -> QueryResult<BoxStream<'static, QueryResult<PgRow>>> {
273 let res = conn.query_raw(&stmt, binds).await.map_err(ErrorHelper)?;
274
275 Ok(res
276 .map_err(|e| diesel::result::Error::from(ErrorHelper(e)))
277 .map_ok(PgRow::new)
278 .boxed())
279}
280
281async fn execute_prepared(
282 conn: Arc<tokio_postgres::Client>,
283 stmt: Statement,
284 binds: Vec<ToSqlHelper>,
285) -> QueryResult<usize> {
286 let binds = binds
287 .iter()
288 .map(|b| b as &(dyn ToSql + Sync))
289 .collect::<Vec<_>>();
290
291 let res = tokio_postgres::Client::execute(&conn, &stmt, &binds as &[_])
292 .await
293 .map_err(ErrorHelper)?;
294 res.try_into()
295 .map_err(|e| diesel::result::Error::DeserializationError(Box::new(e)))
296}
297
298#[inline(always)]
299fn update_transaction_manager_status<T>(
300 query_result: QueryResult<T>,
301 transaction_manager: &mut AnsiTransactionManager,
302) -> QueryResult<T> {
303 if let Err(diesel::result::Error::DatabaseError(DatabaseErrorKind::SerializationFailure, _)) =
304 query_result
305 {
306 transaction_manager
307 .status
308 .set_requires_rollback_maybe_up_to_top_level(true)
309 }
310 query_result
311}
312
313fn prepare_statement_helper(
314 conn: Arc<tokio_postgres::Client>,
315 sql: &str,
316 _is_for_cache: PrepareForCache,
317 metadata: &[PgTypeMetadata],
318) -> CallbackHelper<
319 impl Future<Output = QueryResult<(Statement, Arc<tokio_postgres::Client>)>> + Send,
320> {
321 let bind_types = metadata
322 .iter()
323 .map(type_from_oid)
324 .collect::<QueryResult<Vec<_>>>();
325 let sql = sql.to_string();
332 CallbackHelper(async move {
333 let bind_types = bind_types?;
334 let stmt = conn
335 .prepare_typed(&sql, &bind_types)
336 .await
337 .map_err(ErrorHelper);
338 Ok((stmt?, conn))
339 })
340}
341
342fn type_from_oid(t: &PgTypeMetadata) -> QueryResult<Type> {
343 let oid = t
344 .oid()
345 .map_err(|e| diesel::result::Error::SerializationError(Box::new(e) as _))?;
346
347 if let Some(tpe) = Type::from_oid(oid) {
348 return Ok(tpe);
349 }
350
351 Ok(Type::new(
352 format!("diesel_custom_type_{oid}"),
353 oid,
354 tokio_postgres::types::Kind::Simple,
355 "public".into(),
356 ))
357}
358
359impl AsyncPgConnection {
360 pub fn build_transaction(&mut self) -> TransactionBuilder<'_, Self> {
387 TransactionBuilder::new(self)
388 }
389
390 pub async fn try_from(conn: tokio_postgres::Client) -> ConnectionResult<Self> {
392 Self::setup(
393 conn,
394 None,
395 None,
396 Arc::new(std::sync::Mutex::new(
397 diesel::connection::get_default_instrumentation(),
398 )),
399 )
400 .await
401 }
402
403 pub async fn try_from_client_and_connection<S>(
406 client: tokio_postgres::Client,
407 conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
408 ) -> ConnectionResult<Self>
409 where
410 S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
411 {
412 let (error_rx, shutdown_tx) = drive_connection(conn);
413
414 Self::setup(
415 client,
416 Some(error_rx),
417 Some(shutdown_tx),
418 Arc::new(std::sync::Mutex::new(
419 diesel::connection::get_default_instrumentation(),
420 )),
421 )
422 .await
423 }
424
425 async fn setup(
426 conn: tokio_postgres::Client,
427 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
428 shutdown_channel: Option<oneshot::Sender<()>>,
429 instrumentation: Arc<std::sync::Mutex<Option<Box<dyn Instrumentation>>>>,
430 ) -> ConnectionResult<Self> {
431 let mut conn = Self {
432 conn: Arc::new(conn),
433 stmt_cache: Arc::new(Mutex::new(StatementCache::new())),
434 transaction_state: Arc::new(Mutex::new(AnsiTransactionManager::default())),
435 metadata_cache: Arc::new(Mutex::new(PgMetadataCache::new())),
436 connection_future,
437 shutdown_channel,
438 instrumentation,
439 };
440 conn.set_config_options()
441 .await
442 .map_err(ConnectionError::CouldntSetupConfiguration)?;
443 Ok(conn)
444 }
445
446 pub fn cancel_token(&self) -> tokio_postgres::CancelToken {
448 self.conn.cancel_token()
449 }
450
451 async fn set_config_options(&mut self) -> QueryResult<()> {
452 use crate::run_query_dsl::RunQueryDsl;
453
454 futures_util::future::try_join(
455 diesel::sql_query("SET TIME ZONE 'UTC'").execute(self),
456 diesel::sql_query("SET CLIENT_ENCODING TO 'UTF8'").execute(self),
457 )
458 .await?;
459 Ok(())
460 }
461
462 fn run_with_connection_future<'a, R: 'a>(
463 &self,
464 future: impl Future<Output = QueryResult<R>> + Send + 'a,
465 ) -> BoxFuture<'a, QueryResult<R>> {
466 let connection_future = self.connection_future.as_ref().map(|rx| rx.resubscribe());
467 drive_future(connection_future, future).boxed()
468 }
469
470 fn with_prepared_statement<'a, T, F, R>(
471 &mut self,
472 query: T,
473 callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
474 ) -> BoxFuture<'a, QueryResult<R>>
475 where
476 T: QueryFragment<diesel::pg::Pg> + QueryId,
477 F: Future<Output = QueryResult<R>> + Send + 'a,
478 R: Send,
479 {
480 self.record_instrumentation(InstrumentationEvent::start_query(&diesel::debug_query(
481 &query,
482 )));
483 let mut query_builder = PgQueryBuilder::default();
491
492 let bind_data = construct_bind_data(&query);
493
494 self.with_prepared_statement_after_sql_built(
496 callback,
497 query.is_safe_to_cache_prepared(&Pg),
498 T::query_id(),
499 query.to_sql(&mut query_builder, &Pg),
500 query_builder,
501 bind_data,
502 )
503 }
504
505 fn with_prepared_statement_after_sql_built<'a, F, R>(
506 &mut self,
507 callback: fn(Arc<tokio_postgres::Client>, Statement, Vec<ToSqlHelper>) -> F,
508 is_safe_to_cache_prepared: QueryResult<bool>,
509 query_id: Option<std::any::TypeId>,
510 to_sql_result: QueryResult<()>,
511 query_builder: PgQueryBuilder,
512 bind_data: BindData,
513 ) -> BoxFuture<'a, QueryResult<R>>
514 where
515 F: Future<Output = QueryResult<R>> + Send + 'a,
516 R: Send,
517 {
518 let raw_connection = self.conn.clone();
519 let stmt_cache = self.stmt_cache.clone();
520 let metadata_cache = self.metadata_cache.clone();
521 let tm = self.transaction_state.clone();
522 let instrumentation = self.instrumentation.clone();
523 let BindData {
524 collect_bind_result,
525 fake_oid_locations,
526 generated_oids,
527 mut bind_collector,
528 } = bind_data;
529
530 async move {
531 let sql = to_sql_result.map(|_| query_builder.finish())?;
532 let res = async {
533 let is_safe_to_cache_prepared = is_safe_to_cache_prepared?;
534 collect_bind_result?;
535 if let Some(ref unresolved_types) = generated_oids {
540 let metadata_cache = &mut *metadata_cache.lock().await;
541 let mut real_oids = HashMap::new();
542
543 for ((schema, lookup_type_name), (fake_oid, fake_array_oid)) in
544 unresolved_types
545 {
546 let cache_key = PgMetadataCacheKey::new(
550 schema.as_deref().map(Into::into),
551 lookup_type_name.into(),
552 );
553 let real_metadata = if let Some(type_metadata) =
554 metadata_cache.lookup_type(&cache_key)
555 {
556 type_metadata
557 } else {
558 let type_metadata =
559 lookup_type(schema.clone(), lookup_type_name.clone(), &raw_connection)
560 .await?;
561 metadata_cache.store_type(cache_key, type_metadata);
562
563 PgTypeMetadata::from_result(Ok(type_metadata))
564 };
565 let (real_oid, real_array_oid) = unwrap_oids(&real_metadata);
567 real_oids.extend([(*fake_oid, real_oid), (*fake_array_oid, real_array_oid)]);
568 }
569
570 for m in &mut bind_collector.metadata {
572 let (oid, array_oid) = unwrap_oids(m);
573 *m = PgTypeMetadata::new(
574 real_oids.get(&oid).copied().unwrap_or(oid),
575 real_oids.get(&array_oid).copied().unwrap_or(array_oid)
576 );
577 }
578 for (bind_index, byte_index) in fake_oid_locations {
580 replace_fake_oid(&mut bind_collector.binds, &real_oids, bind_index, byte_index)
581 .ok_or_else(|| {
582 Error::SerializationError(
583 format!("diesel_async failed to replace a type OID serialized in bind value {bind_index}").into(),
584 )
585 })?;
586 }
587 }
588 let stmt = {
589 let mut stmt_cache = stmt_cache.lock().await;
590 let helper = QueryFragmentHelper {
591 sql: sql.clone(),
592 safe_to_cache: is_safe_to_cache_prepared,
593 };
594 let instrumentation = Arc::clone(&instrumentation);
595 stmt_cache
596 .cached_statement_non_generic(
597 query_id,
598 &helper,
599 &Pg,
600 &bind_collector.metadata,
601 raw_connection.clone(),
602 prepare_statement_helper,
603 &mut move |event: InstrumentationEvent<'_>| {
604 instrumentation.lock().unwrap_or_else(|e| e.into_inner())
607 .on_connection_event(event);
608 },
609 )
610 .await?
611 .0
612 .clone()
613 };
614
615 let binds = bind_collector
616 .metadata
617 .into_iter()
618 .zip(bind_collector.binds)
619 .map(|(meta, bind)| ToSqlHelper(meta, bind))
620 .collect::<Vec<_>>();
621 callback(raw_connection, stmt.clone(), binds).await
622 };
623 let res = res.await;
624 let mut tm = tm.lock().await;
625 let r = update_transaction_manager_status(res, &mut tm);
626 instrumentation
627 .lock()
628 .unwrap_or_else(|p| p.into_inner())
629 .on_connection_event(InstrumentationEvent::finish_query(
630 &StrQueryHelper::new(&sql),
631 r.as_ref().err(),
632 ));
633
634 r
635 }
636 .boxed()
637 }
638
639 fn record_instrumentation(&self, event: InstrumentationEvent<'_>) {
640 self.instrumentation
641 .lock()
642 .unwrap_or_else(|p| p.into_inner())
643 .on_connection_event(event);
644 }
645}
646
647struct BindData {
648 collect_bind_result: Result<(), Error>,
649 fake_oid_locations: Vec<(usize, usize)>,
650 generated_oids: GeneratedOidTypeMap,
651 bind_collector: RawBytesBindCollector<Pg>,
652}
653
654fn construct_bind_data(query: &dyn QueryFragment<diesel::pg::Pg>) -> BindData {
655 let mut bind_collector_0 = RawBytesBindCollector::<diesel::pg::Pg>::new();
665 let mut metadata_lookup_0 = PgAsyncMetadataLookup {
666 custom_oid: false,
667 generated_oids: None,
668 oid_generator: |_, _| (FAKE_OID, FAKE_OID),
669 };
670 let collect_bind_result_0 =
671 query.collect_binds(&mut bind_collector_0, &mut metadata_lookup_0, &Pg);
672 if metadata_lookup_0.custom_oid {
682 let mut max_oid = bind_collector_0
686 .metadata
687 .iter()
688 .flat_map(|t| {
689 [
690 t.oid().unwrap_or_default(),
691 t.array_oid().unwrap_or_default(),
692 ]
693 })
694 .max()
695 .unwrap_or_default();
696 let mut bind_collector_1 = RawBytesBindCollector::<diesel::pg::Pg>::new();
697 let mut metadata_lookup_1 = PgAsyncMetadataLookup {
698 custom_oid: false,
699 generated_oids: Some(HashMap::new()),
700 oid_generator: move |_, _| {
701 max_oid += 2;
702 (max_oid, max_oid + 1)
703 },
704 };
705 let collect_bind_result_1 =
706 query.collect_binds(&mut bind_collector_1, &mut metadata_lookup_1, &Pg);
707
708 assert_eq!(
709 bind_collector_0.binds.len(),
710 bind_collector_0.metadata.len()
711 );
712 let fake_oid_locations = std::iter::zip(
713 bind_collector_0
714 .binds
715 .iter()
716 .zip(&bind_collector_0.metadata),
717 &bind_collector_1.binds,
718 )
719 .enumerate()
720 .flat_map(|(bind_index, ((bytes_0, metadata_0), bytes_1))| {
721 let (bytes_0, bytes_1) = if matches!(metadata_0.oid(), Ok(FAKE_OID)) {
725 (
726 bytes_0.as_deref().unwrap_or_default(),
727 bytes_1.as_deref().unwrap_or_default(),
728 )
729 } else {
730 (&[] as &[_], &[] as &[_])
734 };
735 let lookup_map = metadata_lookup_1
736 .generated_oids
737 .as_ref()
738 .map(|map| {
739 map.values()
740 .flat_map(|(oid, array_oid)| [*oid, *array_oid])
741 .collect::<HashSet<_>>()
742 })
743 .unwrap_or_default();
744 std::iter::zip(
745 bytes_0.windows(std::mem::size_of_val(&FAKE_OID)),
746 bytes_1.windows(std::mem::size_of_val(&FAKE_OID)),
747 )
748 .enumerate()
749 .filter_map(move |(byte_index, (l, r))| {
750 let r_val = u32::from_be_bytes(r.try_into().expect("That's the right size"));
759 (l == FAKE_OID.to_be_bytes()
760 && r != FAKE_OID.to_be_bytes()
761 && lookup_map.contains(&r_val))
762 .then_some((bind_index, byte_index))
763 })
764 })
765 .collect::<Vec<_>>();
767 BindData {
768 collect_bind_result: collect_bind_result_0.and(collect_bind_result_1),
769 fake_oid_locations,
770 generated_oids: metadata_lookup_1.generated_oids,
771 bind_collector: bind_collector_1,
772 }
773 } else {
774 BindData {
775 collect_bind_result: collect_bind_result_0,
776 fake_oid_locations: Vec::new(),
777 generated_oids: None,
778 bind_collector: bind_collector_0,
779 }
780 }
781}
782
783type GeneratedOidTypeMap = Option<HashMap<(Option<String>, String), (u32, u32)>>;
784
785struct PgAsyncMetadataLookup<F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static> {
788 custom_oid: bool,
789 generated_oids: GeneratedOidTypeMap,
790 oid_generator: F,
791}
792
793impl<F> PgMetadataLookup for PgAsyncMetadataLookup<F>
794where
795 F: FnMut(&str, Option<&str>) -> (u32, u32) + 'static,
796{
797 fn lookup_type(&mut self, type_name: &str, schema: Option<&str>) -> PgTypeMetadata {
798 self.custom_oid = true;
799
800 let oid = if let Some(map) = &mut self.generated_oids {
801 *map.entry((schema.map(ToOwned::to_owned), type_name.to_owned()))
802 .or_insert_with(|| (self.oid_generator)(type_name, schema))
803 } else {
804 (self.oid_generator)(type_name, schema)
805 };
806
807 PgTypeMetadata::from_result(Ok(oid))
808 }
809}
810
811async fn lookup_type(
812 schema: Option<String>,
813 type_name: String,
814 raw_connection: &tokio_postgres::Client,
815) -> QueryResult<(u32, u32)> {
816 let r = if let Some(schema) = schema.as_ref() {
817 raw_connection
818 .query_one(
819 "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
820 INNER JOIN pg_namespace ON pg_type.typnamespace = pg_namespace.oid \
821 WHERE pg_type.typname = $1 AND pg_namespace.nspname = $2 \
822 LIMIT 1",
823 &[&type_name, schema],
824 )
825 .await
826 .map_err(ErrorHelper)?
827 } else {
828 raw_connection
829 .query_one(
830 "SELECT pg_type.oid, pg_type.typarray FROM pg_type \
831 WHERE pg_type.oid = quote_ident($1)::regtype::oid \
832 LIMIT 1",
833 &[&type_name],
834 )
835 .await
836 .map_err(ErrorHelper)?
837 };
838 Ok((r.get(0), r.get(1)))
839}
840
841fn unwrap_oids(metadata: &PgTypeMetadata) -> (u32, u32) {
842 let err_msg = "PgTypeMetadata is supposed to always be Ok here";
843 (
844 metadata.oid().expect(err_msg),
845 metadata.array_oid().expect(err_msg),
846 )
847}
848
849fn replace_fake_oid(
850 binds: &mut [Option<Vec<u8>>],
851 real_oids: &HashMap<u32, u32>,
852 bind_index: usize,
853 byte_index: usize,
854) -> Option<()> {
855 let serialized_oid = binds
856 .get_mut(bind_index)?
857 .as_mut()?
858 .get_mut(byte_index..)?
859 .first_chunk_mut::<4>()?;
860 *serialized_oid = real_oids
861 .get(&u32::from_be_bytes(*serialized_oid))?
862 .to_be_bytes();
863 Some(())
864}
865
866async fn drive_future<R>(
867 connection_future: Option<broadcast::Receiver<Arc<tokio_postgres::Error>>>,
868 client_future: impl Future<Output = Result<R, diesel::result::Error>>,
869) -> Result<R, diesel::result::Error> {
870 if let Some(mut connection_future) = connection_future {
871 let client_future = std::pin::pin!(client_future);
872 let connection_future = std::pin::pin!(connection_future.recv());
873 match futures_util::future::select(client_future, connection_future).await {
874 Either::Left((res, _)) => res,
875 Either::Right((Ok(e), _)) => Err(self::error_helper::from_tokio_postgres_error(e)),
878 Either::Right((Err(e), _)) => Err(diesel::result::Error::DatabaseError(
880 DatabaseErrorKind::UnableToSendCommand,
881 Box::new(e.to_string()),
882 )),
883 }
884 } else {
885 client_future.await
886 }
887}
888
889fn drive_connection<S>(
890 conn: tokio_postgres::Connection<tokio_postgres::Socket, S>,
891) -> (
892 broadcast::Receiver<Arc<tokio_postgres::Error>>,
893 oneshot::Sender<()>,
894)
895where
896 S: tokio_postgres::tls::TlsStream + Unpin + Send + 'static,
897{
898 let (error_tx, error_rx) = tokio::sync::broadcast::channel(1);
899 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
900
901 tokio::spawn(async move {
902 match futures_util::future::select(shutdown_rx, conn).await {
903 Either::Left(_) | Either::Right((Ok(_), _)) => {}
904 Either::Right((Err(e), _)) => {
905 let _ = error_tx.send(Arc::new(e));
906 }
907 }
908 });
909
910 (error_rx, shutdown_tx)
911}
912
913#[cfg(any(
914 feature = "deadpool",
915 feature = "bb8",
916 feature = "mobc",
917 feature = "r2d2"
918))]
919impl crate::pooled_connection::PoolableConnection for AsyncPgConnection {
920 fn is_broken(&mut self) -> bool {
921 use crate::TransactionManager;
922
923 Self::TransactionManager::is_broken_transaction_manager(self) || self.conn.is_closed()
924 }
925}
926
927impl QueryFragmentForCachedStatement<Pg> for QueryFragmentHelper {
928 fn construct_sql(&self, _backend: &Pg) -> QueryResult<String> {
929 Ok(self.sql.clone())
930 }
931
932 fn is_safe_to_cache_prepared(&self, _backend: &Pg) -> QueryResult<bool> {
933 Ok(self.safe_to_cache)
934 }
935}
936
937#[cfg(test)]
938mod tests {
939 use super::*;
940 use crate::run_query_dsl::RunQueryDsl;
941 use diesel::sql_types::Integer;
942 use diesel::IntoSql;
943
944 #[tokio::test]
945 async fn pipelining() {
946 let database_url =
947 std::env::var("DATABASE_URL").expect("DATABASE_URL must be set in order to run tests");
948 let mut conn = crate::AsyncPgConnection::establish(&database_url)
949 .await
950 .unwrap();
951
952 let q1 = diesel::select(1_i32.into_sql::<Integer>());
953 let q2 = diesel::select(2_i32.into_sql::<Integer>());
954
955 let f1 = q1.get_result::<i32>(&mut conn);
956 let f2 = q2.get_result::<i32>(&mut conn);
957
958 let (r1, r2) = futures_util::try_join!(f1, f2).unwrap();
959
960 assert_eq!(r1, 1);
961 assert_eq!(r2, 2);
962 }
963}