batched_fn/
lib.rs

1//! Deep learning models are usually implemented to make efficient use of a GPU by batching inputs together
2//! in "mini-batches". However, applications serving these models often receive requests one-by-one.
3//! So using a conventional single or multi-threaded server approach will under-utilize the GPU and lead to latency that increases
4//! linearly with the volume of requests.
5//!
6//! `batched-fn` is a drop-in solution for deep learning webservers that queues individual requests and provides them as a batch
7//! to your model. It can be added to any application with minimal refactoring simply by inserting the [`batched_fn`](crate::batched_fn)
8//! macro into the function that runs requests through the model.
9//!
10//! ## Features
11//!
12//! - 🚀 Easy to use: drop the `batched_fn!` macro into existing code.
13//! - 🔥 Lightweight and fast: queue system implemented on top of the blazingly fast [flume crate](https://github.com/zesterer/flume).
14//! - 🙌 Easy to tune: simply adjust [`max_delay`](crate::batched_fn#config) and [`max_batch_size`](crate::batched_fn#config).
15//! - 🛑 [Back pressure](https://medium.com/@jayphelps/backpressure-explained-the-flow-of-data-through-software-2350b3e77ce7) mechanism included:
16//!   just set [`channel_cap`](crate::batched_fn#config) and handle
17//!   [`Error::Full`](crate::Error#variant.Full) by returning a 503 from your webserver.
18//!
19//! ## Examples
20//!
21//! Suppose you have a model API that look like this:
22//!
23//! ```rust
24//! // `Batch` could be anything that implements the `batched_fn::Batch` trait.
25//! type Batch<T> = Vec<T>;
26//!
27//! #[derive(Debug)]
28//! struct Input {
29//!     // ...
30//! }
31//!
32//! #[derive(Debug)]
33//! struct Output {
34//!     // ...
35//! }
36//!
37//! struct Model {
38//!     // ...
39//! }
40//!
41//! impl Model {
42//!     fn predict(&self, batch: Batch<Input>) -> Batch<Output> {
43//!         // ...
44//!         # batch.iter().map(|_| Output {}).collect()
45//!     }
46//!
47//!     fn load() -> Self {
48//!         // ...
49//!         # Self {}
50//!     }
51//! }
52//! ```
53//!
54//! Without `batched-fn` a webserver route would need to call `Model::predict` on each
55//! individual input, resulting in a bottleneck from under-utilizing the GPU:
56//!
57//! ```rust
58//! use once_cell::sync::Lazy;
59//! # use batched_fn::{batched_fn, Batch as BatchTrait};
60//! # type Batch<T> = Vec<T>;
61//! # #[derive(Debug)]
62//! # struct Input {}
63//! # #[derive(Debug)]
64//! # struct Output {}
65//! # struct Model {}
66//! # impl Model {
67//! #     fn predict(&self, batch: Batch<Input>) -> Batch<Output> {
68//! #         batch.iter().map(|_| Output {}).collect()
69//! #     }
70//! #     fn load() -> Self { Self {} }
71//! # }
72//! static MODEL: Lazy<Model> = Lazy::new(Model::load);
73//!
74//! fn predict_for_http_request(input: Input) -> Output {
75//!     let mut batched_input = Batch::with_capacity(1);
76//!     batched_input.push(input);
77//!     MODEL.predict(batched_input).pop().unwrap()
78//! }
79//! ```
80//!
81//! But by dropping the [`batched_fn`](crate::batched_fn) macro into your code you automatically get batched
82//! inference behind the scenes without changing the one-to-one relationship between inputs and
83//! outputs:
84//!
85//! ```rust
86//! # use batched_fn::{batched_fn, Batch as BatchTrait};
87//! # type Batch<T> = Vec<T>;
88//! # #[derive(Debug)]
89//! # struct Input {}
90//! # #[derive(Debug)]
91//! # struct Output {}
92//! # struct Model {}
93//! # impl Model {
94//! #     fn predict(&self, batch: Batch<Input>) -> Batch<Output> {
95//! #         batch.iter().map(|_| Output {}).collect()
96//! #     }
97//! #     fn load() -> Self { Self {} }
98//! # }
99//! async fn predict_for_http_request(input: Input) -> Output {
100//!     let batch_predict = batched_fn! {
101//!         handler = |batch: Batch<Input>, model: &Model| -> Batch<Output> {
102//!             model.predict(batch)
103//!         };
104//!         config = {
105//!             max_batch_size: 16,
106//!             max_delay: 50,
107//!         };
108//!         context = {
109//!             model: Model::load(),
110//!         };
111//!     };
112//!     batch_predict(input).await.unwrap()
113//! }
114//! ```
115//!
116//! ❗️ *Note that the `predict_for_http_request` function now has to be `async`.*
117//!
118//! Here we set the [`max_batch_size`](crate::batched_fn#config) to 16 and [`max_delay`](crate::batched_fn#config)
119//! to 50 milliseconds. This means the batched function will wait at most 50 milliseconds after receiving a single
120//! input to fill a batch of 16. If 15 more inputs are not received within 50 milliseconds
121//! then the partial batch will be ran as-is.
122//!
123//! ## Tuning max batch size and max delay
124//!
125//! The optimal batch size and delay will depend on the specifics of your use case, such as how big of a batch you can fit in memory
126//! (typically on the order of 8, 16, 32, or 64 for a deep learning model) and how long of a delay you can afford.
127//! In general you want to set `max_batch_size` as high as you can, assuming the total processing time for `N` examples is minimized
128//! with a batch size of `N`, and keep `max_delay` small relative to the time it takes for your
129//! handler function to process a batch.
130//!
131//! ## Implementation details
132//!
133//! When the `batched_fn` macro is invoked it spawns a new thread where the
134//! [`handler`](crate::batched_fn#handler) will
135//! be ran. Within that thread, every object specified in the [`context`](crate::batched_fn#context)
136//! is initialized and then passed by reference to the handler each time it is run.
137//!
138//! The object returned by the macro is just a closure that sends a single input and a callback
139//! through an asyncronous channel to the handler thread. When the handler finishes
140//! running a batch it invokes the callback corresponding to each input with the corresponding output,
141//! which triggers the closure to wake up and return the output.
142
143extern crate flume;
144extern crate once_cell;
145
146#[doc(hidden)]
147pub use flume::{bounded, unbounded, Sender};
148#[doc(hidden)]
149pub use once_cell::sync::Lazy;
150
151/// The `Batch` trait is essentially an abstraction of `Vec<T>`. The input and output of a batch
152/// [`handler`](crate::batched_fn#handler) must implement `Batch`.
153///
154/// It represents an owned collection of ordered items of a single type.
155pub trait Batch: IntoIterator<Item = <Self as Batch>::Item> {
156    type Item;
157
158    fn with_capacity(n: usize) -> Self;
159
160    fn len(&self) -> usize;
161
162    fn push(&mut self, item: <Self as Batch>::Item);
163
164    fn is_empty(&self) -> bool {
165        self.len() == 0
166    }
167}
168
169impl<T> Batch for Vec<T> {
170    type Item = T;
171
172    fn with_capacity(n: usize) -> Vec<T> {
173        Vec::<T>::with_capacity(n)
174    }
175
176    fn len(&self) -> usize {
177        self.len()
178    }
179
180    fn push(&mut self, item: T) {
181        self.push(item);
182    }
183}
184
185#[doc(hidden)]
186pub struct Config {
187    pub max_batch_size: usize,
188    pub max_delay: u128,
189    pub channel_cap: Option<usize>,
190    // Used to avoid clippy linting errors within the macro-generated code
191    // when updating the fields of this struct.
192    pub _phantom: std::marker::PhantomData<bool>,
193}
194
195impl Default for Config {
196    fn default() -> Self {
197        Self {
198            max_batch_size: 8,
199            max_delay: 50,
200            channel_cap: None,
201            _phantom: std::marker::PhantomData,
202        }
203    }
204}
205
206/// Error types that can occur while calling a batched function.
207#[derive(Debug, Copy, Clone)]
208pub enum Error {
209    /// Channel is full.
210    ///
211    /// This can happen if you've set `channel_cap`, and should usually be handled
212    /// by returning a 503 error code from your server to signal that the server is too
213    /// busy at the moment to handle any more requests.
214    Full,
215
216    /// Channel has been disconnected, most likely due to the handler thread crashing.
217    Disconnected,
218}
219
220/// Created by the [`batched_fn`](crate::batched_fn) macro.
221///
222/// A `BatchedFn` is a wrapper around a [`handler`](crate::batched_fn#handler)
223/// that provides the interface for evaluating a single input as part of a batch of other inputs.
224pub struct BatchedFn<T, R>
225where
226    T: 'static + Send + Sync + std::fmt::Debug,
227    R: 'static + Send + Sync + std::fmt::Debug,
228{
229    tx: Sender<(T, Sender<R>)>,
230}
231
232impl<T, R> BatchedFn<T, R>
233where
234    T: 'static + Send + Sync + std::fmt::Debug,
235    R: 'static + Send + Sync + std::fmt::Debug,
236{
237    pub fn new(tx: Sender<(T, Sender<R>)>) -> Self {
238        Self { tx }
239    }
240
241    /// Evaluate a single input as part of a batch of other inputs.
242    pub async fn evaluate_in_batch(&self, input: T) -> Result<R, Error> {
243        // Can use `unbounded` channel because we already get backpressure from
244        // the channel that `self.tx` sends to.
245        let (result_tx, result_rx) = unbounded::<R>();
246        self.tx.try_send((input, result_tx)).map_err(|e| match e {
247            flume::TrySendError::Full(_) => Error::Full,
248            flume::TrySendError::Disconnected(_) => Error::Disconnected,
249        })?;
250        result_rx
251            .recv_async()
252            .await
253            .map_err(|_| Error::Disconnected)
254    }
255}
256
257#[doc(hidden)]
258#[macro_export]
259macro_rules! __batched_fn_internal {
260    (
261        handler = |$batch:ident: $batch_input_type:ty $(, $ctx_arg:ident: &$ctx_arg_ty:ty )*| -> $batch_output_type:ty $fn_body:block ;
262        config = {
263            $( $cfg:ident: $cfg_init:expr ),* $(,)?
264        };
265        context = {
266            $( $ctx:ident: $ctx_init:expr ),* $(,)?
267        } $(;)?
268    ) => {{
269        static BATCHED_FN: $crate::Lazy<
270            $crate::BatchedFn<
271                <$batch_input_type as $crate::Batch>::Item,
272                <$batch_output_type as $crate::Batch>::Item,
273            >,
274        > = $crate::Lazy::new(|| {
275            let config = $crate::Config {
276                $( $cfg: $cfg_init, )*
277                ..Default::default()
278            };
279
280            let (tx, mut rx) = match config.channel_cap {
281                None => {
282                    $crate::unbounded::<(
283                        <$batch_input_type as $crate::Batch>::Item,
284                        $crate::Sender<<$batch_output_type as $crate::Batch>::Item>,
285                    )>()
286                }
287                Some(cap) => {
288                    $crate::bounded::<(
289                        <$batch_input_type as $crate::Batch>::Item,
290                        $crate::Sender<<$batch_output_type as $crate::Batch>::Item>,
291                    )>(cap)
292                }
293            };
294
295            std::thread::spawn(move || {
296                // Create handler closure.
297                let handler = |$batch: $batch_input_type $(, $ctx_arg: &$ctx_arg_ty )*| -> $batch_output_type {
298                    $fn_body
299                };
300
301                // Set config vars.
302                let max_batch_size: usize = config.max_batch_size;
303                let max_delay: u128 = config.max_delay;
304
305                // Initialize handler context.
306                struct _Context {
307                    $( $ctx_arg: $ctx_arg_ty, )*
308                }
309
310                let context = _Context {
311                    $( $ctx: $ctx_init, )*
312                };
313
314                // Wait for an input.
315                while let Ok((input, result_tx)) = rx.recv() {
316                    let mut batch_input =
317                        <$batch_input_type as $crate::Batch>::with_capacity(max_batch_size);
318                    let mut batch_txs = Vec::with_capacity(max_batch_size);
319                    batch_input.push(input);
320                    batch_txs.push(result_tx);
321
322                    let mut vacancy = max_batch_size - 1;
323                    let mut time_left = max_delay as u64;
324                    let start = std::time::Instant::now();
325
326                    // While there is still room in the batch we'll wait at most `max_delay`
327                    // milliseconds to try to fill it.
328                    while vacancy > 0 && time_left > 0 {
329                        if let Ok((next_input, next_result_tx)) =
330                            rx.recv_timeout(std::time::Duration::from_millis(time_left))
331                        {
332                            batch_input.push(next_input);
333                            batch_txs.push(next_result_tx);
334                            vacancy -= 1;
335                            let elapsed = start.elapsed().as_millis();
336                            time_left = if elapsed > max_delay {
337                                0
338                            } else {
339                                (max_delay - elapsed) as u64
340                            };
341                        } else {
342                            break;
343                        }
344                    }
345
346                    let batch_output = handler(batch_input $(, &context.$ctx_arg )*);
347                    for (output, mut result_tx) in batch_output.into_iter().zip(batch_txs) {
348                        result_tx.send(output).ok();
349                    }
350                }
351            });
352
353            $crate::BatchedFn::new(tx)
354        });
355
356        |input| BATCHED_FN.evaluate_in_batch(input)
357    }};
358
359}
360
361/// Macro for creating a batched function.
362///
363/// This macro has 3 parameters: [`handler`](#handler), [`config`](#config), and
364/// [`context`](#context). It returns an async function that wraps
365/// [`BatchedFn::evaluate_in_batch`](struct.BatchedFn.html#method.evaluate_in_batch).
366///
367/// # Parameters
368///
369/// ### `handler`
370///
371/// The handler must be in the form of a closure declaration that takes a batch
372/// and any number of references to objects in the context as input and
373/// returns a different type of batch.
374///
375/// ### `config`
376///
377/// Within the config you can specify the `max_batch_size`, `max_delay`, and `channel_cap`.
378///
379/// The batched function will wait at most `max_delay` milliseconds after receiving a single
380/// input to fill a batch of size `max_batch_size`. If enough inputs to fill a full batch
381/// are not received within `max_delay` milliseconds then the partial batch will be ran as-is.
382///
383/// The `channel_cap` option allows you to apply back pressure if too many inputs are waiting for
384/// the handler thread to accept another batch. By default `channel_cap` is `None`, but if
385/// set to `Some(usize)` then
386/// [`BatchedFn::evaluate_in_batch`](struct.BatchedFn.html#method.evaluate_in_batch) will
387/// return [`Error::Full`](crate::Error#variant.Full) if the channel between the calling thread and the handler thread is at this
388/// capacity. You probably want to set this to some multiple of `max_batch_size`.
389///
390/// ## `context`
391///
392/// Any additional reference that the handler takes as input must be defined within
393/// the context.
394///
395/// # Examples
396///
397/// ```rust
398/// # #[macro_use] extern crate batched_fn;
399/// use batched_fn::{batched_fn, Error};
400///
401/// async fn double(x: i32) -> Result<i32, Error> {
402///     let batched_double = batched_fn! {
403///         handler = |batch: Vec<i32>| -> Vec<i32> {
404///             batch.into_iter().map(|x| x*2).collect()
405///         };
406///         config = {
407///             max_batch_size: 4,
408///             max_delay: 50,
409///             channel_cap: Some(20),
410///         };
411///         context = {};
412///     };
413///
414///     batched_double(x).await
415/// }
416/// ```
417///
418/// You can also provide an arbitrary number of additional arguments to the handler by reference.
419/// All of the objects have to be initialized in the [`context`](#context):
420///
421/// ```rust
422/// # #[macro_use] extern crate batched_fn;
423/// # use batched_fn::{batched_fn, Error};
424/// async fn multiply(x: i32) -> Result<i32, Error> {
425///     let batched_multiply = batched_fn! {
426///         handler = |batch: Vec<i32>, factor: &i32| -> Vec<i32> {
427///             batch.into_iter().map(|x| *factor * x ).collect()
428///         };
429///         config = {
430///             max_batch_size: 4,
431///             max_delay: 50
432///         };
433///         context = {
434///             factor: 3
435///         };
436///     };
437///
438///     batched_multiply(x).await
439/// }
440/// ```
441#[macro_export]
442macro_rules! batched_fn {
443    (
444        handler = |$batch:ident: $batch_input_type:ty $(, $ctx_arg:ident: &$ctx_arg_ty:ty )*| -> $batch_output_type:ty $fn_body:block ;
445        config = {
446            $( $cfg:ident: $cfg_init:expr ),* $(,)?
447        };
448        context = {
449            $( $ctx:ident: $ctx_init:expr ),* $(,)?
450        } $(;)?
451    ) => {
452        $crate::__batched_fn_internal!(
453            handler = |$batch: $batch_input_type $(, $ctx_arg: &$ctx_arg_ty )*| -> $batch_output_type $fn_body ;
454            config = {
455                $( $cfg: $cfg_init, )*
456            };
457            context = {
458                $( $ctx: $ctx_init, )*
459            };
460        );
461    };
462}