protosocket_rpc/client/
rpc_client.rs

1use std::sync::{atomic::AtomicBool, Arc};
2
3use tokio::sync::{mpsc, oneshot};
4
5use super::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor};
6use super::reactor::completion_registry::{Completion, CompletionGuard, RpcRegistrar};
7use super::reactor::{
8    completion_streaming::StreamingCompletion, completion_unary::UnaryCompletion,
9};
10use crate::Message;
11
12/// A client for sending RPCs to a protosockets rpc server.
13///
14/// It handles sending messages to the server and associating the responses.
15/// Messages are sent and received in any order, asynchronously, and support cancellation.
16/// To cancel an RPC, drop the response future.
17#[derive(Debug, Clone)]
18pub struct RpcClient<Request, Response>
19where
20    Request: Message,
21    Response: Message,
22{
23    #[allow(clippy::type_complexity)]
24    in_flight_submission: RpcRegistrar<Response>,
25    submission_queue: tokio::sync::mpsc::Sender<Request>,
26    is_alive: Arc<AtomicBool>,
27}
28
29impl<Request, Response> RpcClient<Request, Response>
30where
31    Request: Message,
32    Response: Message,
33{
34    pub(crate) fn new(
35        submission_queue: mpsc::Sender<Request>,
36        message_reactor: &RpcCompletionReactor<Response, DoNothingMessageHandler<Response>>,
37    ) -> Self {
38        Self {
39            submission_queue,
40            in_flight_submission: message_reactor.in_flight_submission_handle(),
41            is_alive: message_reactor.alive_handle(),
42        }
43    }
44
45    /// Checking this before using the client does not guarantee that the client is still alive when you send  
46    /// your message. It may be useful for connection pool implementations - for example, [bb8::ManageConnection](https://github.com/djc/bb8/blob/09a043c001b3c15514d9f03991cfc87f7118a000/bb8/src/api.rs#L383-L384)'s  
47    /// is_valid and has_broken could be bound to this function to help the pool cycle out broken connections.  
48    pub fn is_alive(&self) -> bool {
49        self.is_alive.load(std::sync::atomic::Ordering::Relaxed)
50    }
51
52    /// Send a server-streaming rpc to the server.
53    ///
54    /// This function only sends the request. You must consume the completion stream to get the response.
55    #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
56    pub async fn send_streaming(
57        &self,
58        request: Request,
59    ) -> crate::Result<StreamingCompletion<Response, Request>> {
60        let (sender, completion) = mpsc::unbounded_channel();
61        let completion_guard = self
62            .send_message(Completion::RemoteStreaming(sender), request)
63            .await?;
64
65        let completion = StreamingCompletion::new(completion, completion_guard);
66
67        Ok(completion)
68    }
69
70    /// Send a unary rpc to the server.
71    ///
72    /// This function only sends the request. You must await the completion to get the response.
73    #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
74    pub async fn send_unary(
75        &self,
76        request: Request,
77    ) -> crate::Result<UnaryCompletion<Response, Request>> {
78        let (completor, completion) = oneshot::channel();
79        let completion_guard = self
80            .send_message(Completion::Unary(completor), request)
81            .await?;
82
83        let completion = UnaryCompletion::new(completion, completion_guard);
84
85        Ok(completion)
86    }
87
88    async fn send_message(
89        &self,
90        completion: Completion<Response>,
91        request: Request,
92    ) -> crate::Result<CompletionGuard<Response, Request>> {
93        if !self.is_alive.load(std::sync::atomic::Ordering::Relaxed) {
94            // early-out if the connection is closed
95            return Err(crate::Error::ConnectionIsClosed);
96        }
97        let completion_guard = self.in_flight_submission.register_completion(
98            request.message_id(),
99            completion,
100            self.submission_queue.clone(),
101        );
102        self.submission_queue
103            .send(request)
104            .await
105            .map_err(|_e| crate::Error::ConnectionIsClosed)
106            .map(|_| completion_guard)
107    }
108}
109
110#[cfg(test)]
111mod test {
112    use std::future::Future;
113    use std::pin::pin;
114    use std::task::Context;
115    use std::task::Poll;
116
117    use futures::task::noop_waker_ref;
118
119    use crate::client::reactor::completion_reactor::DoNothingMessageHandler;
120    use crate::client::reactor::completion_reactor::RpcCompletionReactor;
121    use crate::Message;
122
123    use super::RpcClient;
124
125    impl Message for u64 {
126        fn message_id(&self) -> u64 {
127            *self & 0xffffffff
128        }
129
130        fn control_code(&self) -> crate::ProtosocketControlCode {
131            match *self >> 32 {
132                0 => crate::ProtosocketControlCode::Normal,
133                1 => crate::ProtosocketControlCode::Cancel,
134                2 => crate::ProtosocketControlCode::End,
135                _ => unreachable!("invalid control code"),
136            }
137        }
138
139        fn set_message_id(&mut self, message_id: u64) {
140            *self = (*self & 0xf00000000) | message_id;
141        }
142
143        fn cancelled(message_id: u64) -> Self {
144            (1_u64 << 32) | message_id
145        }
146
147        fn ended(message_id: u64) -> Self {
148            (2 << 32) | message_id
149        }
150    }
151
152    fn drive_future<F: Future>(f: F) -> F::Output {
153        let mut f = pin!(f);
154        loop {
155            let next = f.as_mut().poll(&mut Context::from_waker(noop_waker_ref()));
156            if let Poll::Ready(result) = next {
157                break result;
158            }
159        }
160    }
161
162    #[allow(clippy::type_complexity)]
163    fn get_client() -> (
164        tokio::sync::mpsc::Receiver<u64>,
165        RpcClient<u64, u64>,
166        RpcCompletionReactor<u64, DoNothingMessageHandler<u64>>,
167    ) {
168        let (sender, remote_end) = tokio::sync::mpsc::channel::<u64>(10);
169        let rpc_reactor = RpcCompletionReactor::<u64, _>::new(DoNothingMessageHandler::default());
170        let client = RpcClient::new(sender, &rpc_reactor);
171        (remote_end, client, rpc_reactor)
172    }
173
174    #[test]
175    fn unary_drop_cancel() {
176        let (mut remote_end, client, _reactor) = get_client();
177
178        let response = drive_future(client.send_unary(4)).expect("can send");
179        assert_eq!(4, remote_end.blocking_recv().expect("a request is sent"));
180        assert!(remote_end.is_empty(), "no more messages yet");
181
182        drop(response);
183
184        assert_eq!(
185            (1 << 32) + 4,
186            remote_end.blocking_recv().expect("a cancel is sent")
187        );
188    }
189
190    #[test]
191    fn streaming_drop_cancel() {
192        let (mut remote_end, client, _reactor) = get_client();
193
194        let response = drive_future(client.send_streaming(4)).expect("can send");
195        assert_eq!(4, remote_end.blocking_recv().expect("a request is sent"));
196        assert!(remote_end.is_empty(), "no more messages yet");
197
198        drop(response);
199
200        assert_eq!(
201            (1 << 32) + 4,
202            remote_end.blocking_recv().expect("a cancel is sent")
203        );
204    }
205}