use std::fmt;
use downcast_rs::Downcast;
use objekt;
#[macro_use]
pub mod macros;
#[macro_use]
pub mod element_wise;
#[macro_use]
pub mod binary;
pub mod axis;
pub mod array;
pub mod cast;
pub mod cnn;
pub mod downsample;
pub mod dummy;
pub mod identity;
pub mod konst;
pub mod logic;
pub mod math;
pub mod nn;
pub mod scan;
pub mod source;
pub mod unimpl;
pub use axis::{AxesInfo, AxisInfo};
pub use downsample::Downsample;
pub fn check_input_arity(inputs: &[TensorProxy], expected: usize) -> TractResult<()> {
if inputs.len() != expected {
bail!("Wrong input number. Rules expect {}, node has {}.", expected, inputs.len())
} else {
Ok(())
}
}
pub fn check_output_arity(outputs: &[TensorProxy], expected: usize) -> TractResult<()> {
if outputs.len() != expected {
bail!("Wrong output number. Rules expect {}, node has {}.", expected, outputs.len())
} else {
Ok(())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Validation {
Random,
Rounding,
Accurate,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Cost {
Div(DatumType),
FMA(DatumType),
}
use crate::internal::*;
pub trait OpState: fmt::Debug + Send + objekt::Clone {
fn eval(
&mut self,
session: &mut SessionState,
op: &dyn Op,
inputs: TVec<Arc<Tensor>>,
) -> TractResult<TVec<Arc<Tensor>>>;
}
pub trait StatelessOp: Op {
fn eval(&self, inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>>;
}
pub trait StatefullOp {
#[allow(unused_variables)]
fn state(
&self,
session: &mut SessionState,
node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>>;
fn as_stateless(&self) -> Option<&dyn StatelessOp> {
None
}
}
impl<O: StatelessOp + Clone> StatefullOp for O {
fn state(
&self,
_session: &mut SessionState,
_node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(None)
}
fn as_stateless(&self) -> Option<&dyn StatelessOp> {
Some(self)
}
}
pub trait Translate<TI1, O1, TI2, O2, Ctx>
where
TI1: Fact + Clone + 'static,
TI2: Fact + Clone + 'static,
O1: fmt::Display + fmt::Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
O2: fmt::Display + fmt::Debug + AsRef<dyn Op> + AsMut<dyn Op> + Clone + 'static,
{
fn translate(
&self,
source: &ModelImpl<TI1, O1>,
node: &BaseNode<TI1, O1>,
target: &mut ModelImpl<TI2, O2>,
mapping: &HashMap<OutletId, OutletId>,
ctx: &Ctx,
) -> TractResult<TVec<OutletId>>;
}
pub trait Op: fmt::Debug + objekt::Clone + Send + Sync + 'static + Downcast + StatefullOp {
fn name(&self) -> Cow<str>;
fn incorporate(
&self,
_model: &InferenceModel,
_node: &InferenceNode,
) -> TractResult<Option<InferenceModelPatch>> {
Ok(None)
}
fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}
fn nested_models(&self) -> Vec<(Cow<str>, &dyn Model)> {
vec![]
}
fn validation(&self) -> Validation {
Validation::Accurate
}
fn same_as(&self, _other: &dyn Op) -> bool {
false
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![])
}
fn as_typed(&self) -> Option<&dyn TypedOp>;
fn as_pulsed(&self) -> Option<&dyn PulsedOp> {
None
}
fn is_canonic(&self) -> bool {
false
}
}
pub trait TypedOp:
Op + fmt::Debug + objekt::Clone + Send + Sync + 'static + Downcast + StatefullOp
{
fn as_op(&self) -> &dyn Op;
fn as_op_mut(&mut self) -> &mut dyn Op;
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
fn axes_info(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<AxesInfo> {
Ok(tvec![].into())
}
fn declutter(
&self,
_model: &TypedModel,
_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}
fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
Ok(tvec!())
}
#[allow(unused_variables)]
fn dispose_dummy_axis(
&self,
model: &TypedModel,
node: &TypedNode,
axis: usize,
) -> TractResult<Option<Box<dyn TypedOp>>> {
Ok(None)
}
fn pulsify(
&self,
_source: &NormalizedModel,
node: &NormalizedNode,
_target: &mut PulsedModel,
_mapping: &HashMap<OutletId, OutletId>,
_pulse: usize,
) -> TractResult<TVec<OutletId>> {
debug!("{:?}", node);
bail!("Operator {} do not support pulsification", self.name())
}
fn codegen(
&self,
_model: &TypedModel,
_node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}
#[allow(unused_variables)]
fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(Cow<str>, f32)> {
vec![]
}
}
pub trait PulsedOp:
Op + fmt::Debug + objekt::Clone + Send + Sync + 'static + Downcast + StatefullOp
{
fn as_op(&self) -> &dyn Op;
fn as_op_mut(&mut self) -> &mut dyn Op;
fn to_typed(&self) -> Box<dyn TypedOp>;
fn pulsed_output_facts(&self, inputs: &[&PulsedFact]) -> TractResult<TVec<PulsedFact>>;
}
impl
crate::ops::Translate<
NormalizedFact,
Box<dyn TypedOp>,
crate::pulse::PulsedFact,
Box<dyn PulsedOp>,
usize,
> for Box<dyn TypedOp>
{
fn translate(
&self,
source: &NormalizedModel,
node: &NormalizedNode,
target: &mut PulsedModel,
mapping: &HashMap<OutletId, OutletId>,
ctx: &usize,
) -> TractResult<TVec<OutletId>> {
self.pulsify(source, node, target, mapping, *ctx)
}
}
pub trait InferenceOp:
Op + fmt::Debug + objekt::Clone + Send + Sync + 'static + Downcast + StatefullOp
{
fn infer(
&mut self,
inputs: TVec<&InferenceFact>,
outputs: TVec<&InferenceFact>,
observed: TVec<&InferenceFact>,
) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)> {
let (infered_inputs, infered_outputs, observed) =
self.infer_facts(inputs, outputs, observed)?;
if let Some(stateless) = self.as_stateless() {
if infered_inputs.iter().all(|i| i.value.is_concrete()) {
let input_values = infered_inputs
.iter()
.map(|i| i.value.concretize().unwrap().clone().into())
.collect();
match stateless.eval(input_values) {
Ok(values) => {
let output_values =
values.into_iter().map(|t| t.into()).collect::<TVec<_>>();
return Ok((infered_inputs, output_values, observed));
}
Err(e) => match e {
TractError(TractErrorKind::StreamTensor, _) => (),
e => return Err(e),
},
}
}
}
return Ok((infered_inputs, infered_outputs, observed));
}
fn observe_outlets(
&self,
_model: &InferenceModel,
_node: &InferenceNode,
) -> TractResult<Vec<OutletId>> {
Ok(vec![])
}
fn infer_facts(
&mut self,
inputs: TVec<&InferenceFact>,
outputs: TVec<&InferenceFact>,
observed: TVec<&InferenceFact>,
) -> TractResult<(TVec<InferenceFact>, TVec<InferenceFact>, TVec<InferenceFact>)>;
fn nboutputs(&self) -> TractResult<usize> {
Ok(1)
}
fn as_op(&self) -> &dyn Op;
fn as_op_mut(&mut self) -> &mut dyn Op;
fn to_typed(
&self,
_source: &InferenceModel,
_node: &InferenceNode,
_target: &mut TypedModel,
_mapping: &HashMap<OutletId, OutletId>,
) -> TractResult<TVec<OutletId>> {
bail!("Operator can not be made a TypedOp.")
}
}
impl crate::ops::Translate<InferenceFact, Box<dyn InferenceOp>, TypedFact, Box<dyn TypedOp>, ()>
for Box<dyn InferenceOp>
{
fn translate(
&self,
source: &InferenceModel,
node: &InferenceNode,
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
_ctx: &(),
) -> TractResult<TVec<OutletId>> {
self.to_typed(source, node, target, mapping)
}
}
impl crate::ops::Translate<PulsedFact, Box<dyn PulsedOp>, TypedFact, Box<dyn TypedOp>, ()>
for Box<dyn PulsedOp>
{
fn translate(
&self,
_source: &PulsedModel,
node: &PulsedNode,
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
_ctx: &(),
) -> TractResult<TVec<OutletId>> {
let inputs: TVec<OutletId> = node.inputs.iter().map(|i| mapping[i]).collect();
target.wire_node(&*node.name, node.op.to_typed(), &*inputs)
}
}
impl_downcast!(Op);
clone_trait_object!(Op);
clone_trait_object!(StatelessOp);
clone_trait_object!(TypedOp);
clone_trait_object!(InferenceOp);
clone_trait_object!(PulsedOp);
impl<O: Op> From<O> for Box<dyn Op> {
fn from(it: O) -> Box<dyn Op> {
Box::new(it)
}
}
impl<O: InferenceOp> From<O> for Box<dyn InferenceOp> {
fn from(it: O) -> Box<dyn InferenceOp> {
Box::new(it)
}
}
impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
fn from(it: O) -> Box<dyn TypedOp> {
Box::new(it)
}
}
impl<O: PulsedOp> From<O> for Box<dyn PulsedOp> {
fn from(it: O) -> Box<dyn PulsedOp> {
Box::new(it)
}
}
impl AsRef<dyn Op> for dyn InferenceOp {
fn as_ref(&self) -> &dyn Op {
self.as_op()
}
}
impl AsRef<dyn Op> for Box<dyn InferenceOp> {
fn as_ref(&self) -> &dyn Op {
self.as_op()
}
}
impl AsMut<dyn Op> for dyn InferenceOp {
fn as_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
}
impl AsMut<dyn Op> for Box<dyn InferenceOp> {
fn as_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
}
impl AsRef<dyn Op> for dyn TypedOp {
fn as_ref(&self) -> &dyn Op {
self.as_op()
}
}
impl AsRef<dyn Op> for Box<dyn TypedOp> {
fn as_ref(&self) -> &dyn Op {
self.as_op()
}
}
impl AsMut<dyn Op> for dyn TypedOp {
fn as_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
}
impl AsMut<dyn Op> for Box<dyn PulsedOp> {
fn as_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
}
impl AsRef<dyn Op> for dyn PulsedOp {
fn as_ref(&self) -> &dyn Op {
self.as_op()
}
}
impl AsRef<dyn Op> for Box<dyn PulsedOp> {
fn as_ref(&self) -> &dyn Op {
self.as_op()
}
}
impl AsMut<dyn Op> for dyn PulsedOp {
fn as_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
}
impl AsMut<dyn Op> for Box<dyn TypedOp> {
fn as_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
}
impl std::fmt::Display for Box<dyn Op> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}", self.name())
}
}
impl std::fmt::Display for Box<dyn InferenceOp> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}", self.name())
}
}
impl std::fmt::Display for Box<dyn TypedOp> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}", self.name())
}
}
impl std::fmt::Display for Box<dyn PulsedOp> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}", self.name())
}
}