forked_tarpc/
client.rs

1// Copyright 2018 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7//! Provides a client that connects to a server and sends multiplexed requests.
8
9mod in_flight_requests;
10
11use crate::{
12    cancellations::{cancellations, CanceledRequests, RequestCancellation},
13    context, trace, ClientMessage, Request, Response, ServerError, Transport,
14};
15use futures::{prelude::*, ready, stream::Fuse, task::*};
16use in_flight_requests::{DeadlineExceededError, InFlightRequests};
17use pin_project::pin_project;
18use std::{
19    convert::TryFrom,
20    error::Error,
21    fmt,
22    pin::Pin,
23    sync::{
24        atomic::{AtomicUsize, Ordering},
25        Arc,
26    },
27};
28use tokio::sync::{mpsc, oneshot};
29use tracing::Span;
30
31/// Settings that control the behavior of the client.
32#[derive(Clone, Debug)]
33#[non_exhaustive]
34pub struct Config {
35    /// The number of requests that can be in flight at once.
36    /// `max_in_flight_requests` controls the size of the map used by the client
37    /// for storing pending requests.
38    pub max_in_flight_requests: usize,
39    /// The number of requests that can be buffered client-side before being sent.
40    /// `pending_requests_buffer` controls the size of the channel clients use
41    /// to communicate with the request dispatch task.
42    pub pending_request_buffer: usize,
43}
44
45impl Default for Config {
46    fn default() -> Self {
47        Config {
48            max_in_flight_requests: 1_000,
49            pending_request_buffer: 100,
50        }
51    }
52}
53
54/// A channel and dispatch pair. The dispatch drives the sending and receiving of requests
55/// and must be polled continuously or spawned.
56pub struct NewClient<C, D> {
57    /// The new client.
58    pub client: C,
59    /// The client's dispatch.
60    pub dispatch: D,
61}
62
63impl<C, D, E> NewClient<C, D>
64where
65    D: Future<Output = Result<(), E>> + Send + 'static,
66    E: std::error::Error + Send + Sync + 'static,
67{
68    /// Helper method to spawn the dispatch on the default executor.
69    #[cfg(feature = "tokio1")]
70    #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
71    pub fn spawn(self) -> C {
72        let dispatch = self.dispatch.unwrap_or_else(move |e| {
73            let e = anyhow::Error::new(e);
74            tracing::warn!("Connection broken: {:?}", e);
75        });
76        tokio::spawn(dispatch);
77        self.client
78    }
79}
80
81impl<C, D> fmt::Debug for NewClient<C, D> {
82    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
83        write!(fmt, "NewClient")
84    }
85}
86
87const _CHECK_USIZE: () = assert!(
88    std::mem::size_of::<usize>() <= std::mem::size_of::<u64>(),
89    "usize is too big to fit in u64"
90);
91
92/// Handles communication from the client to request dispatch.
93#[derive(Debug)]
94pub struct Channel<Req, Resp> {
95    to_dispatch: mpsc::Sender<DispatchRequest<Req, Resp>>,
96    /// Channel to send a cancel message to the dispatcher.
97    cancellation: RequestCancellation,
98    /// The ID to use for the next request to stage.
99    next_request_id: Arc<AtomicUsize>,
100}
101
102impl<Req, Resp> Clone for Channel<Req, Resp> {
103    fn clone(&self) -> Self {
104        Self {
105            to_dispatch: self.to_dispatch.clone(),
106            cancellation: self.cancellation.clone(),
107            next_request_id: self.next_request_id.clone(),
108        }
109    }
110}
111
112impl<Req, Resp> Channel<Req, Resp> {
113    /// Sends a request to the dispatch task to forward to the server, returning a [`Future`] that
114    /// resolves to the response.
115    #[tracing::instrument(
116        name = "RPC",
117        skip(self, ctx, request_name, request),
118        fields(
119            rpc.trace_id = tracing::field::Empty,
120            rpc.deadline = %humantime::format_rfc3339(ctx.deadline),
121            otel.kind = "client",
122            otel.name = request_name)
123        )]
124    pub async fn call(
125        &self,
126        mut ctx: context::Context,
127        request_name: &'static str,
128        request: Req,
129    ) -> Result<Resp, RpcError> {
130        let span = Span::current();
131        ctx.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
132            tracing::trace!(
133                "OpenTelemetry subscriber not installed; making unsampled child context."
134            );
135            ctx.trace_context.new_child()
136        });
137        span.record("rpc.trace_id", &tracing::field::display(ctx.trace_id()));
138        let (response_completion, mut response) = oneshot::channel();
139        let request_id =
140            u64::try_from(self.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
141
142        // ResponseGuard impls Drop to cancel in-flight requests. It should be created before
143        // sending out the request; otherwise, the response future could be dropped after the
144        // request is sent out but before ResponseGuard is created, rendering the cancellation
145        // logic inactive.
146        let response_guard = ResponseGuard {
147            response: &mut response,
148            request_id,
149            cancellation: &self.cancellation,
150            cancel: true,
151        };
152        self.to_dispatch
153            .send(DispatchRequest {
154                ctx,
155                span,
156                request_id,
157                request,
158                response_completion,
159            })
160            .await
161            .map_err(|mpsc::error::SendError(_)| RpcError::Disconnected)?;
162        response_guard.response().await
163    }
164}
165
166/// A server response that is completed by request dispatch when the corresponding response
167/// arrives off the wire.
168struct ResponseGuard<'a, Resp> {
169    response: &'a mut oneshot::Receiver<Result<Response<Resp>, DeadlineExceededError>>,
170    cancellation: &'a RequestCancellation,
171    request_id: u64,
172    cancel: bool,
173}
174
175/// An error that can occur in the processing of an RPC. This is not request-specific errors but
176/// rather cross-cutting errors that can always occur.
177#[derive(thiserror::Error, Clone, Debug, PartialEq, Eq, Hash)]
178#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
179pub enum RpcError {
180    /// The client disconnected from the server.
181    #[error("the client disconnected from the server")]
182    Disconnected,
183    /// The request exceeded its deadline.
184    #[error("the request exceeded its deadline")]
185    DeadlineExceeded,
186    /// The server aborted request processing.
187    #[error("the server aborted request processing")]
188    Server(#[from] ServerError),
189}
190
191impl From<DeadlineExceededError> for RpcError {
192    fn from(_: DeadlineExceededError) -> Self {
193        RpcError::DeadlineExceeded
194    }
195}
196
197impl<Resp> ResponseGuard<'_, Resp> {
198    async fn response(mut self) -> Result<Resp, RpcError> {
199        let response = (&mut self.response).await;
200        // Cancel drop logic once a response has been received.
201        self.cancel = false;
202        match response {
203            Ok(resp) => Ok(resp?.message?),
204            Err(oneshot::error::RecvError { .. }) => {
205                // The oneshot is Canceled when the dispatch task ends. In that case,
206                // there's nothing listening on the other side, so there's no point in
207                // propagating cancellation.
208                Err(RpcError::Disconnected)
209            }
210        }
211    }
212}
213
214// Cancels the request when dropped, if not already complete.
215impl<Resp> Drop for ResponseGuard<'_, Resp> {
216    fn drop(&mut self) {
217        // The receiver needs to be closed to handle the edge case that the request has not
218        // yet been received by the dispatch task. It is possible for the cancel message to
219        // arrive before the request itself, in which case the request could get stuck in the
220        // dispatch map forever if the server never responds (e.g. if the server dies while
221        // responding). Even if the server does respond, it will have unnecessarily done work
222        // for a client no longer waiting for a response. To avoid this, the dispatch task
223        // checks if the receiver is closed before inserting the request in the map. By
224        // closing the receiver before sending the cancel message, it is guaranteed that if the
225        // dispatch task misses an early-arriving cancellation message, then it will see the
226        // receiver as closed.
227        self.response.close();
228        if self.cancel {
229            self.cancellation.cancel(self.request_id);
230        }
231    }
232}
233
234/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
235/// channel.
236pub fn new<Req, Resp, C>(
237    config: Config,
238    transport: C,
239) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
240where
241    C: Transport<ClientMessage<Req>, Response<Resp>>,
242{
243    let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
244    let (cancellation, canceled_requests) = cancellations();
245    let canceled_requests = canceled_requests;
246
247    NewClient {
248        client: Channel {
249            to_dispatch,
250            cancellation,
251            next_request_id: Arc::new(AtomicUsize::new(0)),
252        },
253        dispatch: RequestDispatch {
254            config,
255            canceled_requests,
256            transport: transport.fuse(),
257            in_flight_requests: InFlightRequests::default(),
258            pending_requests,
259        },
260    }
261}
262
263/// Handles the lifecycle of requests, writing requests to the wire, managing cancellations,
264/// and dispatching responses to the appropriate channel.
265#[must_use]
266#[pin_project]
267#[derive(Debug)]
268pub struct RequestDispatch<Req, Resp, C> {
269    /// Writes requests to the wire and reads responses off the wire.
270    #[pin]
271    transport: Fuse<C>,
272    /// Requests waiting to be written to the wire.
273    pending_requests: mpsc::Receiver<DispatchRequest<Req, Resp>>,
274    /// Requests that were dropped.
275    canceled_requests: CanceledRequests,
276    /// Requests already written to the wire that haven't yet received responses.
277    in_flight_requests: InFlightRequests<Resp>,
278    /// Configures limits to prevent unlimited resource usage.
279    config: Config,
280}
281
282/// Critical errors that result in a Channel disconnecting.
283#[derive(thiserror::Error, Debug)]
284pub enum ChannelError<E>
285where
286    E: Error + Send + Sync + 'static,
287{
288    /// Could not read from the transport.
289    #[error("could not read from the transport")]
290    Read(#[source] E),
291    /// Could not ready the transport for writes.
292    #[error("could not ready the transport for writes")]
293    Ready(#[source] E),
294    /// Could not write to the transport.
295    #[error("could not write to the transport")]
296    Write(#[source] E),
297    /// Could not flush the transport.
298    #[error("could not flush the transport")]
299    Flush(#[source] E),
300    /// Could not close the write end of the transport.
301    #[error("could not close the write end of the transport")]
302    Close(#[source] E),
303    /// Could not poll expired requests.
304    #[error("could not poll expired requests")]
305    Timer(#[source] tokio::time::error::Error),
306}
307
308impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
309where
310    C: Transport<ClientMessage<Req>, Response<Resp>>,
311{
312    fn in_flight_requests<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests<Resp> {
313        self.as_mut().project().in_flight_requests
314    }
315
316    fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<C>> {
317        self.as_mut().project().transport
318    }
319
320    fn poll_ready<'a>(
321        self: &'a mut Pin<&mut Self>,
322        cx: &mut Context<'_>,
323    ) -> Poll<Result<(), ChannelError<C::Error>>> {
324        self.transport_pin_mut()
325            .poll_ready(cx)
326            .map_err(ChannelError::Ready)
327    }
328
329    fn start_send(
330        self: &mut Pin<&mut Self>,
331        message: ClientMessage<Req>,
332    ) -> Result<(), ChannelError<C::Error>> {
333        self.transport_pin_mut()
334            .start_send(message)
335            .map_err(ChannelError::Write)
336    }
337
338    fn poll_flush<'a>(
339        self: &'a mut Pin<&mut Self>,
340        cx: &mut Context<'_>,
341    ) -> Poll<Result<(), ChannelError<C::Error>>> {
342        self.transport_pin_mut()
343            .poll_flush(cx)
344            .map_err(ChannelError::Flush)
345    }
346
347    fn poll_close<'a>(
348        self: &'a mut Pin<&mut Self>,
349        cx: &mut Context<'_>,
350    ) -> Poll<Result<(), ChannelError<C::Error>>> {
351        self.transport_pin_mut()
352            .poll_close(cx)
353            .map_err(ChannelError::Close)
354    }
355
356    fn canceled_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut CanceledRequests {
357        self.as_mut().project().canceled_requests
358    }
359
360    fn pending_requests_mut<'a>(
361        self: &'a mut Pin<&mut Self>,
362    ) -> &'a mut mpsc::Receiver<DispatchRequest<Req, Resp>> {
363        self.as_mut().project().pending_requests
364    }
365
366    fn pump_read(
367        mut self: Pin<&mut Self>,
368        cx: &mut Context<'_>,
369    ) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
370        self.transport_pin_mut()
371            .poll_next(cx)
372            .map_err(ChannelError::Read)
373            .map_ok(|response| {
374                self.complete(response);
375            })
376    }
377
378    fn pump_write(
379        mut self: Pin<&mut Self>,
380        cx: &mut Context<'_>,
381    ) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
382        enum ReceiverStatus {
383            Pending,
384            Closed,
385        }
386
387        let pending_requests_status = match self.as_mut().poll_write_request(cx)? {
388            Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
389            Poll::Ready(None) => ReceiverStatus::Closed,
390            Poll::Pending => ReceiverStatus::Pending,
391        };
392
393        let canceled_requests_status = match self.as_mut().poll_write_cancel(cx)? {
394            Poll::Ready(Some(())) => return Poll::Ready(Some(Ok(()))),
395            Poll::Ready(None) => ReceiverStatus::Closed,
396            Poll::Pending => ReceiverStatus::Pending,
397        };
398
399        // Receiving Poll::Ready(None) when polling expired requests never indicates "Closed",
400        // because there can temporarily be zero in-flight rquests. Therefore, there is no need to
401        // track the status like is done with pending and cancelled requests.
402        if let Poll::Ready(Some(_)) = self.in_flight_requests().poll_expired(cx) {
403            // Expired requests are considered complete; there is no compelling reason to send a
404            // cancellation message to the server, since it will have already exhausted its
405            // allotted processing time.
406            return Poll::Ready(Some(Ok(())));
407        }
408
409        match (pending_requests_status, canceled_requests_status) {
410            (ReceiverStatus::Closed, ReceiverStatus::Closed) => {
411                ready!(self.poll_close(cx)?);
412                Poll::Ready(None)
413            }
414            (ReceiverStatus::Pending, _) | (_, ReceiverStatus::Pending) => {
415                // No more messages to process, so flush any messages buffered in the transport.
416                ready!(self.poll_flush(cx)?);
417
418                // Even if we fully-flush, we return Pending, because we have no more requests
419                // or cancellations right now.
420                Poll::Pending
421            }
422        }
423    }
424
425    /// Yields the next pending request, if one is ready to be sent.
426    ///
427    /// Note that a request will only be yielded if the transport is *ready* to be written to (i.e.
428    /// start_send would succeed).
429    fn poll_next_request(
430        mut self: Pin<&mut Self>,
431        cx: &mut Context<'_>,
432    ) -> Poll<Option<Result<DispatchRequest<Req, Resp>, ChannelError<C::Error>>>> {
433        if self.in_flight_requests().len() >= self.config.max_in_flight_requests {
434            tracing::info!(
435                "At in-flight request capacity ({}/{}).",
436                self.in_flight_requests().len(),
437                self.config.max_in_flight_requests
438            );
439
440            // No need to schedule a wakeup, because timers and responses are responsible
441            // for clearing out in-flight requests.
442            return Poll::Pending;
443        }
444
445        ready!(self.ensure_writeable(cx)?);
446
447        loop {
448            match ready!(self.pending_requests_mut().poll_recv(cx)) {
449                Some(request) => {
450                    if request.response_completion.is_closed() {
451                        let _entered = request.span.enter();
452                        tracing::info!("AbortRequest");
453                        continue;
454                    }
455
456                    return Poll::Ready(Some(Ok(request)));
457                }
458                None => return Poll::Ready(None),
459            }
460        }
461    }
462
463    /// Yields the next pending cancellation, and, if one is ready, cancels the associated request.
464    ///
465    /// Note that a request to cancel will only be yielded if the transport is *ready* to be
466    /// written to (i.e.  start_send would succeed).
467    fn poll_next_cancellation(
468        mut self: Pin<&mut Self>,
469        cx: &mut Context<'_>,
470    ) -> Poll<Option<Result<(context::Context, Span, u64), ChannelError<C::Error>>>> {
471        ready!(self.ensure_writeable(cx)?);
472
473        loop {
474            match ready!(self.canceled_requests_mut().poll_next_unpin(cx)) {
475                Some(request_id) => {
476                    if let Some((ctx, span)) = self.in_flight_requests().cancel_request(request_id)
477                    {
478                        return Poll::Ready(Some(Ok((ctx, span, request_id))));
479                    }
480                }
481                None => return Poll::Ready(None),
482            }
483        }
484    }
485
486    /// Returns Ready if writing a message to the transport (i.e. via write_request or
487    /// write_cancel) would not fail due to a full buffer. If the transport is not ready to be
488    /// written to, flushes it until it is ready.
489    fn ensure_writeable<'a>(
490        self: &'a mut Pin<&mut Self>,
491        cx: &mut Context<'_>,
492    ) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
493        while self.poll_ready(cx)?.is_pending() {
494            ready!(self.poll_flush(cx)?);
495        }
496        Poll::Ready(Some(Ok(())))
497    }
498
499    fn poll_write_request<'a>(
500        self: &'a mut Pin<&mut Self>,
501        cx: &mut Context<'_>,
502    ) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
503        let DispatchRequest {
504            ctx,
505            span,
506            request_id,
507            request,
508            response_completion,
509        } = match ready!(self.as_mut().poll_next_request(cx)?) {
510            Some(dispatch_request) => dispatch_request,
511            None => return Poll::Ready(None),
512        };
513        let entered = span.enter();
514        // poll_next_request only returns Ready if there is room to buffer another request.
515        // Therefore, we can call write_request without fear of erroring due to a full
516        // buffer.
517        let request_id = request_id;
518        let request = ClientMessage::Request(Request {
519            id: request_id,
520            message: request,
521            context: context::Context {
522                deadline: ctx.deadline,
523                discard_response: false,
524                trace_context: ctx.trace_context,
525            },
526        });
527        self.start_send(request)?;
528        tracing::info!("SendRequest");
529        drop(entered);
530
531        self.in_flight_requests()
532            .insert_request(request_id, ctx, span, response_completion)
533            .expect("Request IDs should be unique");
534        Poll::Ready(Some(Ok(())))
535    }
536
537    fn poll_write_cancel<'a>(
538        self: &'a mut Pin<&mut Self>,
539        cx: &mut Context<'_>,
540    ) -> Poll<Option<Result<(), ChannelError<C::Error>>>> {
541        let (context, span, request_id) = match ready!(self.as_mut().poll_next_cancellation(cx)?) {
542            Some(triple) => triple,
543            None => return Poll::Ready(None),
544        };
545        let _entered = span.enter();
546
547        let cancel = ClientMessage::Cancel {
548            trace_context: context.trace_context,
549            request_id,
550        };
551        self.start_send(cancel)?;
552        tracing::info!("CancelRequest");
553        Poll::Ready(Some(Ok(())))
554    }
555
556    /// Sends a server response to the client task that initiated the associated request.
557    fn complete(mut self: Pin<&mut Self>, response: Response<Resp>) -> bool {
558        self.in_flight_requests().complete_request(response)
559    }
560}
561
562impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
563where
564    C: Transport<ClientMessage<Req>, Response<Resp>>,
565{
566    type Output = Result<(), ChannelError<C::Error>>;
567
568    fn poll(
569        mut self: Pin<&mut Self>,
570        cx: &mut Context<'_>,
571    ) -> Poll<Result<(), ChannelError<C::Error>>> {
572        loop {
573            match (self.as_mut().pump_read(cx)?, self.as_mut().pump_write(cx)?) {
574                (Poll::Ready(None), _) => {
575                    tracing::info!("Shutdown: read half closed, so shutting down.");
576                    return Poll::Ready(Ok(()));
577                }
578                (read, Poll::Ready(None)) => {
579                    if self.in_flight_requests.is_empty() {
580                        tracing::info!("Shutdown: write half closed, and no requests in flight.");
581                        return Poll::Ready(Ok(()));
582                    }
583                    tracing::info!(
584                        "Shutdown: write half closed, and {} requests in flight.",
585                        self.in_flight_requests().len()
586                    );
587                    match read {
588                        Poll::Ready(Some(())) => continue,
589                        _ => return Poll::Pending,
590                    }
591                }
592                (Poll::Ready(Some(())), _) | (_, Poll::Ready(Some(()))) => {}
593                _ => return Poll::Pending,
594            }
595        }
596    }
597}
598
599/// A server-bound request sent from a [`Channel`] to request dispatch, which will then manage
600/// the lifecycle of the request.
601#[derive(Debug)]
602struct DispatchRequest<Req, Resp> {
603    pub ctx: context::Context,
604    pub span: Span,
605    pub request_id: u64,
606    pub request: Req,
607    pub response_completion: oneshot::Sender<Result<Response<Resp>, DeadlineExceededError>>,
608}
609
610#[cfg(test)]
611mod tests {
612    use super::{cancellations, Channel, DispatchRequest, RequestDispatch, ResponseGuard};
613    use crate::{
614        client::{
615            in_flight_requests::{DeadlineExceededError, InFlightRequests},
616            Config,
617        },
618        context,
619        transport::{self, channel::UnboundedChannel},
620        ClientMessage, Response,
621    };
622    use assert_matches::assert_matches;
623    use futures::{prelude::*, task::*};
624    use std::{
625        convert::TryFrom,
626        pin::Pin,
627        sync::atomic::{AtomicUsize, Ordering},
628        sync::Arc,
629    };
630    use tokio::sync::{mpsc, oneshot};
631    use tracing::Span;
632
633    #[tokio::test]
634    async fn response_completes_request_future() {
635        let (mut dispatch, mut _channel, mut server_channel) = set_up();
636        let cx = &mut Context::from_waker(noop_waker_ref());
637        let (tx, mut rx) = oneshot::channel();
638
639        dispatch
640            .in_flight_requests
641            .insert_request(0, context::current(), Span::current(), tx)
642            .unwrap();
643        server_channel
644            .send(Response {
645                request_id: 0,
646                message: Ok("Resp".into()),
647            })
648            .await
649            .unwrap();
650        assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
651        assert_matches!(rx.try_recv(), Ok(Ok(Response { request_id: 0, message: Ok(resp) })) if resp == "Resp");
652    }
653
654    #[tokio::test]
655    async fn dispatch_response_cancels_on_drop() {
656        let (cancellation, mut canceled_requests) = cancellations();
657        let (_, mut response) = oneshot::channel();
658        drop(ResponseGuard::<u32> {
659            response: &mut response,
660            cancellation: &cancellation,
661            request_id: 3,
662            cancel: true,
663        });
664        // resp's drop() is run, which should send a cancel message.
665        let cx = &mut Context::from_waker(noop_waker_ref());
666        assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(Some(3)));
667    }
668
669    #[tokio::test]
670    async fn dispatch_response_doesnt_cancel_after_complete() {
671        let (cancellation, mut canceled_requests) = cancellations();
672        let (tx, mut response) = oneshot::channel();
673        tx.send(Ok(Response {
674            request_id: 0,
675            message: Ok("well done"),
676        }))
677        .unwrap();
678        // resp's drop() is run, but should not send a cancel message.
679        ResponseGuard {
680            response: &mut response,
681            cancellation: &cancellation,
682            request_id: 3,
683            cancel: true,
684        }
685        .response()
686        .await
687        .unwrap();
688        drop(cancellation);
689        let cx = &mut Context::from_waker(noop_waker_ref());
690        assert_eq!(canceled_requests.poll_recv(cx), Poll::Ready(None));
691    }
692
693    #[tokio::test]
694    async fn stage_request() {
695        let (mut dispatch, mut channel, _server_channel) = set_up();
696        let cx = &mut Context::from_waker(noop_waker_ref());
697        let (tx, mut rx) = oneshot::channel();
698
699        let _resp = send_request(&mut channel, "hi", tx, &mut rx).await;
700
701        #[allow(unstable_name_collisions)]
702        let req = dispatch.as_mut().poll_next_request(cx).ready();
703        assert!(req.is_some());
704
705        let req = req.unwrap();
706        assert_eq!(req.request_id, 0);
707        assert_eq!(req.request, "hi".to_string());
708    }
709
710    // Regression test for  https://github.com/google/tarpc/issues/220
711    #[tokio::test]
712    async fn stage_request_channel_dropped_doesnt_panic() {
713        let (mut dispatch, mut channel, mut server_channel) = set_up();
714        let cx = &mut Context::from_waker(noop_waker_ref());
715        let (tx, mut rx) = oneshot::channel();
716
717        let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
718        drop(channel);
719
720        assert!(dispatch.as_mut().poll(cx).is_ready());
721        send_response(
722            &mut server_channel,
723            Response {
724                request_id: 0,
725                message: Ok("hello".into()),
726            },
727        )
728        .await;
729        dispatch.await.unwrap();
730    }
731
732    #[allow(unstable_name_collisions)]
733    #[tokio::test]
734    async fn stage_request_response_future_dropped_is_canceled_before_sending() {
735        let (mut dispatch, mut channel, _server_channel) = set_up();
736        let cx = &mut Context::from_waker(noop_waker_ref());
737        let (tx, mut rx) = oneshot::channel();
738
739        let _ = send_request(&mut channel, "hi", tx, &mut rx).await;
740
741        // Drop the channel so polling returns none if no requests are currently ready.
742        drop(channel);
743        // Test that a request future dropped before it's processed by dispatch will cause the request
744        // to not be added to the in-flight request map.
745        assert!(dispatch.as_mut().poll_next_request(cx).ready().is_none());
746    }
747
748    #[allow(unstable_name_collisions)]
749    #[tokio::test]
750    async fn stage_request_response_future_dropped_is_canceled_after_sending() {
751        let (mut dispatch, mut channel, _server_channel) = set_up();
752        let cx = &mut Context::from_waker(noop_waker_ref());
753        let (tx, mut rx) = oneshot::channel();
754
755        let req = send_request(&mut channel, "hi", tx, &mut rx).await;
756
757        assert!(dispatch.as_mut().pump_write(cx).ready().is_some());
758        assert!(!dispatch.in_flight_requests.is_empty());
759
760        // Test that a request future dropped after it's processed by dispatch will cause the request
761        // to be removed from the in-flight request map.
762        drop(req);
763        assert_matches!(
764            dispatch.as_mut().poll_next_cancellation(cx),
765            Poll::Ready(Some(Ok(_)))
766        );
767        assert!(dispatch.in_flight_requests.is_empty());
768    }
769
770    #[tokio::test]
771    async fn stage_request_response_closed_skipped() {
772        let (mut dispatch, mut channel, _server_channel) = set_up();
773        let cx = &mut Context::from_waker(noop_waker_ref());
774        let (tx, mut rx) = oneshot::channel();
775
776        // Test that a request future that's closed its receiver but not yet canceled its request --
777        // i.e. still in `drop fn` -- will cause the request to not be added to the in-flight request
778        // map.
779        let resp = send_request(&mut channel, "hi", tx, &mut rx).await;
780        resp.response.close();
781
782        assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
783    }
784
785    fn set_up() -> (
786        Pin<
787            Box<
788                RequestDispatch<
789                    String,
790                    String,
791                    UnboundedChannel<Response<String>, ClientMessage<String>>,
792                >,
793            >,
794        >,
795        Channel<String, String>,
796        UnboundedChannel<ClientMessage<String>, Response<String>>,
797    ) {
798        let _ = tracing_subscriber::fmt().with_test_writer().try_init();
799
800        let (to_dispatch, pending_requests) = mpsc::channel(1);
801        let (cancellation, canceled_requests) = cancellations();
802        let (client_channel, server_channel) = transport::channel::unbounded();
803
804        let dispatch = RequestDispatch::<String, String, _> {
805            transport: client_channel.fuse(),
806            pending_requests,
807            canceled_requests,
808            in_flight_requests: InFlightRequests::default(),
809            config: Config::default(),
810        };
811
812        let channel = Channel {
813            to_dispatch,
814            cancellation,
815            next_request_id: Arc::new(AtomicUsize::new(0)),
816        };
817
818        (Box::pin(dispatch), channel, server_channel)
819    }
820
821    async fn send_request<'a>(
822        channel: &'a mut Channel<String, String>,
823        request: &str,
824        response_completion: oneshot::Sender<Result<Response<String>, DeadlineExceededError>>,
825        response: &'a mut oneshot::Receiver<Result<Response<String>, DeadlineExceededError>>,
826    ) -> ResponseGuard<'a, String> {
827        let request_id =
828            u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
829        let request = DispatchRequest {
830            ctx: context::current(),
831            span: Span::current(),
832            request_id,
833            request: request.to_string(),
834            response_completion,
835        };
836        let response_guard = ResponseGuard {
837            response,
838            cancellation: &channel.cancellation,
839            request_id,
840            cancel: true,
841        };
842        channel.to_dispatch.send(request).await.unwrap();
843        response_guard
844    }
845
846    async fn send_response(
847        channel: &mut UnboundedChannel<ClientMessage<String>, Response<String>>,
848        response: Response<String>,
849    ) {
850        channel.send(response).await.unwrap();
851    }
852
853    trait PollTest {
854        type T;
855        fn unwrap(self) -> Poll<Self::T>;
856        fn ready(self) -> Self::T;
857    }
858
859    impl<T, E> PollTest for Poll<Option<Result<T, E>>>
860    where
861        E: ::std::fmt::Display,
862    {
863        type T = Option<T>;
864
865        fn unwrap(self) -> Poll<Option<T>> {
866            match self {
867                Poll::Ready(Some(Ok(t))) => Poll::Ready(Some(t)),
868                Poll::Ready(None) => Poll::Ready(None),
869                Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
870                Poll::Pending => Poll::Pending,
871            }
872        }
873
874        fn ready(self) -> Option<T> {
875            match self {
876                Poll::Ready(Some(Ok(t))) => Some(t),
877                Poll::Ready(None) => None,
878                Poll::Ready(Some(Err(e))) => panic!("{}", e.to_string()),
879                Poll::Pending => panic!("Pending"),
880            }
881        }
882    }
883}