hugr_core/builder/
module.rs

1use super::{
2    BuildError, Container,
3    build_traits::HugrBuilder,
4    dataflow::{DFGBuilder, FunctionBuilder},
5};
6
7use crate::hugr::ValidationError;
8use crate::hugr::internal::HugrMutInternals;
9use crate::hugr::views::HugrView;
10use crate::ops;
11use crate::types::{PolyFuncType, Type, TypeBound};
12
13use crate::ops::handle::{AliasID, FuncID, NodeHandle};
14
15use crate::{Hugr, Node};
16use smol_str::SmolStr;
17
18/// Builder for a HUGR module.
19#[derive(Debug, Clone, PartialEq)]
20pub struct ModuleBuilder<T>(pub(super) T);
21
22impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
23    #[inline]
24    fn container_node(&self) -> Node {
25        self.0.as_ref().entrypoint()
26    }
27
28    #[inline]
29    fn hugr_mut(&mut self) -> &mut Hugr {
30        self.0.as_mut()
31    }
32
33    fn hugr(&self) -> &Hugr {
34        self.0.as_ref()
35    }
36}
37
38impl ModuleBuilder<Hugr> {
39    /// Begin building a new module.
40    #[must_use]
41    pub fn new() -> Self {
42        Self(Default::default())
43    }
44}
45
46impl Default for ModuleBuilder<Hugr> {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl HugrBuilder for ModuleBuilder<Hugr> {
53    fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
54        self.0.validate()?;
55        Ok(self.0)
56    }
57}
58
59impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
60    /// Replace a [`ops::FuncDecl`] with [`ops::FuncDefn`] and return a builder for
61    /// the defining graph.
62    ///
63    /// # Errors
64    ///
65    /// This function will return an error if there is an error in adding the
66    /// [`crate::ops::OpType::FuncDefn`] node.
67    pub fn define_declaration(
68        &mut self,
69        f_id: &FuncID<false>,
70    ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
71        let f_node = f_id.node();
72        let decl =
73            self.hugr()
74                .get_optype(f_node)
75                .as_func_decl()
76                .ok_or(BuildError::UnexpectedType {
77                    node: f_node,
78                    op_desc: "crate::ops::OpType::FuncDecl",
79                })?;
80        let name = decl.func_name().clone();
81        let sig = decl.signature().clone();
82        let body = sig.body().clone();
83        self.hugr_mut()
84            .replace_op(f_node, ops::FuncDefn::new(name, sig));
85
86        let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
87        Ok(FunctionBuilder::from_dfg_builder(db))
88    }
89
90    /// Declare a function with `signature` and return a handle to the declaration.
91    ///
92    /// # Errors
93    ///
94    /// This function will return an error if there is an error in adding the
95    /// [`crate::ops::OpType::FuncDecl`] node.
96    pub fn declare(
97        &mut self,
98        name: impl Into<String>,
99        signature: PolyFuncType,
100    ) -> Result<FuncID<false>, BuildError> {
101        let body = signature.body().clone();
102        // TODO add param names to metadata
103        let declare_n = self.add_child_node(ops::FuncDecl::new(name, signature));
104
105        // Add the extensions used by the function types.
106        self.use_extensions(
107            body.used_extensions().unwrap_or_else(|e| {
108                panic!("Build-time signatures should have valid extensions. {e}")
109            }),
110        );
111
112        Ok(declare_n.into())
113    }
114
115    /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias.
116    ///
117    /// # Errors
118    ///
119    /// Error in adding [`crate::ops::OpType::AliasDefn`] child node.
120    pub fn add_alias_def(
121        &mut self,
122        name: impl Into<SmolStr>,
123        typ: Type,
124    ) -> Result<AliasID<true>, BuildError> {
125        // TODO: add AliasDefn in other containers
126        // This is currently tricky as they are not connected to anything so do
127        // not appear in topological traversals.
128        // Could be fixed by removing single-entry requirement and sorting from
129        // every 0-input node.
130        let name: SmolStr = name.into();
131        let bound = typ.least_upper_bound();
132        let node = self.add_child_node(ops::AliasDefn {
133            name: name.clone(),
134            definition: typ,
135        });
136
137        Ok(AliasID::new(node, name, bound))
138    }
139
140    /// Add a [`crate::ops::OpType::AliasDecl`] node and return a handle to the Alias.
141    /// # Errors
142    ///
143    /// Error in adding [`crate::ops::OpType::AliasDecl`] child node.
144    pub fn add_alias_declare(
145        &mut self,
146        name: impl Into<SmolStr>,
147        bound: TypeBound,
148    ) -> Result<AliasID<false>, BuildError> {
149        let name: SmolStr = name.into();
150        let node = self.add_child_node(ops::AliasDecl {
151            name: name.clone(),
152            bound,
153        });
154
155        Ok(AliasID::new(node, name, bound))
156    }
157}
158
159#[cfg(test)]
160mod test {
161    use cool_asserts::assert_matches;
162
163    use crate::extension::prelude::usize_t;
164    use crate::{
165        builder::{Dataflow, DataflowSubContainer, test::n_identity},
166        types::Signature,
167    };
168
169    use super::*;
170    #[test]
171    fn basic_recurse() -> Result<(), BuildError> {
172        let build_result = {
173            let mut module_builder = ModuleBuilder::new();
174
175            let f_id = module_builder.declare(
176                "main",
177                Signature::new(vec![usize_t()], vec![usize_t()]).into(),
178            )?;
179
180            let mut f_build = module_builder.define_declaration(&f_id)?;
181            let call = f_build.call(&f_id, &[], f_build.input_wires())?;
182
183            f_build.finish_with_outputs(call.outputs())?;
184            module_builder.finish_hugr()
185        };
186        assert_matches!(build_result, Ok(_));
187        Ok(())
188    }
189
190    #[test]
191    fn simple_alias() -> Result<(), BuildError> {
192        let build_result = {
193            let mut module_builder = ModuleBuilder::new();
194
195            let qubit_state_type =
196                module_builder.add_alias_declare("qubit_state", TypeBound::Any)?;
197
198            let f_build = module_builder.define_function(
199                "main",
200                Signature::new(
201                    vec![qubit_state_type.get_alias_type()],
202                    vec![qubit_state_type.get_alias_type()],
203                ),
204            )?;
205            n_identity(f_build)?;
206            module_builder.finish_hugr()
207        };
208        assert_matches!(build_result, Ok(_));
209        Ok(())
210    }
211
212    #[test]
213    fn local_def() -> Result<(), BuildError> {
214        let build_result = {
215            let mut module_builder = ModuleBuilder::new();
216
217            let mut f_build = module_builder.define_function(
218                "main",
219                Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
220            )?;
221            let local_build = f_build.define_function(
222                "local",
223                Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
224            )?;
225            let [wire] = local_build.input_wires_arr();
226            let f_id = local_build.finish_with_outputs([wire, wire])?;
227
228            let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?;
229
230            f_build.finish_with_outputs(call.outputs())?;
231            module_builder.finish_hugr()
232        };
233        assert_matches!(build_result, Ok(_));
234        Ok(())
235    }
236}