1pub mod hugrmut;
4
5pub(crate) mod ident;
6pub mod internal;
7pub mod rewrite;
8pub mod serialize;
9pub mod validate;
10pub mod views;
11
12use std::collections::VecDeque;
13use std::io::Read;
14use std::iter;
15
16pub(crate) use self::hugrmut::HugrMut;
17pub use self::validate::ValidationError;
18
19pub use ident::{IdentList, InvalidIdentifier};
20pub use rewrite::{Rewrite, SimpleReplacement, SimpleReplacementError};
21
22use portgraph::multiportgraph::MultiPortGraph;
23use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap};
24use thiserror::Error;
25
26pub use self::views::{HugrView, RootTagged};
27use crate::core::NodeIndex;
28use crate::extension::resolution::{
29 resolve_op_extensions, resolve_op_types_extensions, ExtensionResolutionError,
30 WeakExtensionRegistry,
31};
32use crate::extension::{ExtensionRegistry, ExtensionSet, TO_BE_INFERRED};
33use crate::ops::{OpTag, OpTrait};
34pub use crate::ops::{OpType, DEFAULT_OPTYPE};
35use crate::{Direction, Node};
36
37#[derive(Clone, Debug, PartialEq)]
39pub struct Hugr {
40 graph: MultiPortGraph,
42
43 hierarchy: Hierarchy,
45
46 root: portgraph::NodeIndex,
48
49 op_types: UnmanagedDenseMap<portgraph::NodeIndex, OpType>,
51
52 metadata: UnmanagedDenseMap<portgraph::NodeIndex, Option<NodeMetadataMap>>,
54
55 extensions: ExtensionRegistry,
57}
58
59impl Default for Hugr {
60 fn default() -> Self {
61 Self::new(crate::ops::Module::new())
62 }
63}
64
65impl AsRef<Hugr> for Hugr {
66 fn as_ref(&self) -> &Hugr {
67 self
68 }
69}
70
71impl AsMut<Hugr> for Hugr {
72 fn as_mut(&mut self) -> &mut Hugr {
73 self
74 }
75}
76
77pub type NodeMetadata = serde_json::Value;
81
82pub type NodeMetadataMap = serde_json::Map<String, NodeMetadata>;
84
85impl Hugr {
87 pub fn new(root_node: impl Into<OpType>) -> Self {
89 Self::with_capacity(root_node.into(), 0, 0)
90 }
91
92 pub fn load_json(
100 reader: impl Read,
101 extension_registry: &ExtensionRegistry,
102 ) -> Result<Self, LoadHugrError> {
103 let mut hugr: Hugr = serde_json::from_reader(reader)?;
104
105 hugr.resolve_extension_defs(extension_registry)?;
106 hugr.validate_no_extensions()?;
107
108 if cfg!(feature = "extension_inference") {
109 hugr.infer_extensions(false)?;
110 hugr.validate_extensions()?;
111 }
112
113 Ok(hugr)
114 }
115
116 pub fn infer_extensions(&mut self, remove: bool) -> Result<(), ExtensionError> {
136 fn delta_mut(optype: &mut OpType) -> Option<&mut ExtensionSet> {
137 match optype {
138 OpType::DFG(dfg) => Some(&mut dfg.signature.runtime_reqs),
139 OpType::DataflowBlock(dfb) => Some(&mut dfb.extension_delta),
140 OpType::TailLoop(tl) => Some(&mut tl.extension_delta),
141 OpType::CFG(cfg) => Some(&mut cfg.signature.runtime_reqs),
142 OpType::Conditional(c) => Some(&mut c.extension_delta),
143 OpType::Case(c) => Some(&mut c.signature.runtime_reqs),
144 _ => None,
147 }
148 }
149 fn infer(h: &mut Hugr, node: Node, remove: bool) -> Result<ExtensionSet, ExtensionError> {
150 let mut child_sets = h
151 .children(node)
152 .collect::<Vec<_>>() .into_iter()
154 .map(|ch| Ok((ch, infer(h, ch, remove)?)))
155 .collect::<Result<Vec<_>, _>>()?;
156
157 let Some(es) = delta_mut(h.op_types.get_mut(node.pg_index())) else {
158 return Ok(h.get_optype(node).extension_delta());
159 };
160 if es.contains(&TO_BE_INFERRED) {
161 child_sets.push((node, es.clone())); } else if remove {
164 child_sets.iter().try_for_each(|(ch, ch_exts)| {
165 if !es.is_superset(ch_exts) {
166 return Err(ExtensionError {
167 parent: node,
168 parent_extensions: es.clone(),
169 child: *ch,
170 child_extensions: ch_exts.clone(),
171 });
172 }
173 Ok(())
174 })?;
175 } else {
176 return Ok(es.clone()); }
178 let merged = ExtensionSet::union_over(child_sets.into_iter().map(|(_, e)| e));
179 *es = ExtensionSet::singleton(TO_BE_INFERRED).missing_from(&merged);
180
181 Ok(es.clone())
182 }
183 infer(self, self.root(), remove)?;
184 Ok(())
185 }
186
187 pub fn resolve_extension_defs(
220 &mut self,
221 extensions: &ExtensionRegistry,
222 ) -> Result<(), ExtensionResolutionError> {
223 let mut used_extensions = ExtensionRegistry::default();
224
225 let weak_extensions: WeakExtensionRegistry = extensions.into();
236 for n in 0..self.graph.node_capacity() {
237 let pg_node = portgraph::NodeIndex::new(n);
238 let node: Node = pg_node.into();
239 if !self.contains_node(node) {
240 continue;
241 }
242
243 let op = &mut self.op_types[pg_node];
244
245 if let Some(extension) = resolve_op_extensions(node, op, extensions)? {
246 used_extensions.register_updated_ref(extension);
247 }
248 used_extensions.extend(
249 resolve_op_types_extensions(Some(node), op, &weak_extensions)?.map(|weak| {
250 weak.upgrade()
251 .expect("Extension comes from a valid registry")
252 }),
253 );
254 }
255
256 self.extensions = used_extensions;
257 Ok(())
258 }
259}
260
261impl Hugr {
263 pub(crate) fn with_capacity(root_node: OpType, nodes: usize, ports: usize) -> Self {
265 let mut graph = MultiPortGraph::with_capacity(nodes, ports);
266 let hierarchy = Hierarchy::new();
267 let mut op_types = UnmanagedDenseMap::with_capacity(nodes);
268 let root = graph.add_node(root_node.input_count(), root_node.output_count());
269 let extensions = root_node.used_extensions();
270 op_types[root] = root_node;
271
272 Self {
273 graph,
274 hierarchy,
275 root,
276 op_types,
277 metadata: UnmanagedDenseMap::with_capacity(nodes),
278 extensions: extensions.unwrap_or_default(),
279 }
280 }
281
282 pub(crate) fn set_root(&mut self, root: Node) {
284 self.hierarchy.detach(self.root);
285 self.root = root.pg_index();
286 }
287
288 pub(crate) fn add_node(&mut self, nodetype: OpType) -> Node {
290 let node = self
291 .graph
292 .add_node(nodetype.input_count(), nodetype.output_count());
293 self.op_types[node] = nodetype;
294 node.into()
295 }
296
297 fn canonical_order(&self, root: Node) -> impl Iterator<Item = Node> + '_ {
305 let mut queue = VecDeque::from([root]);
307 iter::from_fn(move || {
308 let node = queue.pop_front()?;
309 for child in self.children(node) {
310 queue.push_back(child);
311 }
312 Some(node)
313 })
314 }
315
316 pub fn canonicalize_nodes(&mut self, mut rekey: impl FnMut(Node, Node)) {
324 let mut ordered = Vec::with_capacity(self.node_count());
326 let root = self.root();
327 ordered.extend(self.as_mut().canonical_order(root));
328
329 for position in 0..ordered.len() {
333 let mut source: Node = ordered[position];
336 while position > source.index() {
337 source = ordered[source.index()];
338 }
339
340 let target: Node = portgraph::NodeIndex::new(position).into();
341 if target != source {
342 let pg_target = target.pg_index();
343 let pg_source = source.pg_index();
344 self.graph.swap_nodes(pg_target, pg_source);
345 self.op_types.swap(pg_target, pg_source);
346 self.hierarchy.swap_nodes(pg_target, pg_source);
347 rekey(source, target);
348 }
349 }
350 self.root = portgraph::NodeIndex::new(0);
351
352 self.graph.compact_nodes(|_, _| {});
356 }
357}
358
359#[derive(Debug, Clone, PartialEq, Error)]
360#[error("Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}")]
361pub struct ExtensionError {
363 parent: Node,
364 parent_extensions: ExtensionSet,
365 child: Node,
366 child_extensions: ExtensionSet,
367}
368
369#[derive(Debug, Clone, PartialEq, Eq, Error)]
373#[non_exhaustive]
374pub enum HugrError {
375 #[error("Invalid tag: required a tag in {required} but found {actual}")]
378 #[allow(missing_docs)]
379 InvalidTag { required: OpTag, actual: OpTag },
380 #[error("Invalid port direction {0:?}.")]
382 InvalidPortDirection(Direction),
383}
384
385#[derive(Debug, Error)]
387#[non_exhaustive]
388pub enum LoadHugrError {
389 #[error("Error while loading the Hugr from JSON: {0}")]
391 Load(#[from] serde_json::Error),
392 #[error(transparent)]
394 Validation(#[from] ValidationError),
395 #[error(transparent)]
397 Extension(#[from] ExtensionResolutionError),
398 #[error(transparent)]
400 RuntimeInference(#[from] ExtensionError),
401}
402
403#[cfg(test)]
404mod test {
405 use std::sync::Arc;
406 use std::{fs::File, io::BufReader};
407
408 use super::internal::HugrMutInternals;
409 #[cfg(feature = "extension_inference")]
410 use super::ValidationError;
411 use super::{ExtensionError, Hugr, HugrMut, HugrView, Node};
412 use crate::extension::{ExtensionId, ExtensionSet, PRELUDE_REGISTRY, TO_BE_INFERRED};
413 use crate::ops::{ExtensionOp, OpName};
414 use crate::types::type_param::TypeParam;
415 use crate::types::{
416 FuncValueType, PolyFuncTypeRV, Signature, Type, TypeArg, TypeBound, TypeRV, TypeRow,
417 };
418
419 use crate::{const_extension_ids, ops, test_file, type_row, Extension};
420 use cool_asserts::assert_matches;
421 use lazy_static::lazy_static;
422 use rstest::rstest;
423
424 const_extension_ids! {
425 pub(crate) const LIFT_EXT_ID: ExtensionId = "LIFT_EXT_ID";
426 }
427 lazy_static! {
428 pub(crate) static ref LIFT_EXT: Arc<Extension> = {
430 Extension::new_arc(
431 LIFT_EXT_ID,
432 hugr::extension::Version::new(0, 0, 0),
433 |ext, extension_ref| {
434 ext.add_op(
435 OpName::new_inline("Lift"),
436 "".into(),
437 PolyFuncTypeRV::new(
438 vec![TypeParam::Extensions, TypeParam::new_list(TypeBound::Any)],
439 FuncValueType::new_endo(TypeRV::new_row_var_use(1, TypeBound::Any))
440 .with_extension_delta(ExtensionSet::type_var(0)),
441 ),
442 extension_ref,
443 )
444 .unwrap();
445 },
446 )
447 };
448 }
449
450 pub(crate) fn lift_op(
451 type_row: impl Into<TypeRow>,
452 extensions: impl Into<ExtensionSet>,
453 ) -> ExtensionOp {
454 LIFT_EXT
455 .instantiate_extension_op(
456 "Lift",
457 [
458 TypeArg::Extensions {
459 es: extensions.into(),
460 },
461 TypeArg::Sequence {
462 elems: type_row
463 .into()
464 .iter()
465 .map(|t| TypeArg::Type { ty: t.clone() })
466 .collect(),
467 },
468 ],
469 )
470 .unwrap()
471 }
472
473 #[test]
474 fn impls_send_and_sync() {
475 #[allow(dead_code)]
478 trait Test: Send + Sync {}
479 impl Test for Hugr {}
480 }
481
482 #[test]
483 fn io_node() {
484 use crate::builder::test::simple_dfg_hugr;
485 use cool_asserts::assert_matches;
486
487 let hugr = simple_dfg_hugr();
488 assert_matches!(hugr.get_io(hugr.root()), Some(_));
489 }
490
491 #[test]
492 #[cfg_attr(miri, ignore)] fn hugr_validation_0() {
494 let hugr = Hugr::load_json(
496 BufReader::new(File::open(test_file!("hugr-0.json")).unwrap()),
497 &PRELUDE_REGISTRY,
498 );
499 assert_matches!(hugr, Err(_));
500 }
501
502 #[test]
503 #[cfg_attr(miri, ignore)] fn hugr_validation_1() {
505 let hugr = Hugr::load_json(
507 BufReader::new(File::open(test_file!("hugr-1.json")).unwrap()),
508 &PRELUDE_REGISTRY,
509 );
510 assert_matches!(&hugr, Ok(_));
511 }
512
513 #[test]
514 #[cfg_attr(miri, ignore)] fn hugr_validation_2() {
516 let hugr = Hugr::load_json(
518 BufReader::new(File::open(test_file!("hugr-2.json")).unwrap()),
519 &PRELUDE_REGISTRY,
520 );
521 assert_matches!(hugr, Err(_));
522 }
523
524 #[test]
525 #[cfg_attr(miri, ignore)] fn hugr_validation_3() {
527 let hugr = Hugr::load_json(
529 BufReader::new(File::open(test_file!("hugr-3.json")).unwrap()),
530 &PRELUDE_REGISTRY,
531 );
532 assert_matches!(&hugr, Ok(_));
533 }
534
535 const_extension_ids! {
536 const XA: ExtensionId = "EXT_A";
537 const XB: ExtensionId = "EXT_B";
538 }
539
540 #[rstest]
541 #[case([], XA.into())]
542 #[case([XA], XA.into())]
543 #[case([XB], ExtensionSet::from_iter([XA, XB]))]
544
545 fn infer_single_delta(
546 #[case] parent: impl IntoIterator<Item = ExtensionId>,
547 #[values(true, false)] remove: bool, #[case] result: ExtensionSet,
549 ) {
550 let parent = ExtensionSet::from_iter(parent).union(TO_BE_INFERRED.into());
551 let (mut h, _) = build_ext_dfg(parent);
552 h.infer_extensions(remove).unwrap();
553 assert_eq!(h, build_ext_dfg(result.union(LIFT_EXT_ID.into())).0);
554 }
555
556 #[test]
557 fn infer_removes_from_delta() {
558 let parent = ExtensionSet::from_iter([XA, XB, LIFT_EXT_ID]);
559 let mut h = build_ext_dfg(parent.clone()).0;
560 let backup = h.clone();
561 h.infer_extensions(false).unwrap();
562 assert_eq!(h, backup); h.infer_extensions(true).unwrap();
564 assert_eq!(
565 h,
566 build_ext_dfg(ExtensionSet::from_iter([XA, LIFT_EXT_ID])).0
567 );
568 }
569
570 #[test]
571 fn infer_bad_remove() {
572 let (mut h, mid) = build_ext_dfg(XB.into());
573 let backup = h.clone();
574 h.infer_extensions(false).unwrap();
575 assert_eq!(h, backup); let val_res = h.validate();
577 let expected_err = ExtensionError {
578 parent: h.root(),
579 parent_extensions: XB.into(),
580 child: mid,
581 child_extensions: ExtensionSet::from_iter([XA, LIFT_EXT_ID]),
582 };
583 #[cfg(feature = "extension_inference")]
584 assert_eq!(
585 val_res,
586 Err(ValidationError::ExtensionError(expected_err.clone()))
587 );
588 #[cfg(not(feature = "extension_inference"))]
589 assert!(val_res.is_ok());
590
591 let inf_res = h.infer_extensions(true);
592 assert_eq!(inf_res, Err(expected_err));
593 }
594
595 fn build_ext_dfg(parent: ExtensionSet) -> (Hugr, Node) {
596 let ty = Type::new_function(Signature::new_endo(type_row![]));
597 let mut h = Hugr::new(ops::DFG {
598 signature: Signature::new_endo(ty.clone()).with_extension_delta(parent.clone()),
599 });
600 let root = h.root();
601 let mid = add_inliftout(&mut h, root, ty);
602 (h, mid)
603 }
604
605 fn add_inliftout(h: &mut Hugr, p: Node, ty: Type) -> Node {
606 let inp = h.add_node_with_parent(
607 p,
608 ops::Input {
609 types: ty.clone().into(),
610 },
611 );
612 let out = h.add_node_with_parent(
613 p,
614 ops::Output {
615 types: ty.clone().into(),
616 },
617 );
618 let mid = h.add_node_with_parent(p, lift_op(ty, XA));
619 h.connect(inp, 0, mid, 0);
620 h.connect(mid, 0, out, 0);
621 mid
622 }
623
624 #[rstest]
625 #[case([XA], [TO_BE_INFERRED], true, [XA])]
627 #[case([XA, XB], [TO_BE_INFERRED], true, [XA])]
629 #[case([XB], [TO_BE_INFERRED], false, [XA])]
631 #[case([XB], [XA, TO_BE_INFERRED], false, [XA])]
633 #[case([XA], [XB, TO_BE_INFERRED], false, [XA, XB])]
635 #[case([XA, XB], [XB, TO_BE_INFERRED], true, [XA, XB])]
637 #[case([TO_BE_INFERRED], [TO_BE_INFERRED, XB], true, [XA, XB])]
639 #[case([XA], [XA, XB], true, [XA])]
641 fn infer_three_generations(
642 #[case] grandparent: impl IntoIterator<Item = ExtensionId>,
643 #[case] parent: impl IntoIterator<Item = ExtensionId>,
644 #[case] success: bool,
645 #[case] result: impl IntoIterator<Item = ExtensionId>,
646 ) {
647 let ty = Type::new_function(Signature::new_endo(type_row![]));
648 let grandparent = ExtensionSet::from_iter(grandparent).union(LIFT_EXT_ID.into());
649 let parent = ExtensionSet::from_iter(parent).union(LIFT_EXT_ID.into());
650 let result = ExtensionSet::from_iter(result).union(LIFT_EXT_ID.into());
651 let root_ty = ops::Conditional {
652 sum_rows: vec![type_row![]],
653 other_inputs: ty.clone().into(),
654 outputs: ty.clone().into(),
655 extension_delta: grandparent.clone(),
656 };
657 let mut h = Hugr::new(root_ty.clone());
658 let p = h.add_node_with_parent(
659 h.root(),
660 ops::Case {
661 signature: Signature::new_endo(ty.clone()).with_extension_delta(parent),
662 },
663 );
664 add_inliftout(&mut h, p, ty.clone());
665 assert!(h.validate_extensions().is_err());
666 let backup = h.clone();
667 let inf_res = h.infer_extensions(true);
668 if success {
669 assert!(inf_res.is_ok());
670 let expected_p = ops::Case {
671 signature: Signature::new_endo(ty).with_extension_delta(result.clone()),
672 };
673 let mut expected = backup;
674 expected.replace_op(p, expected_p).unwrap();
675 let expected_gp = ops::Conditional {
676 extension_delta: result,
677 ..root_ty
678 };
679 expected.replace_op(h.root(), expected_gp).unwrap();
680
681 assert_eq!(h, expected);
682 } else {
683 assert_eq!(
684 inf_res,
685 Err(ExtensionError {
686 parent: h.root(),
687 parent_extensions: grandparent,
688 child: p,
689 child_extensions: result
690 })
691 );
692 }
693 }
694}