rsdiff_graphs/scg/
graph.rs1use super::{Edge, Node, Operation};
6use crate::prelude::GraphResult as Result;
7use num::traits::{NumAssign, NumOps, Signed};
8use petgraph::algo::toposort;
9use petgraph::prelude::{DiGraph, NodeIndex};
10use rsdiff::ops::{Arithmetic, BinaryOp, Op, UnaryOp};
11use std::collections::BTreeMap;
12
13pub(crate) type ValueStore<T> = BTreeMap<NodeIndex, T>;
14
15#[derive(Clone, Debug)]
16pub struct Scg<T> {
17 graph: DiGraph<Node, Edge<T>>,
18 vals: ValueStore<T>,
19}
20
21impl<T> Default for Scg<T> {
22 fn default() -> Self {
23 Self::new()
24 }
25}
26
27impl<T> Scg<T> {
28 pub fn new() -> Self {
29 Self {
30 graph: DiGraph::new(),
31 vals: BTreeMap::new(),
32 }
33 }
34
35 pub fn clear(&mut self) {
36 self.graph.clear();
37 }
38
39 pub fn get(&self, index: NodeIndex) -> Option<&Node> {
40 self.graph.node_weight(index)
41 }
42
43 pub fn get_value(&self, index: NodeIndex) -> Option<&T> {
44 self.vals.get(&index)
45 }
46
47 pub fn constant(&mut self, name: impl ToString, data: T) -> NodeIndex {
48 let v = self.graph.add_node(Node::placeholder(name));
49 self.vals.insert(v, data);
50 v
51 }
52
53 pub fn operation(
54 &mut self,
55 inputs: impl IntoIterator<Item = NodeIndex>,
56 operation: impl Into<Op>,
57 result: Option<T>,
58 ) -> Result<NodeIndex>
59 where
60 T: Default,
61 {
62 let op = Operation::new(inputs, operation);
63 let node = Node::Operation(op.clone());
64 let v = self.graph.add_node(node.clone());
65 let _ = self.vals.insert(v, result.unwrap_or_default());
66 Ok(v)
67 }
68
69 pub fn variable(&mut self, value: T) -> NodeIndex {
70 let v = self.graph.add_node(Node::default());
71 self.vals.insert(v, value);
72 v
73 }
74}
75
76impl<T> Scg<T>
77where
78 T: Copy + Default + NumAssign + NumOps + Signed + 'static,
79{
80 pub fn backward(&self) -> Result<ValueStore<T>> {
81 let nodes: Vec<NodeIndex> = toposort(&self.graph, None)?;
83 self.gradient_at(*nodes.last().unwrap())
85 }
86
87 pub fn gradient_at(&self, target: NodeIndex) -> Result<ValueStore<T>> {
88 let mut gradients = ValueStore::new();
90 let mut stack = Vec::<(NodeIndex, T)>::new();
92 gradients.insert(target, T::one());
94 stack.push((target, T::one()));
95 while let Some((i, grad)) = stack.pop() {
97 let node = &self.graph[i];
99 if let Some(inputs) = node.inputs() {
100 if inputs.is_empty() {
101 continue;
102 }
103 for (j, input) in inputs.iter().enumerate() {
105 let dt = if let Some(op) = node.op() {
107 match op {
108 Op::Binary(base) => match base {
109 BinaryOp::Arith(inner) => match inner {
110 Arithmetic::Add(_) => grad,
111 Arithmetic::Div(_) => {
112 let out = self.vals[&i];
113 let val = self.vals[input];
114 if j % 2 == 0 {
115 grad / val
116 } else {
117 -grad * out / (val * val)
118 }
119 }
120 Arithmetic::Mul(_) => {
121 let out = self.vals[&i];
122 let val = self.vals[input];
123 grad * out / val
124 }
125 Arithmetic::Sub(_) => {
126 if j % 2 == 0 {
127 grad
128 } else {
129 grad.neg()
130 }
131 }
132 _ => todo!(),
133 },
134 _ => todo!(),
135 },
136 Op::Unary(base) => match base {
137 UnaryOp::Neg => -grad,
138 _ => todo!(),
139 },
140 _ => todo!(),
141 }
142 } else {
143 T::default()
144 };
145 *gradients.entry(*input).or_default() += dt;
147 stack.push((*input, dt));
149 }
150 }
151 }
152 Ok(gradients)
153 }
154}
155
156impl<T> Scg<T>
157where
158 T: Copy + Default + NumOps + PartialOrd,
159{
160 pub fn add(&mut self, a: NodeIndex, b: NodeIndex) -> Result<NodeIndex> {
161 let x = self.vals[&a];
162 let y = self.vals[&b];
163 let op = BinaryOp::add();
164 let res = x + y;
165
166 let c = self.operation([a, b], op, Some(res))?;
167 Ok(c)
168 }
169
170 pub fn mul(&mut self, a: NodeIndex, b: NodeIndex) -> Result<NodeIndex> {
171 let x = self.vals[&a];
172 let y = self.vals[&b];
173 let res = x * y;
174 let c = self.operation([a, b], BinaryOp::mul(), Some(res))?;
175
176 Ok(c)
177 }
178}