Skip to main content

tract_gpu/ops/
concat.rs

1use crate::tensor::DeviceTensorExt;
2use tract_core::internal::*;
3
4#[derive(Clone, Debug, PartialEq, Eq, Hash)]
5pub struct GpuConcat {
6    pub axis: usize,
7}
8
9impl GpuConcat {
10    pub fn new(axis: usize) -> Self {
11        Self { axis }
12    }
13
14    pub fn offsets(&self, inputs: &[&TypedFact]) -> TractResult<Vec<TDim>> {
15        let mut offsets = vec![0.to_dim()];
16        for slice in inputs {
17            let len = slice.shape[self.axis].clone();
18            let offset = len + offsets.last().unwrap();
19            offsets.push(offset)
20        }
21        Ok(offsets)
22    }
23}
24
25impl Op for GpuConcat {
26    fn name(&self) -> StaticName {
27        "GpuConcat".into()
28    }
29
30    fn info(&self) -> TractResult<Vec<String>> {
31        Ok(vec![format!("axis: {}", self.axis)])
32    }
33
34    op_as_typed_op!();
35}
36
37impl EvalOp for GpuConcat {
38    fn is_stateless(&self) -> bool {
39        true
40    }
41
42    fn eval_with_session(
43        &self,
44        node_id: usize,
45        session: &TurnState,
46        inputs: TVec<TValue>,
47    ) -> TractResult<TVec<TValue>> {
48        let inputs =
49            inputs.iter().map(|it| it.to_device_tensor()).collect::<TractResult<TVec<_>>>()?;
50
51        let mut output_shape = inputs[0].shape().to_vec();
52        output_shape[self.axis] = inputs.iter().map(|it| it.shape()[self.axis]).sum();
53        let output = crate::session_handler::make_tensor_for_node(
54            session,
55            node_id,
56            inputs[0].datum_type(),
57            &output_shape,
58        )?;
59
60        let ctx = crate::device::get_context()?;
61        let mut cursor = 0usize;
62        for input in &inputs {
63            let slice_len = input.shape()[self.axis];
64            if slice_len == 0 {
65                continue;
66            }
67            // Build zone shape (same as input shape for this slice)
68            let zone_shape = input.shape();
69            // Output offset along concat axis
70            let dst_offset =
71                cursor * output.strides()[self.axis] as usize * output.datum_type().size_of();
72
73            ctx.copy_nd(
74                input,
75                0,
76                input.strides(),
77                &output,
78                dst_offset,
79                zone_shape,
80                output.strides(),
81            )
82            .with_context(|| {
83                format!(
84                    "Error in concat dispatch for slice at offset {} (shape {:?})",
85                    cursor, zone_shape
86                )
87            })?;
88            cursor += slice_len;
89        }
90
91        Ok(tvec!(output.into_tensor().into_tvalue()))
92    }
93}
94
95impl TypedOp for GpuConcat {
96    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
97        crate::utils::facts_to_device_facts(inputs, |facts| {
98            let mut fact = facts[0].without_value();
99            for input in facts {
100                if input.rank() != fact.rank()
101                    || input
102                        .shape
103                        .iter()
104                        .zip(fact.shape.iter())
105                        .enumerate()
106                        .filter(|(ax, _)| *ax != self.axis)
107                        .any(|(_, (i, f))| i != f)
108                {
109                    bail!("Inconsistent {:?} inputs: {:?}", self, facts);
110                }
111            }
112            fact.shape.set(self.axis, self.offsets(facts)?.pop().unwrap());
113            Ok(tvec!(fact))
114        })
115        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
116    }
117
118    as_op!();
119}