protosocket_rpc/client/
connection_pool.rs1use std::{
2 cell::RefCell,
3 future::Future,
4 pin::{pin, Pin},
5 sync::Mutex,
6 task::{Context, Poll},
7};
8
9use futures::FutureExt;
10use rand::{Rng, SeedableRng};
11
12use crate::{client::RpcClient, Message};
13
14pub trait ClientConnector: Clone {
18 type Request: Message;
20 type Response: Message;
22
23 fn connect(
34 self,
35 ) -> impl Future<Output = crate::Result<RpcClient<Self::Request, Self::Response>>> + Send + 'static;
36}
37
38#[derive(Debug)]
48pub struct ConnectionPool<Connector: ClientConnector> {
49 connector: Connector,
50 connections: Vec<Mutex<ConnectionState<Connector::Request, Connector::Response>>>,
51}
52
53impl<Connector: ClientConnector> ConnectionPool<Connector> {
54 pub fn new(connector: Connector, connection_count: usize) -> Self {
58 Self {
59 connector,
60 connections: (0..connection_count)
61 .map(|_| Mutex::new(ConnectionState::Disconnected))
62 .collect(),
63 }
64 }
65
66 pub async fn get_connection_for_key(
68 &self,
69 key: usize,
70 ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
71 let slot = key % self.connections.len();
72
73 self.get_connection_by_slot(slot).await
74 }
75
76 pub async fn get_connection(
78 &self,
79 ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
80 thread_local! {
81 static THREAD_LOCAL_SMALL_RANDOM: RefCell<rand::rngs::SmallRng> = RefCell::new(rand::rngs::SmallRng::from_os_rng());
82 }
83
84 let slot = THREAD_LOCAL_SMALL_RANDOM
86 .with_borrow_mut(|rng| rng.random_range(0..self.connections.len()));
87
88 self.get_connection_by_slot(slot).await
89 }
90
91 async fn get_connection_by_slot(
92 &self,
93 slot: usize,
94 ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
95 let connection_state = &self.connections[slot];
96
97 let connecting_handle = loop {
99 let mut state = connection_state.lock().expect("internal mutex must work");
100 break match &mut *state {
101 ConnectionState::Connected(shared_connection) => {
102 if shared_connection.is_alive() {
103 return Ok(shared_connection.clone());
104 } else {
105 *state = ConnectionState::Disconnected;
106 continue;
107 }
108 }
109 ConnectionState::Connecting(join_handle) => join_handle.clone(),
110 ConnectionState::Disconnected => {
111 let connector = self.connector.clone();
112 let load = SpawnedConnect {
113 inner: tokio::task::spawn(connector.connect()),
114 }
115 .shared();
116 *state = ConnectionState::Connecting(load.clone());
117 continue;
118 }
119 };
120 };
121
122 match connecting_handle.await {
123 Ok(client) => Ok(reconcile_client_slot(connection_state, client)),
124 Err(connect_error) => {
125 let mut state = connection_state.lock().expect("internal mutex must work");
126 *state = ConnectionState::Disconnected;
127 Err(connect_error)
128 }
129 }
130 }
131}
132
133fn reconcile_client_slot<Request, Response>(
134 connection_state: &Mutex<ConnectionState<Request, Response>>,
135 client: RpcClient<Request, Response>,
136) -> RpcClient<Request, Response>
137where
138 Request: Message,
139 Response: Message,
140{
141 let mut state = connection_state.lock().expect("internal mutex must work");
142 match &mut *state {
143 ConnectionState::Connecting(_shared) => {
144 *state = ConnectionState::Connected(client.clone());
148 client
149 }
150 ConnectionState::Connected(rpc_client) => {
151 if rpc_client.is_alive() {
152 rpc_client.clone()
154 } else {
155 *state = ConnectionState::Connected(client.clone());
157 client
158 }
159 }
160 ConnectionState::Disconnected => {
161 *state = ConnectionState::Connected(client.clone());
163 client
164 }
165 }
166}
167
168struct SpawnedConnect<Request, Response>
169where
170 Request: Message,
171 Response: Message,
172{
173 inner: tokio::task::JoinHandle<crate::Result<RpcClient<Request, Response>>>,
174}
175impl<Request, Response> Future for SpawnedConnect<Request, Response>
176where
177 Request: Message,
178 Response: Message,
179{
180 type Output = crate::Result<RpcClient<Request, Response>>;
181
182 fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
183 match pin!(&mut self.inner).poll(context) {
184 Poll::Ready(Ok(client_result)) => Poll::Ready(client_result),
185 Poll::Ready(Err(_join_err)) => Poll::Ready(Err(crate::Error::ConnectionIsClosed)),
186 Poll::Pending => Poll::Pending,
187 }
188 }
189}
190
191#[derive(Debug)]
192enum ConnectionState<Request, Response>
193where
194 Request: Message,
195 Response: Message,
196{
197 Connecting(futures::future::Shared<SpawnedConnect<Request, Response>>),
198 Connected(RpcClient<Request, Response>),
199 Disconnected,
200}