1use crate::internal::*;
2
3pub fn cast(to: DatumType) -> Cast {
4    Cast { to }
5}
6
7pub fn wire_cast(
8    prefix: impl AsRef<str>,
9    target: &mut TypedModel,
10    inputs: &[OutletId],
11    operating_datum_type: DatumType,
12) -> TractResult<TVec<OutletId>> {
13    let prefix = prefix.as_ref();
14    let mut wires = tvec!();
15    for mut wire in inputs.iter().copied() {
16        if target.outlet_fact(wire)?.datum_type != operating_datum_type {
17            wire = target.wire_node(
18                target.unique_name(format!("{prefix}.cast")),
19                crate::ops::cast::cast(operating_datum_type),
20                &[wire],
21            )?[0];
22        }
23        wires.push(wire);
24    }
25    Ok(wires)
26}
27
28#[derive(Debug, Clone, new, Hash, PartialEq, Eq)]
29pub struct Cast {
30    pub to: DatumType,
31}
32
33impl Op for Cast {
34    fn name(&self) -> StaticName {
35        "Cast".into()
36    }
37
38    op_as_typed_op!();
39    impl_op_same_as!();
40}
41
42impl EvalOp for Cast {
43    fn is_stateless(&self) -> bool {
44        true
45    }
46
47    fn eval_with_session(
48        &self,
49        _node_id: usize,
50        state: &SessionState,
51        inputs: TVec<TValue>,
52    ) -> TractResult<TVec<TValue>> {
53        let input = args_1!(inputs);
54        if input.datum_type() == self.to {
55            Ok(tvec!(input))
56        } else if input.datum_type() == TDim::datum_type() {
57            let mut tmp = Tensor::zero_dt(i64::datum_type(), input.shape())?;
58            for (dim, i) in
59                tract_itertools::izip!(input.as_slice::<TDim>()?, tmp.as_slice_mut::<i64>()?)
60            {
61                *i = dim.eval(&state.resolved_symbols).to_i64()?
62            }
63            Ok(tvec!(tmp.cast_to_dt(self.to)?.into_owned().into_tvalue()))
64        } else {
65            Ok(tvec!(input.cast_to_dt(self.to)?.into_owned().into_tvalue()))
66        }
67    }
68}
69
70impl TypedOp for Cast {
71    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
72        Ok(tvec!(self.to.fact(inputs[0].shape.clone())))
73    }
74
75    fn declutter(
76        &self,
77        model: &TypedModel,
78        node: &TypedNode,
79    ) -> TractResult<Option<TypedModelPatch>> {
80        if model.outlet_fact(node.inputs[0])?.datum_type == self.to {
81            TypedModelPatch::shunt_one_op(model, node)
82        } else {
83            Ok(None)
84        }
85    }
86
87    fn axes_mapping(
88        &self,
89        inputs: &[&TypedFact],
90        outputs: &[&TypedFact],
91    ) -> TractResult<AxesMapping> {
92        AxesMapping::natural(inputs, outputs)
93    }
94
95    fn change_axes(
96        &self,
97        model: &TypedModel,
98        node: &TypedNode,
99        _io: InOut,
100        change: &AxisOp,
101    ) -> TractResult<Option<AxisChangeConsequence>> {
102        Ok(Some(AxisChangeConsequence::new(model, node, None, change)))
103    }
104
105    fn slice(
106        &self,
107        patch: &mut TypedModelPatch,
108        _model: &TypedModel,
109        node: &TypedNode,
110        _prefix: &str,
111        inputs: &[OutletId],
112        _output_axis: usize,
113        _start: &TDim,
114        _end: &TDim,
115    ) -> TractResult<Option<TVec<OutletId>>> {
116        patch.wire_node(&node.name, &node.op, inputs).map(Some)
117    }
118
119    as_op!();
120}