use std::fmt;
use std::ops::{Add, Neg};
use num::Zero;
use analyser::prelude::*;
use analyser::rules::prelude::*;
use model::TVec;
use Result;
#[derive(Debug, new)]
pub struct Context {
pub inputs: TVec<TensorFact>,
pub outputs: TVec<TensorFact>,
}
impl Context {
pub fn get<T: Output>(&self, path: &Path) -> Result<T> {
let value = get_path(self, &path[..])?;
Ok(T::from_wrapped(value)?)
}
pub fn set<T: Output>(&mut self, path: &Path, value: T) -> Result<()> {
set_path(self, &path[..], T::into_wrapped(value))?;
Ok(())
}
}
pub trait Rule<'rules>: fmt::Debug {
fn apply(&self, context: &mut Context) -> Result<(bool, Vec<Box<Rule<'rules> + 'rules>>)>;
fn get_paths(&self) -> Vec<&Path>;
}
struct EqualsRule<T: Output + Fact> {
items: Vec<Exp<T>>,
}
impl<T: Output + Fact> EqualsRule<T> {
pub fn new(items: Vec<Exp<T>>) -> EqualsRule<T> {
EqualsRule { items }
}
}
impl<'rules, T: Output + Fact> Rule<'rules> for EqualsRule<T> {
fn apply(&self, context: &mut Context) -> Result<(bool, Vec<Box<Rule<'rules> + 'rules>>)> {
let mut value = None;
for item in &self.items {
let v = item.get(context)?;
if v.is_concrete() {
value = Some(v);
break;
}
}
if let Some(value) = value {
let mut changed = false;
for item in &self.items {
changed |= item.set(context, value.clone())?;
}
return Ok((changed, vec![]));
}
Ok((false, vec![]))
}
fn get_paths(&self) -> Vec<&Path> {
self.items.iter().flat_map(|e| e.get_paths()).collect()
}
}
impl<'rules, T: Output + Fact> fmt::Debug for EqualsRule<T> {
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
write!(formatter, "{:?}", self.items[0])?;
for item in &self.items[1..] {
write!(formatter, " == {:?}", item)?;
}
Ok(())
}
}
struct EqualsZeroRule<F>(Exp<F>)
where
F: Fact + Zero + Add<F, Output = F> + Neg<Output = F> + Clone + ::std::fmt::Debug + Output;
impl<'rules, F> Rule<'rules> for EqualsZeroRule<F>
where
F: Fact + Zero + Add<F, Output = F> + Neg<Output = F> + Clone + ::std::fmt::Debug + Output,
{
fn apply(&self, context: &mut Context) -> Result<(bool, Vec<Box<Rule<'rules> + 'rules>>)> {
Ok((self.0.set(context, F::zero())?, vec![]))
}
fn get_paths(&self) -> Vec<&Path> {
self.0.get_paths()
}
}
impl<F> fmt::Debug for EqualsZeroRule<F>
where
F: Fact + Zero + Add<F, Output = F> + Neg<Output = F> + Clone + ::std::fmt::Debug + Output,
{
fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
self.0.fmt(formatter)?;
write!(formatter, " == 0")
}
}
pub struct WithRule<'rules, T: Fact> {
pub item: Exp<T>,
pub closure: Box<Fn(&mut Solver<'rules>, T) + 'rules>,
}
impl<'rules, T: Output + Fact> WithRule<'rules, T> {
pub fn new<F>(item: Exp<T>, closure: F) -> WithRule<'rules, T>
where
F: Fn(&mut Solver<'rules>, T) + 'rules,
{
let closure = Box::new(closure);
WithRule { item, closure }
}
}
impl<'rules, T: Output + Fact> Rule<'rules> for WithRule<'rules, T> {
fn apply(&self, context: &mut Context) -> Result<(bool, Vec<Box<Rule<'rules> + 'rules>>)> {
let value = self.item.get(context)?;
trace!(" With rule: {:?} is {:?}", self.item, value);
let mut solver = Solver::default();
(self.closure)(&mut solver, value);
Ok((true, solver.take_rules()))
}
fn get_paths(&self) -> Vec<&Path> {
self.item.get_paths()
}
}
impl<'s, T: Output + Fact> fmt::Debug for WithRule<'s, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "WithRule {{ {:?} }}", self.item)
}
}
pub struct GivenRule<'rules, T: Fact> {
pub item: Exp<T>,
pub closure: Box<Fn(&mut Solver<'rules>, T::Concrete) + 'rules>,
}
impl<'rules, T: Output + Fact> GivenRule<'rules, T> {
pub fn new<F>(item: Exp<T>, closure: F) -> GivenRule<'rules, T>
where
F: Fn(&mut Solver<'rules>, T::Concrete) + 'rules,
{
let closure = Box::new(closure);
GivenRule { item, closure }
}
}
impl<'rules, T: Output + Fact> Rule<'rules> for GivenRule<'rules, T> {
fn apply(&self, context: &mut Context) -> Result<(bool, Vec<Box<Rule<'rules> + 'rules>>)> {
let value = self.item.get(context)?;
if let Some(value) = value.concretize() {
trace!(" Given rule: {:?} is {:?}", self.item, value);
let mut solver = Solver::default();
(self.closure)(&mut solver, value);
Ok((true, solver.take_rules()))
} else {
trace!(
"In {:?}, failed to convert {:?} to expected type",
self,
self.item.get(context)?.wrap()
);
Ok((false, vec![]))
}
}
fn get_paths(&self) -> Vec<&Path> {
self.item.get_paths()
}
}
impl<'s, T: Output + Fact> fmt::Debug for GivenRule<'s, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "GivenRule {{ {:?} }}", self.item)
}
}
pub struct GivenAllRule<'rules, T: Fact> {
pub items: Vec<Exp<T>>,
pub closure: Box<Fn(&mut Solver<'rules>, Vec<T::Concrete>) + 'rules>,
}
impl<'rules, T: Output + Fact> GivenAllRule<'rules, T> {
pub fn new<F>(items: Vec<Exp<T>>, closure: F) -> GivenAllRule<'rules, T>
where
F: Fn(&mut Solver<'rules>, Vec<T::Concrete>) + 'rules,
{
let closure = Box::new(closure);
GivenAllRule { items, closure }
}
}
impl<'rules, T: Output + Fact> Rule<'rules> for GivenAllRule<'rules, T> {
fn apply(&self, context: &mut Context) -> Result<(bool, Vec<Box<Rule<'rules> + 'rules>>)> {
let values: Vec<T> = self
.items
.iter()
.map(|it| it.get(context))
.collect::<Result<Vec<T>>>()?;
let concrete: Vec<_> = values.iter().filter_map(|it| it.concretize()).collect();
if concrete.len() == self.items.len() {
trace!(" Given all rule: {:?} is {:?}", self.items, values);
let mut solver = Solver::default();
(self.closure)(&mut solver, concrete);
Ok((true, solver.take_rules()))
} else {
Ok((false, vec![]))
}
}
fn get_paths(&self) -> Vec<&Path> {
self.items.iter().flat_map(|it| it.get_paths()).collect()
}
}
impl<'s, T: Output + Fact> fmt::Debug for GivenAllRule<'s, T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "GivenAllRule {:?}", self.items)
}
}
#[derive(Default)]
pub struct Solver<'rules> {
pub rules: Vec<Box<Rule<'rules> + 'rules>>,
}
impl<'rules> Solver<'rules> {
pub fn take_rules(self) -> Vec<Box<Rule<'rules> + 'rules>> {
self.rules
}
pub fn infer(
self,
mut facts: (TVec<TensorFact>, TVec<TensorFact>),
) -> Result<(TVec<TensorFact>, TVec<TensorFact>)> {
for f in &mut facts.0 {
f.reduce();
}
for f in &mut facts.1 {
f.reduce();
}
let mut context = Context::new(facts.0, facts.1);
let mut changed = true;
let mut added_rules = vec![];
let mut rules: Vec<_> = self.rules.into_iter().map(|r| (false, r)).collect();
while changed {
changed = false;
for (used, rule) in &mut rules {
if *used {
continue;
}
trace!(" Applying rule {:?}", rule);
let (step_used, mut step_added) = rule.apply(&mut context)?;
*used |= step_used;
changed |= step_used;
changed |= step_added.len() > 0;
added_rules.append(&mut step_added);
}
trace!(" Applyingall rules");
for rule in added_rules.drain(..) {
rules.push((false, rule));
}
}
trace!(" Solver exiting {:?}", context);
for i in &mut context.inputs {
i.reduce();
}
for o in &mut context.outputs {
o.reduce();
}
Ok((context.inputs, context.outputs))
}
pub fn equals<T, A, B>(&mut self, left: A, right: B) -> &mut Solver<'rules>
where
T: Output + Fact + 'static,
A: IntoExp<T>,
B: IntoExp<T>,
{
let items: Vec<Exp<T>> = vec![left.bex(), right.bex()];
let rule = EqualsRule::new(items);
self.rules.push(Box::new(rule));
self
}
pub fn equals_all<T>(&mut self, items: Vec<Exp<T>>) -> &mut Solver<'rules>
where
T: Output + Fact + 'static,
{
let rule = EqualsRule::new(items);
self.rules.push(Box::new(rule));
self
}
pub fn equals_zero<F>(&mut self, items: Exp<F>) -> &mut Solver<'rules>
where
F: Fact
+ Zero
+ Add<F, Output = F>
+ Neg<Output = F>
+ Clone
+ ::std::fmt::Debug
+ Output
+ 'rules,
{
let rule = EqualsZeroRule(items);
self.rules.push(Box::new(rule));
self
}
pub fn with<T, A, F>(&mut self, item: A, closure: F) -> &mut Solver<'rules>
where
T: Fact + Output + 'static,
A: IntoExp<T>,
F: Fn(&mut Solver<'rules>, T) + 'rules,
{
let rule = WithRule::new(item.bex(), closure);
self.rules.push(Box::new(rule));
self
}
pub fn given<T, A, F>(&mut self, item: A, closure: F) -> &mut Solver<'rules>
where
T: Fact + Output + 'static,
A: IntoExp<T>,
F: Fn(&mut Solver<'rules>, T::Concrete) + 'rules,
{
let rule = GivenRule::new(item.bex(), closure);
self.rules.push(Box::new(rule));
self
}
pub fn given_all<T, I, A, F>(&mut self, items: I, closure: F) -> &mut Solver<'rules>
where
T: Fact + Output + 'static,
A: IntoExp<T>,
I: IntoIterator<Item = A>,
F: Fn(&mut Solver<'rules>, Vec<T::Concrete>) + 'rules,
{
let rule = GivenAllRule::new(items.into_iter().map(|it| it.bex()).collect(), closure);
self.rules.push(Box::new(rule));
self
}
}
#[cfg(test)]
mod tests {
use super::*;
use DatumType;
fn bootstrap<'s>() -> (Solver<'s>, TensorsProxy, TensorsProxy) {
(
Solver::default(),
TensorsProxy::new(tvec![0].into()),
TensorsProxy::new(tvec![1].into()),
)
}
#[test]
#[should_panic]
fn solver_wrong_size_1() {
let (mut solver, inputs, _) = bootstrap();
solver.equals(&inputs.len, 2);
solver.infer((tvec![].into(), tvec![].into())).unwrap();
}
#[test]
#[should_panic]
fn solver_wrong_size_2() {
let (mut solver, inputs, _) = bootstrap();
solver.equals(&inputs[0].rank, 2);
solver.infer((tvec![].into(), tvec![].into())).unwrap();
}
#[test]
fn solver_exact_size() {
let (mut solver, inputs, _) = bootstrap();
solver.equals(&inputs.len, 1);
let facts = solver
.infer((tvec![TensorFact::new()].into(), tvec![].into()))
.unwrap();
assert_eq!(facts, (tvec![TensorFact::new()].into(), tvec![].into()));
}
#[test]
fn solver_dynamic_size() {
let (mut solver, inputs, _) = bootstrap();
solver.equals(&inputs[1].datum_type, DatumType::I32);
let facts = solver
.infer((tvec![TensorFact::new(), TensorFact::new()], tvec![]))
.unwrap();
let expected = (
tvec![
TensorFact::new(),
TensorFact {
datum_type: typefact!(DatumType::I32),
..TensorFact::new()
},
],
tvec![],
);
assert_eq!(facts, expected);
}
#[test]
fn solver_exact_rank() {
let (mut solver, inputs, _) = bootstrap();
solver.equals(&inputs[0].rank, 2);
let facts = solver.infer((tvec![TensorFact::new()], tvec![])).unwrap();
let expected = (
tvec![TensorFact {
shape: shapefact![_, _],
..TensorFact::new()
}],
tvec![],
);
assert_eq!(facts, expected);
}
#[test]
fn solver_dynamic_rank() {
let (mut solver, inputs, _) = bootstrap();
solver.equals(&inputs[0].shape[1], 0.to_dim());
let facts = solver.infer((tvec![TensorFact::new()], tvec![])).unwrap();
let expected = (
tvec![TensorFact {
shape: shapefact![_, 0; ..],
..TensorFact::new()
}],
tvec![],
);
assert_eq!(facts, expected);
}
#[test]
fn solver_ranks() {
let (mut solver, inputs, _) = bootstrap();
solver.equals(&inputs[0].rank, 3);
solver.equals(&inputs[0].shape[0], &inputs[0].shape[1]);
solver.equals(&inputs[0].shape[1], &inputs[0].shape[2]);
solver.equals(&inputs[0].shape[1], 3.to_dim());
let facts = solver.infer((tvec![TensorFact::new()], tvec![])).unwrap();
let expected = (
tvec![TensorFact {
shape: shapefact![3, 3, 3],
..TensorFact::new()
}],
tvec![],
);
assert_eq!(facts, expected);
}
#[test]
#[should_panic]
fn solver_wrong_constant() {
let (mut solver, _, _) = bootstrap();
solver.equals(1, 2);
solver.infer((tvec![], tvec![])).unwrap();
}
#[test]
fn solver_right_constant() {
let (mut solver, _, _) = bootstrap();
solver.equals(2, 2);
solver.infer((tvec![], tvec![])).unwrap();
}
#[test]
fn solver_backward_1() {
let (mut solver, inputs, outputs) = bootstrap();
solver.equals(&inputs[0].shape[1], &outputs[0].shape[1]);
let facts = solver
.infer((tvec![TensorFact::new()], tvec![TensorFact::new()]))
.unwrap();
let expected = (tvec![TensorFact::new()], tvec![TensorFact::new()]);
assert_eq!(facts, expected);
}
#[test]
fn solver_backward_2() {
let (mut solver, inputs, outputs) = bootstrap();
solver.equals(&inputs[0].shape[1], &outputs[0].shape[1]);
let output = TensorFact {
shape: shapefact![_, 2, _],
..TensorFact::new()
};
let facts = solver
.infer((tvec![TensorFact::new()], tvec![output.clone()]))
.unwrap();
let expected = (
tvec![TensorFact {
shape: shapefact![_, 2; ..],
..TensorFact::new()
}],
tvec![output.clone()],
);
assert_eq!(facts, expected);
}
}