Skip to main content

hugr_core/hugr/views/
root_checked.rs

1use std::marker::PhantomData;
2
3use crate::hugr::HugrError;
4use crate::ops::handle::NodeHandle;
5use crate::ops::{OpTag, OpTrait};
6use crate::{Hugr, Node};
7
8use super::HugrView;
9
10mod dfg;
11pub use dfg::InvalidSignature;
12
13/// A wrapper over a Hugr that ensures the entrypoint optype is of the required
14/// [`OpTag`].
15#[derive(Clone)]
16pub struct RootChecked<H, Handle = Node>(H, PhantomData<Handle>);
17
18impl<H: HugrView, Handle: NodeHandle<H::Node>> RootChecked<H, Handle> {
19    /// A tag that can contain the operation of the hugr entrypoint node.
20    const TAG: OpTag = Handle::TAG;
21
22    /// Returns the most specific tag that can be applied to the entrypoint node.
23    pub fn tag(&self) -> OpTag {
24        let tag = self.0.get_optype(self.0.entrypoint()).tag();
25        debug_assert!(Self::TAG.is_superset(tag));
26        tag
27    }
28
29    /// Create a hierarchical view of a whole HUGR
30    ///
31    /// # Errors
32    /// Returns [`HugrError::InvalidTag`] if the entrypoint isn't a node of the required [`OpTag`]
33    ///
34    /// [`OpTag`]: crate::ops::OpTag
35    pub fn try_new(hugr: H) -> Result<Self, HugrError> {
36        Self::check(&hugr)?;
37        Ok(Self(hugr, PhantomData))
38    }
39
40    /// Check if a Hugr is valid for the given [`OpTag`].
41    ///
42    /// To check arbitrary nodes, use [`check_tag`].
43    pub fn check(hugr: &H) -> Result<(), HugrError> {
44        check_tag::<Handle, _>(hugr, hugr.entrypoint())?;
45        Ok(())
46    }
47
48    /// Returns a reference to the underlying Hugr.
49    pub fn hugr(&self) -> &H {
50        &self.0
51    }
52
53    /// Extracts the underlying Hugr
54    pub fn into_hugr(self) -> H {
55        self.0
56    }
57
58    /// Returns a wrapper over a reference to the underlying Hugr.
59    pub fn as_ref(&self) -> RootChecked<&H, Handle> {
60        RootChecked(&self.0, PhantomData)
61    }
62}
63
64impl<H: AsRef<Hugr>, Handle> AsRef<Hugr> for RootChecked<H, Handle> {
65    fn as_ref(&self) -> &Hugr {
66        self.0.as_ref()
67    }
68}
69
70/// Check that the node in a HUGR can be represented by the required tag.
71pub fn check_tag<Required: NodeHandle<N>, N>(
72    hugr: &impl HugrView<Node = N>,
73    node: N,
74) -> Result<(), HugrError> {
75    let actual = hugr.get_optype(node).tag();
76    let required = Required::TAG;
77    if !required.is_superset(actual) {
78        return Err(HugrError::InvalidTag { required, actual });
79    }
80    Ok(())
81}
82
83#[cfg(test)]
84mod test {
85    use super::RootChecked;
86    use crate::hugr::HugrError;
87    use crate::ops::handle::{CfgID, DfgID};
88    use crate::ops::{OpTag, OpType};
89    use crate::{Hugr, ops, types::Signature};
90
91    #[test]
92    fn root_checked() {
93        let root_type: OpType = ops::DFG {
94            signature: Signature::new(vec![], vec![]),
95        }
96        .into();
97        let mut h = Hugr::new_with_entrypoint(root_type.clone()).unwrap();
98        let cfg_v = RootChecked::<_, CfgID>::check(&h);
99        assert_eq!(
100            cfg_v.err(),
101            Some(HugrError::InvalidTag {
102                required: OpTag::Cfg,
103                actual: OpTag::Dfg
104            })
105        );
106        // This should succeed
107        let dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap();
108        assert!(OpTag::Dfg.is_superset(dfg_v.tag()));
109        assert_eq!(dfg_v.as_ref().tag(), dfg_v.tag());
110    }
111}