#![warn(clippy::all)]
use anyhow::{Error, Result};
use std::collections::HashMap;
mod basic;
mod dynamic;
mod fixed;
mod helpers;
mod memoizing;
pub use basic::BasicInferer;
pub use dynamic::DynamicInferer;
pub use fixed::FixedBatchInferer;
pub use memoizing::MemoizingDynamicInferer;
use crate::{
batcher::{Batched, Batcher, ScratchPadView},
epsilon::{EpsilonInjector, NoiseGenerator},
};
#[derive(Clone, Debug)]
pub struct State<'a> {
pub data: HashMap<&'a str, Vec<f32>>,
}
impl<'a> State<'a> {
pub fn empty() -> Self {
Self {
data: Default::default(),
}
}
}
#[derive(Clone, Debug, Default)]
pub struct Response<'a> {
pub data: HashMap<&'a str, Vec<f32>>,
}
impl<'a> Response<'a> {
pub fn empty() -> Self {
Self {
data: Default::default(),
}
}
}
pub trait Inferer {
fn select_batch_size(&self, max_count: usize) -> usize;
fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error>;
fn input_shapes(&self) -> &[(String, Vec<usize>)] {
self.raw_input_shapes()
}
fn output_shapes(&self) -> &[(String, Vec<usize>)] {
self.raw_output_shapes()
}
fn raw_input_shapes(&self) -> &[(String, Vec<usize>)];
fn raw_output_shapes(&self) -> &[(String, Vec<usize>)];
fn begin_agent(&self, id: u64);
fn end_agent(&self, id: u64);
}
pub trait InfererProvider {
fn build_basic(self) -> Result<BasicInferer>;
fn build_fixed(self, sizes: &[usize]) -> Result<FixedBatchInferer>;
fn build_memoizing(self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer>;
fn build_dynamic(self) -> Result<DynamicInferer>;
}
pub struct InfererBuilder<P: InfererProvider> {
provider: P,
}
impl<P> InfererBuilder<P>
where
P: InfererProvider,
{
pub fn new(provider: P) -> Self {
Self { provider }
}
pub fn build_basic(self) -> Result<BasicInferer> {
self.provider.build_basic()
}
pub fn build_fixed(self, sizes: &[usize]) -> Result<FixedBatchInferer> {
self.provider.build_fixed(sizes)
}
pub fn build_dynamic(self) -> Result<DynamicInferer> {
self.provider.build_dynamic()
}
pub fn build_memoizing(self, preload_sizes: &[usize]) -> Result<MemoizingDynamicInferer> {
self.provider.build_memoizing(preload_sizes)
}
}
pub trait InfererExt: Inferer + Sized {
fn with_default_epsilon(self, key: &str) -> Result<EpsilonInjector<Self>> {
EpsilonInjector::wrap(self, key)
}
fn with_epsilon<G: NoiseGenerator>(
self,
generator: G,
key: &str,
) -> Result<EpsilonInjector<Self, G>> {
EpsilonInjector::with_generator(self, generator, key)
}
fn into_batched(self) -> Batched<Self> {
Batched::wrap(self)
}
#[deprecated(
note = "Please use the more explicit 'infer_batch' instead.",
since = "0.3.0"
)]
fn infer(
&mut self,
observations: HashMap<u64, State<'_>>,
) -> Result<HashMap<u64, Response<'_>>, Error> {
self.infer_batch(observations)
}
fn infer_batch<'this>(
&'this self,
batch: HashMap<u64, State<'_>>,
) -> Result<HashMap<u64, Response<'this>>, anyhow::Error> {
let mut batcher = Batcher::new_sized(self, batch.len());
batcher.extend(batch)?;
batcher.execute(self)
}
fn infer_single<'this>(
&'this self,
input: State<'_>,
) -> Result<Response<'this>, anyhow::Error> {
let mut batcher = Batcher::new_sized(self, 1);
batcher.push(0, input)?;
Ok(batcher.execute(self)?.remove(&0).unwrap())
}
}
impl<T> InfererExt for T where T: Inferer + Sized {}
impl Inferer for Box<dyn Inferer + Send> {
fn select_batch_size(&self, max_count: usize) -> usize {
self.as_ref().select_batch_size(max_count)
}
fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
self.as_ref().infer_raw(batch)
}
fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
self.as_ref().raw_input_shapes()
}
fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
self.as_ref().raw_output_shapes()
}
fn begin_agent(&self, id: u64) {
self.as_ref().begin_agent(id);
}
fn end_agent(&self, id: u64) {
self.as_ref().end_agent(id);
}
}
impl Inferer for Box<dyn Inferer> {
fn select_batch_size(&self, max_count: usize) -> usize {
self.as_ref().select_batch_size(max_count)
}
fn infer_raw(&self, batch: &mut ScratchPadView<'_>) -> Result<(), anyhow::Error> {
self.as_ref().infer_raw(batch)
}
fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
self.as_ref().raw_input_shapes()
}
fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
self.as_ref().raw_output_shapes()
}
fn begin_agent(&self, id: u64) {
self.as_ref().begin_agent(id);
}
fn end_agent(&self, id: u64) {
self.as_ref().end_agent(id);
}
}