tract-core 0.23.0-dev.4

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
use crate::internal::*;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Trilu {
    pub upper: bool,
}

impl Op for Trilu {
    fn name(&self) -> StaticName {
        "Trilu".into()
    }

    op_as_typed_op!();
}

impl EvalOp for Trilu {
    fn is_stateless(&self) -> bool {
        true
    }

    fn eval(&self, inputs: TVec<TValue>) -> TractResult<TVec<TValue>> {
        let (input, k) = args_2!(inputs);
        let mut input = input.into_tensor();
        let k = *k.try_as_plain()?.to_scalar::<i64>()?;
        fn eval_t<T: Datum>(tensor: &mut Tensor, upper: bool, k: i64) -> TractResult<()> {
            let mut tensor_plain = tensor.try_as_plain_mut()?;
            let mut view = tensor_plain.to_array_view_mut::<T>()?;
            for coords in tract_ndarray::indices(view.shape()) {
                let row = coords[view.ndim() - 2] as i64;
                let col = coords[view.ndim() - 1] as i64;
                if upper {
                    if col < row + k {
                        view[coords] = T::default();
                    }
                } else if col > row + k {
                    view[coords] = T::default();
                }
            }
            Ok(())
        }
        dispatch_datum!(eval_t(input.datum_type())(&mut input, self.upper, k))?;
        Ok(tvec!(input.into_tvalue()))
    }
}

impl TypedOp for Trilu {
    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
        Ok(tvec!(inputs[0].without_value()))
    }

    as_op!();
}