Skip to main content

tower_batch/
worker.rs

1use std::{
2    future::Future,
3    mem,
4    ops::Add,
5    pin::Pin,
6    sync::{Arc, Mutex, Weak},
7    task::{Context, Poll},
8    time::{Duration, Instant},
9};
10
11use futures_core::ready;
12use tokio::{
13    sync::{mpsc, Semaphore},
14    time::{sleep_until, Sleep},
15};
16use tower::Service;
17use tracing::{debug, trace};
18
19use super::{
20    error::{Closed, ServiceError},
21    message::{Message, Tx},
22    BatchControl,
23};
24
25/// Shared error state between [`Batch`](crate::Batch) (client) and [`Worker`].
26///
27/// When the worker's inner service fails, the error is stored here so that
28/// `Batch::poll_ready` can retrieve it and propagate it to callers.
29#[derive(Debug)]
30pub(crate) struct Handle {
31    inner: Arc<Mutex<Option<ServiceError>>>,
32}
33
34/// Manages the message channel between [`Batch`](crate::Batch) handles and the [`Worker`].
35///
36/// Receives incoming requests from the unbounded mpsc channel, holds on to a
37/// message when the inner service is not ready (`current_message`), and
38/// propagates errors by storing them in the shared [`Handle`] and closing the
39/// semaphore so that all waiting `Batch` handles are woken.
40#[derive(Debug)]
41struct Bridge<Fut, Request> {
42    rx: mpsc::UnboundedReceiver<Message<Request, Fut>>,
43    handle: Handle,
44    current_message: Option<Message<Request, Fut>>,
45    close: Option<Weak<Semaphore>>,
46    failed: Option<ServiceError>,
47}
48
49/// Accumulates batch items with their oneshot response senders.
50///
51/// Tracks the max-time timer (started when the first item arrives) and
52/// dispatches results — or errors — to all collected senders on flush via
53/// [`notify`](Lot::notify).
54#[derive(Debug)]
55struct Lot<Fut> {
56    max_size: usize,
57    max_time: Duration,
58    responses: Vec<(Tx<Fut>, Result<Fut, ServiceError>)>,
59    time_elapses: Option<Pin<Box<Sleep>>>,
60    time_elapsed: bool,
61}
62
63// Worker state machine.
64//
65// Transitions:
66// - `Collecting` → `Flushing`: batch is full (size) or max time elapsed.
67// - `Flushing` → `Collecting`: flush succeeded, ready for new items.
68// - `Flushing` → `Finished`: flush failed, worker terminates.
69// - `Collecting` → `Finished`: channel closed, no more requests.
70pin_project_lite::pin_project! {
71    #[project = StateProj]
72    #[derive(Debug)]
73    enum State<Fut> {
74        Collecting,
75        Flushing {
76            reason: Option<String>,
77            #[pin]
78            flush_fut: Option<Fut>,
79        },
80        Finished
81    }
82}
83
84pin_project_lite::pin_project! {
85    /// Task that handles processing the buffer. This type should not be used
86    /// directly, instead `Batch` requires an `Executor` that can accept this task.
87    ///
88    /// The struct is `pub` in the private module and the type is *not* re-exported
89    /// as part of the public API. This is the "sealed" pattern to include "private"
90    /// types in public traits that are not meant for consumers of the library to
91    /// implement (only call).
92    #[derive(Debug)]
93    pub struct Worker<T, Request>
94    where
95        T: Service<BatchControl<Request>>,
96        T::Error: Into<crate::BoxError>,
97    {
98        service: T,
99        bridge: Bridge<T::Future, Request>,
100        lot: Lot<T::Future>,
101        #[pin]
102        state: State<T::Future>,
103    }
104}
105
106// ===== impl Worker =====
107
108impl<T, Request> Worker<T, Request>
109where
110    T: Service<BatchControl<Request>>,
111    T::Error: Into<crate::BoxError>,
112{
113    pub(crate) fn new(
114        rx: mpsc::UnboundedReceiver<Message<Request, T::Future>>,
115        service: T,
116        max_size: usize,
117        max_time: Duration,
118        semaphore: &Arc<Semaphore>,
119    ) -> (Handle, Worker<T, Request>) {
120        trace!("creating Batch worker");
121
122        let handle = Handle {
123            inner: Arc::new(Mutex::new(None)),
124        };
125
126        // The service and worker have a parent - child relationship, so we must
127        // downgrade the Arc to Weak, to ensure a cycle between Arc pointers will
128        // never be deallocated.
129        let semaphore = Arc::downgrade(semaphore);
130        let worker = Self {
131            service,
132            bridge: Bridge {
133                rx,
134                current_message: None,
135                handle: handle.clone(),
136                close: Some(semaphore),
137                failed: None,
138            },
139            lot: Lot::new(max_size, max_time),
140            state: State::Collecting,
141        };
142
143        (handle, worker)
144    }
145}
146
147impl<T, Request> Future for Worker<T, Request>
148where
149    T: Service<BatchControl<Request>>,
150    T::Error: Into<crate::BoxError>,
151{
152    type Output = ();
153
154    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
155        trace!("polling worker");
156
157        let mut this = self.project();
158
159        // Flush if the max wait time is reached.
160        if let Poll::Ready(Some(())) = this.lot.poll_max_time(cx) {
161            this.state.set(State::flushing("time".to_owned(), None))
162        }
163
164        loop {
165            match this.state.as_mut().project() {
166                StateProj::Collecting => {
167                    match ready!(this.bridge.poll_next_msg(cx)) {
168                        Some((msg, first)) => {
169                            let _guard = msg.span.enter();
170
171                            trace!(resumed = !first, message = "worker received request");
172
173                            // Wait for the service to be ready
174                            trace!(message = "waiting for service readiness");
175                            match this.service.poll_ready(cx) {
176                                Poll::Ready(Ok(())) => {
177                                    debug!(service.ready = true, message = "adding item");
178
179                                    let response = this.service.call(msg.request.into());
180                                    this.lot.add((msg.tx, Ok(response)));
181
182                                    // Flush if the batch is full.
183                                    if this.lot.is_full() {
184                                        this.state.set(State::flushing("size".to_owned(), None));
185                                    }
186
187                                    // Or flush if the max time has elapsed.
188                                    if this.lot.poll_max_time(cx).is_ready() {
189                                        this.state.set(State::flushing("time".to_owned(), None));
190                                    }
191                                }
192                                Poll::Pending => {
193                                    drop(_guard);
194                                    debug!(service.ready = false, message = "delay item addition");
195                                    this.bridge.return_msg(msg);
196                                    return Poll::Pending;
197                                }
198                                Poll::Ready(Err(e)) => {
199                                    drop(_guard);
200                                    this.bridge.failed("item addition", e.into());
201                                    if let Some(ref e) = this.bridge.failed {
202                                        // Ensure the current caller is notified too.
203                                        this.lot.add((msg.tx, Err(e.clone())));
204                                        this.lot.notify(Some(e.clone()));
205                                    }
206                                }
207                            }
208                        }
209                        None => {
210                            trace!("shutting down, no more requests _ever_");
211                            this.state.set(State::Finished);
212                            return Poll::Ready(());
213                        }
214                    }
215                }
216                StateProj::Flushing { reason, flush_fut } => match flush_fut.as_pin_mut() {
217                    None => {
218                        trace!(
219                            reason = reason.as_mut().unwrap().as_str(),
220                            message = "waiting for service readiness"
221                        );
222                        match this.service.poll_ready(cx) {
223                            Poll::Ready(Ok(())) => {
224                                debug!(
225                                    service.ready = true,
226                                    reason = reason.as_mut().unwrap().as_str(),
227                                    message = "flushing batch"
228                                );
229                                let response = this.service.call(BatchControl::Flush);
230                                let reason = reason.take().expect("missing reason");
231                                this.state.set(State::flushing(reason, Some(response)));
232                            }
233                            Poll::Pending => {
234                                debug!(
235                                    service.ready = false,
236                                    reason = reason.as_mut().unwrap().as_str(),
237                                    message = "delay flush"
238                                );
239                                return Poll::Pending;
240                            }
241                            Poll::Ready(Err(e)) => {
242                                this.bridge.failed("flush", e.into());
243                                if let Some(ref e) = this.bridge.failed {
244                                    this.lot.notify(Some(e.clone()));
245                                }
246                                this.state.set(State::Finished);
247                                return Poll::Ready(());
248                            }
249                        }
250                    }
251                    Some(future) => match ready!(future.poll(cx)) {
252                        Ok(_) => {
253                            debug!(reason = reason.as_mut().unwrap().as_str(), "batch flushed");
254                            this.lot.notify(None);
255                            this.state.set(State::Collecting)
256                        }
257                        Err(e) => {
258                            this.bridge.failed("flush", e.into());
259                            if let Some(ref e) = this.bridge.failed {
260                                this.lot.notify(Some(e.clone()));
261                            }
262                            this.state.set(State::Finished);
263                            return Poll::Ready(());
264                        }
265                    },
266                },
267                StateProj::Finished => {
268                    // We've already received None and are shutting down
269                    return Poll::Ready(());
270                }
271            }
272        }
273    }
274}
275
276// ===== impl State =====
277
278impl<Fut> State<Fut> {
279    fn flushing(reason: String, f: Option<Fut>) -> Self {
280        Self::Flushing {
281            reason: Some(reason),
282            flush_fut: f,
283        }
284    }
285}
286
287// ===== impl Bridge =====
288
289impl<Fut, Request> Drop for Bridge<Fut, Request> {
290    fn drop(&mut self) {
291        self.close_semaphore()
292    }
293}
294
295impl<Fut, Request> Bridge<Fut, Request> {
296    /// Closes the buffer's semaphore if it is still open, waking any pending tasks.
297    fn close_semaphore(&mut self) {
298        if let Some(close) = self
299            .close
300            .take()
301            .as_ref()
302            .and_then(Weak::<Semaphore>::upgrade)
303        {
304            debug!("buffer closing; waking pending tasks");
305            close.close();
306        } else {
307            trace!("buffer already closed");
308        }
309    }
310
311    fn failed(&mut self, action: &str, error: crate::BoxError) {
312        debug!(action,  %error , "service failed");
313
314        // The underlying service failed when we called `poll_ready` on it with the given `error`.
315        // We need to communicate this to all the `Buffer` handles. To do so, we wrap up the error
316        // in an `Arc`, send that `Arc<E>` to all pending requests, and store it so that subsequent
317        // requests will also fail with the same error.
318
319        // Note that we need to handle the case where some handle is concurrently trying to send us
320        // a request. We need to make sure that *either* the send of the request fails *or* it
321        // receives an error on the `oneshot` it constructed. Specifically, we want to avoid the
322        // case where we send errors to all outstanding requests, and *then* the caller sends its
323        // request. We do this by *first* exposing the error, *then* closing the channel used to
324        // send more requests (so the client will see the error when the send fails), and *then*
325        // sending the error to all outstanding requests.
326        let error = ServiceError::new(error);
327
328        let mut inner = self.handle.inner.lock().unwrap();
329
330        if inner.is_some() {
331            // Future::poll was called after we've already errored out!
332            return;
333        }
334
335        *inner = Some(error.clone());
336        drop(inner);
337
338        self.rx.close();
339
340        // Wake any tasks waiting on channel capacity.
341        self.close_semaphore();
342
343        // By closing the mpsc::Receiver, we know that that the run() loop will drain all pending
344        // requests. We just need to make sure that any requests that we receive before we've
345        // exhausted the receiver receive the error:
346        self.failed = Some(error);
347    }
348
349    /// Return the next queued Message that hasn't been canceled.
350    ///
351    /// If a `Message` is returned, the `bool` is true if this is the first time we received this
352    /// message, and false otherwise (i.e., we tried to forward it to the backing service before).
353    fn poll_next_msg(
354        &mut self,
355        cx: &mut Context<'_>,
356    ) -> Poll<Option<(Message<Request, Fut>, bool)>> {
357        trace!("worker polling for next message");
358
359        // Pick any delayed request first
360        if let Some(msg) = self.current_message.take() {
361            // If the oneshot sender is closed, then the receiver is dropped, and nobody cares about
362            // the response. If this is the case, we should continue to the next request.
363            if !msg.tx.is_closed() {
364                trace!("resuming buffered request");
365                return Poll::Ready(Some((msg, false)));
366            }
367
368            trace!("dropping cancelled buffered request");
369        }
370
371        // Get the next request
372        while let Some(msg) = ready!(Pin::new(&mut self.rx).poll_recv(cx)) {
373            if !msg.tx.is_closed() {
374                trace!("processing new request");
375                return Poll::Ready(Some((msg, true)));
376            }
377
378            // Otherwise, request is canceled, so pop the next one.
379            trace!("dropping cancelled request");
380        }
381
382        Poll::Ready(None)
383    }
384
385    fn return_msg(&mut self, msg: Message<Request, Fut>) {
386        self.current_message = Some(msg)
387    }
388}
389
390// ===== impl Lot =====
391
392impl<Fut> Lot<Fut> {
393    fn new(max_size: usize, max_time: Duration) -> Self {
394        Self {
395            max_size,
396            max_time,
397            responses: Vec::with_capacity(max_size),
398            time_elapses: None,
399            time_elapsed: false,
400        }
401    }
402
403    fn poll_max_time(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
404        // When the Worker is polled and the time has elapsed, we return `Some` to let the Worker
405        // know it's time to enter the Flushing state. Subsequent polls (e.g. by the Flush future)
406        // will return None to prevent the Worker from getting stuck in an endless loop of entering
407        // the Flushing state.
408        if self.time_elapsed {
409            return Poll::Ready(None);
410        }
411
412        if let Some(ref mut sleep) = self.time_elapses {
413            if Pin::new(sleep).poll(cx).is_ready() {
414                self.time_elapsed = true;
415                return Poll::Ready(Some(()));
416            }
417        }
418
419        Poll::Pending
420    }
421
422    fn is_full(&self) -> bool {
423        self.responses.len() == self.max_size
424    }
425
426    fn add(&mut self, item: (Tx<Fut>, Result<Fut, ServiceError>)) {
427        if self.responses.is_empty() {
428            self.time_elapses = Some(Box::pin(sleep_until(
429                Instant::now().add(self.max_time).into(),
430            )));
431        }
432        self.responses.push(item);
433    }
434
435    fn notify(&mut self, err: Option<ServiceError>) {
436        for (tx, response) in mem::replace(&mut self.responses, Vec::with_capacity(self.max_size)) {
437            if let Some(ref response) = err {
438                let _ = tx.send(Err(response.clone()));
439            } else {
440                let _ = tx.send(response);
441            }
442        }
443        self.time_elapses = None;
444        self.time_elapsed = false;
445    }
446}
447
448// ===== impl Handle =====
449
450impl Handle {
451    pub(crate) fn get_error_on_closed(&self) -> crate::BoxError {
452        self.inner
453            .lock()
454            .unwrap()
455            .as_ref()
456            .map(|svc_err| svc_err.clone().into())
457            .unwrap_or_else(|| Closed::new().into())
458    }
459}
460
461impl Clone for Handle {
462    fn clone(&self) -> Self {
463        Handle {
464            inner: self.inner.clone(),
465        }
466    }
467}