cervo_core/inferer/
memoizing.rs

1use super::{helpers, Inferer};
2use crate::{batcher::ScratchPadView, model_api::ModelApi};
3use anyhow::Result;
4use parking_lot::{RwLock, RwLockReadGuard, RwLockUpgradableReadGuard, RwLockWriteGuard};
5use std::{
6    collections::{hash_map::Entry, HashMap},
7    ops::Deref,
8};
9use tract_core::prelude::*;
10use tract_hir::prelude::*;
11
12/// The dynamic memoizing batch inferer generates execution plans to
13/// fit each batch perfectly, achieving near-perfect performance no
14/// matter how much data you have - with a hefty up-front cost for
15/// each new batch size.
16///
17/// The dynamic batcher has the highest potential throughput when the
18/// amount of data isn't known. By dynamically generating execution
19/// plans to fit the exact amount of elements in each batch, it will
20/// give tract optimal knowledge for execution each time. The downside
21/// of this is that setting up a new plan is fairly costly, so doing
22/// this for a batch size that is only seen once will waste memory and
23/// compute resources.
24///
25/// While plans are cached; this still means that if your expected
26/// batch size is can vary greatly, you'll end up with noticeable
27/// spikes each time a new plan is generated. If you know you'll have
28/// one or a few batch sizes - but not the exact size - this batcher
29/// will end up providing good value and inform tuning for a fixed
30/// batcher later.
31///
32/// If you know some batch sizes but not all, you can preload the
33/// batcher with those plans to avoid having to build them at runtime.
34///
35/// # Pros
36///
37/// * Optimal amortized performance without tuning
38/// * Requires no tuning for good results
39///
40/// # Cons
41///
42/// * For small amounts of data and large models the spikes can offset
43///   amortized gains significantly
44pub struct MemoizingDynamicInferer {
45    symbol: Symbol,
46    model: TypedModel,
47    model_api: ModelApi,
48    model_cache: RwLock<HashMap<usize, TypedSimplePlan<TypedModel>>>,
49}
50
51impl MemoizingDynamicInferer {
52    /// Create an inferer for the provided `inference` model.
53    ///
54    /// # Errors
55    ///
56    /// Will only forward errors from the [`tract_core::model::Graph`] optimization and graph building steps.
57    pub fn from_model(model: InferenceModel, preloaded_sizes: &[usize]) -> TractResult<Self> {
58        let model_api = ModelApi::for_model(&model)?;
59
60        let (symbol, model) = helpers::build_symbolic_model(model, &model_api.inputs)?;
61        let this = Self {
62            symbol,
63            model,
64            model_api,
65            model_cache: Default::default(),
66        };
67
68        for size in preloaded_sizes {
69            this.get_concrete_model(*size)?;
70        }
71
72        Ok(this)
73    }
74
75    /// Create an inferer for the provided `typed` model.
76    ///
77    /// # Errors
78    ///
79    /// Will only forward errors from the [`tract_core::model::Graph`] optimization and graph building steps.
80    pub fn from_typed(mut model: TypedModel, preloaded_sizes: &[usize]) -> TractResult<Self> {
81        let model_api = ModelApi::for_typed_model(&model)?;
82
83        let symbol = helpers::build_symbolic_typed(&mut model)?;
84        let this = Self {
85            symbol,
86            model,
87            model_api,
88            model_cache: Default::default(),
89        };
90
91        for size in preloaded_sizes {
92            this.get_concrete_model(*size)?;
93        }
94
95        Ok(this)
96    }
97
98    fn build_inputs(&self, batch: &mut ScratchPadView<'_>) -> Result<TVec<TValue>> {
99        let size = batch.len();
100
101        let mut inputs = TVec::default();
102
103        for (idx, (name, shape)) in self.model_api.inputs.iter().enumerate() {
104            assert_eq!(name, batch.input_name(idx));
105
106            let mut full_shape = tvec![size];
107            full_shape.extend_from_slice(shape);
108
109            let total_count: usize = full_shape.iter().product();
110            assert_eq!(total_count, batch.input_slot(idx).len());
111
112            let shape = full_shape;
113
114            let tensor = Tensor::from_shape(&shape, batch.input_slot(idx))?;
115
116            inputs.push(tensor.into());
117        }
118
119        Ok(inputs)
120    }
121
122    fn get_concrete_model(
123        &self,
124        size: usize,
125    ) -> Result<impl Deref<Target = TypedSimplePlan<TypedModel>> + '_> {
126        let cache = self.model_cache.upgradable_read();
127        let cache = {
128            if !cache.contains_key(&size) {
129                let mut content = RwLockUpgradableReadGuard::upgrade(cache);
130                if let Entry::Vacant(e) = content.entry(size) {
131                    let p = self
132                        .model
133                        .concretize_dims(&SymbolValues::default().with(&self.symbol, size as i64))?
134                        .into_optimized()?
135                        .into_decluttered()?
136                        .into_runnable()?;
137
138                    e.insert(p);
139                }
140
141                RwLockWriteGuard::downgrade(content)
142            } else {
143                RwLockUpgradableReadGuard::downgrade(cache)
144            }
145        };
146
147        Ok(RwLockReadGuard::map(cache, |c| &c[&size]))
148    }
149}
150
151impl Inferer for MemoizingDynamicInferer {
152    fn select_batch_size(&self, max_count: usize) -> usize {
153        max_count
154    }
155
156    fn infer_raw(&self, pad: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
157        let count = pad.len();
158        let inputs = self.build_inputs(pad)?;
159
160        let result = self.get_concrete_model(count)?.run(inputs)?;
161
162        for idx in 0..self.model_api.outputs.len() {
163            let value = result[idx].as_slice::<f32>()?;
164            pad.output_slot_mut(idx).copy_from_slice(value);
165        }
166
167        Ok(())
168    }
169
170    fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
171        &self.model_api.inputs
172    }
173
174    fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
175        &self.model_api.outputs
176    }
177
178    fn begin_agent(&self, _id: u64) {}
179    fn end_agent(&self, _id: u64) {}
180}