use std::{
net::SocketAddr,
pin::Pin,
sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
},
};
use bb8::Pool;
use bincode::Encode;
use futures::{Future, future::select_all};
use tokio::sync::Notify;
use super::{
ClientPoolConfig,
handle::PoolHandle,
lease::PooledClientLease,
manager::WireframeConnectionManager,
scheduler::PoolScheduler,
slot::PoolSlot,
};
use crate::{
client::{ClientError, connect_parts::ClientBuildParts},
serializer::Serializer,
};
type AcquirePermit<S, P, C> = (Arc<PoolSlot<S, P, C>>, tokio::sync::OwnedSemaphorePermit);
type AcquirePermitFuture<S, P, C> =
Pin<Box<dyn Future<Output = Result<AcquirePermit<S, P, C>, ClientError>> + Send>>;
pub(crate) struct ClientPoolInner<S, P = (), C = ()>
where
S: Serializer + Clone + Send + Sync + 'static,
P: Encode + Clone + Send + Sync + 'static,
C: Send + 'static,
{
pub(crate) slots: Arc<[Arc<PoolSlot<S, P, C>>]>,
pub(crate) next_slot: AtomicUsize,
pub(crate) scheduler: Arc<PoolScheduler<S, P, C>>,
shutdown: AtomicBool,
shutdown_notify: Notify,
}
pub struct WireframeClientPool<S, P = (), C = ()>
where
S: Serializer + Clone + Send + Sync + 'static,
P: Encode + Clone + Send + Sync + 'static,
C: Send + 'static,
{
inner: Arc<ClientPoolInner<S, P, C>>,
}
impl<S, P, C> WireframeClientPool<S, P, C>
where
S: Serializer + Clone + Send + Sync + 'static,
P: Encode + Clone + Send + Sync + 'static,
C: Send + 'static,
{
pub(crate) async fn connect(
addr: SocketAddr,
pool_config: ClientPoolConfig,
parts: ClientBuildParts<S, P, C>,
) -> Result<Self, ClientError> {
let fairness_policy = pool_config.fairness_policy_value();
let mut slots = Vec::with_capacity(pool_config.pool_size_value());
for _ in 0..pool_config.pool_size_value() {
let manager = WireframeConnectionManager::new(addr, parts.clone());
let bb8_pool = Pool::builder()
.max_size(1)
.idle_timeout(Some(pool_config.idle_timeout_value()))
.reaper_rate(pool_config.idle_timeout_value())
.build(manager)
.await?;
slots.push(Arc::new(PoolSlot::new(
bb8_pool,
pool_config.max_in_flight_per_socket_value(),
pool_config.idle_timeout_value(),
)));
}
Ok(Self {
inner: Arc::new(ClientPoolInner {
slots: Arc::from(slots),
next_slot: AtomicUsize::new(0),
scheduler: Arc::new(PoolScheduler::new(fairness_policy)),
shutdown: AtomicBool::new(false),
shutdown_notify: Notify::new(),
}),
})
}
#[must_use]
pub fn handle(&self) -> PoolHandle<S, P, C> {
let handle_id = self.inner.scheduler.register_handle();
PoolHandle::new(Arc::clone(&self.inner), handle_id)
}
pub async fn acquire(&self) -> Result<PooledClientLease<S, P, C>, ClientError> {
let mut handle = self.handle();
handle.acquire().await
}
pub async fn close(self) {
self.inner.shutdown();
tokio::task::yield_now().await;
drop(self);
}
}
impl<S, P, C> ClientPoolInner<S, P, C>
where
S: Serializer + Clone + Send + Sync + 'static,
P: Encode + Clone + Send + Sync + 'static,
C: Send + 'static,
{
pub(crate) fn is_shutdown(&self) -> bool { self.shutdown.load(Ordering::Acquire) }
pub(crate) async fn shutdown_notified(&self) { self.shutdown_notify.notified().await; }
pub(crate) fn shutdown(&self) {
self.shutdown.store(true, Ordering::Release);
self.shutdown_notify.notify_waiters();
self.scheduler.notify_shutdown();
}
pub(crate) fn try_acquire_immediately(self: &Arc<Self>) -> Option<PooledClientLease<S, P, C>> {
if self.is_shutdown() {
return None;
}
self.ordered_slots().into_iter().find_map(|slot| {
slot.try_acquire_permit()
.map(|permit| PooledClientLease::new(slot, permit, Some(Arc::clone(self))))
})
}
pub(crate) async fn acquire_slot_permit(&self) -> Result<AcquirePermit<S, P, C>, ClientError> {
let waiters = self
.ordered_slots()
.into_iter()
.map(|slot| {
Box::pin(async move {
let permit = slot.acquire_permit().await?;
Ok::<_, ClientError>((slot, permit))
}) as AcquirePermitFuture<S, P, C>
})
.collect::<Vec<_>>();
let (result, ..) = select_all(waiters).await;
result
}
fn ordered_slots(&self) -> Vec<Arc<PoolSlot<S, P, C>>> {
let mut ordered = self.slots.iter().cloned().collect::<Vec<_>>();
let len = ordered.len();
if len > 0 {
let start = self.next_slot.fetch_add(1, Ordering::Relaxed);
ordered.rotate_left(start.wrapping_rem(len));
}
ordered
}
}