Skip to main content

protosocket_rpc/client/
connection_pool.rs

1use std::{
2    cell::RefCell,
3    future::Future,
4    pin::{pin, Pin},
5    sync::Mutex,
6    task::{Context, Poll},
7};
8
9use futures::FutureExt;
10use rand::{Rng, SeedableRng};
11
12use crate::{client::RpcClient, Message};
13
14/// A connection strategy for protosocket rpc clients.
15///
16/// This is called asynchronously by the connection pool to create new connections.
17pub trait ClientConnector: Clone {
18    /// Request message type
19    type Request: Message;
20    /// Response message type
21    type Response: Message;
22
23    /// Connect to the server and return a new RpcClient. See [`crate::client::connect`] for
24    /// the typical way to connect. This is called by the ConnectionPool when it needs a new
25    /// connection.
26    ///
27    /// Your returned future needs to be `'static`, and your connector needs to be cheap to
28    /// clone. One easy way to do that is to just impl ClientConnector on `Arc<YourConnectorType>`
29    /// instead of directly on `YourConnectorType`.
30    ///
31    /// If you have rolling credentials, initialization messages, changing endpoints, or other
32    /// adaptive connection logic, this is the place to do it or consult those sources of truth.
33    fn connect(
34        self,
35    ) -> impl Future<Output = crate::Result<RpcClient<Self::Request, Self::Response>>> + Send + 'static;
36}
37
38/// A connection pool for protosocket rpc clients.
39///
40/// Protosocket-rpc connections are shared and multiplexed, so this vends cloned handles.
41/// You can hold onto a handle from the pool for as long as you want. There is a small
42/// synchronization cost to getting a handle from a pool, so caching is a good idea - but
43/// if you want to load balance a lot, you can just make a pool with as many "slots" as you
44/// want to dilute any contention on connection state locks. The locks are typically held for
45/// the time it takes to clone an `Arc`, so it's usually nanosecond-scale synchronization,
46/// per connection. So if you have several connections, you'll rarely contend.
47#[derive(Debug)]
48pub struct ConnectionPool<Connector: ClientConnector> {
49    connector: Connector,
50    connections: Vec<Mutex<ConnectionState<Connector::Request, Connector::Response>>>,
51}
52
53impl<Connector: ClientConnector> ConnectionPool<Connector> {
54    /// Create a new connection pool.
55    ///
56    /// It will try to maintain `connection_count` healthy connections.
57    pub fn new(connector: Connector, connection_count: usize) -> Self {
58        Self {
59            connector,
60            connections: (0..connection_count)
61                .map(|_| Mutex::new(ConnectionState::Disconnected))
62                .collect(),
63        }
64    }
65
66    /// Get a consistent connection from the pool for a given key.
67    pub async fn get_connection_for_key(
68        &self,
69        key: usize,
70    ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
71        let slot = key % self.connections.len();
72
73        self.get_connection_by_slot(slot).await
74    }
75
76    /// Get a connection from the pool.
77    pub async fn get_connection(
78        &self,
79    ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
80        thread_local! {
81            static THREAD_LOCAL_SMALL_RANDOM: RefCell<rand::rngs::SmallRng> = RefCell::new(rand::rngs::SmallRng::from_os_rng());
82        }
83
84        // Safety: This is executed on a thread, in only one place. It cannot be borrowed anywhere else.
85        let slot = THREAD_LOCAL_SMALL_RANDOM
86            .with_borrow_mut(|rng| rng.random_range(0..self.connections.len()));
87
88        self.get_connection_by_slot(slot).await
89    }
90
91    async fn get_connection_by_slot(
92        &self,
93        slot: usize,
94    ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
95        let connection_state = &self.connections[slot];
96
97        // The connection state requires a mutex, so I need to keep await out of the scope to satisfy clippy (and for paranoia).
98        let connecting_handle = loop {
99            let mut state = connection_state.lock().expect("internal mutex must work");
100            break match &mut *state {
101                ConnectionState::Connected(shared_connection) => {
102                    if shared_connection.is_alive() {
103                        return Ok(shared_connection.clone());
104                    } else {
105                        *state = ConnectionState::Disconnected;
106                        continue;
107                    }
108                }
109                ConnectionState::Connecting(join_handle) => join_handle.clone(),
110                ConnectionState::Disconnected => {
111                    let connector = self.connector.clone();
112                    let load = SpawnedConnect {
113                        inner: tokio::task::spawn(connector.connect()),
114                    }
115                    .shared();
116                    *state = ConnectionState::Connecting(load.clone());
117                    continue;
118                }
119            };
120        };
121
122        match connecting_handle.await {
123            Ok(client) => Ok(reconcile_client_slot(connection_state, client)),
124            Err(connect_error) => {
125                let mut state = connection_state.lock().expect("internal mutex must work");
126                *state = ConnectionState::Disconnected;
127                Err(connect_error)
128            }
129        }
130    }
131}
132
133fn reconcile_client_slot<Request, Response>(
134    connection_state: &Mutex<ConnectionState<Request, Response>>,
135    client: RpcClient<Request, Response>,
136) -> RpcClient<Request, Response>
137where
138    Request: Message,
139    Response: Message,
140{
141    let mut state = connection_state.lock().expect("internal mutex must work");
142    match &mut *state {
143        ConnectionState::Connecting(_shared) => {
144            // Here we drop the shared handle. If there is another task still waiting on it, they will get notified when
145            // the spawned connection task completes. When they come to reconcile with the connection slot, they will
146            // favor this connection and drop their own.
147            *state = ConnectionState::Connected(client.clone());
148            client
149        }
150        ConnectionState::Connected(rpc_client) => {
151            if rpc_client.is_alive() {
152                // someone else beat us to it
153                rpc_client.clone()
154            } else {
155                // well this one is broken too, so we should just replace it with our new one
156                *state = ConnectionState::Connected(client.clone());
157                client
158            }
159        }
160        ConnectionState::Disconnected => {
161            // we raced with a disconnect, but we have a new client, so use it
162            *state = ConnectionState::Connected(client.clone());
163            client
164        }
165    }
166}
167
168struct SpawnedConnect<Request, Response>
169where
170    Request: Message,
171    Response: Message,
172{
173    inner: tokio::task::JoinHandle<crate::Result<RpcClient<Request, Response>>>,
174}
175impl<Request, Response> Future for SpawnedConnect<Request, Response>
176where
177    Request: Message,
178    Response: Message,
179{
180    type Output = crate::Result<RpcClient<Request, Response>>;
181
182    fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
183        match pin!(&mut self.inner).poll(context) {
184            Poll::Ready(Ok(client_result)) => Poll::Ready(client_result),
185            Poll::Ready(Err(_join_err)) => Poll::Ready(Err(crate::Error::ConnectionIsClosed)),
186            Poll::Pending => Poll::Pending,
187        }
188    }
189}
190
191#[derive(Debug)]
192enum ConnectionState<Request, Response>
193where
194    Request: Message,
195    Response: Message,
196{
197    Connecting(futures::future::Shared<SpawnedConnect<Request, Response>>),
198    Connected(RpcClient<Request, Response>),
199    Disconnected,
200}