tract_hir/ops/array/
broadcast.rs1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::ops::array::MultiBroadcastTo as Typed;
5
6#[derive(Debug, Clone, new, Default, Hash)]
7pub struct MultiBroadcastTo;
8
9impl MultiBroadcastTo {
10 fn wire_with_known_target_shape(
11 &self,
12 prefix: &str,
13 model: &mut TypedModel,
14 inputs: &[OutletId],
15 target_shape: &[TDim],
16 ) -> TractResult<TVec<OutletId>> {
17 let left_shape = model.outlet_fact(inputs[0])?.shape.to_tvec();
18 let dims = tract_core::broadcast::multi_broadcast(&[&*left_shape, target_shape])?;
19 let op = Typed::new(dims.into());
20 model.wire_node(prefix, op, &[inputs[0]])
21 }
22}
23
24impl Expansion for MultiBroadcastTo {
25 fn name(&self) -> StaticName {
26 "MultiBroadcastTo".into()
27 }
28
29 fn rules<'r, 'p: 'r, 's: 'r>(
30 &'s self,
31 s: &mut Solver<'r>,
32 inputs: &'p [TensorProxy],
33 outputs: &'p [TensorProxy],
34 ) -> InferenceResult {
35 check_input_arity(inputs, 2)?;
36 check_output_arity(outputs, 1)?;
37 s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
38 s.equals(&inputs[1].rank, 1)?;
39 s.given(&inputs[0].shape, move |s, shape| {
40 s.given(&inputs[1].value, move |s, dims| {
41 let dims = dims.cast_to::<TDim>()?;
42 let dims =
43 tract_core::broadcast::multi_broadcast(&[dims.as_slice::<TDim>()?, &shape])?;
44 s.equals(&outputs[0].shape, ShapeFactoid::from(dims))
45 })
46 })?;
47 Ok(())
48 }
49
50 fn wire(
51 &self,
52 prefix: &str,
53 model: &mut TypedModel,
54 inputs: &[OutletId],
55 ) -> TractResult<TVec<OutletId>> {
56 if let Some(shape) = model.outlet_fact(inputs[1])?.konst.clone() {
57 let shape = shape.cast_to::<TDim>()?;
58 self.wire_with_known_target_shape(prefix, model, inputs, shape.as_slice()?)
59 } else {
60 bail!("shape input is variable")
61 }
62 }
63
64 fn wire_with_inference_model_and_node(
65 &self,
66 prefix: &str,
67 source: &InferenceModel,
68 node: &InferenceNode,
69 model: &mut TypedModel,
70 inputs: &[OutletId],
71 ) -> TractResult<TVec<OutletId>> {
72 if let Some(shape) = model.outlet_fact(inputs[1])?.konst.clone() {
73 let shape = shape.cast_to::<TDim>()?;
74 self.wire_with_known_target_shape(prefix, model, inputs, shape.as_slice()?)
75 } else if let Some(shape) = source.outlet_fact(node.id.into())?.shape.concretize() {
76 let op = Typed::new(shape.into());
77 model.wire_node(prefix, op, &[inputs[0]])
78 } else {
79 bail!("shape input is variable, of variable length (output can not have variable rank)")
80 }
81 }
82}