1use super::DynamicGraph;
6use super::edge::Edge;
7use super::node::Node;
8use crate::NodeIndex;
9use crate::prelude::GraphResult as Result;
10
11use core::ops::Index;
12use petgraph::algo::toposort;
13use rsdiff::ops::{Arithmetic, BinaryOp, Op, UnaryOp};
14use rsdiff::prelude::Scalar;
15use std::collections::HashMap;
16
17macro_rules! push {
18 ($ctx:expr, $(($key:expr, $val:expr)),*) => {
19 $(push!(@impl $ctx, $key, $val);)*
20 };
21
22 ($ctx:expr, $key:expr, $val:expr) => {
23 push!(@impl $ctx, $key, $val)
24 };
25 (@impl $ctx:expr, $key:expr, $val:expr) => {
26 $ctx.push(($key, $val))
27 };
28
29}
30
31macro_rules! binop {
32 ($($call:ident),*) => {
33 $(binop!(@impl $call);)*
34 };
35 (@impl $call:ident) => {
36 pub fn $call(&mut self, lhs: NodeIndex, rhs: NodeIndex) -> NodeIndex {
37 self.binary(lhs, rhs, BinaryOp::$call())
38 }
39 };
40}
41
42macro_rules! unop {
43 ($($call:ident),*) => {
44 $(unop!(@impl $call);)*
45 };
46 (@impl $call:ident) => {
47 pub fn $call(&mut self, recv: NodeIndex) -> NodeIndex {
48 self.unary(recv, UnaryOp::$call())
49 }
50 };
51}
52
53pub struct Dcg<T> {
54 store: DynamicGraph<T>,
55}
56
57impl<T> Default for Dcg<T> {
58 fn default() -> Self {
59 Self::new()
60 }
61}
62
63impl<T> Dcg<T> {
64 pub fn new() -> Self {
65 Dcg {
66 store: DynamicGraph::new(),
67 }
68 }
69
70 pub fn binary(&mut self, lhs: NodeIndex, rhs: NodeIndex, op: impl Into<BinaryOp>) -> NodeIndex {
71 let c = self.store.add_node(Node::binary(lhs, rhs, op));
72 self.store.add_edge(lhs, c, Edge::new([rhs], lhs));
73 self.store.add_edge(rhs, c, Edge::new([lhs], rhs));
74 c
75 }
76
77 pub fn constant(&mut self, value: T) -> NodeIndex {
78 self.input(false, value)
79 }
80
81 pub fn get(&self, index: NodeIndex) -> Option<&Node<T>> {
82 self.store.node_weight(index)
83 }
84
85 pub fn include(&mut self, node: impl Into<Node<T>>) -> NodeIndex {
86 self.store.add_node(node.into())
87 }
88
89 pub fn input(&mut self, param: bool, value: T) -> NodeIndex {
90 self.store.add_node(Node::input(param, value))
91 }
92
93 pub fn op(
94 &mut self,
95 inputs: impl IntoIterator<Item = NodeIndex>,
96 op: impl Into<Op>,
97 ) -> NodeIndex {
98 let args = Vec::from_iter(inputs);
99
100 let c = self.store.add_node(Node::op(args.clone(), op));
101 for arg in args.iter() {
102 self.store.add_edge(*arg, c, Edge::new(args.clone(), *arg));
103 }
104 c
105 }
106
107 pub fn remove(&mut self, index: NodeIndex) -> Option<Node<T>> {
108 self.store.remove_node(index)
109 }
110
111 pub fn unary(&mut self, input: NodeIndex, op: impl Into<UnaryOp>) -> NodeIndex {
112 let c = self.store.add_node(Node::unary(input, op));
113 self.store.add_edge(input, c, Edge::new([input], input));
114 c
115 }
116
117 pub fn variable(&mut self, value: T) -> NodeIndex {
118 self.input(true, value)
119 }
120
121 binop!(add, div, mul, pow, rem, sub);
122
123 unop!(
124 abs, acos, acosh, asin, asinh, atan, atanh, ceil, cos, cosh, exp, floor, inv, ln, neg,
125 recip, sin, sinh, sqr, sqrt, tan, tanh
126 );
127}
128
129impl<T> Dcg<T>
130where
131 T: Scalar<Real = T>,
132{
133 pub fn backward(&self) -> Result<HashMap<NodeIndex, T>> {
134 let sorted = toposort(&self.store, None)?;
135 let target = sorted.last().unwrap();
136 self.gradient(*target)
137 }
138 pub fn gradient(&self, target: NodeIndex) -> Result<HashMap<NodeIndex, T>> {
139 let mut store = HashMap::<NodeIndex, T>::new();
140 let mut stack = Vec::<(NodeIndex, T)>::new();
142 stack.push((target, T::one()));
144 store.insert(target, T::one());
145
146 while let Some((i, grad)) = stack.pop() {
147 let node = &self[i];
148
149 match node {
150 Node::Binary { lhs, rhs, op } => match op {
151 BinaryOp::Arith(inner) => match inner {
152 Arithmetic::Add(_) => {
153 *entry!(store[*lhs]) += grad;
154 *entry!(store[*rhs]) += grad;
155
156 push!(stack, (*lhs, grad), (*rhs, grad));
157 }
158 Arithmetic::Div(_) => {
159 let lhs_grad = grad / self[*rhs].value();
160 let rhs_grad = grad * self[*lhs].value() / self[*rhs].value().powi(2);
161 *entry!(store[*lhs]) += lhs_grad;
162 *entry!(store[*rhs]) += rhs_grad;
163
164 push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
165 }
166 Arithmetic::Mul(_) => {
167 let lhs_grad = grad * self[*rhs].value();
168 let rhs_grad = grad * self[*lhs].value();
169 *entry!(store[*lhs]) += lhs_grad;
170 *entry!(store[*rhs]) += rhs_grad;
171 push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
172 }
173 Arithmetic::Pow(_) => {
174 let lhs_grad = grad
175 * self[*rhs].value()
176 * self[*lhs].value().powf(self[*rhs].value() - T::one());
177 let rhs_grad = grad
178 * self[*lhs].value().powf(self[*rhs].value())
179 * (self[*lhs].value().ln());
180 *entry!(store[*lhs]) += lhs_grad;
181 *entry!(store[*rhs]) += rhs_grad;
182
183 push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
184 }
185 Arithmetic::Sub(_) => {
186 *entry!(store[*lhs]) += grad;
187 *entry!(store[*rhs]) -= grad;
188
189 push!(stack, (*lhs, grad), (*rhs, -grad));
190 }
191 _ => todo!(),
192 },
193 _ => todo!(),
194 },
195 Node::Unary { .. } => {
196 unimplemented!();
197 }
198 Node::Input { param, .. } => {
199 if *param {
200 continue;
201 }
202 *store.entry(i).or_default() += grad;
203 stack.push((i, grad));
204 }
205 _ => {}
206 }
207 }
208
209 Ok(store)
210 }
211}
212
213impl<T> Index<NodeIndex> for Dcg<T> {
214 type Output = Node<T>;
215
216 fn index(&self, index: NodeIndex) -> &Self::Output {
217 self.get(index).unwrap()
218 }
219}