tract_hir/ops/array/
concat.rs

1use crate::infer::*;
2use crate::internal::*;
3
4pub use tract_core::ops::array::TypedConcat;
5use tract_core::ops::cast::wire_cast;
6
7/// Concat: high level concat op
8#[derive(Debug, Clone, new, Hash)]
9pub struct Concat {
10    axis: i64,
11}
12
13
14
15impl Concat {
16    fn resolve_axis(&self, rank: i64) -> TractResult<usize> {
17        if 0 <= self.axis && self.axis < rank {
18            Ok(self.axis as usize)
19        } else if -rank <= self.axis && self.axis < 0 {
20            Ok((self.axis + rank) as usize)
21        } else {
22            bail!("Illegal combination of values for rank and axis: {} and {}", rank, self.axis)
23        }
24    }
25}
26
27impl Expansion for Concat {
28    fn name(&self) -> StaticName {
29        "InferenceConcat".into()
30    }
31
32    fn rules<'r, 'p: 'r, 's: 'r>(
33        &'s self,
34        s: &mut Solver<'r>,
35        inputs: &'p [TensorProxy],
36        outputs: &'p [TensorProxy],
37    ) -> InferenceResult {
38        check_output_arity(outputs, 1)?;
39        s.equals(&outputs[0].rank, &inputs[0].rank)?;
40        let n = inputs.len();
41        s.equals_all((0..n).map(|i| (&inputs[i].rank).bex()).collect())?;
42        s.given_all((0..n).map(|i| (&inputs[i].datum_type).bex()), move |s, dts| {
43            let super_type: DatumType = DatumType::super_type_for(&dts)
44                .with_context(|| format!("No supertype found for {dts:?}"))?;
45            s.equals(&outputs[0].datum_type, super_type)
46        })?;
47        s.given(&inputs[0].rank, move |s, rank| {
48            let axis = self.resolve_axis(rank)?;
49            s.equals(
50                rules::expr::SumExp::new((0..n).map(|i| (&inputs[i].shape[axis]).bex()).collect()),
51                &outputs[0].shape[axis],
52            )?;
53            for axis in 0..axis {
54                s.equals(&outputs[0].shape[axis], &inputs[0].shape[axis])?;
55                s.equals_all((0..n).map(|i| inputs[i].shape[axis].bex()).collect())?;
56            }
57            for axis in (axis + 1)..(rank as usize) {
58                s.equals(&outputs[0].shape[axis], &inputs[0].shape[axis])?;
59                s.equals_all((0..n).map(|i| inputs[i].shape[axis].bex()).collect())?;
60            }
61            Ok(())
62        })?;
63        Ok(())
64    }
65
66    fn wire(
67        &self,
68        prefix: &str,
69        target: &mut TypedModel,
70        inputs: &[OutletId],
71    ) -> TractResult<TVec<OutletId>> {
72        let facts = inputs
73            .iter()
74            .map(|i| target.outlet_fact(*i).cloned())
75            .collect::<TractResult<TVec<_>>>()?;
76
77        let super_type = if let Some(super_type) =
78            DatumType::super_type_for(facts.iter().map(|x| x.datum_type))
79        {
80            super_type
81        } else {
82            bail!("Can not type op");
83        };
84
85        let axis = self.resolve_axis(facts[0].shape.rank() as i64)?;
86
87        let inputs = wire_cast(prefix, target, inputs, super_type)?;
88        let op = TypedConcat::new(axis);
89        target.wire_node(prefix, op, &inputs)
90    }
91}