use super::{helpers, Inferer};
use crate::{batcher::ScratchPadView, model_api::ModelApi};
use anyhow::{Context, Result};
use tract_core::prelude::{tvec, TValue, TVec, Tensor, TractResult, TypedModel, TypedSimplePlan};
use tract_hir::prelude::InferenceModel;
pub struct FixedBatchInferer {
model_api: ModelApi,
models: Vec<BatchedModel>,
}
fn fixup_sizes(sizes: &[usize]) -> Vec<usize> {
let mut sizes = sizes.to_vec();
if !sizes.contains(&1) {
sizes.push(1);
}
sizes.sort_unstable();
sizes.reverse();
sizes
}
impl FixedBatchInferer {
pub fn from_model(model: InferenceModel, sizes: &[usize]) -> TractResult<Self> {
let model_api = ModelApi::for_model(&model)?;
let sizes = fixup_sizes(sizes);
let models = sizes
.into_iter()
.map(|size| {
helpers::build_model(model.clone(), &model_api.inputs, size as i32)
.map(|m| BatchedModel { size, plan: m })
})
.collect::<Result<Vec<_>>>()?;
Ok(Self { models, model_api })
}
pub fn from_typed(model: TypedModel, sizes: &[usize]) -> TractResult<Self> {
let model_api = ModelApi::for_typed_model(&model.clone())?;
let sizes = fixup_sizes(sizes);
let models = sizes
.into_iter()
.map(|size| {
helpers::build_typed(model.clone(), size as i32)
.map(|m| BatchedModel { size, plan: m })
})
.collect::<Result<Vec<_>>>()?;
Ok(Self { models, model_api })
}
}
impl Inferer for FixedBatchInferer {
fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
let plan = self
.models
.iter()
.find(|plan| plan.size == batch.len())
.with_context(|| anyhow::anyhow!("looking for a plan with size {:?}", batch.len()))?;
plan.execute(batch, &self.model_api)
}
fn select_batch_size(&self, max_count: usize) -> usize {
self.models
.iter()
.map(|plan| plan.size)
.find(|size| *size <= max_count)
.unwrap()
}
fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
&self.model_api.inputs
}
fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
&self.model_api.outputs
}
fn begin_agent(&self, _id: u64) {}
fn end_agent(&self, _id: u64) {}
}
struct BatchedModel {
size: usize,
plan: TypedSimplePlan<TypedModel>,
}
impl BatchedModel {
fn build_inputs(
&self,
batch: &mut ScratchPadView<'_>,
model_api: &ModelApi,
) -> Result<TVec<TValue>> {
assert_eq!(batch.len(), self.size);
let size = self.size;
let mut inputs = TVec::default();
for (idx, (name, shape)) in model_api.inputs.iter().enumerate() {
assert_eq!(name, batch.input_name(idx));
let mut full_shape = tvec![size];
full_shape.extend_from_slice(shape);
let total_count: usize = full_shape.iter().product();
assert_eq!(
total_count,
batch.input_slot(idx).len(),
"mismatched number of features: expected {:?}, got {:?} for shape {:?}",
total_count,
batch.input_slot(idx).len(),
full_shape
);
let shape = full_shape;
let tensor = Tensor::from_shape(&shape, batch.input_slot(idx))?;
inputs.push(tensor.into());
}
Ok(inputs)
}
fn execute(&self, pad: &mut ScratchPadView<'_>, model_api: &ModelApi) -> Result<()> {
let inputs = self.build_inputs(pad, model_api)?;
let result = self.plan.run(inputs)?;
for idx in 0..model_api.outputs.len() {
let value = result[idx].as_slice::<f32>()?;
pad.output_slot_mut(idx).copy_from_slice(value);
}
Ok(())
}
}