use std::marker::PhantomData;
use hugr_core::Hugr;
use relrc::EquivalenceResolver;
use wyhash::wyhash;
use crate::state_space::CommitData;
pub trait Resolver: Clone + Default + EquivalenceResolver<CommitData, ()> {}
impl<T: Clone + Default + EquivalenceResolver<CommitData, ()>> Resolver for T {}
#[derive(Clone, Debug, Default, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct PointerEqResolver;
impl<N, E: Clone> EquivalenceResolver<N, E> for PointerEqResolver {
type MergeMapping = ();
type DedupKey = *const N;
fn id(&self) -> String {
"PointerEqResolver".to_string()
}
fn dedup_key(&self, value: &N, _incoming_edges: &[&E]) -> Self::DedupKey {
value as *const N
}
fn try_merge_mapping(
&self,
a_value: &N,
_a_incoming_edges: &[&E],
b_value: &N,
_b_incoming_edges: &[&E],
) -> Result<Self::MergeMapping, relrc::resolver::NotEquivalent> {
if std::ptr::eq(a_value, b_value) {
Ok(())
} else {
Err(relrc::resolver::NotEquivalent)
}
}
fn move_edge_source(&self, _mapping: &Self::MergeMapping, edge: &E) -> E {
edge.clone()
}
}
#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub struct SerdeHashResolver<H>(#[serde(skip)] PhantomData<H>);
impl<H> Default for SerdeHashResolver<H> {
fn default() -> Self {
Self(PhantomData)
}
}
impl<H> SerdeHashResolver<H> {
fn hash(value: &impl serde::Serialize) -> u64 {
let bytes = serde_json::to_vec(value).unwrap();
const SEED: u64 = 0;
wyhash(&bytes, SEED)
}
}
impl<H: serde::Serialize + From<Hugr>> EquivalenceResolver<CommitData, ()>
for SerdeHashResolver<H>
{
type MergeMapping = ();
type DedupKey = u64;
fn id(&self) -> String {
"SerdeHashResolver".to_string()
}
fn dedup_key(&self, value: &CommitData, _incoming_edges: &[&()]) -> Self::DedupKey {
let ser_value = value.clone().into_serial::<H>();
Self::hash(&ser_value)
}
fn try_merge_mapping(
&self,
a_value: &CommitData,
_a_incoming_edges: &[&()],
b_value: &CommitData,
_b_incoming_edges: &[&()],
) -> Result<Self::MergeMapping, relrc::resolver::NotEquivalent> {
let a_ser_value = a_value.clone().into_serial::<H>();
let b_ser_value = b_value.clone().into_serial::<H>();
if Self::hash(&a_ser_value) == Self::hash(&b_ser_value) {
Ok(())
} else {
Err(relrc::resolver::NotEquivalent)
}
}
fn move_edge_source(&self, _mapping: &Self::MergeMapping, _edge: &()) {}
}
#[cfg(test)]
mod tests {
use hugr_core::{builder::endo_sig, ops::FuncDefn};
use super::*;
use crate::{CommitData, tests::WrappedHugr};
#[test]
fn test_serde_hash_resolver_equality() {
let resolver = SerdeHashResolver::<WrappedHugr>::default();
let base_data = CommitData::Base(Hugr::new());
let cloned_data = base_data.clone();
let result = resolver.try_merge_mapping(&base_data, &[], &cloned_data, &[]);
assert!(result.is_ok());
let repl_data = CommitData::Base(
Hugr::new_with_entrypoint(FuncDefn::new("dummy", endo_sig(vec![]))).unwrap(),
);
let result = resolver.try_merge_mapping(&base_data, &[], &repl_data, &[]);
assert!(result.is_err());
}
}