cervo_core/
inferer.rs

1// Author: Tom Olsson <tom.olsson@embark-studios.com>
2// Copyright © 2020, Embark Studios, all rights reserved.
3// Created: 18 May 2020
4
5#![warn(clippy::all)]
6
7/*!
8Inferers is the main access-point for cervo; providing a higher-level API on top of `tract`. Irregardless of
9inferer-flavour you choose, cervo tries to provide a uniform API for batched dictionary based inference.
10
11Using dictionary-based inference comes at a performance overhead; but helps maintain some generality. Our use-case
12hasn't shown that this is significant enough to warrant the fiddliness of other approaches - interning, slot-markers -
13or the fragility of delegating input-building a layer up.
14
15## Choosing an inferer
16
17 <p style="background:rgba(255,181,77,0.16);padding:0.75em;"> <strong>Note:</strong> Our inferer setup has been iterated
18on since 2019, and we've gone through a few variants and tested a bunch of different infering setups. See the rule of
19thumb for selecting an inferer below; but it is suggested to benchmark. While undocumented, you can use the code in the
20`perf-test` folder on GitHub to run various benchmarks.</p>
21
22Cervo currently provides four different inferers, two of which we've used historially (basic and fixed) and two based on
23newer tract functionalities that we've not tested as much yet. You'll find more detail on each page, but here comes a
24quick rundown of the various use cases:
25
26| Inferer   | Batch size   | Memory use                                          | Performance |
27| --------- | ------------ | --------------------------------------------------- | ----------- |
28| Basic     | 1            | Fixed                                               | Linear with number of elements |
29| Fixed     | Known, exact | Fixed, linear with number of configured batch sizes | Optimal if exact match                |
30| Memoizing | Unknown      | Linear with number of batch sizes                   | Optimal, high cost for new batch size |
31| Dynamic   | Unknown      | Fixed                                               | Good scaling but high overhead         |
32
33As a rule of thumb, use a basic inferer if you'll almost always pass a single item. If you need more items and know how
34many, use a fixed inferer. Otherwise, use a memoizing inferer if you can afford the spikes and potential memory use. As
35a final resort you can use the true dynamic inferer trading off the memory use for worse performance.
36 */
37
38use anyhow::{Error, Result};
39use std::collections::HashMap;
40
41mod basic;
42mod dynamic;
43mod fixed;
44mod helpers;
45mod memoizing;
46
47pub use basic::BasicInferer;
48pub use dynamic::DynamicInferer;
49pub use fixed::FixedBatchInferer;
50pub use memoizing::MemoizingDynamicInferer;
51
52use crate::{
53    batcher::{Batched, Batcher, ScratchPadView},
54    epsilon::{EpsilonInjector, NoiseGenerator},
55};
56
57/// The data of one element in a batch.
58#[derive(Clone, Debug)]
59pub struct State<'a> {
60    pub data: HashMap<&'a str, Vec<f32>>,
61}
62
63impl<'a> State<'a> {
64    /// Create a new empty state to fill with data
65    pub fn empty() -> Self {
66        Self {
67            data: Default::default(),
68        }
69    }
70}
71
72/// The output for one batch element.
73#[derive(Clone, Debug, Default)]
74pub struct Response<'a> {
75    pub data: HashMap<&'a str, Vec<f32>>,
76}
77
78impl<'a> Response<'a> {
79    /// Create a new empty state to fill with data
80    pub fn empty() -> Self {
81        Self {
82            data: Default::default(),
83        }
84    }
85}
86
87/// The main workhorse shared by all components in Cervo.
88pub trait Inferer {
89    /// Query the inferer for how many elements it can deal with in a single batch.
90    fn select_batch_size(&self, max_count: usize) -> usize;
91
92    /// Execute the model on the provided pre-batched data.
93    fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error>;
94
95    /// Retrieve the name and shapes of the model inputs. This API
96    /// only contains the external API, so code-based transforms
97    /// outside the model are hidden.
98    fn input_shapes(&self) -> &[(String, Vec<usize>)] {
99        self.raw_input_shapes()
100    }
101
102    /// Retrieve the name and shapes of the model outputs. This API
103    /// only contains the external API, so code-based transforms
104    /// outside the model are hidden.
105    fn output_shapes(&self) -> &[(String, Vec<usize>)] {
106        self.raw_output_shapes()
107    }
108
109    /// Retrieve the name and shapes of the model inputs.
110    fn raw_input_shapes(&self) -> &[(String, Vec<usize>)];
111
112    /// Retrieve the name and shapes of the model outputs.
113    fn raw_output_shapes(&self) -> &[(String, Vec<usize>)];
114
115    fn begin_agent(&self, id: u64);
116    fn end_agent(&self, id: u64);
117}
118
119/// Helper trait to provide helper functions for loadable models.
120pub trait InfererProvider {
121    /// Build a [`BasicInferer`].
122    fn build_basic(self) -> Result<BasicInferer>;
123
124    /// Build a [`FixedBatchInferer`].
125    fn build_fixed(self, sizes: &[usize]) -> Result<FixedBatchInferer>;
126
127    /// Build a [`MemoizingDynamicInferer`].
128    fn build_memoizing(self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer>;
129
130    /// Build a [`DynamicInferer`].
131    fn build_dynamic(self) -> Result<DynamicInferer>;
132}
133
134/// Builder for inferers.
135pub struct InfererBuilder<P: InfererProvider> {
136    provider: P,
137}
138
139impl<P> InfererBuilder<P>
140where
141    P: InfererProvider,
142{
143    /// Begin the building process from the provided model provider.
144    pub fn new(provider: P) -> Self {
145        Self { provider }
146    }
147
148    /// Build a [`BasicInferer`].
149    pub fn build_basic(self) -> Result<BasicInferer> {
150        self.provider.build_basic()
151    }
152
153    /// Build a [`FixedBatchInferer`].
154    pub fn build_fixed(self, sizes: &[usize]) -> Result<FixedBatchInferer> {
155        self.provider.build_fixed(sizes)
156    }
157
158    /// Build a [`DynamicInferer`].
159    pub fn build_dynamic(self) -> Result<DynamicInferer> {
160        self.provider.build_dynamic()
161    }
162
163    /// Build a [`MemoizingDynamicInferer`].
164    pub fn build_memoizing(self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer> {
165        self.provider.build_memoizing(preload_sizes)
166    }
167}
168
169/// Extension trait for [`Inferer`].
170// TODO[TSolberg]: This was intended to be part of the builder but it becomes an awful state-machine and is hard to extend.
171pub trait InfererExt: Inferer + Sized {
172    /// Add an epsilon injector using the default noise kind.
173    fn with_default_epsilon(self, key: &str) -> Result<EpsilonInjector<Self>> {
174        EpsilonInjector::wrap(self, key)
175    }
176
177    /// Add an epsilon injector with a specific noise generator.
178    fn with_epsilon<G: NoiseGenerator>(
179        self,
180        generator: G,
181        key: &str,
182    ) -> Result<EpsilonInjector<Self, G>> {
183        EpsilonInjector::with_generator(self, generator, key)
184    }
185
186    /// Wrap in a batching interface.
187    fn into_batched(self) -> Batched<Self> {
188        Batched::wrap(self)
189    }
190
191    /// Execute the model on the provided batch of elements.
192    #[deprecated(
193        note = "Please use the more explicit 'infer_batch' instead.",
194        since = "0.3.0"
195    )]
196    fn infer(
197        &mut self,
198        observations: HashMap<u64, State<'_>>,
199    ) -> Result<HashMap<u64, Response<'_>>, Error> {
200        self.infer_batch(observations)
201    }
202
203    /// Execute the model on the provided pre-batched data.
204    fn infer_batch<'this>(
205        &'this self,
206        batch: HashMap<u64, State<'_>>,
207    ) -> Result<HashMap<u64, Response<'this>>, anyhow::Error> {
208        let mut batcher = Batcher::new_sized(self, batch.len());
209        batcher.extend(batch)?;
210
211        batcher.execute(self)
212    }
213
214    /// Execute the model on the provided pre-batched data.
215    fn infer_single<'this>(
216        &'this self,
217        input: State<'_>,
218    ) -> Result<Response<'this>, anyhow::Error> {
219        let mut batcher = Batcher::new_sized(self, 1);
220        batcher.push(0, input)?;
221
222        Ok(batcher.execute(self)?.remove(&0).unwrap())
223    }
224}
225
226impl<T> InfererExt for T where T: Inferer + Sized {}
227
228impl Inferer for Box<dyn Inferer + Send> {
229    fn select_batch_size(&self, max_count: usize) -> usize {
230        self.as_ref().select_batch_size(max_count)
231    }
232
233    fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
234        self.as_ref().infer_raw(batch)
235    }
236
237    fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
238        self.as_ref().raw_input_shapes()
239    }
240
241    fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
242        self.as_ref().raw_output_shapes()
243    }
244
245    fn begin_agent(&self, id: u64) {
246        self.as_ref().begin_agent(id);
247    }
248
249    fn end_agent(&self, id: u64) {
250        self.as_ref().end_agent(id);
251    }
252}
253
254impl Inferer for Box<dyn Inferer> {
255    fn select_batch_size(&self, max_count: usize) -> usize {
256        self.as_ref().select_batch_size(max_count)
257    }
258
259    fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
260        self.as_ref().infer_raw(batch)
261    }
262
263    fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
264        self.as_ref().raw_input_shapes()
265    }
266
267    fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
268        self.as_ref().raw_output_shapes()
269    }
270
271    fn begin_agent(&self, id: u64) {
272        self.as_ref().begin_agent(id);
273    }
274
275    fn end_agent(&self, id: u64) {
276        self.as_ref().end_agent(id);
277    }
278}