cervo_core/inferer/
basic.rs1use super::Inferer;
5use crate::{batcher::ScratchPadView, model_api::ModelApi};
6use anyhow::Result;
7use tract_core::prelude::{tvec, TValue, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan};
8use tract_hir::prelude::InferenceModel;
9
10use super::helpers;
11
12pub struct BasicInferer {
25 model: TypedSimplePlan<TypedModel>,
26 model_api: ModelApi,
27}
28
29impl BasicInferer {
30 pub fn from_model(model: InferenceModel) -> TractResult<Self> {
36 let model_api = ModelApi::for_model(&model)?;
37 let model = helpers::build_model(model, &model_api.inputs, 1i32)?;
38
39 Ok(Self { model, model_api })
40 }
41
42 pub fn from_typed(model: TypedModel) -> TractResult<Self> {
43 let model_api = ModelApi::for_typed_model(&model)?;
44 let model = helpers::build_typed(model, 1i32)?;
45
46 Ok(Self { model, model_api })
47 }
48
49 fn build_inputs(&self, obs: &mut ScratchPadView<'_>) -> Result<TVec<TValue>> {
50 let mut inputs = TVec::default();
51
52 for (idx, (name, shape)) in self.model_api.inputs.iter().enumerate() {
53 assert_eq!(name, obs.input_name(idx));
54
55 let mut full_shape = tvec![1];
56 full_shape.extend_from_slice(shape);
57
58 let total_count: usize = full_shape.iter().product();
59 assert_eq!(total_count, obs.input_slot(idx).len());
60
61 let tensor = Tensor::from_shape(&full_shape, obs.input_slot(idx))?;
62
63 inputs.push(tensor.into());
64 }
65
66 Ok(inputs)
67 }
68}
69
70impl Inferer for BasicInferer {
71 fn select_batch_size(&self, _: usize) -> usize {
72 1
73 }
74
75 fn infer_raw(&self, pad: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
76 let inputs = self.build_inputs(pad)?;
77
78 let result = self.model.run(inputs)?;
80
81 for idx in 0..self.model_api.outputs.iter().len() {
82 let value = result[idx].as_slice::<f32>()?;
83 pad.output_slot_mut(idx).copy_from_slice(value);
84 }
85
86 Ok(())
87 }
88
89 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
90 &self.model_api.inputs
91 }
92
93 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
94 &self.model_api.outputs
95 }
96
97 fn begin_agent(&self, _id: u64) {}
98 fn end_agent(&self, _id: u64) {}
99}