hugr_core/hugr/views/
root_checked.rs1use std::marker::PhantomData;
2
3use delegate::delegate;
4use portgraph::MultiPortGraph;
5
6use crate::hugr::internal::{HugrInternals, HugrMutInternals};
7use crate::hugr::{HugrError, HugrMut};
8use crate::ops::handle::NodeHandle;
9use crate::{Hugr, Node};
10
11use super::{check_tag, RootTagged};
12
13#[derive(Clone)]
16pub struct RootChecked<H, Root = Node>(H, PhantomData<Root>);
17
18impl<H: RootTagged + AsRef<Hugr>, Root: NodeHandle> RootChecked<H, Root> {
19 pub fn try_new(hugr: H) -> Result<Self, HugrError> {
26 if !H::RootHandle::TAG.is_superset(Root::TAG) {
27 return Err(HugrError::InvalidTag {
28 required: H::RootHandle::TAG,
29 actual: Root::TAG,
30 });
31 }
32 check_tag::<Root, _>(&hugr, hugr.root())?;
33 Ok(Self(hugr, PhantomData))
34 }
35}
36
37impl<Root> RootChecked<Hugr, Root> {
38 pub fn into_hugr(self) -> Hugr {
40 self.0
41 }
42}
43
44impl<Root> RootChecked<&mut Hugr, Root> {
45 pub fn borrow(&self) -> RootChecked<&Hugr, Root> {
47 RootChecked(&*self.0, PhantomData)
48 }
49}
50
51impl<H: AsRef<Hugr>, Root> HugrInternals for RootChecked<H, Root> {
52 type Portgraph<'p>
53 = &'p MultiPortGraph
54 where
55 Self: 'p;
56 type Node = Node;
57
58 delegate! {
59 to self.as_ref() {
60 fn portgraph(&self) -> Self::Portgraph<'_>;
61 fn base_hugr(&self) -> &Hugr;
62 fn root_node(&self) -> Node;
63 fn get_pg_index(&self, node: Node) -> portgraph::NodeIndex;
64 fn get_node(&self, index: portgraph::NodeIndex) -> Node;
65 }
66 }
67}
68
69impl<H: AsRef<Hugr>, Root: NodeHandle> RootTagged for RootChecked<H, Root> {
70 type RootHandle = Root;
71}
72
73impl<H: AsRef<Hugr>, Root> AsRef<Hugr> for RootChecked<H, Root> {
74 fn as_ref(&self) -> &Hugr {
75 self.0.as_ref()
76 }
77}
78
79impl<H: HugrMutInternals + AsRef<Hugr>, Root> HugrMutInternals for RootChecked<H, Root>
80where
81 Root: NodeHandle,
82{
83 #[inline(always)]
84 fn hugr_mut(&mut self) -> &mut Hugr {
85 self.0.hugr_mut()
86 }
87}
88
89impl<H: HugrMutInternals + AsRef<Hugr>, Root: NodeHandle> HugrMut for RootChecked<H, Root> {}
90
91#[cfg(test)]
92mod test {
93 use super::RootChecked;
94 use crate::extension::prelude::MakeTuple;
95 use crate::extension::ExtensionSet;
96 use crate::hugr::internal::HugrMutInternals;
97 use crate::hugr::{HugrError, HugrMut};
98 use crate::ops::handle::{BasicBlockID, CfgID, DataflowParentID, DfgID};
99 use crate::ops::{DataflowBlock, OpTag, OpType};
100 use crate::{ops, type_row, types::Signature, Hugr, HugrView};
101
102 #[test]
103 fn root_checked() {
104 let root_type: OpType = ops::DFG {
105 signature: Signature::new(vec![], vec![]),
106 }
107 .into();
108 let mut h = Hugr::new(root_type.clone());
109 let cfg_v = RootChecked::<&Hugr, CfgID>::try_new(&h);
110 assert_eq!(
111 cfg_v.err(),
112 Some(HugrError::InvalidTag {
113 required: OpTag::Cfg,
114 actual: OpTag::Dfg
115 })
116 );
117 let mut dfg_v = RootChecked::<&mut Hugr, DfgID>::try_new(&mut h).unwrap();
118 let root = dfg_v.root();
120 let bb: OpType = DataflowBlock {
121 inputs: type_row![],
122 other_outputs: type_row![],
123 sum_rows: vec![type_row![]],
124 extension_delta: ExtensionSet::new(),
125 }
126 .into();
127 let r = dfg_v.replace_op(root, bb.clone());
128 assert_eq!(
129 r,
130 Err(HugrError::InvalidTag {
131 required: OpTag::Dfg,
132 actual: ops::OpTag::DataflowBlock
133 })
134 );
135 assert_eq!(dfg_v.get_optype(root), &root_type);
137
138 assert_eq!(
141 RootChecked::<_, DataflowParentID>::try_new(dfg_v).err(),
142 Some(HugrError::InvalidTag {
143 required: OpTag::Dfg,
144 actual: OpTag::DataflowParent
145 })
146 );
147
148 let mut dfp_v = RootChecked::<&mut Hugr, DataflowParentID>::try_new(&mut h).unwrap();
149 let r = dfp_v.replace_op(root, bb.clone());
150 assert_eq!(r, Ok(root_type));
151 assert_eq!(dfp_v.get_optype(root), &bb);
152 let mut bb_v = RootChecked::<_, BasicBlockID>::try_new(dfp_v).unwrap();
154
155 let nodetype = MakeTuple(type_row![]);
157 bb_v.add_node_with_parent(bb_v.root(), nodetype);
158 }
159}