use open_hypergraphs::array::{vec::*, *};
use open_hypergraphs::finite_function::*;
use open_hypergraphs::indexed_coproduct::*;
use open_hypergraphs::strict::layer::*;
use core::fmt::Debug;
use crate::theory::polycirc::*;
fn to_slices(c: &IndexedCoproduct<VecKind, FiniteFunction<VecKind>>) -> Vec<Vec<usize>> {
let ptr = c.sources.table.cumulative_sum();
let mut result = Vec::with_capacity(c.len());
for i in 0..c.len() {
result.push(c.values.table[ptr[i]..ptr[i + 1]].to_vec())
}
result
}
fn apply<T: Semiring + Copy>(op: &Arr, args: &Vec<T>) -> Vec<T> {
use Arr::*;
match op {
Add => vec![args.iter().copied().sum()],
Zero => vec![T::zero()],
Mul => vec![args.iter().copied().product()],
One => vec![T::one()],
Copy => vec![args[0], args[0]],
Discard => vec![],
}
}
fn eval<T: Semiring + PartialEq + Clone + Default + Debug>(
f: &Term,
inputs: Vec<T>,
) -> Option<Vec<T>> {
let (layering, unvisited) = layered_operations(f);
if unvisited.0.iter().any(|x| *x == 1) {
None
} else {
let op_layers: Vec<Vec<usize>> = layering.into_iter().map(|x| x.0).collect();
let (_, outputs) = eval_layers(f, inputs, &op_layers);
Some(outputs)
}
}
fn eval_layers<T: Semiring + Clone + PartialEq + Default + Debug>(
f: &Term,
inputs: Vec<T>,
layer: &[Vec<usize>],
) -> (Vec<T>, Vec<T>) {
let mut mem = VecArray(inputs).scatter(&f.s.table, f.h.w.len());
let sources = to_slices(&f.h.s);
let targets = to_slices(&f.h.t);
for ops in layer.iter() {
for i in ops {
let op = &f.h.x.0[*i];
let output_values = apply(op, &mem.gather(&sources[*i]));
for (target_ix, value) in targets[*i].iter().zip(output_values) {
mem[*target_ix] = value;
}
}
}
let outputs = mem.gather(&f.t.table);
(mem.0, outputs.to_vec())
}
fn square() -> Option<Term> {
use Arr::*;
&arr(Copy) >> &arr(Mul)
}
#[test]
fn test_square() {
let f = square().unwrap();
assert_eq!(f.source(), mktype(1));
assert_eq!(f.target(), mktype(1));
let result = eval::<usize>(&f, vec![3]).expect("eval failed");
assert_eq!(result, vec![9]);
}
#[test]
fn test_parallel_squares() {
let f = square().unwrap();
let g = &f | &f;
assert_eq!(f.source(), mktype(1));
assert_eq!(f.target(), mktype(1));
let result = eval::<usize>(&g, vec![3, 4]).expect("eval failed");
assert_eq!(result, vec![9, 16]);
}