use std::{future::Future, marker::PhantomData, pin::Pin, sync::Arc, time::Duration};
use tokio::sync::{mpsc, Mutex};
use super::pooled::PooledConnection;
use crate::netlink::{
connection::Connection,
error::{Error, Result},
namespace,
protocol::{
construction::{AsyncConstructible, SyncConstructible},
AsyncProtocolInit, ProtocolState,
},
};
pub(super) type FactoryFuture<P> =
Pin<Box<dyn Future<Output = Result<Connection<P>>> + Send>>;
pub(super) trait Factory<P: ProtocolState>: Send + Sync + 'static {
fn build(&self) -> FactoryFuture<P>;
}
pub(super) struct PoolInner<P: ProtocolState> {
pub(super) available: mpsc::Sender<Connection<P>>,
pub(super) receiver: Mutex<mpsc::Receiver<Connection<P>>>,
pub(super) namespace: Option<String>,
pub(super) size: usize,
pub(super) acquire_timeout: Duration,
pub(super) factory: Arc<dyn Factory<P>>,
}
#[non_exhaustive]
pub struct ConnectionPool<P: ProtocolState> {
pub(super) inner: Arc<PoolInner<P>>,
}
impl<P: ProtocolState> ConnectionPool<P> {
pub async fn acquire(&self) -> Result<PooledConnection<'_, P>> {
let recv_fut = async {
let mut rx = self.inner.receiver.lock().await;
rx.recv().await
};
match tokio::time::timeout(self.inner.acquire_timeout, recv_fut).await {
Ok(Some(conn)) => Ok(PooledConnection::new(self, conn)),
Ok(None) => Err(Error::PoolClosed),
Err(_) => Err(Error::PoolExhausted {
size: self.inner.size,
timeout: self.inner.acquire_timeout,
}),
}
}
pub fn size(&self) -> usize {
self.inner.size
}
pub fn acquire_timeout(&self) -> Duration {
self.inner.acquire_timeout
}
pub fn namespace(&self) -> Option<&str> {
self.inner.namespace.as_deref()
}
}
impl<P: ProtocolState + Default + SyncConstructible + 'static> ConnectionPool<P> {
pub async fn for_namespace(
ns: impl Into<String>,
size: usize,
) -> Result<Self> {
ConnectionPoolBuilder::new()
.namespace(ns)
.size(size)
.build()
.await
}
}
#[non_exhaustive]
pub struct ConnectionPoolBuilder<P: ProtocolState> {
size: usize,
acquire_timeout: Duration,
namespace: Option<String>,
_phantom: PhantomData<fn() -> P>,
}
impl<P: ProtocolState> Default for ConnectionPoolBuilder<P> {
fn default() -> Self {
Self::new()
}
}
impl<P: ProtocolState> ConnectionPoolBuilder<P> {
pub fn new() -> Self {
let size = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(4);
Self {
size,
acquire_timeout: Duration::from_secs(5),
namespace: None,
_phantom: PhantomData,
}
}
pub fn size(mut self, size: usize) -> Self {
self.size = size.max(1);
self
}
pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
self.acquire_timeout = timeout;
self
}
pub fn namespace(mut self, name: impl Into<String>) -> Self {
self.namespace = Some(name.into());
self
}
}
impl<P> ConnectionPoolBuilder<P>
where
P: ProtocolState + Default + SyncConstructible + 'static,
{
pub async fn build(self) -> Result<ConnectionPool<P>> {
let (tx, rx) = mpsc::channel(self.size);
let namespace = self.namespace.clone();
for _ in 0..self.size {
let conn = match &namespace {
Some(ns) => namespace::connection_for::<P>(ns)?,
None => Connection::<P>::new()?,
};
tx.send(conn).await.map_err(|_| Error::PoolClosed)?;
}
let factory: Arc<dyn Factory<P>> = Arc::new(SyncFactory {
namespace: namespace.clone(),
_phantom: PhantomData,
});
Ok(ConnectionPool {
inner: Arc::new(PoolInner {
available: tx,
receiver: Mutex::new(rx),
namespace,
size: self.size,
acquire_timeout: self.acquire_timeout,
factory,
}),
})
}
}
struct SyncFactory<P: ProtocolState + Default + SyncConstructible + 'static> {
namespace: Option<String>,
_phantom: PhantomData<fn() -> P>,
}
impl<P> Factory<P> for SyncFactory<P>
where
P: ProtocolState + Default + SyncConstructible + 'static,
{
fn build(&self) -> FactoryFuture<P> {
let namespace = self.namespace.clone();
Box::pin(async move {
match &namespace {
Some(ns) => namespace::connection_for::<P>(ns),
None => Connection::<P>::new(),
}
})
}
}
impl<P> ConnectionPoolBuilder<P>
where
P: ProtocolState + AsyncProtocolInit + AsyncConstructible + 'static,
{
pub async fn build_async(self) -> Result<ConnectionPool<P>> {
let (tx, rx) = mpsc::channel(self.size);
let namespace = self.namespace.clone();
for _ in 0..self.size {
let conn = match &namespace {
Some(ns) => namespace::connection_for_async::<P>(ns).await?,
None => {
let socket = crate::netlink::socket::NetlinkSocket::new(P::PROTOCOL)?;
let state = P::resolve_async(&socket).await?;
Connection::from_parts(socket, state)
}
};
tx.send(conn).await.map_err(|_| Error::PoolClosed)?;
}
let factory: Arc<dyn Factory<P>> = Arc::new(AsyncFactory {
namespace: namespace.clone(),
_phantom: PhantomData,
});
Ok(ConnectionPool {
inner: Arc::new(PoolInner {
available: tx,
receiver: Mutex::new(rx),
namespace,
size: self.size,
acquire_timeout: self.acquire_timeout,
factory,
}),
})
}
}
struct AsyncFactory<P: ProtocolState + AsyncProtocolInit + AsyncConstructible + 'static> {
namespace: Option<String>,
_phantom: PhantomData<fn() -> P>,
}
impl<P> Factory<P> for AsyncFactory<P>
where
P: ProtocolState + AsyncProtocolInit + AsyncConstructible + 'static,
{
fn build(&self) -> FactoryFuture<P> {
let namespace = self.namespace.clone();
Box::pin(async move {
match &namespace {
Some(ns) => namespace::connection_for_async::<P>(ns).await,
None => {
let socket = crate::netlink::socket::NetlinkSocket::new(P::PROTOCOL)?;
let state = P::resolve_async(&socket).await?;
Ok(Connection::from_parts(socket, state))
}
}
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::netlink::protocol::Route;
#[test]
fn builder_defaults() {
let b = ConnectionPoolBuilder::<Route>::new();
assert!(b.size >= 1);
assert_eq!(b.acquire_timeout, Duration::from_secs(5));
assert_eq!(b.namespace, None);
}
#[test]
fn builder_size_clamped_to_at_least_one() {
let b = ConnectionPoolBuilder::<Route>::new().size(0);
assert_eq!(b.size, 1);
}
#[test]
fn builder_namespace_setter() {
let b = ConnectionPoolBuilder::<Route>::new().namespace("myns");
assert_eq!(b.namespace.as_deref(), Some("myns"));
}
}