tract_hir/ops/array/
broadcast.rs

1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::ops::array::MultiBroadcastTo as Typed;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct MultiBroadcastTo;
8
9impl MultiBroadcastTo {
10    fn wire_with_known_target_shape(
11        &self,
12        prefix: &str,
13        model: &mut TypedModel,
14        inputs: &[OutletId],
15        target_shape: &[TDim],
16    ) -> TractResult<TVec<OutletId>> {
17        let left_shape = model.outlet_fact(inputs[0])?.shape.to_tvec();
18        let dims = tract_core::broadcast::multi_broadcast(&[&*left_shape, target_shape])?;
19        let op = Typed::new(dims.into());
20        model.wire_node(prefix, op, &[inputs[0]])
21    }
22}
23
24impl Expansion for MultiBroadcastTo {
25    fn name(&self) -> StaticName {
26        "MultiBroadcastTo".into()
27    }
28
29    fn rules<'r, 'p: 'r, 's: 'r>(
30        &'s self,
31        s: &mut Solver<'r>,
32        inputs: &'p [TensorProxy],
33        outputs: &'p [TensorProxy],
34    ) -> InferenceResult {
35        check_input_arity(inputs, 2)?;
36        check_output_arity(outputs, 1)?;
37        s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
38        s.equals(&inputs[1].rank, 1)?;
39        s.given(&inputs[0].shape, move |s, shape| {
40            s.given(&inputs[1].value, move |s, dims| {
41                let dims = dims.cast_to::<TDim>()?;
42                let dims =
43                    tract_core::broadcast::multi_broadcast(&[dims.as_slice::<TDim>()?, &shape])?;
44                s.equals(&outputs[0].shape, ShapeFactoid::from(dims))
45            })
46        })?;
47        Ok(())
48    }
49
50    fn wire(
51        &self,
52        prefix: &str,
53        model: &mut TypedModel,
54        inputs: &[OutletId],
55    ) -> TractResult<TVec<OutletId>> {
56        if let Some(shape) = model.outlet_fact(inputs[1])?.konst.clone() {
57            let shape = shape.cast_to::<TDim>()?;
58            self.wire_with_known_target_shape(prefix, model, inputs, shape.as_slice()?)
59        } else {
60            bail!("shape input is variable")
61        }
62    }
63
64    fn wire_with_inference_model_and_node(
65        &self,
66        prefix: &str,
67        source: &InferenceModel,
68        node: &InferenceNode,
69        model: &mut TypedModel,
70        inputs: &[OutletId],
71    ) -> TractResult<TVec<OutletId>> {
72        if let Some(shape) = model.outlet_fact(inputs[1])?.konst.clone() {
73            let shape = shape.cast_to::<TDim>()?;
74            self.wire_with_known_target_shape(prefix, model, inputs, shape.as_slice()?)
75        } else if let Some(shape) = source.outlet_fact(node.id.into())?.shape.concretize() {
76            let op = Typed::new(shape.into());
77            model.wire_node(prefix, op, &[inputs[0]])
78        } else {
79            bail!("shape input is variable, of variable length (output can not have variable rank)")
80        }
81    }
82}