protosocket_rpc/client/
rpc_client.rs1use std::sync::{atomic::AtomicBool, Arc};
2
3use tokio::sync::{mpsc, oneshot};
4
5use super::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor};
6use super::reactor::completion_registry::{Completion, CompletionGuard, RpcRegistrar};
7use super::reactor::{
8 completion_streaming::StreamingCompletion, completion_unary::UnaryCompletion,
9};
10use crate::Message;
11
12#[derive(Debug, Clone)]
18pub struct RpcClient<Request, Response>
19where
20 Request: Message,
21 Response: Message,
22{
23 #[allow(clippy::type_complexity)]
24 in_flight_submission: RpcRegistrar<Response>,
25 submission_queue: tokio::sync::mpsc::Sender<Request>,
26 is_alive: Arc<AtomicBool>,
27}
28
29impl<Request, Response> RpcClient<Request, Response>
30where
31 Request: Message,
32 Response: Message,
33{
34 pub(crate) fn new(
35 submission_queue: mpsc::Sender<Request>,
36 message_reactor: &RpcCompletionReactor<Response, DoNothingMessageHandler<Response>>,
37 ) -> Self {
38 Self {
39 submission_queue,
40 in_flight_submission: message_reactor.in_flight_submission_handle(),
41 is_alive: message_reactor.alive_handle(),
42 }
43 }
44
45 pub fn is_alive(&self) -> bool {
49 self.is_alive.load(std::sync::atomic::Ordering::Relaxed)
50 }
51
52 #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
56 pub async fn send_streaming(
57 &self,
58 request: Request,
59 ) -> crate::Result<StreamingCompletion<Response, Request>> {
60 let (sender, completion) = mpsc::unbounded_channel();
61 let completion_guard = self
62 .send_message(Completion::RemoteStreaming(sender), request)
63 .await?;
64
65 let completion = StreamingCompletion::new(completion, completion_guard);
66
67 Ok(completion)
68 }
69
70 #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
74 pub async fn send_unary(
75 &self,
76 request: Request,
77 ) -> crate::Result<UnaryCompletion<Response, Request>> {
78 let (completor, completion) = oneshot::channel();
79 let completion_guard = self
80 .send_message(Completion::Unary(completor), request)
81 .await?;
82
83 let completion = UnaryCompletion::new(completion, completion_guard);
84
85 Ok(completion)
86 }
87
88 async fn send_message(
89 &self,
90 completion: Completion<Response>,
91 request: Request,
92 ) -> crate::Result<CompletionGuard<Response, Request>> {
93 if !self.is_alive.load(std::sync::atomic::Ordering::Relaxed) {
94 return Err(crate::Error::ConnectionIsClosed);
96 }
97 let completion_guard = self.in_flight_submission.register_completion(
98 request.message_id(),
99 completion,
100 self.submission_queue.clone(),
101 );
102 self.submission_queue
103 .send(request)
104 .await
105 .map_err(|_e| crate::Error::ConnectionIsClosed)
106 .map(|_| completion_guard)
107 }
108}
109
110#[cfg(test)]
111mod test {
112 use std::future::Future;
113 use std::pin::pin;
114 use std::task::Context;
115 use std::task::Poll;
116
117 use futures::task::noop_waker_ref;
118
119 use crate::client::reactor::completion_reactor::DoNothingMessageHandler;
120 use crate::client::reactor::completion_reactor::RpcCompletionReactor;
121 use crate::Message;
122
123 use super::RpcClient;
124
125 impl Message for u64 {
126 fn message_id(&self) -> u64 {
127 *self & 0xffffffff
128 }
129
130 fn control_code(&self) -> crate::ProtosocketControlCode {
131 match *self >> 32 {
132 0 => crate::ProtosocketControlCode::Normal,
133 1 => crate::ProtosocketControlCode::Cancel,
134 2 => crate::ProtosocketControlCode::End,
135 _ => unreachable!("invalid control code"),
136 }
137 }
138
139 fn set_message_id(&mut self, message_id: u64) {
140 *self = (*self & 0xf00000000) | message_id;
141 }
142
143 fn cancelled(message_id: u64) -> Self {
144 (1_u64 << 32) | message_id
145 }
146
147 fn ended(message_id: u64) -> Self {
148 (2 << 32) | message_id
149 }
150 }
151
152 fn drive_future<F: Future>(f: F) -> F::Output {
153 let mut f = pin!(f);
154 loop {
155 let next = f.as_mut().poll(&mut Context::from_waker(noop_waker_ref()));
156 if let Poll::Ready(result) = next {
157 break result;
158 }
159 }
160 }
161
162 #[allow(clippy::type_complexity)]
163 fn get_client() -> (
164 tokio::sync::mpsc::Receiver<u64>,
165 RpcClient<u64, u64>,
166 RpcCompletionReactor<u64, DoNothingMessageHandler<u64>>,
167 ) {
168 let (sender, remote_end) = tokio::sync::mpsc::channel::<u64>(10);
169 let rpc_reactor = RpcCompletionReactor::<u64, _>::new(DoNothingMessageHandler::default());
170 let client = RpcClient::new(sender, &rpc_reactor);
171 (remote_end, client, rpc_reactor)
172 }
173
174 #[test]
175 fn unary_drop_cancel() {
176 let (mut remote_end, client, _reactor) = get_client();
177
178 let response = drive_future(client.send_unary(4)).expect("can send");
179 assert_eq!(4, remote_end.blocking_recv().expect("a request is sent"));
180 assert!(remote_end.is_empty(), "no more messages yet");
181
182 drop(response);
183
184 assert_eq!(
185 (1 << 32) + 4,
186 remote_end.blocking_recv().expect("a cancel is sent")
187 );
188 }
189
190 #[test]
191 fn streaming_drop_cancel() {
192 let (mut remote_end, client, _reactor) = get_client();
193
194 let response = drive_future(client.send_streaming(4)).expect("can send");
195 assert_eq!(4, remote_end.blocking_recv().expect("a request is sent"));
196 assert!(remote_end.is_empty(), "no more messages yet");
197
198 drop(response);
199
200 assert_eq!(
201 (1 << 32) + 4,
202 remote_end.blocking_recv().expect("a cancel is sent")
203 );
204 }
205}