tower_batch/service.rs
1use std::{
2 fmt::Debug,
3 sync::Arc,
4 task::{Context, Poll},
5};
6
7use futures_core::ready;
8use tokio::sync::{mpsc, oneshot, OwnedSemaphorePermit, Semaphore};
9use tokio_util::sync::PollSemaphore;
10use tower::Service;
11
12use super::{
13 future::ResponseFuture,
14 message::Message,
15 worker::{Handle, Worker},
16 BatchControl,
17};
18
19/// Allows batch processing of requests.
20///
21/// See the module documentation for more details.
22#[derive(Debug)]
23pub struct Batch<T, Request>
24where
25 T: Service<BatchControl<Request>>,
26{
27 // Note: this actually _is_ bounded, but rather than using Tokio's bounded
28 // channel, we use Tokio's semaphore separately to implement the bound.
29 tx: mpsc::UnboundedSender<Message<Request, T::Future>>,
30
31 // When the buffer's channel is full, we want to exert backpressure in
32 // `poll_ready`, so that callers such as load balancers could choose to call
33 // another service rather than waiting for buffer capacity.
34 //
35 // Unfortunately, this can't be done easily using Tokio's bounded MPSC
36 // channel, because it doesn't expose a polling-based interface, only an
37 // `async fn ready`, which borrows the sender. Therefore, we implement our
38 // own bounded MPSC on top of the unbounded channel, using a semaphore to
39 // limit how many items are in the channel.
40 semaphore: PollSemaphore,
41
42 // The current semaphore permit, if one has been acquired.
43 //
44 // This is acquired in `poll_ready` and taken in `call`.
45 permit: Option<OwnedSemaphorePermit>,
46 handle: Handle,
47}
48
49impl<T, Request> Batch<T, Request>
50where
51 T: Service<BatchControl<Request>>,
52 T::Error: Into<crate::BoxError>,
53{
54 /// Creates a new `Batch` wrapping `service`.
55 ///
56 /// The wrapper is responsible for telling the inner service when to flush a
57 /// batch of requests.
58 ///
59 /// The default Tokio executor is used to run the given service, which means
60 /// that this method must be called while on the Tokio runtime.
61 pub fn new(service: T, size: usize, time: std::time::Duration) -> Self
62 where
63 T: Send + 'static,
64 T::Future: Send,
65 T::Error: Send + Sync,
66 Request: Send + 'static,
67 {
68 let (service, worker) = Self::pair(service, size, time);
69 tokio::spawn(worker);
70 service
71 }
72
73 /// Creates a new `Batch` wrapping `service`, but returns the background worker.
74 ///
75 /// This is useful if you do not want to spawn directly onto the `tokio`
76 /// runtime but instead want to use your own executor. This will return the
77 /// `Batch` and the background `Worker` that you can then spawn.
78 pub fn pair(service: T, size: usize, time: std::time::Duration) -> (Self, Worker<T, Request>)
79 where
80 T: Send + 'static,
81 T::Future: Send,
82 T::Error: Send + Sync,
83 Request: Send + 'static,
84 {
85 // The semaphore bound limits the maximum number of concurrent requests
86 // (specifically, requests which got a `Ready` from `poll_ready`, but haven't
87 // used their semaphore reservation in a `call` yet).
88 // We choose a bound that allows callers to check readiness for every item in
89 // a batch, then actually submit those items.
90 let (tx, rx) = mpsc::unbounded_channel();
91 let bound = size;
92 let semaphore = Arc::new(Semaphore::new(bound));
93
94 let (handle, worker) = Worker::new(rx, service, size, time, &semaphore);
95
96 let batch = Self {
97 tx,
98 semaphore: PollSemaphore::new(semaphore),
99 permit: None,
100 handle,
101 };
102 (batch, worker)
103 }
104
105 fn get_worker_error(&self) -> crate::BoxError {
106 self.handle.get_error_on_closed()
107 }
108}
109
110impl<T, Request> Service<Request> for Batch<T, Request>
111where
112 T: Service<BatchControl<Request>>,
113 T::Error: Into<crate::BoxError>,
114{
115 // Our response is effectively the response of the service used by the Worker
116 type Response = T::Response;
117 type Error = crate::BoxError;
118 type Future = ResponseFuture<T::Future>;
119
120 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
121 tracing::debug!("checking if service is ready");
122
123 // First, check if the worker is still alive.
124 if self.tx.is_closed() {
125 // If the inner service has errored, then we error here.
126 return Poll::Ready(Err(self.get_worker_error()));
127 }
128
129 // Then, check if we've already acquired a permit.
130 if self.permit.is_some() {
131 // We've already reserved capacity to send a request. We're ready!
132 return Poll::Ready(Ok(()));
133 }
134
135 // Finally, if we haven't already acquired a permit, poll the semaphore to acquire one. If
136 // we acquire a permit, then there's enough buffer capacity to send a new request.
137 // Otherwise, we need to wait for capacity.
138 //
139 // The current task must be scheduled for wakeup every time we return `Poll::Pending`. If
140 // it returns Pending, the semaphore also schedules the task for wakeup when the next permit
141 // is available.
142 let permit =
143 ready!(self.semaphore.poll_acquire(cx)).ok_or_else(|| self.get_worker_error())?;
144 self.permit = Some(permit);
145
146 Poll::Ready(Ok(()))
147 }
148
149 fn call(&mut self, request: Request) -> Self::Future {
150 tracing::debug!("sending request to batch worker");
151
152 let _permit = self
153 .permit
154 .take()
155 .expect("batch full; poll_ready must be called first");
156
157 // Get the current Span so that we can explicitly propagate it to the worker
158 // if we didn't do this, events on the worker related to this span wouldn't be counted
159 // towards that span since the worker would have no way of entering it.
160 let span = tracing::Span::current();
161
162 // If we've made it here, then a semaphore permit has already been acquired, so we can
163 // freely allocate a oneshot.
164 let (tx, rx) = oneshot::channel();
165
166 // The worker is in control of completing the request now.
167 match self.tx.send(Message {
168 request,
169 span,
170 tx,
171 _permit,
172 }) {
173 Err(_) => ResponseFuture::failed(self.get_worker_error()),
174 Ok(_) => ResponseFuture::new(rx),
175 }
176 }
177}
178
179impl<T, Request> Clone for Batch<T, Request>
180where
181 T: Service<BatchControl<Request>>,
182{
183 fn clone(&self) -> Self {
184 Self {
185 tx: self.tx.clone(),
186 semaphore: self.semaphore.clone(),
187 handle: self.handle.clone(),
188
189 // The new clone hasn't acquired a permit yet. It will when it's next polled ready.
190 permit: None,
191 }
192 }
193}