use crate::batcher::ScratchPadView;
use crate::inferer::{
BasicInferer, DynamicInferer, FixedBatchInferer, Inferer, MemoizingDynamicInferer,
};
pub trait InfererWrapper {
fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)];
fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)];
fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()>;
fn begin_agent(&self, inferer: &dyn Inferer, id: u64);
fn end_agent(&self, inferer: &dyn Inferer, id: u64);
}
pub struct BaseWrapper;
impl InfererWrapper for BaseWrapper {
fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
inferer.input_shapes()
}
fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
inferer.output_shapes()
}
fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
inferer.infer_raw(batch)
}
fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
inferer.begin_agent(id);
}
fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
inferer.end_agent(id);
}
}
impl InfererWrapper for Box<dyn InfererWrapper> {
fn input_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
self.as_ref().input_shapes(inferer)
}
fn output_shapes<'a>(&'a self, inferer: &'a dyn Inferer) -> &'a [(String, Vec<usize>)] {
self.as_ref().output_shapes(inferer)
}
fn invoke(&self, inferer: &dyn Inferer, batch: &mut ScratchPadView<'_>) -> anyhow::Result<()> {
self.as_ref().invoke(inferer, batch)
}
fn begin_agent(&self, inferer: &dyn Inferer, id: u64) {
self.as_ref().begin_agent(inferer, id);
}
fn end_agent(&self, inferer: &dyn Inferer, id: u64) {
self.as_ref().end_agent(inferer, id);
}
}
pub struct StatefulInferer<WrapStack: InfererWrapper, Inf: Inferer> {
wrapper_stack: WrapStack,
inferer: Inf,
}
impl<WrapStack: InfererWrapper, Inf: Inferer> StatefulInferer<WrapStack, Inf> {
pub fn new(wrapper_stack: WrapStack, inferer: Inf) -> Self {
Self {
wrapper_stack,
inferer,
}
}
pub fn with_new_inferer<NewInf: Inferer>(
self,
new_inferer: NewInf,
) -> Result<StatefulInferer<WrapStack, NewInf>, (Self, anyhow::Error)> {
if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) {
return Err((self, e));
}
Ok(StatefulInferer {
wrapper_stack: self.wrapper_stack,
inferer: new_inferer,
})
}
pub fn replace_inferer(&mut self, new_inferer: Inf) -> anyhow::Result<()> {
if let Err(e) = Self::check_compatible_shapes(&self.inferer, &new_inferer) {
Err(e)
} else {
self.inferer = new_inferer;
Ok(())
}
}
pub fn check_compatible_shapes<Old: Inferer, New: Inferer>(
old: &Old,
new: &New,
) -> Result<(), anyhow::Error> {
let old_in = old.raw_input_shapes();
let new_in = new.raw_input_shapes();
let old_out = old.raw_output_shapes();
let new_out = new.raw_output_shapes();
for (i, (o, n)) in old_in.iter().zip(new_in).enumerate() {
if o != n {
if o.0 != n.0 {
return Err(anyhow::format_err!(
"name mismatch for input {i}: '{}' != '{}'",
o.0,
n.0,
));
}
return Err(anyhow::format_err!(
"shape mismatch for input '{}': {:?} != {:?}",
o.0,
o.1,
n.1,
));
}
}
for (i, (o, n)) in old_out.iter().zip(new_out).enumerate() {
if o != n {
if o.0 != n.0 {
return Err(anyhow::format_err!(
"name mismatch for output {i}: '{}' != '{}'",
o.0,
n.0,
));
}
return Err(anyhow::format_err!(
"shape mismatch for output {}: {:?} != {:?}",
o.0,
o.1,
n.1,
));
}
}
Ok(())
}
pub fn input_shapes(&self) -> &[(String, Vec<usize>)] {
self.wrapper_stack.input_shapes(&self.inferer)
}
pub fn output_shapes(&self) -> &[(String, Vec<usize>)] {
self.wrapper_stack.output_shapes(&self.inferer)
}
}
impl<WrapStack: InfererWrapper, Inf: Inferer> Inferer for StatefulInferer<WrapStack, Inf> {
fn select_batch_size(&self, max_count: usize) -> usize {
self.inferer.select_batch_size(max_count)
}
fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> anyhow::Result<(), anyhow::Error> {
self.wrapper_stack.invoke(&self.inferer, batch)
}
fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
self.inferer.raw_input_shapes()
}
fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
self.inferer.raw_output_shapes()
}
fn begin_agent(&self, id: u64) {
self.wrapper_stack.begin_agent(&self.inferer, id);
}
fn end_agent(&self, id: u64) {
self.wrapper_stack.end_agent(&self.inferer, id);
}
}
pub trait IntoStateful: Inferer + Sized {
fn into_stateful<WrapStack: InfererWrapper>(
self,
wrapper_stack: WrapStack,
) -> StatefulInferer<WrapStack, Self> {
StatefulInferer::new(wrapper_stack, self)
}
}
impl IntoStateful for BasicInferer {}
impl IntoStateful for DynamicInferer {}
impl IntoStateful for MemoizingDynamicInferer {}
impl IntoStateful for FixedBatchInferer {}
pub trait InfererWrapperExt: InfererWrapper + Sized {
fn wrap<Inf: Inferer>(self, inferer: Inf) -> StatefulInferer<Self, Inf> {
StatefulInferer::new(self, inferer)
}
}
impl<T: InfererWrapper> InfererWrapperExt for T {}