use crate::TractResult;
use std::fmt;
use std::iter::FromIterator;
use std::ops::{Add, Div, Mul, Neg, Sub};
use num_traits::Zero;
use crate::internal::*;
pub trait Fact: fmt::Debug + Clone + PartialEq + Default {
type Concrete: fmt::Debug;
fn concretize(&self) -> Option<Self::Concrete>;
fn is_concrete(&self) -> bool {
self.concretize().is_some()
}
fn unify(&self, other: &Self) -> TractResult<Self>;
}
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[derive(Clone, PartialEq, Default)]
pub struct TensorFact {
pub datum_type: TypeFact,
pub shape: ShapeFact,
pub value: ValueFact,
}
impl TensorFact {
pub fn new() -> TensorFact {
TensorFact::default()
}
pub fn any() -> TensorFact {
TensorFact::default()
}
pub fn dt(dt: DatumType) -> TensorFact {
TensorFact::default().with_datum_type(dt)
}
pub fn dt_shape<S: Into<ShapeFact>>(dt: DatumType, shape: S) -> TensorFact {
TensorFact::dt(dt).with_shape(shape)
}
pub fn shape<S: Into<ShapeFact>>(shape: S) -> TensorFact {
TensorFact::default().with_shape(shape)
}
pub fn with_datum_type(self, dt: DatumType) -> TensorFact {
TensorFact { datum_type: dt.into(), ..self }
}
pub fn with_shape<S: Into<ShapeFact>>(self, shape: S) -> TensorFact {
TensorFact { shape: shape.into(), ..self }
}
pub fn with_streaming_shape<S: IntoIterator<Item = Option<usize>>>(
self,
shape: S,
) -> TensorFact {
use crate::dim::ToDim;
let shape: ShapeFact = shape
.into_iter()
.map(|d| d.map(|d| (d as isize).to_dim()).unwrap_or(TDim::s()))
.collect();
self.with_shape(shape)
}
pub fn stream_info(&self) -> TractResult<Option<StreamInfo>> {
self.shape.stream_info()
}
pub fn format_dt_shape(&self) -> String {
if !self.shape.open && self.shape.dims.len() == 0 {
format!(
"{}",
self.datum_type
.concretize()
.map(|dt| format!("{:?}", dt))
.unwrap_or("___".to_string())
)
} else {
format!(
"{:?}x{}",
self.shape,
self.datum_type
.concretize()
.map(|dt| format!("{:?}", dt))
.unwrap_or("___".to_string())
)
}
}
}
impl Fact for TensorFact {
type Concrete = SharedTensor;
fn concretize(&self) -> Option<Self::Concrete> {
self.value.concretize()
}
fn unify(&self, other: &Self) -> TractResult<Self> {
let tensor = TensorFact {
datum_type: self.datum_type.unify(&other.datum_type)?,
shape: self.shape.unify(&other.shape)?,
value: self.value.unify(&other.value)?,
};
trace!("Unifying {:?} with {:?} into {:?}.", self, other, tensor);
Ok(tensor)
}
}
impl<V: Into<SharedTensor>> From<V> for TensorFact {
fn from(v: V) -> TensorFact {
let v: SharedTensor = v.into();
TensorFact {
datum_type: GenericFact::Only(v.datum_type()),
shape: ShapeFact::from(v.shape()),
value: GenericFact::Only(v),
}
}
}
impl fmt::Debug for TensorFact {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
if let Some(t) = self.value.concretize() {
write!(formatter, "{:?}", t)
} else {
write!(formatter, "{}", self.format_dt_shape())
}
}
}
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[derive(Clone, PartialEq)]
pub enum GenericFact<T: fmt::Debug + Clone + PartialEq> {
Only(T),
Any,
}
impl<T: Copy + Clone + fmt::Debug + PartialEq> Copy for GenericFact<T> {}
impl<T: fmt::Debug + Clone + PartialEq> Fact for GenericFact<T> {
type Concrete = T;
fn concretize(&self) -> Option<T> {
match self {
GenericFact::Any => None,
GenericFact::Only(m) => Some(m.clone()),
}
}
fn unify(&self, other: &Self) -> TractResult<Self> {
let fact = match (self, other) {
(_, GenericFact::Any) => self.clone(),
(GenericFact::Any, _) => other.clone(),
_ if self == other => self.clone(),
_ => bail!("Impossible to unify {:?} with {:?}.", self, other),
};
Ok(fact)
}
}
impl<T: fmt::Debug + Clone + PartialEq> Default for GenericFact<T> {
fn default() -> Self {
GenericFact::Any
}
}
impl<T: fmt::Debug + Clone + PartialEq> From<T> for GenericFact<T> {
fn from(t: T) -> Self {
GenericFact::Only(t)
}
}
impl<T: fmt::Debug + Clone + PartialEq> fmt::Debug for GenericFact<T> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
match self {
GenericFact::Any => write!(formatter, "?"),
GenericFact::Only(u) => write!(formatter, "{:?}", u),
}
}
}
pub type TypeFact = GenericFact<DatumType>;
#[cfg_attr(feature = "serialize", derive(Serialize))]
#[derive(Clone, PartialEq)]
pub struct ShapeFact {
open: bool,
dims: TVec<GenericFact<i32>>,
stream: Option<StreamInfo>,
}
impl ShapeFact {
pub fn open(dims: TVec<DimFact>) -> ShapeFact {
if let Some((ix, &d)) = dims
.iter()
.enumerate()
.find(|(_ix, d)| d.concretize().map(|d| d.is_stream()).unwrap_or(false))
{
let stream = Some(StreamInfo { axis: ix, len: d.concretize().unwrap() });
ShapeFact {
open: true,
dims: dims
.iter()
.map(|d| match d {
GenericFact::Only(d) if d.is_stream() => GenericFact::Only(-1),
GenericFact::Only(d) => GenericFact::Only(d.to_integer().unwrap()),
GenericFact::Any => GenericFact::Any,
})
.collect(),
stream,
}
} else {
ShapeFact {
open: true,
dims: dims
.iter()
.map(|d| match d {
GenericFact::Only(d) => GenericFact::Only(d.to_integer().unwrap()),
GenericFact::Any => GenericFact::Any,
})
.collect(),
stream: None,
}
}
}
pub fn is_open(&self) -> bool {
self.open
}
pub fn closed(dims: TVec<DimFact>) -> ShapeFact {
ShapeFact { open: false, ..Self::open(dims) }
}
pub fn rank(&self) -> IntFact {
if self.open { GenericFact::Any } else { GenericFact::Only(self.dims.len() as i32) }.into()
}
pub fn dims(&self) -> impl Iterator<Item = DimFact> {
let stream = self.stream.clone();
self.dims.clone().into_iter().map(move |d| match d {
GenericFact::Only(-1) => {
assert!(stream.is_some());
GenericFact::Only(stream.unwrap().len)
}
GenericFact::Only(d) => GenericFact::Only(d.to_dim()),
GenericFact::Any => GenericFact::Any,
})
}
pub fn stream_info(&self) -> TractResult<Option<StreamInfo>> {
let concrete = self
.concretize()
.ok_or("Shape has unknown dims, can not find streaming dim for sure.")?;
let count = concrete.iter().filter(|&d| d.is_stream()).count();
if count > 1 {
bail!("Shape has more than one streaming dim. This is terribly wrong.")
}
Ok(concrete
.into_iter()
.enumerate()
.find(|(_, d)| d.is_stream())
.map(|(axis, len)| StreamInfo { axis, len }))
}
pub fn as_concrete_finite(&self) -> TractResult<Option<TVec<usize>>> {
if !self.is_concrete() || self.stream_info()?.is_some() {
return Ok(None);
}
Ok(Some(self.dims.iter().map(|i| i.concretize().unwrap() as usize).collect()))
}
}
impl Fact for ShapeFact {
type Concrete = TVec<TDim>;
fn concretize(self: &ShapeFact) -> Option<TVec<TDim>> {
if self.open {
return None;
}
let dims: TVec<_> = self.dims().filter_map(|d| d.concretize()).collect();
if dims.len() < self.dims.len() {
None
} else {
Some(dims)
}
}
fn unify(&self, other: &Self) -> TractResult<Self> {
let (x, y) = (self, other);
use itertools::EitherOrBoth::{Both, Left, Right};
use itertools::Itertools;
let xi = x.dims();
let yi = y.dims();
let dimensions: TVec<_> = xi
.zip_longest(yi)
.map(|r| match r {
Both(a, b) => a.unify(&b),
Left(d) if y.open => Ok(d),
Right(d) if x.open => Ok(d),
Left(_) | Right(_) => bail!(
"Impossible to unify closed shapes of different rank (found {:?} and {:?}).",
x,
y
),
})
.collect::<TractResult<_>>()
.map_err(|e| format!("Unifying shapes {:?} and {:?}, {}", x, y, e))?;
if x.open && y.open {
Ok(ShapeFact::open(dimensions))
} else {
Ok(ShapeFact::closed(dimensions))
}
}
}
impl Default for ShapeFact {
fn default() -> ShapeFact {
ShapeFact::open(tvec![])
}
}
impl FromIterator<TDim> for ShapeFact {
fn from_iter<I: IntoIterator<Item = TDim>>(iter: I) -> ShapeFact {
ShapeFact::closed(iter.into_iter().map(|d| GenericFact::Only(d)).collect())
}
}
impl FromIterator<usize> for ShapeFact {
fn from_iter<I: IntoIterator<Item = usize>>(iter: I) -> ShapeFact {
ShapeFact::closed(iter.into_iter().map(|d| GenericFact::Only(d.to_dim())).collect())
}
}
impl<D: ToDim, I: IntoIterator<Item = D>> From<I> for ShapeFact {
fn from(it: I) -> ShapeFact {
ShapeFact::closed(it.into_iter().map(|d| GenericFact::Only(d.to_dim())).collect())
}
}
impl fmt::Debug for ShapeFact {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
for (ix, d) in self.dims.iter().enumerate() {
if ix != 0 {
write!(formatter, "x")?
}
if let Some(stream) = self.stream {
if stream.axis == ix {
write!(formatter, "{:?}", stream.len)?;
} else {
write!(formatter, "{:?}", d)?;
}
} else {
write!(formatter, "{:?}", d)?;
}
}
if self.open {
if self.dims.len() == 0 {
write!(formatter, "..")?;
} else {
write!(formatter, "x..")?;
}
}
Ok(())
}
}
pub type DimFact = GenericFact<TDim>;
pub type ValueFact = GenericFact<SharedTensor>;
pub type IntFact = GenericFact<i32>;
impl<T> Zero for GenericFact<T>
where
T: Add<T, Output = T> + Zero + PartialEq + Copy + Clone + ::std::fmt::Debug,
{
fn zero() -> GenericFact<T> {
GenericFact::Only(T::zero())
}
fn is_zero(&self) -> bool {
match self {
GenericFact::Only(t) => t.is_zero(),
_ => false,
}
}
}
impl<T> Neg for GenericFact<T>
where
T: Neg<Output = T> + PartialEq + Copy + Clone + ::std::fmt::Debug,
{
type Output = GenericFact<T>;
fn neg(self) -> GenericFact<T> {
match self {
GenericFact::Only(t) => GenericFact::Only(t.neg()),
any => any,
}
}
}
impl<T, I> Add<I> for GenericFact<T>
where
T: Add<T, Output = T> + PartialEq + Copy + Clone + ::std::fmt::Debug,
I: Into<GenericFact<T>>,
{
type Output = GenericFact<T>;
fn add(self, rhs: I) -> Self::Output {
match (self.concretize(), rhs.into().concretize()) {
(Some(a), Some(b)) => GenericFact::Only(a + b),
_ => GenericFact::Any,
}
}
}
impl<T> Sub<GenericFact<T>> for GenericFact<T>
where
T: Sub<T, Output = T> + PartialEq + Copy + Clone + ::std::fmt::Debug,
{
type Output = GenericFact<T>;
fn sub(self, rhs: GenericFact<T>) -> Self::Output {
match (self.concretize(), rhs.concretize()) {
(Some(a), Some(b)) => GenericFact::Only(a - b),
_ => GenericFact::Any,
}
}
}
impl<T, R> Mul<R> for GenericFact<T>
where
T: Mul<R, Output = T> + PartialEq + Copy + Clone + ::std::fmt::Debug,
{
type Output = GenericFact<T>;
fn mul(self, rhs: R) -> Self::Output {
if let Some(a) = self.concretize() {
GenericFact::Only(a * rhs)
} else {
GenericFact::Any
}
}
}
impl<T, R> Div<R> for GenericFact<T>
where
T: Div<R, Output = T> + PartialEq + Copy + Clone + ::std::fmt::Debug,
{
type Output = GenericFact<T>;
fn div(self, rhs: R) -> Self::Output {
if let Some(a) = self.concretize() {
GenericFact::Only(a / rhs)
} else {
GenericFact::Any
}
}
}