forked_tarpc/
server.rs

1// Copyright 2018 Google LLC
2//
3// Use of this source code is governed by an MIT-style
4// license that can be found in the LICENSE file or at
5// https://opensource.org/licenses/MIT.
6
7//! Provides a server that concurrently handles many connections sending multiplexed requests.
8
9use crate::{
10    cancellations::{cancellations, CanceledRequests, RequestCancellation},
11    context::{self},
12    trace, ClientMessage, Request, Response, Transport,
13};
14#[cfg(feature = "opentelemetry")]
15use crate::context::SpanExt;
16use ::tokio::sync::mpsc;
17use futures::{
18    future::{AbortRegistration, Abortable},
19    prelude::*,
20    ready,
21    stream::Fuse,
22    task::*,
23};
24use in_flight_requests::{AlreadyExistsError, InFlightRequests};
25use pin_project::pin_project;
26use std::{convert::TryFrom, error::Error, fmt, marker::PhantomData, pin::Pin};
27use tracing::{info_span, instrument::Instrument, Span};
28
29mod in_flight_requests;
30#[cfg(test)]
31mod testing;
32
33/// Provides functionality to apply server limits.
34pub mod limits;
35
36/// Provides helper methods for streams of Channels.
37pub mod incoming;
38
39/// Provides convenience functionality for tokio-enabled applications.
40#[cfg(feature = "tokio1")]
41#[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
42pub mod tokio;
43
44/// Settings that control the behavior of [channels](Channel).
45#[derive(Clone, Debug)]
46pub struct Config {
47    /// Controls the buffer size of the in-process channel over which a server's handlers send
48    /// responses to the [`Channel`]. In other words, this is the number of responses that can sit
49    /// in the outbound queue before request handlers begin blocking.
50    pub pending_response_buffer: usize,
51}
52
53impl Default for Config {
54    fn default() -> Self {
55        Config {
56            pending_response_buffer: 100,
57        }
58    }
59}
60
61impl Config {
62    /// Returns a channel backed by `transport` and configured with `self`.
63    pub fn channel<Req, Resp, T>(self, transport: T) -> BaseChannel<Req, Resp, T>
64    where
65        T: Transport<Response<Resp>, ClientMessage<Req>>,
66    {
67        BaseChannel::new(self, transport)
68    }
69}
70
71/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
72pub trait Serve<Req> {
73    /// Type of response.
74    type Resp;
75
76    /// Type of response future.
77    type Fut: Future<Output = Self::Resp>;
78
79    /// Extracts a method name from the request.
80    fn method(&self, _request: &Req) -> Option<&'static str> {
81        None
82    }
83
84    /// Responds to a single request.
85    fn serve(self, ctx: context::Context, req: Req) -> Self::Fut;
86}
87
88impl<Req, Resp, Fut, F> Serve<Req> for F
89where
90    F: FnOnce(context::Context, Req) -> Fut,
91    Fut: Future<Output = Resp>,
92{
93    type Resp = Resp;
94    type Fut = Fut;
95
96    fn serve(self, ctx: context::Context, req: Req) -> Self::Fut {
97        self(ctx, req)
98    }
99}
100
101/// BaseChannel is the standard implementation of a [`Channel`].
102///
103/// BaseChannel manages a [`Transport`](Transport) of client [`messages`](ClientMessage) and
104/// implements a [`Stream`] of [requests](TrackedRequest). See the [`Channel`] documentation for
105/// how to use channels.
106///
107/// Besides requests, the other type of client message handled by `BaseChannel` is [cancellation
108/// messages](ClientMessage::Cancel). `BaseChannel` does not allow direct access to cancellation
109/// messages. Instead, it internally handles them by cancelling corresponding requests (removing
110/// the corresponding in-flight requests and aborting their handlers).
111#[pin_project]
112pub struct BaseChannel<Req, Resp, T> {
113    config: Config,
114    /// Writes responses to the wire and reads requests off the wire.
115    #[pin]
116    transport: Fuse<T>,
117    /// In-flight requests that were dropped by the server before completion.
118    #[pin]
119    canceled_requests: CanceledRequests,
120    /// Notifies `canceled_requests` when a request is canceled.
121    request_cancellation: RequestCancellation,
122    /// Holds data necessary to clean up in-flight requests.
123    in_flight_requests: InFlightRequests,
124    /// Types the request and response.
125    ghost: PhantomData<(Req, Resp)>,
126}
127
128impl<Req, Resp, T> BaseChannel<Req, Resp, T>
129where
130    T: Transport<Response<Resp>, ClientMessage<Req>>,
131{
132    /// Creates a new channel backed by `transport` and configured with `config`.
133    pub fn new(config: Config, transport: T) -> Self {
134        let (request_cancellation, canceled_requests) = cancellations();
135        BaseChannel {
136            config,
137            transport: transport.fuse(),
138            canceled_requests,
139            request_cancellation,
140            in_flight_requests: InFlightRequests::default(),
141            ghost: PhantomData,
142        }
143    }
144
145    /// Creates a new channel backed by `transport` and configured with the defaults.
146    pub fn with_defaults(transport: T) -> Self {
147        Self::new(Config::default(), transport)
148    }
149
150    /// Returns the inner transport over which messages are sent and received.
151    pub fn get_ref(&self) -> &T {
152        self.transport.get_ref()
153    }
154
155    /// Returns the inner transport over which messages are sent and received.
156    pub fn get_pin_ref(self: Pin<&mut Self>) -> Pin<&mut T> {
157        self.project().transport.get_pin_mut()
158    }
159
160    fn in_flight_requests_mut<'a>(self: &'a mut Pin<&mut Self>) -> &'a mut InFlightRequests {
161        self.as_mut().project().in_flight_requests
162    }
163
164    fn canceled_requests_pin_mut<'a>(
165        self: &'a mut Pin<&mut Self>,
166    ) -> Pin<&'a mut CanceledRequests> {
167        self.as_mut().project().canceled_requests
168    }
169
170    fn transport_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut Fuse<T>> {
171        self.as_mut().project().transport
172    }
173
174    fn start_request(
175        mut self: Pin<&mut Self>,
176        mut request: Request<Req>,
177    ) -> Result<TrackedRequest<Req>, AlreadyExistsError> {
178        let span = info_span!(
179            "RPC",
180            rpc.trace_id = %request.context.trace_id(),
181            rpc.deadline = %humantime::format_rfc3339(request.context.deadline),
182            otel.kind = "server",
183            otel.name = tracing::field::Empty,
184        );
185        #[cfg(feature = "opentelemetry")]
186        span.set_context(&request.context);
187        request.context.trace_context = trace::Context::try_from(&span).unwrap_or_else(|_| {
188            tracing::trace!(
189                "OpenTelemetry subscriber not installed; making unsampled \
190                            child context."
191            );
192            request.context.trace_context.new_child()
193        });
194        let entered = span.enter();
195        tracing::info!("ReceiveRequest");
196        let start = self.in_flight_requests_mut().start_request(
197            request.id,
198            request.context.deadline,
199            span.clone(),
200        );
201        match start {
202            Ok(abort_registration) => {
203                drop(entered);
204                Ok(TrackedRequest {
205                    abort_registration,
206                    span,
207                    response_guard: ResponseGuard {
208                        request_id: request.id,
209                        request_cancellation: self.request_cancellation.clone(),
210                        cancel: false,
211                    },
212                    request,
213                })
214            }
215            Err(AlreadyExistsError) => {
216                tracing::trace!("DuplicateRequest");
217                Err(AlreadyExistsError)
218            }
219        }
220    }
221}
222
223impl<Req, Resp, T> fmt::Debug for BaseChannel<Req, Resp, T> {
224    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
225        write!(f, "BaseChannel")
226    }
227}
228
229/// A request tracked by a [`Channel`].
230#[derive(Debug)]
231pub struct TrackedRequest<Req> {
232    /// The request sent by the client.
233    pub request: Request<Req>,
234    /// A registration to abort a future when the [`Channel`] that produced this request stops
235    /// tracking it.
236    pub abort_registration: AbortRegistration,
237    /// A span representing the server processing of this request.
238    pub span: Span,
239    /// An inert response guard. Becomes active in an InFlightRequest.
240    pub response_guard: ResponseGuard,
241}
242
243/// The server end of an open connection with a client, receiving requests from, and sending
244/// responses to, the client. `Channel` is a [`Transport`] with request lifecycle management.
245///
246/// The ways to use a Channel, in order of simplest to most complex, is:
247/// 1. [`Channel::execute`] - Requires the `tokio1` feature. This method is best for those who
248///    do not have specific scheduling needs and whose services are `Send + 'static`.
249/// 2. [`Channel::requests`] - This method is best for those who need direct access to individual
250///    requests, or are not using `tokio`, or want control over [futures](Future) scheduling.
251///    [`Requests`] is a stream of [`InFlightRequests`](InFlightRequest), each which has an
252///    [`execute`](InFlightRequest::execute) method. If using `execute`, request processing will
253///    automatically cease when either the request deadline is reached or when a corresponding
254///    cancellation message is received by the Channel.
255/// 3. [`Stream::next`](futures::stream::StreamExt::next) /
256///    [`Sink::send`](futures::sink::SinkExt::send) - A user is free to manually read requests
257///    from, and send responses into, a Channel in lieu of the previous methods. Channels stream
258///    [`TrackedRequests`](TrackedRequest), which, in addition to the request itself, contains the
259///    server [`Span`], request lifetime [`AbortRegistration`], and an inert [`ResponseGuard`].
260///    Wrapping response logic in an [`Abortable`] future using the abort registration will ensure
261///    that the response does not execute longer than the request deadline. The `Channel` itself
262///    will clean up request state once either the deadline expires, or the response guard is
263///    dropped, or a response is sent.
264///
265/// Channels must be implemented using the decorator pattern: the only way to create a
266/// `TrackedRequest` is to get one from another `Channel`. Ultimately, all `TrackedRequests` are
267/// created by [`BaseChannel`].
268pub trait Channel
269where
270    Self: Transport<Response<<Self as Channel>::Resp>, TrackedRequest<<Self as Channel>::Req>>,
271{
272    /// Type of request item.
273    type Req;
274
275    /// Type of response sink item.
276    type Resp;
277
278    /// The wrapped transport.
279    type Transport;
280
281    /// Configuration of the channel.
282    fn config(&self) -> &Config;
283
284    /// Returns the number of in-flight requests over this channel.
285    fn in_flight_requests(&self) -> usize;
286
287    /// Returns the transport underlying the channel.
288    fn transport(&self) -> &Self::Transport;
289
290    /// Caps the number of concurrent requests to `limit`. An error will be returned for requests
291    /// over the concurrency limit.
292    ///
293    /// Note that this is a very
294    /// simplistic throttling heuristic. It is easy to set a number that is too low for the
295    /// resources available to the server. For production use cases, a more advanced throttler is
296    /// likely needed.
297    fn max_concurrent_requests(
298        self,
299        limit: usize,
300    ) -> limits::requests_per_channel::MaxRequests<Self>
301    where
302        Self: Sized,
303    {
304        limits::requests_per_channel::MaxRequests::new(self, limit)
305    }
306
307    /// Returns a stream of requests that automatically handle request cancellation and response
308    /// routing.
309    ///
310    /// This is a terminal operation. After calling `requests`, the channel cannot be retrieved,
311    /// and the only way to complete requests is via [`Requests::execute`] or
312    /// [`InFlightRequest::execute`].
313    fn requests(self) -> Requests<Self>
314    where
315        Self: Sized,
316    {
317        let (responses_tx, responses) = mpsc::channel(self.config().pending_response_buffer);
318
319        Requests {
320            channel: self,
321            pending_responses: responses,
322            responses_tx,
323        }
324    }
325
326    /// Runs the channel until completion by executing all requests using the given service
327    /// function. Request handlers are run concurrently by [spawning](::tokio::spawn) on tokio's
328    /// default executor.
329    #[cfg(feature = "tokio1")]
330    #[cfg_attr(docsrs, doc(cfg(feature = "tokio1")))]
331    fn execute<S>(self, serve: S) -> self::tokio::TokioChannelExecutor<Requests<Self>, S>
332    where
333        Self: Sized,
334        S: Serve<Self::Req, Resp = Self::Resp> + Send + 'static,
335        S::Fut: Send,
336        Self::Req: Send + 'static,
337        Self::Resp: Send + 'static,
338    {
339        self.requests().execute(serve)
340    }
341}
342
343/// Critical errors that result in a Channel disconnecting.
344#[derive(thiserror::Error, Debug)]
345pub enum ChannelError<E>
346where
347    E: Error + Send + Sync + 'static,
348{
349    /// An error occurred reading from, or writing to, the transport.
350    #[error("an error occurred in the transport: {0}")]
351    Transport(#[source] E),
352    /// An error occurred while polling expired requests.
353    #[error("an error occurred while polling expired requests: {0}")]
354    Timer(#[source] ::tokio::time::error::Error),
355}
356
357impl<Req, Resp, T> Stream for BaseChannel<Req, Resp, T>
358where
359    T: Transport<Response<Resp>, ClientMessage<Req>>,
360{
361    type Item = Result<TrackedRequest<Req>, ChannelError<T::Error>>;
362
363    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Option<Self::Item>> {
364        #[derive(Clone, Copy, Debug)]
365        enum ReceiverStatus {
366            Ready,
367            Pending,
368            Closed,
369        }
370
371        impl ReceiverStatus {
372            fn combine(self, other: Self) -> Self {
373                use ReceiverStatus::*;
374                match (self, other) {
375                    (Ready, _) | (_, Ready) => Ready,
376                    (Closed, Closed) => Closed,
377                    (Pending, Closed) | (Closed, Pending) | (Pending, Pending) => Pending,
378                }
379            }
380        }
381
382        use ReceiverStatus::*;
383
384        loop {
385            let cancellation_status = match self.canceled_requests_pin_mut().poll_recv(cx) {
386                Poll::Ready(Some(request_id)) => {
387                    if let Some(span) = self.in_flight_requests_mut().remove_request(request_id) {
388                        let _entered = span.enter();
389                        tracing::info!("ResponseCancelled");
390                    }
391                    Ready
392                }
393                // Pending cancellations don't block Channel closure, because all they do is ensure
394                // the Channel's internal state is cleaned up. But Channel closure also cleans up
395                // the Channel state, so there's no reason to wait on a cancellation before
396                // closing.
397                //
398                // Ready(None) can't happen, since `self` holds a Cancellation.
399                Poll::Pending | Poll::Ready(None) => Closed,
400            };
401
402            let expiration_status = match self.in_flight_requests_mut().poll_expired(cx) {
403                // No need to send a response, since the client wouldn't be waiting for one
404                // anymore.
405                Poll::Ready(Some(_)) => Ready,
406                Poll::Ready(None) => Closed,
407                Poll::Pending => Pending,
408            };
409
410            let request_status = match self
411                .transport_pin_mut()
412                .poll_next(cx)
413                .map_err(ChannelError::Transport)?
414            {
415                Poll::Ready(Some(message)) => match message {
416                    ClientMessage::Request(request) => {
417                        match self.as_mut().start_request(request) {
418                            Ok(request) => return Poll::Ready(Some(Ok(request))),
419                            Err(AlreadyExistsError) => {
420                                // Instead of closing the channel if a duplicate request is sent,
421                                // just ignore it, since it's already being processed. Note that we
422                                // cannot return Poll::Pending here, since nothing has scheduled a
423                                // wakeup yet.
424                                continue;
425                            }
426                        }
427                    }
428                    ClientMessage::Cancel {
429                        trace_context,
430                        request_id,
431                    } => {
432                        if !self.in_flight_requests_mut().cancel_request(request_id) {
433                            tracing::trace!(
434                                rpc.trace_id = %trace_context.trace_id,
435                                "Received cancellation, but response handler is already complete.",
436                            );
437                        }
438                        Ready
439                    }
440                },
441                Poll::Ready(None) => Closed,
442                Poll::Pending => Pending,
443            };
444
445            tracing::trace!(
446                "Expired requests: {:?}, Inbound: {:?}",
447                expiration_status,
448                request_status
449            );
450            match cancellation_status
451                .combine(expiration_status)
452                .combine(request_status)
453            {
454                Ready => continue,
455                Closed => return Poll::Ready(None),
456                Pending => return Poll::Pending,
457            }
458        }
459    }
460}
461
462impl<Req, Resp, T> Sink<Response<Resp>> for BaseChannel<Req, Resp, T>
463where
464    T: Transport<Response<Resp>, ClientMessage<Req>>,
465    T::Error: Error,
466{
467    type Error = ChannelError<T::Error>;
468
469    fn poll_ready(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
470        self.project()
471            .transport
472            .poll_ready(cx)
473            .map_err(ChannelError::Transport)
474    }
475
476    fn start_send(mut self: Pin<&mut Self>, response: Response<Resp>) -> Result<(), Self::Error> {
477        if let Some(span) = self
478            .in_flight_requests_mut()
479            .remove_request(response.request_id)
480        {
481            let _entered = span.enter();
482            tracing::info!("SendResponse");
483            self.project()
484                .transport
485                .start_send(response)
486                .map_err(ChannelError::Transport)
487        } else {
488            // If the request isn't tracked anymore, there's no need to send the response.
489            Ok(())
490        }
491    }
492
493    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
494        tracing::trace!("poll_flush");
495        self.project()
496            .transport
497            .poll_flush(cx)
498            .map_err(ChannelError::Transport)
499    }
500
501    fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Result<(), Self::Error>> {
502        self.project()
503            .transport
504            .poll_close(cx)
505            .map_err(ChannelError::Transport)
506    }
507}
508
509impl<Req, Resp, T> AsRef<T> for BaseChannel<Req, Resp, T> {
510    fn as_ref(&self) -> &T {
511        self.transport.get_ref()
512    }
513}
514
515impl<Req, Resp, T> Channel for BaseChannel<Req, Resp, T>
516where
517    T: Transport<Response<Resp>, ClientMessage<Req>>,
518{
519    type Req = Req;
520    type Resp = Resp;
521    type Transport = T;
522
523    fn config(&self) -> &Config {
524        &self.config
525    }
526
527    fn in_flight_requests(&self) -> usize {
528        self.in_flight_requests.len()
529    }
530
531    fn transport(&self) -> &Self::Transport {
532        self.get_ref()
533    }
534}
535
536/// A stream of requests coming over a channel. `Requests` also drives the sending of responses, so
537/// it must be continually polled to ensure progress.
538#[pin_project]
539pub struct Requests<C>
540where
541    C: Channel,
542{
543    #[pin]
544    channel: C,
545    /// Responses waiting to be written to the wire.
546    pending_responses: mpsc::Receiver<Response<C::Resp>>,
547    /// Handed out to request handlers to fan in responses.
548    responses_tx: mpsc::Sender<Response<C::Resp>>,
549}
550
551impl<C> Requests<C>
552where
553    C: Channel,
554{
555    /// Returns a reference to the inner channel over which messages are sent and received.
556    pub fn channel(&self) -> &C {
557        &self.channel
558    }
559
560    /// Returns the inner channel over which messages are sent and received.
561    pub fn channel_pin_mut<'a>(self: &'a mut Pin<&mut Self>) -> Pin<&'a mut C> {
562        self.as_mut().project().channel
563    }
564
565    /// Returns the inner channel over which messages are sent and received.
566    pub fn pending_responses_mut<'a>(
567        self: &'a mut Pin<&mut Self>,
568    ) -> &'a mut mpsc::Receiver<Response<C::Resp>> {
569        self.as_mut().project().pending_responses
570    }
571
572    fn pump_read(
573        mut self: Pin<&mut Self>,
574        cx: &mut Context<'_>,
575    ) -> Poll<Option<Result<InFlightRequest<C::Req, C::Resp>, C::Error>>> {
576        self.channel_pin_mut().poll_next(cx).map_ok(
577            |TrackedRequest {
578                 request,
579                 abort_registration,
580                 span,
581                 mut response_guard,
582             }| {
583                // The response guard becomes active once in an InFlightRequest.
584                response_guard.cancel = true;
585                InFlightRequest {
586                    request,
587                    abort_registration,
588                    span,
589                    response_guard,
590                    response_tx: self.responses_tx.clone(),
591                }
592            },
593        )
594    }
595
596    fn pump_write(
597        mut self: Pin<&mut Self>,
598        cx: &mut Context<'_>,
599        read_half_closed: bool,
600    ) -> Poll<Option<Result<(), C::Error>>> {
601        match self.as_mut().poll_next_response(cx)? {
602            Poll::Ready(Some(response)) => {
603                // A Ready result from poll_next_response means the Channel is ready to be written
604                // to. Therefore, we can call start_send without worry of a full buffer.
605                self.channel_pin_mut().start_send(response)?;
606                Poll::Ready(Some(Ok(())))
607            }
608            Poll::Ready(None) => {
609                // Shutdown can't be done before we finish pumping out remaining responses.
610                ready!(self.channel_pin_mut().poll_flush(cx)?);
611                Poll::Ready(None)
612            }
613            Poll::Pending => {
614                // No more requests to process, so flush any requests buffered in the transport.
615                ready!(self.channel_pin_mut().poll_flush(cx)?);
616
617                // Being here means there are no staged requests and all written responses are
618                // fully flushed. So, if the read half is closed and there are no in-flight
619                // requests, then we can close the write half.
620                if read_half_closed && self.channel.in_flight_requests() == 0 {
621                    Poll::Ready(None)
622                } else {
623                    Poll::Pending
624                }
625            }
626        }
627    }
628
629    /// Yields a response ready to be written to the Channel sink.
630    ///
631    /// Note that a response will only be yielded if the Channel is *ready* to be written to (i.e.
632    /// start_send would succeed).
633    fn poll_next_response(
634        mut self: Pin<&mut Self>,
635        cx: &mut Context<'_>,
636    ) -> Poll<Option<Result<Response<C::Resp>, C::Error>>> {
637        ready!(self.ensure_writeable(cx)?);
638
639        match ready!(self.pending_responses_mut().poll_recv(cx)) {
640            Some(response) => Poll::Ready(Some(Ok(response))),
641            None => {
642                // This branch likely won't happen, since the Requests stream is holding a Sender.
643                Poll::Ready(None)
644            }
645        }
646    }
647
648    /// Returns Ready if writing a message to the Channel would not fail due to a full buffer. If
649    /// the Channel is not ready to be written to, flushes it until it is ready.
650    fn ensure_writeable<'a>(
651        self: &'a mut Pin<&mut Self>,
652        cx: &mut Context<'_>,
653    ) -> Poll<Option<Result<(), C::Error>>> {
654        while self.channel_pin_mut().poll_ready(cx)?.is_pending() {
655            ready!(self.channel_pin_mut().poll_flush(cx)?);
656        }
657        Poll::Ready(Some(Ok(())))
658    }
659}
660
661impl<C> fmt::Debug for Requests<C>
662where
663    C: Channel,
664{
665    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
666        write!(fmt, "Requests")
667    }
668}
669
670/// A fail-safe to ensure requests are properly canceled if request processing is aborted before
671/// completing.
672#[derive(Debug)]
673pub struct ResponseGuard {
674    request_cancellation: RequestCancellation,
675    request_id: u64,
676    cancel: bool,
677}
678
679impl Drop for ResponseGuard {
680    fn drop(&mut self) {
681        if self.cancel {
682            self.request_cancellation.cancel(self.request_id);
683        }
684    }
685}
686
687/// A request produced by [Channel::requests].
688///
689/// If dropped without calling [`execute`](InFlightRequest::execute), a cancellation message will
690/// be sent to the Channel to clean up associated request state.
691#[derive(Debug)]
692pub struct InFlightRequest<Req, Res> {
693    request: Request<Req>,
694    abort_registration: AbortRegistration,
695    response_guard: ResponseGuard,
696    span: Span,
697    response_tx: mpsc::Sender<Response<Res>>,
698}
699
700impl<Req, Res> InFlightRequest<Req, Res> {
701    /// Returns a reference to the request.
702    pub fn get(&self) -> &Request<Req> {
703        &self.request
704    }
705
706    /// Returns a [future](Future) that executes the request using the given [service
707    /// function](Serve). The service function's output is automatically sent back to the [Channel]
708    /// that yielded this request. The request will be executed in the scope of this request's
709    /// context.
710    ///
711    /// The returned future will stop executing when the first of the following conditions is met:
712    ///
713    /// 1. The channel that yielded this request receives a [cancellation
714    ///    message](ClientMessage::Cancel) for this request.
715    /// 2. The request [deadline](crate::context::Context::deadline) is reached.
716    /// 3. The service function completes.
717    ///
718    /// If the returned Future is dropped before completion, a cancellation message will be sent to
719    /// the Channel to clean up associated request state.
720    pub async fn execute<S>(self, serve: S)
721    where
722        S: Serve<Req, Resp = Res>,
723    {
724        let Self {
725            response_tx,
726            mut response_guard,
727            abort_registration,
728            span,
729            request:
730                Request {
731                    context,
732                    message,
733                    id: request_id,
734                },
735        } = self;
736        let method = serve.method(&message);
737        // TODO(https://github.com/rust-lang/rust-clippy/issues/9111)
738        // remove when clippy is fixed
739        #[allow(clippy::needless_borrow)]
740        span.record("otel.name", &method.unwrap_or(""));
741        let _ = Abortable::new(
742            async move {
743                tracing::info!("BeginRequest");
744                let response = serve.serve(context, message).await;
745
746                tracing::info!("CompleteRequest");
747                if context.discard_response {
748                    tracing::info!("DiscardingResponse");
749                } else {
750                    let response = Response {
751                        request_id,
752                        message: Ok(response),
753                    };
754                    let _ = response_tx.send(response).await;
755                    tracing::info!("BufferResponse");
756                }
757            },
758            abort_registration,
759        )
760        .instrument(span)
761        .await;
762        // Request processing has completed, meaning either the channel canceled the request or
763        // a request was sent back to the channel. Either way, the channel will clean up the
764        // request data, so the request does not need to be canceled.
765        response_guard.cancel = false;
766    }
767}
768
769impl<C> Stream for Requests<C>
770where
771    C: Channel,
772{
773    type Item = Result<InFlightRequest<C::Req, C::Resp>, C::Error>;
774
775    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
776        loop {
777            let read = self.as_mut().pump_read(cx)?;
778            let read_closed = matches!(read, Poll::Ready(None));
779            match (read, self.as_mut().pump_write(cx, read_closed)?) {
780                (Poll::Ready(None), Poll::Ready(None)) => {
781                    return Poll::Ready(None);
782                }
783                (Poll::Ready(Some(request_handler)), _) => {
784                    return Poll::Ready(Some(Ok(request_handler)));
785                }
786                (_, Poll::Ready(Some(()))) => {}
787                _ => {
788                    return Poll::Pending;
789                }
790            }
791        }
792    }
793}
794
795#[cfg(test)]
796mod tests {
797    use super::{in_flight_requests::AlreadyExistsError, BaseChannel, Channel, Config, Requests};
798    use crate::{
799        context, trace,
800        transport::channel::{self, UnboundedChannel},
801        ClientMessage, Request, Response,
802    };
803    use assert_matches::assert_matches;
804    use futures::{
805        future::{pending, AbortRegistration, Abortable, Aborted},
806        prelude::*,
807        Future,
808    };
809    use futures_test::task::noop_context;
810    use std::{pin::Pin, task::Poll};
811
812    fn test_channel<Req, Resp>() -> (
813        Pin<Box<BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>>>,
814        UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
815    ) {
816        let (tx, rx) = crate::transport::channel::unbounded();
817        (Box::pin(BaseChannel::new(Config::default(), rx)), tx)
818    }
819
820    fn test_requests<Req, Resp>() -> (
821        Pin<
822            Box<
823                Requests<
824                    BaseChannel<Req, Resp, UnboundedChannel<ClientMessage<Req>, Response<Resp>>>,
825                >,
826            >,
827        >,
828        UnboundedChannel<Response<Resp>, ClientMessage<Req>>,
829    ) {
830        let (tx, rx) = crate::transport::channel::unbounded();
831        (
832            Box::pin(BaseChannel::new(Config::default(), rx).requests()),
833            tx,
834        )
835    }
836
837    fn test_bounded_requests<Req, Resp>(
838        capacity: usize,
839    ) -> (
840        Pin<
841            Box<
842                Requests<
843                    BaseChannel<Req, Resp, channel::Channel<ClientMessage<Req>, Response<Resp>>>,
844                >,
845            >,
846        >,
847        channel::Channel<Response<Resp>, ClientMessage<Req>>,
848    ) {
849        let (tx, rx) = crate::transport::channel::bounded(capacity);
850        // Add 1 because capacity 0 is not supported (but is supported by transport::channel::bounded).
851        let config = Config {
852            pending_response_buffer: capacity + 1,
853        };
854        (Box::pin(BaseChannel::new(config, rx).requests()), tx)
855    }
856
857    fn fake_request<Req>(req: Req) -> ClientMessage<Req> {
858        ClientMessage::Request(Request {
859            context: context::current(),
860            id: 0,
861            message: req,
862        })
863    }
864
865    fn test_abortable(
866        abort_registration: AbortRegistration,
867    ) -> impl Future<Output = Result<(), Aborted>> {
868        Abortable::new(pending(), abort_registration)
869    }
870
871    #[tokio::test]
872    async fn base_channel_start_send_duplicate_request_returns_error() {
873        let (mut channel, _tx) = test_channel::<(), ()>();
874
875        channel
876            .as_mut()
877            .start_request(Request {
878                id: 0,
879                context: context::current(),
880                message: (),
881            })
882            .unwrap();
883        assert_matches!(
884            channel.as_mut().start_request(Request {
885                id: 0,
886                context: context::current(),
887                message: ()
888            }),
889            Err(AlreadyExistsError)
890        );
891    }
892
893    #[tokio::test]
894    async fn base_channel_poll_next_aborts_multiple_requests() {
895        let (mut channel, _tx) = test_channel::<(), ()>();
896
897        tokio::time::pause();
898        let req0 = channel
899            .as_mut()
900            .start_request(Request {
901                id: 0,
902                context: context::current(),
903                message: (),
904            })
905            .unwrap();
906        let req1 = channel
907            .as_mut()
908            .start_request(Request {
909                id: 1,
910                context: context::current(),
911                message: (),
912            })
913            .unwrap();
914        tokio::time::advance(std::time::Duration::from_secs(1000)).await;
915
916        assert_matches!(
917            channel.as_mut().poll_next(&mut noop_context()),
918            Poll::Pending
919        );
920        assert_matches!(test_abortable(req0.abort_registration).await, Err(Aborted));
921        assert_matches!(test_abortable(req1.abort_registration).await, Err(Aborted));
922    }
923
924    #[tokio::test]
925    async fn base_channel_poll_next_aborts_canceled_request() {
926        let (mut channel, mut tx) = test_channel::<(), ()>();
927
928        tokio::time::pause();
929        let req = channel
930            .as_mut()
931            .start_request(Request {
932                id: 0,
933                context: context::current(),
934                message: (),
935            })
936            .unwrap();
937
938        tx.send(ClientMessage::Cancel {
939            trace_context: trace::Context::default(),
940            request_id: 0,
941        })
942        .await
943        .unwrap();
944
945        assert_matches!(
946            channel.as_mut().poll_next(&mut noop_context()),
947            Poll::Pending
948        );
949
950        assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
951    }
952
953    #[tokio::test]
954    async fn base_channel_with_closed_transport_and_in_flight_request_returns_pending() {
955        let (mut channel, tx) = test_channel::<(), ()>();
956
957        tokio::time::pause();
958        let _abort_registration = channel
959            .as_mut()
960            .start_request(Request {
961                id: 0,
962                context: context::current(),
963                message: (),
964            })
965            .unwrap();
966
967        drop(tx);
968        assert_matches!(
969            channel.as_mut().poll_next(&mut noop_context()),
970            Poll::Pending
971        );
972    }
973
974    #[tokio::test]
975    async fn base_channel_with_closed_transport_and_no_in_flight_requests_returns_closed() {
976        let (mut channel, tx) = test_channel::<(), ()>();
977        drop(tx);
978        assert_matches!(
979            channel.as_mut().poll_next(&mut noop_context()),
980            Poll::Ready(None)
981        );
982    }
983
984    #[tokio::test]
985    async fn base_channel_poll_next_yields_request() {
986        let (mut channel, mut tx) = test_channel::<(), ()>();
987        tx.send(fake_request(())).await.unwrap();
988
989        assert_matches!(
990            channel.as_mut().poll_next(&mut noop_context()),
991            Poll::Ready(Some(Ok(_)))
992        );
993    }
994
995    #[tokio::test]
996    async fn base_channel_poll_next_aborts_request_and_yields_request() {
997        let (mut channel, mut tx) = test_channel::<(), ()>();
998
999        tokio::time::pause();
1000        let req = channel
1001            .as_mut()
1002            .start_request(Request {
1003                id: 0,
1004                context: context::current(),
1005                message: (),
1006            })
1007            .unwrap();
1008        tokio::time::advance(std::time::Duration::from_secs(1000)).await;
1009
1010        tx.send(fake_request(())).await.unwrap();
1011
1012        assert_matches!(
1013            channel.as_mut().poll_next(&mut noop_context()),
1014            Poll::Ready(Some(Ok(_)))
1015        );
1016        assert_matches!(test_abortable(req.abort_registration).await, Err(Aborted));
1017    }
1018
1019    #[tokio::test]
1020    async fn base_channel_start_send_removes_in_flight_request() {
1021        let (mut channel, _tx) = test_channel::<(), ()>();
1022
1023        channel
1024            .as_mut()
1025            .start_request(Request {
1026                id: 0,
1027                context: context::current(),
1028                message: (),
1029            })
1030            .unwrap();
1031        assert_eq!(channel.in_flight_requests(), 1);
1032        channel
1033            .as_mut()
1034            .start_send(Response {
1035                request_id: 0,
1036                message: Ok(()),
1037            })
1038            .unwrap();
1039        assert_eq!(channel.in_flight_requests(), 0);
1040    }
1041
1042    #[tokio::test]
1043    async fn in_flight_request_drop_cancels_request() {
1044        let (mut requests, mut tx) = test_requests::<(), ()>();
1045        tx.send(fake_request(())).await.unwrap();
1046
1047        let request = match requests.as_mut().poll_next(&mut noop_context()) {
1048            Poll::Ready(Some(Ok(request))) => request,
1049            result => panic!("Unexpected result: {:?}", result),
1050        };
1051        drop(request);
1052
1053        let poll = requests
1054            .as_mut()
1055            .channel_pin_mut()
1056            .poll_next(&mut noop_context());
1057        assert!(poll.is_pending());
1058        let in_flight_requests = requests.channel().in_flight_requests();
1059        assert_eq!(in_flight_requests, 0);
1060    }
1061
1062    #[tokio::test]
1063    async fn in_flight_requests_successful_execute_doesnt_cancel_request() {
1064        let (mut requests, mut tx) = test_requests::<(), ()>();
1065        tx.send(fake_request(())).await.unwrap();
1066
1067        let request = match requests.as_mut().poll_next(&mut noop_context()) {
1068            Poll::Ready(Some(Ok(request))) => request,
1069            result => panic!("Unexpected result: {:?}", result),
1070        };
1071        request.execute(|_, _| async {}).await;
1072        assert!(requests
1073            .as_mut()
1074            .channel_pin_mut()
1075            .canceled_requests
1076            .poll_recv(&mut noop_context())
1077            .is_pending());
1078    }
1079
1080    #[tokio::test]
1081    async fn requests_poll_next_response_returns_pending_when_buffer_full() {
1082        let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
1083
1084        // Response written to the transport.
1085        requests
1086            .as_mut()
1087            .channel_pin_mut()
1088            .start_request(Request {
1089                id: 0,
1090                context: context::current(),
1091                message: (),
1092            })
1093            .unwrap();
1094        requests
1095            .as_mut()
1096            .channel_pin_mut()
1097            .start_send(Response {
1098                request_id: 0,
1099                message: Ok(()),
1100            })
1101            .unwrap();
1102
1103        // Response waiting to be written.
1104        requests
1105            .as_mut()
1106            .project()
1107            .responses_tx
1108            .send(Response {
1109                request_id: 1,
1110                message: Ok(()),
1111            })
1112            .await
1113            .unwrap();
1114
1115        requests
1116            .as_mut()
1117            .channel_pin_mut()
1118            .start_request(Request {
1119                id: 1,
1120                context: context::current(),
1121                message: (),
1122            })
1123            .unwrap();
1124
1125        assert_matches!(
1126            requests.as_mut().poll_next_response(&mut noop_context()),
1127            Poll::Pending
1128        );
1129    }
1130
1131    #[tokio::test]
1132    async fn requests_pump_write_returns_pending_when_buffer_full() {
1133        let (mut requests, _tx) = test_bounded_requests::<(), ()>(0);
1134
1135        // Response written to the transport.
1136        requests
1137            .as_mut()
1138            .channel_pin_mut()
1139            .start_request(Request {
1140                id: 0,
1141                context: context::current(),
1142                message: (),
1143            })
1144            .unwrap();
1145        requests
1146            .as_mut()
1147            .channel_pin_mut()
1148            .start_send(Response {
1149                request_id: 0,
1150                message: Ok(()),
1151            })
1152            .unwrap();
1153
1154        // Response waiting to be written.
1155        requests
1156            .as_mut()
1157            .channel_pin_mut()
1158            .start_request(Request {
1159                id: 1,
1160                context: context::current(),
1161                message: (),
1162            })
1163            .unwrap();
1164        requests
1165            .as_mut()
1166            .project()
1167            .responses_tx
1168            .send(Response {
1169                request_id: 1,
1170                message: Ok(()),
1171            })
1172            .await
1173            .unwrap();
1174
1175        assert_matches!(
1176            requests.as_mut().pump_write(&mut noop_context(), true),
1177            Poll::Pending
1178        );
1179        // Assert that the pending response was not polled while the channel was blocked.
1180        assert_matches!(
1181            requests.as_mut().pending_responses_mut().recv().await,
1182            Some(_)
1183        );
1184    }
1185
1186    #[tokio::test]
1187    async fn requests_pump_read() {
1188        let (mut requests, mut tx) = test_requests::<(), ()>();
1189
1190        // Response written to the transport.
1191        tx.send(fake_request(())).await.unwrap();
1192
1193        assert_matches!(
1194            requests.as_mut().pump_read(&mut noop_context()),
1195            Poll::Ready(Some(Ok(_)))
1196        );
1197        assert_eq!(requests.channel.in_flight_requests(), 1);
1198    }
1199}