use std::collections::{BTreeMap, HashMap};
use derive_more::derive::{From, Into};
use hugr_core::{
IncomingPort, Node, OutgoingPort, SimpleReplacement,
builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig, inout_sig},
extension::prelude::bool_t,
hugr::{Hugr, HugrView, patch::Patch, views::SiblingSubgraph},
ops::handle::NodeHandle,
std_extensions::logic::LogicOp,
};
use rstest::*;
use crate::{
Commit, CommitStateSpace, PatchNode, PersistentHugr, PersistentReplacement,
state_space::CommitId,
};
fn simple_hugr() -> (Hugr, [Node; 3]) {
let mut dfg_builder =
DFGBuilder::new(inout_sig(vec![bool_t(), bool_t()], vec![bool_t()])).unwrap();
let [b0, b1] = dfg_builder.input_wires_arr();
let not0 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b0]).unwrap();
let [b0_not] = not0.outputs_arr();
let not1 = dfg_builder.add_dataflow_op(LogicOp::Not, vec![b1]).unwrap();
let [b1_not] = not1.outputs_arr();
let and = dfg_builder
.add_dataflow_op(LogicOp::And, vec![b0_not, b1_not])
.unwrap();
let hugr = dfg_builder.finish_hugr_with_outputs(and.outputs()).unwrap();
(hugr, [not0.node(), not1.node(), and.node()])
}
fn create_double_not_replacement(hugr: &Hugr, node_to_replace: Node) -> SimpleReplacement {
let mut dfg_builder = DFGBuilder::new(inout_sig(vec![bool_t()], vec![bool_t()])).unwrap();
let [input_wire] = dfg_builder.input_wires_arr();
let not1 = dfg_builder
.add_dataflow_op(LogicOp::Not, vec![input_wire])
.unwrap();
let [not1_out] = not1.outputs_arr();
let not2 = dfg_builder
.add_dataflow_op(LogicOp::Not, vec![not1_out])
.unwrap();
let [not2_out] = not2.outputs_arr();
let replacement_hugr = dfg_builder.finish_hugr_with_outputs([not2_out]).unwrap();
let mut nu_inp = HashMap::new();
nu_inp.insert(
(not1.node(), IncomingPort::from(0)),
(node_to_replace, IncomingPort::from(0)),
);
let mut nu_out = HashMap::new();
nu_out.insert(
(node_to_replace, OutgoingPort::from(0)),
IncomingPort::from(0),
);
let subgraph = SiblingSubgraph::try_from_nodes(vec![node_to_replace], hugr).unwrap();
SimpleReplacement::try_new(subgraph, hugr, replacement_hugr).unwrap()
}
fn create_not_and_to_xor_replacement(hugr: &Hugr) -> SimpleReplacement {
let and_gate = hugr
.nodes()
.find(|&n| hugr.get_optype(n) == &LogicOp::And.into())
.unwrap();
let not_node = hugr.input_neighbours(and_gate).next().unwrap();
let mut dfg_builder =
DFGBuilder::new(inout_sig(vec![bool_t(), bool_t()], vec![bool_t()])).unwrap();
let [in1, in2] = dfg_builder.input_wires_arr();
let xor_op = dfg_builder
.add_dataflow_op(LogicOp::Xor, vec![in1, in2])
.unwrap();
let replacement_hugr = dfg_builder
.finish_hugr_with_outputs(xor_op.outputs())
.unwrap();
let mut nu_inp = HashMap::new();
nu_inp.insert(
(xor_op.node(), IncomingPort::from(0)),
(not_node, IncomingPort::from(0)),
);
nu_inp.insert(
(xor_op.node(), IncomingPort::from(1)),
(and_gate, IncomingPort::from(1)),
);
let mut nu_out = HashMap::new();
nu_out.insert((and_gate, OutgoingPort::from(0)), IncomingPort::from(0));
let subgraph = SiblingSubgraph::try_from_nodes(vec![not_node, and_gate], &hugr).unwrap();
SimpleReplacement::try_new(subgraph, hugr, replacement_hugr).unwrap()
}
pub struct TestStateSpace {
#[allow(dead_code)] state_space: CommitStateSpace,
commits: Vec<Commit<'static>>,
}
impl TestStateSpace {
fn new<'a>(state_space: CommitStateSpace, commits: Vec<Commit<'a>>) -> Self {
assert!(commits.iter().all(|c| c.state_space() == state_space));
let commits = commits
.into_iter()
.map(|c| unsafe { c.upgrade_lifetime() })
.collect();
Self {
state_space,
commits,
}
}
#[allow(dead_code)]
pub fn state_space(&self) -> &CommitStateSpace {
&self.state_space
}
pub fn commits<const N: usize>(&self) -> &[Commit<'_>; N] {
TryFrom::try_from(self.commits.as_slice()).unwrap()
}
}
#[fixture]
pub(crate) fn test_state_space() -> TestStateSpace {
let (base_hugr, [not0_node, not1_node, _and_node]) = simple_hugr();
let state_space = CommitStateSpace::new();
let base = state_space.try_set_base(base_hugr).unwrap();
let replacement1 = create_double_not_replacement(base.commit_hugr(), not0_node);
let commit1 = {
let new_host = PersistentHugr::try_new([base.clone()]).unwrap();
let replacement1 = replacement1
.map_host_nodes(|n| base.to_patch_node(n), &new_host)
.unwrap();
Commit::try_from_replacement(replacement1, &state_space).unwrap()
};
let commit2 = {
let mut direct_hugr = base.commit_hugr().clone();
let node_map = replacement1
.clone()
.apply(&mut direct_hugr)
.unwrap()
.node_map;
let replacement2 = create_not_and_to_xor_replacement(&direct_hugr);
let inv_node_map = {
let mut inv = BTreeMap::new();
for (repl_node, hugr_node) in node_map {
inv.insert(hugr_node, repl_node);
}
inv
};
let to_patch_node = {
|n| {
if let Some(&n) = inv_node_map.get(&n) {
commit1.to_patch_node(n)
} else {
base.to_patch_node(n)
}
}
};
let new_host = PersistentHugr::try_new([commit1.clone()]).unwrap();
let replacement2 = replacement2
.map_host_nodes(to_patch_node, &new_host)
.unwrap();
Commit::try_from_replacement(replacement2, &state_space).unwrap()
};
let commit3 = {
let replacement3 = create_not_and_to_xor_replacement(base.commit_hugr());
let new_host = PersistentHugr::try_new([commit1.clone()]).unwrap();
let replacement3 = replacement3
.map_host_nodes(|n| base.to_patch_node(n), &new_host)
.unwrap();
Commit::try_from_replacement(replacement3, &state_space).unwrap()
};
let commit4 = {
let replacement4 = create_double_not_replacement(base.commit_hugr(), not1_node);
let new_host = PersistentHugr::try_new([commit1.clone()]).unwrap();
let replacement4 = replacement4
.map_host_nodes(|n| base.to_patch_node(n), &new_host)
.unwrap();
Commit::try_from_replacement(replacement4, &state_space).unwrap()
};
TestStateSpace::new(
state_space.clone(),
vec![commit1, commit2, commit3, commit4],
)
}
#[fixture]
pub(super) fn persistent_hugr_empty_child() -> (PersistentHugr, [CommitId; 2], [PatchNode; 3]) {
let (triple_not_hugr, not_nodes) = {
let mut dfg_builder = DFGBuilder::new(endo_sig([bool_t()])).unwrap();
let [mut w] = dfg_builder.input_wires_arr();
let not_nodes = [(); 3].map(|()| {
let handle = dfg_builder.add_dataflow_op(LogicOp::Not, vec![w]).unwrap();
[w] = handle.outputs_arr();
handle.node()
});
(
dfg_builder.finish_hugr_with_outputs([w]).unwrap(),
not_nodes,
)
};
let mut hugr = PersistentHugr::with_base(triple_not_hugr);
let empty_hugr = {
let dfg_builder = DFGBuilder::new(endo_sig([bool_t()])).unwrap();
let inputs = dfg_builder.input_wires();
dfg_builder.finish_hugr_with_outputs(inputs).unwrap()
};
let subg_nodes = [PatchNode(hugr.base(), not_nodes[1])];
let repl = PersistentReplacement::try_new(
SiblingSubgraph::try_from_nodes(subg_nodes, &hugr).unwrap(),
&hugr,
empty_hugr,
)
.unwrap();
let empty_commit = hugr.try_add_replacement(repl).unwrap();
let base_commit = hugr.base();
(
hugr,
[base_commit, empty_commit],
not_nodes.map(|n| PatchNode(base_commit, n)),
)
}
#[rstest]
fn test_successive_replacements(test_state_space: TestStateSpace) {
let [commit1, commit2, _commit3, _commit4] = test_state_space.commits();
let (mut hugr, [not0_node, _not1_node, _and_node]) = simple_hugr();
let replacement1 = create_double_not_replacement(&hugr, not0_node);
replacement1.clone().apply(&mut hugr).unwrap();
let replacement2 = create_not_and_to_xor_replacement(&hugr);
replacement2.clone().apply(&mut hugr).unwrap();
let persistent_hugr = PersistentHugr::try_new([commit1.clone(), commit2.clone()])
.expect("commit1 and commit2 are compatible");
let persistent_final_hugr = persistent_hugr.to_hugr();
assert_eq!(persistent_hugr.all_commit_ids().count(), 3);
assert_eq!(hugr.validate(), Ok(()));
assert_eq!(persistent_final_hugr.validate(), Ok(()));
assert_eq!(
hugr.mermaid_string(),
persistent_final_hugr.mermaid_string()
);
}
#[rstest]
fn test_conflicting_replacements(test_state_space: TestStateSpace) {
let [commit1, _commit2, commit3, _commit4] = test_state_space.commits();
let state_space = commit1.state_space();
let (hugr, [not0_node, _not1_node, _and_node]) = simple_hugr();
let hugr1 = {
let mut hugr = hugr.clone();
let replacement1 = create_double_not_replacement(&hugr, not0_node);
replacement1.apply(&mut hugr).unwrap();
hugr
};
let hugr2 = {
let mut hugr = hugr.clone();
let replacement2 = create_not_and_to_xor_replacement(&hugr);
replacement2.apply(&mut hugr).unwrap();
hugr
};
let persistent_hugr1 = PersistentHugr::try_new([commit1.clone()]).unwrap();
let persistent_hugr2 = PersistentHugr::try_new([commit3.clone()]).unwrap();
assert_eq!(persistent_hugr1.to_hugr().validate(), Ok(()));
assert_eq!(persistent_hugr2.to_hugr().validate(), Ok(()));
let result = state_space.try_create(
persistent_hugr1
.all_commit_ids()
.chain(persistent_hugr2.all_commit_ids()),
);
assert!(
result.is_err(),
"Creating history with conflicting patches should fail"
);
assert_eq!(
hugr1.mermaid_string(),
persistent_hugr1.to_hugr().mermaid_string()
);
assert_eq!(
hugr2.mermaid_string(),
persistent_hugr2.to_hugr().mermaid_string()
);
}
#[rstest]
fn test_disjoint_replacements(test_state_space: TestStateSpace) {
let [commit1, _commit2, _commit3, commit4] = test_state_space.commits();
let (mut hugr, [not0_node, not1_node, _and_node]) = simple_hugr();
let replacement1 = create_double_not_replacement(&hugr, not0_node);
let replacement2 = create_double_not_replacement(&hugr, not1_node);
replacement1.clone().apply(&mut hugr).unwrap();
replacement2.clone().apply(&mut hugr).unwrap();
let persistent_hugr = PersistentHugr::try_new([commit1.clone(), commit4.clone()]).unwrap();
let persistent_final_hugr = persistent_hugr.to_hugr();
assert_eq!(hugr.validate(), Ok(()));
assert_eq!(persistent_final_hugr.validate(), Ok(()));
assert_eq!(persistent_hugr.all_commit_ids().count(), 3);
assert_eq!(
hugr.mermaid_string(),
persistent_final_hugr.mermaid_string()
);
}
#[rstest]
fn test_try_add_replacement(test_state_space: TestStateSpace) {
let [commit1, commit2, commit3, commit4] = test_state_space.commits();
let persistent_hugr = PersistentHugr::try_new([commit1.clone(), commit2.clone()]).unwrap();
{
let mut persistent_hugr = persistent_hugr.clone();
let repl4 = commit4.replacement().unwrap();
let result = persistent_hugr.try_add_replacement(repl4.clone());
assert!(
result.is_ok(),
"[commit1, commit2] + [commit4] are compatible. Got {result:?}"
);
let hugr = persistent_hugr.to_hugr();
let exp_hugr = PersistentHugr::try_new([commit1.clone(), commit2.clone(), commit4.clone()])
.unwrap()
.to_hugr();
assert_eq!(hugr.mermaid_string(), exp_hugr.mermaid_string());
}
{
let mut persistent_hugr = persistent_hugr.clone();
let repl3 = commit3.replacement().unwrap();
let result = persistent_hugr.try_add_replacement(repl3.clone());
assert!(
result.is_err(),
"[commit1, commit2] + [commit3] are incompatible. Got {result:?}"
);
}
}
#[rstest]
fn test_try_add_commit(test_state_space: TestStateSpace) {
let [commit1, commit2, commit3, commit4] = test_state_space.commits();
let state_space = commit1.state_space();
let persistent_hugr = PersistentHugr::try_new([commit1.clone(), commit2.clone()]).unwrap();
{
let mut persistent_hugr = persistent_hugr.clone();
let repl4 = commit4.replacement().unwrap().clone();
let new_commit = Commit::try_from_replacement(repl4, &state_space).unwrap();
let new_commit4_id = persistent_hugr
.try_add_commit(new_commit)
.expect("commit4 is compatible");
let new_commit4 = persistent_hugr.get_commit(new_commit4_id);
assert_eq!(new_commit4.inserted_nodes().count(), 2);
}
{
let mut persistent_hugr = persistent_hugr.clone();
let repl3 = commit3.replacement().unwrap().clone();
let new_commit = Commit::try_from_replacement(repl3, &state_space).unwrap();
persistent_hugr
.try_add_commit(new_commit)
.expect_err("commit3 is incompatible with [commit1, commit2]");
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, From, Into)]
pub(crate) struct WrappedHugr {
#[serde(with = "serial")]
pub hugr: Hugr,
}
mod serial {
use hugr_core::envelope::EnvelopeConfig;
use hugr_core::std_extensions::STD_REG;
use serde::Deserialize;
use super::*;
pub(crate) fn serialize<S>(hugr: &Hugr, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut str = hugr
.store_str(EnvelopeConfig::text())
.map_err(serde::ser::Error::custom)?;
remove_encoder_version(&mut str);
serializer.serialize_str(&str)
}
fn remove_encoder_version(str: &mut String) {
let encoder_pattern = r#""encoder":"hugr-rs v"#;
if let Some(start) = str.find(encoder_pattern)
&& let Some(end) = str[start..].find(r#"","#)
{
let end = start + end + 2; str.replace_range(start..end, "");
}
}
pub(crate) fn deserialize<'de, D>(deserializer: D) -> Result<Hugr, D::Error>
where
D: serde::Deserializer<'de>,
{
let str = String::deserialize(deserializer)?;
Hugr::load_str(str, Some(&STD_REG)).map_err(serde::de::Error::custom)
}
}