use open_hypergraphs::lax::functor::Functor;
use open_hypergraphs::lax::*;
use open_hypergraphs::strict::vec::FiniteFunction;
use std::fmt::Debug;
use thiserror::Error;
use crate::ssa::{SSAError, ssa};
use crate::theory::Theory;
use crate::tree::*;
use crate::{dual, dual::Dual};
use hexpr::Operation;
#[derive(Debug, Error)]
pub enum Error<O> {
#[error("SSA decomposition failed")]
SSAError(#[from] SSAError),
#[error("Type maps had invalid arity/coarity")]
InvalidTypeMaps,
#[error("Error during type map evaluation {0:?}")]
PartialResult(#[from] PartialResult<O>),
#[error("Unable to quotient type map {0:?}")]
InvalidQuotient(FiniteFunction),
}
#[derive(Debug, Error)]
pub struct PartialResult<O> {
pub partial_result: Vec<Option<Tree<(), O>>>,
pub cause: EvalError,
}
#[derive(Debug, Error)]
pub enum EvalError {
#[error("Could not merge values {0} and {1}")]
MergeError(String, String),
#[error("Could not pop symbol {1} of {0:?}")]
MatchError(EdgeId, String),
}
pub fn check(
theory: &Theory,
source: OpenHypergraph<(), Operation>,
target: OpenHypergraph<(), Operation>,
arrow: &mut OpenHypergraph<(), Operation>,
) -> Result<Vec<Tree<(), Operation>>, Error<Operation>> {
let mut fwd = dual::into_fwd(source);
let mut rev = dual::into_rev(target);
fwd.quotient().map_err(Error::InvalidQuotient)?;
rev.quotient().map_err(Error::InvalidQuotient)?;
arrow.quotient().map_err(Error::InvalidQuotient)?;
let type_map = AsType(theory).map_arrow(arrow);
let mut type_term = fwd
.lax_compose(&type_map)
.and_then(|f| f.lax_compose(&rev))
.ok_or(Error::<Operation>::InvalidTypeMaps)?;
let q = type_term.quotient().map_err(Error::InvalidQuotient)?;
let offset = fwd.hypergraph.nodes.len();
let size = arrow.hypergraph.nodes.len();
let indices = (offset..offset + size).map(|i| q.table[i]);
let results = eval_type(type_term)?;
Ok(indices.map(|i| results[i].clone()).collect())
}
pub fn eval_type<O: Clone + Eq + Debug + std::fmt::Display>(
f: OpenHypergraph<(), Dual<O>>,
) -> Result<Vec<Tree<(), O>>, Error<O>> {
let state: Vec<Option<Tree<(), O>>> = vec![None; f.hypergraph.nodes.len()];
eval_type_with(f, state)
}
pub fn eval_type_with<O: Clone + Eq + Debug + std::fmt::Display>(
f: OpenHypergraph<(), Dual<O>>,
mut state: Vec<Option<Tree<(), O>>>,
) -> Result<Vec<Tree<(), O>>, Error<O>> {
for ssa_value in ssa(f.to_strict())? {
let source_values: Vec<Tree<(), O>> = ssa_value
.sources
.into_iter()
.map(|i| {
state[i.0.0]
.clone()
.unwrap_or_else(|| Tree::Leaf(i.0.0, ()))
})
.collect();
match ssa_value.op {
Dual::Fwd(arr) => {
for (i, node_id) in ssa_value.targets.iter().enumerate() {
merge(
&mut state[node_id.0.0],
Tree::Node(arr.clone(), i, source_values.clone()),
)
.map_err(|cause| PartialResult {
cause,
partial_result: state.clone(),
})?;
}
}
Dual::Rev(op) => {
let mut children = None;
for (i, v) in source_values.into_iter().enumerate() {
match v {
Tree::Node(arr, j, node_children) if i == j && arr == op => {
children = match children {
None => Some(node_children),
Some(children) if children == node_children => Some(children),
_ => {
return Err(PartialResult {
partial_result: state,
cause: EvalError::MatchError(
ssa_value.edge_id,
format!("{op:?} (children didn't match)"),
),
}
.into());
}
}
}
_ => {
return Err(PartialResult {
partial_result: state,
cause: EvalError::MatchError(ssa_value.edge_id, format!("{op:?}")),
}
.into());
}
}
}
let children =
children.unwrap_or_else(|| vec![Tree::Empty; ssa_value.targets.len()]);
for (node_id, child) in ssa_value.targets.iter().zip(children.into_iter()) {
merge(&mut state[node_id.0.0], child).map_err(|cause| PartialResult {
cause,
partial_result: state.clone(),
})?;
}
}
};
}
Ok(state
.into_iter()
.enumerate()
.map(|(i, opt)| opt.unwrap_or_else(|| Tree::Leaf(i, ())))
.collect())
}
pub fn merge<O: Debug + Eq>(
value: &mut Option<Tree<(), O>>,
new: Tree<(), O>,
) -> Result<(), EvalError> {
match value {
None => *value = Some(new),
Some(t) => {
if *t != new {
return Err(EvalError::MergeError(
format!("{:?}", t),
format!("{:?}", new),
));
}
}
}
Ok(())
}
#[derive(Clone)]
struct AsType<'a>(pub &'a Theory);
impl Functor<(), Operation, (), Dual<Operation>> for AsType<'_> {
fn map_object(&self, _: &()) -> impl ExactSizeIterator<Item = ()> {
vec![()].into_iter()
}
fn map_operation(
&self,
a: &Operation,
source: &[()],
target: &[()],
) -> OpenHypergraph<(), Dual<Operation>> {
let arrow = self.0.get_arrow(a).expect("missing arrow in theory");
let (s, t) = &arrow.type_maps;
assert_eq!(source.len(), s.targets.len());
assert_eq!(target.len(), t.targets.len());
dual::into_rev(s.clone())
.compose(&dual::into_fwd(t.clone()))
.unwrap()
}
fn map_arrow(&self, f: &OpenHypergraph<(), Operation>) -> OpenHypergraph<(), Dual<Operation>> {
functor::try_define_map_arrow(self, f).unwrap()
}
}