hugr_core/builder/
module.rsuse super::{
build_traits::HugrBuilder,
dataflow::{DFGBuilder, FunctionBuilder},
BuildError, Container,
};
use crate::hugr::internal::HugrMutInternals;
use crate::hugr::views::HugrView;
use crate::hugr::ValidationError;
use crate::ops;
use crate::types::{PolyFuncType, Type, TypeBound};
use crate::ops::handle::{AliasID, FuncID, NodeHandle};
use crate::{Hugr, Node};
use smol_str::SmolStr;
#[derive(Debug, Clone, PartialEq)]
pub struct ModuleBuilder<T>(pub(super) T);
impl<T: AsMut<Hugr> + AsRef<Hugr>> Container for ModuleBuilder<T> {
#[inline]
fn container_node(&self) -> Node {
self.0.as_ref().root()
}
#[inline]
fn hugr_mut(&mut self) -> &mut Hugr {
self.0.as_mut()
}
fn hugr(&self) -> &Hugr {
self.0.as_ref()
}
}
impl ModuleBuilder<Hugr> {
#[must_use]
pub fn new() -> Self {
Self(Default::default())
}
}
impl Default for ModuleBuilder<Hugr> {
fn default() -> Self {
Self::new()
}
}
impl HugrBuilder for ModuleBuilder<Hugr> {
fn finish_hugr(mut self) -> Result<Hugr, ValidationError> {
if cfg!(feature = "extension_inference") {
self.0.infer_extensions(false)?;
}
self.0.validate()?;
Ok(self.0)
}
}
impl<T: AsMut<Hugr> + AsRef<Hugr>> ModuleBuilder<T> {
pub fn define_declaration(
&mut self,
f_id: &FuncID<false>,
) -> Result<FunctionBuilder<&mut Hugr>, BuildError> {
let f_node = f_id.node();
let ops::FuncDecl { signature, name } = self
.hugr()
.get_optype(f_node)
.as_func_decl()
.ok_or(BuildError::UnexpectedType {
node: f_node,
op_desc: "crate::ops::OpType::FuncDecl",
})?
.clone();
let body = signature.body().clone();
self.hugr_mut()
.replace_op(f_node, ops::FuncDefn { name, signature })
.expect("Replacing a FuncDecl node with a FuncDefn should always be valid");
let db = DFGBuilder::create_with_io(self.hugr_mut(), f_node, body)?;
Ok(FunctionBuilder::from_dfg_builder(db))
}
pub fn declare(
&mut self,
name: impl Into<String>,
signature: PolyFuncType,
) -> Result<FuncID<false>, BuildError> {
let body = signature.body().clone();
let declare_n = self.add_child_node(ops::FuncDecl {
signature,
name: name.into(),
});
self.use_extensions(
body.used_extensions().unwrap_or_else(|e| {
panic!("Build-time signatures should have valid extensions. {e}")
}),
);
Ok(declare_n.into())
}
pub fn add_alias_def(
&mut self,
name: impl Into<SmolStr>,
typ: Type,
) -> Result<AliasID<true>, BuildError> {
let name: SmolStr = name.into();
let bound = typ.least_upper_bound();
let node = self.add_child_node(ops::AliasDefn {
name: name.clone(),
definition: typ,
});
Ok(AliasID::new(node, name, bound))
}
pub fn add_alias_declare(
&mut self,
name: impl Into<SmolStr>,
bound: TypeBound,
) -> Result<AliasID<false>, BuildError> {
let name: SmolStr = name.into();
let node = self.add_child_node(ops::AliasDecl {
name: name.clone(),
bound,
});
Ok(AliasID::new(node, name, bound))
}
}
#[cfg(test)]
mod test {
use cool_asserts::assert_matches;
use crate::extension::prelude::usize_t;
use crate::{
builder::{test::n_identity, Dataflow, DataflowSubContainer},
types::Signature,
};
use super::*;
#[test]
fn basic_recurse() -> Result<(), BuildError> {
let build_result = {
let mut module_builder = ModuleBuilder::new();
let f_id = module_builder.declare(
"main",
Signature::new(vec![usize_t()], vec![usize_t()]).into(),
)?;
let mut f_build = module_builder.define_declaration(&f_id)?;
let call = f_build.call(&f_id, &[], f_build.input_wires())?;
f_build.finish_with_outputs(call.outputs())?;
module_builder.finish_hugr()
};
assert_matches!(build_result, Ok(_));
Ok(())
}
#[test]
fn simple_alias() -> Result<(), BuildError> {
let build_result = {
let mut module_builder = ModuleBuilder::new();
let qubit_state_type =
module_builder.add_alias_declare("qubit_state", TypeBound::Any)?;
let f_build = module_builder.define_function(
"main",
Signature::new(
vec![qubit_state_type.get_alias_type()],
vec![qubit_state_type.get_alias_type()],
),
)?;
n_identity(f_build)?;
module_builder.finish_hugr()
};
assert_matches!(build_result, Ok(_));
Ok(())
}
#[test]
fn local_def() -> Result<(), BuildError> {
let build_result = {
let mut module_builder = ModuleBuilder::new();
let mut f_build = module_builder.define_function(
"main",
Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
)?;
let local_build = f_build.define_function(
"local",
Signature::new(vec![usize_t()], vec![usize_t(), usize_t()]),
)?;
let [wire] = local_build.input_wires_arr();
let f_id = local_build.finish_with_outputs([wire, wire])?;
let call = f_build.call(f_id.handle(), &[], f_build.input_wires())?;
f_build.finish_with_outputs(call.outputs())?;
module_builder.finish_hugr()
};
assert_matches!(build_result, Ok(_));
Ok(())
}
}