tract_hir/ops/
element_wise.rs1use tract_core::ops::cast::wire_cast;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Debug, Clone)]
7pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>);
8
9impl Expansion for ElementWiseOp {
10 fn name(&self) -> StaticName {
11 self.0.name().into()
12 }
13
14 fn wire(
15 &self,
16 prefix: &str,
17 target: &mut TypedModel,
18 inputs: &[OutletId],
19 ) -> TractResult<TVec<OutletId>> {
20 let operating_datum_type =
21 self.0.operating_datum_type(target.outlet_fact(inputs[0])?.datum_type);
22 let wires = wire_cast(prefix, target, inputs, operating_datum_type)?;
23 target.wire_node(
24 prefix,
25 tract_core::ops::element_wise::ElementWiseOp(self.0.clone(), None),
26 &wires,
27 )
28 }
29
30 fn rules<'r, 'p: 'r, 's: 'r>(
31 &'s self,
32 s: &mut Solver<'r>,
33 inputs: &'p [TensorProxy],
34 outputs: &'p [TensorProxy],
35 ) -> InferenceResult {
36 check_input_arity(inputs, 1)?;
37 check_output_arity(outputs, 1)?;
38 s.given(&inputs[0].datum_type, move |s, dt| {
39 let dt = self.0.operating_datum_type(dt);
40 if let Some(dt) = self.0.output_type(dt) {
41 s.equals(&outputs[0].datum_type, dt)
42 } else {
43 s.equals(&outputs[0].datum_type, dt)
44 }
45 })?;
46 s.equals(&inputs[0].shape, &outputs[0].shape)?;
47 Ok(())
48 }
49}
50
51pub trait ElementWiseIntoHir {
52 fn into_hir(self) -> Box<dyn InferenceOp>;
53}
54
55impl ElementWiseIntoHir for tract_core::ops::element_wise::ElementWiseOp {
56 fn into_hir(self) -> Box<dyn InferenceOp> {
57 expand(ElementWiseOp(self.0))
58 }
59}