Skip to main content

open_hypergraphs/strict/
eval.rs

1//! An array-backend-agnostic evaluator
2//!
3use crate::array::*;
4use crate::finite_function::*;
5use crate::indexed_coproduct::*;
6use crate::semifinite::*;
7
8use crate::strict::layer::layer;
9use crate::strict::open_hypergraph::*;
10use crate::strict::relation::converse;
11
12use num_traits::Zero;
13use std::default::Default;
14
15// Given a "layering function" `f : N → K` which maps each operation `n ∈ N` into some layer `k ∈
16// K`,
17// return the converse relation `r : K → N*` giving the list of operations in each layer as a list
18// of `FiniteFunction`.
19fn layer_function_to_layers<K: ArrayKind>(f: FiniteFunction<K>) -> Vec<FiniteFunction<K>>
20where
21    K::Type<K::I>: NaturalArray<K>,
22    K::I: Into<usize> + From<usize>,
23{
24    let c = converse(&IndexedCoproduct::elements(f));
25    c.into_iter().collect()
26}
27
28/// Evaluate an acyclic open hypergraph `f` thought of as a function using some specified input
29/// values `s`, and a function `apply` which maps a list of operations and their inputs to their
30/// outputs.
31pub fn eval<K: ArrayKind, O, A, T: Default>(
32    f: &OpenHypergraph<K, O, A>,
33    s: K::Type<T>,
34    apply: impl Fn(
35        SemifiniteFunction<K, A>,
36        IndexedCoproduct<K, SemifiniteFunction<K, T>>,
37    ) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
38) -> Option<K::Type<T>>
39where
40    K::I: Into<usize> + From<usize>,
41    K::Type<K::I>: NaturalArray<K>,
42    K::Type<T>: Array<K, T>,
43    K::Type<O>: Array<K, O>,
44    K::Type<A>: Array<K, A>,
45{
46    let (order, unvisited) = layer(f);
47    let layering = layer_function_to_layers(order);
48
49    // Check that max of 'unvisited' is 0: i.e., no unvisited nodes.
50    // TODO: this has to evaluate the whole array, when it could just use 'any'.
51    if unvisited.max().unwrap_or(K::I::zero()) == K::I::zero() {
52        let (_, outputs) = eval_order(f, s, layering, apply);
53        Some(outputs)
54    } else {
55        None
56    }
57}
58
59// Evaluate an acyclic open hypergraph using a specified order of operations.
60fn eval_order<K: ArrayKind, O, A, T: Default>(
61    // The term to evaluate
62    f: &OpenHypergraph<K, O, A>,
63    // Source wire inputs
64    s: K::Type<T>,
65    // A chosen order of operations
66    // TODO: this should be an *iterator* over arrays?
67    //order: &IndexedCoproduct<K, FiniteFunction<K>>,
68    order: Vec<FiniteFunction<K>>,
69    apply: impl Fn(
70        SemifiniteFunction<K, A>,
71        IndexedCoproduct<K, SemifiniteFunction<K, T>>,
72    ) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
73) -> (K::Type<T>, K::Type<T>)
74where
75    K::Type<K::I>: NaturalArray<K>,
76    K::Type<T>: Array<K, T>,
77    K::Type<O>: Array<K, O>,
78    K::Type<A>: Array<K, A>,
79{
80    // Create memory prefilled with default data
81    let mut mem: SemifiniteFunction<K, T> =
82        SemifiniteFunction::new(K::Type::<T>::fill(T::default(), f.h.w.len()));
83
84    // Overwrite input locations with values in s.
85    mem.0.scatter_assign(&f.s.table, s);
86
87    for op_ix in order {
88        // Compute *labels* of operations to pass to `apply`.
89        let op_labels = (&op_ix >> &f.h.x).unwrap();
90
91        // Get the wire indices and values which are inputs to the operations in op_ix.
92        let input_indexes = f.h.s.map_indexes(&op_ix).unwrap();
93        let input_values = input_indexes.map_semifinite(&mem).unwrap();
94
95        // Compute an IndexedCoproduct of output values.
96        let outputs = apply(op_labels, input_values);
97
98        let output_indexes = f.h.t.map_indexes(&op_ix).unwrap();
99
100        // write outputs to memory
101        mem.0
102            .scatter_assign(&output_indexes.values.table, outputs.values.0);
103
104        // TODO: evaluate all 'ops' in parallel using a user-supplied function
105    }
106    let outputs = mem.0.gather(f.t.table.get_range(..));
107    (mem.0, outputs)
108}