tract_hir/ops/array/
concat.rs1use crate::infer::*;
2use crate::internal::*;
3
4pub use tract_core::ops::array::TypedConcat;
5use tract_core::ops::cast::wire_cast;
6
7#[derive(Debug, Clone, new, Hash)]
9pub struct Concat {
10 axis: i64,
11}
12
13
14
15impl Concat {
16 fn resolve_axis(&self, rank: i64) -> TractResult<usize> {
17 if 0 <= self.axis && self.axis < rank {
18 Ok(self.axis as usize)
19 } else if -rank <= self.axis && self.axis < 0 {
20 Ok((self.axis + rank) as usize)
21 } else {
22 bail!("Illegal combination of values for rank and axis: {} and {}", rank, self.axis)
23 }
24 }
25}
26
27impl Expansion for Concat {
28 fn name(&self) -> StaticName {
29 "InferenceConcat".into()
30 }
31
32 fn rules<'r, 'p: 'r, 's: 'r>(
33 &'s self,
34 s: &mut Solver<'r>,
35 inputs: &'p [TensorProxy],
36 outputs: &'p [TensorProxy],
37 ) -> InferenceResult {
38 check_output_arity(outputs, 1)?;
39 s.equals(&outputs[0].rank, &inputs[0].rank)?;
40 let n = inputs.len();
41 s.equals_all((0..n).map(|i| (&inputs[i].rank).bex()).collect())?;
42 s.given_all((0..n).map(|i| (&inputs[i].datum_type).bex()), move |s, dts| {
43 let super_type: DatumType = DatumType::super_type_for(&dts)
44 .with_context(|| format!("No supertype found for {dts:?}"))?;
45 s.equals(&outputs[0].datum_type, super_type)
46 })?;
47 s.given(&inputs[0].rank, move |s, rank| {
48 let axis = self.resolve_axis(rank)?;
49 s.equals(
50 rules::expr::SumExp::new((0..n).map(|i| (&inputs[i].shape[axis]).bex()).collect()),
51 &outputs[0].shape[axis],
52 )?;
53 for axis in 0..axis {
54 s.equals(&outputs[0].shape[axis], &inputs[0].shape[axis])?;
55 s.equals_all((0..n).map(|i| inputs[i].shape[axis].bex()).collect())?;
56 }
57 for axis in (axis + 1)..(rank as usize) {
58 s.equals(&outputs[0].shape[axis], &inputs[0].shape[axis])?;
59 s.equals_all((0..n).map(|i| inputs[i].shape[axis].bex()).collect())?;
60 }
61 Ok(())
62 })?;
63 Ok(())
64 }
65
66 fn wire(
67 &self,
68 prefix: &str,
69 target: &mut TypedModel,
70 inputs: &[OutletId],
71 ) -> TractResult<TVec<OutletId>> {
72 let facts = inputs
73 .iter()
74 .map(|i| target.outlet_fact(*i).cloned())
75 .collect::<TractResult<TVec<_>>>()?;
76
77 let super_type = if let Some(super_type) =
78 DatumType::super_type_for(facts.iter().map(|x| x.datum_type))
79 {
80 super_type
81 } else {
82 bail!("Can not type op");
83 };
84
85 let axis = self.resolve_axis(facts[0].shape.rank() as i64)?;
86
87 let inputs = wire_cast(prefix, target, inputs, super_type)?;
88 let op = TypedConcat::new(axis);
89 target.wire_node(prefix, op, &inputs)
90 }
91}