hugr-core 0.27.1

Quantinuum's Hierarchical Unified Graph Representation
Documentation
//! Data structure summarizing static nodes of a Hugr and their uses
use std::collections::BTreeMap;

use crate::{HugrView, Node, core::HugrNode, ops::OpType};
use petgraph::{Graph, visit::EdgeRef};

/// Weight for an edge in a [`ModuleGraph`]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[non_exhaustive]
pub enum StaticEdge<N = Node> {
    /// Edge corresponds to a [Call](OpType::Call) node (specified) in the Hugr
    Call(N),
    /// Edge corresponds to a [`LoadFunction`](OpType::LoadFunction) node (specified) in the Hugr
    LoadFunction(N),
    /// Edge corresponds to a [LoadConstant](OpType::LoadConstant) node (specified) in the Hugr
    LoadConstant(N),
}

/// Weight for a petgraph-node in a [`ModuleGraph`]
#[derive(Clone, Debug, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[non_exhaustive]
pub enum StaticNode<N = Node> {
    /// petgraph-node corresponds to a [`FuncDecl`](OpType::FuncDecl) node (specified) in the Hugr
    FuncDecl(N),
    /// petgraph-node corresponds to a [`FuncDefn`](OpType::FuncDefn) node (specified) in the Hugr
    FuncDefn(N),
    /// petgraph-node corresponds to the [HugrView::entrypoint], that is not
    /// a [`FuncDefn`](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
    /// either, as such a node could not have edges, so is not represented in the petgraph.
    NonFuncEntrypoint,
    /// petgraph-node corresponds to a constant; will have no outgoing edges, and incoming
    /// edges will be [StaticEdge::LoadConstant]
    Const(N),
}

/// Details the [`FuncDefn`]s, [`FuncDecl`]s and module-level [`Const`]s in a Hugr,
/// in a Hugr, along with the [`Call`]s, [`LoadFunction`]s, and [`LoadConstant`]s connecting them.
///
/// Each node in the `ModuleGraph` corresponds to a module-level function or const;
/// each edge corresponds to a use of the target contained in the edge's source.
///
/// For Hugrs whose entrypoint is neither a [Module](OpType::Module) nor a [`FuncDefn`],
/// the static graph will have an additional [`StaticNode::NonFuncEntrypoint`]
/// corresponding to the Hugr's entrypoint, with no incoming edges.
///
/// [`Call`]: OpType::Call
/// [`Const`]: OpType::Const
/// [`FuncDecl`]: OpType::FuncDecl
/// [`FuncDefn`]: OpType::FuncDefn
/// [`LoadConstant`]: OpType::LoadConstant
/// [`LoadFunction`]: OpType::LoadFunction
pub struct ModuleGraph<N = Node> {
    g: Graph<StaticNode<N>, StaticEdge<N>>,
    node_to_g: BTreeMap<N, petgraph::graph::NodeIndex<u32>>,
}

impl<N: HugrNode> ModuleGraph<N> {
    /// Makes a new `ModuleGraph` for a Hugr.
    pub fn new(hugr: &impl HugrView<Node = N>) -> Self {
        let mut g = Graph::default();
        let mut node_to_g = hugr
            .children(hugr.module_root())
            .filter_map(|n| {
                let weight = match hugr.get_optype(n) {
                    OpType::FuncDecl(_) => StaticNode::FuncDecl(n),
                    OpType::FuncDefn(_) => StaticNode::FuncDefn(n),
                    OpType::Const(_) => StaticNode::Const(n),
                    _ => return None,
                };
                Some((n, g.add_node(weight)))
            })
            .collect::<BTreeMap<_, _>>();
        if !hugr.entrypoint_optype().is_module() && !node_to_g.contains_key(&hugr.entrypoint()) {
            node_to_g.insert(hugr.entrypoint(), g.add_node(StaticNode::NonFuncEntrypoint));
        }
        for (func, cg_node) in &node_to_g {
            for n in hugr.descendants(*func) {
                let weight = match hugr.get_optype(n) {
                    OpType::Call(_) => StaticEdge::Call(n),
                    OpType::LoadFunction(_) => StaticEdge::LoadFunction(n),
                    OpType::LoadConstant(_) => StaticEdge::LoadConstant(n),
                    _ => continue,
                };
                if let Some(target) = hugr.static_source(n) {
                    if hugr.get_parent(target) == Some(hugr.module_root()) {
                        g.add_edge(*cg_node, node_to_g[&target], weight);
                    } else {
                        // Local constant (only global constants are in the graph)
                        assert!(!node_to_g.contains_key(&target));
                        assert!(hugr.get_optype(n).is_load_constant());
                        assert!(hugr.get_optype(target).is_const());
                    }
                }
            }
        }
        ModuleGraph { g, node_to_g }
    }

    /// Allows access to the petgraph
    #[must_use]
    pub fn graph(&self) -> &Graph<StaticNode<N>, StaticEdge<N>> {
        &self.g
    }

