1use crate::{
8 context,
9 util::{Compact, TimeUntil},
10 ClientMessage, PollIo, Request, Response, Transport,
11};
12use fnv::FnvHashMap;
13use futures::{
14 channel::{mpsc, oneshot},
15 prelude::*,
16 ready,
17 stream::Fuse,
18 task::Context,
19 Poll,
20};
21use log::{debug, info, trace};
22use pin_utils::{unsafe_pinned, unsafe_unpinned};
23use std::{
24 io,
25 marker::Unpin,
26 pin::Pin,
27 sync::{
28 atomic::{AtomicU64, Ordering},
29 Arc,
30 },
31};
32use tokio_timer::{timeout, Timeout};
33use trace::SpanId;
34
35use super::{Config, NewClient};
36
37#[derive(Debug)]
39pub struct Channel<Req, Resp> {
40 to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
41 cancellation: RequestCancellation,
43 next_request_id: Arc<AtomicU64>,
45}
46
47impl<Req, Resp> Clone for Channel<Req, Resp> {
48 fn clone(&self) -> Self {
49 Self {
50 to_dispatch: self.to_dispatch.clone(),
51 cancellation: self.cancellation.clone(),
52 next_request_id: self.next_request_id.clone(),
53 }
54 }
55}
56
57#[derive(Debug)]
59#[must_use = "futures do nothing unless polled"]
60struct Send<'a, Req, Resp> {
61 fut: MapOkDispatchResponse<SendMapErrConnectionReset<'a, Req, Resp>, Resp>,
62}
63
64type SendMapErrConnectionReset<'a, Req, Resp> = MapErrConnectionReset<
65 futures::sink::Send<'a, mpsc::Sender<DispatchRequest<Req, Resp>>, DispatchRequest<Req, Resp>>,
66>;
67
68impl<'a, Req, Resp> Send<'a, Req, Resp> {
69 unsafe_pinned!(
70 fut: MapOkDispatchResponse<
71 MapErrConnectionReset<
72 futures::sink::Send<
73 'a,
74 mpsc::Sender<DispatchRequest<Req, Resp>>,
75 DispatchRequest<Req, Resp>,
76 >,
77 >,
78 Resp,
79 >
80 );
81}
82
83impl<'a, Req, Resp> Future for Send<'a, Req, Resp> {
84 type Output = io::Result<DispatchResponse<Resp>>;
85
86 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
87 self.as_mut().fut().poll(cx)
88 }
89}
90
91#[derive(Debug)]
93#[must_use = "futures do nothing unless polled"]
94pub struct Call<'a, Req, Resp> {
95 fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>,
96}
97
98impl<'a, Req, Resp> Call<'a, Req, Resp> {
99 unsafe_pinned!(fut: AndThenIdent<Send<'a, Req, Resp>, DispatchResponse<Resp>>);
100}
101
102impl<'a, Req, Resp> Future for Call<'a, Req, Resp> {
103 type Output = io::Result<Resp>;
104
105 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
106 self.as_mut().fut().poll(cx)
107 }
108}
109
110impl<Req, Resp> Channel<Req, Resp> {
111 fn send(&mut self, mut ctx: context::Context, request: Req) -> Send<Req, Resp> {
114 ctx.trace_context.parent_id = Some(ctx.trace_context.span_id);
116 ctx.trace_context.span_id = SpanId::random(&mut rand::thread_rng());
117
118 let timeout = ctx.deadline.time_until();
119 trace!(
120 "[{}] Queuing request with timeout {:?}.",
121 ctx.trace_id(),
122 timeout,
123 );
124
125 let (response_completion, response) = oneshot::channel();
126 let cancellation = self.cancellation.clone();
127 let request_id = self.next_request_id.fetch_add(1, Ordering::Relaxed);
128 Send {
129 fut: MapOkDispatchResponse::new(
130 MapErrConnectionReset::new(self.to_dispatch.send(DispatchRequest {
131 ctx,
132 request_id,
133 request,
134 response_completion,
135 })),
136 DispatchResponse {
137 response: Timeout::new(response, timeout),
138 complete: false,
139 request_id,
140 cancellation,
141 ctx,
142 },
143 ),
144 }
145 }
146
147 pub fn call(&mut self, context: context::Context, request: Req) -> Call<Req, Resp> {
150 Call {
151 fut: AndThenIdent::new(self.send(context, request)),
152 }
153 }
154}
155
156#[derive(Debug)]
159struct DispatchResponse<Resp> {
160 response: Timeout<oneshot::Receiver<Response<Resp>>>,
161 ctx: context::Context,
162 complete: bool,
163 cancellation: RequestCancellation,
164 request_id: u64,
165}
166
167impl<Resp> DispatchResponse<Resp> {
168 unsafe_pinned!(ctx: context::Context);
169}
170
171impl<Resp> Future for DispatchResponse<Resp> {
172 type Output = io::Result<Resp>;
173
174 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<Resp>> {
175 let resp = ready!(self.response.poll_unpin(cx));
176
177 Poll::Ready(match resp {
178 Ok(resp) => {
179 self.complete = true;
180 match resp {
181 Ok(resp) => Ok(resp.message?),
182 Err(oneshot::Canceled) => {
183 Err(io::Error::from(io::ErrorKind::ConnectionReset))
187 }
188 }
189 }
190 Err(timeout::Elapsed { .. }) => Err(io::Error::new(
191 io::ErrorKind::TimedOut,
192 "Client dropped expired request.".to_string(),
193 )),
194 })
195 }
196}
197
198impl<Resp> Drop for DispatchResponse<Resp> {
200 fn drop(&mut self) {
201 if !self.complete {
202 self.response.get_mut().close();
213 self.cancellation.cancel(self.request_id);
214 }
215 }
216}
217
218pub fn new<Req, Resp, C>(
221 config: Config,
222 transport: C,
223) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
224where
225 C: Transport<ClientMessage<Req>, Response<Resp>>,
226{
227 let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
228 let (cancellation, canceled_requests) = cancellations();
229 let canceled_requests = canceled_requests.fuse();
230
231 NewClient {
232 client: Channel {
233 to_dispatch,
234 cancellation,
235 next_request_id: Arc::new(AtomicU64::new(0)),
236 },
237 dispatch: RequestDispatch {
238 config,
239 canceled_requests,
240 transport: transport.fuse(),
241 in_flight_requests: FnvHashMap::default(),
242 pending_requests: pending_requests.fuse(),
243 },
244 }
245}
246
247#[derive(Debug)]
250pub struct RequestDispatch<Req, Resp, C> {
251 transport: Fuse<C>,
253 pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>,
255 canceled_requests: Fuse<CanceledRequests>,
257 in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>,
259 config: Config,
261}
262
263impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
264where
265 C: Transport<ClientMessage<Req>, Response<Resp>>,
266{
267 unsafe_pinned!(in_flight_requests: FnvHashMap<u64, InFlightData<Resp>>);
268 unsafe_pinned!(canceled_requests: Fuse<CanceledRequests>);
269 unsafe_pinned!(pending_requests: Fuse<mpsc::Receiver<DispatchRequest<Req, Resp>>>);
270 unsafe_pinned!(transport: Fuse<C>);
271
272 fn pump_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
273 Poll::Ready(match ready!(self.as_mut().transport().poll_next(cx)?) {
274 Some(response) => {
275 self.complete(response);
276 Some(Ok(()))
277 }
278 None => None,
279 })
280 }
281
282 fn pump_write(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> PollIo<()> {
283 enum ReceiverStatus {
284 NotReady,
285 Closed,
286 }
287
288 let pending_requests_status = match self.as_mut().poll_next_request(cx)? {
289 Poll::Ready(Some(dispatch_request)) => {
290 self.as_mut().write_request(dispatch_request)?;
291 return Poll::Ready(Some(Ok(())));
292 }
293 Poll::Ready(None) => ReceiverStatus::Closed,
294 Poll::Pending => ReceiverStatus::NotReady,
295 };
296
297 let canceled_requests_status = match self.as_mut().poll_next_cancellation(cx)? {
298 Poll::Ready(Some((context, request_id))) => {
299 self.as_mut().write_cancel(context, request_id)?;
300 return Poll::Ready(Some(Ok(())));
301 }
302 Poll::Ready(None) => ReceiverStatus::Closed,
303 Poll::Pending => ReceiverStatus::NotReady,
304 };
305
306 match (pending_requests_status, canceled_requests_status) {
307 (ReceiverStatus::Closed, ReceiverStatus::Closed) => {
308 ready!(self.as_mut().transport().poll_flush(cx)?);
309 Poll::Ready(None)
310 }
311 (ReceiverStatus::NotReady, _) | (_, ReceiverStatus::NotReady) => {
312 ready!(self.as_mut().transport().poll_flush(cx)?);
314
315 Poll::Pending
318 }
319 }
320 }
321
322 fn poll_next_request(
324 mut self: Pin<&mut Self>,
325 cx: &mut Context<'_>,
326 ) -> PollIo<DispatchRequest<Req, Resp>> {
327 if self.as_mut().in_flight_requests().len() >= self.config.max_in_flight_requests {
328 info!(
329 "At in-flight request capacity ({}/{}).",
330 self.as_mut().in_flight_requests().len(),
331 self.config.max_in_flight_requests
332 );
333
334 return Poll::Pending;
337 }
338
339 while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? {
340 ready!(self.as_mut().transport().poll_flush(cx)?);
342 }
343
344 loop {
345 match ready!(self.as_mut().pending_requests().poll_next_unpin(cx)) {
346 Some(request) => {
347 if request.response_completion.is_canceled() {
348 trace!(
349 "[{}] Request canceled before being sent.",
350 request.ctx.trace_id()
351 );
352 continue;
353 }
354
355 return Poll::Ready(Some(Ok(request)));
356 }
357 None => return Poll::Ready(None),
358 }
359 }
360 }
361
362 fn poll_next_cancellation(
364 mut self: Pin<&mut Self>,
365 cx: &mut Context<'_>,
366 ) -> PollIo<(context::Context, u64)> {
367 while let Poll::Pending = self.as_mut().transport().poll_ready(cx)? {
368 ready!(self.as_mut().transport().poll_flush(cx)?);
369 }
370
371 loop {
372 let cancellation = self.as_mut().canceled_requests().poll_next_unpin(cx);
373 match ready!(cancellation) {
374 Some(request_id) => {
375 if let Some(in_flight_data) =
376 self.as_mut().in_flight_requests().remove(&request_id)
377 {
378 self.as_mut().in_flight_requests().compact(0.1);
379 debug!("[{}] Removed request.", in_flight_data.ctx.trace_id());
380 return Poll::Ready(Some(Ok((in_flight_data.ctx, request_id))));
381 }
382 }
383 None => return Poll::Ready(None),
384 }
385 }
386 }
387
388 fn write_request(
389 mut self: Pin<&mut Self>,
390 dispatch_request: DispatchRequest<Req, Resp>,
391 ) -> io::Result<()> {
392 let request_id = dispatch_request.request_id;
393 let request = ClientMessage::Request(Request {
394 id: request_id,
395 message: dispatch_request.request,
396 context: context::Context {
397 deadline: dispatch_request.ctx.deadline,
398 trace_context: dispatch_request.ctx.trace_context,
399 },
400 });
401 self.as_mut().transport().start_send(request)?;
402 self.as_mut().in_flight_requests().insert(
403 request_id,
404 InFlightData {
405 ctx: dispatch_request.ctx,
406 response_completion: dispatch_request.response_completion,
407 },
408 );
409 Ok(())
410 }
411
412 fn write_cancel(
413 mut self: Pin<&mut Self>,
414 context: context::Context,
415 request_id: u64,
416 ) -> io::Result<()> {
417 let trace_id = *context.trace_id();
418 let cancel = ClientMessage::Cancel {
419 trace_context: context.trace_context,
420 request_id,
421 };
422 self.as_mut().transport().start_send(cancel)?;
423 trace!("[{}] Cancel message sent.", trace_id);
424 Ok(())
425 }
426
427 fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
429 if let Some(in_flight_data) = self
430 .as_mut()
431 .in_flight_requests()
432 .remove(&response.request_id)
433 {
434 self.as_mut().in_flight_requests().compact(0.1);
435
436 trace!("[{}] Received response.", in_flight_data.ctx.trace_id());
437 let _ = in_flight_data.response_completion.send(response);
438 return true;
439 }
440
441 debug!(
442 "No in-flight request found for request_id = {}.",
443 response.request_id
444 );
445
446 false
448 }
449}
450
451impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
452where
453 C: Transport<ClientMessage<Req>, Response<Resp>>,
454{
455 type Output = io::Result<()>;
456
457 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
458 loop {
459 match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
460 (read, Poll::Ready(None)) => {
461 if self.as_mut().in_flight_requests().is_empty() {
462 info!("Shutdown: write half closed, and no requests in flight.");
463 return Poll::Ready(Ok(()));
464 }
465 info!(
466 "Shutdown: write half closed, and {} requests in flight.",
467 self.as_mut().in_flight_requests().len()
468 );
469 match read {
470 Poll::Ready(Some(())) => continue,
471 _ => return Poll::Pending,
472 }
473 }
474 (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
475 _ => return Poll::Pending,
476 }
477 }
478 }
479}
480
481#[derive(Debug)]
484struct DispatchRequest<Req, Resp> {
485 ctx: context::Context,
486 request_id: u64,
487 request: Req,
488 response_completion: oneshot::Sender<Response<Resp>>,
489}
490
491#[derive(Debug)]
492struct InFlightData<Resp> {
493 ctx: context::Context,
494 response_completion: oneshot::Sender<Response<Resp>>,
495}
496
497#[derive(Debug, Clone)]
499struct RequestCancellation(mpsc::UnboundedSender<u64>);
500
501#[derive(Debug)]
503struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
504
505fn cancellations() -> (RequestCancellation, CanceledRequests) {
507 let (tx, rx) = mpsc::unbounded();
512 (RequestCancellation(tx), CanceledRequests(rx))
513}
514
515impl RequestCancellation {
516 fn cancel(&mut self, request_id: u64) {
518 let _ = self.0.unbounded_send(request_id);
519 }
520}
521
522impl Stream for CanceledRequests {
523 type Item = u64;
524
525 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<u64>> {
526 self.0.poll_next_unpin(cx)
527 }
528}
529
530#[derive(Debug)]
531#[must_use = "futures do nothing unless polled"]
532struct MapErrConnectionReset<Fut> {
533 future: Fut,
534 finished: Option<()>,
535}
536
537impl<Fut> MapErrConnectionReset<Fut> {
538 unsafe_pinned!(future: Fut);
539 unsafe_unpinned!(finished: Option<()>);
540
541 fn new(future: Fut) -> MapErrConnectionReset<Fut> {
542 MapErrConnectionReset {
543 future,
544 finished: Some(()),
545 }
546 }
547}
548
549impl<Fut: Unpin> Unpin for MapErrConnectionReset<Fut> {}
550
551impl<Fut> Future for MapErrConnectionReset<Fut>
552where
553 Fut: TryFuture,
554{
555 type Output = io::Result<Fut::Ok>;
556
557 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
558 match self.as_mut().future().try_poll(cx) {
559 Poll::Pending => Poll::Pending,
560 Poll::Ready(result) => {
561 self.finished().take().expect(
562 "MapErrConnectionReset must not be polled after it returned `Poll::Ready`",
563 );
564 Poll::Ready(result.map_err(|_| io::Error::from(io::ErrorKind::ConnectionReset)))
565 }
566 }
567 }
568}
569
570#[derive(Debug)]
571#[must_use = "futures do nothing unless polled"]
572struct MapOkDispatchResponse<Fut, Resp> {
573 future: Fut,
574 response: Option<DispatchResponse<Resp>>,
575}
576
577impl<Fut, Resp> MapOkDispatchResponse<Fut, Resp> {
578 unsafe_pinned!(future: Fut);
579 unsafe_unpinned!(response: Option<DispatchResponse<Resp>>);
580
581 fn new(future: Fut, response: DispatchResponse<Resp>) -> MapOkDispatchResponse<Fut, Resp> {
582 MapOkDispatchResponse {
583 future,
584 response: Some(response),
585 }
586 }
587}
588
589impl<Fut: Unpin, Resp> Unpin for MapOkDispatchResponse<Fut, Resp> {}
590
591impl<Fut, Resp> Future for MapOkDispatchResponse<Fut, Resp>
592where
593 Fut: TryFuture,
594{
595 type Output = Result<DispatchResponse<Resp>, Fut::Error>;
596
597 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
598 match self.as_mut().future().try_poll(cx) {
599 Poll::Pending => Poll::Pending,
600 Poll::Ready(result) => {
601 let response = self
602 .as_mut()
603 .response()
604 .take()
605 .expect("MapOk must not be polled after it returned `Poll::Ready`");
606 Poll::Ready(result.map(|_| response))
607 }
608 }
609 }
610}
611
612#[derive(Debug)]
613#[must_use = "futures do nothing unless polled"]
614struct AndThenIdent<Fut1, Fut2> {
615 try_chain: TryChain<Fut1, Fut2>,
616}
617
618impl<Fut1, Fut2> AndThenIdent<Fut1, Fut2>
619where
620 Fut1: TryFuture<Ok = Fut2>,
621 Fut2: TryFuture,
622{
623 unsafe_pinned!(try_chain: TryChain<Fut1, Fut2>);
624
625 fn new(future: Fut1) -> AndThenIdent<Fut1, Fut2> {
627 AndThenIdent {
628 try_chain: TryChain::new(future),
629 }
630 }
631}
632
633impl<Fut1, Fut2> Future for AndThenIdent<Fut1, Fut2>
634where
635 Fut1: TryFuture<Ok = Fut2>,
636 Fut2: TryFuture<Error = Fut1::Error>,
637{
638 type Output = Result<Fut2::Ok, Fut2::Error>;
639
640 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
641 self.try_chain().poll(cx, |result| match result {
642 Ok(ok) => TryChainAction::Future(ok),
643 Err(err) => TryChainAction::Output(Err(err)),
644 })
645 }
646}
647
648#[must_use = "futures do nothing unless polled"]
649#[derive(Debug)]
650enum TryChain<Fut1, Fut2> {
651 First(Fut1),
652 Second(Fut2),
653 Empty,
654}
655
656enum TryChainAction<Fut2>
657where
658 Fut2: TryFuture,
659{
660 Future(Fut2),
661 Output(Result<Fut2::Ok, Fut2::Error>),
662}
663
664impl<Fut1, Fut2> TryChain<Fut1, Fut2>
665where
666 Fut1: TryFuture<Ok = Fut2>,
667 Fut2: TryFuture,
668{
669 fn new(fut1: Fut1) -> TryChain<Fut1, Fut2> {
670 TryChain::First(fut1)
671 }
672
673 fn poll<F>(
674 self: Pin<&mut Self>,
675 cx: &mut Context<'_>,
676 f: F,
677 ) -> Poll<Result<Fut2::Ok, Fut2::Error>>
678 where
679 F: FnOnce(Result<Fut1::Ok, Fut1::Error>) -> TryChainAction<Fut2>,
680 {
681 let mut f = Some(f);
682
683 let this = unsafe { Pin::get_unchecked_mut(self) };
685
686 loop {
687 let output = match this {
688 TryChain::First(fut1) => {
689 match unsafe { Pin::new_unchecked(fut1) }.try_poll(cx) {
691 Poll::Pending => return Poll::Pending,
692 Poll::Ready(output) => output,
693 }
694 }
695 TryChain::Second(fut2) => {
696 return unsafe { Pin::new_unchecked(fut2) }.try_poll(cx);
698 }
699 TryChain::Empty => {
700 panic!("future must not be polled after it returned `Poll::Ready`");
701 }
702 };
703
704 *this = TryChain::Empty; let f = f.take().unwrap();
706 match f(output) {
707 TryChainAction::Future(fut2) => *this = TryChain::Second(fut2),
708 TryChainAction::Output(output) => return Poll::Ready(output),
709 }
710 }
711 }
712}
713
714#[cfg(test)]
715mod tests {
716 use super::{
717 cancellations, CanceledRequests, Channel, DispatchResponse, RequestCancellation,
718 RequestDispatch,
719 };
720 use crate::{
721 client::Config,
722 context,
723 transport::{self, channel::UnboundedChannel},
724 ClientMessage, Response,
725 };
726 use fnv::FnvHashMap;
727 use futures::{
728 channel::{mpsc, oneshot},
729 prelude::*,
730 task::Context,
731 Poll,
732 };
733 use futures_test::task::noop_waker_ref;
734 use std::time::Duration;
735 use std::{pin::Pin, sync::atomic::AtomicU64, sync::Arc};
736 use tokio::runtime::current_thread;
737 use tokio_timer::Timeout;
738
739 #[test]
740 fn dispatch_response_cancels_on_timeout() {
741 let (_response_completion, response) = oneshot::channel();
742 let (cancellation, mut canceled_requests) = cancellations();
743 let resp = DispatchResponse::<u64> {
744 response: Timeout::new(response, Duration::from_secs(0)),
746 complete: false,
747 request_id: 3,
748 cancellation,
749 ctx: context::current(),
750 };
751 {
752 pin_utils::pin_mut!(resp);
753 let timer = tokio_timer::Timer::default();
754 let handle = timer.handle();
755 let _guard = tokio_timer::set_default(&handle);
756
757 let _ = resp
758 .as_mut()
759 .poll(&mut Context::from_waker(&noop_waker_ref()));
760 }
762 assert!(canceled_requests.0.try_next().unwrap() == Some(3));
763 }
764
765 #[test]
766 fn stage_request() {
767 let (mut dispatch, mut channel, _server_channel) = set_up();
768 let dispatch = Pin::new(&mut dispatch);
769 let cx = &mut Context::from_waker(&noop_waker_ref());
770
771 let _resp = send_request(&mut channel, "hi");
772
773 let req = dispatch.poll_next_request(cx).ready();
774 assert!(req.is_some());
775
776 let req = req.unwrap();
777 assert_eq!(req.request_id, 0);
778 assert_eq!(req.request, "hi".to_string());
779 }
780
781 fn block_on<F: Future>(f: F) -> F::Output {
782 current_thread::Runtime::new().unwrap().block_on(f)
783 }
784
785 #[test]
787 fn stage_request_channel_dropped_doesnt_panic() {
788 let (mut dispatch, mut channel, mut server_channel) = set_up();
789 let mut dispatch = Pin::new(&mut dispatch);
790 let cx = &mut Context::from_waker(&noop_waker_ref());
791
792 let _ = send_request(&mut channel, "hi");
793 drop(channel);
794
795 assert!(dispatch.as_mut().poll(cx).is_ready());
796 send_response(
797 &mut server_channel,
798 Response {
799 request_id: 0,
800 message: Ok("hello".into()),
801 },
802 );
803 block_on(dispatch).unwrap();
804 }
805
806 #[test]
807 fn stage_request_response_future_dropped_is_canceled_before_sending() {
808 let (mut dispatch, mut channel, _server_channel) = set_up();
809 let dispatch = Pin::new(&mut dispatch);
810 let cx = &mut Context::from_waker(&noop_waker_ref());
811
812 let _ = send_request(&mut channel, "hi");
813
814 drop(channel);
816 assert!(dispatch.poll_next_request(cx).ready().is_none());
819 }
820
821 #[test]
822 fn stage_request_response_future_dropped_is_canceled_after_sending() {
823 let (mut dispatch, mut channel, _server_channel) = set_up();
824 let cx = &mut Context::from_waker(&noop_waker_ref());
825 let mut dispatch = Pin::new(&mut dispatch);
826
827 let req = send_request(&mut channel, "hi");
828
829 assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
830 assert!(!dispatch.as_mut().in_flight_requests().is_empty());
831
832 drop(req);
835 if let Poll::Ready(Some(_)) = dispatch.as_mut().poll_next_cancellation(cx).unwrap() {
836 } else {
838 panic!("Expected request to be cancelled")
839 };
840 assert!(dispatch.in_flight_requests().is_empty());
841 }
842
843 #[test]
844 fn stage_request_response_closed_skipped() {
845 let (mut dispatch, mut channel, _server_channel) = set_up();
846 let dispatch = Pin::new(&mut dispatch);
847 let cx = &mut Context::from_waker(&noop_waker_ref());
848
849 let mut resp = send_request(&mut channel, "hi");
853 resp.response.get_mut().close();
854
855 assert!(dispatch.poll_next_request(cx).is_pending());
856 }
857
858 fn set_up() -> (
859 RequestDispatch<String, String, UnboundedChannel<Response<String>, ClientMessage<String>>>,
860 Channel<String, String>,
861 UnboundedChannel<ClientMessage<String>, Response<String>>,
862 ) {
863 let _ = env_logger::try_init();
864
865 let (to_dispatch, pending_requests) = mpsc::channel(1);
866 let (cancel_tx, canceled_requests) = mpsc::unbounded();
867 let (client_channel, server_channel) = transport::channel::unbounded();
868
869 let dispatch = RequestDispatch::<String, String, _> {
870 transport: client_channel.fuse(),
871 pending_requests: pending_requests.fuse(),
872 canceled_requests: CanceledRequests(canceled_requests).fuse(),
873 in_flight_requests: FnvHashMap::default(),
874 config: Config::default(),
875 };
876
877 let cancellation = RequestCancellation(cancel_tx);
878 let channel = Channel {
879 to_dispatch,
880 cancellation,
881 next_request_id: Arc::new(AtomicU64::new(0)),
882 };
883
884 (dispatch, channel, server_channel)
885 }
886
887 fn send_request(
888 channel: &mut Channel<String, String>,
889 request: &str,
890 ) -> DispatchResponse<String> {
891 block_on(channel.send(context::current(), request.to_string())).unwrap()
892 }
893
894 fn send_response(
895 channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
896 response: Response<String>,
897 ) {
898 block_on(channel.send(response)).unwrap();
899 }
900
901 trait PollTest {
902 type T;
903 fn unwrap(self) -> Poll<Self::T>;
904 fn ready(self) -> Self::T;
905 }
906
907 impl<T, E> PollTest for Poll<Option<Result<T, E>>>
908 where
909 E: ::std::fmt::Display,
910 {
911 type T = Option<T>;
912
913 fn unwrap(self) -> Poll<Option<T>> {
914 match self {
915 Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
916 Poll::Ready(None) => Poll::Ready(None),
917 Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
918 Poll::Pending => Poll::Pending,
919 }
920 }
921
922 fn ready(self) -> Option<T> {
923 match self {
924 Poll::Ready(Some(Ok(t))) => Some(t),
925 Poll::Ready(None) => None,
926 Poll::Ready(Some(Err(e))) => panic!(e.to_string()),
927 Poll::Pending => panic!("Pending"),
928 }
929 }
930 }
931}