use std::ops::Range;
use std::sync::OnceLock;
use itertools::Itertools;
use portgraph::{LinkMut, LinkView, MultiPortGraph, PortMut, PortOffset, PortView};
use crate::core::HugrNode;
use crate::extension::ExtensionRegistry;
use crate::{Direction, Hugr, Node};
use super::HugrView;
use super::views::{panic_invalid_node, panic_invalid_non_entrypoint};
use super::{NodeMetadataMap, OpType};
use crate::ops::handle::NodeHandle;
pub trait HugrInternals {
type RegionPortgraph<'p>: LinkView<LinkEndpoint: Eq, NodeIndexBase = u32, PortIndexBase = u32, PortOffsetBase = u32>
+ Clone
+ 'p
where
Self: 'p;
type Node: Copy + Ord + std::fmt::Debug + std::fmt::Display + std::hash::Hash;
type RegionPortgraphNodes: PortgraphNodeMap<Self::Node>;
fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap;
}
pub trait PortgraphNodeMap<N>: Clone + Sized + std::fmt::Debug {
fn to_portgraph(&self, node: N) -> portgraph::NodeIndex;
#[allow(clippy::wrong_self_convention)]
fn from_portgraph(&self, node: portgraph::NodeIndex) -> N;
}
#[derive(
Copy, Clone, Debug, Default, Eq, PartialEq, Hash, PartialOrd, Ord, derive_more::Display,
)]
pub struct DefaultPGNodeMap;
impl PortgraphNodeMap<Node> for DefaultPGNodeMap {
#[inline]
fn to_portgraph(&self, node: Node) -> portgraph::NodeIndex {
node.into_portgraph()
}
#[inline]
fn from_portgraph(&self, node: portgraph::NodeIndex) -> Node {
node.into()
}
}
impl<N: HugrNode> PortgraphNodeMap<N> for std::collections::HashMap<N, Node> {
#[inline]
fn to_portgraph(&self, node: N) -> portgraph::NodeIndex {
self[&node].into_portgraph()
}
#[inline]
fn from_portgraph(&self, node: portgraph::NodeIndex) -> N {
let node = node.into();
self.iter()
.find_map(|(&k, &v)| (v == node).then_some(k))
.expect("Portgraph node not found in map")
}
}
impl HugrInternals for Hugr {
type RegionPortgraph<'p>
= &'p MultiPortGraph<u32, u32, u32>
where
Self: 'p;
type Node = Node;
type RegionPortgraphNodes = DefaultPGNodeMap;
#[inline]
fn node_metadata_map(&self, node: Self::Node) -> &NodeMetadataMap {
static EMPTY: OnceLock<NodeMetadataMap> = OnceLock::new();
panic_invalid_node(self, node);
let map = self.metadata.get(node.into_portgraph()).as_ref();
map.unwrap_or(EMPTY.get_or_init(Default::default))
}
}
pub trait HugrMutInternals: HugrView {
fn set_module_root(&mut self, root: Self::Node);
fn set_num_ports(&mut self, node: Self::Node, incoming: usize, outgoing: usize);
fn add_ports(&mut self, node: Self::Node, direction: Direction, amount: isize) -> Range<usize>;
fn insert_ports(
&mut self,
node: Self::Node,
direction: Direction,
index: usize,
amount: usize,
) -> Range<usize>;
fn set_parent(&mut self, node: Self::Node, parent: Self::Node);
fn move_after_sibling(&mut self, node: Self::Node, after: Self::Node);
fn move_before_sibling(&mut self, node: Self::Node, before: Self::Node);
fn replace_op(&mut self, node: Self::Node, op: impl Into<OpType>) -> OpType;
fn optype_mut(&mut self, node: Self::Node) -> &mut OpType;
fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap;
fn extensions_mut(&mut self) -> &mut ExtensionRegistry;
}
impl HugrMutInternals for Hugr {
fn set_module_root(&mut self, root: Node) {
panic_invalid_node(self, root.node());
let root = root.into_portgraph();
self.hierarchy.detach(root);
self.module_root = root;
}
#[inline]
fn set_num_ports(&mut self, node: Node, incoming: usize, outgoing: usize) {
panic_invalid_node(self, node);
self.graph
.set_num_ports(node.into_portgraph(), incoming, outgoing, |_, _| {});
}
fn add_ports(&mut self, node: Node, direction: Direction, amount: isize) -> Range<usize> {
panic_invalid_node(self, node);
let mut incoming = self.graph.num_inputs(node.into_portgraph());
let mut outgoing = self.graph.num_outputs(node.into_portgraph());
let increment = |num: &mut usize| {
let new = num.saturating_add_signed(amount);
let range = *num..new;
*num = new;
range
};
let range = match direction {
Direction::Incoming => increment(&mut incoming),
Direction::Outgoing => increment(&mut outgoing),
};
self.graph
.set_num_ports(node.into_portgraph(), incoming, outgoing, |_, _| {});
range
}
fn insert_ports(
&mut self,
node: Node,
direction: Direction,
index: usize,
amount: usize,
) -> Range<usize> {
panic_invalid_node(self, node);
let old_num_ports = self.graph.num_ports(node.into_portgraph(), direction);
self.add_ports(node, direction, amount as isize);
for swap_from_port in (index..old_num_ports).rev() {
let swap_to_port = swap_from_port + amount;
let [from_port_index, to_port_index] = [swap_from_port, swap_to_port].map(|p| {
self.graph
.port_index(node.into_portgraph(), PortOffset::new(direction, p))
.unwrap()
});
let linked_ports = self
.graph
.port_links(from_port_index)
.map(|(_, to_subport)| to_subport.port())
.collect_vec();
self.graph.unlink_port(from_port_index);
for linked_port_index in linked_ports {
let _ = self
.graph
.link_ports(to_port_index, linked_port_index)
.expect("Ports exist");
}
}
index..index + amount
}
fn set_parent(&mut self, node: Node, parent: Node) {
panic_invalid_node(self, parent);
panic_invalid_node(self, node);
self.hierarchy.detach(node.into_portgraph());
self.hierarchy
.push_child(node.into_portgraph(), parent.into_portgraph())
.expect("Inserting a newly-created node into the hierarchy should never fail.");
}
fn move_after_sibling(&mut self, node: Node, after: Node) {
panic_invalid_non_entrypoint(self, node);
panic_invalid_non_entrypoint(self, after);
self.hierarchy.detach(node.into_portgraph());
self.hierarchy
.insert_after(node.into_portgraph(), after.into_portgraph())
.expect("Inserting a newly-created node into the hierarchy should never fail.");
}
fn move_before_sibling(&mut self, node: Node, before: Node) {
panic_invalid_non_entrypoint(self, node);
panic_invalid_non_entrypoint(self, before);
self.hierarchy.detach(node.into_portgraph());
self.hierarchy
.insert_before(node.into_portgraph(), before.into_portgraph())
.expect("Inserting a newly-created node into the hierarchy should never fail.");
}
fn replace_op(&mut self, node: Node, op: impl Into<OpType>) -> OpType {
panic_invalid_node(self, node);
std::mem::replace(self.optype_mut(node), op.into())
}
fn optype_mut(&mut self, node: Node) -> &mut OpType {
panic_invalid_node(self, node);
let node = node.into_portgraph();
self.op_types.get_mut(node)
}
fn node_metadata_map_mut(&mut self, node: Self::Node) -> &mut NodeMetadataMap {
panic_invalid_node(self, node);
self.metadata
.get_mut(node.into_portgraph())
.get_or_insert_with(Default::default)
}
fn extensions_mut(&mut self) -> &mut ExtensionRegistry {
&mut self.extensions
}
}
#[cfg(test)]
mod test {
use crate::{
Direction, HugrView as _,
builder::{Container, DFGBuilder, Dataflow, DataflowHugr},
extension::prelude::Noop,
hugr::internal::HugrMutInternals as _,
ops::handle::NodeHandle,
types::{Signature, Type},
};
#[test]
fn insert_ports() {
let (nop, mut hugr) = {
let mut builder = DFGBuilder::new(Signature::new_endo([Type::UNIT])).unwrap();
let [nop_in] = builder.input_wires_arr();
let nop = builder
.add_dataflow_op(Noop::new(Type::UNIT), [nop_in])
.unwrap();
builder.add_other_wire(nop.node(), builder.output().node());
let [nop_out] = nop.outputs_arr();
(
nop.node(),
builder.finish_hugr_with_outputs([nop_out]).unwrap(),
)
};
let [i, o] = hugr.get_io(hugr.entrypoint()).unwrap();
assert_eq!(0..2, hugr.insert_ports(nop, Direction::Incoming, 0, 2));
assert_eq!(1..3, hugr.insert_ports(nop, Direction::Outgoing, 1, 2));
assert_eq!(hugr.single_linked_input(i, 0), Some((nop, 2.into())));
assert_eq!(hugr.single_linked_output(o, 0), Some((nop, 0.into())));
assert_eq!(hugr.single_linked_output(o, 1), Some((nop, 3.into())));
}
}