use ndarray::*;
use ndarray_linalg::*;
use petgraph;
use petgraph::prelude::*;
use operators::*;
use error::*;
#[derive(Debug, Clone)]
pub struct Node<A: Scalar> {
value: Option<Value<A>>,
deriv: Option<Value<A>>,
prop: Property,
}
#[derive(Debug, Clone, IntoEnum)]
enum Property {
Variable(Variable),
UnaryOperator(UnaryOperatorAny),
BinaryOperator(BinaryOperatorAny),
}
impl<A: Scalar> Node<A> {
fn new(prop: Property) -> Self {
Self {
value: None,
deriv: None,
prop,
}
}
pub fn is_variable(&self) -> bool {
match self.prop {
Property::Variable(_) => true,
Property::UnaryOperator(_) => false,
Property::BinaryOperator(_) => false,
}
}
}
#[derive(Debug, Clone)]
struct Variable {
name: String,
}
impl Variable {
fn new(name: &str) -> Self {
Variable { name: name.to_string() }
}
}
#[derive(Debug, NewType)]
pub struct Graph<A: Scalar>(petgraph::graph::Graph<Node<A>, ()>);
impl<A: Scalar> Graph<A> {
pub fn new() -> Self {
petgraph::graph::Graph::new().into()
}
pub fn variable(&mut self, name: &str) -> NodeIndex {
let var = Variable::new(name);
self.add_node(Node::new(var.into()))
}
pub fn scalar_variable(&mut self, name: &str, value: A) -> NodeIndex {
let var = self.variable(name);
self.set_value(var, value).unwrap();
var
}
pub fn vector_variable(&mut self, name: &str, value: Array<A, Ix1>) -> NodeIndex {
let var = self.variable(name);
self.set_value(var, value).unwrap();
var
}
pub fn set_value<V: Into<Value<A>>>(&mut self, node: NodeIndex, value: V) -> Result<()> {
if self[node].is_variable() {
self[node].value = Some(value.into());
Ok(())
} else {
Err(NodeTypeError {}.into())
}
}
pub fn add(&mut self, lhs: NodeIndex, rhs: NodeIndex) -> NodeIndex {
let p = self.add_node(Node::new(add().into()));
self.add_edge(lhs, p, ());
self.add_edge(rhs, p, ());
p
}
pub fn neg(&mut self, arg: NodeIndex) -> NodeIndex {
let n = self.add_node(Node::new(neg().into()));
self.add_edge(arg, n, ());
n
}
pub fn sub(&mut self, lhs: NodeIndex, rhs: NodeIndex) -> NodeIndex {
let m_rhs = self.neg(rhs);
self.add(lhs, m_rhs)
}
fn get_arg1(&mut self, op: NodeIndex) -> NodeIndex {
let mut iter = self.neighbors_directed(op, Direction::Incoming);
iter.next().unwrap()
}
fn get_arg2(&mut self, op: NodeIndex) -> (NodeIndex, NodeIndex) {
let mut iter = self.neighbors_directed(op, Direction::Incoming);
let rhs = iter.next().unwrap();
let lhs = iter.next().unwrap();
(lhs, rhs)
}
pub fn get_value(&self, node: NodeIndex) -> Option<&Value<A>> {
self[node].value.as_ref()
}
pub fn get_deriv(&self, node: NodeIndex) -> Option<&Value<A>> {
self[node].deriv.as_ref()
}
pub fn eval_value(&mut self, node: NodeIndex, use_cached: bool) -> Result<()> {
let prop = self[node].prop.clone();
let value_exists = self[node].value.is_some();
match prop {
Property::Variable(ref v) => {
if value_exists {
return Ok(());
}
panic!("Variable '{}' is evaluated before set value", v.name)
}
Property::UnaryOperator(ref op) => {
if use_cached && value_exists {
return Ok(());
}
let arg = self.get_arg1(node);
self.eval_value(arg, use_cached)?;
let res = op.eval_value(self.get_value(arg).unwrap())?;
self[node].value = Some(res);
}
Property::BinaryOperator(ref op) => {
if use_cached && value_exists {
return Ok(());
}
let (lhs, rhs) = self.get_arg2(node);
self.eval_value(rhs, use_cached)?;
self.eval_value(lhs, use_cached)?;
let res = op.eval_value(
self.get_value(lhs).unwrap(),
self.get_value(rhs).unwrap(),
)?;
self[node].value = Some(res);
}
};
Ok(())
}
fn deriv_recur(&mut self, node: NodeIndex, der: Value<A>) -> Result<()> {
self[node].deriv = Some(der);
let prop = self[node].prop.clone();
match prop {
Property::Variable(_) => {}
Property::UnaryOperator(ref op) => {
let arg = self.get_arg1(node);
let der = op.eval_deriv(
self.get_value(arg).unwrap(),
self.get_deriv(node).unwrap(),
)?;
self.deriv_recur(arg, der)?;
}
Property::BinaryOperator(ref op) => {
let (lhs, rhs) = self.get_arg2(node);
let (l_der, r_der) = op.eval_deriv(
self.get_value(lhs).unwrap(),
self.get_value(rhs).unwrap(),
self.get_deriv(node).unwrap(),
)?;
self.deriv_recur(lhs, l_der)?;
self.deriv_recur(rhs, r_der)?;
}
};
Ok(())
}
pub fn eval_deriv(&mut self, node: NodeIndex) -> Result<()> {
self.deriv_recur(node, Value::identity())
}
}