radiate_gp/collections/graphs/
eval.rs1use super::{Graph, GraphNode, iter::GraphIterator};
2use crate::{Eval, EvalMut, NodeType, node::Node};
3
4pub struct GraphEvaluator<'a, T, V> {
10 nodes: &'a [GraphNode<T>],
11 output_size: usize,
12 eval_order: Vec<usize>,
13 outputs: Vec<V>,
14 inputs: Vec<Vec<V>>,
15}
16
17impl<'a, T, V> GraphEvaluator<'a, T, V>
18where
19 T: Eval<[V], V>,
20 V: Default + Clone,
21{
22 pub fn new<N>(graph: &'a N) -> GraphEvaluator<'a, T, V>
30 where
31 N: AsRef<[GraphNode<T>]>,
32 {
33 let nodes = graph.as_ref();
34
35 GraphEvaluator {
36 nodes,
37 output_size: nodes
38 .iter()
39 .filter(|node| node.node_type() == NodeType::Output)
40 .count(),
41 inputs: nodes
42 .iter()
43 .map(|node| vec![V::default(); node.incoming().len()])
44 .collect::<Vec<Vec<V>>>(),
45 eval_order: nodes.iter_topological().map(|node| node.index()).collect(),
46 outputs: vec![V::default(); nodes.len()],
47 }
48 }
49}
50
51impl<'a, T, V> EvalMut<[V], Vec<V>> for GraphEvaluator<'a, T, V>
53where
54 T: Eval<[V], V>,
55 V: Clone + Default,
56{
57 #[inline]
67 fn eval_mut(&mut self, input: &[V]) -> Vec<V> {
68 let mut output = Vec::with_capacity(self.output_size);
69 for index in self.eval_order.iter() {
70 let node = &self.nodes[*index];
71 if node.incoming().is_empty() {
72 self.outputs[node.index()] = node.eval(input);
73 } else {
74 for (idx, incoming) in node.incoming().iter().enumerate() {
75 self.inputs[node.index()][idx] = self.outputs[*incoming].clone();
76 }
77
78 self.outputs[node.index()] = node.eval(&self.inputs[node.index()]);
79 }
80
81 if node.node_type() == NodeType::Output {
82 output.push(self.outputs[node.index()].clone());
83 }
84 }
85
86 output
87 }
88}
89
90impl<T, V> Eval<Vec<Vec<V>>, Vec<Vec<V>>> for Graph<T>
91where
92 T: Eval<[V], V>,
93 V: Clone + Default,
94{
95 #[inline]
104 fn eval(&self, input: &Vec<Vec<V>>) -> Vec<Vec<V>> {
105 let mut output = Vec::with_capacity(self.len());
106 let mut evaluator = GraphEvaluator::new(self);
107
108 for inputs in input.iter() {
109 output.push(evaluator.eval_mut(inputs));
110 }
111
112 output
113 }
114}
115
116impl<T: Eval<[V], V>, V: Clone> Eval<[V], V> for GraphNode<T> {
117 #[inline]
124 fn eval(&self, inputs: &[V]) -> V {
125 self.value().eval(inputs)
126 }
127}
128
129#[cfg(test)]
130mod tests {
131
132 use super::*;
133 use crate::{Graph, Op};
134
135 #[test]
136 fn test_graph_eval_simple() {
137 let mut graph = Graph::<Op<f32>>::default();
138
139 let idx_one = graph.insert(NodeType::Input, Op::var(0));
140 let idx_two = graph.insert(NodeType::Input, Op::constant(5_f32));
141 let idx_three = graph.insert(NodeType::Vertex, Op::add());
142 let idx_four = graph.insert(NodeType::Output, Op::linear());
143
144 graph
145 .attach(idx_one, idx_three)
146 .attach(idx_two, idx_three)
147 .attach(idx_three, idx_four);
148
149 let six = graph.eval(&vec![vec![1_f32]]);
150 let seven = graph.eval(&vec![vec![2_f32]]);
151 let eight = graph.eval(&vec![vec![3_f32]]);
152
153 assert_eq!(six, vec![vec![6_f32]]);
154 assert_eq!(seven, vec![vec![7_f32]]);
155 assert_eq!(eight, vec![vec![8_f32]]);
156 assert_eq!(graph.len(), 4);
157 }
158}