protosocket_rpc/client/
connection_pool.rs

1use std::{
2    cell::RefCell,
3    future::Future,
4    pin::{pin, Pin},
5    sync::Mutex,
6    task::{Context, Poll},
7};
8
9use futures::FutureExt;
10use rand::{Rng, SeedableRng};
11
12use crate::{client::RpcClient, Message};
13
14/// A connection strategy for protosocket rpc clients.
15///
16/// This is called asynchronously by the connection pool to create new connections.
17pub trait ClientConnector: Clone {
18    type Request: Message;
19    type Response: Message;
20
21    /// Connect to the server and return a new RpcClient. See [`crate::client::connect`] for
22    /// the typical way to connect. This is called by the ConnectionPool when it needs a new
23    /// connection.
24    ///
25    /// Your returned future needs to be `'static`, and your connector needs to be cheap to
26    /// clone. One easy way to do that is to just impl ClientConnector on `Arc<YourConnectorType>`
27    /// instead of directly on `YourConnectorType`.
28    ///
29    /// If you have rolling credentials, initialization messages, changing endpoints, or other
30    /// adaptive connection logic, this is the place to do it or consult those sources of truth.
31    fn connect(
32        self,
33    ) -> impl Future<Output = crate::Result<RpcClient<Self::Request, Self::Response>>> + Send + 'static;
34}
35
36/// A connection pool for protosocket rpc clients.
37///
38/// Protosocket-rpc connections are shared and multiplexed, so this vends cloned handles.
39/// You can hold onto a handle from the pool for as long as you want. There is a small
40/// synchronization cost to getting a handle from a pool, so caching is a good idea - but
41/// if you want to load balance a lot, you can just make a pool with as many "slots" as you
42/// want to dilute any contention on connection state locks. The locks are typically held for
43/// the time it takes to clone an `Arc`, so it's usually nanosecond-scale synchronization,
44/// per connection. So if you have several connections, you'll rarely contend.
45#[derive(Debug)]
46pub struct ConnectionPool<Connector: ClientConnector> {
47    connector: Connector,
48    connections: Vec<Mutex<ConnectionState<Connector::Request, Connector::Response>>>,
49}
50
51impl<Connector: ClientConnector> ConnectionPool<Connector> {
52    /// Create a new connection pool.
53    ///
54    /// It will try to maintain `connection_count` healthy connections.
55    pub fn new(connector: Connector, connection_count: usize) -> Self {
56        Self {
57            connector,
58            connections: (0..connection_count)
59                .map(|_| Mutex::new(ConnectionState::Disconnected))
60                .collect(),
61        }
62    }
63
64    /// Get a connection from the pool.
65    pub async fn get_connection(
66        &self,
67    ) -> crate::Result<RpcClient<Connector::Request, Connector::Response>> {
68        thread_local! {
69            static THREAD_LOCAL_SMALL_RANDOM: RefCell<rand::rngs::SmallRng> = RefCell::new(rand::rngs::SmallRng::from_os_rng());
70        }
71
72        // Safety: This is executed on a thread, in only one place. It cannot be borrowed anywhere else.
73        let slot = THREAD_LOCAL_SMALL_RANDOM
74            .with_borrow_mut(|rng| rng.random_range(0..self.connections.len()));
75        let connection_state = &self.connections[slot];
76
77        // The connection state requires a mutex, so I need to keep await out of the scope to satisfy clippy (and for paranoia).
78        let connecting_handle = loop {
79            let mut state = connection_state.lock().expect("internal mutex must work");
80            break match &mut *state {
81                ConnectionState::Connected(shared_connection) => {
82                    if shared_connection.is_alive() {
83                        return Ok(shared_connection.clone());
84                    } else {
85                        *state = ConnectionState::Disconnected;
86                        continue;
87                    }
88                }
89                ConnectionState::Connecting(join_handle) => join_handle.clone(),
90                ConnectionState::Disconnected => {
91                    let connector = self.connector.clone();
92                    let load = SpawnedConnect {
93                        inner: tokio::task::spawn(connector.connect()),
94                    }
95                    .shared();
96                    *state = ConnectionState::Connecting(load.clone());
97                    continue;
98                }
99            };
100        };
101
102        match connecting_handle.await {
103            Ok(client) => Ok(reconcile_client_slot(connection_state, client)),
104            Err(connect_error) => {
105                let mut state = connection_state.lock().expect("internal mutex must work");
106                *state = ConnectionState::Disconnected;
107                Err(connect_error)
108            }
109        }
110    }
111}
112
113fn reconcile_client_slot<Request, Response>(
114    connection_state: &Mutex<ConnectionState<Request, Response>>,
115    client: RpcClient<Request, Response>,
116) -> RpcClient<Request, Response>
117where
118    Request: Message,
119    Response: Message,
120{
121    let mut state = connection_state.lock().expect("internal mutex must work");
122    match &mut *state {
123        ConnectionState::Connecting(_shared) => {
124            // Here we drop the shared handle. If there is another task still waiting on it, they will get notified when
125            // the spawned connection task completes. When they come to reconcile with the connection slot, they will
126            // favor this connection and drop their own.
127            *state = ConnectionState::Connected(client.clone());
128            client
129        }
130        ConnectionState::Connected(rpc_client) => {
131            if rpc_client.is_alive() {
132                // someone else beat us to it
133                rpc_client.clone()
134            } else {
135                // well this one is broken too, so we should just replace it with our new one
136                *state = ConnectionState::Connected(client.clone());
137                client
138            }
139        }
140        ConnectionState::Disconnected => {
141            // we raced with a disconnect, but we have a new client, so use it
142            *state = ConnectionState::Connected(client.clone());
143            client
144        }
145    }
146}
147
148struct SpawnedConnect<Request, Response>
149where
150    Request: Message,
151    Response: Message,
152{
153    inner: tokio::task::JoinHandle<crate::Result<RpcClient<Request, Response>>>,
154}
155impl<Request, Response> Future for SpawnedConnect<Request, Response>
156where
157    Request: Message,
158    Response: Message,
159{
160    type Output = crate::Result<RpcClient<Request, Response>>;
161
162    fn poll(mut self: Pin<&mut Self>, context: &mut Context<'_>) -> Poll<Self::Output> {
163        match pin!(&mut self.inner).poll(context) {
164            Poll::Ready(Ok(client_result)) => Poll::Ready(client_result),
165            Poll::Ready(Err(_join_err)) => Poll::Ready(Err(crate::Error::ConnectionIsClosed)),
166            Poll::Pending => Poll::Pending,
167        }
168    }
169}
170
171#[derive(Debug)]
172enum ConnectionState<Request, Response>
173where
174    Request: Message,
175    Response: Message,
176{
177    Connecting(futures::future::Shared<SpawnedConnect<Request, Response>>),
178    Connected(RpcClient<Request, Response>),
179    Disconnected,
180}