hugr_core/builder/
module.rs1use super::{
2 BuildError, Container,
3 build_traits::HugrBuilder,
4 dataflow::{DFGBuilder, FunctionBuilder},
5};
6
7use crate::hugr::ValidationError;
8use crate::hugr::internal::HugrMutInternals;
9use crate::hugr::views::HugrView;
10use crate::ops;
11use crate::types::{PolyFuncType, Type, TypeBound};
12
13use crate::ops::handle::{AliasID, FuncID, NodeHandle};
14
15use crate::{Hugr, Node};
16use smol_str::SmolStr;
17
18#[derive(Debug, Clone, PartialEq)]
20pub struct ModuleBuilder<T>(pub(super) T);
21
22impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
23 #[inline]
24 fn container_node(&self) -> Node {
25 self.0.as_ref().entrypoint()
26 }
27
28 #[inline]
29 fn hugr_mut(&mut self) -> &mut Hugr {
30 self.0.as_mut()
31 }
32
33 fn hugr(&self) -> &Hugr {
34 self.0.as_ref()
35 }
36}
37
38impl ModuleBuilder<Hugr> {
39 #[must_use]
41 pub fn new() -> Self {
42 Self(Default::default())
43 }
44}
45
46impl Default for ModuleBuilder<Hugr> {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl HugrBuilder for ModuleBuilder<Hugr> {
53 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
54 self.0.validate()?;
55 Ok(self.0)
56 }
57}
58
59impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
60 pub fn define_declaration(
68 &mut self,
69 f_id: &FuncID<false>,
70 ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
71 let f_node = f_id.node();
72 let decl =
73 self.hugr()
74 .get_optype(f_node)
75 .as_func_decl()
76 .ok_or(BuildError::UnexpectedType {
77 node: f_node,
78 op_desc: "crate::ops::OpType::FuncDecl",
79 })?;
80 let name = decl.func_name().clone();
81 let sig = decl.signature().clone();
82 let body = sig.body().clone();
83 self.hugr_mut()
84 .replace_op(f_node, ops::FuncDefn::new(name, sig));
85
86 let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
87 Ok(FunctionBuilder::from_dfg_builder(db))
88 }
89
90 pub fn declare(
97 &mut self,
98 name: impl Into<String>,
99 signature: PolyFuncType,
100 ) -> Result<FuncID<false>, BuildError> {
101 let body = signature.body().clone();
102 let declare_n = self.add_child_node(ops::FuncDecl::new(name, signature));
104
105 self.use_extensions(
107 body.used_extensions().unwrap_or_else(|e| {
108 panic!("Build-time signatures should have valid extensions. {e}")
109 }),
110 );
111
112 Ok(declare_n.into())
113 }
114
115 pub fn add_alias_def(
121 &mut self,
122 name: impl Into<SmolStr>,
123 typ: Type,
124 ) -> Result<AliasID<true>, BuildError> {
125 let name: SmolStr = name.into();
131 let bound = typ.least_upper_bound();
132 let node = self.add_child_node(ops::AliasDefn {
133 name: name.clone(),
134 definition: typ,
135 });
136
137 Ok(AliasID::new(node, name, bound))
138 }
139
140 pub fn add_alias_declare(
145 &mut self,
146 name: impl Into<SmolStr>,
147 bound: TypeBound,
148 ) -> Result<AliasID<false>, BuildError> {
149 let name: SmolStr = name.into();
150 let node = self.add_child_node(ops::AliasDecl {
151 name: name.clone(),
152 bound,
153 });
154
155 Ok(AliasID::new(node, name, bound))
156 }
157}
158
159#[cfg(test)]
160mod test {
161 use cool_asserts::assert_matches;
162
163 use crate::extension::prelude::usize_t;
164 use crate::{
165 builder::{Dataflow, DataflowSubContainer, test::n_identity},
166 types::Signature,
167 };
168
169 use super::*;
170 #[test]
171 fn basic_recurse() -> Result<(), BuildError> {
172 let build_result = {
173 let mut module_builder = ModuleBuilder::new();
174
175 let f_id = module_builder.declare(
176 "main",
177 Signature::new(vec![usize_t()], vec![usize_t()]).into(),
178 )?;
179
180 let mut f_build = module_builder.define_declaration(&f_id)?;
181 let call = f_build.call(&f_id, &[], f_build.input_wires())?;
182
183 f_build.finish_with_outputs(call.outputs())?;
184 module_builder.finish_hugr()
185 };
186 assert_matches!(build_result, Ok(_));
187 Ok(())
188 }
189
190 #[test]
191 fn simple_alias() -> Result<(), BuildError> {
192 let build_result = {
193 let mut module_builder = ModuleBuilder::new();
194
195 let qubit_state_type =
196 module_builder.add_alias_declare("qubit_state", TypeBound::Any)?;
197
198 let f_build = module_builder.define_function(
199 "main",
200 Signature::new(
201 vec![qubit_state_type.get_alias_type()],
202 vec![qubit_state_type.get_alias_type()],
203 ),
204 )?;
205 n_identity(f_build)?;
206 module_builder.finish_hugr()
207 };
208 assert_matches!(build_result, Ok(_));
209 Ok(())
210 }
211
212 #[test]
213 fn local_def() -> Result<(), BuildError> {
214 let build_result = {
215 let mut module_builder = ModuleBuilder::new();
216
217 let mut f_build = module_builder.define_function(
218 "main",
219 Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
220 )?;
221 let local_build = f_build.define_function(
222 "local",
223 Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
224 )?;
225 let [wire] = local_build.input_wires_arr();
226 let f_id = local_build.finish_with_outputs([wire, wire])?;
227
228 let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?;
229
230 f_build.finish_with_outputs(call.outputs())?;
231 module_builder.finish_hugr()
232 };
233 assert_matches!(build_result, Ok(_));
234 Ok(())
235 }
236}