batched_fn/lib.rs
1//! Deep learning models are usually implemented to make efficient use of a GPU by batching inputs together
2//! in "mini-batches". However, applications serving these models often receive requests one-by-one.
3//! So using a conventional single or multi-threaded server approach will under-utilize the GPU and lead to latency that increases
4//! linearly with the volume of requests.
5//!
6//! `batched-fn` is a drop-in solution for deep learning webservers that queues individual requests and provides them as a batch
7//! to your model. It can be added to any application with minimal refactoring simply by inserting the [`batched_fn`](crate::batched_fn)
8//! macro into the function that runs requests through the model.
9//!
10//! ## Features
11//!
12//! - 🚀 Easy to use: drop the `batched_fn!` macro into existing code.
13//! - 🔥 Lightweight and fast: queue system implemented on top of the blazingly fast [flume crate](https://github.com/zesterer/flume).
14//! - 🙌 Easy to tune: simply adjust [`max_delay`](crate::batched_fn#config) and [`max_batch_size`](crate::batched_fn#config).
15//! - 🛑 [Back pressure](https://medium.com/@jayphelps/backpressure-explained-the-flow-of-data-through-software-2350b3e77ce7) mechanism included:
16//! just set [`channel_cap`](crate::batched_fn#config) and handle
17//! [`Error::Full`](crate::Error#variant.Full) by returning a 503 from your webserver.
18//!
19//! ## Examples
20//!
21//! Suppose you have a model API that look like this:
22//!
23//! ```rust
24//! // `Batch` could be anything that implements the `batched_fn::Batch` trait.
25//! type Batch<T> = Vec<T>;
26//!
27//! #[derive(Debug)]
28//! struct Input {
29//! // ...
30//! }
31//!
32//! #[derive(Debug)]
33//! struct Output {
34//! // ...
35//! }
36//!
37//! struct Model {
38//! // ...
39//! }
40//!
41//! impl Model {
42//! fn predict(&self, batch: Batch<Input>) -> Batch<Output> {
43//! // ...
44//! # batch.iter().map(|_| Output {}).collect()
45//! }
46//!
47//! fn load() -> Self {
48//! // ...
49//! # Self {}
50//! }
51//! }
52//! ```
53//!
54//! Without `batched-fn` a webserver route would need to call `Model::predict` on each
55//! individual input, resulting in a bottleneck from under-utilizing the GPU:
56//!
57//! ```rust
58//! use once_cell::sync::Lazy;
59//! # use batched_fn::{batched_fn, Batch as BatchTrait};
60//! # type Batch<T> = Vec<T>;
61//! # #[derive(Debug)]
62//! # struct Input {}
63//! # #[derive(Debug)]
64//! # struct Output {}
65//! # struct Model {}
66//! # impl Model {
67//! # fn predict(&self, batch: Batch<Input>) -> Batch<Output> {
68//! # batch.iter().map(|_| Output {}).collect()
69//! # }
70//! # fn load() -> Self { Self {} }
71//! # }
72//! static MODEL: Lazy<Model> = Lazy::new(Model::load);
73//!
74//! fn predict_for_http_request(input: Input) -> Output {
75//! let mut batched_input = Batch::with_capacity(1);
76//! batched_input.push(input);
77//! MODEL.predict(batched_input).pop().unwrap()
78//! }
79//! ```
80//!
81//! But by dropping the [`batched_fn`](crate::batched_fn) macro into your code you automatically get batched
82//! inference behind the scenes without changing the one-to-one relationship between inputs and
83//! outputs:
84//!
85//! ```rust
86//! # use batched_fn::{batched_fn, Batch as BatchTrait};
87//! # type Batch<T> = Vec<T>;
88//! # #[derive(Debug)]
89//! # struct Input {}
90//! # #[derive(Debug)]
91//! # struct Output {}
92//! # struct Model {}
93//! # impl Model {
94//! # fn predict(&self, batch: Batch<Input>) -> Batch<Output> {
95//! # batch.iter().map(|_| Output {}).collect()
96//! # }
97//! # fn load() -> Self { Self {} }
98//! # }
99//! async fn predict_for_http_request(input: Input) -> Output {
100//! let batch_predict = batched_fn! {
101//! handler = |batch: Batch<Input>, model: &Model| -> Batch<Output> {
102//! model.predict(batch)
103//! };
104//! config = {
105//! max_batch_size: 16,
106//! max_delay: 50,
107//! };
108//! context = {
109//! model: Model::load(),
110//! };
111//! };
112//! batch_predict(input).await.unwrap()
113//! }
114//! ```
115//!
116//! ❗️ *Note that the `predict_for_http_request` function now has to be `async`.*
117//!
118//! Here we set the [`max_batch_size`](crate::batched_fn#config) to 16 and [`max_delay`](crate::batched_fn#config)
119//! to 50 milliseconds. This means the batched function will wait at most 50 milliseconds after receiving a single
120//! input to fill a batch of 16. If 15 more inputs are not received within 50 milliseconds
121//! then the partial batch will be ran as-is.
122//!
123//! ## Tuning max batch size and max delay
124//!
125//! The optimal batch size and delay will depend on the specifics of your use case, such as how big of a batch you can fit in memory
126//! (typically on the order of 8, 16, 32, or 64 for a deep learning model) and how long of a delay you can afford.
127//! In general you want to set `max_batch_size` as high as you can, assuming the total processing time for `N` examples is minimized
128//! with a batch size of `N`, and keep `max_delay` small relative to the time it takes for your
129//! handler function to process a batch.
130//!
131//! ## Implementation details
132//!
133//! When the `batched_fn` macro is invoked it spawns a new thread where the
134//! [`handler`](crate::batched_fn#handler) will
135//! be ran. Within that thread, every object specified in the [`context`](crate::batched_fn#context)
136//! is initialized and then passed by reference to the handler each time it is run.
137//!
138//! The object returned by the macro is just a closure that sends a single input and a callback
139//! through an asyncronous channel to the handler thread. When the handler finishes
140//! running a batch it invokes the callback corresponding to each input with the corresponding output,
141//! which triggers the closure to wake up and return the output.
142
143extern crate flume;
144extern crate once_cell;
145
146#[doc(hidden)]
147pub use flume::{bounded, unbounded, Sender};
148#[doc(hidden)]
149pub use once_cell::sync::Lazy;
150
151/// The `Batch` trait is essentially an abstraction of `Vec<T>`. The input and output of a batch
152/// [`handler`](crate::batched_fn#handler) must implement `Batch`.
153///
154/// It represents an owned collection of ordered items of a single type.
155pub trait Batch: IntoIterator<Item = <Self as Batch>::Item> {
156 type Item;
157
158 fn with_capacity(n: usize) -> Self;
159
160 fn len(&self) -> usize;
161
162 fn push(&mut self, item: <Self as Batch>::Item);
163
164 fn is_empty(&self) -> bool {
165 self.len() == 0
166 }
167}
168
169impl<T> Batch for Vec<T> {
170 type Item = T;
171
172 fn with_capacity(n: usize) -> Vec<T> {
173 Vec::<T>::with_capacity(n)
174 }
175
176 fn len(&self) -> usize {
177 self.len()
178 }
179
180 fn push(&mut self, item: T) {
181 self.push(item);
182 }
183}
184
185#[doc(hidden)]
186pub struct Config {
187 pub max_batch_size: usize,
188 pub max_delay: u128,
189 pub channel_cap: Option<usize>,
190 // Used to avoid clippy linting errors within the macro-generated code
191 // when updating the fields of this struct.
192 pub _phantom: std::marker::PhantomData<bool>,
193}
194
195impl Default for Config {
196 fn default() -> Self {
197 Self {
198 max_batch_size: 8,
199 max_delay: 50,
200 channel_cap: None,
201 _phantom: std::marker::PhantomData,
202 }
203 }
204}
205
206/// Error types that can occur while calling a batched function.
207#[derive(Debug, Copy, Clone)]
208pub enum Error {
209 /// Channel is full.
210 ///
211 /// This can happen if you've set `channel_cap`, and should usually be handled
212 /// by returning a 503 error code from your server to signal that the server is too
213 /// busy at the moment to handle any more requests.
214 Full,
215
216 /// Channel has been disconnected, most likely due to the handler thread crashing.
217 Disconnected,
218}
219
220/// Created by the [`batched_fn`](crate::batched_fn) macro.
221///
222/// A `BatchedFn` is a wrapper around a [`handler`](crate::batched_fn#handler)
223/// that provides the interface for evaluating a single input as part of a batch of other inputs.
224pub struct BatchedFn<T, R>
225where
226 T: 'static + Send + Sync + std::fmt::Debug,
227 R: 'static + Send + Sync + std::fmt::Debug,
228{
229 tx: Sender<(T, Sender<R>)>,
230}
231
232impl<T, R> BatchedFn<T, R>
233where
234 T: 'static + Send + Sync + std::fmt::Debug,
235 R: 'static + Send + Sync + std::fmt::Debug,
236{
237 pub fn new(tx: Sender<(T, Sender<R>)>) -> Self {
238 Self { tx }
239 }
240
241 /// Evaluate a single input as part of a batch of other inputs.
242 pub async fn evaluate_in_batch(&self, input: T) -> Result<R, Error> {
243 // Can use `unbounded` channel because we already get backpressure from
244 // the channel that `self.tx` sends to.
245 let (result_tx, result_rx) = unbounded::<R>();
246 self.tx.try_send((input, result_tx)).map_err(|e| match e {
247 flume::TrySendError::Full(_) => Error::Full,
248 flume::TrySendError::Disconnected(_) => Error::Disconnected,
249 })?;
250 result_rx
251 .recv_async()
252 .await
253 .map_err(|_| Error::Disconnected)
254 }
255}
256
257#[doc(hidden)]
258#[macro_export]
259macro_rules! __batched_fn_internal {
260 (
261 handler = |$batch:ident: $batch_input_type:ty $(, $ctx_arg:ident: &$ctx_arg_ty:ty )*| -> $batch_output_type:ty $fn_body:block ;
262 config = {
263 $( $cfg:ident: $cfg_init:expr ),* $(,)?
264 };
265 context = {
266 $( $ctx:ident: $ctx_init:expr ),* $(,)?
267 } $(;)?
268 ) => {{
269 static BATCHED_FN: $crate::Lazy<
270 $crate::BatchedFn<
271 <$batch_input_type as $crate::Batch>::Item,
272 <$batch_output_type as $crate::Batch>::Item,
273 >,
274 > = $crate::Lazy::new(|| {
275 let config = $crate::Config {
276 $( $cfg: $cfg_init, )*
277 ..Default::default()
278 };
279
280 let (tx, mut rx) = match config.channel_cap {
281 None => {
282 $crate::unbounded::<(
283 <$batch_input_type as $crate::Batch>::Item,
284 $crate::Sender<<$batch_output_type as $crate::Batch>::Item>,
285 )>()
286 }
287 Some(cap) => {
288 $crate::bounded::<(
289 <$batch_input_type as $crate::Batch>::Item,
290 $crate::Sender<<$batch_output_type as $crate::Batch>::Item>,
291 )>(cap)
292 }
293 };
294
295 std::thread::spawn(move || {
296 // Create handler closure.
297 let handler = |$batch: $batch_input_type $(, $ctx_arg: &$ctx_arg_ty )*| -> $batch_output_type {
298 $fn_body
299 };
300
301 // Set config vars.
302 let max_batch_size: usize = config.max_batch_size;
303 let max_delay: u128 = config.max_delay;
304
305 // Initialize handler context.
306 struct _Context {
307 $( $ctx_arg: $ctx_arg_ty, )*
308 }
309
310 let context = _Context {
311 $( $ctx: $ctx_init, )*
312 };
313
314 // Wait for an input.
315 while let Ok((input, result_tx)) = rx.recv() {
316 let mut batch_input =
317 <$batch_input_type as $crate::Batch>::with_capacity(max_batch_size);
318 let mut batch_txs = Vec::with_capacity(max_batch_size);
319 batch_input.push(input);
320 batch_txs.push(result_tx);
321
322 let mut vacancy = max_batch_size - 1;
323 let mut time_left = max_delay as u64;
324 let start = std::time::Instant::now();
325
326 // While there is still room in the batch we'll wait at most `max_delay`
327 // milliseconds to try to fill it.
328 while vacancy > 0 && time_left > 0 {
329 if let Ok((next_input, next_result_tx)) =
330 rx.recv_timeout(std::time::Duration::from_millis(time_left))
331 {
332 batch_input.push(next_input);
333 batch_txs.push(next_result_tx);
334 vacancy -= 1;
335 let elapsed = start.elapsed().as_millis();
336 time_left = if elapsed > max_delay {
337 0
338 } else {
339 (max_delay - elapsed) as u64
340 };
341 } else {
342 break;
343 }
344 }
345
346 let batch_output = handler(batch_input $(, &context.$ctx_arg )*);
347 for (output, mut result_tx) in batch_output.into_iter().zip(batch_txs) {
348 result_tx.send(output).ok();
349 }
350 }
351 });
352
353 $crate::BatchedFn::new(tx)
354 });
355
356 |input| BATCHED_FN.evaluate_in_batch(input)
357 }};
358
359}
360
361/// Macro for creating a batched function.
362///
363/// This macro has 3 parameters: [`handler`](#handler), [`config`](#config), and
364/// [`context`](#context). It returns an async function that wraps
365/// [`BatchedFn::evaluate_in_batch`](struct.BatchedFn.html#method.evaluate_in_batch).
366///
367/// # Parameters
368///
369/// ### `handler`
370///
371/// The handler must be in the form of a closure declaration that takes a batch
372/// and any number of references to objects in the context as input and
373/// returns a different type of batch.
374///
375/// ### `config`
376///
377/// Within the config you can specify the `max_batch_size`, `max_delay`, and `channel_cap`.
378///
379/// The batched function will wait at most `max_delay` milliseconds after receiving a single
380/// input to fill a batch of size `max_batch_size`. If enough inputs to fill a full batch
381/// are not received within `max_delay` milliseconds then the partial batch will be ran as-is.
382///
383/// The `channel_cap` option allows you to apply back pressure if too many inputs are waiting for
384/// the handler thread to accept another batch. By default `channel_cap` is `None`, but if
385/// set to `Some(usize)` then
386/// [`BatchedFn::evaluate_in_batch`](struct.BatchedFn.html#method.evaluate_in_batch) will
387/// return [`Error::Full`](crate::Error#variant.Full) if the channel between the calling thread and the handler thread is at this
388/// capacity. You probably want to set this to some multiple of `max_batch_size`.
389///
390/// ## `context`
391///
392/// Any additional reference that the handler takes as input must be defined within
393/// the context.
394///
395/// # Examples
396///
397/// ```rust
398/// # #[macro_use] extern crate batched_fn;
399/// use batched_fn::{batched_fn, Error};
400///
401/// async fn double(x: i32) -> Result<i32, Error> {
402/// let batched_double = batched_fn! {
403/// handler = |batch: Vec<i32>| -> Vec<i32> {
404/// batch.into_iter().map(|x| x*2).collect()
405/// };
406/// config = {
407/// max_batch_size: 4,
408/// max_delay: 50,
409/// channel_cap: Some(20),
410/// };
411/// context = {};
412/// };
413///
414/// batched_double(x).await
415/// }
416/// ```
417///
418/// You can also provide an arbitrary number of additional arguments to the handler by reference.
419/// All of the objects have to be initialized in the [`context`](#context):
420///
421/// ```rust
422/// # #[macro_use] extern crate batched_fn;
423/// # use batched_fn::{batched_fn, Error};
424/// async fn multiply(x: i32) -> Result<i32, Error> {
425/// let batched_multiply = batched_fn! {
426/// handler = |batch: Vec<i32>, factor: &i32| -> Vec<i32> {
427/// batch.into_iter().map(|x| *factor * x ).collect()
428/// };
429/// config = {
430/// max_batch_size: 4,
431/// max_delay: 50
432/// };
433/// context = {
434/// factor: 3
435/// };
436/// };
437///
438/// batched_multiply(x).await
439/// }
440/// ```
441#[macro_export]
442macro_rules! batched_fn {
443 (
444 handler = |$batch:ident: $batch_input_type:ty $(, $ctx_arg:ident: &$ctx_arg_ty:ty )*| -> $batch_output_type:ty $fn_body:block ;
445 config = {
446 $( $cfg:ident: $cfg_init:expr ),* $(,)?
447 };
448 context = {
449 $( $ctx:ident: $ctx_init:expr ),* $(,)?
450 } $(;)?
451 ) => {
452 $crate::__batched_fn_internal!(
453 handler = |$batch: $batch_input_type $(, $ctx_arg: &$ctx_arg_ty )*| -> $batch_output_type $fn_body ;
454 config = {
455 $( $cfg: $cfg_init, )*
456 };
457 context = {
458 $( $ctx: $ctx_init, )*
459 };
460 );
461 };
462}