use std::fmt;
use downcast_rs::Downcast;
use dyn_clone;
#[macro_use]
pub mod macros;
#[macro_use]
pub mod element_wise;
#[macro_use]
pub mod binary;
pub mod array;
pub mod cast;
pub mod change_axes;
pub mod cnn;
pub mod downsample;
pub mod dummy;
pub mod einsum;
pub mod fft;
pub mod identity;
pub mod konst;
pub mod logic;
pub mod math;
pub mod matmul;
pub mod memory;
pub mod nn;
pub mod quant;
pub mod scan;
pub mod source;
pub mod submodel;
pub mod unimpl;
pub use downsample::Downsample;
pub use memory::*;
use crate::internal::*;
use crate::optim::OptimizerSession;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum Validation {
Random,
Rounding,
Accurate,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)]
pub enum Cost {
Div(DatumType),
FMA(DatumType),
Buffer(DatumType),
Params(DatumType),
}
impl Cost {
pub fn is_compute(&self) -> bool {
use Cost::*;
match self {
FMA(_) | Div(_) => true,
Buffer(_) | Params(_) => false,
}
}
}
pub trait FrozenOpState: fmt::Debug + dyn_clone::DynClone + Send + 'static {
fn unfreeze(&self) -> Box<dyn OpState>;
}
pub trait OpStateFreeze {
fn freeze(&self) -> Box<dyn FrozenOpState>;
}
dyn_clone::clone_trait_object!(FrozenOpState);
pub trait OpState: fmt::Debug + dyn_clone::DynClone + OpStateFreeze + Downcast {
fn eval(
&mut self,
session: &mut SessionState,
op: &dyn Op,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>>;
}
dyn_clone::clone_trait_object!(OpState);
impl_downcast!(OpState);
pub trait EvalOp {
#[allow(unused_variables)]
fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
bail!("stateless evaluation not implemented")
}
#[allow(unused_variables)]
fn eval_with_session(
&self,
session: &SessionState,
inputs: TVec<TValue>,
) -> TractResult<TVec<TValue>> {
self.eval(inputs).context("Running legacy eval")
}
#[allow(unused_variables)]
fn state(
&self,
session: &mut SessionState,
node_id: usize,
) -> TractResult<Option<Box<dyn OpState>>> {
Ok(None)
}
fn is_stateless(&self) -> bool;
}
pub trait Op: fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp {
fn name(&self) -> Cow<str>;
fn validation(&self) -> Validation {
Validation::Accurate
}
fn same_as(&self, _other: &dyn Op) -> bool {
false
}
fn info(&self) -> TractResult<Vec<String>> {
Ok(vec![])
}
fn as_typed(&self) -> Option<&dyn TypedOp>;
}
pub trait TypedOp:
Op + fmt::Debug + dyn_clone::DynClone + Send + Sync + 'static + Downcast + EvalOp
{
fn as_op(&self) -> &dyn Op;
fn as_op_mut(&mut self) -> &mut dyn Op;
fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>>;
#[allow(unused_variables)]
fn axes_mapping(
&self,
inputs: &[&TypedFact],
outputs: &[&TypedFact],
) -> TractResult<AxesMapping> {
AxesMapping::disconnected(inputs, outputs)
}
fn fuse(&self, _model: &TypedModel, _node: &TypedNode) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}
#[allow(unused_variables)]
fn declutter_with_session(
&self,
session: &mut OptimizerSession,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
self.declutter(model, node)
}
#[allow(unused_variables)]
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}
fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
Ok(tvec!())
}
#[allow(unused_variables)]
fn suggested_axis_changes(&self) -> TractResult<TVec<(InOut, AxisOp)>> {
Ok(tvec!())
}
#[allow(unused_variables)]
fn change_axes(
&self,
model: &TypedModel,
node: &TypedNode,
io: InOut,
change: &AxisOp,
) -> TractResult<Option<AxisChangeConsequence>> {
Ok(None)
}
#[allow(unused_variables)]
#[allow(clippy::too_many_arguments)]
fn slice(
&self,
patch: &mut TypedModelPatch,
prefix: &str,
inputs: &[OutletId],
output_axis: usize,
start: usize,
end: usize,
) -> TractResult<Option<TVec<OutletId>>> {
Ok(None)
}
#[allow(unused_variables)]
fn quantize(
&self,
model: &TypedModel,
node: &TypedNode,
dt: DatumType,
scale: f32,
zero_point: i32,
) -> TractResult<Option<Box<dyn TypedOp>>> {
Ok(None)
}
#[allow(unused_variables)]
fn concretize_dims(
&self,
source: &TypedModel,
node: &TypedNode,
target: &mut TypedModel,
mapping: &HashMap<OutletId, OutletId>,
values: &SymbolValues,
) -> TractResult<TVec<OutletId>> {
let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<TVec<_>>();
target.wire_node(&node.name, node.op.clone(), &inputs)
}
#[allow(unused_variables)]
fn codegen(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
Ok(None)
}
#[allow(unused_variables)]
fn nested_model_multipliers(&self, inputs: &[&TypedFact]) -> Vec<(Cow<str>, f64)> {
vec![]
}
}
impl_downcast!(Op);
impl_downcast!(TypedOp);
dyn_clone::clone_trait_object!(Op);
dyn_clone::clone_trait_object!(TypedOp);
impl<O: Op> From<O> for Box<dyn Op> {
fn from(it: O) -> Box<dyn Op> {
Box::new(it)
}
}
impl<O: TypedOp> From<O> for Box<dyn TypedOp> {
fn from(it: O) -> Box<dyn TypedOp> {
Box::new(it)
}
}
impl<'a> From<&'a Box<dyn TypedOp>> for Box<dyn TypedOp> {
fn from(it: &'a Box<dyn TypedOp>) -> Box<dyn TypedOp> {
it.clone()
}
}
impl AsRef<dyn Op> for dyn TypedOp {
fn as_ref(&self) -> &dyn Op {
self.as_op()
}
}
impl AsRef<dyn Op> for Box<dyn TypedOp> {
fn as_ref(&self) -> &dyn Op {
self.as_op()
}
}
impl AsMut<dyn Op> for dyn TypedOp {
fn as_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
}
impl AsMut<dyn Op> for Box<dyn TypedOp> {
fn as_mut(&mut self) -> &mut dyn Op {
self.as_op_mut()
}
}
impl std::fmt::Display for Box<dyn Op> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}", self.name())
}
}
impl std::fmt::Display for Box<dyn TypedOp> {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
write!(fmt, "{}", self.name())
}
}