use super::{Edge, Node, Operation};
use crate::prelude::GraphResult as Result;
use num::traits::{NumAssign, NumOps, Signed};
use petgraph::algo::toposort;
use petgraph::prelude::{DiGraph, NodeIndex};
use rsdiff::ops::{Arithmetic, BinaryOp, Op, UnaryOp};
use std::collections::BTreeMap;
pub(crate) type ValueStore<T> = BTreeMap<NodeIndex, T>;
#[derive(Clone, Debug)]
pub struct Scg<T> {
graph: DiGraph<Node, Edge<T>>,
vals: ValueStore<T>,
}
impl<T> Default for Scg<T> {
fn default() -> Self {
Self::new()
}
}
impl<T> Scg<T> {
pub fn new() -> Self {
Self {
graph: DiGraph::new(),
vals: BTreeMap::new(),
}
}
pub fn clear(&mut self) {
self.graph.clear();
}
pub fn get(&self, index: NodeIndex) -> Option<&Node> {
self.graph.node_weight(index)
}
pub fn get_value(&self, index: NodeIndex) -> Option<&T> {
self.vals.get(&index)
}
pub fn constant(&mut self, name: impl ToString, data: T) -> NodeIndex {
let v = self.graph.add_node(Node::placeholder(name));
self.vals.insert(v, data);
v
}
pub fn operation(
&mut self,
inputs: impl IntoIterator<Item = NodeIndex>,
operation: impl Into<Op>,
result: Option<T>,
) -> Result<NodeIndex>
where
T: Default,
{
let op = Operation::new(inputs, operation);
let node = Node::Operation(op.clone());
let v = self.graph.add_node(node.clone());
let _ = self.vals.insert(v, result.unwrap_or_default());
Ok(v)
}
pub fn variable(&mut self, value: T) -> NodeIndex {
let v = self.graph.add_node(Node::default());
self.vals.insert(v, value);
v
}
}
impl<T> Scg<T>
where
T: Copy + Default + NumAssign + NumOps + Signed + 'static,
{
pub fn backward(&self) -> Result<ValueStore<T>> {
let nodes: Vec<NodeIndex> = toposort(&self.graph, None)?;
self.gradient_at(*nodes.last().unwrap())
}
pub fn gradient_at(&self, target: NodeIndex) -> Result<ValueStore<T>> {
let mut gradients = ValueStore::new();
let mut stack = Vec::<(NodeIndex, T)>::new();
gradients.insert(target, T::one());
stack.push((target, T::one()));
while let Some((i, grad)) = stack.pop() {
let node = &self.graph[i];
if let Some(inputs) = node.inputs() {
if inputs.is_empty() {
continue;
}
for (j, input) in inputs.iter().enumerate() {
let dt = if let Some(op) = node.op() {
match op {
Op::Binary(base) => match base {
BinaryOp::Arith(inner) => match inner {
Arithmetic::Add(_) => grad,
Arithmetic::Div(_) => {
let out = self.vals[&i];
let val = self.vals[input];
if j % 2 == 0 {
grad / val
} else {
-grad * out / (val * val)
}
}
Arithmetic::Mul(_) => {
let out = self.vals[&i];
let val = self.vals[input];
grad * out / val
}
Arithmetic::Sub(_) => {
if j % 2 == 0 {
grad
} else {
grad.neg()
}
}
_ => todo!(),
},
_ => todo!(),
},
Op::Unary(base) => match base {
UnaryOp::Neg => -grad,
_ => todo!(),
},
_ => todo!(),
}
} else {
T::default()
};
*gradients.entry(*input).or_default() += dt;
stack.push((*input, dt));
}
}
}
Ok(gradients)
}
}
impl<T> Scg<T>
where
T: Copy + Default + NumOps + PartialOrd,
{
pub fn add(&mut self, a: NodeIndex, b: NodeIndex) -> Result<NodeIndex> {
let x = self.vals[&a];
let y = self.vals[&b];
let op = BinaryOp::add();
let res = x + y;
let c = self.operation([a, b], op, Some(res))?;
Ok(c)
}
pub fn mul(&mut self, a: NodeIndex, b: NodeIndex) -> Result<NodeIndex> {
let x = self.vals[&a];
let y = self.vals[&b];
let res = x * y;
let c = self.operation([a, b], BinaryOp::mul(), Some(res))?;
Ok(c)
}
}