rsdiff_graphs/dcg/
graph.rs

1/*
2    Appellation: graph <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::DynamicGraph;
6use super::edge::Edge;
7use super::node::Node;
8use crate::NodeIndex;
9use crate::prelude::GraphResult as Result;
10
11use core::ops::Index;
12use petgraph::algo::toposort;
13use rsdiff::ops::{Arithmetic, BinaryOp, Op, UnaryOp};
14use rsdiff::prelude::Scalar;
15use std::collections::HashMap;
16
17macro_rules! push {
18    ($ctx:expr, $(($key:expr, $val:expr)),*) => {
19        $(push!(@impl $ctx, $key, $val);)*
20    };
21
22    ($ctx:expr, $key:expr, $val:expr) => {
23        push!(@impl $ctx, $key, $val)
24    };
25    (@impl $ctx:expr, $key:expr, $val:expr) => {
26        $ctx.push(($key, $val))
27    };
28
29}
30
31macro_rules! binop {
32    ($($call:ident),*) => {
33        $(binop!(@impl $call);)*
34    };
35    (@impl $call:ident) => {
36        pub fn $call(&mut self, lhs: NodeIndex, rhs: NodeIndex) -> NodeIndex {
37            self.binary(lhs, rhs, BinaryOp::$call())
38        }
39    };
40}
41
42macro_rules! unop {
43    ($($call:ident),*) => {
44        $(unop!(@impl $call);)*
45    };
46    (@impl $call:ident) => {
47        pub fn $call(&mut self, recv: NodeIndex) -> NodeIndex {
48            self.unary(recv, UnaryOp::$call())
49        }
50    };
51}
52
53pub struct Dcg<T> {
54    store: DynamicGraph<T>,
55}
56
57impl<T> Default for Dcg<T> {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl<T> Dcg<T> {
64    pub fn new() -> Self {
65        Dcg {
66            store: DynamicGraph::new(),
67        }
68    }
69
70    pub fn binary(&mut self, lhs: NodeIndex, rhs: NodeIndex, op: impl Into<BinaryOp>) -> NodeIndex {
71        let c = self.store.add_node(Node::binary(lhs, rhs, op));
72        self.store.add_edge(lhs, c, Edge::new([rhs], lhs));
73        self.store.add_edge(rhs, c, Edge::new([lhs], rhs));
74        c
75    }
76
77    pub fn constant(&mut self, value: T) -> NodeIndex {
78        self.input(false, value)
79    }
80
81    pub fn get(&self, index: NodeIndex) -> Option<&Node<T>> {
82        self.store.node_weight(index)
83    }
84
85    pub fn include(&mut self, node: impl Into<Node<T>>) -> NodeIndex {
86        self.store.add_node(node.into())
87    }
88
89    pub fn input(&mut self, param: bool, value: T) -> NodeIndex {
90        self.store.add_node(Node::input(param, value))
91    }
92
93    pub fn op(
94        &mut self,
95        inputs: impl IntoIterator<Item = NodeIndex>,
96        op: impl Into<Op>,
97    ) -> NodeIndex {
98        let args = Vec::from_iter(inputs);
99
100        let c = self.store.add_node(Node::op(args.clone(), op));
101        for arg in args.iter() {
102            self.store.add_edge(*arg, c, Edge::new(args.clone(), *arg));
103        }
104        c
105    }
106
107    pub fn remove(&mut self, index: NodeIndex) -> Option<Node<T>> {
108        self.store.remove_node(index)
109    }
110
111    pub fn unary(&mut self, input: NodeIndex, op: impl Into<UnaryOp>) -> NodeIndex {
112        let c = self.store.add_node(Node::unary(input, op));
113        self.store.add_edge(input, c, Edge::new([input], input));
114        c
115    }
116
117    pub fn variable(&mut self, value: T) -> NodeIndex {
118        self.input(true, value)
119    }
120
121    binop!(add, div, mul, pow, rem, sub);
122
123    unop!(
124        abs, acos, acosh, asin, asinh, atan, atanh, ceil, cos, cosh, exp, floor, inv, ln, neg,
125        recip, sin, sinh, sqr, sqrt, tan, tanh
126    );
127}
128
129impl<T> Dcg<T>
130where
131    T: Scalar<Real = T>,
132{
133    pub fn backward(&self) -> Result<HashMap<NodeIndex, T>> {
134        let sorted = toposort(&self.store, None)?;
135        let target = sorted.last().unwrap();
136        self.gradient(*target)
137    }
138    pub fn gradient(&self, target: NodeIndex) -> Result<HashMap<NodeIndex, T>> {
139        let mut store = HashMap::<NodeIndex, T>::new();
140        // initialize the stack
141        let mut stack = Vec::<(NodeIndex, T)>::new();
142        // start by computing the gradient of the target w.r.t. itself
143        stack.push((target, T::one()));
144        store.insert(target, T::one());
145
146        while let Some((i, grad)) = stack.pop() {
147            let node = &self[i];
148
149            match node {
150                Node::Binary { lhs, rhs, op } => match op {
151                    BinaryOp::Arith(inner) => match inner {
152                        Arithmetic::Add(_) => {
153                            *entry!(store[*lhs]) += grad;
154                            *entry!(store[*rhs]) += grad;
155
156                            push!(stack, (*lhs, grad), (*rhs, grad));
157                        }
158                        Arithmetic::Div(_) => {
159                            let lhs_grad = grad / self[*rhs].value();
160                            let rhs_grad = grad * self[*lhs].value() / self[*rhs].value().powi(2);
161                            *entry!(store[*lhs]) += lhs_grad;
162                            *entry!(store[*rhs]) += rhs_grad;
163
164                            push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
165                        }
166                        Arithmetic::Mul(_) => {
167                            let lhs_grad = grad * self[*rhs].value();
168                            let rhs_grad = grad * self[*lhs].value();
169                            *entry!(store[*lhs]) += lhs_grad;
170                            *entry!(store[*rhs]) += rhs_grad;
171                            push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
172                        }
173                        Arithmetic::Pow(_) => {
174                            let lhs_grad = grad
175                                * self[*rhs].value()
176                                * self[*lhs].value().powf(self[*rhs].value() - T::one());
177                            let rhs_grad = grad
178                                * self[*lhs].value().powf(self[*rhs].value())
179                                * (self[*lhs].value().ln());
180                            *entry!(store[*lhs]) += lhs_grad;
181                            *entry!(store[*rhs]) += rhs_grad;
182
183                            push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
184                        }
185                        Arithmetic::Sub(_) => {
186                            *entry!(store[*lhs]) += grad;
187                            *entry!(store[*rhs]) -= grad;
188
189                            push!(stack, (*lhs, grad), (*rhs, -grad));
190                        }
191                        _ => todo!(),
192                    },
193                    _ => todo!(),
194                },
195                Node::Unary { .. } => {
196                    unimplemented!();
197                }
198                Node::Input { param, .. } => {
199                    if *param {
200                        continue;
201                    }
202                    *store.entry(i).or_default() += grad;
203                    stack.push((i, grad));
204                }
205                _ => {}
206            }
207        }
208
209        Ok(store)
210    }
211}
212
213impl<T> Index<NodeIndex> for Dcg<T> {
214    type Output = Node<T>;
215
216    fn index(&self, index: NodeIndex) -> &Self::Output {
217        self.get(index).unwrap()
218    }
219}