use std::iter;
use context_iterators::{ContextIterator, IntoContextIterator, MapWithCtx};
use itertools::{Itertools, MapInto};
use portgraph::{LinkView, MultiPortGraph, PortIndex, PortView};
use crate::hugr::internal::HugrMutInternals;
use crate::hugr::{HugrError, HugrMut};
use crate::ops::handle::NodeHandle;
use crate::{Direction, Hugr, Node, Port};
use super::{check_tag, ExtractHugr, HierarchyView, HugrInternals, HugrView, RootTagged};
type FlatRegionGraph<'g> = portgraph::view::FlatRegion<'g, &'g MultiPortGraph>;
#[derive(Clone)]
pub struct SiblingGraph<'g, Root = Node> {
root: Node,
graph: FlatRegionGraph<'g>,
hugr: &'g Hugr,
_phantom: std::marker::PhantomData<Root>,
}
macro_rules! impl_base_members {
() => {
type Nodes<'a> = iter::Chain<iter::Once<Node>, MapInto<portgraph::hierarchy::Children<'a>, Node>>
where
Self: 'a;
type NodePorts<'a> = MapInto<<FlatRegionGraph<'g> as PortView>::NodePortOffsets<'a>, Port>
where
Self: 'a;
type Children<'a> = MapInto<portgraph::hierarchy::Children<'a>, Node>
where
Self: 'a;
#[inline]
fn node_count(&self) -> usize {
self.base_hugr().hierarchy.child_count(self.root.pg_index()) + 1
}
#[inline]
fn edge_count(&self) -> usize {
self.nodes()
.map(|n| self.output_neighbours(n).count())
.sum()
}
#[inline]
fn nodes(&self) -> Self::Nodes<'_> {
let children = self
.base_hugr()
.hierarchy
.children(self.root.pg_index())
.map_into();
iter::once(self.root).chain(children)
}
fn children(&self, node: Node) -> Self::Children<'_> {
match node == self.root {
true => self.base_hugr().hierarchy.children(node.pg_index()).map_into(),
false => portgraph::hierarchy::Children::default().map_into(),
}
}
};
}
impl<'g, Root: NodeHandle> HugrView for SiblingGraph<'g, Root> {
type Neighbours<'a> = MapInto<<FlatRegionGraph<'g> as LinkView>::Neighbours<'a>, Node>
where
Self: 'a;
type PortLinks<'a> = MapWithCtx<
<FlatRegionGraph<'g> as LinkView>::PortLinks<'a>,
&'a Self,
(Node, Port),
> where
Self: 'a;
type NodeConnections<'a> = MapWithCtx<
<FlatRegionGraph<'g> as LinkView>::NodeConnections<'a>,
&'a Self,
[Port; 2],
> where
Self: 'a;
impl_base_members! {}
#[inline]
fn contains_node(&self, node: Node) -> bool {
self.graph.contains_node(node.pg_index())
}
#[inline]
fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> {
self.graph.port_offsets(node.pg_index(), dir).map_into()
}
#[inline]
fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> {
self.graph.all_port_offsets(node.pg_index()).map_into()
}
fn linked_ports(&self, node: Node, port: impl Into<Port>) -> Self::PortLinks<'_> {
let port = self
.graph
.port_index(node.pg_index(), port.into().pg_offset())
.unwrap();
self.graph
.port_links(port)
.with_context(self)
.map_with_context(|(_, link), region| {
let port: PortIndex = link.into();
let node = region.graph.port_node(port).unwrap();
let offset = region.graph.port_offset(port).unwrap();
(node.into(), offset.into())
})
}
fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> {
self.graph
.get_connections(node.pg_index(), other.pg_index())
.with_context(self)
.map_with_context(|(p1, p2), hugr| {
[p1, p2].map(|link| {
let offset = hugr.graph.port_offset(link).unwrap();
offset.into()
})
})
}
#[inline]
fn num_ports(&self, node: Node, dir: Direction) -> usize {
self.graph.num_ports(node.pg_index(), dir)
}
#[inline]
fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> {
self.graph.neighbours(node.pg_index(), dir).map_into()
}
#[inline]
fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> {
self.graph.all_neighbours(node.pg_index()).map_into()
}
}
impl<'g, Root: NodeHandle> RootTagged for SiblingGraph<'g, Root> {
type RootHandle = Root;
}
impl<'a, Root: NodeHandle> SiblingGraph<'a, Root> {
fn new_unchecked(hugr: &'a impl HugrView, root: Node) -> Self {
let hugr = hugr.base_hugr();
Self {
root,
graph: FlatRegionGraph::new_flat_region(&hugr.graph, &hugr.hierarchy, root.pg_index()),
hugr,
_phantom: std::marker::PhantomData,
}
}
}
impl<'a, Root> HierarchyView<'a> for SiblingGraph<'a, Root>
where
Root: NodeHandle,
{
fn try_new(hugr: &'a impl HugrView, root: Node) -> Result<Self, HugrError> {
assert!(
hugr.valid_node(root),
"Cannot create a sibling graph from an invalid node {}.",
root
);
check_tag::<Root>(hugr, root)?;
Ok(Self::new_unchecked(hugr, root))
}
}
impl<'g, Root: NodeHandle> ExtractHugr for SiblingGraph<'g, Root> {}
impl<'g, Root> HugrInternals for SiblingGraph<'g, Root>
where
Root: NodeHandle,
{
type Portgraph<'p> = &'p FlatRegionGraph<'g> where Self: 'p;
#[inline]
fn portgraph(&self) -> Self::Portgraph<'_> {
&self.graph
}
#[inline]
fn base_hugr(&self) -> &Hugr {
self.hugr
}
#[inline]
fn root_node(&self) -> Node {
self.root
}
}
pub struct SiblingMut<'g, Root = Node> {
root: Node,
hugr: &'g mut Hugr,
_phantom: std::marker::PhantomData<Root>,
}
impl<'g, Root: NodeHandle> SiblingMut<'g, Root> {
pub fn try_new<Base: HugrMut>(hugr: &'g mut Base, root: Node) -> Result<Self, HugrError> {
if root == hugr.root() && !Base::RootHandle::TAG.is_superset(Root::TAG) {
return Err(HugrError::InvalidTag {
required: Base::RootHandle::TAG,
actual: Root::TAG,
});
}
check_tag::<Root>(hugr, root)?;
Ok(Self {
hugr: hugr.hugr_mut(),
root,
_phantom: std::marker::PhantomData,
})
}
}
impl<'g, Root: NodeHandle> ExtractHugr for SiblingMut<'g, Root> {}
impl<'g, Root: NodeHandle> HugrInternals for SiblingMut<'g, Root> {
type Portgraph<'p> = FlatRegionGraph<'p> where 'g: 'p, Root: 'p;
fn portgraph(&self) -> Self::Portgraph<'_> {
FlatRegionGraph::new_flat_region(
&self.base_hugr().graph,
&self.base_hugr().hierarchy,
self.root.pg_index(),
)
}
fn base_hugr(&self) -> &Hugr {
self.hugr
}
fn root_node(&self) -> Node {
self.root
}
}
impl<'g, Root: NodeHandle> HugrView for SiblingMut<'g, Root> {
type Neighbours<'a> = <Vec<Node> as IntoIterator>::IntoIter
where
Self: 'a;
type PortLinks<'a> = <Vec<(Node, Port)> as IntoIterator>::IntoIter
where
Self: 'a;
type NodeConnections<'a> = <Vec<[Port; 2]> as IntoIterator>::IntoIter where Self: 'a;
impl_base_members! {}
fn contains_node(&self, node: Node) -> bool {
node == self.root || self.base_hugr().get_parent(node) == Some(self.root)
}
fn node_ports(&self, node: Node, dir: Direction) -> Self::NodePorts<'_> {
match self.contains_node(node) {
true => self.base_hugr().node_ports(node, dir),
false => <FlatRegionGraph as PortView>::NodePortOffsets::default().map_into(),
}
}
fn all_node_ports(&self, node: Node) -> Self::NodePorts<'_> {
match self.contains_node(node) {
true => self.base_hugr().all_node_ports(node),
false => <FlatRegionGraph as PortView>::NodePortOffsets::default().map_into(),
}
}
fn linked_ports(&self, node: Node, port: impl Into<Port>) -> Self::PortLinks<'_> {
SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root)
.linked_ports(node, port)
.collect::<Vec<_>>()
.into_iter()
}
fn node_connections(&self, node: Node, other: Node) -> Self::NodeConnections<'_> {
SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root)
.node_connections(node, other)
.collect::<Vec<_>>()
.into_iter()
}
fn num_ports(&self, node: Node, dir: Direction) -> usize {
match self.contains_node(node) {
true => self.base_hugr().num_ports(node, dir),
false => 0,
}
}
fn neighbours(&self, node: Node, dir: Direction) -> Self::Neighbours<'_> {
SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root)
.neighbours(node, dir)
.collect::<Vec<_>>()
.into_iter()
}
fn all_neighbours(&self, node: Node) -> Self::Neighbours<'_> {
SiblingGraph::<'_, Node>::new_unchecked(self.hugr, self.root)
.all_neighbours(node)
.collect::<Vec<_>>()
.into_iter()
}
}
impl<'g, Root: NodeHandle> RootTagged for SiblingMut<'g, Root> {
type RootHandle = Root;
}
impl<'g, Root: NodeHandle> HugrMutInternals for SiblingMut<'g, Root> {
fn hugr_mut(&mut self) -> &mut Hugr {
self.hugr
}
}
impl<'g, Root: NodeHandle> HugrMut for SiblingMut<'g, Root> {}
#[cfg(test)]
mod test {
use rstest::rstest;
use crate::builder::test::simple_dfg_hugr;
use crate::builder::{Container, Dataflow, DataflowSubContainer, HugrBuilder, ModuleBuilder};
use crate::extension::PRELUDE_REGISTRY;
use crate::ops::handle::{CfgID, DataflowParentID, DfgID, FuncID};
use crate::ops::{dataflow::IOTrait, Input, OpTag, Output};
use crate::ops::{OpTrait, OpType};
use crate::type_row;
use crate::types::{FunctionType, Type};
use super::super::descendants::test::make_module_hgr;
use super::*;
#[test]
fn flat_region() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, def, inner) = make_module_hgr()?;
let region: SiblingGraph = SiblingGraph::try_new(&hugr, def)?;
assert_eq!(region.node_count(), 5);
assert!(region
.nodes()
.all(|n| n == def || hugr.get_parent(n) == Some(def)));
assert_eq!(region.children(inner).count(), 0);
Ok(())
}
const NAT: Type = crate::extension::prelude::USIZE_T;
#[test]
fn nested_flat() -> Result<(), Box<dyn std::error::Error>> {
let mut module_builder = ModuleBuilder::new();
let fty = FunctionType::new(type_row![NAT], type_row![NAT]);
let mut fbuild = module_builder.define_function("main", fty.clone())?;
let dfg = fbuild.dfg_builder(fty, fbuild.input_wires())?;
let ins = dfg.input_wires();
let sub_dfg = dfg.finish_with_outputs(ins)?;
let fun = fbuild.finish_with_outputs(sub_dfg.outputs())?;
let h = module_builder.finish_hugr(&PRELUDE_REGISTRY)?;
let sub_dfg = sub_dfg.node();
let dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&h, sub_dfg)?;
let fun_view: SiblingGraph<'_, FuncID<true>> = SiblingGraph::try_new(&h, fun.node())?;
assert_eq!(fun_view.children(sub_dfg).len(), 0);
let nested_dfg_view: SiblingGraph<'_, DfgID> = SiblingGraph::try_new(&fun_view, sub_dfg)?;
let just_io = vec![
Input::new(type_row![NAT]).into(),
Output::new(type_row![NAT]).into(),
];
for d in [dfg_view, nested_dfg_view] {
assert_eq!(
d.children(sub_dfg).map(|n| d.get_optype(n)).collect_vec(),
just_io.iter().collect_vec()
);
}
Ok(())
}
#[rstest]
fn flat_mut(mut simple_dfg_hugr: Hugr) {
simple_dfg_hugr.update_validate(&PRELUDE_REGISTRY).unwrap();
let root = simple_dfg_hugr.root();
let signature = simple_dfg_hugr.get_df_function_type().unwrap().clone();
let sib_mut = SiblingMut::<CfgID>::try_new(&mut simple_dfg_hugr, root);
assert_eq!(
sib_mut.err(),
Some(HugrError::InvalidTag {
required: OpTag::Cfg,
actual: OpTag::Dfg
})
);
let mut sib_mut = SiblingMut::<DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap();
let bad_nodetype: OpType = crate::ops::CFG { signature }.into();
assert_eq!(
sib_mut.replace_op(sib_mut.root(), bad_nodetype.clone()),
Err(HugrError::InvalidTag {
required: OpTag::Dfg,
actual: OpTag::Cfg
})
);
simple_dfg_hugr.replace_op(root, bad_nodetype).unwrap();
assert!(simple_dfg_hugr.validate(&PRELUDE_REGISTRY).is_err());
}
#[rstest]
fn sibling_mut_covariance(mut simple_dfg_hugr: Hugr) {
let root = simple_dfg_hugr.root();
let case_nodetype = crate::ops::Case {
signature: simple_dfg_hugr.root_type().dataflow_signature().unwrap(),
};
let mut sib_mut = SiblingMut::<DfgID>::try_new(&mut simple_dfg_hugr, root).unwrap();
assert_eq!(
sib_mut.replace_op(root, case_nodetype),
Err(HugrError::InvalidTag {
required: OpTag::Dfg,
actual: OpTag::Case
})
);
let nested_sib_mut = SiblingMut::<DataflowParentID>::try_new(&mut sib_mut, root);
assert!(nested_sib_mut.is_err());
}
#[rstest]
fn extract_hugr() -> Result<(), Box<dyn std::error::Error>> {
let (hugr, _def, inner) = make_module_hgr()?;
let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?;
let extracted = region.extract_hugr();
extracted.validate(&PRELUDE_REGISTRY)?;
let region: SiblingGraph = SiblingGraph::try_new(&hugr, inner)?;
assert_eq!(region.node_count(), extracted.node_count());
assert_eq!(region.root_type(), extracted.root_type());
Ok(())
}
}