use open_hypergraphs::array::vec::*;
use open_hypergraphs::indexed_coproduct::*;
use open_hypergraphs::semifinite::*;
use open_hypergraphs::strict::eval::*;
use crate::theory::polycirc::*;
pub fn apply_op<T: Semiring + Copy>(op: &Arr, args: &[T]) -> Vec<T> {
use Arr::*;
match op {
Add => vec![args.iter().copied().sum()],
Zero => vec![T::zero()],
Mul => vec![args.iter().copied().product()],
One => vec![T::one()],
Copy => vec![args[0], args[0]],
Discard => vec![],
}
}
pub fn apply<T: Clone + PartialEq + Semiring + Copy>(
ops: SemifiniteFunction<VecKind, Arr>,
args: IndexedCoproduct<VecKind, SemifiniteFunction<VecKind, T>>,
) -> IndexedCoproduct<VecKind, SemifiniteFunction<VecKind, T>> {
let args: Vec<SemifiniteFunction<VecKind, T>> = args.into_iter().collect();
let mut coargs = Vec::with_capacity(args.len());
for (op, x) in ops.0.iter().zip(args.iter()) {
coargs.push(apply_op(op, &x.0));
}
let sizes: Vec<usize> = coargs.iter().map(|v| v.len()).collect();
let flat_values: Vec<T> = coargs.into_iter().flatten().collect();
IndexedCoproduct::from_semifinite(
SemifiniteFunction(VecArray(sizes)),
SemifiniteFunction(VecArray(flat_values)),
)
.expect("Invalid IndexedCoproduct construction")
}
fn square() -> Option<Term> {
use Arr::*;
&arr(Copy) >> &arr(Mul)
}
#[test]
fn test_square() {
let f = square().unwrap();
assert_eq!(f.source(), mktype(1));
assert_eq!(f.target(), mktype(1));
let inputs = VecArray(vec![3]);
let result = eval::<VecKind, Obj, Arr, usize>(&f, inputs, apply).expect("eval failed");
assert_eq!(result, VecArray(vec![9]));
}
#[test]
fn test_parallel_squares() {
let f = square().unwrap();
let g = &f | &f;
assert_eq!(f.source(), mktype(1));
assert_eq!(f.target(), mktype(1));
let result =
eval::<VecKind, Obj, Arr, usize>(&g, VecArray(vec![3, 4]), apply).expect("eval failed");
assert_eq!(result, VecArray(vec![9, 16]));
}