mod scratch;
mod wrapper;
use self::scratch::ScratchPad;
use crate::inferer::{Inferer, Response, State};
pub use scratch::ScratchPadView;
use std::collections::HashMap;
pub use wrapper::Batched;
pub struct Batcher {
scratch: ScratchPad,
}
impl Batcher {
pub fn new(inferer: &dyn Inferer) -> Self {
Self {
scratch: ScratchPad::new_for_shapes(
inferer.raw_input_shapes(),
inferer.raw_output_shapes(),
),
}
}
pub fn new_sized(inferer: &dyn Inferer, size: usize) -> Self {
Self {
scratch: ScratchPad::new_with_size(
inferer.raw_input_shapes(),
inferer.raw_output_shapes(),
size,
),
}
}
#[inline]
fn input_slot(&self, name: &str) -> Option<usize> {
self.scratch
.inputs
.iter()
.position(|slot| slot.name == name)
}
pub fn push(&mut self, id: u64, state: State<'_>) -> anyhow::Result<()> {
self.scratch.next(id);
for (k, v) in state.data {
let slot = self
.input_slot(k)
.ok_or_else(|| anyhow::anyhow!("key doesn't match an input: {:?}", k))?;
self.scratch.push(slot, v);
}
Ok(())
}
pub fn extend<'a, Iter: IntoIterator<Item = (u64, State<'a>)>>(
&mut self,
states: Iter,
) -> anyhow::Result<()> {
for (id, state) in states {
self.push(id, state)?;
}
Ok(())
}
pub fn execute<'b>(
&mut self,
inferer: &'b dyn Inferer,
) -> anyhow::Result<HashMap<u64, Response<'b>>> {
let mut total_offset = 0;
while self.scratch.batch_size > 0 {
let preferred_batch_size = inferer.select_batch_size(self.scratch.batch_size);
let mut view = self.scratch.chunk(total_offset, preferred_batch_size);
inferer.infer_raw(&mut view)?;
total_offset += preferred_batch_size;
}
let mut outputs = vec![Response::empty(); self.scratch.ids.len()];
for slot in 0..inferer.output_shapes().len() {
let slot_name = &inferer.output_shapes()[slot].0;
let scratch_slot = self
.scratch
.lookup_output_slot(slot_name)
.expect("invalid inferer passed to `Batcher::execute`");
for (idx, o) in outputs.iter_mut().enumerate() {
let slot_response = self.scratch.output_slot(scratch_slot, idx..idx + 1);
o.data.insert(slot_name, slot_response.to_owned());
}
}
Ok(self.scratch.ids.drain(..).zip(outputs).collect::<_>())
}
pub fn is_empty(&self) -> bool {
self.scratch.batch_size == 0
}
pub fn len(&self) -> usize {
self.scratch.batch_size
}
}