hugr_core/builder/
module.rs

1use super::{
2    build_traits::HugrBuilder,
3    dataflow::{DFGBuilder, FunctionBuilder},
4    BuildError, Container,
5};
6
7use crate::hugr::internal::HugrMutInternals;
8use crate::hugr::views::HugrView;
9use crate::hugr::ValidationError;
10use crate::ops;
11use crate::types::{PolyFuncType, Type, TypeBound};
12
13use crate::ops::handle::{AliasID, FuncID, NodeHandle};
14
15use crate::{Hugr, Node};
16use smol_str::SmolStr;
17
18/// Builder for a HUGR module.
19#[derive(Debug, Clone, PartialEq)]
20pub struct ModuleBuilder<T>(pub(super) T);
21
22impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
23    #[inline]
24    fn container_node(&self) -> Node {
25        self.0.as_ref().entrypoint()
26    }
27
28    #[inline]
29    fn hugr_mut(&mut self) -> &mut Hugr {
30        self.0.as_mut()
31    }
32
33    fn hugr(&self) -> &Hugr {
34        self.0.as_ref()
35    }
36}
37
38impl ModuleBuilder<Hugr> {
39    /// Begin building a new module.
40    #[must_use]
41    pub fn new() -> Self {
42        Self(Default::default())
43    }
44}
45
46impl Default for ModuleBuilder<Hugr> {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl HugrBuilder for ModuleBuilder<Hugr> {
53    fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
54        self.0.validate()?;
55        Ok(self.0)
56    }
57}
58
59impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
60    /// Replace a [`ops::FuncDecl`] with [`ops::FuncDefn`] and return a builder for
61    /// the defining graph.
62    ///
63    /// # Errors
64    ///
65    /// This function will return an error if there is an error in adding the
66    /// [`crate::ops::OpType::FuncDefn`] node.
67    pub fn define_declaration(
68        &mut self,
69        f_id: &FuncID<false>,
70    ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
71        let f_node = f_id.node();
72        let ops::FuncDecl { signature, name } = self
73            .hugr()
74            .get_optype(f_node)
75            .as_func_decl()
76            .ok_or(BuildError::UnexpectedType {
77                node: f_node,
78                op_desc: "crate::ops::OpType::FuncDecl",
79            })?
80            .clone();
81        let body = signature.body().clone();
82        self.hugr_mut()
83            .replace_op(f_node, ops::FuncDefn { name, signature });
84
85        let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
86        Ok(FunctionBuilder::from_dfg_builder(db))
87    }
88
89    /// Declare a function with `signature` and return a handle to the declaration.
90    ///
91    /// # Errors
92    ///
93    /// This function will return an error if there is an error in adding the
94    /// [`crate::ops::OpType::FuncDecl`] node.
95    pub fn declare(
96        &mut self,
97        name: impl Into<String>,
98        signature: PolyFuncType,
99    ) -> Result<FuncID<false>, BuildError> {
100        let body = signature.body().clone();
101        // TODO add param names to metadata
102        let declare_n = self.add_child_node(ops::FuncDecl {
103            signature,
104            name: name.into(),
105        });
106
107        // Add the extensions used by the function types.
108        self.use_extensions(
109            body.used_extensions().unwrap_or_else(|e| {
110                panic!("Build-time signatures should have valid extensions. {e}")
111            }),
112        );
113
114        Ok(declare_n.into())
115    }
116
117    /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias.
118    ///
119    /// # Errors
120    ///
121    /// Error in adding [`crate::ops::OpType::AliasDefn`] child node.
122    pub fn add_alias_def(
123        &mut self,
124        name: impl Into<SmolStr>,
125        typ: Type,
126    ) -> Result<AliasID<true>, BuildError> {
127        // TODO: add AliasDefn in other containers
128        // This is currently tricky as they are not connected to anything so do
129        // not appear in topological traversals.
130        // Could be fixed by removing single-entry requirement and sorting from
131        // every 0-input node.
132        let name: SmolStr = name.into();
133        let bound = typ.least_upper_bound();
134        let node = self.add_child_node(ops::AliasDefn {
135            name: name.clone(),
136            definition: typ,
137        });
138
139        Ok(AliasID::new(node, name, bound))
140    }
141
142    /// Add a [`crate::ops::OpType::AliasDecl`] node and return a handle to the Alias.
143    /// # Errors
144    ///
145    /// Error in adding [`crate::ops::OpType::AliasDecl`] child node.
146    pub fn add_alias_declare(
147        &mut self,
148        name: impl Into<SmolStr>,
149        bound: TypeBound,
150    ) -> Result<AliasID<false>, BuildError> {
151        let name: SmolStr = name.into();
152        let node = self.add_child_node(ops::AliasDecl {
153            name: name.clone(),
154            bound,
155        });
156
157        Ok(AliasID::new(node, name, bound))
158    }
159}
160
161#[cfg(test)]
162mod test {
163    use cool_asserts::assert_matches;
164
165    use crate::extension::prelude::usize_t;
166    use crate::{
167        builder::{test::n_identity, Dataflow, DataflowSubContainer},
168        types::Signature,
169    };
170
171    use super::*;
172    #[test]
173    fn basic_recurse() -> Result<(), BuildError> {
174        let build_result = {
175            let mut module_builder = ModuleBuilder::new();
176
177            let f_id = module_builder.declare(
178                "main",
179                Signature::new(vec![usize_t()], vec![usize_t()]).into(),
180            )?;
181
182            let mut f_build = module_builder.define_declaration(&f_id)?;
183            let call = f_build.call(&f_id, &[], f_build.input_wires())?;
184
185            f_build.finish_with_outputs(call.outputs())?;
186            module_builder.finish_hugr()
187        };
188        assert_matches!(build_result, Ok(_));
189        Ok(())
190    }
191
192    #[test]
193    fn simple_alias() -> Result<(), BuildError> {
194        let build_result = {
195            let mut module_builder = ModuleBuilder::new();
196
197            let qubit_state_type =
198                module_builder.add_alias_declare("qubit_state", TypeBound::Any)?;
199
200            let f_build = module_builder.define_function(
201                "main",
202                Signature::new(
203                    vec![qubit_state_type.get_alias_type()],
204                    vec![qubit_state_type.get_alias_type()],
205                ),
206            )?;
207            n_identity(f_build)?;
208            module_builder.finish_hugr()
209        };
210        assert_matches!(build_result, Ok(_));
211        Ok(())
212    }
213
214    #[test]
215    fn local_def() -> Result<(), BuildError> {
216        let build_result = {
217            let mut module_builder = ModuleBuilder::new();
218
219            let mut f_build = module_builder.define_function(
220                "main",
221                Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
222            )?;
223            let local_build = f_build.define_function(
224                "local",
225                Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
226            )?;
227            let [wire] = local_build.input_wires_arr();
228            let f_id = local_build.finish_with_outputs([wire, wire])?;
229
230            let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?;
231
232            f_build.finish_with_outputs(call.outputs())?;
233            module_builder.finish_hugr()
234        };
235        assert_matches!(build_result, Ok(_));
236        Ok(())
237    }
238}