use std::{
cell::RefCell,
future::Future,
pin::{pin, Pin},
sync::Mutex,
task::{Context, Poll},
};
use futures::FutureExt;
use rand::{Rng, SeedableRng};
use crate::{client::RpcClient, Message};
pub trait ClientConnector: Clone {
type Request: Message;
type Response: Message;
fn connect(
self,
) -> impl Future<Output = crate::Result<RpcClient<Self::Request, Self::Response>>> + Send + 'static;
}
#[derive(Debug)]
pub struct ConnectionPool<Connector: ClientConnector> {
connector: Connector,
connections: Vec<Mutex<ConnectionState<Connector::Request, Connector::Response>>>,
}
impl<Connector: ClientConnector> ConnectionPool<Connector> {
pub fn new(connector: Connector, connection_count: usize) -> Self {
Self {
connector,
connections: (0..connection_count)
.map(|_| Mutex::new(ConnectionState::Disconnected))
.collect(),
}
}
pub async fn get_connection_for_key(
&self,
key: usize,
) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
let slot = key % self.connections.len();
self.get_connection_by_slot(slot).await
}
pub async fn get_connection(
&self,
) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
thread_local! {
static THREAD_LOCAL_SMALL_RANDOM: RefCell<rand::rngs::SmallRng> = RefCell::new(rand::rngs::SmallRng::from_os_rng());
}
let slot = THREAD_LOCAL_SMALL_RANDOM
.with_borrow_mut(|rng| rng.random_range(0..self.connections.len()));
self.get_connection_by_slot(slot).await
}
async fn get_connection_by_slot(
&self,
slot: usize,
) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
let connection_state = &self.connections[slot];
let connecting_handle = loop {
let mut state = connection_state.lock().expect("internal mutex must work");
break match &mut *state {
ConnectionState::Connected(shared_connection) => {
if shared_connection.is_alive() {
return Ok(shared_connection.clone());
} else {
*state = ConnectionState::Disconnected;
continue;
}
}
ConnectionState::Connecting(join_handle) => join_handle.clone(),
ConnectionState::Disconnected => {
let connector = self.connector.clone();
let load = SpawnedConnect {
inner: tokio::task::spawn(connector.connect()),
}
.shared();
*state = ConnectionState::Connecting(load.clone());
continue;
}
};
};
match connecting_handle.await {
Ok(client) => Ok(reconcile_client_slot(connection_state, client)),
Err(connect_error) => {
let mut state = connection_state.lock().expect("internal mutex must work");
*state = ConnectionState::Disconnected;
Err(connect_error)
}
}
}
}
fn reconcile_client_slot<Request, Response>(
connection_state: &Mutex<ConnectionState<Request, Response>>,
client: RpcClient<Request, Response>,
) -> RpcClient<Request, Response>
where
Request: Message,
Response: Message,
{
let mut state = connection_state.lock().expect("internal mutex must work");
match &mut *state {
ConnectionState::Connecting(_shared) => {
*state = ConnectionState::Connected(client.clone());
client
}
ConnectionState::Connected(rpc_client) => {
if rpc_client.is_alive() {
rpc_client.clone()
} else {
*state = ConnectionState::Connected(client.clone());
client
}
}
ConnectionState::Disconnected => {
*state = ConnectionState::Connected(client.clone());
client
}
}
}
struct SpawnedConnect<Request, Response>
where
Request: Message,
Response: Message,
{
inner: tokio::task::JoinHandle<crate::Result<RpcClient<Request, Response>>>,
}
impl<Request, Response> Future for SpawnedConnect<Request, Response>
where
Request: Message,
Response: Message,
{
type Output = crate::Result<RpcClient<Request, Response>>;
fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
match pin!(&mut self.inner).poll(context) {
Poll::Ready(Ok(client_result)) => Poll::Ready(client_result),
Poll::Ready(Err(_join_err)) => Poll::Ready(Err(crate::Error::ConnectionIsClosed)),
Poll::Pending => Poll::Pending,
}
}
}
#[derive(Debug)]
enum ConnectionState<Request, Response>
where
Request: Message,
Response: Message,
{
Connecting(futures::future::Shared<SpawnedConnect<Request, Response>>),
Connected(RpcClient<Request, Response>),
Disconnected,
}