1use std::marker::PhantomData;
2use std::sync::Weak;
3use std::{convert, fmt, io};
4
5use log::*;
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use tokio::sync::mpsc;
9use tokio::time::Duration;
10
11use crate::common::{Request, Response, UntypedRequest, UntypedResponse};
12
13mod mailbox;
14pub use mailbox::*;
15
16const CHANNEL_MAILBOX_CAPACITY: usize = 10000;
18
19pub struct Channel<T, U> {
25 inner: UntypedChannel,
26 _request: PhantomData<T>,
27 _response: PhantomData<U>,
28}
29
30impl<T, U> Clone for Channel<T, U> {
32 fn clone(&self) -> Self {
33 Self {
34 inner: self.inner.clone(),
35 _request: self._request,
36 _response: self._response,
37 }
38 }
39}
40
41impl<T, U> fmt::Debug for Channel<T, U> {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43 f.debug_struct("Channel")
44 .field("tx", &self.inner.tx)
45 .field("post_office", &self.inner.post_office)
46 .field("_request", &self._request)
47 .field("_response", &self._response)
48 .finish()
49 }
50}
51
52impl<T, U> Channel<T, U>
53where
54 T: Send + Sync + Serialize + 'static,
55 U: Send + Sync + DeserializeOwned + 'static,
56{
57 pub fn is_closed(&self) -> bool {
59 self.inner.is_closed()
60 }
61
62 pub fn into_untyped_channel(self) -> UntypedChannel {
64 self.inner
65 }
66
67 pub async fn assign_default_mailbox(&self, buffer: usize) -> io::Result<Mailbox<Response<U>>> {
69 Ok(map_to_typed_mailbox(
70 self.inner.assign_default_mailbox(buffer).await?,
71 ))
72 }
73
74 pub async fn remove_default_mailbox(&self) -> io::Result<()> {
77 self.inner.remove_default_mailbox().await
78 }
79
80 pub async fn mail(&mut self, req: impl Into<Request<T>>) -> io::Result<Mailbox<Response<U>>> {
84 Ok(map_to_typed_mailbox(
85 self.inner.mail(req.into().to_untyped_request()?).await?,
86 ))
87 }
88
89 pub async fn mail_timeout(
91 &mut self,
92 req: impl Into<Request<T>>,
93 duration: impl Into<Option<Duration>>,
94 ) -> io::Result<Mailbox<Response<U>>> {
95 match duration.into() {
96 Some(duration) => tokio::time::timeout(duration, self.mail(req))
97 .await
98 .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
99 .and_then(convert::identity),
100 None => self.mail(req).await,
101 }
102 }
103
104 pub async fn send(&mut self, req: impl Into<Request<T>>) -> io::Result<Response<U>> {
107 let mut mailbox = self.mail(req).await?;
109
110 mailbox
112 .next()
113 .await
114 .ok_or_else(|| io::Error::from(io::ErrorKind::ConnectionAborted))
115 }
116
117 pub async fn send_timeout(
119 &mut self,
120 req: impl Into<Request<T>>,
121 duration: impl Into<Option<Duration>>,
122 ) -> io::Result<Response<U>> {
123 match duration.into() {
124 Some(duration) => tokio::time::timeout(duration, self.send(req))
125 .await
126 .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
127 .and_then(convert::identity),
128 None => self.send(req).await,
129 }
130 }
131
132 pub async fn fire(&mut self, req: impl Into<Request<T>>) -> io::Result<()> {
135 self.inner.fire(req.into().to_untyped_request()?).await
136 }
137
138 pub async fn fire_timeout(
140 &mut self,
141 req: impl Into<Request<T>>,
142 duration: impl Into<Option<Duration>>,
143 ) -> io::Result<()> {
144 match duration.into() {
145 Some(duration) => tokio::time::timeout(duration, self.fire(req))
146 .await
147 .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
148 .and_then(convert::identity),
149 None => self.fire(req).await,
150 }
151 }
152}
153
154fn map_to_typed_mailbox<T: Send + DeserializeOwned + 'static>(
155 mailbox: Mailbox<UntypedResponse<'static>>,
156) -> Mailbox<Response<T>> {
157 mailbox.map_opt(|res| match res.to_typed_response() {
158 Ok(res) => Some(res),
159 Err(x) => {
160 if log::log_enabled!(Level::Trace) {
161 trace!(
162 "Invalid response payload: {}",
163 String::from_utf8_lossy(&res.payload)
164 );
165 }
166
167 error!(
168 "Unable to parse response payload into {}: {x}",
169 std::any::type_name::<T>()
170 );
171 None
172 }
173 })
174}
175
176#[derive(Debug)]
185pub struct UntypedChannel {
186 pub(crate) tx: mpsc::Sender<UntypedRequest<'static>>,
188
189 pub(crate) post_office: Weak<PostOffice<UntypedResponse<'static>>>,
191}
192
193impl Clone for UntypedChannel {
195 fn clone(&self) -> Self {
196 Self {
197 tx: self.tx.clone(),
198 post_office: Weak::clone(&self.post_office),
199 }
200 }
201}
202
203impl UntypedChannel {
204 pub fn is_closed(&self) -> bool {
206 self.tx.is_closed()
207 }
208
209 pub fn into_typed_channel<T, U>(self) -> Channel<T, U> {
211 Channel {
212 inner: self,
213 _request: PhantomData,
214 _response: PhantomData,
215 }
216 }
217
218 pub async fn assign_default_mailbox(
220 &self,
221 buffer: usize,
222 ) -> io::Result<Mailbox<UntypedResponse<'static>>> {
223 match Weak::upgrade(&self.post_office) {
224 Some(post_office) => Ok(post_office.assign_default_mailbox(buffer).await),
225 None => Err(io::Error::new(
226 io::ErrorKind::NotConnected,
227 "Channel's post office is no longer available",
228 )),
229 }
230 }
231
232 pub async fn remove_default_mailbox(&self) -> io::Result<()> {
235 match Weak::upgrade(&self.post_office) {
236 Some(post_office) => {
237 post_office.remove_default_mailbox().await;
238 Ok(())
239 }
240 None => Err(io::Error::new(
241 io::ErrorKind::NotConnected,
242 "Channel's post office is no longer available",
243 )),
244 }
245 }
246
247 pub async fn mail(
251 &mut self,
252 req: UntypedRequest<'_>,
253 ) -> io::Result<Mailbox<UntypedResponse<'static>>> {
254 let mailbox = Weak::upgrade(&self.post_office)
256 .ok_or_else(|| {
257 io::Error::new(
258 io::ErrorKind::NotConnected,
259 "Channel's post office is no longer available",
260 )
261 })?
262 .make_mailbox(req.id.clone().into_owned(), CHANNEL_MAILBOX_CAPACITY)
263 .await;
264
265 self.fire(req).await?;
267
268 Ok(mailbox)
270 }
271
272 pub async fn mail_timeout(
274 &mut self,
275 req: UntypedRequest<'_>,
276 duration: impl Into<Option<Duration>>,
277 ) -> io::Result<Mailbox<UntypedResponse<'static>>> {
278 match duration.into() {
279 Some(duration) => tokio::time::timeout(duration, self.mail(req))
280 .await
281 .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
282 .and_then(convert::identity),
283 None => self.mail(req).await,
284 }
285 }
286
287 pub async fn send(&mut self, req: UntypedRequest<'_>) -> io::Result<UntypedResponse<'static>> {
290 let mut mailbox = self.mail(req).await?;
292
293 mailbox
295 .next()
296 .await
297 .ok_or_else(|| io::Error::from(io::ErrorKind::ConnectionAborted))
298 }
299
300 pub async fn send_timeout(
302 &mut self,
303 req: UntypedRequest<'_>,
304 duration: impl Into<Option<Duration>>,
305 ) -> io::Result<UntypedResponse<'static>> {
306 match duration.into() {
307 Some(duration) => tokio::time::timeout(duration, self.send(req))
308 .await
309 .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
310 .and_then(convert::identity),
311 None => self.send(req).await,
312 }
313 }
314
315 pub async fn fire(&mut self, req: UntypedRequest<'_>) -> io::Result<()> {
318 self.tx
319 .send(req.into_owned())
320 .await
321 .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x.to_string()))
322 }
323
324 pub async fn fire_timeout(
326 &mut self,
327 req: UntypedRequest<'_>,
328 duration: impl Into<Option<Duration>>,
329 ) -> io::Result<()> {
330 match duration.into() {
331 Some(duration) => tokio::time::timeout(duration, self.fire(req))
332 .await
333 .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
334 .and_then(convert::identity),
335 None => self.fire(req).await,
336 }
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 mod typed {
345 use std::sync::Arc;
346 use std::time::Duration;
347
348 use test_log::test;
349
350 use super::*;
351
352 type TestChannel = Channel<u8, u8>;
353 type Setup = (
354 TestChannel,
355 mpsc::Receiver<UntypedRequest<'static>>,
356 Arc<PostOffice<UntypedResponse<'static>>>,
357 );
358
359 fn setup(buffer: usize) -> Setup {
360 let post_office = Arc::new(PostOffice::default());
361 let (tx, rx) = mpsc::channel(buffer);
362 let channel = {
363 let post_office = Arc::downgrade(&post_office);
364 UntypedChannel { tx, post_office }
365 };
366
367 (channel.into_typed_channel(), rx, post_office)
368 }
369
370 #[test(tokio::test)]
371 async fn mail_should_return_mailbox_that_receives_responses_until_post_office_drops_it() {
372 let (mut channel, _server, post_office) = setup(100);
373
374 let req = Request::new(0);
375 let res = Response::new(req.id.clone(), 1);
376
377 let mut mailbox = channel.mail(req).await.unwrap();
378
379 assert!(
381 post_office
382 .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
383 .await,
384 "Failed to deliver: {res:?}"
385 );
386 assert_eq!(mailbox.next().await, Some(res.clone()));
387
388 assert!(
390 post_office
391 .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
392 .await,
393 "Failed to deliver: {res:?}"
394 );
395 assert_eq!(mailbox.next().await, Some(res.clone()));
396
397 let next_task = tokio::spawn(async move { mailbox.next().await });
400 tokio::task::yield_now().await;
401
402 post_office.cancel(&res.origin_id).await;
404
405 match next_task.await {
406 Ok(None) => {}
407 x => panic!("Unexpected response: {:?}", x),
408 }
409 }
410
411 #[test(tokio::test)]
412 async fn send_should_wait_until_response_received() {
413 let (mut channel, _server, post_office) = setup(100);
414
415 let req = Request::new(0);
416 let res = Response::new(req.id.clone(), 1);
417
418 let (actual, _) = tokio::join!(
419 channel.send(req),
420 post_office
421 .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
422 );
423 match actual {
424 Ok(actual) => assert_eq!(actual, res),
425 x => panic!("Unexpected response: {:?}", x),
426 }
427 }
428
429 #[test(tokio::test)]
430 async fn send_timeout_should_fail_if_response_not_received_in_time() {
431 let (mut channel, mut server, _post_office) = setup(100);
432
433 let req = Request::new(0);
434 match channel.send_timeout(req, Duration::from_millis(30)).await {
435 Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
436 x => panic!("Unexpected response: {:?}", x),
437 }
438
439 let _frame = server.recv().await.unwrap();
440 }
441
442 #[test(tokio::test)]
443 async fn fire_should_send_request_and_not_wait_for_response() {
444 let (mut channel, mut server, _post_office) = setup(100);
445
446 let req = Request::new(0);
447 match channel.fire(req).await {
448 Ok(_) => {}
449 x => panic!("Unexpected response: {:?}", x),
450 }
451
452 let _frame = server.recv().await.unwrap();
453 }
454 }
455
456 mod untyped {
457 use std::sync::Arc;
458 use std::time::Duration;
459
460 use test_log::test;
461
462 use super::*;
463
464 type TestChannel = UntypedChannel;
465 type Setup = (
466 TestChannel,
467 mpsc::Receiver<UntypedRequest<'static>>,
468 Arc<PostOffice<UntypedResponse<'static>>>,
469 );
470
471 fn setup(buffer: usize) -> Setup {
472 let post_office = Arc::new(PostOffice::default());
473 let (tx, rx) = mpsc::channel(buffer);
474 let channel = {
475 let post_office = Arc::downgrade(&post_office);
476 TestChannel { tx, post_office }
477 };
478
479 (channel, rx, post_office)
480 }
481
482 #[test(tokio::test)]
483 async fn mail_should_return_mailbox_that_receives_responses_until_post_office_drops_it() {
484 let (mut channel, _server, post_office) = setup(100);
485
486 let req = Request::new(0).to_untyped_request().unwrap().into_owned();
487 let res = Response::new(req.id.clone().into_owned(), 1)
488 .to_untyped_response()
489 .unwrap()
490 .into_owned();
491
492 let mut mailbox = channel.mail(req).await.unwrap();
493
494 assert!(
496 post_office.deliver_untyped_response(res.clone()).await,
497 "Failed to deliver: {res:?}"
498 );
499 assert_eq!(mailbox.next().await, Some(res.clone()));
500
501 assert!(
503 post_office.deliver_untyped_response(res.clone()).await,
504 "Failed to deliver: {res:?}"
505 );
506 assert_eq!(mailbox.next().await, Some(res.clone()));
507
508 let next_task = tokio::spawn(async move { mailbox.next().await });
511 tokio::task::yield_now().await;
512
513 post_office
515 .cancel(&res.origin_id.clone().into_owned())
516 .await;
517
518 match next_task.await {
519 Ok(None) => {}
520 x => panic!("Unexpected response: {:?}", x),
521 }
522 }
523
524 #[test(tokio::test)]
525 async fn send_should_wait_until_response_received() {
526 let (mut channel, _server, post_office) = setup(100);
527
528 let req = Request::new(0).to_untyped_request().unwrap().into_owned();
529 let res = Response::new(req.id.clone().into_owned(), 1)
530 .to_untyped_response()
531 .unwrap()
532 .into_owned();
533
534 let (actual, _) = tokio::join!(
535 channel.send(req),
536 post_office.deliver_untyped_response(res.clone())
537 );
538 match actual {
539 Ok(actual) => assert_eq!(actual, res),
540 x => panic!("Unexpected response: {:?}", x),
541 }
542 }
543
544 #[test(tokio::test)]
545 async fn send_timeout_should_fail_if_response_not_received_in_time() {
546 let (mut channel, mut server, _post_office) = setup(100);
547
548 let req = Request::new(0).to_untyped_request().unwrap().into_owned();
549 match channel.send_timeout(req, Duration::from_millis(30)).await {
550 Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
551 x => panic!("Unexpected response: {:?}", x),
552 }
553
554 let _frame = server.recv().await.unwrap();
555 }
556
557 #[test(tokio::test)]
558 async fn fire_should_send_request_and_not_wait_for_response() {
559 let (mut channel, mut server, _post_office) = setup(100);
560
561 let req = Request::new(0).to_untyped_request().unwrap().into_owned();
562 match channel.fire(req).await {
563 Ok(_) => {}
564 x => panic!("Unexpected response: {:?}", x),
565 }
566
567 let _frame = server.recv().await.unwrap();
568 }
569 }
570}