1use crate::tensor::DeviceTensorExt;
2use tract_core::internal::*;
3
4use crate::tensor::DeviceTensor;
5
6pub type DispatchCastFn = fn(&DeviceTensor, &DeviceTensor) -> TractResult<()>;
7
8#[derive(Clone)]
9pub struct GpuCast {
10 pub to: DatumType,
11 pub backend_name: &'static str,
12 pub dispatch: DispatchCastFn,
13 pub is_supported_dt: fn(DatumType) -> bool,
14}
15
16impl GpuCast {
17 pub fn new(
18 to: DatumType,
19 backend_name: &'static str,
20 dispatch: DispatchCastFn,
21 is_supported_dt: fn(DatumType) -> bool,
22 ) -> Option<Self> {
23 is_supported_dt(to).then_some(Self { to, backend_name, dispatch, is_supported_dt })
24 }
25
26 pub fn is_supported_dt(&self, dt: DatumType) -> bool {
27 (self.is_supported_dt)(dt)
28 }
29}
30
31impl std::fmt::Debug for GpuCast {
32 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
33 write!(f, "{}Cast({:?})", self.backend_name, self.to)
34 }
35}
36
37impl PartialEq for GpuCast {
38 fn eq(&self, other: &Self) -> bool {
39 self.backend_name == other.backend_name && self.to == other.to
40 }
41}
42
43impl Eq for GpuCast {}
44
45impl std::hash::Hash for GpuCast {
46 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47 self.backend_name.hash(state);
48 self.to.hash(state);
49 }
50}
51
52impl Op for GpuCast {
53 fn name(&self) -> StaticName {
54 format!("{}Cast", self.backend_name).into()
55 }
56
57 op_as_typed_op!();
58}
59
60impl EvalOp for GpuCast {
61 fn is_stateless(&self) -> bool {
62 true
63 }
64
65 fn eval_with_session(
66 &self,
67 node_id: usize,
68 session: &TurnState,
69 inputs: TVec<TValue>,
70 ) -> TractResult<TVec<TValue>> {
71 let input_value = args_1!(inputs);
72 let input = input_value.to_device_tensor()?;
73 if input.datum_type() == self.to {
74 Ok(tvec!(input_value))
75 } else {
76 let output = crate::session_handler::make_tensor_for_node(
77 session,
78 node_id,
79 self.to,
80 input.shape(),
81 )?;
82 (self.dispatch)(input, &output)?;
83 Ok(tvec![output.into_tensor().into_tvalue()])
84 }
85 }
86}
87
88impl TypedOp for GpuCast {
89 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
90 crate::utils::facts_to_device_facts(inputs, |facts| {
91 Ok(tvec!(self.to.fact(facts[0].shape.clone())))
92 })
93 .with_context(|| format!("Error while computing facts for {:?}", self.name()))
94 }
95
96 as_op!();
97}