#![allow(missing_docs)]
use anyhow::Result;
use rstest::{fixture, rstest};
use std::str::FromStr;
use std::sync::Arc;
use hugr::{
Extension, Hugr,
builder::{Dataflow as _, DataflowHugr as _},
envelope::{EnvelopeConfig, EnvelopeFormat, read_envelope, write_envelope},
extension::{TypeDefBound, Version, prelude::bool_t},
package::Package,
std_extensions::std_reg,
types::Signature,
};
use hugr_core::{export::export_package, import::import_package};
use hugr_model::v0 as model;
fn roundtrip(source: &str) -> Result<String> {
roundtrip_with_reg(source, &std_reg())
}
fn roundtrip_with_reg(source: &str, reg: &hugr::extension::ExtensionRegistry) -> Result<String> {
let bump = model::bumpalo::Bump::new();
let package_ast = model::ast::Package::from_str(source)?;
let package_table = package_ast.resolve(&bump)?;
let mut core = import_package(&package_table, Default::default(), reg)?;
for module in core.modules.iter_mut() {
module.resolve_extension_defs(reg)?;
}
let exported_table = export_package(&core.modules, &core.extensions, &bump);
let exported_ast = exported_table.as_ast().unwrap();
Ok(exported_ast.to_string())
}
#[fixture]
fn versioned_extension() -> Arc<Extension> {
Extension::new_arc(
"model.versioned".try_into().unwrap(),
Version::new(0, 2, 3),
|extension, extension_ref| {
extension
.add_type(
"MyType".into(),
vec![],
String::new(),
TypeDefBound::copyable(),
extension_ref,
)
.unwrap();
extension
.add_op(
"my_op".into(),
String::new(),
Signature::new(vec![], vec![]),
extension_ref,
)
.unwrap();
},
)
}
fn versioned_source(symbol_suffix: &str) -> String {
format!(
r#"(hugr 0)
(mod)
(import model.versioned.my_op{symbol_suffix})
(import model.versioned.MyType{symbol_suffix})
(declare-func public typed (core.fn [] [model.versioned.MyType]))
(define-func public main (core.fn [] [])
(dfg [] []
(signature (core.fn [] []))
((model.versioned.my_op) [] []
(signature (core.fn [] [])))))
"#
)
}
fn validate_fixture(mut source: String) -> Result<()> {
source.insert_str(0, "HUGRiHJv(@");
let package = Package::load_str(source, None)?;
package.validate()?;
Ok(())
}
macro_rules! test_roundtrip {
($name: ident, $file: expr) => {
#[test]
#[cfg_attr(miri, ignore)] pub fn $name() {
let ast = roundtrip(include_str!($file)).unwrap_or_else(|err| panic!("{:?}", err));
insta::assert_snapshot!(ast)
}
};
}
test_roundtrip!(
test_roundtrip_add,
"../../hugr-model/tests/fixtures/model-add.edn"
);
test_roundtrip!(
test_roundtrip_call,
"../../hugr-model/tests/fixtures/model-call.edn"
);
test_roundtrip!(
test_roundtrip_alias,
"../../hugr-model/tests/fixtures/model-alias.edn"
);
test_roundtrip!(
test_roundtrip_cfg,
"../../hugr-model/tests/fixtures/model-cfg.edn"
);
test_roundtrip!(
test_roundtrip_cond,
"../../hugr-model/tests/fixtures/model-cond.edn"
);
test_roundtrip!(
test_roundtrip_loop,
"../../hugr-model/tests/fixtures/model-loop.edn"
);
test_roundtrip!(
test_roundtrip_params,
"../../hugr-model/tests/fixtures/model-params.edn"
);
test_roundtrip!(
test_roundtrip_constraints,
"../../hugr-model/tests/fixtures/model-constraints.edn"
);
test_roundtrip!(
test_roundtrip_const,
"../../hugr-model/tests/fixtures/model-const.edn"
);
test_roundtrip!(
test_roundtrip_order,
"../../hugr-model/tests/fixtures/model-order.edn"
);
test_roundtrip!(
test_roundtrip_entrypoint,
"../../hugr-model/tests/fixtures/model-entrypoint.edn"
);
#[fixture]
fn simple_dfg_hugr() -> Hugr {
let dfg_builder =
hugr::builder::DFGBuilder::new(Signature::new(vec![bool_t()], vec![bool_t()])).unwrap();
let [i1] = dfg_builder.input_wires_arr();
dfg_builder.finish_hugr_with_outputs([i1]).unwrap()
}
#[rstest]
#[case(EnvelopeFormat::SExpressionWithExtensions)]
#[case(EnvelopeFormat::ModelWithExtensions)]
fn import_package_with_extensions(#[case] format: EnvelopeFormat, simple_dfg_hugr: Hugr) {
let ext = Extension::new_arc(
"miniquantum".try_into().unwrap(),
hugr::extension::Version::new(0, 1, 0),
|_, _| {},
);
let mut package = Package::new([simple_dfg_hugr]);
package.extensions.register(ext);
let mut bytes: Vec<u8> = Vec::new();
write_envelope(&mut bytes, &package, EnvelopeConfig::new(format)).unwrap();
let buff = std::io::BufReader::new(bytes.as_slice());
let (_, loaded_pkg) = read_envelope(buff, &std_reg()).unwrap();
assert_eq!(loaded_pkg.extensions.len(), 1);
let read_ext = loaded_pkg.extensions.iter_all().next().unwrap();
assert_eq!(read_ext.name(), &"miniquantum".try_into().unwrap());
assert_eq!(package, loaded_pkg);
}
#[rstest]
#[case("@0.2.3")]
#[case("")]
fn core_versions(#[case] symbol_suffix: &str, versioned_extension: Arc<Extension>) {
let mut reg = std_reg();
reg.register(versioned_extension);
let ast = roundtrip_with_reg(&versioned_source(symbol_suffix), ®).unwrap();
assert!(ast.contains("model.versioned.MyType@0.2.3"));
assert!(ast.contains("model.versioned.my_op@0.2.3"));
}
#[rstest]
#[cfg_attr(miri, ignore)] pub fn test_fixtures_validate(
#[files("../hugr-model/tests/fixtures/*.edn")]
#[mode = str]
source: &str,
) -> Result<()> {
validate_fixture(source.to_string())
}