hugr_core/hugr/views/
root_checked.rs1use std::marker::PhantomData;
2
3use crate::hugr::HugrError;
4use crate::ops::handle::NodeHandle;
5use crate::ops::{OpTag, OpTrait};
6use crate::{Hugr, Node};
7
8use super::HugrView;
9
10mod dfg;
11pub use dfg::InvalidSignature;
12
13#[derive(Clone)]
16pub struct RootChecked<H, Handle = Node>(H, PhantomData<Handle>);
17
18impl<H: HugrView, Handle: NodeHandle<H::Node>> RootChecked<H, Handle> {
19 const TAG: OpTag = Handle::TAG;
21
22 pub fn tag(&self) -> OpTag {
24 let tag = self.0.get_optype(self.0.entrypoint()).tag();
25 debug_assert!(Self::TAG.is_superset(tag));
26 tag
27 }
28
29 pub fn try_new(hugr: H) -> Result<Self, HugrError> {
36 Self::check(&hugr)?;
37 Ok(Self(hugr, PhantomData))
38 }
39
40 pub fn check(hugr: &H) -> Result<(), HugrError> {
44 check_tag::<Handle, _>(hugr, hugr.entrypoint())?;
45 Ok(())
46 }
47
48 pub fn hugr(&self) -> &H {
50 &self.0
51 }
52
53 pub fn into_hugr(self) -> H {
55 self.0
56 }
57
58 pub fn as_ref(&self) -> RootChecked<&H, Handle> {
60 RootChecked(&self.0, PhantomData)
61 }
62}
63
64impl<H: AsRef<Hugr>, Handle> AsRef<Hugr> for RootChecked<H, Handle> {
65 fn as_ref(&self) -> &Hugr {
66 self.0.as_ref()
67 }
68}
69
70pub fn check_tag<Required: NodeHandle<N>, N>(
72 hugr: &impl HugrView<Node = N>,
73 node: N,
74) -> Result<(), HugrError> {
75 let actual = hugr.get_optype(node).tag();
76 let required = Required::TAG;
77 if !required.is_superset(actual) {
78 return Err(HugrError::InvalidTag { required, actual });
79 }
80 Ok(())
81}
82
83#[cfg(test)]
84mod test {
85 use super::RootChecked;
86 use crate::hugr::HugrError;
87 use crate::ops::handle::{CfgID, DfgID};
88 use crate::ops::{OpTag, OpType};
89 use crate::{Hugr, ops, types::Signature};
90
91 #[test]
92 fn root_checked() {
93 let root_type: OpType = ops::DFG {
94 signature: Signature::new(vec![], vec![]),
95 }
96 .into();
97 let mut h = Hugr::new_with_entrypoint(root_type.clone()).unwrap();
98 let cfg_v = RootChecked::<_, CfgID>::check(&h);
99 assert_eq!(
100 cfg_v.err(),
101 Some(HugrError::InvalidTag {
102 required: OpTag::Cfg,
103 actual: OpTag::Dfg
104 })
105 );
106 let dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap();
108 assert!(OpTag::Dfg.is_superset(dfg_v.tag()));
109 assert_eq!(dfg_v.as_ref().tag(), dfg_v.tag());
110 }
111}