tract-core 0.19.2

Tiny, no-nonsense, self contained, TensorFlow and ONNX inference
Documentation
//! Partial and complete tensor types representations.
use crate::internal::*;
use downcast_rs::Downcast;
use std::fmt;

#[derive(Clone, PartialEq, Eq, Hash)]
pub struct ShapeFact {
    dims: TVec<TDim>,
    concrete: Option<TVec<usize>>,
}

impl ShapeFact {
    #[inline]
    pub fn rank(&self) -> usize {
        self.dims.len()
    }

    fn compute_concrete(&mut self) {
        assert!(self.dims.iter().all(|d| d.to_isize().map(|d| d >= 0).unwrap_or(true)));
        self.concrete =
            self.dims.iter().map(|d| d.to_usize()).collect::<TractResult<TVec<_>>>().ok()
    }

    /// Shape of the tensor, unless it has symbolic dimensions.
    #[inline]
    pub fn as_concrete(&self) -> Option<&[usize]> {
        self.concrete.as_deref()
    }

    /// Do we have a symbol-less value ?
    #[inline]
    pub fn is_concrete(&self) -> bool {
        self.concrete.is_some()
    }

    /// Iterator over dimension of the shape.
    pub fn iter(&self) -> impl Iterator<Item = TDim> + '_ {
        self.dims.iter().cloned()
    }

    /// Convert the shape to an array of extended dimensions.
    #[inline]
    pub fn to_tvec(&self) -> TVec<TDim> {
        self.dims.clone()
    }

    /// Compute the volume of the tensor.
    #[inline]
    pub fn volume(&self) -> TDim {
        self.dims.iter().product()
    }

    #[inline]
    pub fn eval(&self, values: &SymbolValues) -> TractResult<Cow<ShapeFact>> {
        if self.is_concrete() {
            Ok(Cow::Borrowed(self))
        } else {
            Ok(Cow::Owned(self.iter().map(|d| d.eval(values)).collect::<ShapeFact>()))
        }
    }

    #[inline]
    pub fn eval_to_usize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<usize>>> {
        if let Some(c) = &self.concrete {
            Ok(Cow::Borrowed(c))
        } else {
            Ok(Cow::Owned(
                self.iter().map(|d| d.eval(values).to_usize()).collect::<TractResult<TVec<_>>>()?,
            ))
        }
    }

    #[inline]
    pub fn eval_to_isize(&self, values: &SymbolValues) -> TractResult<Cow<TVec<isize>>> {
        if let Some(c) = &self.concrete {
            Ok(unsafe { std::mem::transmute(Cow::Borrowed(c)) })
        } else {
            Ok(Cow::Owned(
                self.iter().map(|d| d.eval(values).to_isize()).collect::<TractResult<TVec<_>>>()?,
            ))
        }
    }

    pub fn from_dims<D: ToDim, T: IntoIterator<Item = D>>(it: T) -> ShapeFact {
        let mut dims =
            ShapeFact { dims: it.into_iter().map(|d| d.to_dim()).collect(), concrete: None };
        dims.compute_concrete();
        dims
    }

    pub fn set(&mut self, ix: usize, dim: TDim) {
        self.dims[ix] = dim;
        self.compute_concrete();
    }

    pub fn insert_axis(&mut self, axis: usize) -> TractResult<()> {
        self.dims.insert(axis, 1.into());
        if let Some(concrete) = &mut self.concrete {
            concrete.insert(axis, 1);
        }
        Ok(())
    }

    pub fn remove_axis(&mut self, axis: usize) -> TractResult<()> {
        self.dims.remove(axis);
        if let Some(concrete) = &mut self.concrete {
            concrete.remove(axis);
        }
        Ok(())
    }

    pub fn compatible_with(&self, _other: &ShapeFact) -> bool {
        if self.rank() == _other.rank() {
            self.dims
                .iter()
                .zip(_other.dims.iter())
                .all(|(dim, other_dim)| dim.compatible_with(other_dim))
        } else {
            false
        }
    }

    pub fn scalar() -> ShapeFact {
        let void: &[usize] = &[];
        Self::from(void)
    }
}

