1use crate::array::*;
4use crate::finite_function::*;
5use crate::indexed_coproduct::*;
6use crate::layer::{converse, layer};
7use crate::open_hypergraph::*;
8use crate::semifinite::*;
9
10use num_traits::Zero;
11use std::default::Default;
12
13fn layer_function_to_layers<K: ArrayKind>(f: FiniteFunction<K>) -> Vec<FiniteFunction<K>>
18where
19 K::Type<K::I>: NaturalArray<K>,
20 K::I: Into<usize> + From<usize>,
21{
22 let c = converse(&IndexedCoproduct::elements(f));
23 c.into_iter().collect()
24}
25
26pub fn eval<K: ArrayKind, O, A, T: Default>(
30 f: &OpenHypergraph<K, O, A>,
31 s: K::Type<T>,
32 apply: impl Fn(
33 SemifiniteFunction<K, A>,
34 IndexedCoproduct<K, SemifiniteFunction<K, T>>,
35 ) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
36) -> Option<K::Type<T>>
37where
38 K::I: Into<usize> + From<usize>,
39 K::Type<K::I>: NaturalArray<K>,
40 K::Type<T>: Array<K, T>,
41 K::Type<O>: Array<K, O>,
42 K::Type<A>: Array<K, A>,
43{
44 let (order, unvisited) = layer(f);
45 let layering = layer_function_to_layers(order);
46
47 if unvisited.max().unwrap_or(K::I::zero()) == K::I::zero() {
50 let (_, outputs) = eval_order(f, s, layering, apply);
51 Some(outputs)
52 } else {
53 None
54 }
55}
56
57fn eval_order<K: ArrayKind, O, A, T: Default>(
59 f: &OpenHypergraph<K, O, A>,
61 s: K::Type<T>,
63 order: Vec<FiniteFunction<K>>,
67 apply: impl Fn(
68 SemifiniteFunction<K, A>,
69 IndexedCoproduct<K, SemifiniteFunction<K, T>>,
70 ) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
71) -> (K::Type<T>, K::Type<T>)
72where
73 K::Type<K::I>: NaturalArray<K>,
74 K::Type<T>: Array<K, T>,
75 K::Type<O>: Array<K, O>,
76 K::Type<A>: Array<K, A>,
77{
78 let mut mem: SemifiniteFunction<K, T> =
80 SemifiniteFunction::new(K::Type::<T>::fill(T::default(), f.h.w.len()));
81
82 mem.0.scatter_assign(&f.s.table, s);
84
85 for op_ix in order {
86 let op_labels = (&op_ix >> &f.h.x).unwrap();
88
89 let input_indexes = f.h.s.map_indexes(&op_ix).unwrap();
91 let input_values = input_indexes.map_semifinite(&mem).unwrap();
92
93 let outputs = apply(op_labels, input_values);
95
96 let output_indexes = f.h.t.map_indexes(&op_ix).unwrap();
97
98 mem.0
100 .scatter_assign(&output_indexes.values.table, outputs.values.0);
101
102 }
104 let outputs = mem.0.gather(f.t.table.get_range(..));
105 (mem.0, outputs)
106}