1use std::sync::{atomic::AtomicBool, Arc};
2
3use tokio::sync::oneshot;
4
5use super::reactor::completion_reactor::{DoNothingMessageHandler, RpcCompletionReactor};
6use super::reactor::completion_registry::{Completion, CompletionGuard};
7use super::reactor::{
8 completion_streaming::StreamingCompletion, completion_unary::UnaryCompletion,
9};
10use crate::client::reactor::completion_reactor::{CompletableRpc, RpcNotification};
11use crate::Message;
12
13#[derive(Debug)]
19pub struct RpcClient<Request, Response>
20where
21 Request: Message,
22 Response: Message,
23{
24 submission_queue: spillway::Sender<RpcNotification<Response, Request>>,
25 is_alive: Arc<AtomicBool>,
26}
27
28impl<Request, Response> Clone for RpcClient<Request, Response>
29where
30 Request: Message,
31 Response: Message,
32{
33 fn clone(&self) -> Self {
34 Self {
35 submission_queue: self.submission_queue.clone(),
36 is_alive: self.is_alive.clone(),
37 }
38 }
39}
40
41impl<Request, Response> RpcClient<Request, Response>
42where
43 Request: Message,
44 Response: Message,
45{
46 pub(crate) fn new(
47 submission_queue: spillway::Sender<RpcNotification<Response, Request>>,
48 message_reactor: &RpcCompletionReactor<
49 Response,
50 Request,
51 DoNothingMessageHandler<Response>,
52 >,
53 ) -> Self {
54 Self {
55 submission_queue,
56 is_alive: message_reactor.alive_handle(),
57 }
58 }
59
60 pub fn is_alive(&self) -> bool {
64 self.is_alive.load(std::sync::atomic::Ordering::Relaxed)
65 }
66
67 #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
71 pub fn send_streaming(
72 &self,
73 request: Request,
74 ) -> crate::Result<StreamingCompletion<Response, Request>> {
75 let (sender, completion) = spillway::channel_with_concurrency(1);
76 let completion_guard = self.send_message(Completion::RemoteStreaming(sender), request)?;
77
78 let completion = StreamingCompletion::new(completion, completion_guard);
79
80 Ok(completion)
81 }
82
83 #[must_use = "You must await the completion to get the response. If you drop the completion, the request will be cancelled."]
87 pub fn send_unary(
88 &self,
89 request: Request,
90 ) -> crate::Result<UnaryCompletion<Response, Request>> {
91 let (completor, completion) = oneshot::channel();
92 let completion_guard = self.send_message(Completion::Unary(completor), request)?;
93
94 let completion = UnaryCompletion::new(completion, completion_guard);
95
96 Ok(completion)
97 }
98
99 fn send_message(
100 &self,
101 completion: Completion<Response>,
102 request: Request,
103 ) -> crate::Result<CompletionGuard<Response, Request>> {
104 if !self.is_alive.load(std::sync::atomic::Ordering::Relaxed) {
105 return Err(crate::Error::ConnectionIsClosed);
107 }
108 let message_id = request.message_id();
109 let completion_guard = CompletionGuard::new(message_id, self.submission_queue.clone());
110
111 self.submission_queue
112 .send(RpcNotification::New(CompletableRpc {
113 message_id,
114 completion,
115 request,
116 }))
117 .map_err(|_e| crate::Error::ConnectionIsClosed)
118 .map(|_| completion_guard)
119 }
120}
121
122#[cfg(test)]
123mod test {
124 use std::future::Future;
125 use std::pin::pin;
126 use std::sync::atomic::AtomicBool;
127 use std::sync::Arc;
128 use std::sync::Mutex;
129 use std::task::Context;
130 use std::task::Poll;
131
132 use futures::task::noop_waker_ref;
133
134 use crate::client::connection_pool::ClientConnector;
135 use crate::client::connection_pool::ConnectionPool;
136 use crate::client::reactor::completion_reactor::CompletableRpc;
137 use crate::client::reactor::completion_reactor::DoNothingMessageHandler;
138 use crate::client::reactor::completion_reactor::RpcCompletionReactor;
139 use crate::client::reactor::completion_reactor::RpcNotification;
140 use crate::Message;
141
142 use super::RpcClient;
143
144 impl Message for u64 {
145 fn message_id(&self) -> u64 {
146 *self & 0xffffffff
147 }
148
149 fn control_code(&self) -> crate::ProtosocketControlCode {
150 match *self >> 32 {
151 0 => crate::ProtosocketControlCode::Normal,
152 1 => crate::ProtosocketControlCode::Cancel,
153 2 => crate::ProtosocketControlCode::End,
154 _ => unreachable!("invalid control code"),
155 }
156 }
157
158 fn set_message_id(&mut self, message_id: u64) {
159 *self = (*self & 0xf00000000) | message_id;
160 }
161
162 fn cancelled(message_id: u64) -> Self {
163 (1_u64 << 32) | message_id
164 }
165
166 fn ended(message_id: u64) -> Self {
167 (2 << 32) | message_id
168 }
169 }
170
171 fn drive_future<F: Future>(f: F) -> F::Output {
172 let mut f = pin!(f);
173 loop {
174 let next = f.as_mut().poll(&mut Context::from_waker(noop_waker_ref()));
175 if let Poll::Ready(result) = next {
176 break result;
177 }
178 }
179 }
180
181 #[allow(clippy::type_complexity)]
182 fn get_client() -> (
183 spillway::Receiver<RpcNotification<u64, u64>>,
184 RpcClient<u64, u64>,
185 RpcCompletionReactor<u64, u64, DoNothingMessageHandler<u64>>,
186 ) {
187 let (sender, remote_end) = spillway::channel();
188 let rpc_reactor =
189 RpcCompletionReactor::<u64, u64, _>::new(DoNothingMessageHandler::default());
190 let client = RpcClient::new(sender, &rpc_reactor);
191 (remote_end, client, rpc_reactor)
192 }
193
194 #[test]
195 fn unary_drop_cancel() {
196 let (mut remote_end, client, _reactor) = get_client();
197
198 let response = client.send_unary(4).expect("can send");
199
200 let notification = drive_future(remote_end.next()).expect("a request is sent");
201 assert!(matches!(
202 notification,
203 RpcNotification::New(CompletableRpc { message_id: 4, .. })
204 ));
205
206 drop(response);
207
208 let cancellation = drive_future(remote_end.next()).expect("a cancel is sent");
209 assert!(matches!(cancellation, RpcNotification::Cancel(4)));
210 }
211
212 #[test]
213 fn streaming_drop_cancel() {
214 let (mut remote_end, client, _reactor) = get_client();
215
216 let response = client.send_streaming(4).expect("can send");
217
218 let notification = drive_future(remote_end.next()).expect("a request is sent");
219 assert!(matches!(
220 notification,
221 RpcNotification::New(CompletableRpc { message_id: 4, .. })
222 ));
223
224 drop(response);
225
226 let cancellation = drive_future(remote_end.next()).expect("a cancel is sent");
227 assert!(matches!(cancellation, RpcNotification::Cancel(4)));
228 }
229
230 #[allow(clippy::type_complexity)]
231 #[derive(Default)]
232 struct TestConnector {
233 clients: Mutex<
234 Vec<(
235 spillway::Receiver<RpcNotification<u64, u64>>,
236 RpcClient<u64, u64>,
237 RpcCompletionReactor<u64, u64, DoNothingMessageHandler<u64>>,
238 )>,
239 >,
240 fail_connections: AtomicBool,
241 }
242 impl ClientConnector for Arc<TestConnector> {
243 type Request = u64;
244 type Response = u64;
245
246 async fn connect(self) -> crate::Result<RpcClient<Self::Request, Self::Response>> {
247 if self
248 .fail_connections
249 .load(std::sync::atomic::Ordering::Relaxed)
250 {
251 return Err(crate::Error::IoFailure(Arc::new(std::io::Error::other(
252 "simulated connection failure",
253 ))));
254 }
255 let (remote_end, client, reactor) = get_client();
257 self.clients
258 .lock()
259 .expect("mutex works")
260 .push((remote_end, client.clone(), reactor));
261
262 Ok(client)
263 }
264 }
265
266 #[tokio::test]
268 async fn connection_pool() {
269 let connector = Arc::new(TestConnector::default());
270 let pool = ConnectionPool::new(connector.clone(), 1);
271
272 let rpc_client_a = pool
273 .get_connection()
274 .await
275 .expect("can get a connection from the pool");
276 assert_eq!(
277 1,
278 connector.clients.lock().expect("mutex works").len(),
279 "one connection created"
280 );
281
282 let rpc_client_b = pool
283 .get_connection()
284 .await
285 .expect("can get a connection from the pool");
286 assert_eq!(
287 1,
288 connector.clients.lock().expect("mutex works").len(),
289 "still one connection created"
290 );
291
292 assert!(
293 Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
294 "same connection shared"
295 );
296
297 let _reply_a = rpc_client_a.send_unary(42).expect("can send");
298 let _reply_b = rpc_client_b.send_unary(43).expect("can send");
299
300 let (mut remote_end, _client, _reactor) = {
301 let mut clients = connector.clients.lock().expect("mutex works");
302 clients.pop().expect("one client exists")
303 };
304 assert!(matches!(
305 remote_end.next().await.expect("request a is received"),
306 RpcNotification::New(CompletableRpc { message_id: 42, .. })
307 ));
308 assert!(matches!(
309 remote_end.next().await.expect("request b is received"),
310 RpcNotification::New(CompletableRpc { message_id: 43, .. })
311 ));
312 }
313
314 #[tokio::test]
315 async fn connection_pool_reconnect() {
316 let connector = Arc::new(TestConnector::default());
317 let pool = ConnectionPool::new(connector.clone(), 1);
318
319 let rpc_client_a = pool
320 .get_connection()
321 .await
322 .expect("can get a connection from the pool");
323 assert_eq!(
324 1,
325 connector.clients.lock().expect("mutex works").len(),
326 "one connection created"
327 );
328
329 rpc_client_a
330 .is_alive
331 .store(false, std::sync::atomic::Ordering::Relaxed);
332
333 let rpc_client_b = pool.get_connection().await.expect("can get a connection from the pool even when the previous connection is dead, as long as the connection attempt succeeds");
334 assert_eq!(
335 2,
336 connector.clients.lock().expect("mutex works").len(),
337 "a new connection was created, so the connector was asked to make a new connection"
338 );
339 assert!(
342 !Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
343 "new connection created"
344 );
345 }
346
347 #[tokio::test]
348 async fn connection_pool_failure() {
349 let connector = Arc::new(TestConnector::default());
350 let pool = ConnectionPool::new(connector.clone(), 1);
351 connector
352 .fail_connections
353 .store(true, std::sync::atomic::Ordering::Relaxed);
354
355 pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
356 }
357
358 #[tokio::test]
359 async fn connection_pool_reconnect_failure_recovery() {
360 let connector = Arc::new(TestConnector::default());
361 let pool = ConnectionPool::new(connector.clone(), 1);
362 let rpc_client_a = pool
363 .get_connection()
364 .await
365 .expect("can get a connection from the pool");
366
367 rpc_client_a
368 .is_alive
369 .store(false, std::sync::atomic::Ordering::Relaxed);
370 connector
371 .fail_connections
372 .store(true, std::sync::atomic::Ordering::Relaxed);
373
374 pool.get_connection().await.expect_err("connection attempt fails, and the calling code gets the error. It does not try forever without surfacing errors.");
375
376 connector
377 .fail_connections
378 .store(false, std::sync::atomic::Ordering::Relaxed);
379
380 let rpc_client_b = pool
381 .get_connection()
382 .await
383 .expect("can get a connection from the pool now");
384 assert!(
385 !Arc::ptr_eq(&rpc_client_a.is_alive, &rpc_client_b.is_alive),
386 "new connection created"
387 );
388 }
389}