1use crate::fact::{DeviceFact, DeviceTypedFactExt};
2use crate::tensor::{DeviceTensorExt, IntoDevice};
3use derive_new::new;
4use std::fmt;
5use tract_core::internal::*;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
8pub enum DeviceSyncKind {
9 ToHost,
10 ToDevice,
11}
12
13impl fmt::Display for DeviceSyncKind {
14 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
15 write!(f, "{self:?}")
16 }
17}
18
19#[derive(Debug, Clone, new, Copy, PartialEq, Eq, Hash)]
20pub struct DeviceSync {
21 pub kind: DeviceSyncKind,
22}
23
24impl Op for DeviceSync {
25 fn name(&self) -> StaticName {
26 format!("DeviceSync{}", self.kind).into()
27 }
28
29 fn same_as(&self, other: &dyn Op) -> bool {
30 let Some(other) = other.downcast_ref::<DeviceSync>() else { return false };
31 self == other
32 }
33
34 op_as_typed_op!();
35}
36
37impl EvalOp for DeviceSync {
38 fn is_stateless(&self) -> bool {
39 true
40 }
41
42 fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
43 let input = args_1!(inputs);
44 match self.kind {
45 DeviceSyncKind::ToHost => {
46 let device_tensor = input.to_device_tensor()?;
47
48 let tensor = device_tensor
49 .to_host()
50 .with_context(|| "Error while syncing device tensor to host")?;
51 Ok(tvec![tensor.into_tvalue()])
52 }
53 DeviceSyncKind::ToDevice => {
54 let device_input = if let Some(t) = input.as_arc_tensor() {
55 Arc::clone(t).into_device()?
56 } else {
57 input.into_tensor().into_device()?
58 };
59 Ok(tvec![device_input.into_opaque_tensor().into()])
60 }
61 }
62 }
63}
64
65impl TypedOp for DeviceSync {
66 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
67 let input = inputs[0];
68 match self.kind {
69 DeviceSyncKind::ToHost => {
70 let mut typed_fact = input
71 .to_device_fact()
72 .with_context(|| {
73 "Cannot sync to Host a tensor without DeviceFact as metadata in its TypedFact"
74 })?
75 .clone()
76 .into_typed_fact();
77 if let Some(konst) = input.konst.clone() {
78 typed_fact.konst = Some(konst.to_device_tensor()?.to_host()?);
79 }
80 Ok(tvec!(typed_fact))
81 }
82 DeviceSyncKind::ToDevice => {
83 ensure!(
84 input.datum_type != DatumType::Opaque,
85 "Cannot sync Opaque Tensor to Device"
86 );
87 Ok(tvec![DeviceFact::from_host(input.clone())?.into_opaque_fact()])
88 }
89 }
90 }
91
92 as_op!();
93}