use core::{
future::Future,
mem,
pin::Pin,
task::{Context, Poll},
};
use std::{
collections::{HashMap, VecDeque},
sync::{Arc, Mutex},
};
use tokio::sync::{Semaphore, SemaphorePermit};
use xitca_io::bytes::BytesMut;
use super::{
client::{Client, ClientBorrowMut},
config::Config,
copy::{r#Copy, CopyIn, CopyOut},
driver::codec::{encode::Encode, Response},
error::Error,
execute::{Execute, ExecuteMut},
iter::AsyncLendingIterator,
prepare::Prepare,
query::Query,
session::Session,
statement::{Statement, StatementNamed},
transaction::Transaction,
types::{Oid, Type},
BoxedFuture, Postgres,
};
pub struct PoolBuilder {
config: Result<Config, Error>,
capacity: usize,
}
impl PoolBuilder {
pub fn capacity(mut self, cap: usize) -> Self {
self.capacity = cap;
self
}
pub fn build(self) -> Result<Pool, Error> {
let config = self.config?;
Ok(Pool {
conn: Mutex::new(VecDeque::with_capacity(self.capacity)),
permits: Semaphore::new(self.capacity),
config,
})
}
}
pub struct Pool {
conn: Mutex<VecDeque<PoolClient>>,
permits: Semaphore,
config: Config,
}
impl Pool {
pub fn builder<C>(cfg: C) -> PoolBuilder
where
Config: TryFrom<C>,
Error: From<<Config as TryFrom<C>>::Error>,
{
PoolBuilder {
config: cfg.try_into().map_err(Into::into),
capacity: 1,
}
}
pub async fn get(&self) -> Result<PoolConnection<'_>, Error> {
let _permit = self.permits.acquire().await.expect("Semaphore must not be closed");
let conn = self.conn.lock().unwrap().pop_front();
let conn = match conn {
Some(conn) => conn,
None => self.connect().await?,
};
Ok(PoolConnection {
pool: self,
conn: Some(conn),
_permit,
})
}
#[inline(never)]
fn connect(&self) -> BoxedFuture<'_, Result<PoolClient, Error>> {
Box::pin(async move {
let (client, mut driver) = Postgres::new(self.config.clone()).connect().await?;
tokio::task::spawn(async move {
while let Ok(Some(_)) = driver.try_next().await {
}
});
Ok(PoolClient::new(client))
})
}
}
pub struct PoolConnection<'a> {
pool: &'a Pool,
conn: Option<PoolClient>,
_permit: SemaphorePermit<'a>,
}
impl PoolConnection<'_> {
#[inline]
pub fn transaction(&mut self) -> impl Future<Output = Result<Transaction<Self>, Error>> + Send {
Transaction::<Self>::builder().begin(self)
}
#[inline]
pub fn copy_in(&mut self, stmt: &Statement) -> impl Future<Output = Result<CopyIn<Self>, Error>> + Send {
CopyIn::new(self, stmt)
}
#[inline]
pub async fn copy_out(&self, stmt: &Statement) -> Result<CopyOut, Error> {
CopyOut::new(self, stmt).await
}
#[inline(always)]
pub fn consume(self) -> Self {
self
}
pub fn cancel_token(&self) -> Session {
self.conn().client.cancel_token()
}
fn insert_cache(&mut self, named: &str, stmt: Statement) -> Arc<Statement> {
let stmt = Arc::new(stmt);
self.conn_mut().statements.insert(Box::from(named), stmt.clone());
stmt
}
fn conn(&self) -> &PoolClient {
self.conn.as_ref().unwrap()
}
fn conn_mut(&mut self) -> &mut PoolClient {
self.conn.as_mut().unwrap()
}
}
impl ClientBorrowMut for PoolConnection<'_> {
#[inline]
fn _borrow_mut(&mut self) -> &mut Client {
&mut self.conn_mut().client
}
}
impl Prepare for PoolConnection<'_> {
#[inline]
fn _get_type(&self, oid: Oid) -> BoxedFuture<'_, Result<Type, Error>> {
self.conn().client._get_type(oid)
}
#[inline]
fn _get_type_blocking(&self, oid: Oid) -> Result<Type, Error> {
self.conn().client._get_type_blocking(oid)
}
}
impl Query for PoolConnection<'_> {
#[inline]
fn _send_encode_query<S>(&self, stmt: S) -> Result<(S::Output, Response), Error>
where
S: Encode,
{
self.conn().client._send_encode_query(stmt)
}
}
impl r#Copy for PoolConnection<'_> {
#[inline]
fn send_one_way<F>(&self, func: F) -> Result<(), Error>
where
F: FnOnce(&mut BytesMut) -> Result<(), Error>,
{
self.conn().client.send_one_way(func)
}
}
impl Drop for PoolConnection<'_> {
fn drop(&mut self) {
let conn = self.conn.take().unwrap();
if conn.client.closed() {
return;
}
self.pool.conn.lock().unwrap().push_back(conn);
}
}
struct PoolClient {
client: Client,
statements: HashMap<Box<str>, Arc<Statement>>,
}
impl PoolClient {
fn new(client: Client) -> Self {
Self {
client,
statements: HashMap::new(),
}
}
}
impl<'c, 's> ExecuteMut<'c, PoolConnection<'_>> for StatementNamed<'s>
where
's: 'c,
{
type ExecuteMutOutput = StatementCacheFuture<'c>;
type QueryMutOutput = Self::ExecuteMutOutput;
fn execute_mut(self, cli: &'c mut PoolConnection) -> Self::ExecuteMutOutput {
match cli.conn().statements.get(self.stmt) {
Some(stmt) => StatementCacheFuture::Cached(stmt.clone()),
None => StatementCacheFuture::Prepared(Box::pin(async move {
let stmt = self.execute(cli).await?.leak();
Ok(cli.insert_cache(self.stmt, stmt))
})),
}
}
#[inline]
fn query_mut(self, cli: &'c mut PoolConnection) -> Self::QueryMutOutput {
self.execute_mut(cli)
}
}
pub enum StatementCacheFuture<'c> {
Cached(Arc<Statement>),
Prepared(BoxedFuture<'c, Result<Arc<Statement>, Error>>),
Done,
}
impl Future for StatementCacheFuture<'_> {
type Output = Result<Arc<Statement>, Error>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
match mem::replace(this, Self::Done) {
Self::Cached(stmt) => Poll::Ready(Ok(stmt)),
Self::Prepared(mut fut) => {
let res = fut.as_mut().poll(cx);
if res.is_pending() {
drop(mem::replace(this, Self::Prepared(fut)));
}
res
}
Self::Done => panic!("StatementCacheFuture polled after finish"),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[tokio::test]
async fn pool() {
let pool = Pool::builder("postgres://postgres:postgres@localhost:5432")
.build()
.unwrap();
let mut conn = pool.get().await.unwrap();
let stmt = Statement::named("SELECT 1", &[]).execute_mut(&mut conn).await.unwrap();
stmt.execute(&conn.consume()).await.unwrap();
}
}