impl std::ops::Deref for ShapeFact {
    type Target = [TDim];
    fn deref(&self) -> &[TDim] {
        &self.dims
    }
}

impl<D: ToDim, T: IntoIterator<Item = D>> From<T> for ShapeFact {
    fn from(it: T) -> ShapeFact {
        ShapeFact::from_dims(it)
    }
}

/// Type information about a tensor: shape, and element type, in various state
/// of determination.
pub trait Fact: std::fmt::Debug + Downcast + dyn_clone::DynClone + Send + Sync + 'static {
    fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>>;

    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
        self.to_typed_fact()?.matches(t, symbols)
    }

    fn same_as(&self, _other: &dyn Fact) -> bool;

    /// Ensure that self is same type as another fact or a subtype
    fn compatible_with(&self, _other: &dyn Fact) -> bool;

    fn datum_type(&self) -> Option<DatumType>;
}

impl_downcast!(Fact);
dyn_clone::clone_trait_object!(Fact);

impl<D: ToDim> std::iter::FromIterator<D> for ShapeFact {
    fn from_iter<T: IntoIterator<Item = D>>(iter: T) -> Self {
        ShapeFact::from_dims(iter.into_iter().map(|d| d.to_dim()))
    }
}

impl fmt::Debug for ShapeFact {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        use tract_itertools::Itertools;
        write!(fmt, "{}", self.iter().join(","))
    }
}

impl AsRef<[TDim]> for ShapeFact {
    fn as_ref(&self) -> &[TDim] {
        &self.dims
    }
}

/// Fully determined tensor information for TypedModel.
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct TypedFact {
    /// tensor element type
    pub datum_type: DatumType,
    /// tensor shape
    pub shape: ShapeFact,
    /// optional constant value
    pub konst: Option<Arc<Tensor>>,
    /// optional uniform value
    pub uniform: Option<Arc<Tensor>>,
}

impl_dyn_hash!(TypedFact);

impl TypedFact {
    pub fn scalar<T>() -> TypedFact
    where
        T: Datum,
    {
        Self::dt_scalar(T::datum_type())
    }

    pub fn shape<T, S>(shape: S) -> TypedFact
    where
        T: Datum,
        S: Into<ShapeFact>,
    {
        Self::dt_shape(T::datum_type(), shape)
    }

    pub fn dt_scalar(datum_type: DatumType) -> TypedFact {
        TypedFact { datum_type, shape: ShapeFact::scalar(), konst: None, uniform: None }
    }

    pub fn dt_shape<S>(datum_type: DatumType, shape: S) -> TypedFact
    where
        S: Into<ShapeFact>,
    {
        TypedFact { datum_type, shape: shape.into(), konst: None, uniform: None }
    }

    pub fn rank(&self) -> usize {
        if cfg!(debug_assertions) {
            self.consistent().unwrap();
        }
        self.shape.rank()
    }

    fn format_dt_shape_nocheck(&self) -> String {
        if self.shape.rank() > 0 {
            format!("{:?},{:?}", self.shape, self.datum_type)
        } else {
            format!("{:?}", self.datum_type)
        }
    }

    pub fn format_dt_shape(&self) -> String {
        if cfg!(debug_assertions) {
            self.consistent().unwrap()
        }
        self.format_dt_shape_nocheck()
    }

    pub fn consistent(&self) -> TractResult<()> {
        if let Some(k) = &self.konst {
            if !self.matches(k.as_ref(), None)? {
                bail!("fact says {}, constant is {:?}", self.format_dt_shape_nocheck(), k);
            }
        }
        if let Some(u) = &self.uniform {
            if self.datum_type != u.datum_type() {
                bail!("fact as uniform value {:?}, but is of type {:?}", u, self.datum_type);
            }
        }
        if let (Some(u), Some(k)) = (self.uniform.as_deref(), self.konst.as_deref()) {
            if let Some(k) = k.as_uniform() {
                if &k != u {
                    bail!("Uniform value and uniform constant mismatch: {:?}, {:?}", u, k);
                }
            } else {
                bail!("Fact said to be uniform ({:?}) and equal to {:?} which is not.", u, k);
            }
        }
        Ok(())
    }

