protosocket_rpc/client/
connection_pool.rs1use std::{
2 cell::RefCell,
3 future::Future,
4 pin::{pin, Pin},
5 sync::Mutex,
6 task::{Context, Poll},
7};
8
9use futures::FutureExt;
10use rand::{Rng, SeedableRng};
11
12use crate::{client::RpcClient, Message};
13
14pub trait ClientConnector: Clone {
18 type Request: Message;
19 type Response: Message;
20
21 fn connect(
32 self,
33 ) -> impl Future<Output = crate::Result<RpcClient<Self::Request, Self::Response>>> + Send + 'static;
34}
35
36#[derive(Debug)]
46pub struct ConnectionPool<Connector: ClientConnector> {
47 connector: Connector,
48 connections: Vec<Mutex<ConnectionState<Connector::Request, Connector::Response>>>,
49}
50
51impl<Connector: ClientConnector> ConnectionPool<Connector> {
52 pub fn new(connector: Connector, connection_count: usize) -> Self {
56 Self {
57 connector,
58 connections: (0..connection_count)
59 .map(|_| Mutex::new(ConnectionState::Disconnected))
60 .collect(),
61 }
62 }
63
64 pub async fn get_connection_for_key(
66 &self,
67 key: usize,
68 ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
69 let slot = key % self.connections.len();
70
71 self.get_connection_by_slot(slot).await
72 }
73
74 pub async fn get_connection(
76 &self,
77 ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
78 thread_local! {
79 static THREAD_LOCAL_SMALL_RANDOM: RefCell<rand::rngs::SmallRng> = RefCell::new(rand::rngs::SmallRng::from_os_rng());
80 }
81
82 let slot = THREAD_LOCAL_SMALL_RANDOM
84 .with_borrow_mut(|rng| rng.random_range(0..self.connections.len()));
85
86 self.get_connection_by_slot(slot).await
87 }
88
89 async fn get_connection_by_slot(
90 &self,
91 slot: usize,
92 ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
93 let connection_state = &self.connections[slot];
94
95 let connecting_handle = loop {
97 let mut state = connection_state.lock().expect("internal mutex must work");
98 break match &mut *state {
99 ConnectionState::Connected(shared_connection) => {
100 if shared_connection.is_alive() {
101 return Ok(shared_connection.clone());
102 } else {
103 *state = ConnectionState::Disconnected;
104 continue;
105 }
106 }
107 ConnectionState::Connecting(join_handle) => join_handle.clone(),
108 ConnectionState::Disconnected => {
109 let connector = self.connector.clone();
110 let load = SpawnedConnect {
111 inner: tokio::task::spawn(connector.connect()),
112 }
113 .shared();
114 *state = ConnectionState::Connecting(load.clone());
115 continue;
116 }
117 };
118 };
119
120 match connecting_handle.await {
121 Ok(client) => Ok(reconcile_client_slot(connection_state, client)),
122 Err(connect_error) => {
123 let mut state = connection_state.lock().expect("internal mutex must work");
124 *state = ConnectionState::Disconnected;
125 Err(connect_error)
126 }
127 }
128 }
129}
130
131fn reconcile_client_slot<Request, Response>(
132 connection_state: &Mutex<ConnectionState<Request, Response>>,
133 client: RpcClient<Request, Response>,
134) -> RpcClient<Request, Response>
135where
136 Request: Message,
137 Response: Message,
138{
139 let mut state = connection_state.lock().expect("internal mutex must work");
140 match &mut *state {
141 ConnectionState::Connecting(_shared) => {
142 *state = ConnectionState::Connected(client.clone());
146 client
147 }
148 ConnectionState::Connected(rpc_client) => {
149 if rpc_client.is_alive() {
150 rpc_client.clone()
152 } else {
153 *state = ConnectionState::Connected(client.clone());
155 client
156 }
157 }
158 ConnectionState::Disconnected => {
159 *state = ConnectionState::Connected(client.clone());
161 client
162 }
163 }
164}
165
166struct SpawnedConnect<Request, Response>
167where
168 Request: Message,
169 Response: Message,
170{
171 inner: tokio::task::JoinHandle<crate::Result<RpcClient<Request, Response>>>,
172}
173impl<Request, Response> Future for SpawnedConnect<Request, Response>
174where
175 Request: Message,
176 Response: Message,
177{
178 type Output = crate::Result<RpcClient<Request, Response>>;
179
180 fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
181 match pin!(&mut self.inner).poll(context) {
182 Poll::Ready(Ok(client_result)) => Poll::Ready(client_result),
183 Poll::Ready(Err(_join_err)) => Poll::Ready(Err(crate::Error::ConnectionIsClosed)),
184 Poll::Pending => Poll::Pending,
185 }
186 }
187}
188
189#[derive(Debug)]
190enum ConnectionState<Request, Response>
191where
192 Request: Message,
193 Response: Message,
194{
195 Connecting(futures::future::Shared<SpawnedConnect<Request, Response>>),
196 Connected(RpcClient<Request, Response>),
197 Disconnected,
198}