Skip to main content

tract_gpu/ops/
cast.rs

1use crate::tensor::DeviceTensorExt;
2use tract_core::internal::*;
3
4use crate::tensor::DeviceTensor;
5
6pub type DispatchCastFn = fn(&DeviceTensor, &DeviceTensor) -> TractResult<()>;
7
8#[derive(Clone)]
9pub struct GpuCast {
10    pub to: DatumType,
11    pub backend_name: &'static str,
12    pub dispatch: DispatchCastFn,
13    pub is_supported_dt: fn(DatumType) -> bool,
14}
15
16impl GpuCast {
17    pub fn new(
18        to: DatumType,
19        backend_name: &'static str,
20        dispatch: DispatchCastFn,
21        is_supported_dt: fn(DatumType) -> bool,
22    ) -> Option<Self> {
23        is_supported_dt(to).then_some(Self { to, backend_name, dispatch, is_supported_dt })
24    }
25
26    pub fn is_supported_dt(&self, dt: DatumType) -> bool {
27        (self.is_supported_dt)(dt)
28    }
29}
30
31impl std::fmt::Debug for GpuCast {
32    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
33        write!(f, "{}Cast({:?})", self.backend_name, self.to)
34    }
35}
36
37impl PartialEq for GpuCast {
38    fn eq(&self, other: &Self) -> bool {
39        self.backend_name == other.backend_name && self.to == other.to
40    }
41}
42
43impl Eq for GpuCast {}
44
45impl std::hash::Hash for GpuCast {
46    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
47        self.backend_name.hash(state);
48        self.to.hash(state);
49    }
50}
51
52impl Op for GpuCast {
53    fn name(&self) -> StaticName {
54        format!("{}Cast", self.backend_name).into()
55    }
56
57    op_as_typed_op!();
58}
59
60impl EvalOp for GpuCast {
61    fn is_stateless(&self) -> bool {
62        true
63    }
64
65    fn eval_with_session(
66        &self,
67        node_id: usize,
68        session: &TurnState,
69        inputs: TVec<TValue>,
70    ) -> TractResult<TVec<TValue>> {
71        let input_value = args_1!(inputs);
72        let input = input_value.to_device_tensor()?;
73        if input.datum_type() == self.to {
74            Ok(tvec!(input_value))
75        } else {
76            let output = crate::session_handler::make_tensor_for_node(
77                session,
78                node_id,
79                self.to,
80                input.shape(),
81            )?;
82            (self.dispatch)(input, &output)?;
83            Ok(tvec![output.into_tensor().into_tvalue()])
84        }
85    }
86}
87
88impl TypedOp for GpuCast {
89    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
90        crate::utils::facts_to_device_facts(inputs, |facts| {
91            Ok(tvec!(self.to.fact(facts[0].shape.clone())))
92        })
93        .with_context(|| format!("Error while computing facts for {:?}", self.name()))
94    }
95
96    as_op!();
97}