    pub fn without_value(&self) -> Self {
        Self::dt_shape(self.datum_type, self.shape.clone())
    }
}

impl Fact for TypedFact {
    fn to_typed_fact(&self) -> TractResult<Cow<TypedFact>> {
        if cfg!(debug_assertions) {
            self.consistent()?
        }
        Ok(Cow::Borrowed(self))
    }

    fn matches(&self, t: &Tensor, symbols: Option<&SymbolValues>) -> TractResult<bool> {
        if self.datum_type != t.datum_type() || self.shape.len() != t.rank() {
            return Ok(false);
        }
        for i in 0..t.rank() {
            if let Ok(dim) =
                self.shape[i].eval(symbols.unwrap_or(&SymbolValues::default())).to_usize()
            {
                if dim != t.shape()[i] {
                    return Ok(false);
                }
            }
        }
        Ok(true)
    }

    fn same_as(&self, other: &dyn Fact) -> bool {
        if cfg!(debug_assertions) {
            self.consistent().unwrap()
        }
        if let Some(other) = other.downcast_ref::<Self>() {
            if cfg!(debug_assertions) {
                other.consistent().unwrap()
            }
            self == other
        } else {
            false
        }
    }

    fn compatible_with(&self, other: &dyn Fact) -> bool {
        if cfg!(debug_assertions) {
            self.consistent().unwrap()
        }
        if let Some(other) = other.downcast_ref::<Self>() {
            if cfg!(debug_assertions) {
                other.consistent().unwrap()
            }
            self.datum_type == other.datum_type && self.shape.compatible_with(&other.shape)
        } else {
            false
        }
    }

    fn datum_type(&self) -> Option<DatumType> {
        Some(self.datum_type)
    }
}

impl From<Tensor> for TypedFact {
    fn from(t: Tensor) -> TypedFact {
        TypedFact::from(t.into_arc_tensor())
    }
}

impl<'t> From<&'t Tensor> for TypedFact {
    fn from(t: &'t Tensor) -> TypedFact {
        TypedFact::from(t.clone())
    }
}

impl From<Arc<Tensor>> for TypedFact {
    fn from(t: Arc<Tensor>) -> TypedFact {
        TypedFact {
            datum_type: t.datum_type(),
            shape: ShapeFact::from_dims(t.shape().iter().map(TDim::from)),
            uniform: t.as_uniform().map(Arc::new),
            konst: Some(t),
        }
    }
}

impl<'a> From<&'a TypedFact> for TypedFact {
    fn from(fact: &TypedFact) -> TypedFact {
        fact.clone()
    }
}

impl fmt::Debug for TypedFact {
    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
        match self.konst {
            Some(ref k) => write!(fmt, "{k:?}"),
            None if self.rank() > 0 => write!(fmt, "{:?},{:?}", self.shape, self.datum_type),
            None => write!(fmt, "{:?}", self.datum_type),
        }
    }
}

pub trait DatumExt {
    fn scalar_fact() -> TypedFact;
    fn fact<S>(shape: S) -> TypedFact
    where
        S: Into<ShapeFact>;
}

impl<T: Datum> DatumExt for T {
    #[allow(clippy::needless_borrow)]
    fn scalar_fact() -> TypedFact {
        TypedFact::shape::<Self, &[usize]>(&[])
    }

    fn fact<S>(shape: S) -> TypedFact
    where
        S: Into<ShapeFact>,
    {
        TypedFact::shape::<Self, _>(shape)
    }
}

pub trait DatumTypeExt {
    fn scalar_fact(&self) -> TypedFact;
    fn fact<S>(&self, shape: S) -> TypedFact
    where
        S: Into<ShapeFact>;
}

impl DatumTypeExt for DatumType {
    #[allow(clippy::needless_borrow)]
    fn scalar_fact(&self) -> TypedFact {
        TypedFact::dt_shape::<&[usize]>(*self, &[])
    }

    fn fact<S>(&self, shape: S) -> TypedFact
    where
        S: Into<ShapeFact>,
    {
        TypedFact::dt_shape(*self, shape)
    }
}