Skip to main content

tract_gpu/
sync.rs

1use crate::fact::{DeviceFact, DeviceTypedFactExt};
2use crate::tensor::{DeviceTensorExt, IntoDevice};
3use derive_new::new;
4use std::fmt;
5use tract_core::internal::*;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum DeviceSyncKind {
9    ToHost,
10    ToDevice,
11}
12
13impl fmt::Display for DeviceSyncKind {
14    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
15        write!(f, "{self:?}")
16    }
17}
18
19#[derive(Debug, Clone, new, Copy, PartialEq, Eq, Hash)]
20pub struct DeviceSync {
21    pub kind: DeviceSyncKind,
22}
23
24impl Op for DeviceSync {
25    fn name(&self) -> StaticName {
26        format!("DeviceSync{}", self.kind).into()
27    }
28
29    fn same_as(&self, other: &dyn Op) -> bool {
30        let Some(other) = other.downcast_ref::<DeviceSync>() else { return false };
31        self == other
32    }
33
34    op_as_typed_op!();
35}
36
37impl EvalOp for DeviceSync {
38    fn is_stateless(&self) -> bool {
39        true
40    }
41
42    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
43        let input = args_1!(inputs);
44        match self.kind {
45            DeviceSyncKind::ToHost => {
46                let device_tensor = input.to_device_tensor()?;
47
48                let tensor = device_tensor
49                    .to_host()
50                    .with_context(|| "Error while syncing device tensor to host")?;
51                Ok(tvec![tensor.into_tvalue()])
52            }
53            DeviceSyncKind::ToDevice => {
54                let device_input = if let Some(t) = input.as_arc_tensor() {
55                    Arc::clone(t).into_device()?
56                } else {
57                    input.into_tensor().into_device()?
58                };
59                Ok(tvec![device_input.into_opaque_tensor().into()])
60            }
61        }
62    }
63}
64
65impl TypedOp for DeviceSync {
66    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
67        let input = inputs[0];
68        match self.kind {
69            DeviceSyncKind::ToHost => Ok(tvec![input
70                .to_device_fact()
71                .with_context(|| {
72                    "Cannot sync to Host a tensor without DeviceFact as metadata in its TypedFact"
73                })?
74                .clone()
75                .into_typed_fact()]),
76            DeviceSyncKind::ToDevice => {
77                ensure!(
78                    input.datum_type != DatumType::Opaque,
79                    "Cannot sync Opaque Tensor to Device"
80                );
81                Ok(tvec![DeviceFact::from_host(input.clone())?.into_opaque_fact()])
82            }
83        }
84    }
85
86    as_op!();
87}