use core::{f64, panic};
use std::io::BufReader;
use std::sync::Arc;
use itertools::Itertools;
use rstest::rstest;
use crate::builder::{
DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, HugrBuilder, ModuleBuilder,
};
use crate::envelope::EnvelopeConfig;
use crate::extension::prelude::{ConstUsize, bool_t, usize_custom_t, usize_t};
use crate::extension::resolution::WeakExtensionRegistry;
use crate::extension::resolution::{resolve_op_extensions, resolve_op_types_extensions};
use crate::extension::{
ExtensionId, ExtensionRegistry, ExtensionSet, PRELUDE, PRELUDE_REGISTRY, TypeDefBound,
};
use crate::ops::constant::CustomConst;
use crate::ops::constant::test::CustomTestValue;
use crate::ops::{CallIndirect, ExtensionOp, Input, NamedOp, OpType, Tag, Value};
use crate::package::Package;
use crate::std_extensions::arithmetic::conversions::{self, ConvertOpDef};
use crate::std_extensions::arithmetic::float_types::{self, ConstF64, float64_type};
use crate::std_extensions::arithmetic::int_ops;
use crate::std_extensions::arithmetic::int_types::{self, int_type};
use crate::std_extensions::collections::list::ListValue;
use crate::std_extensions::std_reg;
use crate::types::type_param::TypeParam;
use crate::types::{PolyFuncType, Signature, Type, TypeBound};
use crate::{Extension, Hugr, HugrView, type_row};
#[rstest]
#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())]
#[case::parametric_op(int_ops::IntOpDef::ieq.with_log_width(4),
ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned(), PRELUDE.to_owned()]
))]
fn collect_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: ExtensionRegistry) {
let op = op.into();
let resolved = op.used_extensions().unwrap();
assert_eq!(resolved, extensions);
}
#[rstest]
#[case::empty(Input { types: type_row![]}, ExtensionRegistry::default())]
#[case::parametric_op(int_ops::IntOpDef::ieq.with_log_width(4),
ExtensionRegistry::new([int_ops::EXTENSION.to_owned(), int_types::EXTENSION.to_owned(), PRELUDE.to_owned()]
))]
fn resolve_type_extensions(#[case] op: impl Into<OpType>, #[case] extensions: ExtensionRegistry) {
let op = op.into();
let ser = serde_json::to_string(&op).unwrap();
let mut deser_op: OpType = serde_json::from_str(&ser).unwrap();
let dummy_node = portgraph::NodeIndex::new(0).into();
resolve_op_extensions(dummy_node, &mut deser_op, &extensions).unwrap();
let weak_extensions: WeakExtensionRegistry = (&extensions).into();
resolve_op_types_extensions(Some(dummy_node), &mut deser_op, &weak_extensions)
.unwrap()
.for_each(|_| ());
let deser_extensions = deser_op.used_extensions().unwrap();
assert_eq!(
deser_extensions, extensions,
"{deser_extensions} != {extensions}"
);
}
fn make_extension(name: &str, op_name: &str) -> (Arc<Extension>, OpType) {
let ext = Extension::new_test_arc(ExtensionId::new_unchecked(name), |ext, extension_ref| {
ext.add_op(
op_name.into(),
String::new(),
Signature::new_endo([bool_t()]),
extension_ref,
)
.unwrap();
});
let op_def = ext.get_op(op_name).unwrap();
let op = ExtensionOp::new(op_def.clone(), vec![]).unwrap();
(ext, op.into())
}
fn make_extension_self_referencing(name: &str, op_name: &str, type_name: &str) -> Arc<Extension> {
Extension::new_test_arc(ExtensionId::new_unchecked(name), |ext, extension_ref| {
let type_def = ext
.add_type(
type_name.into(),
vec![],
String::new(),
TypeDefBound::any(),
extension_ref,
)
.unwrap();
let typ = type_def.instantiate([]).unwrap();
ext.add_op(
op_name.into(),
String::new(),
Signature::new(vec![typ.into()], vec![usize_t()]),
extension_ref,
)
.unwrap();
})
}
fn check_extension_resolution(mut hugr: Hugr) {
let build_extensions = hugr.extensions().clone();
let mut resolution_extensions = std_reg();
resolution_extensions.extend(&build_extensions);
let collected_exts = ExtensionRegistry::new(hugr.nodes().flat_map(|node| {
hugr.get_optype(node)
.used_extensions()
.unwrap_or_default()
.into_iter()
}));
assert_eq!(
collected_exts, build_extensions,
"{collected_exts} != {build_extensions}"
);
hugr.resolve_extension_defs(&resolution_extensions).unwrap();
assert_eq!(
hugr.extensions(),
&build_extensions,
"{} != {build_extensions}",
hugr.extensions()
);
let ser = hugr.store_str(EnvelopeConfig::text()).unwrap();
let deser_hugr = Hugr::load_str(&ser, Some(&resolution_extensions)).unwrap();
assert_eq!(
deser_hugr.extensions(),
&build_extensions,
"{} != {build_extensions}",
deser_hugr.extensions()
);
}
#[rstest]
fn resolve_hugr_extensions_simple() {
let mut build = DFGBuilder::new(Signature::new(vec![], vec![float64_type()])).unwrap();
let f_const = build.add_load_const(Value::extension(ConstF64::new(f64::consts::PI)));
let mut hugr = build
.finish_hugr_with_outputs([f_const])
.unwrap_or_else(|e| panic!("{e}"));
let build_extensions = hugr.extensions().clone();
let mut collected_exts = ExtensionRegistry::default();
for node in hugr.nodes() {
let op = hugr.get_optype(node);
collected_exts.extend(op.used_extensions().unwrap());
}
assert_eq!(
collected_exts, build_extensions,
"{collected_exts} != {build_extensions}"
);
hugr.resolve_extension_defs(&build_extensions).unwrap();
assert_eq!(
hugr.extensions(),
&build_extensions,
"{} != {build_extensions}",
hugr.extensions()
);
let ser = hugr.store_str(EnvelopeConfig::text()).unwrap();
let deser_hugr = Hugr::load_str(&ser, Some(&build_extensions)).unwrap();
assert_eq!(
deser_hugr.extensions(),
&build_extensions,
"{} != {build_extensions}",
hugr.extensions()
);
}
#[rstest]
fn resolve_hugr_extensions() {
let (ext_a, op_a) = make_extension("dummy.a", "op_a");
let (ext_b, op_b) = make_extension("dummy.b", "op_b");
let (ext_c, op_c) = make_extension("dummy.c", "op_c");
let (ext_d, op_d) = make_extension("dummy.d", "op_d");
let (_ext_e, op_e) = make_extension("dummy.e", "op_e");
let mut module = ModuleBuilder::new();
let decl = module
.declare(
"dummy_declaration",
Signature::new_endo([float64_type()]).into(),
)
.unwrap();
let mut func = module
.define_function(
"dummy_fn",
Signature::new(vec![float64_type(), bool_t()], vec![]),
)
.unwrap();
let [func_i0, func_i1] = func.input_wires_arr();
func.call(&decl, &[], vec![func_i0]).unwrap();
let loaded_func = func.load_func(&decl, &[]).unwrap();
func.add_dataflow_op(
CallIndirect {
signature: Signature::new_endo([float64_type()]),
},
vec![loaded_func, func_i0],
)
.unwrap();
func.add_dataflow_op(op_a, vec![func_i1]).unwrap();
let mut dfg = func.dfg_builder_endo([(bool_t(), func_i1)]).unwrap();
let dfg_inputs = dfg.input_wires().collect_vec();
dfg.add_dataflow_op(op_b, dfg_inputs.clone()).unwrap();
dfg.finish_with_outputs(dfg_inputs).unwrap();
func.add_dataflow_op(
Tag::new(0, vec![vec![bool_t()].into(), vec![int_type(4)].into()]),
vec![func_i1],
)
.unwrap();
let mut tail_loop = func
.tail_loop_builder([(bool_t(), func_i1)], [], vec![].into())
.unwrap();
let tl_inputs = tail_loop.input_wires().collect_vec();
tail_loop.add_dataflow_op(op_c, tl_inputs).unwrap();
let tl_tag = tail_loop.add_load_const(Value::true_val());
let tl_tag = tail_loop
.add_dataflow_op(
Tag::new(0, vec![vec![Type::new_unit_sum(2)].into(), vec![].into()]),
vec![tl_tag],
)
.unwrap()
.out_wire(0);
tail_loop.finish_with_outputs(tl_tag, vec![]).unwrap();
let cond_tag = func.add_load_const(Value::unary_unit_sum());
let mut cond = func
.conditional_builder(([type_row![]], cond_tag), [], type_row![])
.unwrap();
let mut case = cond.case_builder(0).unwrap();
case.add_dataflow_op(op_e, [func_i1]).unwrap();
case.finish_with_outputs([]).unwrap();
let mut cfg = func
.cfg_builder([(bool_t(), func_i1)], vec![].into())
.unwrap();
let mut cfg_entry = cfg.entry_builder([type_row![]], type_row![]).unwrap();
let [cfg_i0] = cfg_entry.input_wires_arr();
cfg_entry.add_dataflow_op(op_d, [cfg_i0]).unwrap();
let cfg_tag = cfg_entry.add_load_const(Value::unary_unit_sum());
let cfg_entry_wire = cfg_entry.finish_with_outputs(cfg_tag, []).unwrap();
let cfg_exit = cfg.exit_block();
cfg.branch(&cfg_entry_wire, 0, &cfg_exit).unwrap();
func.finish_with_outputs(vec![]).unwrap();
let hugr = module.finish_hugr().unwrap_or_else(|e| panic!("{e}"));
let build_extensions = hugr.extensions().clone();
assert!(build_extensions.contains(ext_a.name()));
assert!(build_extensions.contains(ext_b.name()));
assert!(build_extensions.contains(ext_c.name()));
assert!(build_extensions.contains(ext_d.name()));
check_extension_resolution(hugr);
}
#[rstest]
#[case::usize(ConstUsize::new(42))]
#[case::list(ListValue::new(
float64_type(),
[ConstF64::new(f64::consts::PI).into()],
))]
#[case::custom(CustomTestValue(usize_custom_t(
&Arc::downgrade(&PRELUDE),
)))]
fn resolve_custom_const(#[case] custom_const: impl CustomConst) {
let mut dfg_builder = DFGBuilder::new(Signature::new(vec![], vec![])).unwrap();
dfg_builder.add_load_const(Value::extension(custom_const));
let hugr = dfg_builder
.finish_hugr_with_outputs([])
.unwrap_or_else(|e| panic!("{e}"));
check_extension_resolution(hugr);
}
#[rstest]
fn resolve_call() {
let dummy_fn_sig = PolyFuncType::new(
vec![TypeParam::RuntimeType(TypeBound::Linear)],
Signature::new(vec![], vec![bool_t()]),
);
let generic_type_1 = float64_type().into();
let generic_type_2 = int_type(6).into();
let expected_exts = [
float_types::EXTENSION_ID.clone(),
int_types::EXTENSION_ID.clone(),
]
.into_iter()
.collect::<ExtensionSet>();
let mut module = ModuleBuilder::new();
let dummy_fn = module.declare("called_fn", dummy_fn_sig).unwrap();
let mut func = module
.define_function("caller_fn", Signature::new(vec![], vec![bool_t()]))
.unwrap();
let _load_func = func.load_func(&dummy_fn, &[generic_type_1]).unwrap();
let call = func.call(&dummy_fn, &[generic_type_2], vec![]).unwrap();
func.finish_with_outputs(call.outputs()).unwrap();
let hugr = module.finish_hugr().unwrap();
for ext in expected_exts {
assert!(hugr.extensions().contains(&ext));
}
check_extension_resolution(hugr);
}
#[rstest]
fn resolve_transitive_extension_deps() {
let mut build = DFGBuilder::new(Signature::new(vec![int_type(6)], vec![usize_t()])).unwrap();
let [input] = build.input_wires_arr();
let out = build
.add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [input])
.unwrap();
let hugr = build
.finish_hugr_with_outputs(out.outputs())
.unwrap_or_else(|e| panic!("{e}"));
assert!(hugr.extensions().contains(&conversions::EXTENSION_ID));
assert!(hugr.extensions().contains(&float_types::EXTENSION_ID));
check_extension_resolution(hugr);
}
#[rstest]
fn resolve_lower_func_extensions() {
let (inner_ext, inner_op) = make_extension("inner.extension", "InnerOp");
let mut dfg = DFGBuilder::new(Signature::new_endo(vec![bool_t()])).unwrap();
let [input] = dfg.input_wires_arr();
let inner_result = dfg.add_dataflow_op(inner_op.clone(), [input]).unwrap();
let lower_hugr = dfg
.finish_hugr_with_outputs(inner_result.outputs())
.unwrap();
let mut lower_pkg = Package::from_hugr(lower_hugr);
lower_pkg.extensions.register_updated(inner_ext.clone());
let inner_ext_name = inner_ext.name().clone();
let outer_ext = Extension::new_test_arc(
ExtensionId::new_unchecked("outer.extension"),
|ext, ext_ref| {
let op_def = ext
.add_op(
"OuterOp".into(),
String::new(),
Signature::new_endo(vec![bool_t()]),
ext_ref,
)
.unwrap();
op_def.add_lower_func(crate::extension::op_def::LowerFunc::FixedHugr {
extensions: ExtensionSet::singleton(inner_ext_name),
pkg: Box::new(lower_pkg),
});
},
);
let mut outer_pkg = Package::default();
outer_pkg.extensions.register_updated(outer_ext.clone());
let mut buffer = Vec::new();
outer_pkg
.store(&mut buffer, EnvelopeConfig::binary())
.expect("serialize outer package");
let loaded_pkg =
Package::load(BufReader::new(buffer.as_slice()), None).expect("load outer package");
let loaded_outer_ext = loaded_pkg
.extensions
.get(outer_ext.name())
.expect("outer extension missing after load");
let loaded_op_def = loaded_outer_ext
.get_op("OuterOp")
.expect("OuterOp missing after load");
assert_eq!(
loaded_op_def.lower_funcs.len(),
1,
"expected 1 lower func after load"
);
let available = ExtensionSet::singleton(inner_ext.name().clone());
let lower_hugr = loaded_op_def
.try_lower(&[], &available)
.expect("try_lower should succeed when inner extension is available");
for node in lower_hugr.nodes() {
let op = lower_hugr.get_optype(node);
assert!(
!matches!(op, crate::ops::OpType::OpaqueOp(_)),
"Op {op:?} on {node} is an unresolved OpaqueOp after loading"
);
}
lower_hugr.nodes().any(|node| {
let op = lower_hugr.get_optype(node);
let crate::ops::OpType::ExtensionOp(op) = op else {
return false;
};
op.extension_id() == inner_ext.name() && op.name() == inner_op.name()
});
}
#[test]
fn register_new_cyclic() {
let ext_id = ExtensionId::new("ext").unwrap();
let ext = make_extension_self_referencing(&ext_id, "my_op", "my_type");
let reg = ExtensionRegistry::new([ext]);
let ser = serde_json::to_string(®).unwrap();
let new_reg = ExtensionRegistry::load_json(ser.as_bytes(), &PRELUDE_REGISTRY).unwrap();
assert!(new_reg.contains(&ext_id));
new_reg.validate().unwrap();
}