tower_batch/worker.rs
1use std::{
2 future::Future,
3 mem,
4 ops::Add,
5 pin::Pin,
6 sync::{Arc, Mutex, Weak},
7 task::{Context, Poll},
8 time::{Duration, Instant},
9};
10
11use futures_core::ready;
12use tokio::{
13 sync::{mpsc, Semaphore},
14 time::{sleep_until, Sleep},
15};
16use tower::Service;
17use tracing::{debug, trace};
18
19use super::{
20 error::{Closed, ServiceError},
21 message::{Message, Tx},
22 BatchControl,
23};
24
25/// Shared error state between [`Batch`](crate::Batch) (client) and [`Worker`].
26///
27/// When the worker's inner service fails, the error is stored here so that
28/// `Batch::poll_ready` can retrieve it and propagate it to callers.
29#[derive(Debug)]
30pub(crate) struct Handle {
31 inner: Arc<Mutex<Option<ServiceError>>>,
32}
33
34/// Manages the message channel between [`Batch`](crate::Batch) handles and the [`Worker`].
35///
36/// Receives incoming requests from the unbounded mpsc channel, holds on to a
37/// message when the inner service is not ready (`current_message`), and
38/// propagates errors by storing them in the shared [`Handle`] and closing the
39/// semaphore so that all waiting `Batch` handles are woken.
40#[derive(Debug)]
41struct Bridge<Fut, Request> {
42 rx: mpsc::UnboundedReceiver<Message<Request, Fut>>,
43 handle: Handle,
44 current_message: Option<Message<Request, Fut>>,
45 close: Option<Weak<Semaphore>>,
46 failed: Option<ServiceError>,
47}
48
49/// Accumulates batch items with their oneshot response senders.
50///
51/// Tracks the max-time timer (started when the first item arrives) and
52/// dispatches results — or errors — to all collected senders on flush via
53/// [`notify`](Lot::notify).
54#[derive(Debug)]
55struct Lot<Fut> {
56 max_size: usize,
57 max_time: Duration,
58 responses: Vec<(Tx<Fut>, Result<Fut, ServiceError>)>,
59 time_elapses: Option<Pin<Box<Sleep>>>,
60 time_elapsed: bool,
61}
62
63// Worker state machine.
64//
65// Transitions:
66// - `Collecting` → `Flushing`: batch is full (size) or max time elapsed.
67// - `Flushing` → `Collecting`: flush succeeded, ready for new items.
68// - `Flushing` → `Finished`: flush failed, worker terminates.
69// - `Collecting` → `Finished`: channel closed, no more requests.
70pin_project_lite::pin_project! {
71 #[project = StateProj]
72 #[derive(Debug)]
73 enum State<Fut> {
74 Collecting,
75 Flushing {
76 reason: Option<String>,
77 #[pin]
78 flush_fut: Option<Fut>,
79 },
80 Finished
81 }
82}
83
84pin_project_lite::pin_project! {
85 /// Task that handles processing the buffer. This type should not be used
86 /// directly, instead `Batch` requires an `Executor` that can accept this task.
87 ///
88 /// The struct is `pub` in the private module and the type is *not* re-exported
89 /// as part of the public API. This is the "sealed" pattern to include "private"
90 /// types in public traits that are not meant for consumers of the library to
91 /// implement (only call).
92 #[derive(Debug)]
93 pub struct Worker<T, Request>
94 where
95 T: Service<BatchControl<Request>>,
96 T::Error: Into<crate::BoxError>,
97 {
98 service: T,
99 bridge: Bridge<T::Future, Request>,
100 lot: Lot<T::Future>,
101 #[pin]
102 state: State<T::Future>,
103 }
104}
105
106// ===== impl Worker =====
107
108impl<T, Request> Worker<T, Request>
109where
110 T: Service<BatchControl<Request>>,
111 T::Error: Into<crate::BoxError>,
112{
113 pub(crate) fn new(
114 rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
115 service: T,
116 max_size: usize,
117 max_time: Duration,
118 semaphore: &Arc<Semaphore>,
119 ) -> (Handle, Worker<T, Request>) {
120 trace!("creating Batch worker");
121
122 let handle = Handle {
123 inner: Arc::new(Mutex::new(None)),
124 };
125
126 // The service and worker have a parent - child relationship, so we must
127 // downgrade the Arc to Weak, to ensure a cycle between Arc pointers will
128 // never be deallocated.
129 let semaphore = Arc::downgrade(semaphore);
130 let worker = Self {
131 service,
132 bridge: Bridge {
133 rx,
134 current_message: None,
135 handle: handle.clone(),
136 close: Some(semaphore),
137 failed: None,
138 },
139 lot: Lot::new(max_size, max_time),
140 state: State::Collecting,
141 };
142
143 (handle, worker)
144 }
145}
146
147impl<T, Request> Future for Worker<T, Request>
148where
149 T: Service<BatchControl<Request>>,
150 T::Error: Into<crate::BoxError>,
151{
152 type Output = ();
153
154 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
155 trace!("polling worker");
156
157 let mut this = self.project();
158
159 // Flush if the max wait time is reached.
160 if let Poll::Ready(Some(())) = this.lot.poll_max_time(cx) {
161 this.state.set(State::flushing("time".to_owned(), None))
162 }
163
164 loop {
165 match this.state.as_mut().project() {
166 StateProj::Collecting => {
167 match ready!(this.bridge.poll_next_msg(cx)) {
168 Some((msg, first)) => {
169 let _guard = msg.span.enter();
170
171 trace!(resumed = !first, message = "worker received request");
172
173 // Wait for the service to be ready
174 trace!(message = "waiting for service readiness");
175 match this.service.poll_ready(cx) {
176 Poll::Ready(Ok(())) => {
177 debug!(service.ready = true, message = "adding item");
178
179 let response = this.service.call(msg.request.into());
180 this.lot.add((msg.tx, Ok(response)));
181
182 // Flush if the batch is full.
183 if this.lot.is_full() {
184 this.state.set(State::flushing("size".to_owned(), None));
185 }
186
187 // Or flush if the max time has elapsed.
188 if this.lot.poll_max_time(cx).is_ready() {
189 this.state.set(State::flushing("time".to_owned(), None));
190 }
191 }
192 Poll::Pending => {
193 drop(_guard);
194 debug!(service.ready = false, message = "delay item addition");
195 this.bridge.return_msg(msg);
196 return Poll::Pending;
197 }
198 Poll::Ready(Err(e)) => {
199 drop(_guard);
200 this.bridge.failed("item addition", e.into());
201 if let Some(ref e) = this.bridge.failed {
202 // Ensure the current caller is notified too.
203 this.lot.add((msg.tx, Err(e.clone())));
204 this.lot.notify(Some(e.clone()));
205 }
206 }
207 }
208 }
209 None => {
210 trace!("shutting down, no more requests _ever_");
211 this.state.set(State::Finished);
212 return Poll::Ready(());
213 }
214 }
215 }
216 StateProj::Flushing { reason, flush_fut } => match flush_fut.as_pin_mut() {
217 None => {
218 trace!(
219 reason = reason.as_mut().unwrap().as_str(),
220 message = "waiting for service readiness"
221 );
222 match this.service.poll_ready(cx) {
223 Poll::Ready(Ok(())) => {
224 debug!(
225 service.ready = true,
226 reason = reason.as_mut().unwrap().as_str(),
227 message = "flushing batch"
228 );
229 let response = this.service.call(BatchControl::Flush);
230 let reason = reason.take().expect("missing reason");
231 this.state.set(State::flushing(reason, Some(response)));
232 }
233 Poll::Pending => {
234 debug!(
235 service.ready = false,
236 reason = reason.as_mut().unwrap().as_str(),
237 message = "delay flush"
238 );
239 return Poll::Pending;
240 }
241 Poll::Ready(Err(e)) => {
242 this.bridge.failed("flush", e.into());
243 if let Some(ref e) = this.bridge.failed {
244 this.lot.notify(Some(e.clone()));
245 }
246 this.state.set(State::Finished);
247 return Poll::Ready(());
248 }
249 }
250 }
251 Some(future) => match ready!(future.poll(cx)) {
252 Ok(_) => {
253 debug!(reason = reason.as_mut().unwrap().as_str(), "batch flushed");
254 this.lot.notify(None);
255 this.state.set(State::Collecting)
256 }
257 Err(e) => {
258 this.bridge.failed("flush", e.into());
259 if let Some(ref e) = this.bridge.failed {
260 this.lot.notify(Some(e.clone()));
261 }
262 this.state.set(State::Finished);
263 return Poll::Ready(());
264 }
265 },
266 },
267 StateProj::Finished => {
268 // We've already received None and are shutting down
269 return Poll::Ready(());
270 }
271 }
272 }
273 }
274}
275
276// ===== impl State =====
277
278impl<Fut> State<Fut> {
279 fn flushing(reason: String, f: Option<Fut>) -> Self {
280 Self::Flushing {
281 reason: Some(reason),
282 flush_fut: f,
283 }
284 }
285}
286
287// ===== impl Bridge =====
288
289impl<Fut, Request> Drop for Bridge<Fut, Request> {
290 fn drop(&mut self) {
291 self.close_semaphore()
292 }
293}
294
295impl<Fut, Request> Bridge<Fut, Request> {
296 /// Closes the buffer's semaphore if it is still open, waking any pending tasks.
297 fn close_semaphore(&mut self) {
298 if let Some(close) = self
299 .close
300 .take()
301 .as_ref()
302 .and_then(Weak::<Semaphore>::upgrade)
303 {
304 debug!("buffer closing; waking pending tasks");
305 close.close();
306 } else {
307 trace!("buffer already closed");
308 }
309 }
310
311 fn failed(&mut self, action: &str, error: crate::BoxError) {
312 debug!(action, %error , "service failed");
313
314 // The underlying service failed when we called `poll_ready` on it with the given `error`.
315 // We need to communicate this to all the `Buffer` handles. To do so, we wrap up the error
316 // in an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
317 // requests will also fail with the same error.
318
319 // Note that we need to handle the case where some handle is concurrently trying to send us
320 // a request. We need to make sure that *either* the send of the request fails *or* it
321 // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
322 // case where we send errors to all outstanding requests, and *then* the caller sends its
323 // request. We do this by *first* exposing the error, *then* closing the channel used to
324 // send more requests (so the client will see the error when the send fails), and *then*
325 // sending the error to all outstanding requests.
326 let error = ServiceError::new(error);
327
328 let mut inner = self.handle.inner.lock().unwrap();
329
330 if inner.is_some() {
331 // Future::poll was called after we've already errored out!
332 return;
333 }
334
335 *inner = Some(error.clone());
336 drop(inner);
337
338 self.rx.close();
339
340 // Wake any tasks waiting on channel capacity.
341 self.close_semaphore();
342
343 // By closing the mpsc::Receiver, we know that that the run() loop will drain all pending
344 // requests. We just need to make sure that any requests that we receive before we've
345 // exhausted the receiver receive the error:
346 self.failed = Some(error);
347 }
348
349 /// Return the next queued Message that hasn't been canceled.
350 ///
351 /// If a `Message` is returned, the `bool` is true if this is the first time we received this
352 /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
353 fn poll_next_msg(
354 &mut self,
355 cx: &mut Context<'_>,
356 ) -> Poll<Option<(Message<Request, Fut>, bool)>> {
357 trace!("worker polling for next message");
358
359 // Pick any delayed request first
360 if let Some(msg) = self.current_message.take() {
361 // If the oneshot sender is closed, then the receiver is dropped, and nobody cares about
362 // the response. If this is the case, we should continue to the next request.
363 if !msg.tx.is_closed() {
364 trace!("resuming buffered request");
365 return Poll::Ready(Some((msg, false)));
366 }
367
368 trace!("dropping cancelled buffered request");
369 }
370
371 // Get the next request
372 while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
373 if !msg.tx.is_closed() {
374 trace!("processing new request");
375 return Poll::Ready(Some((msg, true)));
376 }
377
378 // Otherwise, request is canceled, so pop the next one.
379 trace!("dropping cancelled request");
380 }
381
382 Poll::Ready(None)
383 }
384
385 fn return_msg(&mut self, msg: Message<Request, Fut>) {
386 self.current_message = Some(msg)
387 }
388}
389
390// ===== impl Lot =====
391
392impl<Fut> Lot<Fut> {
393 fn new(max_size: usize, max_time: Duration) -> Self {
394 Self {
395 max_size,
396 max_time,
397 responses: Vec::with_capacity(max_size),
398 time_elapses: None,
399 time_elapsed: false,
400 }
401 }
402
403 fn poll_max_time(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
404 // When the Worker is polled and the time has elapsed, we return `Some` to let the Worker
405 // know it's time to enter the Flushing state. Subsequent polls (e.g. by the Flush future)
406 // will return None to prevent the Worker from getting stuck in an endless loop of entering
407 // the Flushing state.
408 if self.time_elapsed {
409 return Poll::Ready(None);
410 }
411
412 if let Some(ref mut sleep) = self.time_elapses {
413 if Pin::new(sleep).poll(cx).is_ready() {
414 self.time_elapsed = true;
415 return Poll::Ready(Some(()));
416 }
417 }
418
419 Poll::Pending
420 }
421
422 fn is_full(&self) -> bool {
423 self.responses.len() == self.max_size
424 }
425
426 fn add(&mut self, item: (Tx<Fut>, Result<Fut, ServiceError>)) {
427 if self.responses.is_empty() {
428 self.time_elapses = Some(Box::pin(sleep_until(
429 Instant::now().add(self.max_time).into(),
430 )));
431 }
432 self.responses.push(item);
433 }
434
435 fn notify(&mut self, err: Option<ServiceError>) {
436 for (tx, response) in mem::replace(&mut self.responses, Vec::with_capacity(self.max_size)) {
437 if let Some(ref response) = err {
438 let _ = tx.send(Err(response.clone()));
439 } else {
440 let _ = tx.send(response);
441 }
442 }
443 self.time_elapses = None;
444 self.time_elapsed = false;
445 }
446}
447
448// ===== impl Handle =====
449
450impl Handle {
451 pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
452 self.inner
453 .lock()
454 .unwrap()
455 .as_ref()
456 .map(|svc_err| svc_err.clone().into())
457 .unwrap_or_else(|| Closed::new().into())
458 }
459}
460
461impl Clone for Handle {
462 fn clone(&self) -> Self {
463 Handle {
464 inner: self.inner.clone(),
465 }
466 }
467}