tower_buffer/worker.rs
1use crate::{
2 error::{Closed, Error, ServiceError},
3 message::Message,
4};
5use futures_core::ready;
6use pin_project::pin_project;
7use std::sync::{Arc, Mutex};
8use std::{
9 future::Future,
10 pin::Pin,
11 task::{Context, Poll},
12};
13use tokio::sync::mpsc;
14use tower_service::Service;
15
16/// Task that handles processing the buffer. This type should not be used
17/// directly, instead `Buffer` requires an `Executor` that can accept this task.
18///
19/// The struct is `pub` in the private module and the type is *not* re-exported
20/// as part of the public API. This is the "sealed" pattern to include "private"
21/// types in public traits that are not meant for consumers of the library to
22/// implement (only call).
23#[pin_project]
24#[derive(Debug)]
25pub struct Worker<T, Request>
26where
27 T: Service<Request>,
28 T::Error: Into<Error>,
29{
30 current_message: Option<Message<Request, T::Future>>,
31 rx: mpsc::Receiver<Message<Request, T::Future>>,
32 service: T,
33 finish: bool,
34 failed: Option<ServiceError>,
35 handle: Handle,
36}
37
38/// Get the error out
39#[derive(Debug)]
40pub(crate) struct Handle {
41 inner: Arc<Mutex<Option<ServiceError>>>,
42}
43
44impl<T, Request> Worker<T, Request>
45where
46 T: Service<Request>,
47 T::Error: Into<Error>,
48{
49 pub(crate) fn new(
50 service: T,
51 rx: mpsc::Receiver<Message<Request, T::Future>>,
52 ) -> (Handle, Worker<T, Request>) {
53 let handle = Handle {
54 inner: Arc::new(Mutex::new(None)),
55 };
56
57 let worker = Worker {
58 current_message: None,
59 finish: false,
60 failed: None,
61 rx,
62 service,
63 handle: handle.clone(),
64 };
65
66 (handle, worker)
67 }
68
69 /// Return the next queued Message that hasn't been canceled.
70 ///
71 /// If a `Message` is returned, the `bool` is true if this is the first time we received this
72 /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
73 fn poll_next_msg(
74 &mut self,
75 cx: &mut Context<'_>,
76 ) -> Poll<Option<(Message<Request, T::Future>, bool)>> {
77 if self.finish {
78 // We've already received None and are shutting down
79 return Poll::Ready(None);
80 }
81
82 tracing::trace!("worker polling for next message");
83 if let Some(mut msg) = self.current_message.take() {
84 // poll_closed returns Poll::Ready is the receiver is dropped.
85 // Returning Pending means it is still alive, so we should still
86 // use it.
87 if msg.tx.poll_closed(cx).is_pending() {
88 tracing::trace!("resuming buffered request");
89 return Poll::Ready(Some((msg, false)));
90 }
91
92 tracing::trace!("dropping cancelled buffered request");
93 }
94
95 // Get the next request
96 while let Some(mut msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
97 if msg.tx.poll_closed(cx).is_pending() {
98 tracing::trace!("processing new request");
99 return Poll::Ready(Some((msg, true)));
100 }
101 // Otherwise, request is canceled, so pop the next one.
102 tracing::trace!("dropping cancelled request");
103 }
104
105 Poll::Ready(None)
106 }
107
108 fn failed(&mut self, error: Error) {
109 // The underlying service failed when we called `poll_ready` on it with the given `error`. We
110 // need to communicate this to all the `Buffer` handles. To do so, we wrap up the error in
111 // an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
112 // requests will also fail with the same error.
113
114 // Note that we need to handle the case where some handle is concurrently trying to send us
115 // a request. We need to make sure that *either* the send of the request fails *or* it
116 // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
117 // case where we send errors to all outstanding requests, and *then* the caller sends its
118 // request. We do this by *first* exposing the error, *then* closing the channel used to
119 // send more requests (so the client will see the error when the send fails), and *then*
120 // sending the error to all outstanding requests.
121 let error = ServiceError::new(error);
122
123 let mut inner = self.handle.inner.lock().unwrap();
124
125 if inner.is_some() {
126 // Future::poll was called after we've already errored out!
127 return;
128 }
129
130 *inner = Some(error.clone());
131 drop(inner);
132
133 self.rx.close();
134
135 // By closing the mpsc::Receiver, we know that poll_next_msg will soon return Ready(None),
136 // which will trigger the `self.finish == true` phase. We just need to make sure that any
137 // requests that we receive before we've exhausted the receiver receive the error:
138 self.failed = Some(error);
139 }
140}
141
142impl<T, Request> Future for Worker<T, Request>
143where
144 T: Service<Request>,
145 T::Error: Into<Error>,
146{
147 type Output = ();
148
149 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
150 if self.finish {
151 return Poll::Ready(());
152 }
153
154 loop {
155 match ready!(self.poll_next_msg(cx)) {
156 Some((msg, first)) => {
157 let _guard = msg.span.enter();
158 if let Some(ref failed) = self.failed {
159 tracing::trace!("notifying caller about worker failure");
160 let _ = msg.tx.send(Err(failed.clone()));
161 continue;
162 }
163
164 // Wait for the service to be ready
165 tracing::trace!(
166 resumed = !first,
167 message = "worker received request; waiting for service readiness"
168 );
169 match self.service.poll_ready(cx) {
170 Poll::Ready(Ok(())) => {
171 tracing::debug!(service.ready = true, message = "processing request");
172 let response = self.service.call(msg.request);
173
174 // Send the response future back to the sender.
175 //
176 // An error means the request had been canceled in-between
177 // our calls, the response future will just be dropped.
178 tracing::trace!("returning response future");
179 let _ = msg.tx.send(Ok(response));
180 }
181 Poll::Pending => {
182 tracing::trace!(service.ready = false, message = "delay");
183 // Put out current message back in its slot.
184 drop(_guard);
185 self.current_message = Some(msg);
186 return Poll::Pending;
187 }
188 Poll::Ready(Err(e)) => {
189 let error = e.into();
190 tracing::debug!({ %error }, "service failed");
191 drop(_guard);
192 self.failed(error);
193 let _ = msg.tx.send(Err(self
194 .failed
195 .as_ref()
196 .expect("Worker::failed did not set self.failed?")
197 .clone()));
198 }
199 }
200 }
201 None => {
202 // No more more requests _ever_.
203 self.finish = true;
204 return Poll::Ready(());
205 }
206 }
207 }
208 }
209}
210
211impl Handle {
212 pub(crate) fn get_error_on_closed(&self) -> Error {
213 self.inner
214 .lock()
215 .unwrap()
216 .as_ref()
217 .map(|svc_err| svc_err.clone().into())
218 .unwrap_or_else(|| Closed::new().into())
219 }
220}
221
222impl Clone for Handle {
223 fn clone(&self) -> Handle {
224 Handle {
225 inner: self.inner.clone(),
226 }
227 }
228}