use crate::array::vec::VecArray;
use crate::lax::*;
use crate::strict::vec::{FiniteFunction, IndexedCoproduct};
pub trait Functor<O1, A1, O2, A2> {
fn map_object(&self, o: &O1) -> impl ExactSizeIterator<Item = O2>;
fn map_operation(&self, a: &A1, source: &[O1], target: &[O1]) -> OpenHypergraph<O2, A2>;
fn map_arrow(&self, f: &OpenHypergraph<O1, A1>) -> OpenHypergraph<O2, A2>;
}
pub fn try_define_map_arrow<O1: Clone, A1, O2: Clone, A2: Clone>(
functor: &impl Functor<O1, A1, O2, A2>,
f: &OpenHypergraph<O1, A1>,
) -> Option<OpenHypergraph<O2, A2>> {
if !f.hypergraph.is_strict() {
return None;
}
let fx = map_operations(functor, f);
let fw = map_objects(functor, f);
spider_map_arrow(f, &fw, fx)
}
pub fn map_arrow_witness<O1: Clone, A1, O2: Clone, A2: Clone>(
functor: &impl Functor<O1, A1, O2, A2>,
f: &OpenHypergraph<O1, A1>,
) -> Option<(OpenHypergraph<O2, A2>, IndexedCoproduct<FiniteFunction>)> {
if !f.hypergraph.is_strict() {
return None;
}
let fx = map_operations(functor, f);
let fw = map_objects(functor, f);
let result = spider_map_arrow(f, &fw, fx)?;
let n: usize = fw.iter().map(|v| v.len()).sum();
let total_result_nodes = result.hypergraph.nodes.len();
let witness_values = FiniteFunction::new(VecArray((n..2 * n).collect()), total_result_nodes)?;
let fw_sizes = FiniteFunction::new(VecArray(fw.iter().map(|v| v.len()).collect()), n + 1)?;
let witness = IndexedCoproduct::new(fw_sizes, witness_values)?;
Some((result, witness))
}
fn spider_map_arrow<O1, A1, O2: Clone, A2: Clone>(
f: &OpenHypergraph<O1, A1>,
fw: &[Vec<O2>],
fx: OpenHypergraph<O2, A2>,
) -> Option<OpenHypergraph<O2, A2>> {
let fw_flat: Vec<O2> = fw.iter().flat_map(|v| v.iter().cloned()).collect();
let fw_total = fw_flat.len();
let fs = map_half_spider(fw, &f.sources)?;
let ft = map_half_spider(fw, &f.targets)?;
let all_edge_sources: Vec<NodeId> = f
.hypergraph
.adjacency
.iter()
.flat_map(|adj| adj.sources.iter().copied())
.collect();
let all_edge_targets: Vec<NodeId> = f
.hypergraph
.adjacency
.iter()
.flat_map(|adj| adj.targets.iter().copied())
.collect();
let e_s = map_half_spider(fw, &all_edge_sources)?;
let e_t = map_half_spider(fw, &all_edge_targets)?;
let id_fn = FiniteFunction::identity(fw_total);
let i = OpenHypergraph::<O2, A2>::identity(fw_flat.clone());
let sx = OpenHypergraph::<O2, A2>::spider(fs, (&id_fn + &e_s)?, fw_flat.clone())?;
let yt = OpenHypergraph::<O2, A2>::spider((&id_fn + &e_t)?, ft, fw_flat)?;
sx.lax_compose(&i.tensor(&fx))?.lax_compose(&yt)
}
fn map_half_spider<O>(fw: &[Vec<O>], node_ids: &[NodeId]) -> Option<FiniteFunction> {
let fw_total: usize = fw.iter().map(|v| v.len()).sum();
let fw_sizes =
FiniteFunction::new(VecArray(fw.iter().map(|v| v.len()).collect()), fw_total + 1)?;
let node_count = fw.len();
let f = FiniteFunction::new(VecArray(node_ids.iter().map(|n| n.0).collect()), node_count)?;
fw_sizes.injections(&f)
}
fn map_operations<O1: Clone, A1, O2: Clone, A2: Clone>(
functor: &impl Functor<O1, A1, O2, A2>,
f: &OpenHypergraph<O1, A1>,
) -> OpenHypergraph<O2, A2> {
let mut result = OpenHypergraph::empty();
for (i, a) in f.hypergraph.edges.iter().enumerate() {
let source: Vec<O1> = f.hypergraph.adjacency[i]
.sources
.iter()
.map(|nid| f.hypergraph.nodes[nid.0].clone())
.collect();
let target: Vec<O1> = f.hypergraph.adjacency[i]
.targets
.iter()
.map(|nid| f.hypergraph.nodes[nid.0].clone())
.collect();
result.tensor_assign(functor.map_operation(a, &source, &target));
}
result
}
fn map_objects<O1, A1, O2, A2>(
functor: &impl Functor<O1, A1, O2, A2>,
f: &OpenHypergraph<O1, A1>,
) -> Vec<Vec<O2>> {
f.hypergraph
.nodes
.iter()
.map(|o| functor.map_object(o).collect())
.collect()
}