use std::any::Any;
use std::fmt::Debug;
use downcast_rs::Downcast;
use dyn_clone::DynClone;
use lazy_static::lazy_static;
use tract_linalg::multithread::Executor;
use crate::internal::*;
#[derive(Clone, Debug, Default)]
pub struct RunOptions {
pub skip_order_opt_ram: bool,
pub executor: Option<Executor>,
pub memory_sizing_hints: Option<SymbolValues>,
}
pub trait Runtime: Debug + Send + Sync + 'static {
fn name(&self) -> StaticName;
fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
self.prepare_with_options(model, &Default::default())
}
fn check(&self) -> TractResult<()>;
fn prepare_with_options(
&self,
model: TypedModel,
options: &RunOptions,
) -> TractResult<Box<dyn Runnable>>;
}
pub trait Runnable: Any + Downcast + Debug + Send + Sync + 'static {
fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
self.spawn()?.run(inputs)
}
fn spawn(&self) -> TractResult<Box<dyn State>>;
fn input_count(&self) -> usize {
self.typed_model().context("Fallback implementation on typed_model()").unwrap().inputs.len()
}
fn output_count(&self) -> usize {
self.typed_model()
.context("Fallback implementation on typed_model()")
.unwrap()
.outputs
.len()
}
fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
self.typed_model()
.context("Fallback implementation on typed_model()")
.unwrap()
.input_fact(ix)
}
fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
self.typed_model()
.context("Fallback implementation on typed_model()")
.unwrap()
.output_fact(ix)
}
fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
lazy_static! {
static ref NO_PROPERTIES: HashMap<String, Arc<Tensor>> = Default::default();
};
self.typed_model().map(|model| &model.properties).unwrap_or(&NO_PROPERTIES)
}
fn typed_plan(&self) -> Option<&Arc<TypedSimplePlan>>;
fn typed_model(&self) -> Option<&Arc<TypedModel>>;
}
impl_downcast!(Runnable);
pub trait State: Any + Downcast + Debug + 'static {
fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>>;
fn runnable(&self) -> &dyn Runnable;
fn input_count(&self) -> usize {
self.runnable().input_count()
}
fn output_count(&self) -> usize {
self.runnable().output_count()
}
fn freeze(&self) -> Box<dyn FrozenState>;
fn freeze_into(self: Box<Self>) -> Box<dyn FrozenState> {
self.freeze()
}
}
impl_downcast!(State);
pub trait FrozenState: Any + Debug + DynClone + Send {
fn unfreeze(&self) -> Box<dyn State>;
fn input_count(&self) -> usize;
fn output_count(&self) -> usize;
}
dyn_clone::clone_trait_object!(FrozenState);
#[derive(Debug)]
pub struct DefaultRuntime;
impl Runtime for DefaultRuntime {
fn name(&self) -> StaticName {
Cow::Borrowed("default")
}
fn prepare_with_options(
&self,
model: TypedModel,
options: &RunOptions,
) -> TractResult<Box<dyn Runnable>> {
let model = model.into_optimized()?;
Ok(Box::new(TypedSimplePlan::new_with_options(model, options)?))
}
fn check(&self) -> TractResult<()> {
Ok(())
}
}
impl Runnable for Arc<TypedRunnableModel> {
fn spawn(&self) -> TractResult<Box<dyn State>> {
Ok(Box::new(self.spawn()?))
}
fn typed_plan(&self) -> Option<&Self> {
Some(self)
}
fn typed_model(&self) -> Option<&Arc<TypedModel>> {
Some(&self.model)
}
fn input_count(&self) -> usize {
self.model.inputs.len()
}
fn output_count(&self) -> usize {
self.model.outputs.len()
}
fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
self.model.input_fact(ix)
}
fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
self.model.output_fact(ix)
}
}
impl State for TypedSimpleState {
fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
self.run(inputs)
}
fn runnable(&self) -> &dyn Runnable {
&self.plan
}
fn freeze(&self) -> Box<dyn FrozenState> {
Box::new(TypedSimpleState::freeze(self))
}
fn freeze_into(self: Box<Self>) -> Box<dyn FrozenState> {
Box::new(TypedSimpleState::freeze_into(*self))
}
}
impl FrozenState for TypedFrozenSimpleState {
fn unfreeze(&self) -> Box<dyn State> {
Box::new(TypedFrozenSimpleState::unfreeze(self))
}
fn input_count(&self) -> usize {
self.plan().model().input_outlets().unwrap().len()
}
fn output_count(&self) -> usize {
self.plan().model().output_outlets().unwrap().len()
}
}
pub struct InventorizedRuntime(pub &'static dyn Runtime);
impl Runtime for InventorizedRuntime {
fn name(&self) -> StaticName {
self.0.name()
}
fn prepare_with_options(
&self,
model: TypedModel,
options: &RunOptions,
) -> TractResult<Box<dyn Runnable>> {
self.0.prepare_with_options(model, options)
}
fn check(&self) -> TractResult<()> {
self.0.check()
}
}
impl Debug for InventorizedRuntime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
inventory::collect!(InventorizedRuntime);
pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
inventory::iter::<InventorizedRuntime>().filter(|rt| rt.check().is_ok()).map(|ir| ir.0)
}
pub fn runtime_for_name(s: &str) -> TractResult<Option<&'static dyn Runtime>> {
let Some(rt) = inventory::iter::<InventorizedRuntime>().find(|rt| rt.name() == s) else {
return Ok(None);
};
rt.check()?;
Ok(Some(rt.0))
}
#[macro_export]
macro_rules! register_runtime {
($type: ty= $val:expr) => {
static D: $type = $val;
inventory::submit! { $crate::runtime::InventorizedRuntime(&D) }
};
}
register_runtime!(DefaultRuntime = DefaultRuntime);