protosocket_rpc/client/
connection_pool.rs

1use std::{
2    cell::RefCell,
3    future::Future,
4    pin::{pin, Pin},
5    sync::Mutex,
6    task::{Context, Poll},
7};
8
9use futures::FutureExt;
10use rand::{Rng, SeedableRng};
11
12use crate::{client::RpcClient, Message};
13
14/// A connection strategy for protosocket rpc clients.
15///
16/// This is called asynchronously by the connection pool to create new connections.
17pub trait ClientConnector: Clone {
18    type Request: Message;
19    type Response: Message;
20
21    /// Connect to the server and return a new RpcClient. See [`crate::client::connect`] for
22    /// the typical way to connect. This is called by the ConnectionPool when it needs a new
23    /// connection.
24    ///
25    /// Your returned future needs to be `'static`, and your connector needs to be cheap to
26    /// clone. One easy way to do that is to just impl ClientConnector on `Arc<YourConnectorType>`
27    /// instead of directly on `YourConnectorType`.
28    ///
29    /// If you have rolling credentials, initialization messages, changing endpoints, or other
30    /// adaptive connection logic, this is the place to do it or consult those sources of truth.
31    fn connect(
32        self,
33    ) -> impl Future<Output = crate::Result<RpcClient<Self::Request, Self::Response>>> + Send + 'static;
34}
35
36/// A connection pool for protosocket rpc clients.
37///
38/// Protosocket-rpc connections are shared and multiplexed, so this vends cloned handles.
39/// You can hold onto a handle from the pool for as long as you want. There is a small
40/// synchronization cost to getting a handle from a pool, so caching is a good idea - but
41/// if you want to load balance a lot, you can just make a pool with as many "slots" as you
42/// want to dilute any contention on connection state locks. The locks are typically held for
43/// the time it takes to clone an `Arc`, so it's usually nanosecond-scale synchronization,
44/// per connection. So if you have several connections, you'll rarely contend.
45#[derive(Debug)]
46pub struct ConnectionPool<Connector: ClientConnector> {
47    connector: Connector,
48    connections: Vec<Mutex<ConnectionState<Connector::Request, Connector::Response>>>,
49}
50
51impl<Connector: ClientConnector> ConnectionPool<Connector> {
52    /// Create a new connection pool.
53    ///
54    /// It will try to maintain `connection_count` healthy connections.
55    pub fn new(connector: Connector, connection_count: usize) -> Self {
56        Self {
57            connector,
58            connections: (0..connection_count)
59                .map(|_| Mutex::new(ConnectionState::Disconnected))
60                .collect(),
61        }
62    }
63
64    /// Get a consistent connection from the pool for a given key.
65    pub async fn get_connection_for_key(
66        &self,
67        key: usize,
68    ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
69        let slot = key % self.connections.len();
70
71        self.get_connection_by_slot(slot).await
72    }
73
74    /// Get a connection from the pool.
75    pub async fn get_connection(
76        &self,
77    ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
78        thread_local! {
79            static THREAD_LOCAL_SMALL_RANDOM: RefCell<rand::rngs::SmallRng> = RefCell::new(rand::rngs::SmallRng::from_os_rng());
80        }
81
82        // Safety: This is executed on a thread, in only one place. It cannot be borrowed anywhere else.
83        let slot = THREAD_LOCAL_SMALL_RANDOM
84            .with_borrow_mut(|rng| rng.random_range(0..self.connections.len()));
85
86        self.get_connection_by_slot(slot).await
87    }
88
89    async fn get_connection_by_slot(
90        &self,
91        slot: usize,
92    ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
93        let connection_state = &self.connections[slot];
94
95        // The connection state requires a mutex, so I need to keep await out of the scope to satisfy clippy (and for paranoia).
96        let connecting_handle = loop {
97            let mut state = connection_state.lock().expect("internal mutex must work");
98            break match &mut *state {
99                ConnectionState::Connected(shared_connection) => {
100                    if shared_connection.is_alive() {
101                        return Ok(shared_connection.clone());
102                    } else {
103                        *state = ConnectionState::Disconnected;
104                        continue;
105                    }
106                }
107                ConnectionState::Connecting(join_handle) => join_handle.clone(),
108                ConnectionState::Disconnected => {
109                    let connector = self.connector.clone();
110                    let load = SpawnedConnect {
111                        inner: tokio::task::spawn(connector.connect()),
112                    }
113                    .shared();
114                    *state = ConnectionState::Connecting(load.clone());
115                    continue;
116                }
117            };
118        };
119
120        match connecting_handle.await {
121            Ok(client) => Ok(reconcile_client_slot(connection_state, client)),
122            Err(connect_error) => {
123                let mut state = connection_state.lock().expect("internal mutex must work");
124                *state = ConnectionState::Disconnected;
125                Err(connect_error)
126            }
127        }
128    }
129}
130
131fn reconcile_client_slot<Request, Response>(
132    connection_state: &Mutex<ConnectionState<Request, Response>>,
133    client: RpcClient<Request, Response>,
134) -> RpcClient<Request, Response>
135where
136    Request: Message,
137    Response: Message,
138{
139    let mut state = connection_state.lock().expect("internal mutex must work");
140    match &mut *state {
141        ConnectionState::Connecting(_shared) => {
142            // Here we drop the shared handle. If there is another task still waiting on it, they will get notified when
143            // the spawned connection task completes. When they come to reconcile with the connection slot, they will
144            // favor this connection and drop their own.
145            *state = ConnectionState::Connected(client.clone());
146            client
147        }
148        ConnectionState::Connected(rpc_client) => {
149            if rpc_client.is_alive() {
150                // someone else beat us to it
151                rpc_client.clone()
152            } else {
153                // well this one is broken too, so we should just replace it with our new one
154                *state = ConnectionState::Connected(client.clone());
155                client
156            }
157        }
158        ConnectionState::Disconnected => {
159            // we raced with a disconnect, but we have a new client, so use it
160            *state = ConnectionState::Connected(client.clone());
161            client
162        }
163    }
164}
165
166struct SpawnedConnect<Request, Response>
167where
168    Request: Message,
169    Response: Message,
170{
171    inner: tokio::task::JoinHandle<crate::Result<RpcClient<Request, Response>>>,
172}
173impl<Request, Response> Future for SpawnedConnect<Request, Response>
174where
175    Request: Message,
176    Response: Message,
177{
178    type Output = crate::Result<RpcClient<Request, Response>>;
179
180    fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
181        match pin!(&mut self.inner).poll(context) {
182            Poll::Ready(Ok(client_result)) => Poll::Ready(client_result),
183            Poll::Ready(Err(_join_err)) => Poll::Ready(Err(crate::Error::ConnectionIsClosed)),
184            Poll::Pending => Poll::Pending,
185        }
186    }
187}
188
189#[derive(Debug)]
190enum ConnectionState<Request, Response>
191where
192    Request: Message,
193    Response: Message,
194{
195    Connecting(futures::future::Shared<SpawnedConnect<Request, Response>>),
196    Connected(RpcClient<Request, Response>),
197    Disconnected,
198}