use open_hypergraphs::strict::eval::eval;
use open_hypergraphs::strict::vec::*;
use core::ops::{Add, Mul};
use num_traits::{One, Zero};
use std::iter::{Product, Sum};
#[derive(Clone, PartialEq, Eq, Debug)]
pub struct Obj;
#[derive(Clone, PartialEq, Eq, Debug)]
pub enum Arr {
Add,
Zero,
Mul,
One,
Copy,
Discard,
}
pub type Term = OpenHypergraph<Obj, Arr>;
pub trait Semiring: Sized + Add + Zero + Sum + Mul + One + Product + Copy {}
impl Semiring for usize {}
pub fn arr_type(a: &Arr) -> (usize, usize) {
use Arr::*;
match a {
Add => (2, 1),
Zero => (0, 1),
Mul => (2, 1),
One => (0, 1),
Copy => (1, 2),
Discard => (1, 0),
}
}
pub fn mktype(n: usize) -> SemifiniteFunction<Obj> {
SemifiniteFunction::new(VecArray(vec![Obj; n]))
}
pub fn arr(op: Arr) -> Term {
let (a, b) = arr_type(&op);
OpenHypergraph::singleton(op, mktype(a), mktype(b))
}
pub fn apply_op<T: Semiring + Copy>(op: &Arr, args: &[T]) -> Vec<T> {
use Arr::*;
match op {
Add => vec![args.iter().copied().sum()],
Zero => vec![T::zero()],
Mul => vec![args.iter().copied().product()],
One => vec![T::one()],
Copy => vec![args[0], args[0]],
Discard => vec![],
}
}
pub fn apply<T: Clone + PartialEq + Semiring + Copy>(
ops: SemifiniteFunction<Arr>,
args: IndexedCoproduct<SemifiniteFunction<T>>,
) -> IndexedCoproduct<SemifiniteFunction<T>> {
let args: Vec<SemifiniteFunction<T>> = args.into_iter().collect();
let mut coargs = Vec::with_capacity(args.len());
for (op, x) in ops.0.iter().zip(args.iter()) {
coargs.push(apply_op(op, &x.0));
}
let sizes: Vec<usize> = coargs.iter().map(|v| v.len()).collect();
let flat_values: Vec<T> = coargs.into_iter().flatten().collect();
IndexedCoproduct::from_semifinite(
SemifiniteFunction::new(VecArray(sizes)),
SemifiniteFunction::new(VecArray(flat_values)),
)
.expect("Invalid IndexedCoproduct construction")
}
fn square() -> Option<Term> {
use Arr::*;
&arr(Copy) >> &arr(Mul)
}
mod imperative {
use super::{Arr, Obj};
use open_hypergraphs::lax::*;
type Term = open_hypergraphs::lax::OpenHypergraph<Obj, Arr>;
fn binop(state: &mut Term, x: Arr) -> ((NodeId, NodeId), NodeId) {
let (_, (x, y)) = state.new_operation(x, vec![Obj, Obj], vec![Obj]);
((x[0], x[1]), y[0])
}
pub fn mul(state: &mut Term) -> ((NodeId, NodeId), NodeId) {
binop(state, Arr::Mul)
}
pub fn copy(state: &mut Term) -> (NodeId, (NodeId, NodeId)) {
let (_, (x, y)) = state.new_operation(Arr::Copy, vec![Obj], vec![Obj, Obj]);
(x[0], (y[0], y[1]))
}
pub fn square() -> open_hypergraphs::strict::vec::OpenHypergraph<Obj, Arr> {
let mut state = OpenHypergraph::<Obj, Arr>::empty();
let (a, (x0, x1)) = copy(&mut state);
let ((y0, y1), b) = mul(&mut state);
state.unify(x0, y0);
state.unify(x1, y1);
state.sources = vec![a];
state.targets = vec![b];
state.quotient().unwrap();
state.to_strict()
}
}
fn main() {
let f = square().unwrap();
assert_eq!(f.source(), mktype(1));
assert_eq!(f.target(), mktype(1));
let inputs = VecArray(vec![3]);
let result = eval::<VecKind, Obj, Arr, usize>(&f, inputs, apply).expect("eval failed");
println!("3^2 = {:?}", result[0]);
let f = imperative::square();
assert_eq!(f.source(), mktype(1));
assert_eq!(f.target(), mktype(1));
let inputs = VecArray(vec![3]);
let result = eval::<VecKind, Obj, Arr, usize>(&f, inputs, apply).expect("eval failed");
println!("3^2 = {:?}", result[0]);
}