cervo_core/
batcher.rs

1// Author: Tom Solberg <tom.solberg@embark-studios.com>
2// Copyright © 2022, Embark Studios, all rights reserved.
3// Created: 22 July 2022
4
5/*!
6Tools for batching and batched execution.
7
8Batching leads to lower memory pressure by reusing data gathering
9allocations, and higher performance by being able to run larger
10kernels. This is especially noticeable for networks with large matrix
11multiplications where the weights do not fit in the CPU cache.
12*/
13
14mod scratch;
15mod wrapper;
16
17use self::scratch::ScratchPad;
18use crate::inferer::{Inferer, Response, State};
19pub use scratch::ScratchPadView;
20use std::collections::HashMap;
21pub use wrapper::Batched;
22
23/// Low-level batch builder to help transition from per-entity code to
24/// batched inference. Consider using the [`Batched`] wrapper instead
25/// to avoid tracking two objects.
26///
27/// Reusing this across frames will have a noticeable performance
28/// impact for large model inputs or outputs, and reduce memory
29/// pressure.
30///
31/// Note that Batchers are specific to the inferer used for
32/// initialization.
33pub struct Batcher {
34    scratch: ScratchPad,
35}
36
37impl Batcher {
38    /// Create a new batcher for the provided inferer.
39    pub fn new(inferer: &dyn Inferer) -> Self {
40        Self {
41            scratch: ScratchPad::new_for_shapes(
42                inferer.raw_input_shapes(),
43                inferer.raw_output_shapes(),
44            ),
45        }
46    }
47
48    /// Create a new batcher for the provided inferer with space for the specified batch size.
49    pub fn new_sized(inferer: &dyn Inferer, size: usize) -> Self {
50        Self {
51            scratch: ScratchPad::new_with_size(
52                inferer.raw_input_shapes(),
53                inferer.raw_output_shapes(),
54                size,
55            ),
56        }
57    }
58
59    #[inline]
60    fn input_slot(&self, name: &str) -> Option<usize> {
61        self.scratch
62            .inputs
63            .iter()
64            .position(|slot| slot.name == name)
65    }
66
67    /// Insert a single element into the batch to include in the next execution.
68    pub fn push(&mut self, id: u64, state: State<'_>) -> anyhow::Result<()> {
69        self.scratch.next(id);
70        for (k, v) in state.data {
71            let slot = self
72                .input_slot(k)
73                .ok_or_else(|| anyhow::anyhow!("key doesn't match an input: {:?}", k))?;
74
75            self.scratch.push(slot, v);
76        }
77
78        Ok(())
79    }
80
81    /// Insert a sequence of elements into the batch to include in the next execution.
82    pub fn extend<'a, Iter: IntoIterator<Item = (u64, State<'a>)>>(
83        &mut self,
84        states: Iter,
85    ) -> anyhow::Result<()> {
86        for (id, state) in states {
87            self.push(id, state)?;
88        }
89
90        Ok(())
91    }
92
93    /// Run the provided inferer on the data that has been enqueued previously.
94    pub fn execute<'b>(
95        &mut self,
96        inferer: &'b dyn Inferer,
97    ) -> anyhow::Result<HashMap<u64, Response<'b>>> {
98        // pick up as many items as possible (by slicing the stores) and feed into the model.
99        // this builds up a set of output stores that'll feed in sequence.
100        let mut total_offset = 0;
101        while self.scratch.batch_size > 0 {
102            let preferred_batch_size = inferer.select_batch_size(self.scratch.batch_size);
103
104            let mut view = self.scratch.chunk(total_offset, preferred_batch_size);
105
106            inferer.infer_raw(&mut view)?;
107            total_offset += preferred_batch_size;
108        }
109
110        let mut outputs = vec![Response::empty(); self.scratch.ids.len()];
111
112        for slot in 0..inferer.output_shapes().len() {
113            let slot_name = &inferer.output_shapes()[slot].0;
114            let scratch_slot = self
115                .scratch
116                .lookup_output_slot(slot_name)
117                .expect("invalid inferer passed to `Batcher::execute`");
118
119            for (idx, o) in outputs.iter_mut().enumerate() {
120                let slot_response = self.scratch.output_slot(scratch_slot, idx..idx + 1);
121                o.data.insert(slot_name, slot_response.to_owned());
122            }
123        }
124
125        Ok(self.scratch.ids.drain(..).zip(outputs).collect::<_>())
126    }
127
128    /// Check if there is any data to run on here.
129    pub fn is_empty(&self) -> bool {
130        self.scratch.batch_size == 0
131    }
132
133    /// Amount of elements to run on in the batch here.
134    pub fn len(&self) -> usize {
135        self.scratch.batch_size
136    }
137}