1use super::edge::Edge;
6use super::node::Node;
7use super::DynamicGraph;
8use crate::prelude::GraphResult as Result;
9use crate::NodeIndex;
10
11use acme::ops::{Arithmetic, BinaryOp, Op, UnaryOp};
12use acme::prelude::Scalar;
13use core::ops::Index;
14use petgraph::algo::toposort;
15use std::collections::HashMap;
16
17macro_rules! entry {
18 ($ctx:ident[$key:expr]) => {
19 entry!(@base $ctx[$key]).or_default()
20 };
21 ($ctx:ident[$key:expr], $val:expr) => {
22 entry!(@base $ctx[$key].or_insert($val))
23 };
24 (@base $ctx:ident[$key:expr].$call:ident($val:expr)) => {
25 entry!($ctx[$key]).$call:ident($val)
26 };
27 (@base $ctx:ident[$key:expr]) => {
28 $ctx.entry($key)
29 };
30
31}
32
33macro_rules! push {
34 ($ctx:expr, $(($key:expr, $val:expr)),*) => {
35 $(push!(@impl $ctx, $key, $val);)*
36 };
37
38 ($ctx:expr, $key:expr, $val:expr) => {
39 push!(@impl $ctx, $key, $val)
40 };
41 (@impl $ctx:expr, $key:expr, $val:expr) => {
42 $ctx.push(($key, $val))
43 };
44
45}
46
47macro_rules! binop {
48 ($($call:ident),*) => {
49 $(binop!(@impl $call);)*
50 };
51 (@impl $call:ident) => {
52 pub fn $call(&mut self, lhs: NodeIndex, rhs: NodeIndex) -> NodeIndex {
53 self.binary(lhs, rhs, BinaryOp::$call())
54 }
55 };
56}
57
58macro_rules! unop {
59 ($($call:ident),*) => {
60 $(unop!(@impl $call);)*
61 };
62 (@impl $call:ident) => {
63 pub fn $call(&mut self, recv: NodeIndex) -> NodeIndex {
64 self.unary(recv, UnaryOp::$call())
65 }
66 };
67}
68
69pub struct Dcg<T> {
70 store: DynamicGraph<T>,
71}
72
73impl<T> Default for Dcg<T> {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl<T> Dcg<T> {
80 pub fn new() -> Self {
81 Dcg {
82 store: DynamicGraph::new(),
83 }
84 }
85
86 pub fn binary(&mut self, lhs: NodeIndex, rhs: NodeIndex, op: impl Into<BinaryOp>) -> NodeIndex {
87 let c = self.store.add_node(Node::binary(lhs, rhs, op));
88 self.store.add_edge(lhs, c, Edge::new([rhs], lhs));
89 self.store.add_edge(rhs, c, Edge::new([lhs], rhs));
90 c
91 }
92
93 pub fn constant(&mut self, value: T) -> NodeIndex {
94 self.input(false, value)
95 }
96
97 pub fn get(&self, index: NodeIndex) -> Option<&Node<T>> {
98 self.store.node_weight(index)
99 }
100
101 pub fn include(&mut self, node: impl Into<Node<T>>) -> NodeIndex {
102 self.store.add_node(node.into())
103 }
104
105 pub fn input(&mut self, param: bool, value: T) -> NodeIndex {
106 self.store.add_node(Node::input(param, value))
107 }
108
109 pub fn op(
110 &mut self,
111 inputs: impl IntoIterator<Item = NodeIndex>,
112 op: impl Into<Op>,
113 ) -> NodeIndex {
114 let args = Vec::from_iter(inputs);
115
116 let c = self.store.add_node(Node::op(args.clone(), op));
117 for arg in args.iter() {
118 self.store.add_edge(*arg, c, Edge::new(args.clone(), *arg));
119 }
120 c
121 }
122
123 pub fn remove(&mut self, index: NodeIndex) -> Option<Node<T>> {
124 self.store.remove_node(index)
125 }
126
127 pub fn unary(&mut self, input: NodeIndex, op: impl Into<UnaryOp>) -> NodeIndex {
128 let c = self.store.add_node(Node::unary(input, op));
129 self.store.add_edge(input, c, Edge::new([input], input));
130 c
131 }
132
133 pub fn variable(&mut self, value: T) -> NodeIndex {
134 self.input(true, value)
135 }
136
137 binop!(add, div, mul, pow, rem, sub);
138
139 unop!(
140 abs, acos, acosh, asin, asinh, atan, atanh, ceil, cos, cosh, exp, floor, inv, ln, neg,
141 recip, sin, sinh, sqr, sqrt, tan, tanh
142 );
143}
144
145impl<T> Dcg<T>
146where
147 T: Scalar<Real = T>,
148{
149 pub fn backward(&self) -> Result<HashMap<NodeIndex, T>> {
150 let sorted = toposort(&self.store, None)?;
151 let target = sorted.last().unwrap();
152 self.gradient(*target)
153 }
154 pub fn gradient(&self, target: NodeIndex) -> Result<HashMap<NodeIndex, T>> {
155 let mut store = HashMap::<NodeIndex, T>::new();
156 let mut stack = Vec::<(NodeIndex, T)>::new();
158 stack.push((target, T::one()));
160 store.insert(target, T::one());
161
162 while let Some((i, grad)) = stack.pop() {
163 let node = &self[i];
164
165 match node {
166 Node::Binary { lhs, rhs, op } => match op {
167 BinaryOp::Arith(inner) => match inner {
168 Arithmetic::Add(_) => {
169 *entry!(store[*lhs]) += grad;
170 *entry!(store[*rhs]) += grad;
171
172 push!(stack, (*lhs, grad), (*rhs, grad));
173 }
174 Arithmetic::Div(_) => {
175 let lhs_grad = grad / self[*rhs].value();
176 let rhs_grad = grad * self[*lhs].value() / self[*rhs].value().powi(2);
177 *entry!(store[*lhs]) += lhs_grad;
178 *entry!(store[*rhs]) += rhs_grad;
179
180 push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
181 }
182 Arithmetic::Mul(_) => {
183 let lhs_grad = grad * self[*rhs].value();
184 let rhs_grad = grad * self[*lhs].value();
185 *entry!(store[*lhs]) += lhs_grad;
186 *entry!(store[*rhs]) += rhs_grad;
187 push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
188 }
189 Arithmetic::Pow(_) => {
190 let lhs_grad = grad
191 * self[*rhs].value()
192 * self[*lhs].value().powf(self[*rhs].value() - T::one());
193 let rhs_grad = grad
194 * self[*lhs].value().powf(self[*rhs].value())
195 * (self[*lhs].value().ln());
196 *entry!(store[*lhs]) += lhs_grad;
197 *entry!(store[*rhs]) += rhs_grad;
198
199 push!(stack, (*lhs, lhs_grad), (*rhs, rhs_grad));
200 }
201 Arithmetic::Sub(_) => {
202 *entry!(store[*lhs]) += grad;
203 *entry!(store[*rhs]) -= grad;
204
205 push!(stack, (*lhs, grad), (*rhs, -grad));
206 }
207 _ => todo!(),
208 },
209 _ => todo!(),
210 },
211 Node::Unary { .. } => {
212 unimplemented!();
213 }
214 Node::Input { param, .. } => {
215 if *param {
216 continue;
217 }
218 *store.entry(i).or_default() += grad;
219 stack.push((i, grad));
220 }
221 _ => {}
222 }
223 }
224
225 Ok(store)
226 }
227}
228
229impl<T> Index<NodeIndex> for Dcg<T> {
230 type Output = Node<T>;
231
232 fn index(&self, index: NodeIndex) -> &Self::Output {
233 self.get(index).unwrap()
234 }
235}