protosocket_rpc/client/
rpc_client.rs

1use std::sync::{atomic::AtomicBool, Arc};
2
3use tokio::sync::{mpsc, oneshot};
4
5use super::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor};
6use super::reactor::completion_registry::{Completion, CompletionGuard, RpcRegistrar};
7use super::reactor::{
8    completion_streaming::StreamingCompletion, completion_unary::UnaryCompletion,
9};
10use crate::Message;
11
12/// A client for sending RPCs to a protosockets rpc server.
13///
14/// It handles sending messages to the server and associating the responses.
15/// Messages are sent and received in any order, asynchronously, and support cancellation.
16/// To cancel an RPC, drop the response future.
17#[derive(Debug)]
18pub struct RpcClient<Request, Response>
19where
20    Request: Message,
21    Response: Message,
22{
23    #[allow(clippy::type_complexity)]
24    in_flight_submission: RpcRegistrar<Response>,
25    submission_queue: tokio::sync::mpsc::Sender<Request>,
26    is_alive: Arc<AtomicBool>,
27}
28
29impl<Request, Response> Clone for RpcClient<Request, Response>
30where
31    Request: Message,
32    Response: Message,
33{
34    fn clone(&self) -> Self {
35        Self {
36            in_flight_submission: self.in_flight_submission.clone(),
37            submission_queue: self.submission_queue.clone(),
38            is_alive: self.is_alive.clone(),
39        }
40    }
41}
42
43impl<Request, Response> RpcClient<Request, Response>
44where
45    Request: Message,
46    Response: Message,
47{
48    pub(crate) fn new(
49        submission_queue: mpsc::Sender<Request>,
50        message_reactor: &RpcCompletionReactor<Response, DoNothingMessageHandler<Response>>,
51    ) -> Self {
52        Self {
53            submission_queue,
54            in_flight_submission: message_reactor.in_flight_submission_handle(),
55            is_alive: message_reactor.alive_handle(),
56        }
57    }
58
59    /// Checking this before using the client does not guarantee that the client is still alive when you send  
60    /// your message. It may be useful for connection pool implementations - for example, [bb8::ManageConnection](https://github.com/djc/bb8/blob/09a043c001b3c15514d9f03991cfc87f7118a000/bb8/src/api.rs#L383-L384)'s  
61    /// is_valid and has_broken could be bound to this function to help the pool cycle out broken connections.  
62    pub fn is_alive(&self) -> bool {
63        self.is_alive.load(std::sync::atomic::Ordering::Relaxed)
64    }
65
66    /// Send a server-streaming rpc to the server.
67    ///
68    /// This function only sends the request. You must consume the completion stream to get the response.
69    #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
70    pub async fn send_streaming(
71        &self,
72        request: Request,
73    ) -> crate::Result<StreamingCompletion<Response, Request>> {
74        let (sender, completion) = mpsc::unbounded_channel();
75        let completion_guard = self
76            .send_message(Completion::RemoteStreaming(sender), request)
77            .await?;
78
79        let completion = StreamingCompletion::new(completion, completion_guard);
80
81        Ok(completion)
82    }
83
84    /// Send a unary rpc to the server.
85    ///
86    /// This function only sends the request. You must await the completion to get the response.
87    #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
88    pub async fn send_unary(
89        &self,
90        request: Request,
91    ) -> crate::Result<UnaryCompletion<Response, Request>> {
92        let (completor, completion) = oneshot::channel();
93        let completion_guard = self
94            .send_message(Completion::Unary(completor), request)
95            .await?;
96
97        let completion = UnaryCompletion::new(completion, completion_guard);
98
99        Ok(completion)
100    }
101
102    async fn send_message(
103        &self,
104        completion: Completion<Response>,
105        request: Request,
106    ) -> crate::Result<CompletionGuard<Response, Request>> {
107        if !self.is_alive.load(std::sync::atomic::Ordering::Relaxed) {
108            // early-out if the connection is closed
109            return Err(crate::Error::ConnectionIsClosed);
110        }
111        let completion_guard = self.in_flight_submission.register_completion(
112            request.message_id(),
113            completion,
114            self.submission_queue.clone(),
115        );
116        self.submission_queue
117            .send(request)
118            .await
119            .map_err(|_e| crate::Error::ConnectionIsClosed)
120            .map(|_| completion_guard)
121    }
122}
123
124#[cfg(test)]
125mod test {
126    use std::future::Future;
127    use std::pin::pin;
128    use std::sync::atomic::AtomicBool;
129    use std::sync::Arc;
130    use std::sync::Mutex;
131    use std::task::Context;
132    use std::task::Poll;
133
134    use futures::task::noop_waker_ref;
135    use tokio::sync::mpsc;
136
137    use crate::client::connection_pool::ClientConnector;
138    use crate::client::connection_pool::ConnectionPool;
139    use crate::client::reactor::completion_reactor::DoNothingMessageHandler;
140    use crate::client::reactor::completion_reactor::RpcCompletionReactor;
141    use crate::Message;
142
143    use super::RpcClient;
144
145    impl Message for u64 {
146        fn message_id(&self) -> u64 {
147            *self & 0xffffffff
148        }
149
150        fn control_code(&self) -> crate::ProtosocketControlCode {
151            match *self >> 32 {
152                0 => crate::ProtosocketControlCode::Normal,
153                1 => crate::ProtosocketControlCode::Cancel,
154                2 => crate::ProtosocketControlCode::End,
155                _ => unreachable!("invalid control code"),
156            }
157        }
158
159        fn set_message_id(&mut self, message_id: u64) {
160            *self = (*self & 0xf00000000) | message_id;
161        }
162
163        fn cancelled(message_id: u64) -> Self {
164            (1_u64 << 32) | message_id
165        }
166
167        fn ended(message_id: u64) -> Self {
168            (2 << 32) | message_id
169        }
170    }
171
172    fn drive_future<F: Future>(f: F) -> F::Output {
173        let mut f = pin!(f);
174        loop {
175            let next = f.as_mut().poll(&mut Context::from_waker(noop_waker_ref()));
176            if let Poll::Ready(result) = next {
177                break result;
178            }
179        }
180    }
181
182    #[allow(clippy::type_complexity)]
183    fn get_client() -> (
184        tokio::sync::mpsc::Receiver<u64>,
185        RpcClient<u64, u64>,
186        RpcCompletionReactor<u64, DoNothingMessageHandler<u64>>,
187    ) {
188        let (sender, remote_end) = tokio::sync::mpsc::channel::<u64>(10);
189        let rpc_reactor = RpcCompletionReactor::<u64, _>::new(DoNothingMessageHandler::default());
190        let client = RpcClient::new(sender, &rpc_reactor);
191        (remote_end, client, rpc_reactor)
192    }
193
194    #[test]
195    fn unary_drop_cancel() {
196        let (mut remote_end, client, _reactor) = get_client();
197
198        let response = drive_future(client.send_unary(4)).expect("can send");
199        assert_eq!(4, remote_end.blocking_recv().expect("a request is sent"));
200        assert!(remote_end.is_empty(), "no more messages yet");
201
202        drop(response);
203
204        assert_eq!(
205            (1 << 32) + 4,
206            remote_end.blocking_recv().expect("a cancel is sent")
207        );
208    }
209
210    #[test]
211    fn streaming_drop_cancel() {
212        let (mut remote_end, client, _reactor) = get_client();
213
214        let response = drive_future(client.send_streaming(4)).expect("can send");
215        assert_eq!(4, remote_end.blocking_recv().expect("a request is sent"));
216        assert!(remote_end.is_empty(), "no more messages yet");
217
218        drop(response);
219
220        assert_eq!(
221            (1 << 32) + 4,
222            remote_end.blocking_recv().expect("a cancel is sent")
223        );
224    }
225
226    #[allow(clippy::type_complexity)]
227    #[derive(Default)]
228    struct TestConnector {
229        clients: Mutex<
230            Vec<(
231                mpsc::Receiver<u64>,
232                RpcClient<u64, u64>,
233                RpcCompletionReactor<u64, DoNothingMessageHandler<u64>>,
234            )>,
235        >,
236        fail_connections: AtomicBool,
237    }
238    impl ClientConnector for Arc<TestConnector> {
239        type Request = u64;
240        type Response = u64;
241
242        async fn connect(self) -> crate::Result<RpcClient<Self::Request, Self::Response>> {
243            if self
244                .fail_connections
245                .load(std::sync::atomic::Ordering::Relaxed)
246            {
247                return Err(crate::Error::IoFailure(Arc::new(std::io::Error::other(
248                    "simulated connection failure",
249                ))));
250            }
251            // normally I'd just call `protosocket_rpc::client::connect` in here
252            let (remote_end, client, reactor) = get_client();
253            self.clients
254                .lock()
255                .expect("mutex works")
256                .push((remote_end, client.clone(), reactor));
257
258            Ok(client)
259        }
260    }
261
262    // have to use tokio::test for the connection pool because it uses tokio::spawn
263    #[tokio::test]
264    async fn connection_pool() {
265        let connector = Arc::new(TestConnector::default());
266        let pool = ConnectionPool::new(connector.clone(), 1);
267
268        let rpc_client_a = pool
269            .get_connection()
270            .await
271            .expect("can get a connection from the pool");
272        assert_eq!(
273            1,
274            connector.clients.lock().expect("mutex works").len(),
275            "one connection created"
276        );
277
278        let rpc_client_b = pool
279            .get_connection()
280            .await
281            .expect("can get a connection from the pool");
282        assert_eq!(
283            1,
284            connector.clients.lock().expect("mutex works").len(),
285            "still one connection created"
286        );
287
288        assert!(
289            Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
290            "same connection shared"
291        );
292
293        let _reply_a = rpc_client_a.send_unary(42).await.expect("can send");
294        let _reply_b = rpc_client_b.send_unary(43).await.expect("can send");
295
296        let (mut remote_end, _client, _reactor) = {
297            let mut clients = connector.clients.lock().expect("mutex works");
298            clients.pop().expect("one client exists")
299        };
300        assert_eq!(42, remote_end.recv().await.expect("request a is received"));
301        assert_eq!(43, remote_end.recv().await.expect("request b is received"));
302    }
303
304    #[tokio::test]
305    async fn connection_pool_reconnect() {
306        let connector = Arc::new(TestConnector::default());
307        let pool = ConnectionPool::new(connector.clone(), 1);
308
309        let rpc_client_a = pool
310            .get_connection()
311            .await
312            .expect("can get a connection from the pool");
313        assert_eq!(
314            1,
315            connector.clients.lock().expect("mutex works").len(),
316            "one connection created"
317        );
318
319        rpc_client_a
320            .is_alive
321            .store(false, std::sync::atomic::Ordering::Relaxed);
322
323        let rpc_client_b = pool.get_connection().await.expect("can get a connection from the pool even when the previous connection is dead, as long as the connection attempt succeeds");
324        assert_eq!(
325            2,
326            connector.clients.lock().expect("mutex works").len(),
327            "a new connection was created, so the connector was asked to make a new connection"
328        );
329        // Note that the connection pool holds a plain Vec of individual clients, so it cannot create more connections than it started with.
330
331        assert!(
332            !Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
333            "new connection created"
334        );
335    }
336
337    #[tokio::test]
338    async fn connection_pool_failure() {
339        let connector = Arc::new(TestConnector::default());
340        let pool = ConnectionPool::new(connector.clone(), 1);
341        connector
342            .fail_connections
343            .store(true, std::sync::atomic::Ordering::Relaxed);
344
345        pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
346    }
347
348    #[tokio::test]
349    async fn connection_pool_reconnect_failure_recovery() {
350        let connector = Arc::new(TestConnector::default());
351        let pool = ConnectionPool::new(connector.clone(), 1);
352        let rpc_client_a = pool
353            .get_connection()
354            .await
355            .expect("can get a connection from the pool");
356
357        rpc_client_a
358            .is_alive
359            .store(false, std::sync::atomic::Ordering::Relaxed);
360        connector
361            .fail_connections
362            .store(true, std::sync::atomic::Ordering::Relaxed);
363
364        pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
365
366        connector
367            .fail_connections
368            .store(false, std::sync::atomic::Ordering::Relaxed);
369
370        let rpc_client_b = pool
371            .get_connection()
372            .await
373            .expect("can get a connection from the pool now");
374        assert!(
375            !Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
376            "new connection created"
377        );
378    }
379}