hugr_core/builder/
module.rs1use super::{
2 build_traits::HugrBuilder,
3 dataflow::{DFGBuilder, FunctionBuilder},
4 BuildError, Container,
5};
6
7use crate::hugr::internal::HugrMutInternals;
8use crate::hugr::views::HugrView;
9use crate::hugr::ValidationError;
10use crate::ops;
11use crate::types::{PolyFuncType, Type, TypeBound};
12
13use crate::ops::handle::{AliasID, FuncID, NodeHandle};
14
15use crate::{Hugr, Node};
16use smol_str::SmolStr;
17
18#[derive(Debug, Clone, PartialEq)]
20pub struct ModuleBuilder<T>(pub(super) T);
21
22impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
23 #[inline]
24 fn container_node(&self) -> Node {
25 self.0.as_ref().entrypoint()
26 }
27
28 #[inline]
29 fn hugr_mut(&mut self) -> &mut Hugr {
30 self.0.as_mut()
31 }
32
33 fn hugr(&self) -> &Hugr {
34 self.0.as_ref()
35 }
36}
37
38impl ModuleBuilder<Hugr> {
39 #[must_use]
41 pub fn new() -> Self {
42 Self(Default::default())
43 }
44}
45
46impl Default for ModuleBuilder<Hugr> {
47 fn default() -> Self {
48 Self::new()
49 }
50}
51
52impl HugrBuilder for ModuleBuilder<Hugr> {
53 fn finish_hugr(self) -> Result<Hugr, ValidationError<Node>> {
54 self.0.validate()?;
55 Ok(self.0)
56 }
57}
58
59impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
60 pub fn define_declaration(
68 &mut self,
69 f_id: &FuncID<false>,
70 ) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
71 let f_node = f_id.node();
72 let ops::FuncDecl { signature, name } = self
73 .hugr()
74 .get_optype(f_node)
75 .as_func_decl()
76 .ok_or(BuildError::UnexpectedType {
77 node: f_node,
78 op_desc: "crate::ops::OpType::FuncDecl",
79 })?
80 .clone();
81 let body = signature.body().clone();
82 self.hugr_mut()
83 .replace_op(f_node, ops::FuncDefn { name, signature });
84
85 let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
86 Ok(FunctionBuilder::from_dfg_builder(db))
87 }
88
89 pub fn declare(
96 &mut self,
97 name: impl Into<String>,
98 signature: PolyFuncType,
99 ) -> Result<FuncID<false>, BuildError> {
100 let body = signature.body().clone();
101 let declare_n = self.add_child_node(ops::FuncDecl {
103 signature,
104 name: name.into(),
105 });
106
107 self.use_extensions(
109 body.used_extensions().unwrap_or_else(|e| {
110 panic!("Build-time signatures should have valid extensions. {e}")
111 }),
112 );
113
114 Ok(declare_n.into())
115 }
116
117 pub fn add_alias_def(
123 &mut self,
124 name: impl Into<SmolStr>,
125 typ: Type,
126 ) -> Result<AliasID<true>, BuildError> {
127 let name: SmolStr = name.into();
133 let bound = typ.least_upper_bound();
134 let node = self.add_child_node(ops::AliasDefn {
135 name: name.clone(),
136 definition: typ,
137 });
138
139 Ok(AliasID::new(node, name, bound))
140 }
141
142 pub fn add_alias_declare(
147 &mut self,
148 name: impl Into<SmolStr>,
149 bound: TypeBound,
150 ) -> Result<AliasID<false>, BuildError> {
151 let name: SmolStr = name.into();
152 let node = self.add_child_node(ops::AliasDecl {
153 name: name.clone(),
154 bound,
155 });
156
157 Ok(AliasID::new(node, name, bound))
158 }
159}
160
161#[cfg(test)]
162mod test {
163 use cool_asserts::assert_matches;
164
165 use crate::extension::prelude::usize_t;
166 use crate::{
167 builder::{test::n_identity, Dataflow, DataflowSubContainer},
168 types::Signature,
169 };
170
171 use super::*;
172 #[test]
173 fn basic_recurse() -> Result<(), BuildError> {
174 let build_result = {
175 let mut module_builder = ModuleBuilder::new();
176
177 let f_id = module_builder.declare(
178 "main",
179 Signature::new(vec![usize_t()], vec![usize_t()]).into(),
180 )?;
181
182 let mut f_build = module_builder.define_declaration(&f_id)?;
183 let call = f_build.call(&f_id, &[], f_build.input_wires())?;
184
185 f_build.finish_with_outputs(call.outputs())?;
186 module_builder.finish_hugr()
187 };
188 assert_matches!(build_result, Ok(_));
189 Ok(())
190 }
191
192 #[test]
193 fn simple_alias() -> Result<(), BuildError> {
194 let build_result = {
195 let mut module_builder = ModuleBuilder::new();
196
197 let qubit_state_type =
198 module_builder.add_alias_declare("qubit_state", TypeBound::Any)?;
199
200 let f_build = module_builder.define_function(
201 "main",
202 Signature::new(
203 vec![qubit_state_type.get_alias_type()],
204 vec![qubit_state_type.get_alias_type()],
205 ),
206 )?;
207 n_identity(f_build)?;
208 module_builder.finish_hugr()
209 };
210 assert_matches!(build_result, Ok(_));
211 Ok(())
212 }
213
214 #[test]
215 fn local_def() -> Result<(), BuildError> {
216 let build_result = {
217 let mut module_builder = ModuleBuilder::new();
218
219 let mut f_build = module_builder.define_function(
220 "main",
221 Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
222 )?;
223 let local_build = f_build.define_function(
224 "local",
225 Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
226 )?;
227 let [wire] = local_build.input_wires_arr();
228 let f_id = local_build.finish_with_outputs([wire, wire])?;
229
230 let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?;
231
232 f_build.finish_with_outputs(call.outputs())?;
233 module_builder.finish_hugr()
234 };
235 assert_matches!(build_result, Ok(_));
236 Ok(())
237 }
238}