    /// Convert a Hugr [Node] into a petgraph node index.
    /// Result will be `None` if `n` is not a [`FuncDefn`](OpType::FuncDefn),
    /// [`FuncDecl`](OpType::FuncDecl) or the [HugrView::entrypoint].
    pub fn node_index(&self, n: N) -> Option<petgraph::graph::NodeIndex<u32>> {
        self.node_to_g.get(&n).copied()
    }

    /// Returns an iterator over the out-edges from the given Node, i.e.
    /// edges to the functions/constants called/loaded by it.
    ///
    /// If the node is not recognised as a function or the entrypoint,
    /// for example if it is a [`Const`](OpType::Const), the iterator will be empty.
    pub fn out_edges(&self, n: N) -> impl Iterator<Item = (&StaticEdge<N>, &StaticNode<N>)> {
        let g = self.graph();
        self.node_index(n).into_iter().flat_map(move |n| {
            self.graph().edges(n).map(|e| {
                (
                    g.edge_weight(e.id()).unwrap(),
                    g.node_weight(e.target()).unwrap(),
                )
            })
        })
    }

    /// Returns an iterator over the in-edges to the given Node, i.e.
    /// edges from the (necessarily) functions that call/load it.
    ///
    /// If the node is not recognised as a function or constant,
    /// for example if it is a non-function entrypoint, the iterator will be empty.
    pub fn in_edges(&self, n: N) -> impl Iterator<Item = (&StaticNode<N>, &StaticEdge<N>)> {
        let g = self.graph();
        self.node_index(n).into_iter().flat_map(move |n| {
            self.graph()
                .edges_directed(n, petgraph::Direction::Incoming)
                .map(|e| {
                    (
                        g.node_weight(e.source()).unwrap(),
                        g.edge_weight(e.id()).unwrap(),
                    )
                })
        })
    }
}

#[cfg(test)]
mod test {
    use itertools::Itertools as _;
    use rstest::rstest;
    use std::collections::HashMap;

    use crate::builder::{
        Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder,
        ModuleBuilder, endo_sig, inout_sig,
    };
    use crate::extension::prelude::{ConstUsize, usize_t};
    use crate::hugr::hugrmut::HugrMut;
    use crate::ops::{Value, handle::NodeHandle};

    use super::*;

    #[test]
    fn edges() {
        let mut mb = ModuleBuilder::new();
        let cst = mb.add_constant(Value::from(ConstUsize::new(42)));
        let callee = mb.define_function("callee", endo_sig([usize_t()])).unwrap();
        let ins = callee.input_wires();
        let callee = callee.finish_with_outputs(ins).unwrap();
        let mut caller = mb
            .define_function("caller", inout_sig(vec![], [usize_t()]))
            .unwrap();
        let val = caller.load_const(&cst);
        let call = caller.call(callee.handle(), &[], vec![val]).unwrap();
        let caller = caller.finish_with_outputs(call.outputs()).unwrap();
        let h = mb.finish_hugr().unwrap();

        let mg = ModuleGraph::new(&h);
        let call_edge = StaticEdge::Call(call.node());
        let load_const_edge = StaticEdge::LoadConstant(val.node());

        assert_eq!(mg.out_edges(callee.node()).next(), None);
        assert_eq!(
            mg.in_edges(callee.node()).collect_vec(),
            [(&StaticNode::FuncDefn(caller.node()), &call_edge,)]
        );

        assert_eq!(
            mg.out_edges(caller.node()).collect_vec(),
            [
                (&call_edge, &StaticNode::FuncDefn(callee.node()),),
                (&load_const_edge, &StaticNode::Const(cst.node()),)
            ]
        );
        assert_eq!(mg.in_edges(caller.node()).next(), None);

        assert_eq!(mg.out_edges(cst.node()).next(), None);
        assert_eq!(
            mg.in_edges(cst.node()).collect_vec(),
            [(&StaticNode::FuncDefn(caller.node()), &load_const_edge,)]
        );
    }

    #[rstest]
    fn entrypoint(#[values(true, false)] single_node: bool) {
        let mut dfb = DFGBuilder::new(endo_sig([usize_t()])).unwrap();
        let called = dfb
            .module_root_builder()
            .declare("called", endo_sig([usize_t()]).into())
            .unwrap();
        let ins = dfb.input_wires();
        let call = dfb.call(&called, &[], ins).unwrap();
        let mut h = dfb.finish_hugr_with_outputs(call.outputs()).unwrap();
        let main = h.get_parent(h.entrypoint()).unwrap();
        if single_node {
            h.set_entrypoint(call.node());
        }

        let mg = ModuleGraph::new(&h);
        let mut in_edges: HashMap<_, _> = mg.in_edges(called.node()).collect();
        assert_eq!(
            in_edges.remove(&StaticNode::NonFuncEntrypoint),
            Some(&StaticEdge::Call(call.node()))
        );
        let expected = (&StaticNode::FuncDefn(main), &StaticEdge::Call(call.node()));
        assert_eq!(in_edges, HashMap::from_iter([expected]));
        for n in [h.entrypoint(), main] {
            assert_eq!(
                mg.out_edges(n).collect_vec(),
                vec![(
                    &StaticEdge::Call(call.node()),
                    &StaticNode::FuncDecl(called.node())
                )]
            );
        }
    }
}