hugr_core/builder/
module.rs

1use super::{
2    build_traits::HugrBuilder,
3    dataflow::{DFGBuilder, FunctionBuilder},
4    BuildError, Container,
5};
6
7use crate::hugr::internal::HugrMutInternals;
8use crate::hugr::views::HugrView;
9use crate::hugr::ValidationError;
10use crate::ops;
11use crate::types::{PolyFuncType, Type, TypeBound};
12
13use crate::ops::handle::{AliasID, FuncID, NodeHandle};
14
15use crate::{Hugr, Node};
16use smol_str::SmolStr;
17
18/// Builder for a HUGR module.
19#[derive(Debug, Clone, PartialEq)]
20pub struct ModuleBuilder<T>(pub(super) T);
21
22impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
23    #[inline]
24    fn container_node(&self) -> Node {
25        self.0.as_ref().root()
26    }
27
28    #[inline]
29    fn hugr_mut(&mut self) -> &mut Hugr {
30        self.0.as_mut()
31    }
32
33    fn hugr(&self) -> &Hugr {
34        self.0.as_ref()
35    }
36}
37
38impl ModuleBuilder<Hugr> {
39    /// Begin building a new module.
40    #[must_use]
41    pub fn new() -> Self {
42        Self(Default::default())
43    }
44}
45
46impl Default for ModuleBuilder<Hugr> {
47    fn default() -> Self {
48        Self::new()
49    }
50}
51
52impl HugrBuilder for ModuleBuilder<Hugr> {
53    fn finish_hugr(mut self) -> Result<Hugr, ValidationError> {
54        if cfg!(feature = "extension_inference") {
55            self.0.infer_extensions(false)?;
56        }
57        self.0.validate()?;
58        Ok(self.0)
59    }
60}
61
62impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
63    /// Replace a [`ops::FuncDecl`] with [`ops::FuncDefn`] and return a builder for
64    /// the defining graph.
65    ///
66    /// # Errors
67    ///
68    /// This function will return an error if there is an error in adding the
69    /// [`crate::ops::OpType::FuncDefn`] node.
70    pub fn define_declaration(
71        &mut self,
72        f_id: &FuncID<false>,
73    ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
74        let f_node = f_id.node();
75        let ops::FuncDecl { signature, name } = self
76            .hugr()
77            .get_optype(f_node)
78            .as_func_decl()
79            .ok_or(BuildError::UnexpectedType {
80                node: f_node,
81                op_desc: "crate::ops::OpType::FuncDecl",
82            })?
83            .clone();
84        let body = signature.body().clone();
85        self.hugr_mut()
86            .replace_op(f_node, ops::FuncDefn { name, signature })
87            .expect("Replacing a FuncDecl node with a FuncDefn should always be valid");
88
89        let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
90        Ok(FunctionBuilder::from_dfg_builder(db))
91    }
92
93    /// Declare a function with `signature` and return a handle to the declaration.
94    ///
95    /// # Errors
96    ///
97    /// This function will return an error if there is an error in adding the
98    /// [`crate::ops::OpType::FuncDecl`] node.
99    pub fn declare(
100        &mut self,
101        name: impl Into<String>,
102        signature: PolyFuncType,
103    ) -> Result<FuncID<false>, BuildError> {
104        let body = signature.body().clone();
105        // TODO add param names to metadata
106        let declare_n = self.add_child_node(ops::FuncDecl {
107            signature,
108            name: name.into(),
109        });
110
111        // Add the extensions used by the function types.
112        self.use_extensions(
113            body.used_extensions().unwrap_or_else(|e| {
114                panic!("Build-time signatures should have valid extensions. {e}")
115            }),
116        );
117
118        Ok(declare_n.into())
119    }
120
121    /// Add a [`crate::ops::OpType::AliasDefn`] node and return a handle to the Alias.
122    ///
123    /// # Errors
124    ///
125    /// Error in adding [`crate::ops::OpType::AliasDefn`] child node.
126    pub fn add_alias_def(
127        &mut self,
128        name: impl Into<SmolStr>,
129        typ: Type,
130    ) -> Result<AliasID<true>, BuildError> {
131        // TODO: add AliasDefn in other containers
132        // This is currently tricky as they are not connected to anything so do
133        // not appear in topological traversals.
134        // Could be fixed by removing single-entry requirement and sorting from
135        // every 0-input node.
136        let name: SmolStr = name.into();
137        let bound = typ.least_upper_bound();
138        let node = self.add_child_node(ops::AliasDefn {
139            name: name.clone(),
140            definition: typ,
141        });
142
143        Ok(AliasID::new(node, name, bound))
144    }
145
146    /// Add a [`crate::ops::OpType::AliasDecl`] node and return a handle to the Alias.
147    /// # Errors
148    ///
149    /// Error in adding [`crate::ops::OpType::AliasDecl`] child node.
150    pub fn add_alias_declare(
151        &mut self,
152        name: impl Into<SmolStr>,
153        bound: TypeBound,
154    ) -> Result<AliasID<false>, BuildError> {
155        let name: SmolStr = name.into();
156        let node = self.add_child_node(ops::AliasDecl {
157            name: name.clone(),
158            bound,
159        });
160
161        Ok(AliasID::new(node, name, bound))
162    }
163}
164
165#[cfg(test)]
166mod test {
167    use cool_asserts::assert_matches;
168
169    use crate::extension::prelude::usize_t;
170    use crate::{
171        builder::{test::n_identity, Dataflow, DataflowSubContainer},
172        types::Signature,
173    };
174
175    use super::*;
176    #[test]
177    fn basic_recurse() -> Result<(), BuildError> {
178        let build_result = {
179            let mut module_builder = ModuleBuilder::new();
180
181            let f_id = module_builder.declare(
182                "main",
183                Signature::new(vec![usize_t()], vec![usize_t()]).into(),
184            )?;
185
186            let mut f_build = module_builder.define_declaration(&f_id)?;
187            let call = f_build.call(&f_id, &[], f_build.input_wires())?;
188
189            f_build.finish_with_outputs(call.outputs())?;
190            module_builder.finish_hugr()
191        };
192        assert_matches!(build_result, Ok(_));
193        Ok(())
194    }
195
196    #[test]
197    fn simple_alias() -> Result<(), BuildError> {
198        let build_result = {
199            let mut module_builder = ModuleBuilder::new();
200
201            let qubit_state_type =
202                module_builder.add_alias_declare("qubit_state", TypeBound::Any)?;
203
204            let f_build = module_builder.define_function(
205                "main",
206                Signature::new(
207                    vec![qubit_state_type.get_alias_type()],
208                    vec![qubit_state_type.get_alias_type()],
209                ),
210            )?;
211            n_identity(f_build)?;
212            module_builder.finish_hugr()
213        };
214        assert_matches!(build_result, Ok(_));
215        Ok(())
216    }
217
218    #[test]
219    fn local_def() -> Result<(), BuildError> {
220        let build_result = {
221            let mut module_builder = ModuleBuilder::new();
222
223            let mut f_build = module_builder.define_function(
224                "main",
225                Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
226            )?;
227            let local_build = f_build.define_function(
228                "local",
229                Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
230            )?;
231            let [wire] = local_build.input_wires_arr();
232            let f_id = local_build.finish_with_outputs([wire, wire])?;
233
234            let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?;
235
236            f_build.finish_with_outputs(call.outputs())?;
237            module_builder.finish_hugr()
238        };
239        assert_matches!(build_result, Ok(_));
240        Ok(())
241    }
242}