1use crate::tensor::DeviceTensorExt;
2use tract_core::internal::*;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
5pub struct GpuConcat {
6 pub axis: usize,
7}
8
9impl GpuConcat {
10 pub fn new(axis: usize) -> Self {
11 Self { axis }
12 }
13
14 pub fn offsets(&self, inputs: &[&TypedFact]) -> TractResult<Vec<TDim>> {
15 let mut offsets = vec![0.to_dim()];
16 for slice in inputs {
17 let len = slice.shape[self.axis].clone();
18 let offset = len + offsets.last().unwrap();
19 offsets.push(offset)
20 }
21 Ok(offsets)
22 }
23}
24
25impl Op for GpuConcat {
26 fn name(&self) -> StaticName {
27 "GpuConcat".into()
28 }
29
30 fn info(&self) -> TractResult<Vec<String>> {
31 Ok(vec![format!("axis: {}", self.axis)])
32 }
33
34 op_as_typed_op!();
35}
36
37impl EvalOp for GpuConcat {
38 fn is_stateless(&self) -> bool {
39 true
40 }
41
42 fn eval_with_session(
43 &self,
44 node_id: usize,
45 session: &TurnState,
46 inputs: TVec<TValue>,
47 ) -> TractResult<TVec<TValue>> {
48 let inputs =
49 inputs.iter().map(|it| it.to_device_tensor()).collect::<TractResult<TVec<_>>>()?;
50
51 let mut output_shape = inputs[0].shape().to_vec();
52 output_shape[self.axis] = inputs.iter().map(|it| it.shape()[self.axis]).sum();
53 let output = crate::session_handler::make_tensor_for_node(
54 session,
55 node_id,
56 inputs[0].datum_type(),
57 &output_shape,
58 )?;
59
60 let ctx = crate::device::get_context()?;
61 let mut cursor = 0usize;
62 for input in &inputs {
63 let slice_len = input.shape()[self.axis];
64 if slice_len == 0 {
65 continue;
66 }
67 let zone_shape = input.shape();
69 let dst_offset =
71 cursor * output.strides()[self.axis] as usize * output.datum_type().size_of();
72
73 ctx.copy_nd(
74 input,
75 0,
76 input.strides(),
77 &output,
78 dst_offset,
79 zone_shape,
80 output.strides(),
81 )
82 .with_context(|| {
83 format!(
84 "Error in concat dispatch for slice at offset {} (shape {:?})",
85 cursor, zone_shape
86 )
87 })?;
88 cursor += slice_len;
89 }
90
91 Ok(tvec!(output.into_tensor().into_tvalue()))
92 }
93}
94
95impl TypedOp for GpuConcat {
96 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
97 crate::utils::facts_to_device_facts(inputs, |facts| {
98 let mut fact = facts[0].without_value();
99 for input in facts {
100 if input.rank() != fact.rank()
101 || input
102 .shape
103 .iter()
104 .zip(fact.shape.iter())
105 .enumerate()
106 .filter(|(ax, _)| *ax != self.axis)
107 .any(|(_, (i, f))| i != f)
108 {
109 bail!("Inconsistent {:?} inputs: {:?}", self, facts);
110 }
111 }
112 fact.shape.set(self.axis, self.offsets(facts)?.pop().unwrap());
113 Ok(tvec!(fact))
114 })
115 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
116 }
117
118 as_op!();
119}