tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use std::any::Any;
use std::fmt::Debug;

use downcast_rs::Downcast;
use dyn_clone::DynClone;
use lazy_static::lazy_static;
use tract_linalg::multithread::Executor;

use crate::internal::*;

#[derive(Clone, Debug, Default)]
pub struct RunOptions {
    /// Use the simple ordering instead of the newer memory friendly one
    pub skip_order_opt_ram: bool,

    /// Override default global executor
    pub executor: Option<Executor>,

    /// Memory sizing hints
    pub memory_sizing_hints: Option<SymbolValues>,
}

pub trait Runtime: Debug + Send + Sync + 'static {
    fn name(&self) -> StaticName;
    fn prepare(&self, model: TypedModel) -> TractResult<Box<dyn Runnable>> {
        self.prepare_with_options(model, &Default::default())
    }
    fn check(&self) -> TractResult<()>;
    fn prepare_with_options(
        &self,
        model: TypedModel,
        options: &RunOptions,
    ) -> TractResult<Box<dyn Runnable>>;
}

pub trait Runnable: Any + Downcast + Debug + Send + Sync + 'static {
    fn run(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        self.spawn()?.run(inputs)
    }
    fn spawn(&self) -> TractResult<Box<dyn State>>;
    fn input_count(&self) -> usize {
        self.typed_model().context("Fallback implementation on typed_model()").unwrap().inputs.len()
    }
    fn output_count(&self) -> usize {
        self.typed_model()
            .context("Fallback implementation on typed_model()")
            .unwrap()
            .outputs
            .len()
    }
    fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
        self.typed_model()
            .context("Fallback implementation on typed_model()")
            .unwrap()
            .input_fact(ix)
    }
    fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
        self.typed_model()
            .context("Fallback implementation on typed_model()")
            .unwrap()
            .output_fact(ix)
    }
    fn properties(&self) -> &HashMap<String, Arc<Tensor>> {
        lazy_static! {
            static ref NO_PROPERTIES: HashMap<String, Arc<Tensor>> = Default::default();
        };
        self.typed_model().map(|model| &model.properties).unwrap_or(&NO_PROPERTIES)
    }

    fn typed_plan(&self) -> Option<&Arc<TypedSimplePlan>>;
    fn typed_model(&self) -> Option<&Arc<TypedModel>>;
}
impl_downcast!(Runnable);

pub trait State: Any + Downcast + Debug + 'static {
    fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>>;

    fn runnable(&self) -> &dyn Runnable;

    fn input_count(&self) -> usize {
        self.runnable().input_count()
    }

    fn output_count(&self) -> usize {
        self.runnable().output_count()
    }

    fn freeze(&self) -> Box<dyn FrozenState>;
    /// Consuming freeze: moves data instead of cloning.
    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenState> {
        self.freeze()
    }
}
impl_downcast!(State);

pub trait FrozenState: Any + Debug + DynClone + Send {
    fn unfreeze(&self) -> Box<dyn State>;
    fn input_count(&self) -> usize;
    fn output_count(&self) -> usize;
}
dyn_clone::clone_trait_object!(FrozenState);

#[derive(Debug)]
pub struct DefaultRuntime;

impl Runtime for DefaultRuntime {
    fn name(&self) -> StaticName {
        Cow::Borrowed("default")
    }

    fn prepare_with_options(
        &self,
        model: TypedModel,
        options: &RunOptions,
    ) -> TractResult<Box<dyn Runnable>> {
        let model = model.into_optimized()?;
        Ok(Box::new(TypedSimplePlan::new_with_options(model, options)?))
    }

    fn check(&self) -> TractResult<()> {
        Ok(())
    }
}

impl Runnable for Arc<TypedRunnableModel> {
    fn spawn(&self) -> TractResult<Box<dyn State>> {
        Ok(Box::new(self.spawn()?))
    }

    fn typed_plan(&self) -> Option<&Self> {
        Some(self)
    }

    fn typed_model(&self) -> Option<&Arc<TypedModel>> {
        Some(&self.model)
    }

    fn input_count(&self) -> usize {
        self.model.inputs.len()
    }

    fn output_count(&self) -> usize {
        self.model.outputs.len()
    }

    fn input_fact(&self, ix: usize) -> TractResult<&TypedFact> {
        self.model.input_fact(ix)
    }
    fn output_fact(&self, ix: usize) -> TractResult<&TypedFact> {
        self.model.output_fact(ix)
    }
}

impl State for TypedSimpleState {
    fn run(&mut self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        self.run(inputs)
    }

    fn runnable(&self) -> &dyn Runnable {
        &self.plan
    }

    fn freeze(&self) -> Box<dyn FrozenState> {
        Box::new(TypedSimpleState::freeze(self))
    }

    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenState> {
        Box::new(TypedSimpleState::freeze_into(*self))
    }
}

impl FrozenState for TypedFrozenSimpleState {
    fn unfreeze(&self) -> Box<dyn State> {
        Box::new(TypedFrozenSimpleState::unfreeze(self))
    }

    fn input_count(&self) -> usize {
        self.plan().model().input_outlets().unwrap().len()
    }

    fn output_count(&self) -> usize {
        self.plan().model().output_outlets().unwrap().len()
    }
}

pub struct InventorizedRuntime(pub &'static dyn Runtime);

impl Runtime for InventorizedRuntime {
    fn name(&self) -> StaticName {
        self.0.name()
    }

    fn prepare_with_options(
        &self,
        model: TypedModel,
        options: &RunOptions,
    ) -> TractResult<Box<dyn Runnable>> {
        self.0.prepare_with_options(model, options)
    }

    fn check(&self) -> TractResult<()> {
        self.0.check()
    }
}

impl Debug for InventorizedRuntime {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.0.fmt(f)
    }
}

inventory::collect!(InventorizedRuntime);

pub fn runtimes() -> impl Iterator<Item = &'static dyn Runtime> {
    inventory::iter::<InventorizedRuntime>().filter(|rt| rt.check().is_ok()).map(|ir| ir.0)
}

pub fn runtime_for_name(s: &str) -> TractResult<Option<&'static dyn Runtime>> {
    let Some(rt) = inventory::iter::<InventorizedRuntime>().find(|rt| rt.name() == s) else {
        return Ok(None);
    };
    rt.check()?;
    Ok(Some(rt.0))
}

#[macro_export]
macro_rules! register_runtime {
    ($type: ty= $val:expr) => {
        static D: $type = $val;
        inventory::submit! { $crate::runtime::InventorizedRuntime(&D) }
    };
}

register_runtime!(DefaultRuntime = DefaultRuntime);