use crate::Server;
use once_cell::sync::OnceCell;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};
use std::sync::Mutex;
#[derive(Debug)]
pub struct ServerPool(OnceCell<InnerPool>, usize);
impl ServerPool {
pub const fn new(max_servers: usize) -> Self {
ServerPool(OnceCell::new(), max_servers)
}
pub fn get_server(&self) -> ServerHandle {
self.0.get_or_init(|| InnerPool::new(self.1)).get_server()
}
}
#[allow(clippy::mutex_atomic)]
#[derive(Debug)]
struct InnerPool {
servers_created: Mutex<usize>,
servers_tx: crossbeam_channel::Sender<Server>,
servers_rx: crossbeam_channel::Receiver<Server>,
}
#[allow(clippy::mutex_atomic)]
impl InnerPool {
fn new(max_capacity: usize) -> Self {
assert!(max_capacity > 0);
let (servers_tx, servers_rx) = crossbeam_channel::bounded(max_capacity);
InnerPool {
servers_created: Mutex::new(0),
servers_tx,
servers_rx,
}
}
fn get_server(&self) -> ServerHandle {
if let Ok(server) = self.servers_rx.try_recv() {
return ServerHandle {
servers_tx: self.servers_tx.clone(),
server: Some(server),
lifetime_marker: PhantomData,
};
}
{
let mut servers_created = self.servers_created.lock().expect("poisoned mutex");
if *servers_created < self.servers_tx.capacity().unwrap() {
*servers_created += 1;
return ServerHandle {
servers_tx: self.servers_tx.clone(),
server: Some(Server::run()),
lifetime_marker: PhantomData,
};
}
}
ServerHandle {
servers_tx: self.servers_tx.clone(),
server: Some(
self.servers_rx
.recv()
.expect("all senders unexpectedly dropped"),
),
lifetime_marker: PhantomData,
}
}
}
#[allow(clippy::mutex_atomic)]
impl Drop for InnerPool {
fn drop(&mut self) {
let servers_created = self.servers_created.lock().expect("poisoned mutex");
for _ in 0..*servers_created {
self.servers_rx
.recv()
.expect("all senders unexpectedly dropped");
}
}
}
#[derive(Debug)]
pub struct ServerHandle<'a> {
servers_tx: crossbeam_channel::Sender<Server>,
server: Option<Server>,
lifetime_marker: PhantomData<&'a ()>,
}
impl Deref for ServerHandle<'_> {
type Target = Server;
fn deref(&self) -> &Server {
self.server.as_ref().unwrap()
}
}
impl DerefMut for ServerHandle<'_> {
fn deref_mut(&mut self) -> &mut Server {
self.server.as_mut().unwrap()
}
}
impl Drop for ServerHandle<'_> {
fn drop(&mut self) {
let mut server = self.server.take().unwrap();
server.verify_and_clear();
self.servers_tx
.send(server)
.expect("all receivers unexpectedly dropped");
}
}
#[cfg(test)]
mod tests {
use super::*;
const MAX_SERVERS: usize = 5;
static POOL: ServerPool = ServerPool::new(MAX_SERVERS);
#[test]
fn test_max_threads() {
use std::sync::atomic::{AtomicUsize, Ordering::SeqCst};
let concurrent_server_handles = AtomicUsize::new(0);
let desired_concurrency_reached = std::sync::Barrier::new(MAX_SERVERS);
crossbeam_utils::thread::scope(|s| {
for _ in 0..10 {
s.spawn(|_| {
let _server = POOL.get_server();
desired_concurrency_reached.wait();
let prev_value = concurrent_server_handles.fetch_add(1, SeqCst);
if prev_value > MAX_SERVERS {
panic!("too many concurrent server handles: {}", prev_value + 1);
}
std::thread::sleep(std::time::Duration::from_millis(500));
concurrent_server_handles.fetch_sub(1, SeqCst);
});
}
})
.unwrap();
}
}