1use crate::array::*;
4use crate::finite_function::*;
5use crate::indexed_coproduct::*;
6use crate::layer::{converse, layer};
7use crate::open_hypergraph::*;
8use crate::semifinite::*;
9
10use num_traits::Zero;
11
12fn layer_function_to_layers<K: ArrayKind>(f: FiniteFunction<K>) -> Vec<FiniteFunction<K>>
17where
18 K::Type<K::I>: NaturalArray<K>,
19 K::I: Into<usize> + From<usize>,
20{
21 let c = converse(&IndexedCoproduct::elements(f));
22 c.into_iter().collect()
23}
24
25pub fn eval<K: ArrayKind, O, A, T>(
29 f: &OpenHypergraph<K, O, A>,
30 s: K::Type<T>,
31 apply: impl Fn(
32 SemifiniteFunction<K, A>,
33 IndexedCoproduct<K, SemifiniteFunction<K, T>>,
34 ) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
35) -> Option<K::Type<T>>
36where
37 K::I: Into<usize> + From<usize>,
38 K::Type<K::I>: NaturalArray<K>,
39 K::Type<T>: Array<K, T>,
40 K::Type<O>: Array<K, O>,
41 K::Type<A>: Array<K, A>,
42{
43 let (order, unvisited) = layer(f);
44 let layering = layer_function_to_layers(order);
45
46 if unvisited.max().unwrap_or(K::I::zero()) == K::I::zero() {
49 let (_, outputs) = eval_order(f, s, layering, apply);
50 Some(outputs)
51 } else {
52 None
53 }
54}
55
56fn eval_order<K: ArrayKind, O, A, T>(
58 f: &OpenHypergraph<K, O, A>,
60 s: K::Type<T>,
62 order: Vec<FiniteFunction<K>>,
66 apply: impl Fn(
67 SemifiniteFunction<K, A>,
68 IndexedCoproduct<K, SemifiniteFunction<K, T>>,
69 ) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
70) -> (K::Type<T>, K::Type<T>)
71where
72 K::Type<K::I>: NaturalArray<K>,
73 K::Type<T>: Array<K, T>,
74 K::Type<O>: Array<K, O>,
75 K::Type<A>: Array<K, A>,
76{
77 let mut mem = SemifiniteFunction::new(s.scatter(f.s.table.get_range(..), f.h.w.len()));
79
80 for op_ix in order {
81 let op_labels = (&op_ix >> &f.h.x).unwrap();
83
84 let input_indexes = f.h.s.map_indexes(&op_ix).unwrap();
86 let input_values = input_indexes.map_semifinite(&mem).unwrap();
87
88 let outputs = apply(op_labels, input_values);
90
91 let output_indexes = f.h.t.map_indexes(&op_ix).unwrap();
92
93 mem.0
95 .scatter_assign(&output_indexes.values.table, outputs.values.0);
96
97 }
99 let outputs = mem.0.gather(f.t.table.get_range(..));
100 (mem.0, outputs)
101}