1use super::{
2 BuildError, Container,
3 build_traits::HugrBuilder,
4 dataflow::{DFGBuilder, FunctionBuilder},
5};
6
7use crate::hugr::linking::{HugrLinking, NodeLinkingDirectives, NodeLinkingError};
8use crate::hugr::{
9 ValidationError, hugrmut::InsertedForest, internal::HugrMutInternals, views::HugrView,
10};
11use crate::ops;
12use crate::ops::handle::{AliasID, FuncID, NodeHandle};
13use crate::types::{PolyFuncType, Type, TypeBound};
14use crate::{Hugr, Node, Visibility, ops::FuncDefn};
15
16use smol_str::SmolStr;
17
18#[derive(Debug, Default, Clone, PartialEq)]
20pub struct ModuleBuilder<T>(pub(super) T);
21
22impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
23 #[inline]
24 fn container_node(&self) -> Node {
25 self.0.as_ref().module_root()
26 }
27
28 #[inline]
29 fn hugr_mut(&mut self) -> &mut Hugr {
30 self.0.as_mut()
31 }
32
33 fn hugr(&self) -> &Hugr {
34 self.0.as_ref()
35 }
36}
37
38impl ModuleBuilder<Hugr> {
39 #[must_use]
41 pub fn new() -> Self {
42 Self::default()
43 }
44}
45
46impl HugrBuilder for ModuleBuilder<Hugr> {
47 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
48 self.0.validate()?;
49 Ok(self.0)
50 }
51}
52
53impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
54 #[must_use]
56 pub fn with_hugr(hugr: T) -> Self {
57 ModuleBuilder(hugr)
58 }
59
60 pub fn define_declaration(
68 &mut self,
69 f_id: &FuncID<false>,
70 ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
71 let f_node = f_id.node();
72 let opty = self.hugr_mut().optype_mut(f_node);
73 let ops::OpType::FuncDecl(decl) = opty else {
74 return Err(BuildError::UnexpectedType {
75 node: f_node,
76 op_desc: "crate::ops::OpType::FuncDecl",
77 });
78 };
79
80 let body = decl.signature().body().clone();
81 *opty = ops::FuncDefn::new_vis(
82 decl.func_name(),
83 decl.signature().clone(),
84 decl.visibility().clone(),
85 )
86 .into();
87
88 let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
89 Ok(FunctionBuilder::from_dfg_builder(db))
90 }
91
92 pub fn define_function_vis(
100 &mut self,
101 name: impl Into<String>,
102 signature: impl Into<PolyFuncType>,
103 visibility: Visibility,
104 ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
105 self.define_function_op(FuncDefn::new_vis(name, signature, visibility))
106 }
107
108 fn define_function_op(
109 &mut self,
110 op: FuncDefn,
111 ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
112 let body = op.signature().body().clone();
113 let f_node = self.add_child_node(op);
114
115 self.use_extensions(
117 body.used_extensions().unwrap_or_else(|e| {
118 panic!("Build-time signatures should have valid extensions. {e}")
119 }),
120 );
121
122 let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
123 Ok(FunctionBuilder::from_dfg_builder(db))
124 }
125
126 pub fn declare(
133 &mut self,
134 name: impl Into<String>,
135 signature: PolyFuncType,
136 ) -> Result<FuncID<false>, BuildError> {
137 self.declare_vis(name, signature, Visibility::Public)
138 }
139
140 pub fn declare_vis(
148 &mut self,
149 name: impl Into<String>,
150 signature: PolyFuncType,
151 visibility: Visibility,
152 ) -> Result<FuncID<false>, BuildError> {
153 let body = signature.body().clone();
154 let declare_n = self.add_child_node(ops::FuncDecl::new_vis(name, signature, visibility));
156
157 self.use_extensions(
159 body.used_extensions().unwrap_or_else(|e| {
160 panic!("Build-time signatures should have valid extensions. {e}")
161 }),
162 );
163
164 Ok(declare_n.into())
165 }
166
167 pub fn define_function(
175 &mut self,
176 name: impl Into<String>,
177 signature: impl Into<PolyFuncType>,
178 ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
179 self.define_function_op(FuncDefn::new(name, signature))
180 }
181
182 pub fn add_alias_def(
188 &mut self,
189 name: impl Into<SmolStr>,
190 typ: Type,
191 ) -> Result<AliasID<true>, BuildError> {
192 let name: SmolStr = name.into();
198 let bound = typ.least_upper_bound();
199 let node = self.add_child_node(ops::AliasDefn {
200 name: name.clone(),
201 definition: typ,
202 });
203
204 Ok(AliasID::new(node, name, bound))
205 }
206
207 pub fn add_alias_declare(
212 &mut self,
213 name: impl Into<SmolStr>,
214 bound: TypeBound,
215 ) -> Result<AliasID<false>, BuildError> {
216 let name: SmolStr = name.into();
217 let node = self.add_child_node(ops::AliasDecl {
218 name: name.clone(),
219 bound,
220 });
221
222 Ok(AliasID::new(node, name, bound))
223 }
224
225 pub fn link_hugr_by_node(
232 &mut self,
233 other: Hugr,
234 children: NodeLinkingDirectives<Node, Node>,
235 ) -> Result<InsertedForest, NodeLinkingError> {
236 self.hugr_mut()
237 .insert_link_hugr_by_node(None, other, children)
238 }
239
240 pub fn link_view_by_node<H: HugrView>(
247 &mut self,
248 other: &H,
249 children: NodeLinkingDirectives<H::Node, Node>,
250 ) -> Result<InsertedForest<H::Node>, NodeLinkingError<H::Node>> {
251 self.hugr_mut()
252 .insert_link_view_by_node(None, other, children)
253 }
254}
255
256#[cfg(test)]
257mod test {
258 use std::collections::{HashMap, HashSet};
259
260 use cool_asserts::assert_matches;
261
262 use crate::builder::test::dfg_calling_defn_decl;
263 use crate::builder::{Dataflow, DataflowSubContainer, test::n_identity};
264 use crate::extension::prelude::usize_t;
265 use crate::{hugr::linking::NodeLinkingDirective, ops::OpType, types::Signature};
266
267 use super::*;
268 #[test]
269 fn basic_recurse() -> Result<(), BuildError> {
270 let build_result = {
271 let mut module_builder = ModuleBuilder::new();
272
273 let f_id = module_builder.declare(
274 "main",
275 Signature::new(vec![usize_t()], vec![usize_t()]).into(),
276 )?;
277
278 let mut f_build = module_builder.define_declaration(&f_id)?;
279 let call = f_build.call(&f_id, &[], f_build.input_wires())?;
280
281 f_build.finish_with_outputs(call.outputs())?;
282 module_builder.finish_hugr()
283 };
284 assert_matches!(build_result, Ok(_));
285 Ok(())
286 }
287
288 #[test]
289 fn simple_alias() -> Result<(), BuildError> {
290 let build_result = {
291 let mut module_builder = ModuleBuilder::new();
292
293 let qubit_state_type =
294 module_builder.add_alias_declare("qubit_state", TypeBound::Linear)?;
295
296 let f_build = module_builder.define_function(
297 "main",
298 Signature::new(
299 vec![qubit_state_type.get_alias_type()],
300 vec![qubit_state_type.get_alias_type()],
301 ),
302 )?;
303 n_identity(f_build)?;
304 module_builder.finish_hugr()
305 };
306 assert_matches!(build_result, Ok(_));
307 Ok(())
308 }
309
310 #[test]
311 fn builder_from_existing() -> Result<(), BuildError> {
312 let hugr = Hugr::new();
313
314 let fn_builder = FunctionBuilder::with_hugr(hugr, "main", Signature::new_endo(vec![]))?;
315 let mut hugr = fn_builder.finish_hugr()?;
316
317 let mut module_builder = ModuleBuilder::with_hugr(&mut hugr);
318 module_builder.declare("other", Signature::new_endo(vec![]).into())?;
319
320 hugr.validate()?;
321
322 Ok(())
323 }
324
325 #[test]
326 fn link_by_node() {
327 let mut mb = ModuleBuilder::new();
328 let (dfg, defn, decl) = dfg_calling_defn_decl();
329 let added = mb
330 .link_view_by_node(
331 &dfg,
332 HashMap::from([
333 (defn.node(), NodeLinkingDirective::add()),
334 (decl.node(), NodeLinkingDirective::add()),
335 ]),
336 )
337 .unwrap();
338 let n_defn = added.node_map[&defn.node()];
339 let n_decl = added.node_map[&decl.node()];
340 let h = mb.hugr();
341 assert_eq!(h.children(h.module_root()).count(), 2);
342 h.validate().unwrap();
343 let old_name = match mb.hugr_mut().optype_mut(n_defn) {
344 OpType::FuncDefn(fd) => std::mem::replace(fd.func_name_mut(), "new".to_string()),
345 _ => panic!(),
346 };
347 let main = dfg.get_parent(dfg.entrypoint()).unwrap();
348 assert_eq!(
349 dfg.get_optype(main).as_func_defn().unwrap().func_name(),
350 "main"
351 );
352 mb.link_hugr_by_node(
353 dfg,
354 HashMap::from([
355 (main, NodeLinkingDirective::add()),
356 (decl.node(), NodeLinkingDirective::UseExisting(n_defn)),
357 (defn.node(), NodeLinkingDirective::replace([n_decl])),
358 ]),
359 )
360 .unwrap();
361 let h = mb.finish_hugr().unwrap();
362 assert_eq!(
363 h.children(h.module_root())
364 .map(|n| h.get_optype(n).as_func_defn().unwrap().func_name().as_str())
365 .collect::<HashSet<_>>(),
366 HashSet::from(["main", "new", old_name.as_str()])
367 );
368 }
369}