use std::fmt;
use std::marker::PhantomData;
use std::ops::{Div, Mul};
use num::cast::ToPrimitive;
use num::Zero;
use analyser::prelude::*;
use analyser::rules::prelude::*;
use dim::TDim;
use {DatumType, Result, Tensor};
pub trait Output: fmt::Debug + Clone + PartialEq {
fn wrap(self) -> Wrapped {
Self::into_wrapped(self)
}
fn into_wrapped(source: Self) -> Wrapped;
fn from_wrapped(wrapped: Wrapped) -> Result<Self>;
}
macro_rules! impl_output {
($type:ty, $constr:ident) => {
impl Output for $type {
fn into_wrapped(source: Self) -> Wrapped {
Wrapped::$constr(source)
}
fn from_wrapped(wrapped: Wrapped) -> Result<$type> {
if let Wrapped::$constr(v) = wrapped {
Ok(v)
} else {
bail!("Tried to get a {} from {:?}.", stringify!($ty), wrapped);
}
}
}
};
}
impl_output!(IntFact, Int);
impl_output!(TypeFact, Type);
impl_output!(ShapeFact, Shape);
impl_output!(ValueFact, Value);
impl_output!(DimFact, Dim);
impl Output for usize {
fn into_wrapped(source: usize) -> Wrapped {
IntFact::into_wrapped((source as isize).into())
}
fn from_wrapped(wrapped: Wrapped) -> Result<usize> {
let message = format!("Tried to convert {:?} to a usize.", wrapped);
IntFact::from_wrapped(wrapped)?
.concretize()
.and_then(|u| u.to_usize())
.ok_or(message.into())
}
}
impl Output for isize {
fn into_wrapped(source: isize) -> Wrapped {
IntFact::into_wrapped(source.into())
}
fn from_wrapped(wrapped: Wrapped) -> Result<isize> {
let message = format!("Tried to convert {:?} to a isize.", wrapped);
IntFact::from_wrapped(wrapped)?
.concretize()
.ok_or(message.into())
}
}
impl Output for Tensor {
fn into_wrapped(source: Tensor) -> Wrapped {
ValueFact::into_wrapped(source.into())
}
fn from_wrapped(wrapped: Wrapped) -> Result<Tensor> {
let message = format!("Tried to convert {:?} to a tensor.", wrapped);
ValueFact::from_wrapped(wrapped)?
.concretize()
.ok_or(message.into())
}
}
impl Output for TDim {
fn into_wrapped(source: TDim) -> Wrapped {
DimFact::into_wrapped(source.into())
}
fn from_wrapped(wrapped: Wrapped) -> Result<TDim> {
let message = format!("Tried to convert {:?} to a usize.", wrapped);
DimFact::from_wrapped(wrapped)?
.concretize()
.ok_or(message.into())
}
}
#[derive(Debug, Clone)]
pub enum Wrapped {
Int(IntFact),
Type(TypeFact),
Shape(ShapeFact),
Value(ValueFact),
Dim(DimFact),
}
pub trait Expression: fmt::Debug {
type Output: Output;
fn get(&self, context: &Context) -> Result<Self::Output>;
fn set(&self, context: &mut Context, value: Self::Output) -> Result<()>;
fn get_paths(&self) -> Vec<&Path>;
}
pub struct ConstantExpression<T: Output>(T);
impl<T: Output> Expression for ConstantExpression<T> {
type Output = T;
fn get(&self, _: &Context) -> Result<T> {
Ok(self.0.clone())
}
fn set(&self, _: &mut Context, value: T) -> Result<()> {
if self.0 == value {
Ok(())
} else {
bail!(
"Cannot set the value of constant {:?} to {:?}.",
self.0,
value
);
}
}
fn get_paths(&self) -> Vec<&Path> {
vec![]
}
}
impl<T: Output> fmt::Debug for ConstantExpression<T> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{:?}", self.0)
}
}
pub struct VariableExpression<T: Output>(Path, PhantomData<T>);
impl<T: Output> Expression for VariableExpression<T> {
type Output = T;
fn get(&self, context: &Context) -> Result<T> {
context.get(&self.0)
}
fn set(&self, context: &mut Context, value: T) -> Result<()> {
context.set(&self.0, value)
}
fn get_paths(&self) -> Vec<&Path> {
vec![&self.0]
}
}
impl<T: Output> fmt::Debug for VariableExpression<T> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{:?}", self.0)
}
}
pub struct ProductExpression<E, V>(isize, E)
where
V: Zero + Mul<isize, Output = V> + Div<isize, Output = V> + Clone + Output,
E: Expression<Output = V>;
impl<E, V> Expression for ProductExpression<E, V>
where
V: Zero + Mul<isize, Output = V> + Div<isize, Output = V> + Clone + Output,
E: Expression<Output = V>,
{
type Output = V;
fn get(&self, context: &Context) -> Result<V> {
let v: V = self.1.get(context)?;
Ok(v * self.0)
}
fn set(&self, context: &mut Context, value: V) -> Result<()> {
let k = &self.0;
let m = value;
if m.is_zero() && k.is_zero() {
Ok(())
} else if m.is_zero() {
self.1.set(context, V::zero())
} else {
let div = m.div(*k);
self.1.set(context, div)
}
}
fn get_paths(&self) -> Vec<&Path> {
self.1.get_paths()
}
}
impl<E, V> fmt::Debug for ProductExpression<E, V>
where
V: Zero + Mul<isize, Output = V> + Div<isize, Output = V> + Clone + Output,
E: Expression<Output = V>,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{}*{{{:?}}}", self.0, self.1)
}
}
pub trait IntoExpression<T> {
fn into_expr(self) -> T;
}
impl IntoExpression<ConstantExpression<IntFact>> for isize {
fn into_expr(self) -> ConstantExpression<IntFact> {
ConstantExpression(self.into())
}
}
impl<'a> IntoExpression<ConstantExpression<IntFact>> for &'a isize {
fn into_expr(self) -> ConstantExpression<IntFact> {
ConstantExpression((*self).into())
}
}
impl IntoExpression<ConstantExpression<TypeFact>> for DatumType {
fn into_expr(self) -> ConstantExpression<TypeFact> {
ConstantExpression(self.into())
}
}
impl<'a> IntoExpression<ConstantExpression<TypeFact>> for &'a DatumType {
fn into_expr(self) -> ConstantExpression<TypeFact> {
ConstantExpression((*self).into())
}
}
impl<T> IntoExpression<ConstantExpression<T>> for T
where
T: Fact + Output,
{
fn into_expr(self) -> ConstantExpression<T> {
ConstantExpression(self)
}
}
impl IntoExpression<ConstantExpression<DimFact>> for TDim {
fn into_expr(self) -> ConstantExpression<DimFact> {
ConstantExpression(self.into())
}
}
impl<T> IntoExpression<VariableExpression<T::Output>> for T
where
T: ComparableProxy,
{
fn into_expr(self) -> VariableExpression<T::Output> {
VariableExpression(self.get_path().clone().into(), PhantomData)
}
}
impl<E, V, I> IntoExpression<ProductExpression<E, V>> for (isize, I)
where
V: Zero + Mul<isize, Output = V> + Div<isize, Output = V> + Clone + Output,
E: Expression<Output = V>,
I: IntoExpression<E>,
{
fn into_expr(self) -> ProductExpression<E, V> {
let (k, e) = self;
ProductExpression(k, e.into_expr())
}
}
impl<E, V, I> IntoExpression<ProductExpression<E, V>> for (i32, I)
where
V: Zero + Mul<isize, Output = V> + Div<isize, Output = V> + Clone + Output,
E: Expression<Output = V>,
I: IntoExpression<E>,
{
fn into_expr(self) -> ProductExpression<E, V> {
let (k, e) = self;
ProductExpression(k as isize, e.into_expr())
}
}