pub mod hugrmut;
pub(crate) mod ident;
pub mod internal;
pub mod rewrite;
pub mod serialize;
pub mod validate;
pub mod views;
use std::collections::VecDeque;
use std::iter;
pub(crate) use self::hugrmut::HugrMut;
pub use self::validate::ValidationError;
pub use ident::{IdentList, InvalidIdentifier};
pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError};
use portgraph::multiportgraph::MultiPortGraph;
use portgraph::{Hierarchy, PortMut, UnmanagedDenseMap};
use thiserror::Error;
pub use self::views::{HugrView, RootTagged};
use crate::core::NodeIndex;
use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED};
use crate::ops::custom::resolve_extension_ops;
use crate::ops::{OpTag, OpTrait};
pub use crate::ops::{OpType, DEFAULT_OPTYPE};
use crate::{Direction, Node};
#[derive(Clone, Debug, PartialEq)]
pub struct Hugr {
graph: MultiPortGraph,
hierarchy: Hierarchy,
root: portgraph::NodeIndex,
op_types: UnmanagedDenseMap<portgraph::NodeIndex, OpType>,
metadata: UnmanagedDenseMap<portgraph::NodeIndex, Option<NodeMetadataMap>>,
}
impl Default for Hugr {
fn default() -> Self {
Self::new(crate::ops::Module::new())
}
}
impl AsRef<Hugr> for Hugr {
fn as_ref(&self) -> &Hugr {
self
}
}
impl AsMut<Hugr> for Hugr {
fn as_mut(&mut self) -> &mut Hugr {
self
}
}
pub type NodeMetadata = serde_json::Value;
pub type NodeMetadataMap = serde_json::Map<String, NodeMetadata>;
impl Hugr {
pub fn new(root_node: impl Into<OpType>) -> Self {
Self::with_capacity(root_node.into(), 0, 0)
}
pub fn update_validate(
&mut self,
extension_registry: &ExtensionRegistry,
) -> Result<(), ValidationError> {
resolve_extension_ops(self, extension_registry)?;
self.validate_no_extensions(extension_registry)?;
#[cfg(feature = "extension_inference")]
{
self.infer_extensions(false)?;
self.validate_extensions()?;
}
Ok(())
}
pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> {
fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> {
match optype {
OpType::DFG(dfg) => Some(&mut dfg.signature.extension_reqs),
OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta),
OpType::TailLoop(tl) => Some(&mut tl.extension_delta),
OpType::CFG(cfg) => Some(&mut cfg.signature.extension_reqs),
OpType::Conditional(c) => Some(&mut c.extension_delta),
OpType::Case(c) => Some(&mut c.signature.extension_reqs),
_ => None,
}
}
fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result<ExtensionSet, ExtensionError> {
let mut child_sets = h
.children(node)
.collect::<Vec<_>>() .into_iter()
.map(|ch| Ok((ch, infer(h, ch, remove)?)))
.collect::<Result<Vec<_>, _>>()?;
let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else {
return Ok(h.get_optype(node).extension_delta());
};
if es.contains(&TO_BE_INFERRED) {
child_sets.push((node, es.clone())); } else if remove {
child_sets.iter().try_for_each(|(ch, ch_exts)| {
if !es.is_superset(ch_exts) {
return Err(ExtensionError {
parent: node,
parent_extensions: es.clone(),
child: *ch,
child_extensions: ch_exts.clone(),
});
}
Ok(())
})?;
} else {
return Ok(es.clone()); }
let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e));
*es = ExtensionSet::singleton(&TO_BE_INFERRED).missing_from(&merged);
Ok(es.clone())
}
infer(self, self.root(), remove)?;
Ok(())
}
}
impl Hugr {
pub(crate) fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self {
let mut graph = MultiPortGraph::with_capacity(nodes, ports);
let hierarchy = Hierarchy::new();
let mut op_types = UnmanagedDenseMap::with_capacity(nodes);
let root = graph.add_node(0, 0);
op_types[root] = root_node;
Self {
graph,
hierarchy,
root,
op_types,
metadata: UnmanagedDenseMap::with_capacity(nodes),
}
}
pub(crate) fn set_root(&mut self, root: Node) {
self.hierarchy.detach(self.root);
self.root = root.pg_index();
}
pub(crate) fn add_node(&mut self, nodetype: OpType) -> Node {
let node = self
.graph
.add_node(nodetype.input_count(), nodetype.output_count());
self.op_types[node] = nodetype;
node.into()
}
fn canonical_order(&self, root: Node) -> impl Iterator<Item = Node> + '_ {
let mut queue = VecDeque::from([root]);
iter::from_fn(move || {
let node = queue.pop_front()?;
for child in self.children(node) {
queue.push_back(child);
}
Some(node)
})
}
pub fn canonicalize_nodes(&mut self, mut rekey: impl FnMut(Node, Node)) {
let mut ordered = Vec::with_capacity(self.node_count());
let root = self.root();
ordered.extend(self.as_mut().canonical_order(root));
for position in 0..ordered.len() {
let mut source: Node = ordered[position];
while position > source.index() {
source = ordered[source.index()];
}
let target: Node = portgraph::NodeIndex::new(position).into();
if target != source {
let pg_target = target.pg_index();
let pg_source = source.pg_index();
self.graph.swap_nodes(pg_target, pg_source);
self.op_types.swap(pg_target, pg_source);
self.hierarchy.swap_nodes(pg_target, pg_source);
rekey(source, target);
}
}
self.root = portgraph::NodeIndex::new(0);
self.graph.compact_nodes(|_, _| {});
}
}
#[derive(Debug, Clone, PartialEq, Error)]
#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")]
pub struct ExtensionError {
parent: Node,
parent_extensions: ExtensionSet,
child: Node,
child_extensions: ExtensionSet,
}
#[derive(Debug, Clone, PartialEq, Eq, Error)]
#[non_exhaustive]
pub enum HugrError {
#[error("Invalid tag: required a tag in {required} but found {actual}")]
#[allow(missing_docs)]
InvalidTag { required: OpTag, actual: OpTag },
#[error("Invalid port direction {0:?}.")]
InvalidPortDirection(Direction),
}
#[cfg(test)]
mod test {
use std::{fs::File, io::BufReader};
use super::internal::HugrMutInternals;
#[cfg(feature = "extension_inference")]
use super::ValidationError;
use super::{ExtensionError, Hugr, HugrMut, HugrView, Node};
use crate::extension::{
ExtensionId, ExtensionSet, EMPTY_REG, PRELUDE_REGISTRY, TO_BE_INFERRED,
};
use crate::types::{FunctionType, Type};
use crate::{const_extension_ids, ops, test_file, type_row};
use rstest::rstest;
#[test]
fn impls_send_and_sync() {
#[allow(dead_code)]
trait Test: Send + Sync {}
impl Test for Hugr {}
}
#[test]
fn io_node() {
use crate::builder::test::simple_dfg_hugr;
use cool_asserts::assert_matches;
let hugr = simple_dfg_hugr();
assert_matches!(hugr.get_io(hugr.root()), Some(_));
}
#[test]
#[cfg_attr(miri, ignore)] #[should_panic] fn hugr_validation_0() {
let mut hugr: Hugr = serde_json::from_reader(BufReader::new(
File::open(test_file!("hugr-0.json")).unwrap(),
))
.unwrap();
assert!(
hugr.update_validate(&PRELUDE_REGISTRY).is_err(),
"HUGR should not validate."
);
}
#[test]
#[cfg_attr(miri, ignore)] fn hugr_validation_1() {
let mut hugr: Hugr = serde_json::from_reader(BufReader::new(
File::open(test_file!("hugr-1.json")).unwrap(),
))
.unwrap();
assert!(hugr.update_validate(&PRELUDE_REGISTRY).is_ok());
}
#[test]
#[cfg_attr(miri, ignore)] fn hugr_validation_2() {
let mut hugr: Hugr = serde_json::from_reader(BufReader::new(
File::open(test_file!("hugr-2.json")).unwrap(),
))
.unwrap();
assert!(
hugr.update_validate(&PRELUDE_REGISTRY).is_err(),
"HUGR should not validate."
);
}
#[test]
#[cfg_attr(miri, ignore)] fn hugr_validation_3() {
let mut hugr: Hugr = serde_json::from_reader(BufReader::new(
File::open(test_file!("hugr-3.json")).unwrap(),
))
.unwrap();
assert!(hugr.update_validate(&PRELUDE_REGISTRY).is_ok());
}
const_extension_ids! {
const XA: ExtensionId = "EXT_A";
const XB: ExtensionId = "EXT_B";
}
#[rstest]
#[case([], XA.into())]
#[case([XA], XA.into())]
#[case([XB], ExtensionSet::from_iter([XA, XB]))]
fn infer_single_delta(
#[case] parent: impl IntoIterator<Item = ExtensionId>,
#[values(true, false)] remove: bool, #[case] result: ExtensionSet,
) {
let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into());
let (mut h, _) = build_ext_dfg(parent);
h.infer_extensions(remove).unwrap();
assert_eq!(h, build_ext_dfg(result).0);
}
#[test]
fn infer_removes_from_delta() {
let parent = ExtensionSet::from_iter([XA, XB]);
let mut h = build_ext_dfg(parent.clone()).0;
let backup = h.clone();
h.infer_extensions(false).unwrap();
assert_eq!(h, backup); h.infer_extensions(true).unwrap();
assert_eq!(h, build_ext_dfg(XA.into()).0);
}
#[test]
fn infer_bad_remove() {
let (mut h, mid) = build_ext_dfg(XB.into());
let backup = h.clone();
h.infer_extensions(false).unwrap();
assert_eq!(h, backup); let val_res = h.validate(&EMPTY_REG);
let expected_err = ExtensionError {
parent: h.root(),
parent_extensions: XB.into(),
child: mid,
child_extensions: XA.into(),
};
#[cfg(feature = "extension_inference")]
assert_eq!(
val_res,
Err(ValidationError::ExtensionError(expected_err.clone()))
);
#[cfg(not(feature = "extension_inference"))]
assert!(val_res.is_ok());
let inf_res = h.infer_extensions(true);
assert_eq!(inf_res, Err(expected_err));
}
fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) {
let ty = Type::new_function(FunctionType::new_endo(type_row![]));
let mut h = Hugr::new(ops::DFG {
signature: FunctionType::new_endo(ty.clone()).with_extension_delta(parent.clone()),
});
let root = h.root();
let mid = add_inliftout(&mut h, root, ty);
(h, mid)
}
fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node {
let inp = h.add_node_with_parent(
p,
ops::Input {
types: ty.clone().into(),
},
);
let out = h.add_node_with_parent(
p,
ops::Output {
types: ty.clone().into(),
},
);
let mid = h.add_node_with_parent(
p,
ops::Lift {
type_row: ty.into(),
new_extension: XA,
},
);
h.connect(inp, 0, mid, 0);
h.connect(mid, 0, out, 0);
mid
}
#[rstest]
#[case([XA], [TO_BE_INFERRED], true, [XA])]
#[case([XA, XB], [TO_BE_INFERRED], true, [XA])]
#[case([XB], [TO_BE_INFERRED], false, [XA])]
#[case([XB], [XA, TO_BE_INFERRED], false, [XA])]
#[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])]
#[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])]
#[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])]
#[case([XA], [XA, XB], true, [XA])]
fn infer_three_generations(
#[case] grandparent: impl IntoIterator<Item = ExtensionId>,
#[case] parent: impl IntoIterator<Item = ExtensionId>,
#[case] success: bool,
#[case] result: impl IntoIterator<Item = ExtensionId>,
) {
let ty = Type::new_function(FunctionType::new_endo(type_row![]));
let grandparent = ExtensionSet::from_iter(grandparent);
let result = ExtensionSet::from_iter(result);
let root_ty = ops::Conditional {
sum_rows: vec![type_row![]],
other_inputs: ty.clone().into(),
outputs: ty.clone().into(),
extension_delta: grandparent.clone(),
};
let mut h = Hugr::new(root_ty.clone());
let p = h.add_node_with_parent(
h.root(),
ops::Case {
signature: FunctionType::new_endo(ty.clone())
.with_extension_delta(ExtensionSet::from_iter(parent)),
},
);
add_inliftout(&mut h, p, ty.clone());
assert!(h.validate_extensions().is_err());
let backup = h.clone();
let inf_res = h.infer_extensions(true);
if success {
assert!(inf_res.is_ok());
let expected_p = ops::Case {
signature: FunctionType::new_endo(ty).with_extension_delta(result.clone()),
};
let mut expected = backup;
expected.replace_op(p, expected_p).unwrap();
let expected_gp = ops::Conditional {
extension_delta: result,
..root_ty
};
expected.replace_op(h.root(), expected_gp).unwrap();
assert_eq!(h, expected);
} else {
assert_eq!(
inf_res,
Err(ExtensionError {
parent: h.root(),
parent_extensions: grandparent,
child: p,
child_extensions: result
})
);
}
}
}