Skip to main content

tract_gpu/
sync.rs

1use crate::fact::{DeviceFact, DeviceTypedFactExt};
2use crate::tensor::{DeviceTensorExt, IntoDevice};
3use derive_new::new;
4use std::collections::HashMap;
5use std::fmt;
6use std::sync::Arc;
7use tract_core::internal::*;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
10pub enum DeviceSyncKind {
11    ToHost,
12    ToDevice,
13}
14
15impl fmt::Display for DeviceSyncKind {
16    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
17        write!(f, "{self:?}")
18    }
19}
20
21#[derive(Debug, Clone, new, Copy, PartialEq, Eq, Hash)]
22pub struct DeviceSync {
23    pub kind: DeviceSyncKind,
24}
25
26impl Op for DeviceSync {
27    fn name(&self) -> StaticName {
28        format!("DeviceSync{}", self.kind).into()
29    }
30
31    op_as_typed_op!();
32}
33
34impl EvalOp for DeviceSync {
35    fn is_stateless(&self) -> bool {
36        true
37    }
38
39    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
40        let input = args_1!(inputs);
41        match self.kind {
42            DeviceSyncKind::ToHost => {
43                let device_tensor = input.to_device_tensor()?;
44
45                let tensor = device_tensor
46                    .to_host()
47                    .with_context(|| "Error while syncing device tensor to host")?;
48                Ok(tvec![tensor.into_tvalue()])
49            }
50            DeviceSyncKind::ToDevice => {
51                let device_input = if let Some(t) = input.as_arc_tensor() {
52                    Arc::clone(t).into_device()?
53                } else {
54                    input.into_tensor().into_device()?
55                };
56                Ok(tvec![device_input.into_tensor().into()])
57            }
58        }
59    }
60}
61
62impl TypedOp for DeviceSync {
63    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
64        let input = inputs[0];
65        match self.kind {
66            DeviceSyncKind::ToHost => {
67                let mut typed_fact = input
68                    .to_device_fact()
69                    .with_context(|| {
70                        "Cannot sync to Host a tensor without DeviceFact as metadata in its TypedFact"
71                    })?
72                    .clone()
73                    .into_typed_fact();
74                if let Some(konst) = input.konst.clone() {
75                    if let Some(dt) = konst.as_device_tensor() {
76                        typed_fact.konst = Some(dt.to_host()?);
77                    } else {
78                        typed_fact.konst = Some(konst);
79                    }
80                }
81                Ok(tvec!(typed_fact))
82            }
83            DeviceSyncKind::ToDevice => {
84                ensure!(
85                    input.as_device_fact().is_none(),
86                    "Cannot sync to Device a tensor already on Device"
87                );
88                Ok(tvec![DeviceFact::from_host(input.clone())?.into_exotic_fact()])
89            }
90        }
91    }
92
93    as_op!();
94}
95
96/// Map node inputs through the translation mapping, inserting DeviceSync nodes
97/// where needed to move tensors to/from the device.
98pub fn sync_inputs_if_required(
99    model: &mut TypedModel,
100    node: &TypedNode,
101    mapping: &HashMap<OutletId, OutletId>,
102    sync_kind: DeviceSyncKind,
103) -> TractResult<TVec<OutletId>> {
104    let mut mapped_inputs = tvec![];
105    for (i_idx, i) in node.inputs.iter().enumerate() {
106        let in_fact = model.outlet_fact_mut(mapping[i])?;
107        match sync_kind {
108            DeviceSyncKind::ToHost if in_fact.as_device_fact().is_some() => {
109                mapped_inputs.push(
110                    model.wire_node(
111                        format!("{}.to-cpu-{i_idx}", node.name),
112                        DeviceSync::new(sync_kind),
113                        &[mapping[i]],
114                    )?[0],
115                );
116            }
117            DeviceSyncKind::ToDevice if in_fact.as_device_fact().is_none() => {
118                if let Some(ref konst) = in_fact.konst
119                    && konst.as_device_tensor().is_none()
120                {
121                    let device_konst = konst.as_ref().clone().into_device()?.into_tensor();
122                    let device_fact = DeviceFact::from_host(in_fact.clone())?;
123
124                    *in_fact = device_fact.into_exotic_fact();
125
126                    in_fact.konst = Some(Arc::new(device_konst));
127                    mapped_inputs.push(mapping[i]);
128                    continue;
129                }
130                ensure!(
131                    in_fact.datum_type.is_copy(),
132                    "Only copy DatumType can be sync to Device: {:?}",
133                    in_fact.datum_type
134                );
135
136                mapped_inputs.push(
137                    model.wire_node(
138                        format!("{}.to-device-{i_idx}", node.name),
139                        DeviceSync::new(sync_kind),
140                        &[mapping[i]],
141                    )?[0],
142                );
143            }
144            _ => mapped_inputs.push(mapping[i]),
145        }
146    }
147    Ok(mapped_inputs)
148}
149
150/// For model outputs that are on device, insert DeviceSync nodes to move them back to host.
151pub fn sync_model_outputs_if_required(
152    src: &TypedModel,
153    node: &TypedNode,
154    target: &mut TypedModel,
155    target_node_outlet_ids: TVec<OutletId>,
156) -> TractResult<TVec<OutletId>> {
157    let mut outputs = tvec![];
158    for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
159        let is_src_output = src.outputs.contains(&OutletId::new(node.id, o_idx));
160        if target.outlet_fact(o)?.as_device_fact().is_some() && is_src_output {
161            let sync_output = target.wire_node(
162                format!("{}.to-host-{o_idx}-out", node.name),
163                DeviceSync::new(DeviceSyncKind::ToHost),
164                &[o],
165            )?[0];
166            outputs.push(sync_output);
167        } else {
168            outputs.push(o)
169        }
170    }
171    Ok(outputs)
172}