use derive_more::{Display, Error, From};
use itertools::Itertools;
use std::path::Path;
use std::{fs, io, mem};
use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
use crate::extension::resolution::ExtensionResolutionError;
use crate::extension::{ExtensionId, ExtensionRegistry};
use crate::hugr::internal::HugrMutInternals;
use crate::hugr::{ExtensionError, HugrView, ValidationError};
use crate::ops::{FuncDefn, Module, NamedOp, OpTag, OpTrait, OpType};
use crate::{Extension, Hugr};
#[derive(Debug, Default, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
pub struct Package {
pub modules: Vec<Hugr>,
pub extensions: ExtensionRegistry,
}
impl Package {
pub fn new(modules: impl IntoIterator<Item = Hugr>) -> Result<Self, PackageError> {
let modules: Vec<Hugr> = modules.into_iter().collect();
let mut extensions = ExtensionRegistry::default();
for (idx, module) in modules.iter().enumerate() {
let root_op = module.get_optype(module.root());
if !root_op.is_module() {
return Err(PackageError::NonModuleHugr {
module_index: idx,
root_op: root_op.clone(),
});
}
extensions.extend(module.extensions());
}
Ok(Self {
modules,
extensions,
})
}
pub fn from_hugrs(modules: impl IntoIterator<Item = Hugr>) -> Result<Self, PackageError> {
let modules: Vec<Hugr> = modules
.into_iter()
.map(to_module_hugr)
.collect::<Result<_, PackageError>>()?;
let mut extensions = ExtensionRegistry::default();
for module in &modules {
extensions.extend(module.extensions());
}
Ok(Self {
modules,
extensions,
})
}
pub fn from_hugr(hugr: Hugr) -> Result<Self, PackageError> {
let mut package = Self::default();
let module = to_module_hugr(hugr)?;
package.extensions = module.extensions().clone();
package.modules.push(module);
Ok(package)
}
pub fn validate(&self) -> Result<(), PackageValidationError> {
for hugr in self.modules.iter() {
hugr.validate()?;
let missing_exts = hugr
.extensions()
.ids()
.filter(|id| !self.extensions.contains(id))
.cloned()
.collect_vec();
if !missing_exts.is_empty() {
return Err(PackageValidationError::MissingExtension {
missing: missing_exts,
available: self.extensions.ids().cloned().collect(),
});
}
}
Ok(())
}
pub fn from_json_reader(
reader: impl io::Read,
extension_registry: &ExtensionRegistry,
) -> Result<Self, PackageEncodingError> {
let val: serde_json::Value = serde_json::from_reader(reader)?;
#[derive(Debug, serde::Deserialize)]
struct PackageDeser {
pub modules: Vec<Hugr>,
pub extensions: Vec<Extension>,
}
let loaded_pkg = serde_json::from_value::<PackageDeser>(val.clone());
if let Ok(PackageDeser {
mut modules,
extensions: pkg_extensions,
}) = loaded_pkg
{
let mut pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
pkg_extensions,
&extension_registry.into(),
)?;
let mut combined_registry = extension_registry.clone();
combined_registry.extend(&pkg_extensions);
for module in &mut modules {
module.resolve_extension_defs(&combined_registry)?;
pkg_extensions.extend(module.extensions());
}
return Ok(Package {
modules,
extensions: pkg_extensions,
});
};
let pkg_load_err = loaded_pkg.unwrap_err();
if let Ok(mut hugr) = serde_json::from_value::<Hugr>(val) {
hugr.resolve_extension_defs(extension_registry)?;
if cfg!(feature = "extension_inference") {
hugr.infer_extensions(false)?;
}
return Ok(Package::from_hugr(hugr)?);
}
Err(PackageEncodingError::JsonEncoding(pkg_load_err))
}
pub fn from_json(
json: impl AsRef<str>,
extension_registry: &ExtensionRegistry,
) -> Result<Self, PackageEncodingError> {
Self::from_json_reader(json.as_ref().as_bytes(), extension_registry)
}
pub fn from_json_file(
path: impl AsRef<Path>,
extension_registry: &ExtensionRegistry,
) -> Result<Self, PackageEncodingError> {
let file = fs::File::open(path)?;
let reader = io::BufReader::new(file);
Self::from_json_reader(reader, extension_registry)
}
pub fn to_json_writer(&self, writer: impl io::Write) -> Result<(), PackageEncodingError> {
serde_json::to_writer(writer, self)?;
Ok(())
}
pub fn to_json(&self) -> Result<String, PackageEncodingError> {
let json = serde_json::to_string(self)?;
Ok(json)
}
pub fn to_json_file(&self, path: impl AsRef<Path>) -> Result<(), PackageEncodingError> {
let file = fs::File::open(path)?;
let writer = io::BufWriter::new(file);
self.to_json_writer(writer)
}
}
impl AsRef<[Hugr]> for Package {
fn as_ref(&self) -> &[Hugr] {
&self.modules
}
}
fn to_module_hugr(mut hugr: Hugr) -> Result<Hugr, PackageError> {
let root = hugr.root();
let root_op = hugr.get_optype(root).clone();
let tag = root_op.tag();
if root_op.is_module() {
return Ok(hugr);
}
if OpTag::ModuleOp.is_superset(tag) {
let new_root = hugr.add_node(Module::new().into());
hugr.set_root(new_root);
hugr.set_parent(root, new_root);
return Ok(hugr);
}
if OpTag::Dfg.is_superset(tag) {
let signature = root_op
.dataflow_signature()
.unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()));
hugr.set_num_ports(root, 0, 1);
hugr.replace_op(
root,
FuncDefn {
name: "main".to_string(),
signature: signature.into_owned().into(),
},
)
.expect("Hugr accepts any root node");
let new_root = hugr.add_node(Module::new().into());
hugr.set_root(new_root);
hugr.set_parent(root, new_root);
return Ok(hugr);
}
if OpTag::DataflowChild.is_superset(tag) && !root_op.is_input() && !root_op.is_output() {
let signature = root_op
.dataflow_signature()
.unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()))
.into_owned();
let mut new_hugr = ModuleBuilder::new();
let mut func = new_hugr.define_function("main", signature).unwrap();
let dataflow_node = func.add_hugr_with_wires(hugr, func.input_wires()).unwrap();
func.finish_with_outputs(dataflow_node.outputs()).unwrap();
return Ok(mem::take(new_hugr.hugr_mut()));
}
Err(PackageError::CannotWrapHugr {
root_op: root_op.clone(),
})
}
#[derive(Debug, Display, Error, PartialEq)]
#[non_exhaustive]
pub enum PackageError {
#[display("Module {module_index} in the package does not have an OpType::Module root, but {}", root_op.name())]
NonModuleHugr {
module_index: usize,
root_op: OpType,
},
#[display("A hugr with optype {} cannot be wrapped in a module.", root_op.name())]
CannotWrapHugr {
root_op: OpType,
},
}
#[derive(Debug, Display, Error, From)]
#[non_exhaustive]
pub enum PackageEncodingError {
JsonEncoding(serde_json::Error),
IOError(io::Error),
Package(PackageError),
ExtensionResolution(ExtensionResolutionError),
RuntimeExtensionResolution(ExtensionError),
}
#[derive(Debug, Display, From, Error)]
#[non_exhaustive]
pub enum PackageValidationError {
#[display("The package modules use the extension{} {} not present in the defined set. The declared extensions are {}",
if missing.len() > 1 {"s"} else {""},
missing.iter().map(|id| id.to_string()).collect::<Vec<_>>().join(", "),
available.iter().map(|id| id.to_string()).collect::<Vec<_>>().join(", "),
)]
MissingExtension {
missing: Vec<ExtensionId>,
available: Vec<ExtensionId>,
},
Validation(ValidationError),
}
#[cfg(test)]
mod test {
use cool_asserts::assert_matches;
use crate::builder::test::{
simple_cfg_hugr, simple_dfg_hugr, simple_funcdef_hugr, simple_module_hugr,
};
use crate::ops::dataflow::IOTrait;
use crate::ops::Input;
use super::*;
use rstest::{fixture, rstest};
#[fixture]
fn simple_package() -> Package {
let hugr0 = simple_module_hugr();
let hugr1 = simple_module_hugr();
Package::new([hugr0, hugr1]).unwrap()
}
#[fixture]
fn simple_input_node() -> Hugr {
Hugr::new(Input::new(vec![]))
}
#[rstest]
#[case::empty(Package::default())]
#[case::simple(simple_package())]
fn package_roundtrip(#[case] package: Package) {
use crate::extension::PRELUDE_REGISTRY;
let json = package.to_json().unwrap();
let new_package = Package::from_json(&json, &PRELUDE_REGISTRY).unwrap();
assert_eq!(package, new_package);
}
#[rstest]
#[case::module("module", simple_module_hugr(), false)]
#[case::funcdef("funcdef", simple_funcdef_hugr(), false)]
#[case::dfg("dfg", simple_dfg_hugr(), false)]
#[case::cfg("cfg", simple_cfg_hugr(), false)]
#[case::unsupported_input("input", simple_input_node(), true)]
#[cfg_attr(miri, ignore)] fn hugr_to_package(#[case] test_name: &str, #[case] hugr: Hugr, #[case] errors: bool) {
match (&Package::from_hugr(hugr), errors) {
(Ok(package), false) => {
assert_eq!(package.modules.len(), 1);
let hugr = &package.modules[0];
let root_op = hugr.get_optype(hugr.root());
assert!(root_op.is_module());
insta::assert_snapshot!(test_name, hugr.mermaid_string());
}
(Err(_), true) => {}
(p, _) => panic!("Unexpected result {:?}", p),
}
}
#[rstest]
fn package_properties() {
let module = simple_module_hugr();
let dfg = simple_dfg_hugr();
assert_matches!(
Package::new([module.clone(), dfg.clone()]),
Err(PackageError::NonModuleHugr {
module_index: 1,
root_op: OpType::DFG(_),
})
);
let pkg = Package::from_hugrs([module, dfg]).unwrap();
pkg.validate().unwrap();
assert_eq!(pkg.modules.len(), 2);
}
}