use crate::array::*;
use crate::finite_function::*;
use crate::indexed_coproduct::*;
use crate::semifinite::*;
use crate::strict::layer::layer;
use crate::strict::open_hypergraph::*;
use crate::strict::relation::converse;
use num_traits::Zero;
use std::default::Default;
fn layer_function_to_layers<K: ArrayKind>(f: FiniteFunction<K>) -> Vec<FiniteFunction<K>>
where
K::Type<K::I>: NaturalArray<K>,
K::I: Into<usize> + From<usize>,
{
let c = converse(&IndexedCoproduct::elements(f));
c.into_iter().collect()
}
pub fn eval<K: ArrayKind, O, A, T: Default>(
f: &OpenHypergraph<K, O, A>,
s: K::Type<T>,
apply: impl Fn(
SemifiniteFunction<K, A>,
IndexedCoproduct<K, SemifiniteFunction<K, T>>,
) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
) -> Option<K::Type<T>>
where
K::I: Into<usize> + From<usize>,
K::Type<K::I>: NaturalArray<K>,
K::Type<T>: Array<K, T>,
K::Type<O>: Array<K, O>,
K::Type<A>: Array<K, A>,
{
let (order, unvisited) = layer(f);
let layering = layer_function_to_layers(order);
if unvisited.max().unwrap_or(K::I::zero()) == K::I::zero() {
let (_, outputs) = eval_order(f, s, layering, apply);
Some(outputs)
} else {
None
}
}
fn eval_order<K: ArrayKind, O, A, T: Default>(
f: &OpenHypergraph<K, O, A>,
s: K::Type<T>,
order: Vec<FiniteFunction<K>>,
apply: impl Fn(
SemifiniteFunction<K, A>,
IndexedCoproduct<K, SemifiniteFunction<K, T>>,
) -> IndexedCoproduct<K, SemifiniteFunction<K, T>>,
) -> (K::Type<T>, K::Type<T>)
where
K::Type<K::I>: NaturalArray<K>,
K::Type<T>: Array<K, T>,
K::Type<O>: Array<K, O>,
K::Type<A>: Array<K, A>,
{
let mut mem: SemifiniteFunction<K, T> =
SemifiniteFunction::new(K::Type::<T>::fill(T::default(), f.h.w.len()));
mem.0.scatter_assign(&f.s.table, s);
for op_ix in order {
let op_labels = (&op_ix >> &f.h.x).unwrap();
let input_indexes = f.h.s.map_indexes(&op_ix).unwrap();
let input_values = input_indexes.map_semifinite(&mem).unwrap();
let outputs = apply(op_labels, input_values);
let output_indexes = f.h.t.map_indexes(&op_ix).unwrap();
mem.0
.scatter_assign(&output_indexes.values.table, outputs.values.0);
}
let outputs = mem.0.gather(f.t.table.get_range(..));
(mem.0, outputs)
}