radiate_gp/collections/graphs/
eval.rs

1use super::{Graph, GraphNode, iter::GraphIterator};
2use crate::{Eval, EvalMut, NodeType, eval::EvalIntoMut, node::Node};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use smallvec::SmallVec;
6use std::ops::Range;
7
8/// A cache for storing intermediate results during graph evaluation.
9///
10/// This cache is used to store the inputs and outputs of each node in the graph
11/// during evaluation, allowing for more efficient re-evaluation of nodes when
12/// their inputs change. If we want to save a graph's evaluation between different evals,
13/// we need to keep track of the inputs and outputs from previous runs incase of recurrent
14/// structures. This cache is the answer to that.
15#[derive(Clone, Debug, PartialEq)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct GraphEvalCache<V> {
18    eval_order: Vec<usize>,
19    outputs: Vec<V>,
20    inputs: Vec<V>,
21    input_ranges: Vec<Range<usize>>,
22    output_indices: SmallVec<[usize; 8]>,
23}
24
25/// [GraphEvaluator] is a struct that is used to evaluate a [Graph] of [GraphNode]'s. It uses the [GraphIterator]
26/// to traverse the [Graph] in a pseudo-topological order and evaluate the nodes in the correct order.
27pub struct GraphEvaluator<'a, T, V> {
28    nodes: &'a [GraphNode<T>],
29    inner: GraphEvalCache<V>,
30}
31
32impl<'a, T, V> GraphEvaluator<'a, T, V>
33where
34    T: Eval<[V], V>,
35    V: Default + Clone,
36{
37    /// Creates a new [GraphEvaluator] with the given [Graph]. We pre-allocate the necessary
38    /// storage for inputs and outputs based on the structure of the graph on creation.
39    /// This way, we can reuse the same evaluator for multiple evaluations of the same graph
40    /// without needing to reallocate memory each time.
41    ///
42    /// # Arguments
43    /// * graph - The [Graph] to reduce.
44    #[inline]
45    pub fn new<N>(graph: &'a N) -> GraphEvaluator<'a, T, V>
46    where
47        N: AsRef<[GraphNode<T>]>,
48    {
49        let nodes = graph.as_ref();
50
51        let mut total_inputs = 0;
52        let mut input_ranges = Vec::with_capacity(nodes.len());
53
54        for node in nodes {
55            let k = node.incoming().len();
56            input_ranges.push(total_inputs..total_inputs + k);
57            total_inputs += k;
58        }
59
60        let mut output_indices: SmallVec<[usize; 8]> = SmallVec::new();
61        for (i, n) in nodes.iter().enumerate() {
62            if n.node_type() == NodeType::Output {
63                output_indices.push(i);
64            }
65        }
66
67        GraphEvaluator {
68            nodes,
69            inner: GraphEvalCache {
70                inputs: vec![V::default(); total_inputs],
71                outputs: vec![V::default(); nodes.len()],
72                eval_order: nodes.iter_topological().map(|n| n.index()).collect(),
73                input_ranges,
74                output_indices,
75            },
76        }
77    }
78
79    pub fn take_cache(self) -> GraphEvalCache<V> {
80        self.inner
81    }
82}
83
84impl<T, V> EvalMut<[V], Vec<V>> for GraphEvaluator<'_, T, V>
85where
86    T: Eval<[V], V>,
87    V: Clone + Default,
88{
89    #[inline]
90    fn eval_mut(&mut self, input: &[V]) -> Vec<V> {
91        let out_len = self.inner.output_indices.len();
92        let mut buffer: Vec<V> = vec![V::default(); out_len];
93        self.eval_into_mut(input, &mut buffer[..]);
94        buffer
95    }
96}
97
98impl<T, V> EvalIntoMut<[V], [V]> for GraphEvaluator<'_, T, V>
99where
100    T: Eval<[V], V>,
101    V: Clone + Default,
102{
103    #[inline]
104    fn eval_into_mut(&mut self, input: &[V], buffer: &mut [V]) {
105        for &index in self.inner.eval_order.iter() {
106            let node = &self.nodes[index];
107            let incoming = node.incoming();
108
109            if incoming.is_empty() {
110                self.inner.outputs[index] = node.eval(input);
111            } else {
112                let range = &self.inner.input_ranges[index];
113                let buf = &mut self.inner.inputs[range.clone()];
114
115                for (dst, &src_idx) in buf.iter_mut().zip(incoming.iter()) {
116                    *dst = self.inner.outputs[src_idx].clone();
117                }
118
119                self.inner.outputs[index] = node.eval(buf);
120            }
121        }
122
123        let mut count = 0;
124        for &idx in &self.inner.output_indices {
125            buffer[count] = self.inner.outputs[idx].clone();
126            count += 1;
127        }
128    }
129}
130
131impl<T, V> Eval<[Vec<V>], Vec<Vec<V>>> for Graph<T>
132where
133    T: Eval<[V], V>,
134    V: Clone + Default,
135{
136    /// Evaluates the [Graph] with the given input `Vec<Vec<T>>`. Returns the output of the [Graph] as `Vec<Vec<T>>`.
137    /// This is intended to be used when evaluating a batch of inputs.
138    ///
139    /// # Arguments
140    /// * `input` - A `Vec<Vec<T>>` to evaluate the [Graph] with.
141    ///
142    /// # Returns
143    /// * A `Vec<Vec<T>>` which is the output of the [Graph].
144    #[inline]
145    fn eval(&self, input: &[Vec<V>]) -> Vec<Vec<V>> {
146        let mut evaluator = GraphEvaluator::new(self);
147        input
148            .iter()
149            .map(|input| evaluator.eval_mut(input))
150            .collect()
151    }
152}
153
154impl<T, V> Eval<[V], V> for GraphNode<T>
155where
156    T: Eval<[V], V>,
157    V: Clone,
158{
159    /// Evaluates the [GraphNode] with the given input. Returns the output of the [GraphNode].
160    /// # Arguments
161    /// * `inputs` - A `Vec` of `V` to evaluate the [GraphNode] with.
162    ///
163    /// # Returns
164    /// * A `V` which is the output of the [GraphNode].
165    #[inline]
166    fn eval(&self, inputs: &[V]) -> V {
167        self.value().eval(inputs)
168    }
169}
170
171impl<'a, G, T, V> From<(&'a G, GraphEvalCache<V>)> for GraphEvaluator<'a, T, V>
172where
173    G: AsRef<[GraphNode<T>]>,
174    T: Eval<[V], V>,
175    V: Default + Clone,
176{
177    fn from((graph, cache): (&'a G, GraphEvalCache<V>)) -> Self {
178        if cache.eval_order.is_empty() || graph.as_ref().len() != cache.eval_order.len() {
179            return GraphEvaluator::new(graph);
180        }
181
182        GraphEvaluator {
183            nodes: graph.as_ref(),
184            inner: cache,
185        }
186    }
187}
188
189impl<'a, T, V> From<&'a Graph<T>> for GraphEvaluator<'a, T, V>
190where
191    T: Eval<[V], V>,
192    V: Default + Clone,
193{
194    fn from(graph: &'a Graph<T>) -> Self {
195        GraphEvaluator::new(graph)
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::{Graph, Op};
203
204    fn round(value: f32, places: u32) -> f32 {
205        let factor = 10_f32.powi(places as i32);
206        (value * factor).round() / factor
207    }
208
209    #[test]
210    fn test_graph_eval_simple() {
211        let mut graph = Graph::<Op<f32>>::default();
212
213        let idx_one = graph.insert(NodeType::Input, Op::var(0));
214        let idx_two = graph.insert(NodeType::Input, Op::constant(5_f32));
215        let idx_three = graph.insert(NodeType::Vertex, Op::add());
216        let idx_four = graph.insert(NodeType::Output, Op::linear());
217
218        graph
219            .attach(idx_one, idx_three)
220            .attach(idx_two, idx_three)
221            .attach(idx_three, idx_four);
222
223        let six = graph.eval(&[vec![1_f32]]);
224        let seven = graph.eval(&[vec![2_f32]]);
225        let eight = graph.eval(&[vec![3_f32]]);
226
227        assert_eq!(six, vec![vec![6_f32]]);
228        assert_eq!(seven, vec![vec![7_f32]]);
229        assert_eq!(eight, vec![vec![8_f32]]);
230        assert_eq!(graph.len(), 4);
231    }
232
233    #[test]
234    fn test_graph_eval_recurrent() {
235        let mut graph = Graph::<Op<f32>>::default();
236
237        graph.insert(NodeType::Input, Op::var(0));
238        graph.insert(NodeType::Vertex, Op::diff());
239        graph.insert(NodeType::Output, Op::sigmoid());
240        graph.insert(NodeType::Edge, Op::weight_with(-1.41));
241        graph.insert(NodeType::Vertex, Op::sigmoid());
242        graph.insert(NodeType::Vertex, Op::exp());
243        graph.insert(NodeType::Edge, Op::weight_with(-1.10));
244        graph.insert(NodeType::Vertex, Op::exp());
245        graph.insert(NodeType::Vertex, Op::exp());
246        graph.insert(NodeType::Vertex, Op::div());
247
248        graph.attach(0, 1);
249        graph.attach(1, 1);
250        graph.attach(4, 1);
251        graph.attach(7, 1);
252        graph.attach(8, 1);
253        graph.attach(1, 2);
254        graph.attach(3, 2);
255        graph.attach(6, 2);
256        graph.attach(5, 3);
257        graph.attach(1, 4);
258        graph.attach(0, 5);
259        graph.attach(9, 6);
260        graph.attach(4, 7);
261        graph.attach(7, 8);
262        graph.attach(0, 9);
263        graph.attach(9, 9);
264
265        graph.set_cycles(vec![]);
266
267        let mut evaluator = GraphEvaluator::new(&graph);
268
269        let out1 = evaluator.eval_mut(&vec![0.0])[0];
270        let out2 = evaluator.eval_mut(&vec![0.0])[0];
271        let out3 = evaluator.eval_mut(&vec![0.0])[0];
272        let out4 = evaluator.eval_mut(&vec![1.0])[0];
273        let out5 = evaluator.eval_mut(&vec![0.0])[0];
274        let out6 = evaluator.eval_mut(&vec![0.0])[0];
275        let out7 = evaluator.eval_mut(&vec![0.0])[0];
276
277        assert_eq!(round(out1, 3), 0.196);
278        assert_eq!(round(out2, 3), 0.000);
279        assert_eq!(round(out3, 3), 0.902);
280        assert_eq!(round(out4, 3), 0.000);
281        assert_eq!(round(out5, 3), 0.000);
282        assert_eq!(round(out6, 3), 0.000);
283        assert_eq!(round(out7, 3), 1.000);
284    }
285}