use std::{path::Path, sync::Arc};
use nix::sys::socket::UnixAddr;
use tokio::sync::{Mutex, MutexGuard, RwLock};
use super::{IOError, Listener, SocketAddress, Stream};
#[derive(Debug)]
pub enum PoolError {
NoAddressesSpecified,
InvalidSourceAddress,
}
#[derive(Debug)]
pub struct PoolGuard<'item> {
value: MutexGuard<'item, Stream>,
}
#[derive(Debug)]
pub struct StreamPool {
addresses: Vec<SocketAddress>, handles: Vec<Mutex<Stream>>,
}
pub type SharedStreamPool = Arc<RwLock<StreamPool>>;
impl Drop for PoolGuard<'_> {
fn drop(&mut self) {
self.value.reset();
}
}
impl<'item> PoolGuard<'item> {
#[must_use]
pub fn new(value: MutexGuard<'item, Stream>) -> Self {
Self { value }
}
}
impl std::ops::Deref for PoolGuard<'_> {
type Target = Stream;
fn deref(&self) -> &Self::Target {
&self.value
}
}
impl std::ops::DerefMut for PoolGuard<'_> {
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.value
}
}
impl StreamPool {
pub fn new(
start_address: SocketAddress,
mut count: u8,
) -> Result<Self, IOError> {
if count == 0 {
return Err(IOError::PoolError(PoolError::NoAddressesSpecified));
}
println!("StreamPool start address: {start_address}");
let mut addresses = Vec::new();
let mut addr = start_address;
while count > 0 {
addresses.push(addr.clone());
count -= 1;
if count == 0 {
break; }
addr = addr.next_address()?;
}
Ok(Self::with_addresses(addresses))
}
pub fn single(address: SocketAddress) -> Result<Self, IOError> {
Self::new(address, 1)
}
#[must_use]
fn with_addresses(
addresses: impl IntoIterator<Item = SocketAddress>,
) -> Self {
let addresses: Vec<SocketAddress> = addresses.into_iter().collect();
let streams: Vec<Stream> = addresses.iter().map(Stream::new).collect();
let handles = streams.into_iter().map(Mutex::new).collect();
Self { addresses, handles }
}
#[must_use]
pub fn shared(self) -> SharedStreamPool {
Arc::new(RwLock::new(self))
}
#[must_use]
pub fn len(&self) -> usize {
self.addresses.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub fn busy(&self) -> bool {
for h in &self.handles {
if h.try_lock().is_ok() {
return false;
}
}
true
}
pub async fn get(&self) -> PoolGuard<'_> {
assert!(
!self.handles.is_empty(),
"empty handles in AsyncPool. Bad init?"
);
let iter = self.handles.iter().map(|h| {
let l = h.lock();
Box::pin(l)
});
let (guard, _, _) = futures::future::select_all(iter).await;
PoolGuard::new(guard)
}
pub fn listen(&self) -> Result<Vec<Listener>, IOError> {
let mut listeners = Vec::new();
for addr in &self.addresses {
let listener = Listener::listen(addr)?;
listeners.push(listener);
}
Ok(listeners)
}
pub fn expand_to(&mut self, size: u8) -> Result<(), IOError> {
println!("StreamPool: expanding async pool to {size}");
let size = size as usize;
if let Some(last_address) = self.addresses.last().cloned() {
let mut next = last_address;
let count = self.addresses.len();
for _ in count..size {
next = next.next_address()?;
self.handles.push(Mutex::new(Stream::new(&next)));
self.addresses.push(next.clone());
}
}
Ok(())
}
pub fn listen_to(&mut self, size: u8) -> Result<Vec<Listener>, IOError> {
println!("StreamPool: listening async pool to {size}");
let size = size as usize;
let mut listeners = Vec::new();
if let Some(last_address) = self.addresses.last().cloned() {
let mut next = last_address;
let count = self.addresses.len();
for _ in count..size {
next = next.next_address()?;
self.addresses.push(next.clone());
let listener = Listener::listen(&next)?;
listeners.push(listener);
}
}
Ok(listeners)
}
#[must_use]
pub fn to_streams(self) -> Vec<Stream> {
self.handles.into_iter().map(tokio::sync::Mutex::into_inner).collect()
}
}
fn next_usock_path(path: &Path) -> Result<String, IOError> {
let path =
path.as_os_str().to_str().ok_or(IOError::ConnectAddressInvalid)?;
if let Some(underscore_index) = path.rfind('_') {
let num_str = &path[underscore_index + 1..];
let num = num_str.parse::<usize>();
Ok(match num {
Ok(index) => {
format!("{}_{}", &path[0..underscore_index], index + 1)
}
Err(_) => format!("{path}_0"), })
} else {
Ok(format!("{path}_0"))
}
}
impl SocketAddress {
pub(crate) fn next_address(&self) -> Result<Self, IOError> {
match self {
Self::Unix(usock) => match usock.path() {
Some(path) => {
let path: &str = &next_usock_path(path)?;
let addr = UnixAddr::new(path)?;
Ok(Self::Unix(addr))
}
None => {
Err(IOError::PoolError(PoolError::InvalidSourceAddress))
}
},
#[cfg(feature = "vm")]
Self::Vsock(vsock) => Ok(Self::new_vsock(
vsock.cid(),
vsock.port() + 1,
super::vsock_svm_flags(*vsock),
)),
}
}
}
#[cfg(test)]
mod test {
use std::path::PathBuf;
use super::*;
#[tokio::test]
async fn test_async_pool_available() {
let start_addr = SocketAddress::new_unix("/tmp/never.sock");
let pool = StreamPool::new(start_addr, 2).unwrap();
let first = pool.get().await;
let second = pool.get().await;
drop(first);
let third = pool.get().await;
let result = tokio::time::timeout(
std::time::Duration::from_millis(200),
async {
let _fourth = pool.get().await;
},
)
.await;
drop(third);
drop(second);
assert!(result.is_err()); }
#[tokio::test]
async fn test_pool_guard_hatch() {
let server_addr =
SocketAddress::new_unix("/tmp/test_pool_guard_hatch.sock");
let server = Listener::listen(&server_addr).unwrap();
let server_task = tokio::task::spawn(async move {
let _stream = server.accept().await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(1)).await;
});
let pool = StreamPool::new(server_addr, 1).unwrap().shared();
let pool_clone = pool.clone();
let client_task = tokio::task::spawn(async move {
let borrowed_pool = pool_clone.read().await;
let mut stream = borrowed_pool.get().await;
let _ = stream.call(&[1]).await;
});
tokio::time::sleep(std::time::Duration::from_millis(300)).await;
client_task.abort();
let borrowed_pool = pool.read().await;
let stream = borrowed_pool.get().await;
assert!(!stream.is_connected());
server_task.abort();
}
#[test]
fn next_usock_path_works() {
assert_eq!(
next_usock_path(&PathBuf::from("basic")).unwrap(),
"basic_0"
);
assert_eq!(next_usock_path(&PathBuf::from("")).unwrap(), "_0");
assert_eq!(
next_usock_path(&PathBuf::from("with_underscore_elsewhere"))
.unwrap(),
"with_underscore_elsewhere_0"
);
assert_eq!(
next_usock_path(&PathBuf::from("with_underscore_at_end_")).unwrap(),
"with_underscore_at_end__0"
);
assert_eq!(
next_usock_path(&PathBuf::from("good_num_2")).unwrap(),
"good_num_3"
);
assert_eq!(
next_usock_path(&PathBuf::from("good_num_34")).unwrap(),
"good_num_35"
);
assert_eq!(
next_usock_path(&PathBuf::from("good_num_999")).unwrap(),
"good_num_1000"
);
}
}