radiate_gp/collections/graphs/
eval.rs

1use super::{Graph, GraphNode, iter::GraphIterator};
2use crate::{Eval, EvalMut, NodeType, node::Node};
3
4/// `GraphReducer` is a struct that is used to evaluate a `Graph` of `Node`s. It uses the `GraphIterator`
5/// to traverse the `Graph` in a sudo-topological order and evaluate the nodes in the correct order.
6///
7/// On the first iteration it caches the order of nodes in the `Graph` and then uses that order to
8/// evaluate the nodes in the correct order. This is a massive performance improvement.
9pub struct GraphEvaluator<'a, T, V> {
10    nodes: &'a [GraphNode<T>],
11    output_size: usize,
12    eval_order: Vec<usize>,
13    outputs: Vec<V>,
14    inputs: Vec<Vec<V>>,
15}
16
17impl<'a, T, V> GraphEvaluator<'a, T, V>
18where
19    T: Eval<[V], V>,
20    V: Default + Clone,
21{
22    /// Creates a new `GraphEvaluator` with the given `Graph`. Will cache the order of nodes in
23    /// the `Graph` on the first iteration. On initialization the `GraphEvaluator` will cache the
24    /// output size of the `Graph` to be used in the `reduce` method and create a vec of `Tracer`
25    /// which will be used to evaluate the `Graph` in the `reduce` method.
26    ///
27    /// # Arguments
28    /// * `graph` - The `Graph` to reduce.
29    pub fn new<N>(graph: &'a N) -> GraphEvaluator<'a, T, V>
30    where
31        N: AsRef<[GraphNode<T>]>,
32    {
33        let nodes = graph.as_ref();
34
35        GraphEvaluator {
36            nodes,
37            output_size: nodes
38                .iter()
39                .filter(|node| node.node_type() == NodeType::Output)
40                .count(),
41            inputs: nodes
42                .iter()
43                .map(|node| vec![V::default(); node.incoming().len()])
44                .collect::<Vec<Vec<V>>>(),
45            eval_order: nodes.iter_topological().map(|node| node.index()).collect(),
46            outputs: vec![V::default(); nodes.len()],
47        }
48    }
49}
50
51/// Implements the `EvalMut` trait for `GraphEvaluator`.
52impl<'a, T, V> EvalMut<[V], Vec<V>> for GraphEvaluator<'a, T, V>
53where
54    T: Eval<[V], V>,
55    V: Clone + Default,
56{
57    /// Evaluates the `Graph` with the given input. Returns the output of the `Graph`.
58    /// The `reduce` method will cache the order of nodes in the `Graph` on the first iteration.
59    /// On subsequent iterations it will use the cached order to evaluate the `Graph` in the correct order.
60    ///
61    /// # Arguments
62    /// * `input` - A `Vec` of `T` to evaluate the `Graph` with.
63    ///
64    ///  # Returns
65    /// * A `Vec` of `T` which is the output of the `Graph`.
66    #[inline]
67    fn eval_mut(&mut self, input: &[V]) -> Vec<V> {
68        let mut output = Vec::with_capacity(self.output_size);
69        for index in self.eval_order.iter() {
70            let node = &self.nodes[*index];
71            if node.incoming().is_empty() {
72                self.outputs[node.index()] = node.eval(input);
73            } else {
74                for (idx, incoming) in node.incoming().iter().enumerate() {
75                    self.inputs[node.index()][idx] = self.outputs[*incoming].clone();
76                }
77
78                self.outputs[node.index()] = node.eval(&self.inputs[node.index()]);
79            }
80
81            if node.node_type() == NodeType::Output {
82                output.push(self.outputs[node.index()].clone());
83            }
84        }
85
86        output
87    }
88}
89
90impl<T, V> Eval<Vec<Vec<V>>, Vec<Vec<V>>> for Graph<T>
91where
92    T: Eval<[V], V>,
93    V: Clone + Default,
94{
95    /// Evaluates the `Graph` with the given input 'Vec<Vec<T>>'. Returns the output of the `Graph` as 'Vec<Vec<T>>'.
96    /// This is inteded to be used when evaluating a batch of inputs.
97    ///
98    /// # Arguments
99    /// * `input` - A `Vec<Vec<T>>` to evaluate the `Graph` with.
100    ///
101    /// # Returns
102    /// * A `Vec<Vec<T>>` which is the output of the `Graph`.
103    #[inline]
104    fn eval(&self, input: &Vec<Vec<V>>) -> Vec<Vec<V>> {
105        let mut output = Vec::with_capacity(self.len());
106        let mut evaluator = GraphEvaluator::new(self);
107
108        for inputs in input.iter() {
109            output.push(evaluator.eval_mut(inputs));
110        }
111
112        output
113    }
114}
115
116impl<T: Eval<[V], V>, V: Clone> Eval<[V], V> for GraphNode<T> {
117    /// Evaluates the `GraphNode` with the given input. Returns the output of the `GraphNode`.
118    /// # Arguments
119    /// * `inputs` - A `Vec` of `T` to evaluate the `GraphNode` with.
120    ///
121    /// # Returns
122    /// * A `T` which is the output of the `GraphNode`.
123    #[inline]
124    fn eval(&self, inputs: &[V]) -> V {
125        self.value().eval(inputs)
126    }
127}
128
129#[cfg(test)]
130mod tests {
131
132    use super::*;
133    use crate::{Graph, Op};
134
135    #[test]
136    fn test_graph_eval_simple() {
137        let mut graph = Graph::<Op<f32>>::default();
138
139        let idx_one = graph.insert(NodeType::Input, Op::var(0));
140        let idx_two = graph.insert(NodeType::Input, Op::constant(5_f32));
141        let idx_three = graph.insert(NodeType::Vertex, Op::add());
142        let idx_four = graph.insert(NodeType::Output, Op::linear());
143
144        graph
145            .attach(idx_one, idx_three)
146            .attach(idx_two, idx_three)
147            .attach(idx_three, idx_four);
148
149        let six = graph.eval(&vec![vec![1_f32]]);
150        let seven = graph.eval(&vec![vec![2_f32]]);
151        let eight = graph.eval(&vec![vec![3_f32]]);
152
153        assert_eq!(six, vec![vec![6_f32]]);
154        assert_eq!(seven, vec![vec![7_f32]]);
155        assert_eq!(eight, vec![vec![8_f32]]);
156        assert_eq!(graph.len(), 4);
157    }
158}