tower_buffer/
worker.rs

1use crate::{
2    error::{Closed, Error, ServiceError},
3    message::Message,
4};
5use futures_core::ready;
6use pin_project::pin_project;
7use std::sync::{Arc, Mutex};
8use std::{
9    future::Future,
10    pin::Pin,
11    task::{Context, Poll},
12};
13use tokio::sync::mpsc;
14use tower_service::Service;
15
16/// Task that handles processing the buffer. This type should not be used
17/// directly, instead `Buffer` requires an `Executor` that can accept this task.
18///
19/// The struct is `pub` in the private module and the type is *not* re-exported
20/// as part of the public API. This is the "sealed" pattern to include "private"
21/// types in public traits that are not meant for consumers of the library to
22/// implement (only call).
23#[pin_project]
24#[derive(Debug)]
25pub struct Worker<T, Request>
26where
27    T: Service<Request>,
28    T::Error: Into<Error>,
29{
30    current_message: Option<Message<Request, T::Future>>,
31    rx: mpsc::Receiver<Message<Request, T::Future>>,
32    service: T,
33    finish: bool,
34    failed: Option<ServiceError>,
35    handle: Handle,
36}
37
38/// Get the error out
39#[derive(Debug)]
40pub(crate) struct Handle {
41    inner: Arc<Mutex<Option<ServiceError>>>,
42}
43
44impl<T, Request> Worker<T, Request>
45where
46    T: Service<Request>,
47    T::Error: Into<Error>,
48{
49    pub(crate) fn new(
50        service: T,
51        rx: mpsc::Receiver<Message<Request, T::Future>>,
52    ) -> (Handle, Worker<T, Request>) {
53        let handle = Handle {
54            inner: Arc::new(Mutex::new(None)),
55        };
56
57        let worker = Worker {
58            current_message: None,
59            finish: false,
60            failed: None,
61            rx,
62            service,
63            handle: handle.clone(),
64        };
65
66        (handle, worker)
67    }
68
69    /// Return the next queued Message that hasn't been canceled.
70    ///
71    /// If a `Message` is returned, the `bool` is true if this is the first time we received this
72    /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
73    fn poll_next_msg(
74        &mut self,
75        cx: &mut Context<'_>,
76    ) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
77        if self.finish {
78            // We've already received None and are shutting down
79            return Poll::Ready(None);
80        }
81
82        tracing::trace!("worker polling for next message");
83        if let Some(mut msg) = self.current_message.take() {
84            // poll_closed returns Poll::Ready is the receiver is dropped.
85            // Returning Pending means it is still alive, so we should still
86            // use it.
87            if msg.tx.poll_closed(cx).is_pending() {
88                tracing::trace!("resuming buffered request");
89                return Poll::Ready(Some((msg, false)));
90            }
91
92            tracing::trace!("dropping cancelled buffered request");
93        }
94
95        // Get the next request
96        while let Some(mut msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
97            if msg.tx.poll_closed(cx).is_pending() {
98                tracing::trace!("processing new request");
99                return Poll::Ready(Some((msg, true)));
100            }
101            // Otherwise, request is canceled, so pop the next one.
102            tracing::trace!("dropping cancelled request");
103        }
104
105        Poll::Ready(None)
106    }
107
108    fn failed(&mut self, error: Error) {
109        // The underlying service failed when we called `poll_ready` on it with the given `error`. We
110        // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
111        // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
112        // requests will also fail with the same error.
113
114        // Note that we need to handle the case where some handle is concurrently trying to send us
115        // a request. We need to make sure that *either* the send of the request fails *or* it
116        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
117        // case where we send errors to all outstanding requests, and *then* the caller sends its
118        // request. We do this by *first* exposing the error, *then* closing the channel used to
119        // send more requests (so the client will see the error when the send fails), and *then*
120        // sending the error to all outstanding requests.
121        let error = ServiceError::new(error);
122
123        let mut inner = self.handle.inner.lock().unwrap();
124
125        if inner.is_some() {
126            // Future::poll was called after we've already errored out!
127            return;
128        }
129
130        *inner = Some(error.clone());
131        drop(inner);
132
133        self.rx.close();
134
135        // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
136        // which will trigger the `self.finish == true` phase. We just need to make sure that any
137        // requests that we receive before we've exhausted the receiver receive the error:
138        self.failed = Some(error);
139    }
140}
141
142impl<T, Request> Future for Worker<T, Request>
143where
144    T: Service<Request>,
145    T::Error: Into<Error>,
146{
147    type Output = ();
148
149    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150        if self.finish {
151            return Poll::Ready(());
152        }
153
154        loop {
155            match ready!(self.poll_next_msg(cx)) {
156                Some((msg, first)) => {
157                    let _guard = msg.span.enter();
158                    if let Some(ref failed) = self.failed {
159                        tracing::trace!("notifying caller about worker failure");
160                        let _ = msg.tx.send(Err(failed.clone()));
161                        continue;
162                    }
163
164                    // Wait for the service to be ready
165                    tracing::trace!(
166                        resumed = !first,
167                        message = "worker received request; waiting for service readiness"
168                    );
169                    match self.service.poll_ready(cx) {
170                        Poll::Ready(Ok(())) => {
171                            tracing::debug!(service.ready = true, message = "processing request");
172                            let response = self.service.call(msg.request);
173
174                            // Send the response future back to the sender.
175                            //
176                            // An error means the request had been canceled in-between
177                            // our calls, the response future will just be dropped.
178                            tracing::trace!("returning response future");
179                            let _ = msg.tx.send(Ok(response));
180                        }
181                        Poll::Pending => {
182                            tracing::trace!(service.ready = false, message = "delay");
183                            // Put out current message back in its slot.
184                            drop(_guard);
185                            self.current_message = Some(msg);
186                            return Poll::Pending;
187                        }
188                        Poll::Ready(Err(e)) => {
189                            let error = e.into();
190                            tracing::debug!({ %error }, "service failed");
191                            drop(_guard);
192                            self.failed(error);
193                            let _ = msg.tx.send(Err(self
194                                .failed
195                                .as_ref()
196                                .expect("Worker::failed did not set self.failed?")
197                                .clone()));
198                        }
199                    }
200                }
201                None => {
202                    // No more more requests _ever_.
203                    self.finish = true;
204                    return Poll::Ready(());
205                }
206            }
207        }
208    }
209}
210
211impl Handle {
212    pub(crate) fn get_error_on_closed(&self) -> Error {
213        self.inner
214            .lock()
215            .unwrap()
216            .as_ref()
217            .map(|svc_err| svc_err.clone().into())
218            .unwrap_or_else(|| Closed::new().into())
219    }
220}
221
222impl Clone for Handle {
223    fn clone(&self) -> Handle {
224        Handle {
225            inner: self.inner.clone(),
226        }
227    }
228}