Skip to main content

tract_gpu/ops/
broadcast.rs

1use crate::tensor::DeviceTensorExt;
2use crate::utils::compute_broadcast_strides;
3use tract_core::internal::*;
4
5#[derive(Clone, Debug, PartialEq, Eq, Hash)]
6pub struct GpuMultiBroadcastTo {
7    pub shape: ShapeFact,
8}
9
10impl GpuMultiBroadcastTo {
11    pub fn new(shape: ShapeFact) -> Self {
12        Self { shape }
13    }
14}
15
16impl Op for GpuMultiBroadcastTo {
17    fn name(&self) -> StaticName {
18        "GpuMultiBroadcastTo".into()
19    }
20
21    op_as_typed_op!();
22}
23
24impl EvalOp for GpuMultiBroadcastTo {
25    fn is_stateless(&self) -> bool {
26        true
27    }
28
29    fn eval_with_session(
30        &self,
31        node_id: usize,
32        session: &TurnState,
33        inputs: TVec<TValue>,
34    ) -> TractResult<TVec<TValue>> {
35        let input_value = args_1!(inputs);
36        let input = input_value.to_device_tensor()?;
37        let shape = self.shape.eval_to_usize(&session.resolved_symbols)?;
38        let output = crate::session_handler::make_tensor_for_node(
39            session,
40            node_id,
41            input.datum_type(),
42            &shape,
43        )?;
44
45        // Pad input shape/strides to output rank for broadcasting
46        let mut input_strides = vec![input.strides()[0]; output.rank() - input.rank()];
47        input_strides.extend(input.strides());
48        let mut input_shape = vec![1usize; output.rank() - input.rank()];
49        input_shape.extend(input.shape());
50        let broadcast_strides: TVec<isize> =
51            compute_broadcast_strides(&input_shape, &input_strides)?;
52
53        let ctx = crate::device::get_context()?;
54        ctx.copy_nd(input, 0, &broadcast_strides, &output, 0, output.shape(), output.strides())?;
55        Ok(tvec![output.into_tensor().into_tvalue()])
56    }
57}
58
59impl TypedOp for GpuMultiBroadcastTo {
60    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
61        crate::utils::facts_to_device_facts(inputs, |facts| {
62            let mut fact = facts[0].datum_type.fact(self.shape.clone());
63            fact.uniform.clone_from(&inputs[0].uniform);
64            Ok(tvec!(fact))
65        })
66        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
67    }
68
69    as_op!();
70}