acme_graphs/dcg/
graph.rs

1/*
2    Appellation: graph <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::edge::Edge;
6use super::node::Node;
7use super::DynamicGraph;
8use crate::prelude::GraphResult as Result;
9use crate::NodeIndex;
10
11use acme::ops::{Arithmetic, BinaryOp, Op, UnaryOp};
12use acme::prelude::Scalar;
13use core::ops::Index;
14use petgraph::algo::toposort;
15use std::collections::HashMap;
16
17macro_rules! entry {
18    ($ctx:ident[$key:expr]) => {
19        entry!(@base $ctx[$key]).or_default()
20    };
21    ($ctx:ident[$key:expr], $val:expr) => {
22        entry!(@base $ctx[$key].or_insert($val))
23    };
24    (@base $ctx:ident[$key:expr].$call:ident($val:expr)) => {
25        entry!($ctx[$key]).$call:ident($val)
26    };
27    (@base $ctx:ident[$key:expr]) => {
28        $ctx.entry($key)
29    };
30
31}
32
33macro_rules! push {
34    ($ctx:expr, $(($key:expr, $val:expr)),*) => {
35        $(push!(@impl $ctx, $key, $val);)*
36    };
37
38    ($ctx:expr, $key:expr, $val:expr) => {
39        push!(@impl $ctx, $key, $val)
40    };
41    (@impl $ctx:expr, $key:expr, $val:expr) => {
42        $ctx.push(($key, $val))
43    };
44
45}
46
47macro_rules! binop {
48    ($($call:ident),*) => {
49        $(binop!(@impl $call);)*
50    };
51    (@impl $call:ident) => {
52        pub fn $call(&mut self, lhs: NodeIndex, rhs: NodeIndex) -> NodeIndex {
53            self.binary(lhs, rhs, BinaryOp::$call())
54        }
55    };
56}
57
58macro_rules! unop {
59    ($($call:ident),*) => {
60        $(unop!(@impl $call);)*
61    };
62    (@impl $call:ident) => {
63        pub fn $call(&mut self, recv: NodeIndex) -> NodeIndex {
64            self.unary(recv, UnaryOp::$call())
65        }
66    };
67}
68
69pub struct Dcg<T> {
70    store: DynamicGraph<T>,
71}
72
73impl<T> Default for Dcg<T> {
74    fn default() -> Self {
75        Self::new()
76    }
77}
78
79impl<T> Dcg<T> {
80    pub fn new() -> Self {
81        Dcg {
82            store: DynamicGraph::new(),
83        }
84    }
85
86    pub fn binary(&mut self, lhs: NodeIndex, rhs: NodeIndex, op: impl Into<BinaryOp>) -> NodeIndex {
87        let c = self.store.add_node(Node::binary(lhs, rhs, op));
88        self.store.add_edge(lhs, c, Edge::new([rhs], lhs));
89        self.store.add_edge(rhs, c, Edge::new([lhs], rhs));
90        c
91    }
92
93    pub fn constant(&mut self, value: T) -> NodeIndex {
94        self.input(false, value)
95    }
96
97    pub fn get(&self, index: NodeIndex) -> Option<&Node<T>> {
98        self.store.node_weight(index)
99    }
100
101    pub fn include(&mut self, node: impl Into<Node<T>>) -> NodeIndex {
102        self.store.add_node(node.into())
103    }
104
105    pub fn input(&mut self, param: bool, value: T) -> NodeIndex {
106        self.store.add_node(Node::input(param, value))
107    }
108
109    pub fn op(
110        &mut self,
111        inputs: impl IntoIterator<Item = NodeIndex>,
112        op: impl Into<Op>,
113    ) -> NodeIndex {
114        let args = Vec::from_iter(inputs);
115
116        let c = self.store.add_node(Node::op(args.clone(), op));
117        for arg in args.iter() {
118            self.store.add_edge(*arg, c, Edge::new(args.clone(), *arg));
119        }
120        c
121    }
122
123    pub fn remove(&mut self, index: NodeIndex) -> Option<Node<T>> {
124        self.store.remove_node(index)
125    }
126
127    pub fn unary(&mut self, input: NodeIndex, op: impl Into<UnaryOp>) -> NodeIndex {
128        let c = self.store.add_node(Node::unary(input, op));
129        self.store.add_edge(input, c, Edge::new([input], input));
130        c
131    }
132
133    pub fn variable(&mut self, value: T) -> NodeIndex {
134        self.input(true, value)
135    }
136
137    binop!(add, div, mul, pow, rem, sub);
138
139    unop!(
140        abs, acos, acosh, asin, asinh, atan, atanh, ceil, cos, cosh, exp, floor, inv, ln, neg,
141        recip, sin, sinh, sqr, sqrt, tan, tanh
142    );
143}
144
145impl<T> Dcg<T>
146where
147    T: Scalar<Real = T>,
148{
149    pub fn backward(&self) -> Result<HashMap<NodeIndex, T>> {
150        let sorted = toposort(&self.store, None)?;
151        let target = sorted.last().unwrap();
152        self.gradient(*target)
153    }
154    pub fn gradient(&self, target: NodeIndex) -> Result<HashMap<NodeIndex, T>> {
155        let mut store = HashMap::<NodeIndex, T>::new();
156        // initialize the stack
157        let mut stack = Vec::<(NodeIndex, T)>::new();
158        // start by computing the gradient of the target w.r.t. itself
159        stack.push((target, T::one()));
160        store.insert(target, T::one());
161
162        while let Some((i, grad)) = stack.pop() {
163            let node = &self[i];
164
165            match node {
166                Node::Binary { lhs, rhs, op } => match op {
167                    BinaryOp::Arith(inner) => match inner {
168                        Arithmetic::Add(_) => {
169                            *entry!(store[*lhs]) += grad;
170                            *entry!(store[*rhs]) += grad;
171
172                            push!(stack, (*lhs, grad), (*rhs, grad));
173                        }
174                        Arithmetic::Div(_) => {
175                            let lhs_grad = grad / self[*rhs].value();
176                            let rhs_grad = grad * self[*lhs].value() / self[*rhs].value().powi(2);
177                            *entry!(store[*lhs]) += lhs_grad;
178                            *entry!(store[*rhs]) += rhs_grad;
179
180                            push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
181                        }
182                        Arithmetic::Mul(_) => {
183                            let lhs_grad = grad * self[*rhs].value();
184                            let rhs_grad = grad * self[*lhs].value();
185                            *entry!(store[*lhs]) += lhs_grad;
186                            *entry!(store[*rhs]) += rhs_grad;
187                            push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
188                        }
189                        Arithmetic::Pow(_) => {
190                            let lhs_grad = grad
191                                * self[*rhs].value()
192                                * self[*lhs].value().powf(self[*rhs].value() - T::one());
193                            let rhs_grad = grad
194                                * self[*lhs].value().powf(self[*rhs].value())
195                                * (self[*lhs].value().ln());
196                            *entry!(store[*lhs]) += lhs_grad;
197                            *entry!(store[*rhs]) += rhs_grad;
198
199                            push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
200                        }
201                        Arithmetic::Sub(_) => {
202                            *entry!(store[*lhs]) += grad;
203                            *entry!(store[*rhs]) -= grad;
204
205                            push!(stack, (*lhs, grad), (*rhs, -grad));
206                        }
207                        _ => todo!(),
208                    },
209                    _ => todo!(),
210                },
211                Node::Unary { .. } => {
212                    unimplemented!();
213                }
214                Node::Input { param, .. } => {
215                    if *param {
216                        continue;
217                    }
218                    *store.entry(i).or_default() += grad;
219                    stack.push((i, grad));
220                }
221                _ => {}
222            }
223        }
224
225        Ok(store)
226    }
227}
228
229impl<T> Index<NodeIndex> for Dcg<T> {
230    type Output = Node<T>;
231
232    fn index(&self, index: NodeIndex) -> &Self::Output {
233        self.get(index).unwrap()
234    }
235}