hugr_core/builder/
module.rs

1use super::{
2    BuildError, Container,
3    build_traits::HugrBuilder,
4    dataflow::{DFGBuilder, FunctionBuilder},
5};
6
7use crate::hugr::linking::{HugrLinking, NodeLinkingDirectives, NodeLinkingError};
8use crate::hugr::{
9    ValidationError, hugrmut::InsertedForest, internal::HugrMutInternals, views::HugrView,
10};
11use crate::ops;
12use crate::ops::handle::{AliasID, FuncID, NodeHandle};
13use crate::types::{PolyFuncType, Type, TypeBound};
14use crate::{Hugr, Node, Visibility, ops::FuncDefn};
15
16use smol_str::SmolStr;
17
18/// Builder for a HUGR module.
19#[derive(Debug, Default, Clone, PartialEq)]
20pub struct ModuleBuilder<T>(pub(super) T);
21
22impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
23    #[inline]
24    fn container_node(&self) -> Node {
25        self.0.as_ref().module_root()
26    }
27
28    #[inline]
29    fn hugr_mut(&mut self) -> &mut Hugr {
30        self.0.as_mut()
31    }
32
33    fn hugr(&self) -> &Hugr {
34        self.0.as_ref()
35    }
36}
37
38impl ModuleBuilder<Hugr> {
39    /// Begin building a new module.
40    #[must_use]
41    pub fn new() -> Self {
42        Self::default()
43    }
44}
45
46impl HugrBuilder for ModuleBuilder<Hugr> {
47    fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
48        self.0.validate()?;
49        Ok(self.0)
50    }
51}
52
53impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
54    /// Continue building a module from an existing hugr.
55    #[must_use]
56    pub fn with_hugr(hugr: T) -> Self {
57        ModuleBuilder(hugr)
58    }
59
60    /// Replace a [`ops::FuncDecl`] with [`ops::FuncDefn`] and return a builder for
61    /// the defining graph.
62    ///
63    /// # Errors
64    ///
65    /// This function will return an error if there is an error in adding the
66    /// [`crate::ops::OpType::FuncDefn`] node.
67    pub fn define_declaration(
68        &mut self,
69        f_id: &FuncID<false>,
70    ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
71        let f_node = f_id.node();
72        let opty = self.hugr_mut().optype_mut(f_node);
73        let ops::OpType::FuncDecl(decl) = opty else {
74            return Err(BuildError::UnexpectedType {
75                node: f_node,
76                op_desc: "crate::ops::OpType::FuncDecl",
77            });
78        };
79
80        let body = decl.signature().body().clone();
81        *opty = ops::FuncDefn::new_vis(
82            decl.func_name(),
83            decl.signature().clone(),
84            decl.visibility().clone(),
85        )
86        .into();
87
88        let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
89        Ok(FunctionBuilder::from_dfg_builder(db))
90    }
91
92    /// Add a [`ops::FuncDefn`] node of the specified visibility.
93    /// Returns a builder to define the function body graph.
94    ///
95    /// # Errors
96    ///
97    /// This function will return an error if there is an error in adding the
98    /// [`ops::FuncDefn`] node.
99    pub fn define_function_vis(
100        &mut self,
101        name: impl Into<String>,
102        signature: impl Into<PolyFuncType>,
103        visibility: Visibility,
104    ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
105        self.define_function_op(FuncDefn::new_vis(name, signature, visibility))
106    }
107
108    fn define_function_op(
109        &mut self,
110        op: FuncDefn,
111    ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
112        let body = op.signature().body().clone();
113        let f_node = self.add_child_node(op);
114
115        // Add the extensions used by the function types.
116        self.use_extensions(
117            body.used_extensions().unwrap_or_else(|e| {
118                panic!("Build-time signatures should have valid extensions. {e}")
119            }),
120        );
121
122        let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
123        Ok(FunctionBuilder::from_dfg_builder(db))
124    }
125
126    /// Declare a [Visibility::Public] function with `signature` and return a handle to the declaration.
127    ///
128    /// # Errors
129    ///
130    /// This function will return an error if there is an error in adding the
131    /// [`crate::ops::OpType::FuncDecl`] node.
132    pub fn declare(
133        &mut self,
134        name: impl Into<String>,
135        signature: PolyFuncType,
136    ) -> Result<FuncID<false>, BuildError> {
137        self.declare_vis(name, signature, Visibility::Public)
138    }
139
140    /// Declare a function with the specified `signature` and [Visibility],
141    /// and return a handle to the declaration.
142    ///
143    /// # Errors
144    ///
145    /// This function will return an error if there is an error in adding the
146    /// [`crate::ops::OpType::FuncDecl`] node.
147    pub fn declare_vis(
148        &mut self,
149        name: impl Into<String>,
150        signature: PolyFuncType,
151        visibility: Visibility,
152    ) -> Result<FuncID<false>, BuildError> {
153        let body = signature.body().clone();
154        // TODO add param names to metadata
155        let declare_n = self.add_child_node(ops::FuncDecl::new_vis(name, signature, visibility));
156
157        // Add the extensions used by the function types.
158        self.use_extensions(
159            body.used_extensions().unwrap_or_else(|e| {
160                panic!("Build-time signatures should have valid extensions. {e}")
161            }),
162        );
163
164        Ok(declare_n.into())
165    }
166
167    /// Adds a [`ops::FuncDefn`] node and returns a builder to define the function
168    /// body graph. The function will be private. (See [Self::define_function_vis].)
169    ///
170    /// # Errors
171    ///
172    /// This function will return an error if there is an error in adding the
173    /// [`ops::FuncDefn`] node.
174    pub fn define_function(
175        &mut self,
176        name: impl Into<String>,
177        signature: impl Into<PolyFuncType>,
178    ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
179        self.define_function_op(FuncDefn::new(name, signature))
180    }
181
182    /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias.
183    ///
184    /// # Errors
185    ///
186    /// Error in adding [`crate::ops::OpType::AliasDefn`] child node.
187    pub fn add_alias_def(
188        &mut self,
189        name: impl Into<SmolStr>,
190        typ: Type,
191    ) -> Result<AliasID<true>, BuildError> {
192        // TODO: add AliasDefn in other containers
193        // This is currently tricky as they are not connected to anything so do
194        // not appear in topological traversals.
195        // Could be fixed by removing single-entry requirement and sorting from
196        // every 0-input node.
197        let name: SmolStr = name.into();
198        let bound = typ.least_upper_bound();
199        let node = self.add_child_node(ops::AliasDefn {
200            name: name.clone(),
201            definition: typ,
202        });
203
204        Ok(AliasID::new(node, name, bound))
205    }
206
207    /// Add a [`crate::ops::OpType::AliasDecl`] node and return a handle to the Alias.
208    /// # Errors
209    ///
210    /// Error in adding [`crate::ops::OpType::AliasDecl`] child node.
211    pub fn add_alias_declare(
212        &mut self,
213        name: impl Into<SmolStr>,
214        bound: TypeBound,
215    ) -> Result<AliasID<false>, BuildError> {
216        let name: SmolStr = name.into();
217        let node = self.add_child_node(ops::AliasDecl {
218            name: name.clone(),
219            bound,
220        });
221
222        Ok(AliasID::new(node, name, bound))
223    }
224
225    /// Add some module-children of another Hugr to this module, with
226    /// linking directives specified explicitly by [Node].
227    ///
228    /// `children` contains a map from the children of `other` to insert,
229    /// to how they should be combined with the nodes in `self`. Note if
230    /// this map is empty, nothing is added.
231    pub fn link_hugr_by_node(
232        &mut self,
233        other: Hugr,
234        children: NodeLinkingDirectives<Node, Node>,
235    ) -> Result<InsertedForest, NodeLinkingError> {
236        self.hugr_mut()
237            .insert_link_hugr_by_node(None, other, children)
238    }
239
240    /// Copy module-children from a HugrView into this module, with
241    /// linking directives specified explicitly by [Node].
242    ///
243    /// `children` contains a map from the children of `other` to copy,
244    /// to how they should be combined with the nodes in `self`. Note if
245    /// this map is empty, nothing is added.
246    pub fn link_view_by_node<H: HugrView>(
247        &mut self,
248        other: &H,
249        children: NodeLinkingDirectives<H::Node, Node>,
250    ) -> Result<InsertedForest<H::Node>, NodeLinkingError<H::Node>> {
251        self.hugr_mut()
252            .insert_link_view_by_node(None, other, children)
253    }
254}
255
256#[cfg(test)]
257mod test {
258    use std::collections::{HashMap, HashSet};
259
260    use cool_asserts::assert_matches;
261
262    use crate::builder::test::dfg_calling_defn_decl;
263    use crate::builder::{Dataflow, DataflowSubContainer, test::n_identity};
264    use crate::extension::prelude::usize_t;
265    use crate::{hugr::linking::NodeLinkingDirective, ops::OpType, types::Signature};
266
267    use super::*;
268    #[test]
269    fn basic_recurse() -> Result<(), BuildError> {
270        let build_result = {
271            let mut module_builder = ModuleBuilder::new();
272
273            let f_id = module_builder.declare(
274                "main",
275                Signature::new(vec![usize_t()], vec![usize_t()]).into(),
276            )?;
277
278            let mut f_build = module_builder.define_declaration(&f_id)?;
279            let call = f_build.call(&f_id, &[], f_build.input_wires())?;
280
281            f_build.finish_with_outputs(call.outputs())?;
282            module_builder.finish_hugr()
283        };
284        assert_matches!(build_result, Ok(_));
285        Ok(())
286    }
287
288    #[test]
289    fn simple_alias() -> Result<(), BuildError> {
290        let build_result = {
291            let mut module_builder = ModuleBuilder::new();
292
293            let qubit_state_type =
294                module_builder.add_alias_declare("qubit_state", TypeBound::Linear)?;
295
296            let f_build = module_builder.define_function(
297                "main",
298                Signature::new(
299                    vec![qubit_state_type.get_alias_type()],
300                    vec![qubit_state_type.get_alias_type()],
301                ),
302            )?;
303            n_identity(f_build)?;
304            module_builder.finish_hugr()
305        };
306        assert_matches!(build_result, Ok(_));
307        Ok(())
308    }
309
310    #[test]
311    fn builder_from_existing() -> Result<(), BuildError> {
312        let hugr = Hugr::new();
313
314        let fn_builder = FunctionBuilder::with_hugr(hugr, "main", Signature::new_endo(vec![]))?;
315        let mut hugr = fn_builder.finish_hugr()?;
316
317        let mut module_builder = ModuleBuilder::with_hugr(&mut hugr);
318        module_builder.declare("other", Signature::new_endo(vec![]).into())?;
319
320        hugr.validate()?;
321
322        Ok(())
323    }
324
325    #[test]
326    fn link_by_node() {
327        let mut mb = ModuleBuilder::new();
328        let (dfg, defn, decl) = dfg_calling_defn_decl();
329        let added = mb
330            .link_view_by_node(
331                &dfg,
332                HashMap::from([
333                    (defn.node(), NodeLinkingDirective::add()),
334                    (decl.node(), NodeLinkingDirective::add()),
335                ]),
336            )
337            .unwrap();
338        let n_defn = added.node_map[&defn.node()];
339        let n_decl = added.node_map[&decl.node()];
340        let h = mb.hugr();
341        assert_eq!(h.children(h.module_root()).count(), 2);
342        h.validate().unwrap();
343        let old_name = match mb.hugr_mut().optype_mut(n_defn) {
344            OpType::FuncDefn(fd) => std::mem::replace(fd.func_name_mut(), "new".to_string()),
345            _ => panic!(),
346        };
347        let main = dfg.get_parent(dfg.entrypoint()).unwrap();
348        assert_eq!(
349            dfg.get_optype(main).as_func_defn().unwrap().func_name(),
350            "main"
351        );
352        mb.link_hugr_by_node(
353            dfg,
354            HashMap::from([
355                (main, NodeLinkingDirective::add()),
356                (decl.node(), NodeLinkingDirective::UseExisting(n_defn)),
357                (defn.node(), NodeLinkingDirective::replace([n_decl])),
358            ]),
359        )
360        .unwrap();
361        let h = mb.finish_hugr().unwrap();
362        assert_eq!(
363            h.children(h.module_root())
364                .map(|n| h.get_optype(n).as_func_defn().unwrap().func_name().as_str())
365                .collect::<HashSet<_>>(),
366            HashSet::from(["main", "new", old_name.as_str()])
367        );
368    }
369}