use derive_more::{Display, Error, From};
use std::collections::HashMap;
use std::path::Path;
use std::{fs, io, mem};
use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
use crate::extension::{ExtensionRegistry, ExtensionRegistryError};
use crate::hugr::internal::HugrMutInternals;
use crate::hugr::{HugrView, ValidationError};
use crate::ops::{FuncDefn, Module, NamedOp, OpTag, OpTrait, OpType};
use crate::{Extension, Hugr};
#[derive(Debug, Default, Clone, serde::Serialize, serde::Deserialize)]
pub struct Package {
pub modules: Vec<Hugr>,
pub extensions: Vec<Extension>,
}
impl Package {
pub fn new(
modules: impl IntoIterator<Item = Hugr>,
extensions: impl IntoIterator<Item = Extension>,
) -> Result<Self, PackageError> {
let modules: Vec<Hugr> = modules.into_iter().collect();
for (idx, module) in modules.iter().enumerate() {
let root_op = module.get_optype(module.root());
if !root_op.is_module() {
return Err(PackageError::NonModuleHugr {
module_index: idx,
root_op: root_op.clone(),
});
}
}
Ok(Self {
modules,
extensions: extensions.into_iter().collect(),
})
}
pub fn from_hugrs(
modules: impl IntoIterator<Item = Hugr>,
extensions: impl IntoIterator<Item = Extension>,
) -> Result<Self, PackageError> {
let modules: Vec<Hugr> = modules
.into_iter()
.map(to_module_hugr)
.collect::<Result<_, PackageError>>()?;
Ok(Self {
modules,
extensions: extensions.into_iter().collect(),
})
}
pub fn from_hugr(hugr: Hugr) -> Result<Self, PackageError> {
let mut package = Self::default();
let module = to_module_hugr(hugr)?;
package.modules.push(module);
Ok(package)
}
pub fn update_validate(
&mut self,
reg: &mut ExtensionRegistry,
) -> Result<(), PackageValidationError> {
for ext in &self.extensions {
reg.register_updated_ref(ext)?;
}
for hugr in self.modules.iter_mut() {
hugr.update_validate(reg)?;
}
Ok(())
}
#[deprecated(since = "0.13.2", note = "Replaced by `Package::update_validate`")]
pub fn validate(
mut self,
reg: &mut ExtensionRegistry,
) -> Result<Vec<Hugr>, PackageValidationError> {
self.update_validate(reg)?;
Ok(self.modules)
}
pub fn from_json_reader(reader: impl io::Read) -> Result<Self, PackageEncodingError> {
let val: serde_json::Value = serde_json::from_reader(reader)?;
let pkg_load_err = match serde_json::from_value::<Package>(val.clone()) {
Ok(p) => return Ok(p),
Err(e) => e,
};
if let Ok(hugr) = serde_json::from_value::<Hugr>(val) {
return Ok(Package::from_hugr(hugr)?);
}
Err(PackageEncodingError::JsonEncoding(pkg_load_err))
}
pub fn from_json(json: impl AsRef<str>) -> Result<Self, PackageEncodingError> {
Self::from_json_reader(json.as_ref().as_bytes())
}
pub fn from_json_file(path: impl AsRef<Path>) -> Result<Self, PackageEncodingError> {
let file = fs::File::open(path)?;
let reader = io::BufReader::new(file);
Self::from_json_reader(reader)
}
pub fn to_json_writer(&self, writer: impl io::Write) -> Result<(), PackageEncodingError> {
serde_json::to_writer(writer, self)?;
Ok(())
}
pub fn to_json(&self) -> Result<String, PackageEncodingError> {
let json = serde_json::to_string(self)?;
Ok(json)
}
pub fn to_json_file(&self, path: impl AsRef<Path>) -> Result<(), PackageEncodingError> {
let file = fs::File::open(path)?;
let writer = io::BufWriter::new(file);
self.to_json_writer(writer)
}
}
impl PartialEq for Package {
fn eq(&self, other: &Self) -> bool {
if self.modules != other.modules || self.extensions.len() != other.extensions.len() {
return false;
}
let exts = self
.extensions
.iter()
.map(|e| (&e.name, e))
.collect::<HashMap<_, _>>();
other
.extensions
.iter()
.all(|e| exts.get(&e.name).map_or(false, |&e2| e == e2))
}
}
impl AsRef<[Hugr]> for Package {
fn as_ref(&self) -> &[Hugr] {
&self.modules
}
}
fn to_module_hugr(mut hugr: Hugr) -> Result<Hugr, PackageError> {
let root = hugr.root();
let root_op = hugr.get_optype(root).clone();
let tag = root_op.tag();
if root_op.is_module() {
return Ok(hugr);
}
if OpTag::ModuleOp.is_superset(tag) {
let new_root = hugr.add_node(Module::new().into());
hugr.set_root(new_root);
hugr.set_parent(root, new_root);
return Ok(hugr);
}
if OpTag::Dfg.is_superset(tag) {
let signature = root_op
.dataflow_signature()
.unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()));
hugr.set_num_ports(root, 0, 1);
hugr.replace_op(
root,
FuncDefn {
name: "main".to_string(),
signature: signature.into(),
},
)
.expect("Hugr accepts any root node");
let new_root = hugr.add_node(Module::new().into());
hugr.set_root(new_root);
hugr.set_parent(root, new_root);
return Ok(hugr);
}
if OpTag::DataflowChild.is_superset(tag) && !root_op.is_input() && !root_op.is_output() {
let signature = root_op
.dataflow_signature()
.unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()));
let mut new_hugr = ModuleBuilder::new();
let mut func = new_hugr.define_function("main", signature).unwrap();
let dataflow_node = func.add_hugr_with_wires(hugr, func.input_wires()).unwrap();
func.finish_with_outputs(dataflow_node.outputs()).unwrap();
return Ok(mem::take(new_hugr.hugr_mut()));
}
Err(PackageError::CannotWrapHugr {
root_op: root_op.clone(),
})
}
#[derive(Debug, Display, Error, PartialEq)]
#[non_exhaustive]
pub enum PackageError {
#[display("Module {module_index} in the package does not have an OpType::Module root, but {}", root_op.name())]
NonModuleHugr {
module_index: usize,
root_op: OpType,
},
#[display("A hugr with optype {} cannot be wrapped in a module.", root_op.name())]
CannotWrapHugr {
root_op: OpType,
},
}
#[derive(Debug, Display, Error, From)]
#[non_exhaustive]
pub enum PackageEncodingError {
JsonEncoding(serde_json::Error),
IOError(io::Error),
Package(PackageError),
}
#[derive(Debug, From)]
#[non_exhaustive]
pub enum PackageValidationError {
Extension(ExtensionRegistryError),
Validation(ValidationError),
#[from(ignore)]
#[deprecated(
since = "0.13.2",
note = "Replaced by `PackageValidationError::Validation`"
)]
Validate(ValidationError),
#[from(ignore)]
#[deprecated(
since = "0.13.2",
note = "Replaced by `PackageValidationError::Extension`"
)]
ExtReg(ExtensionRegistryError),
}
#[allow(deprecated)]
impl std::error::Error for PackageValidationError {
fn source(&self) -> Option<&(dyn derive_more::Error + 'static)> {
match self {
PackageValidationError::Extension(source) => Some(source),
PackageValidationError::Validation(source) => Some(source),
PackageValidationError::Validate(source) => Some(source),
PackageValidationError::ExtReg(source) => Some(source),
}
}
}
impl Display for PackageValidationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
#[allow(deprecated)]
match self {
PackageValidationError::Extension(e) => write!(f, "Error processing extensions: {}", e),
PackageValidationError::Validation(e) => write!(f, "Error validating HUGR: {}", e),
PackageValidationError::Validate(e) => write!(f, "Error validating HUGR: {}", e),
PackageValidationError::ExtReg(e) => write!(f, "Error registering extension: {}", e),
}
}
}
#[cfg(test)]
mod test {
use cool_asserts::assert_matches;
use crate::builder::test::{
simple_cfg_hugr, simple_dfg_hugr, simple_funcdef_hugr, simple_module_hugr,
};
use crate::extension::{ExtensionId, EMPTY_REG};
use crate::ops::dataflow::IOTrait;
use crate::ops::Input;
use super::*;
use rstest::{fixture, rstest};
use semver::Version;
#[fixture]
fn simple_package() -> Package {
let hugr0 = simple_module_hugr();
let hugr1 = simple_module_hugr();
let ext_1_id = ExtensionId::new("ext1").unwrap();
let ext_2_id = ExtensionId::new("ext2").unwrap();
let ext1 = Extension::new(ext_1_id.clone(), Version::new(2, 4, 8));
let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0));
Package {
modules: vec![hugr0, hugr1],
extensions: vec![ext1, ext2],
}
}
#[fixture]
fn simple_input_node() -> Hugr {
Hugr::new(Input::new(vec![]))
}
#[rstest]
#[case::empty(Package::default())]
#[case::simple(simple_package())]
fn package_roundtrip(#[case] package: Package) {
let json = package.to_json().unwrap();
let new_package = Package::from_json(&json).unwrap();
assert_eq!(package, new_package);
}
#[rstest]
#[case::module("module", simple_module_hugr(), false)]
#[case::funcdef("funcdef", simple_funcdef_hugr(), false)]
#[case::dfg("dfg", simple_dfg_hugr(), false)]
#[case::cfg("cfg", simple_cfg_hugr(), false)]
#[case::unsupported_input("input", simple_input_node(), true)]
fn hugr_to_package(#[case] test_name: &str, #[case] hugr: Hugr, #[case] errors: bool) {
match (&Package::from_hugr(hugr), errors) {
(Ok(package), false) => {
assert_eq!(package.modules.len(), 1);
let hugr = &package.modules[0];
let root_op = hugr.get_optype(hugr.root());
assert!(root_op.is_module());
insta::assert_snapshot!(test_name, hugr.mermaid_string());
}
(Err(_), true) => {}
(p, _) => panic!("Unexpected result {:?}", p),
}
}
#[rstest]
fn package_properties() {
let module = simple_module_hugr();
let dfg = simple_dfg_hugr();
assert_matches!(
Package::new([module.clone(), dfg.clone()], []),
Err(PackageError::NonModuleHugr {
module_index: 1,
root_op: OpType::DFG(_),
})
);
let mut pkg = Package::from_hugrs([module, dfg], []).unwrap();
let mut reg = EMPTY_REG.clone();
pkg.update_validate(&mut reg).unwrap();
assert_eq!(pkg.modules.len(), 2);
}
}