hugr_core/builder/
module.rs1use super::{
2 build_traits::HugrBuilder,
3 dataflow::{DFGBuilder, FunctionBuilder},
4 BuildError, Container,
5};
6
7use crate::hugr::internal::HugrMutInternals;
8use crate::hugr::views::HugrView;
9use crate::hugr::ValidationError;
10use crate::ops;
11use crate::types::{PolyFuncType, Type, TypeBound};
12
13use crate::ops::handle::{AliasID, FuncID, NodeHandle};
14
15use crate::{Hugr, Node};
16use smol_str::SmolStr;
17
18#[derive(Debug, Clone, PartialEq)]
20pub struct ModuleBuilder<T>(pub(super) T);
21
22impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
23 #[inline]
24 fn container_node(&self) -> Node {
25 self.0.as_ref().root()
26 }
27
28 #[inline]
29 fn hugr_mut(&mut self) -> &mut Hugr {
30 self.0.as_mut()
31 }
32
33 fn hugr(&self) -> &Hugr {
34 self.0.as_ref()
35 }
36}
37
38impl ModuleBuilder<Hugr> {
39 #[must_use]
41 pub fn new() -> Self {
42 Self(Default::default())
43 }
44}
45
46impl Default for ModuleBuilder<Hugr> {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl HugrBuilder for ModuleBuilder<Hugr> {
53 fn finish_hugr(mut self) -> Result<Hugr, ValidationError> {
54 if cfg!(feature = "extension_inference") {
55 self.0.infer_extensions(false)?;
56 }
57 self.0.validate()?;
58 Ok(self.0)
59 }
60}
61
62impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
63 pub fn define_declaration(
71 &mut self,
72 f_id: &FuncID<false>,
73 ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
74 let f_node = f_id.node();
75 let ops::FuncDecl { signature, name } = self
76 .hugr()
77 .get_optype(f_node)
78 .as_func_decl()
79 .ok_or(BuildError::UnexpectedType {
80 node: f_node,
81 op_desc: "crate::ops::OpType::FuncDecl",
82 })?
83 .clone();
84 let body = signature.body().clone();
85 self.hugr_mut()
86 .replace_op(f_node, ops::FuncDefn { name, signature })
87 .expect("Replacing a FuncDecl node with a FuncDefn should always be valid");
88
89 let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
90 Ok(FunctionBuilder::from_dfg_builder(db))
91 }
92
93 pub fn declare(
100 &mut self,
101 name: impl Into<String>,
102 signature: PolyFuncType,
103 ) -> Result<FuncID<false>, BuildError> {
104 let body = signature.body().clone();
105 let declare_n = self.add_child_node(ops::FuncDecl {
107 signature,
108 name: name.into(),
109 });
110
111 self.use_extensions(
113 body.used_extensions().unwrap_or_else(|e| {
114 panic!("Build-time signatures should have valid extensions. {e}")
115 }),
116 );
117
118 Ok(declare_n.into())
119 }
120
121 pub fn add_alias_def(
127 &mut self,
128 name: impl Into<SmolStr>,
129 typ: Type,
130 ) -> Result<AliasID<true>, BuildError> {
131 let name: SmolStr = name.into();
137 let bound = typ.least_upper_bound();
138 let node = self.add_child_node(ops::AliasDefn {
139 name: name.clone(),
140 definition: typ,
141 });
142
143 Ok(AliasID::new(node, name, bound))
144 }
145
146 pub fn add_alias_declare(
151 &mut self,
152 name: impl Into<SmolStr>,
153 bound: TypeBound,
154 ) -> Result<AliasID<false>, BuildError> {
155 let name: SmolStr = name.into();
156 let node = self.add_child_node(ops::AliasDecl {
157 name: name.clone(),
158 bound,
159 });
160
161 Ok(AliasID::new(node, name, bound))
162 }
163}
164
165#[cfg(test)]
166mod test {
167 use cool_asserts::assert_matches;
168
169 use crate::extension::prelude::usize_t;
170 use crate::{
171 builder::{test::n_identity, Dataflow, DataflowSubContainer},
172 types::Signature,
173 };
174
175 use super::*;
176 #[test]
177 fn basic_recurse() -> Result<(), BuildError> {
178 let build_result = {
179 let mut module_builder = ModuleBuilder::new();
180
181 let f_id = module_builder.declare(
182 "main",
183 Signature::new(vec![usize_t()], vec![usize_t()]).into(),
184 )?;
185
186 let mut f_build = module_builder.define_declaration(&f_id)?;
187 let call = f_build.call(&f_id, &[], f_build.input_wires())?;
188
189 f_build.finish_with_outputs(call.outputs())?;
190 module_builder.finish_hugr()
191 };
192 assert_matches!(build_result, Ok(_));
193 Ok(())
194 }
195
196 #[test]
197 fn simple_alias() -> Result<(), BuildError> {
198 let build_result = {
199 let mut module_builder = ModuleBuilder::new();
200
201 let qubit_state_type =
202 module_builder.add_alias_declare("qubit_state", TypeBound::Any)?;
203
204 let f_build = module_builder.define_function(
205 "main",
206 Signature::new(
207 vec![qubit_state_type.get_alias_type()],
208 vec![qubit_state_type.get_alias_type()],
209 ),
210 )?;
211 n_identity(f_build)?;
212 module_builder.finish_hugr()
213 };
214 assert_matches!(build_result, Ok(_));
215 Ok(())
216 }
217
218 #[test]
219 fn local_def() -> Result<(), BuildError> {
220 let build_result = {
221 let mut module_builder = ModuleBuilder::new();
222
223 let mut f_build = module_builder.define_function(
224 "main",
225 Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
226 )?;
227 let local_build = f_build.define_function(
228 "local",
229 Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
230 )?;
231 let [wire] = local_build.input_wires_arr();
232 let f_id = local_build.finish_with_outputs([wire, wire])?;
233
234 let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?;
235
236 f_build.finish_with_outputs(call.outputs())?;
237 module_builder.finish_hugr()
238 };
239 assert_matches!(build_result, Ok(_));
240 Ok(())
241 }
242}