use std::future::Future;
use std::marker::Unpin;
use std::pin::Pin;
use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;
use async_channel::bounded;
use async_channel::unbounded;
use async_channel::Receiver;
use async_channel::Sender;
use futures::future::Either;
use futures::FutureExt;
struct PoolInternal<T, F: Fn() -> U, U: Future<Output = Result<T, E>>, E> {
sender: Sender<T>,
receiver: Receiver<T>,
out: AtomicUsize,
gen: F,
cap: usize,
}
pub struct Pool<T: 'static, F: Fn() -> U, U: Future<Output = Result<T, E>> + Unpin, E>(
Arc<PoolInternal<T, F, U, E>>,
);
impl<T, F, U, E> Clone for Pool<T, F, U, E>
where
F: Fn() -> U,
U: Future<Output = Result<T, E>> + Unpin,
{
fn clone(&self) -> Self {
Pool(self.0.clone())
}
}
impl<T, F, U, E> Pool<T, F, U, E>
where
F: Fn() -> U,
U: Future<Output = Result<T, E>> + Unpin,
{
pub fn new(cap: usize, gen: F) -> Self {
let (sender, receiver) = if cap > 0 { bounded(cap) } else { unbounded() };
Pool(Arc::new(PoolInternal {
sender,
receiver,
out: AtomicUsize::new(0),
gen,
cap,
}))
}
#[cfg(feature = "timeout")]
pub async fn get(self, timeout: Duration) -> Result<PoolGuard<T, F, U, E>, E> {
match self.0.receiver.recv() {
Ok(t) => Ok(PoolGuard {
pool: self.clone(),
item: Some(t),
}),
Err(_) => {
let closure = |pool| {
move |t| PoolGuard {
pool,
item: Some(t),
}
};
if self.0.cap > 0 {
let out = self.0.out.load(Ordering::SeqCst);
if out < self.0.cap {
self.0.out.fetch_add(1, Ordering::SeqCst);
(self.0.gen)().await.map(closure(self.clone()))
} else {
PoolFuture::new(self.clone(), timeout).await
}
} else {
(self.0.gen)().await.map(closure(self.clone()))
}
}
}
}
pub async fn get(self) -> Result<PoolGuard<T, F, U, E>, E> {
match self.0.receiver.try_recv() {
Ok(t) => Ok(PoolGuard {
pool: self.clone(),
item: Some(t),
dirty: false,
}),
Err(_) => {
let closure = |pool| {
move |t| PoolGuard {
pool,
item: Some(t),
dirty: false,
}
};
if self.0.cap > 0 {
let out = self.0.out.load(Ordering::SeqCst);
if out < self.0.cap {
self.0.out.fetch_add(1, Ordering::SeqCst);
(self.0.gen)().await.map(closure(self.clone()))
} else {
PoolFuture::new(self.clone()).await
}
} else {
(self.0.gen)().await.map(closure(self.clone()))
}
}
}
}
pub fn len(&self) -> usize {
self.0.receiver.len()
}
pub fn add(&self, item: T) -> Result<(), async_channel::TrySendError<T>> {
self.0.sender.try_send(item)
}
}
struct PoolFuture<T: 'static, F: Fn() -> U, U: Future<Output = Result<T, E>> + Unpin, E> {
pool: Pool<T, F, U, E>,
internal: Option<Either<U, Pin<Box<dyn Future<Output = Result<T, E>>>>>>,
}
impl<'a, T, F, U, E> PoolFuture<T, F, U, E>
where
F: Fn() -> U,
U: Future<Output = Result<T, E>> + Unpin,
{
fn new(pool: Pool<T, F, U, E>) -> Self {
PoolFuture {
internal: None,
pool,
}
}
}
impl<T, F, U, E> Future for PoolFuture<T, F, U, E>
where
F: Fn() -> U,
U: Future<Output = Result<T, E>> + Unpin,
{
type Output = Result<PoolGuard<T, F, U, E>, E>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match &mut self.internal {
None => match self.pool.0.receiver.try_recv() {
Ok(t) => Poll::Ready(Ok(PoolGuard {
pool: self.pool.clone(),
item: Some(t),
dirty: false,
})),
Err(_) => {
if self.pool.0.out.load(Ordering::SeqCst) < self.pool.0.cap {
self.internal = Some(Either::Left((self.pool.0.gen)()));
cx.waker().clone().wake();
Poll::Pending
} else {
let recv = self.pool.0.receiver.clone();
self.internal = Some(Either::Right(
async move { Ok(recv.recv().await.unwrap()) }.boxed_local(),
));
cx.waker().clone().wake();
Poll::Pending
}
}
},
Some(ref mut fut) => match Pin::new(fut).poll(cx) {
Poll::Ready(Ok(t)) => Poll::Ready(Ok(PoolGuard {
pool: self.pool.clone(),
item: Some(t),
dirty: false,
})),
Poll::Ready(Err(e)) => Poll::Ready(Err(e)),
Poll::Pending => Poll::Pending,
},
}
}
}
pub struct PoolGuard<T: 'static, F: Fn() -> U, U: Future<Output = Result<T, E>> + Unpin, E> {
pool: Pool<T, F, U, E>,
item: Option<T>,
dirty: bool,
}
impl<T, F, U, E> PoolGuard<T, F, U, E>
where
F: Fn() -> U,
U: Future<Output = Result<T, E>> + Unpin,
{
pub fn mark_dirty(&mut self) {
self.dirty = true;
}
pub fn mark_clean(&mut self) {
self.dirty = false;
}
pub fn detach(mut self) -> T {
let item = self.item.take();
if self.pool.0.cap > 0 {
self.pool.0.out.fetch_sub(1, Ordering::SeqCst);
}
item.unwrap()
}
pub fn destroy(self) {
self.detach();
}
}
impl<T, F, U, E> std::ops::Deref for PoolGuard<T, F, U, E>
where
F: Fn() -> U,
U: Future<Output = Result<T, E>> + Unpin,
{
type Target = T;
fn deref(&self) -> &T {
self.item.as_ref().unwrap()
}
}
impl<T, F, U, E> std::ops::DerefMut for PoolGuard<T, F, U, E>
where
F: Fn() -> U,
U: Future<Output = Result<T, E>> + Unpin,
{
fn deref_mut(&mut self) -> &mut T {
self.item.as_mut().unwrap()
}
}
impl<T, F, U, E> Drop for PoolGuard<T, F, U, E>
where
F: Fn() -> U,
U: Future<Output = Result<T, E>> + Unpin,
{
fn drop(&mut self) {
if self.dirty {
self.pool.0.out.fetch_sub(1, Ordering::SeqCst);
} else {
if let Some(item) = self.item.take() {
match self.pool.add(item) {
Ok(_) => {
if self.pool.0.cap > 0 {
self.pool.0.out.fetch_sub(1, Ordering::SeqCst);
}
}
Err(_) => (),
};
}
}
}
}
#[cfg(test)]
async fn test_fut() -> Result<(), failure::Error> {
use futures::TryFutureExt;
let pool = Pool::new(20, || {
tokio_postgres::connect(
"postgres://aiden:pass@localhost:5432/pgdb",
tokio_postgres::NoTls,
)
.map_ok(|(client, connection)| {
let connection = connection.map_err(|e| eprintln!("connection error: {}", e));
tokio::spawn(connection);
client
})
.boxed()
});
let client = pool.clone().get().await?;
let stmt = client.prepare("SELECT $1::TEXT").await?;
let rows = client.query(&stmt, &[&"hello".to_owned()]).await?;
let hello: String = rows[0].get(0);
println!("{}", hello);
assert_eq!("hello", &hello);
println!("len: {}", pool.len());
assert_eq!(1, pool.len());
Ok(())
}
#[test]
fn test() {
tokio::runtime::Runtime::new()
.unwrap()
.block_on(test_fut())
.unwrap();
}