1use std::marker::PhantomData;
2use std::sync::Weak;
3use std::{convert, fmt, io};
4
5use log::*;
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use tokio::sync::mpsc;
9use tokio::time::Duration;
10
11use crate::common::{Request, Response, UntypedRequest, UntypedResponse};
12
13mod mailbox;
14pub use mailbox::*;
15
16const CHANNEL_MAILBOX_CAPACITY: usize = 10000;
18
19pub struct Channel<T, U> {
25    inner: UntypedChannel,
26    _request: PhantomData<T>,
27    _response: PhantomData<U>,
28}
29
30impl<T, U> Clone for Channel<T, U> {
32    fn clone(&self) -> Self {
33        Self {
34            inner: self.inner.clone(),
35            _request: self._request,
36            _response: self._response,
37        }
38    }
39}
40
41impl<T, U> fmt::Debug for Channel<T, U> {
42    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
43        f.debug_struct("Channel")
44            .field("tx", &self.inner.tx)
45            .field("post_office", &self.inner.post_office)
46            .field("_request", &self._request)
47            .field("_response", &self._response)
48            .finish()
49    }
50}
51
52impl<T, U> Channel<T, U>
53where
54    T: Send + Sync + Serialize + 'static,
55    U: Send + Sync + DeserializeOwned + 'static,
56{
57    pub fn is_closed(&self) -> bool {
59        self.inner.is_closed()
60    }
61
62    pub fn into_untyped_channel(self) -> UntypedChannel {
64        self.inner
65    }
66
67    pub async fn assign_default_mailbox(&self, buffer: usize) -> io::Result<Mailbox<Response<U>>> {
69        Ok(map_to_typed_mailbox(
70            self.inner.assign_default_mailbox(buffer).await?,
71        ))
72    }
73
74    pub async fn remove_default_mailbox(&self) -> io::Result<()> {
77        self.inner.remove_default_mailbox().await
78    }
79
80    pub async fn mail(&mut self, req: impl Into<Request<T>>) -> io::Result<Mailbox<Response<U>>> {
84        Ok(map_to_typed_mailbox(
85            self.inner.mail(req.into().to_untyped_request()?).await?,
86        ))
87    }
88
89    pub async fn mail_timeout(
91        &mut self,
92        req: impl Into<Request<T>>,
93        duration: impl Into<Option<Duration>>,
94    ) -> io::Result<Mailbox<Response<U>>> {
95        match duration.into() {
96            Some(duration) => tokio::time::timeout(duration, self.mail(req))
97                .await
98                .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
99                .and_then(convert::identity),
100            None => self.mail(req).await,
101        }
102    }
103
104    pub async fn send(&mut self, req: impl Into<Request<T>>) -> io::Result<Response<U>> {
107        let mut mailbox = self.mail(req).await?;
109
110        mailbox
112            .next()
113            .await
114            .ok_or_else(|| io::Error::from(io::ErrorKind::ConnectionAborted))
115    }
116
117    pub async fn send_timeout(
119        &mut self,
120        req: impl Into<Request<T>>,
121        duration: impl Into<Option<Duration>>,
122    ) -> io::Result<Response<U>> {
123        match duration.into() {
124            Some(duration) => tokio::time::timeout(duration, self.send(req))
125                .await
126                .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
127                .and_then(convert::identity),
128            None => self.send(req).await,
129        }
130    }
131
132    pub async fn fire(&mut self, req: impl Into<Request<T>>) -> io::Result<()> {
135        self.inner.fire(req.into().to_untyped_request()?).await
136    }
137
138    pub async fn fire_timeout(
140        &mut self,
141        req: impl Into<Request<T>>,
142        duration: impl Into<Option<Duration>>,
143    ) -> io::Result<()> {
144        match duration.into() {
145            Some(duration) => tokio::time::timeout(duration, self.fire(req))
146                .await
147                .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
148                .and_then(convert::identity),
149            None => self.fire(req).await,
150        }
151    }
152}
153
154fn map_to_typed_mailbox<T: Send + DeserializeOwned + 'static>(
155    mailbox: Mailbox<UntypedResponse<'static>>,
156) -> Mailbox<Response<T>> {
157    mailbox.map_opt(|res| match res.to_typed_response() {
158        Ok(res) => Some(res),
159        Err(x) => {
160            if log::log_enabled!(Level::Trace) {
161                trace!(
162                    "Invalid response payload: {}",
163                    String::from_utf8_lossy(&res.payload)
164                );
165            }
166
167            error!(
168                "Unable to parse response payload into {}: {x}",
169                std::any::type_name::<T>()
170            );
171            None
172        }
173    })
174}
175
176#[derive(Debug)]
185pub struct UntypedChannel {
186    pub(crate) tx: mpsc::Sender<UntypedRequest<'static>>,
188
189    pub(crate) post_office: Weak<PostOffice<UntypedResponse<'static>>>,
191}
192
193impl Clone for UntypedChannel {
195    fn clone(&self) -> Self {
196        Self {
197            tx: self.tx.clone(),
198            post_office: Weak::clone(&self.post_office),
199        }
200    }
201}
202
203impl UntypedChannel {
204    pub fn is_closed(&self) -> bool {
206        self.tx.is_closed()
207    }
208
209    pub fn into_typed_channel<T, U>(self) -> Channel<T, U> {
211        Channel {
212            inner: self,
213            _request: PhantomData,
214            _response: PhantomData,
215        }
216    }
217
218    pub async fn assign_default_mailbox(
220        &self,
221        buffer: usize,
222    ) -> io::Result<Mailbox<UntypedResponse<'static>>> {
223        match Weak::upgrade(&self.post_office) {
224            Some(post_office) => Ok(post_office.assign_default_mailbox(buffer).await),
225            None => Err(io::Error::new(
226                io::ErrorKind::NotConnected,
227                "Channel's post office is no longer available",
228            )),
229        }
230    }
231
232    pub async fn remove_default_mailbox(&self) -> io::Result<()> {
235        match Weak::upgrade(&self.post_office) {
236            Some(post_office) => {
237                post_office.remove_default_mailbox().await;
238                Ok(())
239            }
240            None => Err(io::Error::new(
241                io::ErrorKind::NotConnected,
242                "Channel's post office is no longer available",
243            )),
244        }
245    }
246
247    pub async fn mail(
251        &mut self,
252        req: UntypedRequest<'_>,
253    ) -> io::Result<Mailbox<UntypedResponse<'static>>> {
254        let mailbox = Weak::upgrade(&self.post_office)
256            .ok_or_else(|| {
257                io::Error::new(
258                    io::ErrorKind::NotConnected,
259                    "Channel's post office is no longer available",
260                )
261            })?
262            .make_mailbox(req.id.clone().into_owned(), CHANNEL_MAILBOX_CAPACITY)
263            .await;
264
265        self.fire(req).await?;
267
268        Ok(mailbox)
270    }
271
272    pub async fn mail_timeout(
274        &mut self,
275        req: UntypedRequest<'_>,
276        duration: impl Into<Option<Duration>>,
277    ) -> io::Result<Mailbox<UntypedResponse<'static>>> {
278        match duration.into() {
279            Some(duration) => tokio::time::timeout(duration, self.mail(req))
280                .await
281                .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
282                .and_then(convert::identity),
283            None => self.mail(req).await,
284        }
285    }
286
287    pub async fn send(&mut self, req: UntypedRequest<'_>) -> io::Result<UntypedResponse<'static>> {
290        let mut mailbox = self.mail(req).await?;
292
293        mailbox
295            .next()
296            .await
297            .ok_or_else(|| io::Error::from(io::ErrorKind::ConnectionAborted))
298    }
299
300    pub async fn send_timeout(
302        &mut self,
303        req: UntypedRequest<'_>,
304        duration: impl Into<Option<Duration>>,
305    ) -> io::Result<UntypedResponse<'static>> {
306        match duration.into() {
307            Some(duration) => tokio::time::timeout(duration, self.send(req))
308                .await
309                .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
310                .and_then(convert::identity),
311            None => self.send(req).await,
312        }
313    }
314
315    pub async fn fire(&mut self, req: UntypedRequest<'_>) -> io::Result<()> {
318        self.tx
319            .send(req.into_owned())
320            .await
321            .map_err(|x| io::Error::new(io::ErrorKind::BrokenPipe, x.to_string()))
322    }
323
324    pub async fn fire_timeout(
326        &mut self,
327        req: UntypedRequest<'_>,
328        duration: impl Into<Option<Duration>>,
329    ) -> io::Result<()> {
330        match duration.into() {
331            Some(duration) => tokio::time::timeout(duration, self.fire(req))
332                .await
333                .map_err(|x| io::Error::new(io::ErrorKind::TimedOut, x))
334                .and_then(convert::identity),
335            None => self.fire(req).await,
336        }
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    mod typed {
345        use std::sync::Arc;
346        use std::time::Duration;
347
348        use test_log::test;
349
350        use super::*;
351
352        type TestChannel = Channel<u8, u8>;
353        type Setup = (
354            TestChannel,
355            mpsc::Receiver<UntypedRequest<'static>>,
356            Arc<PostOffice<UntypedResponse<'static>>>,
357        );
358
359        fn setup(buffer: usize) -> Setup {
360            let post_office = Arc::new(PostOffice::default());
361            let (tx, rx) = mpsc::channel(buffer);
362            let channel = {
363                let post_office = Arc::downgrade(&post_office);
364                UntypedChannel { tx, post_office }
365            };
366
367            (channel.into_typed_channel(), rx, post_office)
368        }
369
370        #[test(tokio::test)]
371        async fn mail_should_return_mailbox_that_receives_responses_until_post_office_drops_it() {
372            let (mut channel, _server, post_office) = setup(100);
373
374            let req = Request::new(0);
375            let res = Response::new(req.id.clone(), 1);
376
377            let mut mailbox = channel.mail(req).await.unwrap();
378
379            assert!(
381                post_office
382                    .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
383                    .await,
384                "Failed to deliver: {res:?}"
385            );
386            assert_eq!(mailbox.next().await, Some(res.clone()));
387
388            assert!(
390                post_office
391                    .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
392                    .await,
393                "Failed to deliver: {res:?}"
394            );
395            assert_eq!(mailbox.next().await, Some(res.clone()));
396
397            let next_task = tokio::spawn(async move { mailbox.next().await });
400            tokio::task::yield_now().await;
401
402            post_office.cancel(&res.origin_id).await;
404
405            match next_task.await {
406                Ok(None) => {}
407                x => panic!("Unexpected response: {:?}", x),
408            }
409        }
410
411        #[test(tokio::test)]
412        async fn send_should_wait_until_response_received() {
413            let (mut channel, _server, post_office) = setup(100);
414
415            let req = Request::new(0);
416            let res = Response::new(req.id.clone(), 1);
417
418            let (actual, _) = tokio::join!(
419                channel.send(req),
420                post_office
421                    .deliver_untyped_response(res.to_untyped_response().unwrap().into_owned())
422            );
423            match actual {
424                Ok(actual) => assert_eq!(actual, res),
425                x => panic!("Unexpected response: {:?}", x),
426            }
427        }
428
429        #[test(tokio::test)]
430        async fn send_timeout_should_fail_if_response_not_received_in_time() {
431            let (mut channel, mut server, _post_office) = setup(100);
432
433            let req = Request::new(0);
434            match channel.send_timeout(req, Duration::from_millis(30)).await {
435                Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
436                x => panic!("Unexpected response: {:?}", x),
437            }
438
439            let _frame = server.recv().await.unwrap();
440        }
441
442        #[test(tokio::test)]
443        async fn fire_should_send_request_and_not_wait_for_response() {
444            let (mut channel, mut server, _post_office) = setup(100);
445
446            let req = Request::new(0);
447            match channel.fire(req).await {
448                Ok(_) => {}
449                x => panic!("Unexpected response: {:?}", x),
450            }
451
452            let _frame = server.recv().await.unwrap();
453        }
454    }
455
456    mod untyped {
457        use std::sync::Arc;
458        use std::time::Duration;
459
460        use test_log::test;
461
462        use super::*;
463
464        type TestChannel = UntypedChannel;
465        type Setup = (
466            TestChannel,
467            mpsc::Receiver<UntypedRequest<'static>>,
468            Arc<PostOffice<UntypedResponse<'static>>>,
469        );
470
471        fn setup(buffer: usize) -> Setup {
472            let post_office = Arc::new(PostOffice::default());
473            let (tx, rx) = mpsc::channel(buffer);
474            let channel = {
475                let post_office = Arc::downgrade(&post_office);
476                TestChannel { tx, post_office }
477            };
478
479            (channel, rx, post_office)
480        }
481
482        #[test(tokio::test)]
483        async fn mail_should_return_mailbox_that_receives_responses_until_post_office_drops_it() {
484            let (mut channel, _server, post_office) = setup(100);
485
486            let req = Request::new(0).to_untyped_request().unwrap().into_owned();
487            let res = Response::new(req.id.clone().into_owned(), 1)
488                .to_untyped_response()
489                .unwrap()
490                .into_owned();
491
492            let mut mailbox = channel.mail(req).await.unwrap();
493
494            assert!(
496                post_office.deliver_untyped_response(res.clone()).await,
497                "Failed to deliver: {res:?}"
498            );
499            assert_eq!(mailbox.next().await, Some(res.clone()));
500
501            assert!(
503                post_office.deliver_untyped_response(res.clone()).await,
504                "Failed to deliver: {res:?}"
505            );
506            assert_eq!(mailbox.next().await, Some(res.clone()));
507
508            let next_task = tokio::spawn(async move { mailbox.next().await });
511            tokio::task::yield_now().await;
512
513            post_office
515                .cancel(&res.origin_id.clone().into_owned())
516                .await;
517
518            match next_task.await {
519                Ok(None) => {}
520                x => panic!("Unexpected response: {:?}", x),
521            }
522        }
523
524        #[test(tokio::test)]
525        async fn send_should_wait_until_response_received() {
526            let (mut channel, _server, post_office) = setup(100);
527
528            let req = Request::new(0).to_untyped_request().unwrap().into_owned();
529            let res = Response::new(req.id.clone().into_owned(), 1)
530                .to_untyped_response()
531                .unwrap()
532                .into_owned();
533
534            let (actual, _) = tokio::join!(
535                channel.send(req),
536                post_office.deliver_untyped_response(res.clone())
537            );
538            match actual {
539                Ok(actual) => assert_eq!(actual, res),
540                x => panic!("Unexpected response: {:?}", x),
541            }
542        }
543
544        #[test(tokio::test)]
545        async fn send_timeout_should_fail_if_response_not_received_in_time() {
546            let (mut channel, mut server, _post_office) = setup(100);
547
548            let req = Request::new(0).to_untyped_request().unwrap().into_owned();
549            match channel.send_timeout(req, Duration::from_millis(30)).await {
550                Err(x) => assert_eq!(x.kind(), io::ErrorKind::TimedOut),
551                x => panic!("Unexpected response: {:?}", x),
552            }
553
554            let _frame = server.recv().await.unwrap();
555        }
556
557        #[test(tokio::test)]
558        async fn fire_should_send_request_and_not_wait_for_response() {
559            let (mut channel, mut server, _post_office) = setup(100);
560
561            let req = Request::new(0).to_untyped_request().unwrap().into_owned();
562            match channel.fire(req).await {
563                Ok(_) => {}
564                x => panic!("Unexpected response: {:?}", x),
565            }
566
567            let _frame = server.recv().await.unwrap();
568        }
569    }
570}