open_hypergraphs/
eval.rs

1//! An array-backend-agnostic evaluator
2//!
3use crate::array::*;
4use crate::finite_function::*;
5use crate::indexed_coproduct::*;
6use crate::layer::{converse, layer};
7use crate::open_hypergraph::*;
8use crate::semifinite::*;
9
10use num_traits::Zero;
11
12// Given a "layering function" `f : N → K` which maps each operation `n ∈ N` into some layer `k ∈
13// K`,
14// return the converse relation `r : K → N*` giving the list of operations in each layer as a list
15// of `FiniteFunction`.
16fn layer_function_to_layers<K: ArrayKind>(f: FiniteFunction<K>) -> Vec<FiniteFunction<K>>
17where
18    K::Type<K::I>: NaturalArray<K>,
19    K::I: Into<usize> + From<usize>,
20{
21    let c = converse(&IndexedCoproduct::elements(f));
22    c.into_iter().collect()
23}
24
25/// Evaluate an acyclic open hypergraph `f` thought of as a function using some specified input
26/// values `s`, and a function `apply` which maps a list of operations and their inputs to their
27/// outputs.
28pub fn eval<K: ArrayKind, O, A, T>(
29    f: &OpenHypergraph<K, O, A>,
30    s: K::Type<T>,
31    apply: impl Fn(
32        SemifiniteFunction<K, A>,
33        IndexedCoproduct<K, SemifiniteFunction<K, T>>,
34    ) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
35) -> Option<K::Type<T>>
36where
37    K::I: Into<usize> + From<usize>,
38    K::Type<K::I>: NaturalArray<K>,
39    K::Type<T>: Array<K, T>,
40    K::Type<O>: Array<K, O>,
41    K::Type<A>: Array<K, A>,
42{
43    let (order, unvisited) = layer(f);
44    let layering = layer_function_to_layers(order);
45
46    // Check that max of 'unvisited' is 0: i.e., no unvisited nodes.
47    // TODO: this has to evaluate the whole array, when it could just use 'any'.
48    if unvisited.max().unwrap_or(K::I::zero()) == K::I::zero() {
49        let (_, outputs) = eval_order(f, s, layering, apply);
50        Some(outputs)
51    } else {
52        None
53    }
54}
55
56// Evaluate an acyclic open hypergraph using a specified order of operations.
57fn eval_order<K: ArrayKind, O, A, T>(
58    // The term to evaluate
59    f: &OpenHypergraph<K, O, A>,
60    // Source wire inputs
61    s: K::Type<T>,
62    // A chosen order of operations
63    // TODO: this should be an *iterator* over arrays?
64    //order: &IndexedCoproduct<K, FiniteFunction<K>>,
65    order: Vec<FiniteFunction<K>>,
66    apply: impl Fn(
67        SemifiniteFunction<K, A>,
68        IndexedCoproduct<K, SemifiniteFunction<K, T>>,
69    ) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
70) -> (K::Type<T>, K::Type<T>)
71where
72    K::Type<K::I>: NaturalArray<K>,
73    K::Type<T>: Array<K, T>,
74    K::Type<O>: Array<K, O>,
75    K::Type<A>: Array<K, A>,
76{
77    // Create memory prefilled with input data
78    let mut mem = SemifiniteFunction::new(s.scatter(f.s.table.get_range(..), f.h.w.len()));
79
80    for op_ix in order {
81        // Compute *labels* of operations to pass to `apply`.
82        let op_labels = (&op_ix >> &f.h.x).unwrap();
83
84        // Get the wire indices and values which are inputs to the operations in op_ix.
85        let input_indexes = f.h.s.map_indexes(&op_ix).unwrap();
86        let input_values = input_indexes.map_semifinite(&mem).unwrap();
87
88        // Compute an IndexedCoproduct of output values.
89        let outputs = apply(op_labels, input_values);
90
91        let output_indexes = f.h.t.map_indexes(&op_ix).unwrap();
92
93        // write outputs to memory
94        mem.0
95            .scatter_assign(&output_indexes.values.table, outputs.values.0);
96
97        // TODO: evaluate all 'ops' in parallel using a user-supplied function
98    }
99    let outputs = mem.0.gather(f.t.table.get_range(..));
100    (mem.0, outputs)
101}