protosocket_rpc/client/
connection_pool.rs1use std::{
2 cell::RefCell,
3 future::Future,
4 pin::{pin, Pin},
5 sync::Mutex,
6 task::{Context, Poll},
7};
8
9use futures::FutureExt;
10use rand::{Rng, SeedableRng};
11
12use crate::{client::RpcClient, Message};
13
14pub trait ClientConnector: Clone {
18 type Request: Message;
19 type Response: Message;
20
21 fn connect(
32 self,
33 ) -> impl Future<Output = crate::Result<RpcClient<Self::Request, Self::Response>>> + Send + 'static;
34}
35
36#[derive(Debug)]
46pub struct ConnectionPool<Connector: ClientConnector> {
47 connector: Connector,
48 connections: Vec<Mutex<ConnectionState<Connector::Request, Connector::Response>>>,
49}
50
51impl<Connector: ClientConnector> ConnectionPool<Connector> {
52 pub fn new(connector: Connector, connection_count: usize) -> Self {
56 Self {
57 connector,
58 connections: (0..connection_count)
59 .map(|_| Mutex::new(ConnectionState::Disconnected))
60 .collect(),
61 }
62 }
63
64 pub async fn get_connection(
66 &self,
67 ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
68 thread_local! {
69 static THREAD_LOCAL_SMALL_RANDOM: RefCell<rand::rngs::SmallRng> = RefCell::new(rand::rngs::SmallRng::from_os_rng());
70 }
71
72 let slot = THREAD_LOCAL_SMALL_RANDOM
74 .with_borrow_mut(|rng| rng.random_range(0..self.connections.len()));
75 let connection_state = &self.connections[slot];
76
77 let connecting_handle = loop {
79 let mut state = connection_state.lock().expect("internal mutex must work");
80 break match &mut *state {
81 ConnectionState::Connected(shared_connection) => {
82 if shared_connection.is_alive() {
83 return Ok(shared_connection.clone());
84 } else {
85 *state = ConnectionState::Disconnected;
86 continue;
87 }
88 }
89 ConnectionState::Connecting(join_handle) => join_handle.clone(),
90 ConnectionState::Disconnected => {
91 let connector = self.connector.clone();
92 let load = SpawnedConnect {
93 inner: tokio::task::spawn(connector.connect()),
94 }
95 .shared();
96 *state = ConnectionState::Connecting(load.clone());
97 continue;
98 }
99 };
100 };
101
102 match connecting_handle.await {
103 Ok(client) => Ok(reconcile_client_slot(connection_state, client)),
104 Err(connect_error) => {
105 let mut state = connection_state.lock().expect("internal mutex must work");
106 *state = ConnectionState::Disconnected;
107 Err(connect_error)
108 }
109 }
110 }
111}
112
113fn reconcile_client_slot<Request, Response>(
114 connection_state: &Mutex<ConnectionState<Request, Response>>,
115 client: RpcClient<Request, Response>,
116) -> RpcClient<Request, Response>
117where
118 Request: Message,
119 Response: Message,
120{
121 let mut state = connection_state.lock().expect("internal mutex must work");
122 match &mut *state {
123 ConnectionState::Connecting(_shared) => {
124 *state = ConnectionState::Connected(client.clone());
128 client
129 }
130 ConnectionState::Connected(rpc_client) => {
131 if rpc_client.is_alive() {
132 rpc_client.clone()
134 } else {
135 *state = ConnectionState::Connected(client.clone());
137 client
138 }
139 }
140 ConnectionState::Disconnected => {
141 *state = ConnectionState::Connected(client.clone());
143 client
144 }
145 }
146}
147
148struct SpawnedConnect<Request, Response>
149where
150 Request: Message,
151 Response: Message,
152{
153 inner: tokio::task::JoinHandle<crate::Result<RpcClient<Request, Response>>>,
154}
155impl<Request, Response> Future for SpawnedConnect<Request, Response>
156where
157 Request: Message,
158 Response: Message,
159{
160 type Output = crate::Result<RpcClient<Request, Response>>;
161
162 fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
163 match pin!(&mut self.inner).poll(context) {
164 Poll::Ready(Ok(client_result)) => Poll::Ready(client_result),
165 Poll::Ready(Err(_join_err)) => Poll::Ready(Err(crate::Error::ConnectionIsClosed)),
166 Poll::Pending => Poll::Pending,
167 }
168 }
169}
170
171#[derive(Debug)]
172enum ConnectionState<Request, Response>
173where
174 Request: Message,
175 Response: Message,
176{
177 Connecting(futures::future::Shared<SpawnedConnect<Request, Response>>),
178 Connected(RpcClient<Request, Response>),
179 Disconnected,
180}