use std::{
cell::RefCell,
collections::{BTreeSet, HashMap, VecDeque},
vec,
};
use hugr_core::{
Hugr, HugrView, Node,
hugr::patch::{Patch, simple_replace},
};
use itertools::Itertools;
use relrc::HistoryGraph;
use crate::{
Commit, CommitData, CommitId, CommitStateSpace, InvalidCommit, PatchNode, PersistentReplacement,
};
mod cache;
use cache::PersistentHugrCache;
pub mod serial;
#[derive(Clone, derive_more::Debug)]
pub struct PersistentHugr {
graph: HistoryGraph<CommitData, ()>,
base_commit_id: CommitId,
#[debug(skip)]
cache: RefCell<PersistentHugrCache>,
}
impl PersistentHugr {
pub fn with_base(hugr: Hugr) -> Self {
let state_space = CommitStateSpace::new();
let base = state_space.try_set_base(hugr).expect("empty state space");
Self::from_commit(base)
}
pub fn from_commit(commit: Commit) -> Self {
Self::try_new([commit]).expect("invalid commit")
}
pub fn try_new<'a>(
commits: impl IntoIterator<Item = Commit<'a>>,
) -> Result<Self, InvalidCommit> {
let commits = get_ancestors_while(commits, |_| true);
let state_space = commits
.front()
.ok_or(InvalidCommit::NonUniqueBase(0))?
.state_space();
let all_commit_ids = BTreeSet::from_iter(commits.iter().map(|c| c.as_ptr()));
let mut graph = HistoryGraph::with_registry(state_space.to_registry());
for commit in commits {
if commit.state_space() != state_space {
return Err(InvalidCommit::NonUniqueBase(2));
}
let selected_children = commit
.children(&state_space)
.filter(|c| all_commit_ids.contains(&c.as_ptr()));
if let Some(node) = find_conflicting_node(commit.id(), selected_children) {
return Err(InvalidCommit::IncompatibleHistory(commit.id(), node));
}
graph.insert_node(commit.into());
}
let base_commit = graph
.all_node_ids()
.filter(|&id| {
matches!(
graph.get_node(id).expect("valid ID").value(),
CommitData::Base(_)
)
})
.exactly_one()
.map_err(|err| InvalidCommit::NonUniqueBase(err.count()))?;
Ok(Self {
graph,
base_commit_id: base_commit,
cache: RefCell::new(PersistentHugrCache::default()),
})
}
pub fn add_replacement(&mut self, replacement: PersistentReplacement) -> CommitId {
self.try_add_replacement(replacement)
.expect("invalid replacement")
}
pub fn try_add_replacement(
&mut self,
replacement: PersistentReplacement,
) -> Result<CommitId, InvalidCommit> {
let new_invalid_nodes = replacement
.subgraph()
.nodes()
.iter()
.map(|&PatchNode(id, node)| (id, node))
.into_grouping_map()
.collect::<BTreeSet<_>>();
for (parent, new_invalid_nodes) in new_invalid_nodes {
let invalidation_set = self.deleted_nodes(parent).collect();
if let Some(&node) = new_invalid_nodes.intersection(&invalidation_set).next() {
return Err(InvalidCommit::IncompatibleHistory(parent, node));
}
}
let commit = Commit::try_from_replacement(replacement, self.state_space())?;
let commit = unsafe { commit.upgrade_lifetime() };
self.try_add_commit(commit)
}
pub fn try_add_commit(&mut self, commit: Commit) -> Result<CommitId, InvalidCommit> {
self.try_add_commits([commit.clone()])?;
Ok(commit.id())
}
pub fn try_add_commits<'a>(
&mut self,
commits: impl IntoIterator<Item = Commit<'a>>,
) -> Result<(), InvalidCommit> {
let new_commits = get_ancestors_while(commits, |c| !self.contains(c));
for new_commit in new_commits.iter().rev() {
let new_commit_id = new_commit.id();
if &new_commit.state_space() != self.state_space() {
return Err(InvalidCommit::NonUniqueBase(2));
}
let curr_children = self
.children_commits(new_commit_id)
.map(|id| self.get_commit(id));
let new_children = new_commits
.iter()
.filter(|&c| c.parents().any(|p| p.as_ptr() == new_commit.as_ptr()));
if let Some(node) = find_conflicting_node(
new_commit_id,
curr_children
.chain(new_children)
.unique_by(|c| c.as_ptr())
.map(|c| c.to_owned()),
) {
return Err(InvalidCommit::IncompatibleHistory(new_commit_id, node));
}
self.graph.insert_node(new_commit.clone().into());
for parent in new_commit.parents() {
self.cache.borrow_mut().invalidate_children(parent.id());
}
}
Ok(())
}
pub fn is_valid(&self) -> Result<(), InvalidCommit> {
let mut found_base = false;
for id in self.all_commit_ids() {
let commit = self.get_commit(id);
if matches!(commit.value(), CommitData::Base(_)) {
found_base = true;
if id != self.base_commit_id {
return Err(InvalidCommit::NonUniqueBase(2));
}
}
let children = self
.children_commits(id)
.map(|child_id| self.get_commit(child_id).clone());
if let Some(already_invalid) = find_conflicting_node(id, children) {
return Err(InvalidCommit::IncompatibleHistory(id, already_invalid));
}
}
if !found_base {
return Err(InvalidCommit::NonUniqueBase(0));
}
Ok(())
}
pub fn state_space(&self) -> &CommitStateSpace {
self.graph.registry().into()
}
pub fn base(&self) -> CommitId {
self.base_commit_id
}
pub fn base_hugr(&self) -> &Hugr {
let CommitData::Base(hugr) = self.get_commit(self.base_commit_id).value() else {
panic!("base commit is not a base hugr");
};
hugr
}
pub fn get_commit(&self, commit_id: CommitId) -> &Commit<'_> {
self.graph
.get_node(commit_id)
.expect("invalid commit ID")
.into()
}
pub fn contains(&self, commit: &Commit) -> bool {
self.graph.contains(commit.as_relrc())
}
pub fn contains_id(&self, commit_id: CommitId) -> bool {
self.graph.contains_id(commit_id)
}
pub fn base_commit(&self) -> &Commit<'_> {
self.get_commit(self.base())
}
pub fn all_commit_ids(&self) -> impl Iterator<Item = CommitId> + Clone + '_ {
self.graph.all_node_ids()
}
fn toposort_commits(&self) -> Vec<CommitId> {
petgraph::algo::toposort(&self.graph, None).expect("history is a DAG")
}
pub fn children_commits(&self, commit_id: CommitId) -> impl Iterator<Item = CommitId> + '_ {
self.cache
.borrow_mut()
.children_or_insert(commit_id, || self.graph.children(commit_id).collect())
.clone()
.into_iter()
}
pub fn parent_commits(&self, commit_id: CommitId) -> impl Iterator<Item = CommitId> + '_ {
self.graph.parents(commit_id)
}
pub fn to_hugr(&self) -> Hugr {
self.apply_all().0
}
pub fn apply_all(&self) -> (Hugr, HashMap<PatchNode, Node>) {
let mut hugr = self.base_hugr().clone();
let mut node_map = HashMap::from_iter(hugr.nodes().map(|n| (PatchNode(self.base(), n), n)));
for commit_id in self.toposort_commits() {
let Some(repl) = self.get_commit(commit_id).replacement() else {
continue;
};
let repl = repl
.map_host_nodes(|n| node_map[&n], &hugr)
.expect("invalid replacement");
let simple_replace::Outcome {
node_map: new_node_map,
removed_nodes,
} = repl.apply(&mut hugr).expect("invalid replacement");
debug_assert!(
hugr.validate().is_ok(),
"malformed patch in persistent hugr:\n{}",
hugr.mermaid_string()
);
for (old_node, new_node) in new_node_map {
let old_patch_node = PatchNode(commit_id, old_node);
node_map.insert(old_patch_node, new_node);
}
for remove_node in removed_nodes.into_keys() {
let &remove_patch_node = node_map
.iter()
.find_map(|(patch_node, &hugr_node)| {
(hugr_node == remove_node).then_some(patch_node)
})
.expect("node not found in node_map");
node_map.remove(&remove_patch_node);
}
}
(hugr, node_map)
}
pub fn deleted_nodes<'a>(&'a self, commit_id: CommitId) -> impl Iterator<Item = Node> + 'a {
self.children_commits(commit_id).flat_map(move |child_id| {
let all_invalidated = self.get_commit(child_id).deleted_parent_nodes();
all_invalidated
.filter_map(move |PatchNode(owner, node)| (owner == commit_id).then_some(node))
})
}
pub fn contains_node(&self, PatchNode(commit_id, node): PatchNode) -> bool {
let is_replacement_io = || {
let commit = self.get_commit(commit_id);
commit
.replacement()
.is_some_and(|repl| repl.get_replacement_io().contains(&node))
};
let is_deleted = || self.deleted_nodes(commit_id).contains(&node);
self.contains_id(commit_id) && !is_replacement_io() && !is_deleted()
}
}
fn get_ancestors_while<'a>(
commits: impl IntoIterator<Item = Commit<'a>>,
continue_fn: impl Fn(&Commit) -> bool,
) -> VecDeque<Commit<'a>> {
let mut seen_ids = BTreeSet::new();
let commits = commits.into_iter();
let mut all_commits = VecDeque::with_capacity(commits.size_hint().0);
for commit in commits {
if !seen_ids.insert(commit.as_ptr()) {
continue;
}
let start = all_commits.len();
let mut ind = start;
all_commits.push_back(commit);
while ind < all_commits.len() {
let commit = all_commits[ind].clone();
ind += 1;
if !continue_fn(&commit) {
continue;
}
for commit in commit.parents() {
if seen_ids.insert(commit.as_ptr()) {
all_commits.push_back(commit.clone());
}
}
}
all_commits.rotate_right(all_commits.len() - start);
}
all_commits
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum NodeStatus {
Deleted(CommitId),
ReplacementIO,
Valid,
}
impl PersistentHugr {
pub(crate) fn to_persistent_node(&self, node: Node, commit_id: CommitId) -> PatchNode {
PatchNode(commit_id, node)
}
pub(crate) fn find_deleting_commit(
&self,
node @ PatchNode(commit_id, _): PatchNode,
) -> Option<CommitId> {
let mut children = self.children_commits(commit_id);
children.find(move |&child_id| {
let child = self.get_commit(child_id);
child.deleted_parent_nodes().contains(&node)
})
}
pub(crate) fn node_status(
&self,
per_node @ PatchNode(commit_id, node): PatchNode,
) -> NodeStatus {
debug_assert!(self.contains_id(commit_id), "unknown commit");
if self
.get_commit(commit_id)
.replacement()
.is_some_and(|repl| repl.get_replacement_io().contains(&node))
{
NodeStatus::ReplacementIO
} else if let Some(commit_id) = self.find_deleting_commit(per_node) {
NodeStatus::Deleted(commit_id)
} else {
NodeStatus::Valid
}
}
}
impl<'a> IntoIterator for &'a PersistentHugr {
type Item = Commit<'a>;
type IntoIter = vec::IntoIter<Commit<'a>>;
fn into_iter(self) -> Self::IntoIter {
self.graph
.all_node_ids()
.map(|id| self.get_commit(id).clone())
.collect_vec()
.into_iter()
}
}
pub(crate) fn find_conflicting_node<'a>(
commit_id: CommitId,
children: impl IntoIterator<Item = Commit<'a>>,
) -> Option<Node> {
let mut all_invalidated = BTreeSet::new();
children.into_iter().find_map(|child| {
let mut new_invalidated =
child
.deleted_parent_nodes()
.filter_map(|PatchNode(del_commit_id, node)| {
(del_commit_id == commit_id).then_some(node)
});
new_invalidated.find(|&n| !all_invalidated.insert(n))
})
}
#[cfg(test)]
mod tests {
use super::*;
use hugr_core::{
builder::{DFGBuilder, Dataflow, DataflowHugr, endo_sig},
extension::prelude::bool_t,
hugr::views::SiblingSubgraph,
ops::handle::NodeHandle,
std_extensions::logic::LogicOp,
};
use rstest::*;
fn notop_hugr() -> (Hugr, Node, Node) {
let mut builder = DFGBuilder::new(endo_sig(vec![bool_t()])).unwrap();
let [input] = builder.input_wires_arr();
let notop = builder.add_dataflow_op(LogicOp::Not, [input]).unwrap();
let notop2 = builder
.add_dataflow_op(LogicOp::Not, notop.outputs())
.unwrap();
let [out] = notop2.outputs_arr();
(
builder.finish_hugr_with_outputs([out]).unwrap(),
notop.node(),
notop2.node(),
)
}
fn add_commit(persistent_hugr: &mut PersistentHugr, node: PatchNode) -> (CommitId, PatchNode) {
let (repl_hugr, repl_not, _) = notop_hugr();
let repl1 = PersistentReplacement::try_new(
SiblingSubgraph::from_node(node, &persistent_hugr),
&persistent_hugr,
repl_hugr,
)
.unwrap();
let commit = persistent_hugr.add_replacement(repl1);
(commit, PatchNode(commit, repl_not))
}
#[fixture]
fn linear_commits() -> (PersistentHugr, Vec<CommitId>) {
let (base_hugr, notop, _) = notop_hugr();
let mut persistent_hugr = PersistentHugr::with_base(base_hugr);
let base_not = persistent_hugr.base_commit().to_patch_node(notop);
let (cm1, cm1_not) = add_commit(&mut persistent_hugr, base_not);
let (cm2, cm2_not) = add_commit(&mut persistent_hugr, cm1_not);
let (cm3, _cm3_not) = add_commit(&mut persistent_hugr, cm2_not);
let base_id = persistent_hugr.base();
(persistent_hugr, vec![base_id, cm1, cm2, cm3])
}
#[fixture]
fn branching_commits() -> (PersistentHugr, Vec<CommitId>) {
let (base_hugr, notop, notop2) = notop_hugr();
let mut persistent_hugr = PersistentHugr::with_base(base_hugr);
let base_commit = persistent_hugr.base_commit();
let base_not = base_commit.to_patch_node(notop);
let base_not2 = base_commit.to_patch_node(notop2);
let base_id = base_commit.id();
let (cm1, cm1_not) = add_commit(&mut persistent_hugr, base_not);
let (cm2, _cm2_not) = add_commit(&mut persistent_hugr, cm1_not);
let (cm3, cm3_not) = add_commit(&mut persistent_hugr, base_not2);
let (cm4, _cm4_not) = add_commit(&mut persistent_hugr, cm3_not);
(persistent_hugr, vec![base_id, cm1, cm2, cm3, cm4])
}
#[rstest]
fn test_get_ancestors_while_linear_chain(linear_commits: (PersistentHugr, Vec<CommitId>)) {
let (persistent_hugr, commit_ids) = linear_commits;
let commits = commit_ids
.iter()
.map(|&id| persistent_hugr.get_commit(id).clone())
.collect_vec();
let ancestors = get_ancestors_while([commits[3].clone()], |_| true);
let ancestor_ids: Vec<_> = ancestors.iter().map(|c| c.id()).collect();
assert_eq!(
ancestor_ids,
vec![commit_ids[3], commit_ids[2], commit_ids[1], commit_ids[0]]
);
}
#[rstest]
fn test_get_ancestors_while_branching_structure(
branching_commits: (PersistentHugr, Vec<CommitId>),
) {
let (persistent_hugr, commit_ids) = branching_commits;
let commits = commit_ids
.iter()
.map(|&id| persistent_hugr.get_commit(id).clone())
.collect_vec();
let ancestors = get_ancestors_while([commits[2].clone(), commits[4].clone()], |_| true);
let ancestor_ids: Vec<_> = ancestors.iter().map(|c| c.id()).collect();
let valid_orderings = [
vec![
commit_ids[4],
commit_ids[3],
commit_ids[2],
commit_ids[1],
commit_ids[0],
],
vec![
commit_ids[2],
commit_ids[1],
commit_ids[4],
commit_ids[3],
commit_ids[0],
],
];
assert!(valid_orderings.contains(&ancestor_ids));
}
#[rstest]
fn test_get_ancestors_while_with_filter(linear_commits: (PersistentHugr, Vec<CommitId>)) {
let (persistent_hugr, commit_ids) = linear_commits;
let commits = commit_ids
.iter()
.map(|&id| persistent_hugr.get_commit(id).clone())
.collect_vec();
let [_base, commit1, commit2, commit3] = commits.try_into().unwrap();
let ancestors = get_ancestors_while([commit3.clone()], |c| c.id() != commit1.id());
let ancestor_ids: Vec<_> = ancestors.iter().map(|c| c.id()).collect();
assert_eq!(ancestor_ids, vec![commit3.id(), commit2.id(), commit1.id()]);
}
}