radiate_gp/collections/graphs/
eval.rs1use super::{Graph, GraphNode, iter::GraphIterator};
2use crate::{Eval, EvalMut, NodeType, eval::EvalIntoMut, node::Node};
3#[cfg(feature = "serde")]
4use serde::{Deserialize, Serialize};
5use smallvec::SmallVec;
6use std::ops::Range;
7
8#[derive(Clone, Debug, PartialEq)]
16#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
17pub struct GraphEvalCache<V> {
18 eval_order: Vec<usize>,
19 outputs: Vec<V>,
20 inputs: Vec<V>,
21 input_ranges: Vec<Range<usize>>,
22 output_indices: SmallVec<[usize; 8]>,
23}
24
25pub struct GraphEvaluator<'a, T, V> {
28 nodes: &'a [GraphNode<T>],
29 inner: GraphEvalCache<V>,
30}
31
32impl<'a, T, V> GraphEvaluator<'a, T, V>
33where
34 T: Eval<[V], V>,
35 V: Default + Clone,
36{
37 #[inline]
45 pub fn new<N>(graph: &'a N) -> GraphEvaluator<'a, T, V>
46 where
47 N: AsRef<[GraphNode<T>]>,
48 {
49 let nodes = graph.as_ref();
50
51 let mut total_inputs = 0;
52 let mut input_ranges = Vec::with_capacity(nodes.len());
53
54 for node in nodes {
55 let k = node.incoming().len();
56 input_ranges.push(total_inputs..total_inputs + k);
57 total_inputs += k;
58 }
59
60 let mut output_indices: SmallVec<[usize; 8]> = SmallVec::new();
61 for (i, n) in nodes.iter().enumerate() {
62 if n.node_type() == NodeType::Output {
63 output_indices.push(i);
64 }
65 }
66
67 GraphEvaluator {
68 nodes,
69 inner: GraphEvalCache {
70 inputs: vec![V::default(); total_inputs],
71 outputs: vec![V::default(); nodes.len()],
72 eval_order: nodes.iter_topological().map(|n| n.index()).collect(),
73 input_ranges,
74 output_indices,
75 },
76 }
77 }
78
79 pub fn take_cache(self) -> GraphEvalCache<V> {
80 self.inner
81 }
82}
83
84impl<T, V> EvalMut<[V], Vec<V>> for GraphEvaluator<'_, T, V>
85where
86 T: Eval<[V], V>,
87 V: Clone + Default,
88{
89 #[inline]
90 fn eval_mut(&mut self, input: &[V]) -> Vec<V> {
91 let out_len = self.inner.output_indices.len();
92 let mut buffer: Vec<V> = vec![V::default(); out_len];
93 self.eval_into_mut(input, &mut buffer[..]);
94 buffer
95 }
96}
97
98impl<T, V> EvalIntoMut<[V], [V]> for GraphEvaluator<'_, T, V>
99where
100 T: Eval<[V], V>,
101 V: Clone + Default,
102{
103 #[inline]
104 fn eval_into_mut(&mut self, input: &[V], buffer: &mut [V]) {
105 for &index in self.inner.eval_order.iter() {
106 let node = &self.nodes[index];
107 let incoming = node.incoming();
108
109 if incoming.is_empty() {
110 self.inner.outputs[index] = node.eval(input);
111 } else {
112 let range = &self.inner.input_ranges[index];
113 let buf = &mut self.inner.inputs[range.clone()];
114
115 for (dst, &src_idx) in buf.iter_mut().zip(incoming.iter()) {
116 *dst = self.inner.outputs[src_idx].clone();
117 }
118
119 self.inner.outputs[index] = node.eval(buf);
120 }
121 }
122
123 let mut count = 0;
124 for &idx in &self.inner.output_indices {
125 buffer[count] = self.inner.outputs[idx].clone();
126 count += 1;
127 }
128 }
129}
130
131impl<T, V> Eval<[Vec<V>], Vec<Vec<V>>> for Graph<T>
132where
133 T: Eval<[V], V>,
134 V: Clone + Default,
135{
136 #[inline]
145 fn eval(&self, input: &[Vec<V>]) -> Vec<Vec<V>> {
146 let mut evaluator = GraphEvaluator::new(self);
147 input
148 .iter()
149 .map(|input| evaluator.eval_mut(input))
150 .collect()
151 }
152}
153
154impl<T, V> Eval<[V], V> for GraphNode<T>
155where
156 T: Eval<[V], V>,
157 V: Clone,
158{
159 #[inline]
166 fn eval(&self, inputs: &[V]) -> V {
167 self.value().eval(inputs)
168 }
169}
170
171impl<'a, G, T, V> From<(&'a G, GraphEvalCache<V>)> for GraphEvaluator<'a, T, V>
172where
173 G: AsRef<[GraphNode<T>]>,
174 T: Eval<[V], V>,
175 V: Default + Clone,
176{
177 fn from((graph, cache): (&'a G, GraphEvalCache<V>)) -> Self {
178 if cache.eval_order.is_empty() || graph.as_ref().len() != cache.eval_order.len() {
179 return GraphEvaluator::new(graph);
180 }
181
182 GraphEvaluator {
183 nodes: graph.as_ref(),
184 inner: cache,
185 }
186 }
187}
188
189impl<'a, T, V> From<&'a Graph<T>> for GraphEvaluator<'a, T, V>
190where
191 T: Eval<[V], V>,
192 V: Default + Clone,
193{
194 fn from(graph: &'a Graph<T>) -> Self {
195 GraphEvaluator::new(graph)
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::{Graph, Op};
203
204 fn round(value: f32, places: u32) -> f32 {
205 let factor = 10_f32.powi(places as i32);
206 (value * factor).round() / factor
207 }
208
209 #[test]
210 fn test_graph_eval_simple() {
211 let mut graph = Graph::<Op<f32>>::default();
212
213 let idx_one = graph.insert(NodeType::Input, Op::var(0));
214 let idx_two = graph.insert(NodeType::Input, Op::constant(5_f32));
215 let idx_three = graph.insert(NodeType::Vertex, Op::add());
216 let idx_four = graph.insert(NodeType::Output, Op::linear());
217
218 graph
219 .attach(idx_one, idx_three)
220 .attach(idx_two, idx_three)
221 .attach(idx_three, idx_four);
222
223 let six = graph.eval(&[vec![1_f32]]);
224 let seven = graph.eval(&[vec![2_f32]]);
225 let eight = graph.eval(&[vec![3_f32]]);
226
227 assert_eq!(six, vec![vec![6_f32]]);
228 assert_eq!(seven, vec![vec![7_f32]]);
229 assert_eq!(eight, vec![vec![8_f32]]);
230 assert_eq!(graph.len(), 4);
231 }
232
233 #[test]
234 fn test_graph_eval_recurrent() {
235 let mut graph = Graph::<Op<f32>>::default();
236
237 graph.insert(NodeType::Input, Op::var(0));
238 graph.insert(NodeType::Vertex, Op::diff());
239 graph.insert(NodeType::Output, Op::sigmoid());
240 graph.insert(NodeType::Edge, Op::weight_with(-1.41));
241 graph.insert(NodeType::Vertex, Op::sigmoid());
242 graph.insert(NodeType::Vertex, Op::exp());
243 graph.insert(NodeType::Edge, Op::weight_with(-1.10));
244 graph.insert(NodeType::Vertex, Op::exp());
245 graph.insert(NodeType::Vertex, Op::exp());
246 graph.insert(NodeType::Vertex, Op::div());
247
248 graph.attach(0, 1);
249 graph.attach(1, 1);
250 graph.attach(4, 1);
251 graph.attach(7, 1);
252 graph.attach(8, 1);
253 graph.attach(1, 2);
254 graph.attach(3, 2);
255 graph.attach(6, 2);
256 graph.attach(5, 3);
257 graph.attach(1, 4);
258 graph.attach(0, 5);
259 graph.attach(9, 6);
260 graph.attach(4, 7);
261 graph.attach(7, 8);
262 graph.attach(0, 9);
263 graph.attach(9, 9);
264
265 graph.set_cycles(vec![]);
266
267 let mut evaluator = GraphEvaluator::new(&graph);
268
269 let out1 = evaluator.eval_mut(&vec![0.0])[0];
270 let out2 = evaluator.eval_mut(&vec![0.0])[0];
271 let out3 = evaluator.eval_mut(&vec![0.0])[0];
272 let out4 = evaluator.eval_mut(&vec![1.0])[0];
273 let out5 = evaluator.eval_mut(&vec![0.0])[0];
274 let out6 = evaluator.eval_mut(&vec![0.0])[0];
275 let out7 = evaluator.eval_mut(&vec![0.0])[0];
276
277 assert_eq!(round(out1, 3), 0.196);
278 assert_eq!(round(out2, 3), 0.000);
279 assert_eq!(round(out3, 3), 0.902);
280 assert_eq!(round(out4, 3), 0.000);
281 assert_eq!(round(out5, 3), 0.000);
282 assert_eq!(round(out6, 3), 0.000);
283 assert_eq!(round(out7, 3), 1.000);
284 }
285}