1use std::sync::{atomic::AtomicBool, Arc};
2
3use tokio::sync::{mpsc, oneshot};
4
5use super::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor};
6use super::reactor::completion_registry::{Completion, CompletionGuard, RpcRegistrar};
7use super::reactor::{
8 completion_streaming::StreamingCompletion, completion_unary::UnaryCompletion,
9};
10use crate::Message;
11
12#[derive(Debug)]
18pub struct RpcClient<Request, Response>
19where
20 Request: Message,
21 Response: Message,
22{
23 #[allow(clippy::type_complexity)]
24 in_flight_submission: RpcRegistrar<Response>,
25 submission_queue: tokio::sync::mpsc::Sender<Request>,
26 is_alive: Arc<AtomicBool>,
27}
28
29impl<Request, Response> Clone for RpcClient<Request, Response>
30where
31 Request: Message,
32 Response: Message,
33{
34 fn clone(&self) -> Self {
35 Self {
36 in_flight_submission: self.in_flight_submission.clone(),
37 submission_queue: self.submission_queue.clone(),
38 is_alive: self.is_alive.clone(),
39 }
40 }
41}
42
43impl<Request, Response> RpcClient<Request, Response>
44where
45 Request: Message,
46 Response: Message,
47{
48 pub(crate) fn new(
49 submission_queue: mpsc::Sender<Request>,
50 message_reactor: &RpcCompletionReactor<Response, DoNothingMessageHandler<Response>>,
51 ) -> Self {
52 Self {
53 submission_queue,
54 in_flight_submission: message_reactor.in_flight_submission_handle(),
55 is_alive: message_reactor.alive_handle(),
56 }
57 }
58
59 pub fn is_alive(&self) -> bool {
63 self.is_alive.load(std::sync::atomic::Ordering::Relaxed)
64 }
65
66 #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
70 pub async fn send_streaming(
71 &self,
72 request: Request,
73 ) -> crate::Result<StreamingCompletion<Response, Request>> {
74 let (sender, completion) = mpsc::unbounded_channel();
75 let completion_guard = self
76 .send_message(Completion::RemoteStreaming(sender), request)
77 .await?;
78
79 let completion = StreamingCompletion::new(completion, completion_guard);
80
81 Ok(completion)
82 }
83
84 #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
88 pub async fn send_unary(
89 &self,
90 request: Request,
91 ) -> crate::Result<UnaryCompletion<Response, Request>> {
92 let (completor, completion) = oneshot::channel();
93 let completion_guard = self
94 .send_message(Completion::Unary(completor), request)
95 .await?;
96
97 let completion = UnaryCompletion::new(completion, completion_guard);
98
99 Ok(completion)
100 }
101
102 async fn send_message(
103 &self,
104 completion: Completion<Response>,
105 request: Request,
106 ) -> crate::Result<CompletionGuard<Response, Request>> {
107 if !self.is_alive.load(std::sync::atomic::Ordering::Relaxed) {
108 return Err(crate::Error::ConnectionIsClosed);
110 }
111 let completion_guard = self.in_flight_submission.register_completion(
112 request.message_id(),
113 completion,
114 self.submission_queue.clone(),
115 );
116 self.submission_queue
117 .send(request)
118 .await
119 .map_err(|_e| crate::Error::ConnectionIsClosed)
120 .map(|_| completion_guard)
121 }
122}
123
124#[cfg(test)]
125mod test {
126 use std::future::Future;
127 use std::pin::pin;
128 use std::sync::atomic::AtomicBool;
129 use std::sync::Arc;
130 use std::sync::Mutex;
131 use std::task::Context;
132 use std::task::Poll;
133
134 use futures::task::noop_waker_ref;
135 use tokio::sync::mpsc;
136
137 use crate::client::connection_pool::ClientConnector;
138 use crate::client::connection_pool::ConnectionPool;
139 use crate::client::reactor::completion_reactor::DoNothingMessageHandler;
140 use crate::client::reactor::completion_reactor::RpcCompletionReactor;
141 use crate::Message;
142
143 use super::RpcClient;
144
145 impl Message for u64 {
146 fn message_id(&self) -> u64 {
147 *self & 0xffffffff
148 }
149
150 fn control_code(&self) -> crate::ProtosocketControlCode {
151 match *self >> 32 {
152 0 => crate::ProtosocketControlCode::Normal,
153 1 => crate::ProtosocketControlCode::Cancel,
154 2 => crate::ProtosocketControlCode::End,
155 _ => unreachable!("invalid control code"),
156 }
157 }
158
159 fn set_message_id(&mut self, message_id: u64) {
160 *self = (*self & 0xf00000000) | message_id;
161 }
162
163 fn cancelled(message_id: u64) -> Self {
164 (1_u64 << 32) | message_id
165 }
166
167 fn ended(message_id: u64) -> Self {
168 (2 << 32) | message_id
169 }
170 }
171
172 fn drive_future<F: Future>(f: F) -> F::Output {
173 let mut f = pin!(f);
174 loop {
175 let next = f.as_mut().poll(&mut Context::from_waker(noop_waker_ref()));
176 if let Poll::Ready(result) = next {
177 break result;
178 }
179 }
180 }
181
182 #[allow(clippy::type_complexity)]
183 fn get_client() -> (
184 tokio::sync::mpsc::Receiver<u64>,
185 RpcClient<u64, u64>,
186 RpcCompletionReactor<u64, DoNothingMessageHandler<u64>>,
187 ) {
188 let (sender, remote_end) = tokio::sync::mpsc::channel::<u64>(10);
189 let rpc_reactor = RpcCompletionReactor::<u64, _>::new(DoNothingMessageHandler::default());
190 let client = RpcClient::new(sender, &rpc_reactor);
191 (remote_end, client, rpc_reactor)
192 }
193
194 #[test]
195 fn unary_drop_cancel() {
196 let (mut remote_end, client, _reactor) = get_client();
197
198 let response = drive_future(client.send_unary(4)).expect("can send");
199 assert_eq!(4, remote_end.blocking_recv().expect("a request is sent"));
200 assert!(remote_end.is_empty(), "no more messages yet");
201
202 drop(response);
203
204 assert_eq!(
205 (1 << 32) + 4,
206 remote_end.blocking_recv().expect("a cancel is sent")
207 );
208 }
209
210 #[test]
211 fn streaming_drop_cancel() {
212 let (mut remote_end, client, _reactor) = get_client();
213
214 let response = drive_future(client.send_streaming(4)).expect("can send");
215 assert_eq!(4, remote_end.blocking_recv().expect("a request is sent"));
216 assert!(remote_end.is_empty(), "no more messages yet");
217
218 drop(response);
219
220 assert_eq!(
221 (1 << 32) + 4,
222 remote_end.blocking_recv().expect("a cancel is sent")
223 );
224 }
225
226 #[allow(clippy::type_complexity)]
227 #[derive(Default)]
228 struct TestConnector {
229 clients: Mutex<
230 Vec<(
231 mpsc::Receiver<u64>,
232 RpcClient<u64, u64>,
233 RpcCompletionReactor<u64, DoNothingMessageHandler<u64>>,
234 )>,
235 >,
236 fail_connections: AtomicBool,
237 }
238 impl ClientConnector for Arc<TestConnector> {
239 type Request = u64;
240 type Response = u64;
241
242 async fn connect(self) -> crate::Result<RpcClient<Self::Request, Self::Response>> {
243 if self
244 .fail_connections
245 .load(std::sync::atomic::Ordering::Relaxed)
246 {
247 return Err(crate::Error::IoFailure(Arc::new(std::io::Error::other(
248 "simulated connection failure",
249 ))));
250 }
251 let (remote_end, client, reactor) = get_client();
253 self.clients
254 .lock()
255 .expect("mutex works")
256 .push((remote_end, client.clone(), reactor));
257
258 Ok(client)
259 }
260 }
261
262 #[tokio::test]
264 async fn connection_pool() {
265 let connector = Arc::new(TestConnector::default());
266 let pool = ConnectionPool::new(connector.clone(), 1);
267
268 let rpc_client_a = pool
269 .get_connection()
270 .await
271 .expect("can get a connection from the pool");
272 assert_eq!(
273 1,
274 connector.clients.lock().expect("mutex works").len(),
275 "one connection created"
276 );
277
278 let rpc_client_b = pool
279 .get_connection()
280 .await
281 .expect("can get a connection from the pool");
282 assert_eq!(
283 1,
284 connector.clients.lock().expect("mutex works").len(),
285 "still one connection created"
286 );
287
288 assert!(
289 Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
290 "same connection shared"
291 );
292
293 let _reply_a = rpc_client_a.send_unary(42).await.expect("can send");
294 let _reply_b = rpc_client_b.send_unary(43).await.expect("can send");
295
296 let (mut remote_end, _client, _reactor) = {
297 let mut clients = connector.clients.lock().expect("mutex works");
298 clients.pop().expect("one client exists")
299 };
300 assert_eq!(42, remote_end.recv().await.expect("request a is received"));
301 assert_eq!(43, remote_end.recv().await.expect("request b is received"));
302 }
303
304 #[tokio::test]
305 async fn connection_pool_reconnect() {
306 let connector = Arc::new(TestConnector::default());
307 let pool = ConnectionPool::new(connector.clone(), 1);
308
309 let rpc_client_a = pool
310 .get_connection()
311 .await
312 .expect("can get a connection from the pool");
313 assert_eq!(
314 1,
315 connector.clients.lock().expect("mutex works").len(),
316 "one connection created"
317 );
318
319 rpc_client_a
320 .is_alive
321 .store(false, std::sync::atomic::Ordering::Relaxed);
322
323 let rpc_client_b = pool.get_connection().await.expect("can get a connection from the pool even when the previous connection is dead, as long as the connection attempt succeeds");
324 assert_eq!(
325 2,
326 connector.clients.lock().expect("mutex works").len(),
327 "a new connection was created, so the connector was asked to make a new connection"
328 );
329 assert!(
332 !Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
333 "new connection created"
334 );
335 }
336
337 #[tokio::test]
338 async fn connection_pool_failure() {
339 let connector = Arc::new(TestConnector::default());
340 let pool = ConnectionPool::new(connector.clone(), 1);
341 connector
342 .fail_connections
343 .store(true, std::sync::atomic::Ordering::Relaxed);
344
345 pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
346 }
347
348 #[tokio::test]
349 async fn connection_pool_reconnect_failure_recovery() {
350 let connector = Arc::new(TestConnector::default());
351 let pool = ConnectionPool::new(connector.clone(), 1);
352 let rpc_client_a = pool
353 .get_connection()
354 .await
355 .expect("can get a connection from the pool");
356
357 rpc_client_a
358 .is_alive
359 .store(false, std::sync::atomic::Ordering::Relaxed);
360 connector
361 .fail_connections
362 .store(true, std::sync::atomic::Ordering::Relaxed);
363
364 pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
365
366 connector
367 .fail_connections
368 .store(false, std::sync::atomic::Ordering::Relaxed);
369
370 let rpc_client_b = pool
371 .get_connection()
372 .await
373 .expect("can get a connection from the pool now");
374 assert!(
375 !Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
376 "new connection created"
377 );
378 }
379}