use std::fmt;
use std::marker::PhantomData;
use std::ops::{Add, Div, Mul, Neg, Sub};
use tract_num_traits::ToPrimitive;
use tract_num_traits::Zero;
use crate::internal::*;
use self::super::super::factoid::*;
use self::super::path::Path;
use self::super::proxies::*;
use self::super::solver::Context;
pub trait Output: fmt::Debug + Clone + PartialEq {
fn wrap(self) -> Wrapped {
Self::into_wrapped(self)
}
fn into_wrapped(source: Self) -> Wrapped;
fn from_wrapped(wrapped: Wrapped) -> TractResult<Self>;
}
macro_rules! impl_output {
($type:ty, $constr:ident, $name:expr) => {
impl Output for $type {
fn into_wrapped(source: Self) -> Wrapped {
Wrapped::$constr(source)
}
fn from_wrapped(wrapped: Wrapped) -> TractResult<$type> {
if let Wrapped::$constr(v) = wrapped {
Ok(v)
} else {
bail!("Tried to get a {} from {:?}.", $name, wrapped);
}
}
}
};
}
impl_output!(IntFactoid, Int, "Int");
impl_output!(TypeFactoid, Type, "DatumType");
impl_output!(ShapeFactoid, Shape, "Shape");
impl_output!(ValueFact, Tensor, "Tensor");
impl_output!(DimFact, Dim, "TDim");
impl Output for usize {
fn into_wrapped(source: usize) -> Wrapped {
IntFactoid::into_wrapped((source as i64).into())
}
fn from_wrapped(wrapped: Wrapped) -> TractResult<usize> {
IntFactoid::from_wrapped(wrapped.clone())?
.concretize()
.and_then(|u| u.to_usize())
.with_context(|| format!("Tried to convert {:?} to a usize.", wrapped))
}
}
impl Output for i64 {
fn into_wrapped(source: i64) -> Wrapped {
IntFactoid::into_wrapped(source.into())
}
fn from_wrapped(wrapped: Wrapped) -> TractResult<i64> {
IntFactoid::from_wrapped(wrapped.clone())?
.concretize()
.with_context(|| format!("Tried to convert {:?} to a i64.", wrapped))
}
}
impl Output for Arc<Tensor> {
fn into_wrapped(source: Arc<Tensor>) -> Wrapped {
ValueFact::into_wrapped(source.into())
}
fn from_wrapped(wrapped: Wrapped) -> TractResult<Arc<Tensor>> {
ValueFact::from_wrapped(wrapped.clone())?
.concretize()
.with_context(|| format_err!("Tried to convert {:?} to a tensor.", wrapped))
}
}
impl Output for TDim {
fn into_wrapped(source: TDim) -> Wrapped {
DimFact::into_wrapped(source.into())
}
fn from_wrapped(wrapped: Wrapped) -> TractResult<TDim> {
DimFact::from_wrapped(wrapped.clone())?
.concretize()
.with_context(|| format_err!("Tried to convert {:?} to a usize.", wrapped))
}
}
#[derive(Debug, Clone)]
pub enum Wrapped {
Int(IntFactoid),
Type(TypeFactoid),
Shape(ShapeFactoid),
Tensor(ValueFact),
Dim(DimFact),
}
pub trait TExp<T>: fmt::Debug {
fn get(&self, context: &Context) -> TractResult<T>;
fn set(&self, context: &mut Context, value: T) -> TractResult<bool>;
fn get_paths(&self) -> Vec<&Path>;
}
pub struct Exp<T>(Box<dyn TExp<T>>);
impl<T: Factoid + Output + Clone + fmt::Debug> TExp<T> for Exp<T> {
fn get(&self, context: &Context) -> TractResult<T> {
self.0.get(context)
}
fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
self.0.set(context, value)
}
fn get_paths(&self) -> Vec<&Path> {
self.0.get_paths()
}
}
impl<T> fmt::Debug for Exp<T>
where
T: Factoid + Output + Clone + ::std::fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{:?}", self.0)
}
}
pub trait IntoExp<T> {
fn bex(self) -> Exp<T>;
}
#[derive(new)]
pub struct SumExp<T>(Vec<Exp<T>>)
where
T: Factoid + Output + Clone + ::std::fmt::Debug + 'static;
impl<T> TExp<T> for SumExp<T>
where
T: Factoid + Output + Zero + Add<T> + Neg<Output = T> + Clone + ::std::fmt::Debug + 'static,
{
fn get(&self, context: &Context) -> TractResult<T> {
self.0.iter().try_fold(T::zero(), |acc, it| Ok(acc + it.0.get(context)?))
}
fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
let mut sum = T::zero();
let mut misses = vec![];
for item in &self.0 {
let fact = item.get(context)?;
if fact.is_concrete() {
sum = sum + fact;
} else {
misses.push(item);
}
}
if misses.len() > 1 {
Ok(false)
} else if misses.len() == 1 {
misses[0].set(context, value + -sum)?;
Ok(true)
} else if sum == value {
Ok(false)
} else {
bail!("{:?} set to {:?}, already is {:?}", self, value, sum)
}
}
fn get_paths(&self) -> Vec<&Path> {
self.0.iter().flat_map(|e| e.get_paths()).collect()
}
}
impl<T> fmt::Debug for SumExp<T>
where
T: Factoid + Output + Clone + ::std::fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
for (ix, t) in self.0.iter().enumerate() {
if ix > 0 {
write!(formatter, " + ")?;
}
t.fmt(formatter)?;
}
Ok(())
}
}
pub struct ConstantExp<T>(T)
where
T: Factoid + Output + Clone + ::std::fmt::Debug;
impl<T> TExp<T> for ConstantExp<T>
where
T: Factoid + Output + Clone + ::std::fmt::Debug,
{
fn get(&self, _: &Context) -> TractResult<T> {
Ok(self.0.clone())
}
fn set(&self, _: &mut Context, value: T) -> TractResult<bool> {
self.0.unify(&value)?;
Ok(false)
}
fn get_paths(&self) -> Vec<&Path> {
vec![]
}
}
impl<T> fmt::Debug for ConstantExp<T>
where
T: Factoid + Output + Clone + ::std::fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{:?}", self.0)
}
}
pub struct VariableExp<T>(Path, PhantomData<T>)
where
T: Factoid + Output + Clone + ::std::fmt::Debug;
impl<T> TExp<T> for VariableExp<T>
where
T: Factoid + Output + Clone + ::std::fmt::Debug,
{
fn get(&self, context: &Context) -> TractResult<T> {
context.get(&self.0).with_context(|| format!("while getting {:?}", self.0))
}
fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
let old = self.get(context)?;
let new = old.unify(&value)?;
let diff = old != new;
context.set(&self.0, new).with_context(|| format!("while setting {:?}", self.0))?;
Ok(diff)
}
fn get_paths(&self) -> Vec<&Path> {
vec![&self.0]
}
}
impl<T> fmt::Debug for VariableExp<T>
where
T: Factoid + Output + Clone + ::std::fmt::Debug,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{:?}", self.0)
}
}
pub struct ScaledExp<T>(i64, Exp<T>)
where
T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone;
impl<T> TExp<T> for ScaledExp<T>
where
T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone,
{
fn get(&self, context: &Context) -> TractResult<T> {
let v: T = self.1.get(context)?;
Ok(v * self.0)
}
fn set(&self, context: &mut Context, value: T) -> TractResult<bool> {
let k = &self.0;
let m = value;
if m.is_zero() && k.is_zero() {
Ok(false)
} else if m.is_zero() {
self.1.set(context, T::zero())
} else {
let div = m.div(*k);
self.1.set(context, div)
}
}
fn get_paths(&self) -> Vec<&Path> {
self.1.get_paths()
}
}
impl<T> fmt::Debug for ScaledExp<T>
where
T: Factoid + Output + Zero + Mul<i64, Output = T> + Div<i64, Output = T> + Clone,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{}*{{{:?}}}", self.0, self.1)
}
}
pub struct IntoDimExp(Exp<IntFactoid>);
impl TExp<DimFact> for IntoDimExp {
fn get(&self, context: &Context) -> TractResult<DimFact> {
let v: IntFactoid = self.0.get(context)?;
match v {
GenericFactoid::Only(i) => Ok(GenericFactoid::Only(i.to_dim())),
GenericFactoid::Any => Ok(GenericFactoid::Any),
}
}
fn set(&self, context: &mut Context, value: DimFact) -> TractResult<bool> {
if let Some(concrete) = value.concretize() {
if let Ok(int) = concrete.to_i64() {
return self.0.set(context, GenericFactoid::Only(int));
}
}
Ok(false)
}
fn get_paths(&self) -> Vec<&Path> {
self.0.get_paths()
}
}
impl fmt::Debug for IntoDimExp {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{{({:?}) as dim}}", self.0)
}
}
impl<T, E: TExp<T> + 'static> IntoExp<T> for E {
fn bex(self) -> Exp<T> {
Exp(Box::new(self))
}
}
impl IntoExp<TypeFactoid> for TypeProxy {
fn bex(self) -> Exp<TypeFactoid> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl<'a> IntoExp<TypeFactoid> for &'a TypeProxy {
fn bex(self) -> Exp<TypeFactoid> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl IntoExp<TypeFactoid> for DatumType {
fn bex(self) -> Exp<TypeFactoid> {
ConstantExp(self.into()).bex()
}
}
impl<'a> IntoExp<TypeFactoid> for &'a DatumType {
fn bex(self) -> Exp<TypeFactoid> {
ConstantExp((*self).into()).bex()
}
}
impl<'a> IntoExp<IntFactoid> for &'a IntProxy {
fn bex(self) -> Exp<IntFactoid> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl<'a> IntoExp<IntFactoid> for &'a ElementProxy {
fn bex(self) -> Exp<IntFactoid> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl IntoExp<IntFactoid> for i64 {
fn bex(self) -> Exp<IntFactoid> {
ConstantExp(self.into()).bex()
}
}
impl<IE: IntoExp<IntFactoid>> Add<IE> for Exp<IntFactoid> {
type Output = Exp<IntFactoid>;
fn add(self, other: IE) -> Exp<IntFactoid> {
SumExp(vec![self.bex(), other.bex()]).bex()
}
}
impl<IE: IntoExp<IntFactoid>> Sub<IE> for Exp<IntFactoid> {
type Output = Exp<IntFactoid>;
fn sub(self, other: IE) -> Exp<IntFactoid> {
SumExp(vec![self.bex(), -1 * other.bex()]).bex()
}
}
impl Mul<Exp<IntFactoid>> for i64 {
type Output = Exp<IntFactoid>;
fn mul(self, other: Exp<IntFactoid>) -> Exp<IntFactoid> {
ScaledExp(self, other).bex()
}
}
impl<'a> IntoExp<DimFact> for &'a DimProxy {
fn bex(self) -> Exp<DimFact> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl IntoExp<DimFact> for TDim {
fn bex(self) -> Exp<DimFact> {
ConstantExp(self.into()).bex()
}
}
impl IntoExp<DimFact> for &TDim {
fn bex(self) -> Exp<DimFact> {
ConstantExp(self.clone().into()).bex()
}
}
impl<IE: IntoExp<DimFact>> Add<IE> for Exp<DimFact> {
type Output = Exp<DimFact>;
fn add(self, other: IE) -> Exp<DimFact> {
SumExp(vec![self.bex(), other.bex()]).bex()
}
}
impl<IE: IntoExp<DimFact>> Sub<IE> for Exp<DimFact> {
type Output = Exp<DimFact>;
fn sub(self, other: IE) -> Exp<DimFact> {
SumExp(vec![self.bex(), -1 * other.bex()]).bex()
}
}
impl Mul<Exp<DimFact>> for i64 {
type Output = Exp<DimFact>;
fn mul(self, other: Exp<DimFact>) -> Exp<DimFact> {
ScaledExp(self, other).bex()
}
}
pub trait ToDimExp {
fn to_dim(self) -> Exp<DimFact>;
}
impl ToDimExp for Exp<IntFactoid> {
fn to_dim(self) -> Exp<DimFact> {
IntoDimExp(self).bex()
}
}
impl IntoExp<ShapeFactoid> for ShapeFactoid {
fn bex(self) -> Exp<ShapeFactoid> {
ConstantExp(self).bex()
}
}
impl IntoExp<ShapeFactoid> for ShapeProxy {
fn bex(self) -> Exp<ShapeFactoid> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl<'a> IntoExp<ShapeFactoid> for &'a ShapeProxy {
fn bex(self) -> Exp<ShapeFactoid> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl IntoExp<ShapeFactoid> for TVec<TDim> {
fn bex(self) -> Exp<ShapeFactoid> {
ConstantExp(self.into_iter().collect()).bex()
}
}
impl IntoExp<ValueFact> for ValueProxy {
fn bex(self) -> Exp<ValueFact> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl<'a> IntoExp<ValueFact> for &'a ValueProxy {
fn bex(self) -> Exp<ValueFact> {
VariableExp(self.get_path().clone(), PhantomData).bex()
}
}
impl IntoExp<ValueFact> for Arc<Tensor> {
fn bex(self) -> Exp<ValueFact> {
ConstantExp(self.into()).bex()
}
}