rsdiff_graphs/scg/
graph.rs

1/*
2    Appellation: graph <module>
3    Contrib: FL03 <jo3mccain@icloud.com>
4*/
5use super::{Edge, Node, Operation};
6use crate::prelude::GraphResult as Result;
7use num::traits::{NumAssign, NumOps, Signed};
8use petgraph::algo::toposort;
9use petgraph::prelude::{DiGraph, NodeIndex};
10use rsdiff::ops::{Arithmetic, BinaryOp, Op, UnaryOp};
11use std::collections::BTreeMap;
12
13pub(crate) type ValueStore<T> = BTreeMap<NodeIndex, T>;
14
15#[derive(Clone, Debug)]
16pub struct Scg<T> {
17    graph: DiGraph<Node, Edge<T>>,
18    vals: ValueStore<T>,
19}
20
21impl<T> Default for Scg<T> {
22    fn default() -> Self {
23        Self::new()
24    }
25}
26
27impl<T> Scg<T> {
28    pub fn new() -> Self {
29        Self {
30            graph: DiGraph::new(),
31            vals: BTreeMap::new(),
32        }
33    }
34
35    pub fn clear(&mut self) {
36        self.graph.clear();
37    }
38
39    pub fn get(&self, index: NodeIndex) -> Option<&Node> {
40        self.graph.node_weight(index)
41    }
42
43    pub fn get_value(&self, index: NodeIndex) -> Option<&T> {
44        self.vals.get(&index)
45    }
46
47    pub fn constant(&mut self, name: impl ToString, data: T) -> NodeIndex {
48        let v = self.graph.add_node(Node::placeholder(name));
49        self.vals.insert(v, data);
50        v
51    }
52
53    pub fn operation(
54        &mut self,
55        inputs: impl IntoIterator<Item = NodeIndex>,
56        operation: impl Into<Op>,
57        result: Option<T>,
58    ) -> Result<NodeIndex>
59    where
60        T: Default,
61    {
62        let op = Operation::new(inputs, operation);
63        let node = Node::Operation(op.clone());
64        let v = self.graph.add_node(node.clone());
65        let _ = self.vals.insert(v, result.unwrap_or_default());
66        Ok(v)
67    }
68
69    pub fn variable(&mut self, value: T) -> NodeIndex {
70        let v = self.graph.add_node(Node::default());
71        self.vals.insert(v, value);
72        v
73    }
74}
75
76impl<T> Scg<T>
77where
78    T: Copy + Default + NumAssign + NumOps + Signed + 'static,
79{
80    pub fn backward(&self) -> Result<ValueStore<T>> {
81        // find the topological order of the graph
82        let nodes: Vec<NodeIndex> = toposort(&self.graph, None)?;
83        // compute the gradient w.r.t. the last topological node
84        self.gradient_at(*nodes.last().unwrap())
85    }
86
87    pub fn gradient_at(&self, target: NodeIndex) -> Result<ValueStore<T>> {
88        // initialize the gradient store
89        let mut gradients = ValueStore::new();
90        // initialize the stack
91        let mut stack = Vec::<(NodeIndex, T)>::new();
92        // start by computing the gradient of the target w.r.t. itself
93        gradients.insert(target, T::one());
94        stack.push((target, T::one()));
95        // iterate through the nodes in reverse topological order
96        while let Some((i, grad)) = stack.pop() {
97            // get the current node
98            let node = &self.graph[i];
99            if let Some(inputs) = node.inputs() {
100                if inputs.is_empty() {
101                    continue;
102                }
103                // iterate through the inputs of the current node
104                for (j, input) in inputs.iter().enumerate() {
105                    // calculate the gradient of each input w.r.t. the current node
106                    let dt = if let Some(op) = node.op() {
107                        match op {
108                            Op::Binary(base) => match base {
109                                BinaryOp::Arith(inner) => match inner {
110                                    Arithmetic::Add(_) => grad,
111                                    Arithmetic::Div(_) => {
112                                        let out = self.vals[&i];
113                                        let val = self.vals[input];
114                                        if j % 2 == 0 {
115                                            grad / val
116                                        } else {
117                                            -grad * out / (val * val)
118                                        }
119                                    }
120                                    Arithmetic::Mul(_) => {
121                                        let out = self.vals[&i];
122                                        let val = self.vals[input];
123                                        grad * out / val
124                                    }
125                                    Arithmetic::Sub(_) => {
126                                        if j % 2 == 0 {
127                                            grad
128                                        } else {
129                                            grad.neg()
130                                        }
131                                    }
132                                    _ => todo!(),
133                                },
134                                _ => todo!(),
135                            },
136                            Op::Unary(base) => match base {
137                                UnaryOp::Neg => -grad,
138                                _ => todo!(),
139                            },
140                            _ => todo!(),
141                        }
142                    } else {
143                        T::default()
144                    };
145                    // add or insert the gradient of the input
146                    *gradients.entry(*input).or_default() += dt;
147                    // push the input and its gradient onto the stack
148                    stack.push((*input, dt));
149                }
150            }
151        }
152        Ok(gradients)
153    }
154}
155
156impl<T> Scg<T>
157where
158    T: Copy + Default + NumOps + PartialOrd,
159{
160    pub fn add(&mut self, a: NodeIndex, b: NodeIndex) -> Result<NodeIndex> {
161        let x = self.vals[&a];
162        let y = self.vals[&b];
163        let op = BinaryOp::add();
164        let res = x + y;
165
166        let c = self.operation([a, b], op, Some(res))?;
167        Ok(c)
168    }
169
170    pub fn mul(&mut self, a: NodeIndex, b: NodeIndex) -> Result<NodeIndex> {
171        let x = self.vals[&a];
172        let y = self.vals[&b];
173        let res = x * y;
174        let c = self.operation([a, b], BinaryOp::mul(), Some(res))?;
175
176        Ok(c)
177    }
178}