Skip to main content

protosocket_rpc/client/
rpc_client.rs

1use std::sync::{atomic::AtomicBool, Arc};
2
3use tokio::sync::oneshot;
4
5use super::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor};
6use super::reactor::completion_registry::{Completion, CompletionGuard};
7use super::reactor::{
8    completion_streaming::StreamingCompletion, completion_unary::UnaryCompletion,
9};
10use crate::client::reactor::completion_reactor::{CompletableRpc, RpcNotification};
11use crate::Message;
12
13/// A client for sending RPCs to a protosockets rpc server.
14///
15/// It handles sending messages to the server and associating the responses.
16/// Messages are sent and received in any order, asynchronously, and support cancellation.
17/// To cancel an RPC, drop the response future.
18#[derive(Debug)]
19pub struct RpcClient<Request, Response>
20where
21    Request: Message,
22    Response: Message,
23{
24    submission_queue: spillway::Sender<RpcNotification<Response, Request>>,
25    is_alive: Arc<AtomicBool>,
26}
27
28impl<Request, Response> Clone for RpcClient<Request, Response>
29where
30    Request: Message,
31    Response: Message,
32{
33    fn clone(&self) -> Self {
34        Self {
35            submission_queue: self.submission_queue.clone(),
36            is_alive: self.is_alive.clone(),
37        }
38    }
39}
40
41impl<Request, Response> RpcClient<Request, Response>
42where
43    Request: Message,
44    Response: Message,
45{
46    pub(crate) fn new(
47        submission_queue: spillway::Sender<RpcNotification<Response, Request>>,
48        message_reactor: &RpcCompletionReactor<
49            Response,
50            Request,
51            DoNothingMessageHandler<Response>,
52        >,
53    ) -> Self {
54        Self {
55            submission_queue,
56            is_alive: message_reactor.alive_handle(),
57        }
58    }
59
60    /// Checking this before using the client does not guarantee that the client is still alive when you send  
61    /// your message. It may be useful for connection pool implementations - for example, [bb8::ManageConnection](https://github.com/djc/bb8/blob/09a043c001b3c15514d9f03991cfc87f7118a000/bb8/src/api.rs#L383-L384)'s  
62    /// is_valid and has_broken could be bound to this function to help the pool cycle out broken connections.  
63    pub fn is_alive(&self) -> bool {
64        self.is_alive.load(std::sync::atomic::Ordering::Relaxed)
65    }
66
67    /// Send a server-streaming rpc to the server.
68    ///
69    /// This function only sends the request. You must consume the completion stream to get the response.
70    #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
71    pub fn send_streaming(
72        &self,
73        request: Request,
74    ) -> crate::Result<StreamingCompletion<Response, Request>> {
75        let (sender, completion) = spillway::channel_with_concurrency(1);
76        let completion_guard = self.send_message(Completion::RemoteStreaming(sender), request)?;
77
78        let completion = StreamingCompletion::new(completion, completion_guard);
79
80        Ok(completion)
81    }
82
83    /// Send a unary rpc to the server.
84    ///
85    /// This function only sends the request. You must await the completion to get the response.
86    #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
87    pub fn send_unary(
88        &self,
89        request: Request,
90    ) -> crate::Result<UnaryCompletion<Response, Request>> {
91        let (completor, completion) = oneshot::channel();
92        let completion_guard = self.send_message(Completion::Unary(completor), request)?;
93
94        let completion = UnaryCompletion::new(completion, completion_guard);
95
96        Ok(completion)
97    }
98
99    fn send_message(
100        &self,
101        completion: Completion<Response>,
102        request: Request,
103    ) -> crate::Result<CompletionGuard<Response, Request>> {
104        if !self.is_alive.load(std::sync::atomic::Ordering::Relaxed) {
105            // early-out if the connection is closed
106            return Err(crate::Error::ConnectionIsClosed);
107        }
108        let message_id = request.message_id();
109        let completion_guard = CompletionGuard::new(message_id, self.submission_queue.clone());
110
111        self.submission_queue
112            .send(RpcNotification::New(CompletableRpc {
113                message_id,
114                completion,
115                request,
116            }))
117            .map_err(|_e| crate::Error::ConnectionIsClosed)
118            .map(|_| completion_guard)
119    }
120}
121
122#[cfg(test)]
123mod test {
124    use std::future::Future;
125    use std::pin::pin;
126    use std::sync::atomic::AtomicBool;
127    use std::sync::Arc;
128    use std::sync::Mutex;
129    use std::task::Context;
130    use std::task::Poll;
131
132    use futures::task::noop_waker_ref;
133
134    use crate::client::connection_pool::ClientConnector;
135    use crate::client::connection_pool::ConnectionPool;
136    use crate::client::reactor::completion_reactor::CompletableRpc;
137    use crate::client::reactor::completion_reactor::DoNothingMessageHandler;
138    use crate::client::reactor::completion_reactor::RpcCompletionReactor;
139    use crate::client::reactor::completion_reactor::RpcNotification;
140    use crate::Message;
141
142    use super::RpcClient;
143
144    impl Message for u64 {
145        fn message_id(&self) -> u64 {
146            *self & 0xffffffff
147        }
148
149        fn control_code(&self) -> crate::ProtosocketControlCode {
150            match *self >> 32 {
151                0 => crate::ProtosocketControlCode::Normal,
152                1 => crate::ProtosocketControlCode::Cancel,
153                2 => crate::ProtosocketControlCode::End,
154                _ => unreachable!("invalid control code"),
155            }
156        }
157
158        fn set_message_id(&mut self, message_id: u64) {
159            *self = (*self & 0xf00000000) | message_id;
160        }
161
162        fn cancelled(message_id: u64) -> Self {
163            (1_u64 << 32) | message_id
164        }
165
166        fn ended(message_id: u64) -> Self {
167            (2 << 32) | message_id
168        }
169    }
170
171    fn drive_future<F: Future>(f: F) -> F::Output {
172        let mut f = pin!(f);
173        loop {
174            let next = f.as_mut().poll(&mut Context::from_waker(noop_waker_ref()));
175            if let Poll::Ready(result) = next {
176                break result;
177            }
178        }
179    }
180
181    #[allow(clippy::type_complexity)]
182    fn get_client() -> (
183        spillway::Receiver<RpcNotification<u64, u64>>,
184        RpcClient<u64, u64>,
185        RpcCompletionReactor<u64, u64, DoNothingMessageHandler<u64>>,
186    ) {
187        let (sender, remote_end) = spillway::channel();
188        let rpc_reactor =
189            RpcCompletionReactor::<u64, u64, _>::new(DoNothingMessageHandler::default());
190        let client = RpcClient::new(sender, &rpc_reactor);
191        (remote_end, client, rpc_reactor)
192    }
193
194    #[test]
195    fn unary_drop_cancel() {
196        let (mut remote_end, client, _reactor) = get_client();
197
198        let response = client.send_unary(4).expect("can send");
199
200        let notification = drive_future(remote_end.next()).expect("a request is sent");
201        assert!(matches!(
202            notification,
203            RpcNotification::New(CompletableRpc { message_id: 4, .. })
204        ));
205
206        drop(response);
207
208        let cancellation = drive_future(remote_end.next()).expect("a cancel is sent");
209        assert!(matches!(cancellation, RpcNotification::Cancel(4)));
210    }
211
212    #[test]
213    fn streaming_drop_cancel() {
214        let (mut remote_end, client, _reactor) = get_client();
215
216        let response = client.send_streaming(4).expect("can send");
217
218        let notification = drive_future(remote_end.next()).expect("a request is sent");
219        assert!(matches!(
220            notification,
221            RpcNotification::New(CompletableRpc { message_id: 4, .. })
222        ));
223
224        drop(response);
225
226        let cancellation = drive_future(remote_end.next()).expect("a cancel is sent");
227        assert!(matches!(cancellation, RpcNotification::Cancel(4)));
228    }
229
230    #[allow(clippy::type_complexity)]
231    #[derive(Default)]
232    struct TestConnector {
233        clients: Mutex<
234            Vec<(
235                spillway::Receiver<RpcNotification<u64, u64>>,
236                RpcClient<u64, u64>,
237                RpcCompletionReactor<u64, u64, DoNothingMessageHandler<u64>>,
238            )>,
239        >,
240        fail_connections: AtomicBool,
241    }
242    impl ClientConnector for Arc<TestConnector> {
243        type Request = u64;
244        type Response = u64;
245
246        async fn connect(self) -> crate::Result<RpcClient<Self::Request, Self::Response>> {
247            if self
248                .fail_connections
249                .load(std::sync::atomic::Ordering::Relaxed)
250            {
251                return Err(crate::Error::IoFailure(Arc::new(std::io::Error::other(
252                    "simulated connection failure",
253                ))));
254            }
255            // normally I'd just call `protosocket_rpc::client::connect` in here
256            let (remote_end, client, reactor) = get_client();
257            self.clients
258                .lock()
259                .expect("mutex works")
260                .push((remote_end, client.clone(), reactor));
261
262            Ok(client)
263        }
264    }
265
266    // have to use tokio::test for the connection pool because it uses tokio::spawn
267    #[tokio::test]
268    async fn connection_pool() {
269        let connector = Arc::new(TestConnector::default());
270        let pool = ConnectionPool::new(connector.clone(), 1);
271
272        let rpc_client_a = pool
273            .get_connection()
274            .await
275            .expect("can get a connection from the pool");
276        assert_eq!(
277            1,
278            connector.clients.lock().expect("mutex works").len(),
279            "one connection created"
280        );
281
282        let rpc_client_b = pool
283            .get_connection()
284            .await
285            .expect("can get a connection from the pool");
286        assert_eq!(
287            1,
288            connector.clients.lock().expect("mutex works").len(),
289            "still one connection created"
290        );
291
292        assert!(
293            Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
294            "same connection shared"
295        );
296
297        let _reply_a = rpc_client_a.send_unary(42).expect("can send");
298        let _reply_b = rpc_client_b.send_unary(43).expect("can send");
299
300        let (mut remote_end, _client, _reactor) = {
301            let mut clients = connector.clients.lock().expect("mutex works");
302            clients.pop().expect("one client exists")
303        };
304        assert!(matches!(
305            remote_end.next().await.expect("request a is received"),
306            RpcNotification::New(CompletableRpc { message_id: 42, .. })
307        ));
308        assert!(matches!(
309            remote_end.next().await.expect("request b is received"),
310            RpcNotification::New(CompletableRpc { message_id: 43, .. })
311        ));
312    }
313
314    #[tokio::test]
315    async fn connection_pool_reconnect() {
316        let connector = Arc::new(TestConnector::default());
317        let pool = ConnectionPool::new(connector.clone(), 1);
318
319        let rpc_client_a = pool
320            .get_connection()
321            .await
322            .expect("can get a connection from the pool");
323        assert_eq!(
324            1,
325            connector.clients.lock().expect("mutex works").len(),
326            "one connection created"
327        );
328
329        rpc_client_a
330            .is_alive
331            .store(false, std::sync::atomic::Ordering::Relaxed);
332
333        let rpc_client_b = pool.get_connection().await.expect("can get a connection from the pool even when the previous connection is dead, as long as the connection attempt succeeds");
334        assert_eq!(
335            2,
336            connector.clients.lock().expect("mutex works").len(),
337            "a new connection was created, so the connector was asked to make a new connection"
338        );
339        // Note that the connection pool holds a plain Vec of individual clients, so it cannot create more connections than it started with.
340
341        assert!(
342            !Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
343            "new connection created"
344        );
345    }
346
347    #[tokio::test]
348    async fn connection_pool_failure() {
349        let connector = Arc::new(TestConnector::default());
350        let pool = ConnectionPool::new(connector.clone(), 1);
351        connector
352            .fail_connections
353            .store(true, std::sync::atomic::Ordering::Relaxed);
354
355        pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
356    }
357
358    #[tokio::test]
359    async fn connection_pool_reconnect_failure_recovery() {
360        let connector = Arc::new(TestConnector::default());
361        let pool = ConnectionPool::new(connector.clone(), 1);
362        let rpc_client_a = pool
363            .get_connection()
364            .await
365            .expect("can get a connection from the pool");
366
367        rpc_client_a
368            .is_alive
369            .store(false, std::sync::atomic::Ordering::Relaxed);
370        connector
371            .fail_connections
372            .store(true, std::sync::atomic::Ordering::Relaxed);
373
374        pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
375
376        connector
377            .fail_connections
378            .store(false, std::sync::atomic::Ordering::Relaxed);
379
380        let rpc_client_b = pool
381            .get_connection()
382            .await
383            .expect("can get a connection from the pool now");
384        assert!(
385            !Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
386            "new connection created"
387        );
388    }
389}