1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110
use crate::internal::*; use tract_core::internal::*; pub fn expand<E: Expansion>(e: E) -> Box<dyn InferenceOp> { Box::new(Box::new(e) as Box<dyn Expansion>) } pub trait Expansion: tract_core::dyn_clone::DynClone + std::fmt::Debug + Send + Sync + tract_core::downcast_rs::Downcast + tract_core::internal::DynHash { fn name(&self) -> Cow<str>; fn op_families(&self) -> &'static [&'static str]; fn validation(&self) -> Validation { Validation::Accurate } fn info(&self) -> TractResult<Vec<String>> { Ok(vec![]) } fn nboutputs(&self) -> TractResult<usize> { Ok(1) } fn wire( &self, prefix: &str, model: &mut TypedModel, inputs: &[OutletId], ) -> TractResult<TVec<OutletId>>; fn rules<'r, 'p: 'r, 's: 'r>( &'s self, s: &mut Solver<'r>, inputs: &'p [TensorProxy], outputs: &'p [TensorProxy], ) -> InferenceResult; } tract_core::dyn_clone::clone_trait_object!(Expansion); impl Hash for Box<dyn Expansion> { fn hash<H: std::hash::Hasher>(&self, state: &mut H) { std::hash::Hash::hash(&self.type_id(), state); self.dyn_hash(state) } } tract_linalg::impl_dyn_hash!(Box<dyn Expansion>); impl Op for Box<dyn Expansion> { fn name(&self) -> Cow<str> { self.as_ref().name().into() } fn op_families(&self) -> &'static [&'static str] { self.as_ref().op_families() } fn info(&self) -> TractResult<Vec<String>> { self.as_ref().info() } not_a_typed_op!(); not_a_pulsed_op!(); } impl StatelessOp for Box<dyn Expansion> { fn eval(&self, inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> { let mut adhoc = TypedModel::default(); let wires = inputs .iter() .enumerate() .map(|(ix, i)| adhoc.add_source(format!("adhoc-source-{}", ix), TypedFact::from(&**i))) .collect::<TractResult<TVec<OutletId>>>()?; let wires = self.wire("adhoc", &mut adhoc, &*wires)?; adhoc.set_output_outlets(&*wires)?; SimplePlan::new(adhoc)?.run(inputs.into_iter().map(|t| t.into_tensor()).collect()) } } impl InferenceRulesOp for Box<dyn Expansion> { fn rules<'r, 'p: 'r, 's: 'r>( &'s self, s: &mut Solver<'r>, inputs: &'p [TensorProxy], outputs: &'p [TensorProxy], ) -> InferenceResult { self.as_ref().rules(s, inputs, outputs) } fn to_typed( &self, _source: &InferenceModel, node: &InferenceNode, target: &mut TypedModel, mapping: &HashMap<OutletId, OutletId>, ) -> TractResult<TVec<OutletId>> { let inputs = node.inputs.iter().map(|i| mapping[i]).collect::<Vec<_>>(); self.wire(&node.name, target, &inputs) } fn nboutputs(&self) -> TractResult<usize> { self.as_ref().nboutputs() } as_op!(); }