use std::{marker::PhantomData, mem};
use delegate::delegate;
use hugr_core::{
Direction, Hugr, HugrView, IncomingPort, Node, OutgoingPort, Port,
hugr::{
NodeMetadataMap, internal::HugrInternals, patch::simple_replace::InvalidReplacement,
views::InvalidSignature,
},
ops::OpType,
};
use itertools::{Either, Itertools};
use relrc::RelRc;
use thiserror::Error;
use crate::{
CommitData, CommitId, CommitStateSpace, PatchNode, PersistentReplacement,
subgraph::InvalidPinnedSubgraph,
};
mod boundary;
#[derive(Debug, Clone)]
#[repr(transparent)]
pub struct Commit<'a>(RelRc<CommitData, ()>, PhantomData<&'a ()>);
impl<'a> Commit<'a> {
pub fn try_from_replacement(
replacement: PersistentReplacement,
state_space: &'a CommitStateSpace,
) -> Result<Self, InvalidCommit> {
Self::try_new(replacement, [], state_space)
}
pub(crate) fn new_base(hugr: Hugr, state_space: &'a CommitStateSpace) -> Self {
let commit = RelRc::new(CommitData::Base(hugr));
commit
.try_register_in(state_space.as_registry())
.expect("new node is not yet registered");
Commit(commit, PhantomData)
}
pub fn try_new<'b>(
replacement: PersistentReplacement,
parents: impl IntoIterator<Item = Commit<'b>>,
state_space: &'a CommitStateSpace,
) -> Result<Self, InvalidCommit> {
if replacement.subgraph().nodes().is_empty() {
return Err(InvalidCommit::EmptyReplacement);
}
let repl_parents = get_parent_commits(&replacement, state_space)?
.into_iter()
.map_into::<RelRc<_, ()>>();
let parents = parents
.into_iter()
.map_into::<RelRc<_, ()>>()
.chain(repl_parents)
.unique_by(|p| p.as_ptr());
let rc = RelRc::with_parents(replacement.into(), parents.into_iter().map(|p| (p, ())));
if let Err(err) = get_base_ancestors(&rc).exactly_one() {
return Err(InvalidCommit::NonUniqueBase(err.count()));
}
rc.try_register_in(state_space.as_registry())
.expect("new node is not yet registered");
Ok(Self(rc, PhantomData))
}
pub(crate) unsafe fn from_relrc(rc: RelRc<CommitData, ()>) -> Self {
Self(rc, PhantomData)
}
pub fn state_space(&self) -> CommitStateSpace {
self.0
.registry()
.expect("invalid commit: not registered")
.into()
}
pub fn id(&self) -> CommitId {
self.state_space()
.get_id(self)
.expect("invalid commit: not registered")
}
pub fn parents(&self) -> impl Iterator<Item = &Self> + '_ {
self.as_relrc()
.all_parents()
.map_into()
.map(|cm: &Commit| unsafe { upgrade_lifetime(cm) })
}
pub fn children(&self, _state_space: &'a CommitStateSpace) -> impl Iterator<Item = Self> + '_ {
self.as_relrc()
.all_children()
.map(|rc| unsafe { Self::from_relrc(rc) })
}
pub fn is_valid(&self) -> bool {
get_base_ancestors(&self.0).exactly_one().is_ok() && self.0.registry().is_some()
}
pub(crate) fn as_relrc(&self) -> &RelRc<CommitData, ()> {
&self.0
}
pub fn inserted_nodes(&self) -> impl Iterator<Item = Node> + '_ {
match self.0.value() {
CommitData::Base(base) => Either::Left(base.nodes()),
CommitData::Replacement(repl) => {
Either::Right(repl.replacement().entry_descendants().skip(3))
}
}
}
pub fn replacement(&self) -> Option<&PersistentReplacement> {
match self.0.value() {
CommitData::Base(_) => None,
CommitData::Replacement(replacement) => Some(replacement),
}
}
pub fn deleted_parent_nodes(&self) -> impl Iterator<Item = PatchNode> + '_ {
self.replacement()
.into_iter()
.flat_map(|r| r.invalidation_set())
}
pub(crate) fn commit_hugr(&self) -> &Hugr {
match self.value() {
CommitData::Base(base) => base,
CommitData::Replacement(repl) => repl.replacement(),
}
}
delegate! {
to self.0 {
pub(crate) fn value(&self) -> &CommitData;
pub(crate) fn as_ptr(&self) -> *const relrc::node::InnerData<CommitData, ()>;
}
}
pub(crate) fn base_commit(&self) -> &Self {
let rc = get_base_ancestors(&self.0)
.next()
.expect("no base commit found");
let commit: &Commit = rc.into();
unsafe { upgrade_lifetime(commit) }
}
pub(crate) fn is_value_port(&self, node: Node, port: impl Into<Port>) -> bool {
self.commit_hugr()
.get_optype(node)
.port_kind(port)
.expect("invalid port")
.is_value()
}
pub(crate) fn value_ports(
&self,
node: Node,
dir: Direction,
) -> impl Iterator<Item = (Node, Port)> + '_ {
let ports = self.node_ports(node, dir);
ports.filter_map(move |p| self.is_value_port(node, p).then_some((node, p)))
}
pub(crate) fn output_value_ports(
&self,
node: Node,
) -> impl Iterator<Item = (Node, OutgoingPort)> + '_ {
self.value_ports(node, Direction::Outgoing)
.map(|(n, p)| (n, p.as_outgoing().expect("unexpected port direction")))
}
pub(crate) fn input_value_ports(
&self,
node: Node,
) -> impl Iterator<Item = (Node, IncomingPort)> + '_ {
self.value_ports(node, Direction::Incoming)
.map(|(n, p)| (n, p.as_incoming().expect("unexpected port direction")))
}
pub unsafe fn upgrade_lifetime<'b>(self) -> Commit<'b> {
Commit(self.0, PhantomData)
}
}
pub(crate) unsafe fn upgrade_lifetime<'a, 'b, 'c>(commit: &'c Commit<'a>) -> &'c Commit<'b> {
unsafe { mem::transmute(commit) }
}
impl Commit<'_> {
pub fn get_optype(&self, node: Node) -> &OpType {
let hugr = self.commit_hugr();
hugr.get_optype(node)
}
pub fn num_ports(&self, node: Node, dir: Direction) -> usize {
self.commit_hugr().num_ports(node, dir)
}
#[inline]
pub fn node_outputs(&self, node: Node) -> impl Iterator<Item = OutgoingPort> + Clone + '_ {
self.node_ports(node, Direction::Outgoing)
.map(|p| p.as_outgoing().unwrap())
}
#[inline]
pub fn node_inputs(&self, node: Node) -> impl Iterator<Item = IncomingPort> + Clone + '_ {
self.node_ports(node, Direction::Incoming)
.map(|p| p.as_incoming().unwrap())
}
pub fn node_ports(
&self,
node: Node,
dir: Direction,
) -> impl Iterator<Item = Port> + Clone + '_ {
self.commit_hugr().node_ports(node, dir)
}
pub fn all_node_ports(&self, node: Node) -> impl Iterator<Item = Port> + Clone + '_ {
self.commit_hugr().all_node_ports(node)
}
pub fn node_metadata_map(&self, node: Node) -> &NodeMetadataMap {
self.commit_hugr().node_metadata_map(node)
}
}
fn get_base_ancestors(arg: &RelRc<CommitData, ()>) -> impl Iterator<Item = &RelRc<CommitData, ()>> {
arg.all_ancestors()
.filter(|c| matches!(c.value(), CommitData::Base(_)))
}
impl From<Commit<'_>> for RelRc<CommitData, ()> {
fn from(Commit(data, _): Commit) -> Self {
data
}
}
impl<'a> From<&'a RelRc<CommitData, ()>> for &'a Commit<'a> {
fn from(rc: &'a RelRc<CommitData, ()>) -> Self {
unsafe { mem::transmute(rc) }
}
}
#[derive(Debug, Clone, Error)]
#[non_exhaustive]
pub enum InvalidCommit {
#[error("Incompatible history: children of commit {0:?} conflict in {1:?}")]
IncompatibleHistory(CommitId, Node),
#[error("Missing parent commit: {0:?}")]
UnknownParent(CommitId),
#[error("Commit is not a replacement")]
NotReplacement,
#[error("{0} base commits found (should be 1)")]
NonUniqueBase(usize),
#[error("Not allowed: empty replacement")]
EmptyReplacement,
#[error("Invalid subgraph: {0}")]
InvalidSubgraph(#[from] InvalidPinnedSubgraph),
#[error("Invalid replacement: {0}")]
InvalidReplacement(#[from] InvalidReplacement),
#[error("Invalid signature: {0}")]
InvalidSignature(#[from] InvalidSignature),
#[error("Incomplete wire: {0} is unpinned")]
IncompleteWire(PatchNode, Port),
#[error("Unknown commit ID: {0:?}")]
UnknownCommitId(CommitId),
}
fn get_parent_commits<'a>(
replacement: &PersistentReplacement,
state_space: &'a CommitStateSpace,
) -> Result<Vec<Commit<'a>>, InvalidCommit> {
let parent_ids = replacement.invalidation_set().map(|n| n.owner()).unique();
parent_ids
.map(|id| {
state_space
.try_upgrade(id)
.ok_or(InvalidCommit::UnknownParent(id))
})
.collect()
}