tract_hir/ops/
element_wise.rs

1use tract_core::ops::cast::wire_cast;
2
3use crate::infer::*;
4use crate::internal::*;
5
6#[derive(Debug, Clone)]
7pub struct ElementWiseOp(pub Box<dyn ElementWiseMiniOp>);
8
9impl Expansion for ElementWiseOp {
10    fn name(&self) -> StaticName {
11        self.0.name().into()
12    }
13
14    fn wire(
15        &self,
16        prefix: &str,
17        target: &mut TypedModel,
18        inputs: &[OutletId],
19    ) -> TractResult<TVec<OutletId>> {
20        let operating_datum_type =
21            self.0.operating_datum_type(target.outlet_fact(inputs[0])?.datum_type);
22        let wires = wire_cast(prefix, target, inputs, operating_datum_type)?;
23        target.wire_node(
24            prefix,
25            tract_core::ops::element_wise::ElementWiseOp(self.0.clone(), None),
26            &wires,
27        )
28    }
29
30    fn rules<'r, 'p: 'r, 's: 'r>(
31        &'s self,
32        s: &mut Solver<'r>,
33        inputs: &'p [TensorProxy],
34        outputs: &'p [TensorProxy],
35    ) -> InferenceResult {
36        check_input_arity(inputs, 1)?;
37        check_output_arity(outputs, 1)?;
38        s.given(&inputs[0].datum_type, move |s, dt| {
39            let dt = self.0.operating_datum_type(dt);
40            if let Some(dt) = self.0.output_type(dt) {
41                s.equals(&outputs[0].datum_type, dt)
42            } else {
43                s.equals(&outputs[0].datum_type, dt)
44            }
45        })?;
46        s.equals(&inputs[0].shape, &outputs[0].shape)?;
47        Ok(())
48    }
49}
50
51pub trait ElementWiseIntoHir {
52    fn into_hir(self) -> Box<dyn InferenceOp>;
53}
54
55impl ElementWiseIntoHir for tract_core::ops::element_wise::ElementWiseOp {
56    fn into_hir(self) -> Box<dyn InferenceOp> {
57        expand(ElementWiseOp(self.0))
58    }
59}