1use core::{
2 future::Future,
3 mem,
4 ops::Deref,
5 pin::Pin,
6 task::{Context, Poll},
7};
8
9use std::{
10 collections::{HashMap, VecDeque},
11 sync::Mutex,
12};
13
14use tokio::sync::{Semaphore, SemaphorePermit};
15use xitca_io::{bytes::BytesMut, io::AsyncIo};
16
17use super::{
18 BoxedFuture, Postgres,
19 client::{Client, ClientBorrowMut},
20 config::Config,
21 copy::{r#Copy, CopyIn, CopyOut},
22 driver::{
23 Driver,
24 codec::{AsParams, Response, encode::Encode},
25 generic::GenericDriver,
26 },
27 error::Error,
28 execute::Execute,
29 iter::AsyncLendingIterator,
30 prepare::Prepare,
31 query::{Query, RowAffected, RowStreamOwned},
32 session::Session,
33 statement::{Statement, StatementNamed, StatementQuery},
34 transaction::{Transaction, TransactionBuilder},
35 types::{Oid, Type},
36};
37
38pub struct PoolBuilder {
40 config: Result<Config, Error>,
41 capacity: usize,
42}
43
44impl PoolBuilder {
45 pub fn capacity(mut self, cap: usize) -> Self {
50 self.capacity = cap;
51 self
52 }
53
54 pub fn build(self) -> Result<Pool, Error> {
56 let config = self.config?;
57
58 Ok(Pool {
59 conn: Mutex::new(VecDeque::with_capacity(self.capacity)),
60 permits: Semaphore::new(self.capacity),
61 config: Box::new(config),
62 })
63 }
64}
65
66pub struct Pool {
68 conn: Mutex<VecDeque<PoolClient>>,
69 permits: Semaphore,
70 config: Box<Config>,
71}
72
73impl Pool {
74 pub fn builder<C>(cfg: C) -> PoolBuilder
76 where
77 Config: TryFrom<C>,
78 Error: From<<Config as TryFrom<C>>::Error>,
79 {
80 PoolBuilder {
81 config: cfg.try_into().map_err(Into::into),
82 capacity: 1,
83 }
84 }
85
86 pub async fn get(&self) -> Result<PoolConnection<'_>, Error> {
90 let _permit = self.permits.acquire().await.expect("Semaphore must not be closed");
91 let conn = self.conn.lock().unwrap().pop_front();
92 let conn = match conn {
93 Some(conn) if !conn.client.closed() => conn,
94 _ => self.connect().await?,
95 };
96 Ok(PoolConnection {
97 pool: self,
98 conn: Some(conn),
99 _permit,
100 })
101 }
102
103 #[cold]
104 #[inline(never)]
105 fn connect(&self) -> BoxedFuture<'_, Result<PoolClient, Error>> {
106 Box::pin(async move {
107 let (client, driver) = Postgres::new(Clone::clone(&*self.config)).connect().await?;
108 match driver {
109 Driver::Tcp(drv) => {
110 #[cfg(feature = "io-uring")]
111 {
112 drive_uring(drv)
113 }
114
115 #[cfg(not(feature = "io-uring"))]
116 {
117 drive(drv)
118 }
119 }
120 Driver::Dynamic(drv) => drive(drv),
121 #[cfg(feature = "tls")]
122 Driver::Tls(drv) => drive(drv),
123 #[cfg(unix)]
124 Driver::Unix(drv) => drive(drv),
125 #[cfg(all(unix, feature = "tls"))]
126 Driver::UnixTls(drv) => drive(drv),
127 #[cfg(feature = "quic")]
128 Driver::Quic(drv) => drive(drv),
129 };
130 Ok(PoolClient::new(client))
131 })
132 }
133}
134
135fn drive(mut drv: GenericDriver<impl AsyncIo + Send + 'static>) {
136 tokio::task::spawn(async move {
137 while drv.try_next().await?.is_some() {
138 }
140 Ok::<_, Error>(())
141 });
142}
143
144#[cfg(feature = "io-uring")]
145fn drive_uring(drv: GenericDriver<xitca_io::net::TcpStream>) {
146 use core::{async_iter::AsyncIterator, future::poll_fn, pin::pin};
147
148 tokio::task::spawn_local(async move {
149 let mut iter = pin!(crate::driver::io_uring::UringDriver::from_tcp(drv).into_iter());
150 while let Some(res) = poll_fn(|cx| iter.as_mut().poll_next(cx)).await {
151 let _ = res?;
152 }
153 Ok::<_, Error>(())
154 });
155}
156
157pub struct PoolConnection<'a> {
186 pool: &'a Pool,
187 conn: Option<PoolClient>,
188 _permit: SemaphorePermit<'a>,
189}
190
191impl PoolConnection<'_> {
192 #[inline]
194 pub fn transaction(&mut self) -> impl Future<Output = Result<Transaction<&mut Self>, Error>> + Send {
195 TransactionBuilder::new().begin(self)
196 }
197
198 #[inline]
200 pub fn transaction_owned(self) -> impl Future<Output = Result<Transaction<Self>, Error>> + Send {
201 TransactionBuilder::new().begin(self)
202 }
203
204 #[inline]
206 pub fn copy_in(&mut self, stmt: &Statement) -> impl Future<Output = Result<CopyIn<'_, Self>, Error>> + Send {
207 CopyIn::new(self, stmt)
208 }
209
210 #[inline]
212 pub async fn copy_out(&self, stmt: &Statement) -> Result<CopyOut, Error> {
213 CopyOut::new(self, stmt).await
214 }
215
216 #[inline(always)]
269 pub fn consume(self) -> Self {
270 self
271 }
272
273 pub fn cancel_token(&self) -> Session {
275 self.conn().client.cancel_token()
276 }
277
278 fn insert_cache(&mut self, named: &str, stmt: Statement) -> &CachedStatement {
279 self.conn_mut()
280 .statements
281 .entry(Box::from(named))
282 .or_insert(CachedStatement { stmt })
283 }
284
285 fn conn(&self) -> &PoolClient {
286 self.conn.as_ref().unwrap()
287 }
288
289 fn conn_mut(&mut self) -> &mut PoolClient {
290 self.conn.as_mut().unwrap()
291 }
292}
293
294impl ClientBorrowMut for PoolConnection<'_> {
295 #[inline]
296 fn _borrow_mut(&mut self) -> &mut Client {
297 &mut self.conn_mut().client
298 }
299}
300
301impl Prepare for PoolConnection<'_> {
302 #[inline]
303 async fn _get_type(&self, oid: Oid) -> Result<Type, Error> {
304 self.conn().client._get_type(oid).await
305 }
306
307 #[inline]
308 fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
309 self.conn().client._get_type_blocking(oid)
310 }
311}
312
313impl Query for PoolConnection<'_> {
314 #[inline]
315 fn _send_encode_query<S>(&self, stmt: S) -> Result<(S::Output, Response), Error>
316 where
317 S: Encode,
318 {
319 self.conn().client._send_encode_query(stmt)
320 }
321}
322
323impl r#Copy for PoolConnection<'_> {
324 #[inline]
325 fn send_one_way<F>(&self, func: F) -> Result<(), Error>
326 where
327 F: FnOnce(&mut BytesMut) -> Result<(), Error>,
328 {
329 self.conn().client.send_one_way(func)
330 }
331}
332
333impl Drop for PoolConnection<'_> {
334 fn drop(&mut self) {
335 let conn = self.conn.take().unwrap();
336 self.pool.conn.lock().unwrap().push_back(conn);
337 }
338}
339
340pub struct CachedStatement {
345 stmt: Statement,
346}
347
348impl Clone for CachedStatement {
349 fn clone(&self) -> Self {
350 Self {
351 stmt: self.stmt.duplicate(),
352 }
353 }
354}
355
356impl Deref for CachedStatement {
357 type Target = Statement;
358
359 fn deref(&self) -> &Self::Target {
360 &self.stmt
361 }
362}
363
364struct PoolClient {
365 client: Client,
366 statements: HashMap<Box<str>, CachedStatement>,
367}
368
369impl PoolClient {
370 fn new(client: Client) -> Self {
371 Self {
372 client,
373 statements: HashMap::new(),
374 }
375 }
376}
377
378impl<'c, 's> Execute<&'c mut PoolConnection<'_>> for StatementNamed<'s>
379where
380 's: 'c,
381{
382 type ExecuteOutput = StatementCacheFuture<'c>;
383 type QueryOutput = Self::ExecuteOutput;
384
385 fn execute(self, cli: &'c mut PoolConnection) -> Self::ExecuteOutput {
386 match cli.conn().statements.get(self.stmt) {
387 Some(stmt) => StatementCacheFuture::Cached(stmt.clone()),
388 None => StatementCacheFuture::Prepared(Box::pin(async move {
389 let name = self.stmt;
390 let stmt = self.execute(&*cli).await?.leak();
391 Ok(cli.insert_cache(name, stmt).clone())
392 })),
393 }
394 }
395
396 #[inline]
397 fn query(self, cli: &'c mut PoolConnection) -> Self::QueryOutput {
398 self.execute(cli)
399 }
400}
401
402#[cfg(not(feature = "nightly"))]
403impl<'c, 's, P> Execute<&'c mut PoolConnection<'_>> for StatementQuery<'s, P>
404where
405 P: AsParams + Send + 'c,
406 's: 'c,
407{
408 type ExecuteOutput = BoxedFuture<'c, Result<RowAffected, Error>>;
409 type QueryOutput = BoxedFuture<'c, Result<RowStreamOwned, Error>>;
410
411 fn execute(self, conn: &'c mut PoolConnection<'_>) -> Self::ExecuteOutput {
412 Box::pin(async move {
413 let StatementQuery { stmt, types, params } = self;
414
415 let stmt = match conn.conn().statements.get(stmt) {
416 Some(stmt) => stmt,
417 None => {
418 let prepared_stmt = Statement::named(stmt, types).execute(&conn).await?.leak();
419 conn.insert_cache(stmt, prepared_stmt);
420 conn.conn().statements.get(stmt).unwrap()
421 }
422 };
423
424 stmt.bind(params).query(conn).await.map(RowAffected::from)
425 })
426 }
427
428 fn query(self, conn: &'c mut PoolConnection<'_>) -> Self::QueryOutput {
429 Box::pin(async move {
430 let StatementQuery { stmt, types, params } = self;
431
432 let stmt = match conn.conn().statements.get(stmt) {
433 Some(stmt) => stmt,
434 None => {
435 let prepared_stmt = Statement::named(stmt, types).execute(&conn).await?.leak();
436 conn.insert_cache(stmt, prepared_stmt);
437 conn.conn().statements.get(stmt).unwrap()
438 }
439 };
440
441 stmt.bind(params).into_owned().query(conn).await
442 })
443 }
444}
445
446#[cfg(feature = "nightly")]
447impl<'c, 's, 'p, P> Execute<&'c mut PoolConnection<'p>> for StatementQuery<'s, P>
448where
449 P: AsParams + Send + 'c,
450 's: 'c,
451 'p: 'c,
452{
453 type ExecuteOutput = impl Future<Output = Result<RowAffected, Error>> + Send + 'c;
454 type QueryOutput = impl Future<Output = Result<RowStreamOwned, Error>> + Send + 'c;
455
456 fn execute(self, conn: &'c mut PoolConnection<'p>) -> Self::ExecuteOutput {
457 async move {
458 let StatementQuery { stmt, types, params } = self;
459
460 let stmt = match conn.conn().statements.get(stmt) {
461 Some(stmt) => stmt,
462 None => {
463 let prepared_stmt = Statement::named(stmt, types).execute(&conn).await?.leak();
464 conn.insert_cache(stmt, prepared_stmt);
465 conn.conn().statements.get(stmt).unwrap()
466 }
467 };
468
469 stmt.bind(params).query(conn).await.map(RowAffected::from)
470 }
471 }
472
473 fn query(self, conn: &'c mut PoolConnection<'p>) -> Self::QueryOutput {
474 async move {
475 let StatementQuery { stmt, types, params } = self;
476
477 let stmt = match conn.conn().statements.get(stmt) {
478 Some(stmt) => stmt,
479 None => {
480 let prepared_stmt = Statement::named(stmt, types).execute(&conn).await?.leak();
481 conn.insert_cache(stmt, prepared_stmt);
482 conn.conn().statements.get(stmt).unwrap()
483 }
484 };
485
486 stmt.bind(params).into_owned().query(conn).await
487 }
488 }
489}
490
491#[cfg(not(feature = "nightly"))]
493impl<'c, 's, P> Execute<&'c Pool> for StatementQuery<'s, P>
494where
495 P: AsParams + Send + 'c,
496 's: 'c,
497{
498 type ExecuteOutput = BoxedFuture<'c, Result<u64, Error>>;
499 type QueryOutput = BoxedFuture<'c, Result<RowStreamOwned, Error>>;
500
501 #[inline]
502 fn execute(self, pool: &'c Pool) -> Self::ExecuteOutput {
503 Box::pin(async {
504 {
505 let mut conn = pool.get().await?;
506 self.execute(&mut conn).await?
507 }
508 .await
510 })
511 }
512
513 #[inline]
514 fn query(self, pool: &'c Pool) -> Self::QueryOutput {
515 Box::pin(async {
516 let mut conn = pool.get().await?;
517 self.query(&mut conn).await
518 })
519 }
520}
521
522#[cfg(feature = "nightly")]
523impl<'c, 's, P> Execute<&'c Pool> for StatementQuery<'s, P>
524where
525 P: AsParams + Send + 'c,
526 's: 'c,
527{
528 type ExecuteOutput = impl Future<Output = Result<u64, Error>> + Send + 'c;
529 type QueryOutput = impl Future<Output = Result<RowStreamOwned, Error>> + Send + 'c;
530
531 #[inline]
532 fn execute(self, pool: &'c Pool) -> Self::ExecuteOutput {
533 async {
534 {
535 let mut conn = pool.get().await?;
536 self.execute(&mut conn).await?
537 }
538 .await
540 }
541 }
542
543 #[inline]
544 fn query(self, pool: &'c Pool) -> Self::QueryOutput {
545 async {
546 let mut conn = pool.get().await?;
547 self.query(&mut conn).await
548 }
549 }
550}
551
552#[cfg(not(feature = "nightly"))]
554impl<'c, 's, I, P> Execute<&'c Pool> for I
555where
556 I: IntoIterator,
557 I::IntoIter: Iterator<Item = StatementQuery<'s, P>> + Send + 'c,
558 P: AsParams + Send + 'c,
559 's: 'c,
560{
561 type ExecuteOutput = BoxedFuture<'c, Result<u64, Error>>;
562 type QueryOutput = BoxedFuture<'c, Result<Vec<RowStreamOwned>, Error>>;
563
564 #[inline]
565 fn execute(self, pool: &'c Pool) -> Self::ExecuteOutput {
566 Box::pin(execute_iter_with_pool(self.into_iter(), pool))
567 }
568
569 #[inline]
570 fn query(self, pool: &'c Pool) -> Self::QueryOutput {
571 Box::pin(query_iter_with_pool(self.into_iter(), pool))
572 }
573}
574
575#[cfg(feature = "nightly")]
576impl<'c, 's, I, P> Execute<&'c Pool> for I
577where
578 I: IntoIterator,
579 I::IntoIter: Iterator<Item = StatementQuery<'s, P>> + Send + 'c,
580 P: AsParams + Send + 'c,
581 's: 'c,
582{
583 type ExecuteOutput = impl Future<Output = Result<u64, Error>> + Send + 'c;
584 type QueryOutput = impl Future<Output = Result<Vec<RowStreamOwned>, Error>> + Send + 'c;
585
586 #[inline]
587 fn execute(self, pool: &'c Pool) -> Self::ExecuteOutput {
588 execute_iter_with_pool(self.into_iter(), pool)
589 }
590
591 #[inline]
592 fn query(self, pool: &'c Pool) -> Self::QueryOutput {
593 query_iter_with_pool(self.into_iter(), pool)
594 }
595}
596
597async fn execute_iter_with_pool<P>(
598 iter: impl Iterator<Item = StatementQuery<'_, P>> + Send,
599 pool: &Pool,
600) -> Result<u64, Error>
601where
602 P: AsParams + Send,
603{
604 let mut res = Vec::with_capacity(iter.size_hint().0);
605
606 {
607 let mut conn = pool.get().await?;
608
609 for stmt in iter {
610 let fut = stmt.execute(&mut conn).await?;
611 res.push(fut);
612 }
613 }
614
615 let mut num = 0;
616
617 for res in res {
618 num += res.await?;
619 }
620
621 Ok(num)
622}
623
624async fn query_iter_with_pool<P>(
625 iter: impl Iterator<Item = StatementQuery<'_, P>> + Send,
626 pool: &Pool,
627) -> Result<Vec<RowStreamOwned>, Error>
628where
629 P: AsParams + Send,
630{
631 let mut res = Vec::with_capacity(iter.size_hint().0);
632
633 let mut conn = pool.get().await?;
634
635 for stmt in iter {
636 let stream = stmt.query(&mut conn).await?;
637 res.push(stream);
638 }
639
640 Ok(res)
641}
642
643pub enum StatementCacheFuture<'c> {
644 Cached(CachedStatement),
645 Prepared(BoxedFuture<'c, Result<CachedStatement, Error>>),
646 Done,
647}
648
649impl Future for StatementCacheFuture<'_> {
650 type Output = Result<CachedStatement, Error>;
651
652 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
653 let this = self.get_mut();
654 match mem::replace(this, Self::Done) {
655 Self::Cached(stmt) => Poll::Ready(Ok(stmt)),
656 Self::Prepared(mut fut) => {
657 let res = fut.as_mut().poll(cx);
658 if res.is_pending() {
659 drop(mem::replace(this, Self::Prepared(fut)));
660 }
661 res
662 }
663 Self::Done => panic!("StatementCacheFuture polled after finish"),
664 }
665 }
666}
667
668#[cfg(not(feature = "io-uring"))]
669#[cfg(test)]
670mod test {
671 use super::*;
672
673 #[tokio::test]
674 async fn pool() {
675 let pool = Pool::builder("postgres://postgres:postgres@localhost:5432")
676 .build()
677 .unwrap();
678
679 {
680 let mut conn = pool.get().await.unwrap();
681
682 let stmt = Statement::named("SELECT 1", &[]).execute(&mut conn).await.unwrap();
683 stmt.execute(&conn.consume()).await.unwrap();
684
685 let num = Statement::named("SELECT 1", &[])
686 .bind_none()
687 .query(&pool)
688 .await
689 .unwrap()
690 .try_next()
691 .await
692 .unwrap()
693 .unwrap()
694 .get::<i32>(0);
695
696 assert_eq!(num, 1);
697 }
698
699 let res = [
700 Statement::named("SELECT 1", &[]).bind_none(),
701 Statement::named("SELECT 1", &[]).bind_none(),
702 ]
703 .query(&pool)
704 .await
705 .unwrap();
706
707 for mut res in res {
708 let num = res.try_next().await.unwrap().unwrap().get::<i32>(0);
709 assert_eq!(num, 1);
710 }
711
712 let _ = vec![
713 Statement::named("SELECT 1", &[]).bind_dyn(&[&1]),
714 Statement::named("SELECT 1", &[]).bind_dyn(&[&"123"]),
715 Statement::named("SELECT 1", &[]).bind_dyn(&[&String::new()]),
716 ]
717 .query(&pool)
718 .await;
719 }
720}