use core::panic;
use std::collections::HashMap;
use portgraph::view::{NodeFilter, NodeFiltered};
use portgraph::{LinkMut, NodeIndex, PortMut, PortView, SecondaryMap};
use crate::hugr::views::SiblingSubgraph;
use crate::hugr::{HugrView, Node, OpType, RootTagged};
use crate::hugr::{NodeMetadata, Rewrite};
use crate::{Hugr, IncomingPort, OutgoingPort, Port, PortIndex};
use super::internal::HugrMutInternals;
use super::NodeMetadataMap;
pub trait HugrMut: HugrMutInternals {
fn get_metadata_mut(&mut self, node: Node, key: impl AsRef<str>) -> &mut NodeMetadata {
panic_invalid_node(self, node);
let node_meta = self
.hugr_mut()
.metadata
.get_mut(node.pg_index())
.get_or_insert_with(Default::default);
node_meta
.entry(key.as_ref())
.or_insert(serde_json::Value::Null)
}
fn set_metadata(
&mut self,
node: Node,
key: impl AsRef<str>,
metadata: impl Into<NodeMetadata>,
) {
let entry = self.get_metadata_mut(node, key);
*entry = metadata.into();
}
fn take_node_metadata(&mut self, node: Node) -> Option<NodeMetadataMap> {
if !self.valid_node(node) {
return None;
}
self.hugr_mut().metadata.take(node.pg_index())
}
fn overwrite_node_metadata(&mut self, node: Node, metadata: Option<NodeMetadataMap>) {
panic_invalid_node(self, node);
self.hugr_mut().metadata.set(node.pg_index(), metadata);
}
#[inline]
fn add_node_with_parent(&mut self, parent: Node, op: impl Into<OpType>) -> Node {
panic_invalid_node(self, parent);
self.hugr_mut().add_node_with_parent(parent, op)
}
#[inline]
fn add_node_before(&mut self, sibling: Node, nodetype: impl Into<OpType>) -> Node {
panic_invalid_non_root(self, sibling);
self.hugr_mut().add_node_before(sibling, nodetype)
}
#[inline]
fn add_node_after(&mut self, sibling: Node, op: impl Into<OpType>) -> Node {
panic_invalid_non_root(self, sibling);
self.hugr_mut().add_node_after(sibling, op)
}
#[inline]
fn remove_node(&mut self, node: Node) -> OpType {
panic_invalid_non_root(self, node);
self.hugr_mut().remove_node(node)
}
#[inline]
fn connect(
&mut self,
src: Node,
src_port: impl Into<OutgoingPort>,
dst: Node,
dst_port: impl Into<IncomingPort>,
) {
panic_invalid_node(self, src);
panic_invalid_node(self, dst);
self.hugr_mut().connect(src, src_port, dst, dst_port);
}
#[inline]
fn disconnect(&mut self, node: Node, port: impl Into<Port>) {
panic_invalid_node(self, node);
self.hugr_mut().disconnect(node, port);
}
fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) {
panic_invalid_node(self, src);
panic_invalid_node(self, dst);
self.hugr_mut().add_other_edge(src, dst)
}
#[inline]
fn insert_hugr(&mut self, root: Node, other: Hugr) -> InsertionResult {
panic_invalid_node(self, root);
self.hugr_mut().insert_hugr(root, other)
}
#[inline]
fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult {
panic_invalid_node(self, root);
self.hugr_mut().insert_from_view(root, other)
}
fn insert_subgraph(
&mut self,
root: Node,
other: &impl HugrView,
subgraph: &SiblingSubgraph,
) -> HashMap<Node, Node> {
panic_invalid_node(self, root);
self.hugr_mut().insert_subgraph(root, other, subgraph)
}
fn apply_rewrite<R, E>(&mut self, rw: impl Rewrite<ApplyResult = R, Error = E>) -> Result<R, E>
where
Self: Sized,
{
rw.apply(self)
}
}
pub struct InsertionResult {
pub new_root: Node,
pub node_map: HashMap<Node, Node>,
}
fn translate_indices(node_map: HashMap<NodeIndex, NodeIndex>) -> HashMap<Node, Node> {
HashMap::from_iter(node_map.into_iter().map(|(k, v)| (k.into(), v.into())))
}
impl<T: RootTagged<RootHandle = Node> + AsMut<Hugr>> HugrMut for T {
fn add_node_with_parent(&mut self, parent: Node, node: impl Into<OpType>) -> Node {
let node = self.as_mut().add_node(node.into());
self.as_mut()
.hierarchy
.push_child(node.pg_index(), parent.pg_index())
.expect("Inserting a newly-created node into the hierarchy should never fail.");
node
}
fn add_node_before(&mut self, sibling: Node, nodetype: impl Into<OpType>) -> Node {
let node = self.as_mut().add_node(nodetype.into());
self.as_mut()
.hierarchy
.insert_before(node.pg_index(), sibling.pg_index())
.expect("Inserting a newly-created node into the hierarchy should never fail.");
node
}
fn add_node_after(&mut self, sibling: Node, op: impl Into<OpType>) -> Node {
let node = self.as_mut().add_node(op.into());
self.as_mut()
.hierarchy
.insert_after(node.pg_index(), sibling.pg_index())
.expect("Inserting a newly-created node into the hierarchy should never fail.");
node
}
fn remove_node(&mut self, node: Node) -> OpType {
panic_invalid_non_root(self, node);
self.as_mut().hierarchy.remove(node.pg_index());
self.as_mut().graph.remove_node(node.pg_index());
self.as_mut().op_types.take(node.pg_index())
}
fn connect(
&mut self,
src: Node,
src_port: impl Into<OutgoingPort>,
dst: Node,
dst_port: impl Into<IncomingPort>,
) {
let src_port = src_port.into();
let dst_port = dst_port.into();
panic_invalid_port(self, src, src_port);
panic_invalid_port(self, dst, dst_port);
self.as_mut()
.graph
.link_nodes(
src.pg_index(),
src_port.index(),
dst.pg_index(),
dst_port.index(),
)
.expect("The ports should exist at this point.");
}
fn disconnect(&mut self, node: Node, port: impl Into<Port>) {
let port = port.into();
let offset = port.pg_offset();
panic_invalid_port(self, node, port);
let port = self
.as_mut()
.graph
.port_index(node.pg_index(), offset)
.expect("The port should exist at this point.");
self.as_mut().graph.unlink_port(port);
}
fn add_other_edge(&mut self, src: Node, dst: Node) -> (OutgoingPort, IncomingPort) {
let src_port = self
.get_optype(src)
.other_output_port()
.expect("Source operation has no non-dataflow outgoing edges");
let dst_port = self
.get_optype(dst)
.other_input_port()
.expect("Destination operation has no non-dataflow incoming edges");
self.connect(src, src_port, dst, dst_port);
(src_port, dst_port)
}
fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult {
let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other);
for (&node, &new_node) in node_map.iter() {
let optype = other.op_types.take(node);
self.as_mut().op_types.set(new_node, optype);
let meta = other.metadata.take(node);
self.as_mut().metadata.set(new_node, meta);
}
debug_assert_eq!(
Some(&new_root.pg_index()),
node_map.get(&other.root().pg_index())
);
InsertionResult {
new_root,
node_map: translate_indices(node_map),
}
}
fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult {
let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other);
for (&node, &new_node) in node_map.iter() {
let nodetype = other.get_optype(node.into());
self.as_mut().op_types.set(new_node, nodetype.clone());
let meta = other.base_hugr().metadata.get(node);
self.as_mut().metadata.set(new_node, meta.clone());
}
debug_assert_eq!(
Some(&new_root.pg_index()),
node_map.get(&other.root().pg_index())
);
InsertionResult {
new_root,
node_map: translate_indices(node_map),
}
}
fn insert_subgraph(
&mut self,
root: Node,
other: &impl HugrView,
subgraph: &SiblingSubgraph,
) -> HashMap<Node, Node> {
let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> =
NodeFiltered::new_node_filtered(
other.portgraph(),
|node, ctx| ctx.contains(&node.into()),
subgraph.nodes(),
);
let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph);
for (&node, &new_node) in node_map.iter() {
let nodetype = other.get_optype(node.into());
self.as_mut().op_types.set(new_node, nodetype.clone());
let meta = other.base_hugr().metadata.get(node);
self.as_mut().metadata.set(new_node, meta.clone());
}
translate_indices(node_map)
}
}
fn insert_hugr_internal(
hugr: &mut Hugr,
root: Node,
other: &impl HugrView,
) -> (Node, HashMap<NodeIndex, NodeIndex>) {
let node_map = hugr
.graph
.insert_graph(&other.portgraph())
.unwrap_or_else(|e| panic!("Internal error while inserting a hugr into another: {e}"));
let other_root = node_map[&other.root().pg_index()];
hugr.hierarchy
.push_child(other_root, root.pg_index())
.expect("Inserting a newly-created node into the hierarchy should never fail.");
for (&node, &new_node) in node_map.iter() {
other.children(node.into()).for_each(|child| {
hugr.hierarchy
.push_child(node_map[&child.pg_index()], new_node)
.expect("Inserting a newly-created node into the hierarchy should never fail.");
});
}
let root_optype = other.get_optype(other.root());
hugr.set_num_ports(
other_root.into(),
root_optype.input_count(),
root_optype.output_count(),
);
(other_root.into(), node_map)
}
fn insert_subgraph_internal(
hugr: &mut Hugr,
root: Node,
other: &impl HugrView,
portgraph: &impl portgraph::LinkView,
) -> HashMap<NodeIndex, NodeIndex> {
let node_map = hugr
.graph
.insert_graph(&portgraph)
.expect("Internal error while inserting a subgraph into another");
for (&node, &new_node) in node_map.iter() {
let new_parent = other
.get_parent(node.into())
.and_then(|parent| node_map.get(&parent.pg_index()).copied())
.unwrap_or(root.pg_index());
hugr.hierarchy
.push_child(new_node, new_parent)
.expect("Inserting a newly-created node into the hierarchy should never fail.");
}
node_map
}
#[track_caller]
pub(super) fn panic_invalid_node<H: HugrView + ?Sized>(hugr: &H, node: Node) {
if !hugr.valid_node(node) {
panic!(
"Received an invalid node {node} while mutating a HUGR:\n\n {}",
hugr.mermaid_string()
);
}
}
#[track_caller]
pub(super) fn panic_invalid_non_root<H: HugrView + ?Sized>(hugr: &H, node: Node) {
if !hugr.valid_non_root(node) {
panic!(
"Received an invalid non-root node {node} while mutating a HUGR:\n\n {}",
hugr.mermaid_string()
);
}
}
#[track_caller]
pub(super) fn panic_invalid_port<H: HugrView + ?Sized>(
hugr: &H,
node: Node,
port: impl Into<Port>,
) {
let port = port.into();
if hugr
.portgraph()
.port_index(node.pg_index(), port.pg_offset())
.is_none()
{
panic!(
"Received an invalid port {port} for node {node} while mutating a HUGR:\n\n {}",
hugr.mermaid_string()
);
}
}
#[cfg(test)]
mod test {
use crate::{
extension::{
prelude::{Noop, USIZE_T},
PRELUDE_REGISTRY,
},
macros::type_row,
ops::{self, dataflow::IOTrait},
types::{Signature, Type},
};
use super::*;
const NAT: Type = USIZE_T;
#[test]
fn simple_function() -> Result<(), Box<dyn std::error::Error>> {
let mut hugr = Hugr::default();
let module: Node = hugr.root();
let f: Node = hugr.add_node_with_parent(
module,
ops::FuncDefn {
name: "main".into(),
signature: Signature::new(type_row![NAT], type_row![NAT, NAT])
.with_prelude()
.into(),
},
);
{
let f_in = hugr.add_node_with_parent(f, ops::Input::new(type_row![NAT]));
let f_out = hugr.add_node_with_parent(f, ops::Output::new(type_row![NAT, NAT]));
let noop = hugr.add_node_with_parent(f, Noop(NAT));
hugr.connect(f_in, 0, noop, 0);
hugr.connect(noop, 0, f_out, 0);
hugr.connect(noop, 0, f_out, 1);
}
hugr.update_validate(&PRELUDE_REGISTRY)?;
Ok(())
}
}