hugr_core/hugr/views/
root_checked.rs

1use std::marker::PhantomData;
2
3use crate::hugr::HugrError;
4use crate::ops::handle::NodeHandle;
5use crate::ops::{OpTag, OpTrait};
6use crate::{Hugr, Node};
7
8use super::HugrView;
9
10mod dfg;
11pub use dfg::InvalidSignature;
12
13/// A wrapper over a Hugr that ensures the entrypoint optype is of the required
14/// [`OpTag`].
15#[derive(Clone)]
16pub struct RootChecked<H, Handle = Node>(H, PhantomData<Handle>);
17
18impl<H: HugrView, Handle: NodeHandle<H::Node>> RootChecked<H, Handle> {
19    /// A tag that can contain the operation of the hugr entrypoint node.
20    const TAG: OpTag = Handle::TAG;
21
22    /// Returns the most specific tag that can be applied to the entrypoint node.
23    pub fn tag(&self) -> OpTag {
24        let tag = self.0.get_optype(self.0.entrypoint()).tag();
25        debug_assert!(Self::TAG.is_superset(tag));
26        tag
27    }
28
29    /// Create a hierarchical view of a whole HUGR
30    ///
31    /// # Errors
32    /// Returns [`HugrError::InvalidTag`] if the entrypoint isn't a node of the required [`OpTag`]
33    ///
34    /// [`OpTag`]: crate::ops::OpTag
35    pub fn try_new(hugr: H) -> Result<Self, HugrError> {
36        Self::check(&hugr)?;
37        Ok(Self(hugr, PhantomData))
38    }
39
40    /// Check if a Hugr is valid for the given [`OpTag`].
41    ///
42    /// To check arbitrary nodes, use [`check_tag`].
43    pub fn check(hugr: &H) -> Result<(), HugrError> {
44        check_tag::<Handle, _>(hugr, hugr.entrypoint())?;
45        Ok(())
46    }
47
48    /// Returns a reference to the underlying Hugr.
49    pub fn hugr(&self) -> &H {
50        &self.0
51    }
52
53    /// Extracts the underlying Hugr
54    pub fn into_hugr(self) -> H {
55        self.0
56    }
57
58    /// Returns a wrapper over a reference to the underlying Hugr.
59    pub fn as_ref(&self) -> RootChecked<&H, Handle> {
60        RootChecked(&self.0, PhantomData)
61    }
62}
63
64impl<H: AsRef<Hugr>, Handle> AsRef<Hugr> for RootChecked<H, Handle> {
65    fn as_ref(&self) -> &Hugr {
66        self.0.as_ref()
67    }
68}
69
70/// A trait for types that can be checked for a specific [`OpTag`] at their entrypoint node.
71///
72/// This is used mainly specifying function inputs that may either be a [`HugrView`] or an already checked [`RootChecked`].
73pub trait RootCheckable<H: HugrView, Handle: NodeHandle<H::Node>>: Sized {
74    /// Wrap the Hugr in a [`RootChecked`] if it is valid for the required [`OpTag`].
75    ///
76    /// If `Self` is already a [`RootChecked`], it is a no-op.
77    fn try_into_checked(self) -> Result<RootChecked<H, Handle>, HugrError>;
78}
79impl<H: HugrView, Handle: NodeHandle<H::Node>> RootCheckable<H, Handle> for H {
80    fn try_into_checked(self) -> Result<RootChecked<H, Handle>, HugrError> {
81        RootChecked::try_new(self)
82    }
83}
84impl<H: HugrView, Handle: NodeHandle<H::Node>> RootCheckable<H, Handle> for RootChecked<H, Handle> {
85    fn try_into_checked(self) -> Result<RootChecked<H, Handle>, HugrError> {
86        Ok(self)
87    }
88}
89
90/// Check that the node in a HUGR can be represented by the required tag.
91pub fn check_tag<Required: NodeHandle<N>, N>(
92    hugr: &impl HugrView<Node = N>,
93    node: N,
94) -> Result<(), HugrError> {
95    let actual = hugr.get_optype(node).tag();
96    let required = Required::TAG;
97    if !required.is_superset(actual) {
98        return Err(HugrError::InvalidTag { required, actual });
99    }
100    Ok(())
101}
102
103#[cfg(test)]
104mod test {
105    use super::RootChecked;
106    use crate::hugr::HugrError;
107    use crate::ops::handle::{CfgID, DfgID};
108    use crate::ops::{OpTag, OpType};
109    use crate::{Hugr, ops, types::Signature};
110
111    #[test]
112    fn root_checked() {
113        let root_type: OpType = ops::DFG {
114            signature: Signature::new(vec![], vec![]),
115        }
116        .into();
117        let mut h = Hugr::new_with_entrypoint(root_type.clone()).unwrap();
118        let cfg_v = RootChecked::<_, CfgID>::check(&h);
119        assert_eq!(
120            cfg_v.err(),
121            Some(HugrError::InvalidTag {
122                required: OpTag::Cfg,
123                actual: OpTag::Dfg
124            })
125        );
126        // This should succeed
127        let dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap();
128        assert!(OpTag::Dfg.is_superset(dfg_v.tag()));
129        assert_eq!(dfg_v.as_ref().tag(), dfg_v.tag());
130    }
131}