use crate::tfpb::node_def::NodeDef;
use tract_core::internal::*;
#[derive(Debug, Clone, new)]
pub struct Transpose {
t: DatumType,
t_perm: DatumType,
}
pub fn transpose(pb: &NodeDef) -> TractResult<Box<Op>> {
let t = pb.get_attr_datum_type("T")?;
let t_perm = pb.get_attr_datum_type("Tperm")?;
Ok(Box::new(Transpose::new(t, t_perm)))
}
impl Transpose {
fn compute_shape<D: DimLike>(shape: &[D], perm: &[i32]) -> TVec<D> {
let mut new_shape = tvec![D::zero(); shape.len()];
for (ix, &d) in perm.iter().enumerate() {
new_shape[ix] = shape[d as usize];
}
new_shape
}
fn eval_t<T: Datum>(
&self,
input: Arc<Tensor>,
perm: &[usize],
) -> TractResult<TVec<Arc<Tensor>>> {
Ok(tvec![input.into_tensor().into_array::<T>()?.permuted_axes(perm).into_arc_tensor()])
}
}
impl Op for Transpose {
fn name(&self) -> Cow<str> {
"tf.Transpose".into()
}
fn declutter(
&self,
model: &TypedModel,
node: &TypedNode,
) -> TractResult<Option<TypedModelPatch>> {
let inputs = model.node_input_facts(node.id)?;
if let Some(ref perm) = inputs[1].konst {
let perm: Vec<usize> =
perm.cast_to::<i32>()?.as_slice::<i32>()?.iter().map(|&x| x as usize).collect();
let op = ::tract_core::ops::array::PermuteAxes::new(Some(perm));
return Ok(Some(TypedModelPatch::single_unary_op(&model, &node, op)?));
}
Ok(None)
}
}
impl StatelessOp for Transpose {
fn eval(&self, mut inputs: TVec<Arc<Tensor>>) -> TractResult<TVec<Arc<Tensor>>> {
let (data, perm) = args_2!(inputs);
let perm: TVec<usize> =
perm.cast_to::<i32>()?.as_slice::<i32>()?.iter().map(|&x| x as usize).collect();
dispatch_datum!(Self::eval_t(data.datum_type())(self, data, &*perm))
}
}
impl InferenceRulesOp for Transpose {
fn rules<'r, 'p: 'r, 's: 'r>(
&'s self,
s: &mut Solver<'r>,
inputs: &'p [TensorProxy],
outputs: &'p [TensorProxy],
) -> InferenceResult {
check_output_arity(&inputs, 2)?;
check_output_arity(&outputs, 1)?;
s.equals(&inputs[0].datum_type, self.t)?;
s.equals(&inputs[1].datum_type, self.t_perm)?;
s.equals(&outputs[0].datum_type, &inputs[0].datum_type)?;
s.equals(&outputs[0].rank, &inputs[0].rank)?;
s.equals(&inputs[1].rank, 1)?;
s.equals(&inputs[1].shape[0], inputs[0].rank.bex().to_dim())?;
s.given_2(&inputs[0].shape, &inputs[1].value, move |s, shape, perm| {
let perm = perm.cast_to::<i32>()?;
let output_shape = Self::compute_shape(&shape, perm.as_slice::<i32>()?);
s.equals(&outputs[0].shape, output_shape)
})
}
}