cervo_core/inferer/
basic.rs

1/*!
2A basic unbatched inferer that doesn't require a lot of custom setup or management.
3 */
4use super::Inferer;
5use crate::{batcher::ScratchPadView, model_api::ModelApi};
6use anyhow::Result;
7use tract_core::prelude::{tvec, TValue, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan};
8use tract_hir::prelude::InferenceModel;
9
10use super::helpers;
11
12/// The most basic inferer provided will deal with a single element at
13/// a time, at the cost of reduced (but predictable) performance per
14/// element.
15///
16/// # Pros
17///
18/// * Requires no tuning
19/// * Very predictable performance across different workloads
20///
21/// # Cons
22///
23/// * Scales linearly unless it's the only code executing
24pub struct BasicInferer {
25    model: TypedSimplePlan<TypedModel>,
26    model_api: ModelApi,
27}
28
29impl BasicInferer {
30    /// Create an inferer for the provided `inference` model.
31    ///
32    /// # Errors
33    ///
34    /// Will only forward errors from the [`tract_core::model::Graph`] optimization and graph building steps.
35    pub fn from_model(model: InferenceModel) -> TractResult<Self> {
36        let model_api = ModelApi::for_model(&model)?;
37        let model = helpers::build_model(model, &model_api.inputs, 1i32)?;
38
39        Ok(Self { model, model_api })
40    }
41
42    pub fn from_typed(model: TypedModel) -> TractResult<Self> {
43        let model_api = ModelApi::for_typed_model(&model)?;
44        let model = helpers::build_typed(model, 1i32)?;
45
46        Ok(Self { model, model_api })
47    }
48
49    fn build_inputs(&self, obs: &mut ScratchPadView<'_>) -> Result<TVec<TValue>> {
50        let mut inputs = TVec::default();
51
52        for (idx, (name, shape)) in self.model_api.inputs.iter().enumerate() {
53            assert_eq!(name, obs.input_name(idx));
54
55            let mut full_shape = tvec![1];
56            full_shape.extend_from_slice(shape);
57
58            let total_count: usize = full_shape.iter().product();
59            assert_eq!(total_count, obs.input_slot(idx).len());
60
61            let tensor = Tensor::from_shape(&full_shape, obs.input_slot(idx))?;
62
63            inputs.push(tensor.into());
64        }
65
66        Ok(inputs)
67    }
68}
69
70impl Inferer for BasicInferer {
71    fn select_batch_size(&self, _: usize) -> usize {
72        1
73    }
74
75    fn infer_raw(&self, pad: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
76        let inputs = self.build_inputs(pad)?;
77
78        // Run the optimized plan to get actions back!
79        let result = self.model.run(inputs)?;
80
81        for idx in 0..self.model_api.outputs.iter().len() {
82            let value = result[idx].as_slice::<f32>()?;
83            pad.output_slot_mut(idx).copy_from_slice(value);
84        }
85
86        Ok(())
87    }
88
89    fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
90        &self.model_api.inputs
91    }
92
93    fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
94        &self.model_api.outputs
95    }
96
97    fn begin_agent(&self, _id: u64) {}
98    fn end_agent(&self, _id: u64) {}
99}