1pub mod hugrmut;
4pub(crate) mod ident;
5pub mod internal;
6pub mod linking;
7pub mod patch;
8pub mod serialize;
9pub mod validate;
10pub mod views;
11
12use std::collections::VecDeque;
13use std::io;
14use std::iter;
15
16pub(crate) use self::hugrmut::HugrMut;
17pub use self::validate::ValidationError;
18
19pub use ident::{IdentList, InvalidIdentifier};
20use itertools::Itertools;
21pub use patch::{Patch, SimpleReplacement, SimpleReplacementError};
22
23use portgraph::multiportgraph::MultiPortGraph;
24use portgraph::{Hierarchy, PortMut, PortView, UnmanagedDenseMap};
25use thiserror::Error;
26
27pub use self::views::HugrView;
28use crate::core::NodeIndex;
29use crate::envelope::{self, EnvelopeConfig, EnvelopeError};
30use crate::extension::resolution::{
31    ExtensionResolutionError, WeakExtensionRegistry, resolve_op_extensions,
32    resolve_op_types_extensions,
33};
34use crate::extension::{EMPTY_REG, ExtensionRegistry, ExtensionSet};
35use crate::ops::{self, Module, NamedOp, OpName, OpTag, OpTrait};
36pub use crate::ops::{DEFAULT_OPTYPE, OpType};
37use crate::package::Package;
38use crate::{Direction, Node};
39
40#[derive(Clone, Debug, PartialEq)]
42pub struct Hugr {
43    graph: MultiPortGraph<u32, u32, u32>,
45
46    hierarchy: Hierarchy,
48
49    module_root: portgraph::NodeIndex,
53
54    entrypoint: portgraph::NodeIndex,
56
57    op_types: UnmanagedDenseMap<portgraph::NodeIndex, OpType>,
59
60    metadata: UnmanagedDenseMap<portgraph::NodeIndex, Option<NodeMetadataMap>>,
62
63    extensions: ExtensionRegistry,
65}
66
67impl Default for Hugr {
68    fn default() -> Self {
69        Self::new()
70    }
71}
72
73impl AsRef<Hugr> for Hugr {
74    fn as_ref(&self) -> &Hugr {
75        self
76    }
77}
78
79impl AsMut<Hugr> for Hugr {
80    fn as_mut(&mut self) -> &mut Hugr {
81        self
82    }
83}
84
85pub type NodeMetadata = serde_json::Value;
89
90pub type NodeMetadataMap = serde_json::Map<String, NodeMetadata>;
92
93impl Hugr {
95    #[must_use]
97    pub fn new() -> Self {
98        make_module_hugr(Module::new().into(), 0, 0).unwrap()
99    }
100
101    pub fn new_with_entrypoint(entrypoint_op: impl Into<OpType>) -> Result<Self, HugrError> {
115        Self::with_capacity(entrypoint_op, 0, 0)
116    }
117
118    pub fn with_capacity(
132        entrypoint_op: impl Into<OpType>,
133        nodes: usize,
134        ports: usize,
135    ) -> Result<Self, HugrError> {
136        let entrypoint_op: OpType = entrypoint_op.into();
137        let op_name = entrypoint_op.name();
138        make_module_hugr(entrypoint_op, nodes, ports)
139            .ok_or(HugrError::UnsupportedEntrypoint { op: op_name })
140    }
141
142    pub fn reserve(&mut self, nodes: usize, links: usize) {
148        let ports = links * 2;
149        self.graph.reserve(nodes, ports);
150    }
151
152    pub fn load(
160        reader: impl io::BufRead,
161        extensions: Option<&ExtensionRegistry>,
162    ) -> Result<Self, EnvelopeError> {
163        let pkg = Package::load(reader, extensions)?;
164        match pkg.modules.into_iter().exactly_one() {
165            Ok(hugr) => Ok(hugr),
166            Err(e) => Err(EnvelopeError::ExpectedSingleHugr { count: e.count() }),
167        }
168    }
169
170    pub fn load_str(
181        envelope: impl AsRef<str>,
182        extensions: Option<&ExtensionRegistry>,
183    ) -> Result<Self, EnvelopeError> {
184        Self::load(envelope.as_ref().as_bytes(), extensions)
185    }
186
187    pub fn store(
194        &self,
195        writer: impl io::Write,
196        config: EnvelopeConfig,
197    ) -> Result<(), EnvelopeError> {
198        self.store_with_exts(writer, config, &EMPTY_REG)
199    }
200
201    pub fn store_with_exts(
207        &self,
208        writer: impl io::Write,
209        config: EnvelopeConfig,
210        extensions: &ExtensionRegistry,
211    ) -> Result<(), EnvelopeError> {
212        envelope::write_envelope_impl(writer, [self], extensions, config)
213    }
214
215    pub fn store_str(&self, config: EnvelopeConfig) -> Result<String, EnvelopeError> {
226        self.store_str_with_exts(config, &EMPTY_REG)
227    }
228
229    pub fn store_str_with_exts(
239        &self,
240        config: EnvelopeConfig,
241        extensions: &ExtensionRegistry,
242    ) -> Result<String, EnvelopeError> {
243        if !config.format.ascii_printable() {
244            return Err(EnvelopeError::NonASCIIFormat {
245                format: config.format,
246            });
247        }
248
249        let mut buf = Vec::new();
250        self.store_with_exts(&mut buf, config, extensions)?;
251        Ok(String::from_utf8(buf).expect("Envelope is valid utf8"))
252    }
253
254    pub fn resolve_extension_defs(
282        &mut self,
283        extensions: &ExtensionRegistry,
284    ) -> Result<(), ExtensionResolutionError> {
285        let mut used_extensions = ExtensionRegistry::default();
286
287        let weak_extensions: WeakExtensionRegistry = extensions.into();
298        for n in 0..self.graph.node_capacity() {
299            let pg_node = portgraph::NodeIndex::new(n);
300            let node: Node = pg_node.into();
301            if !self.contains_node(node) {
302                continue;
303            }
304
305            let op = &mut self.op_types[pg_node];
306
307            if let Some(extension) = resolve_op_extensions(node, op, extensions)? {
308                used_extensions.register_updated_ref(extension);
309            }
310            used_extensions.extend(
311                resolve_op_types_extensions(Some(node), op, &weak_extensions)?.map(|weak| {
312                    weak.upgrade()
313                        .expect("Extension comes from a valid registry")
314                }),
315            );
316        }
317
318        self.extensions = used_extensions;
319        Ok(())
320    }
321}
322
323impl Hugr {
325    pub(crate) fn add_node(&mut self, nodetype: OpType) -> Node {
327        let node = self
328            .graph
329            .add_node(nodetype.input_count(), nodetype.output_count());
330        self.op_types[node] = nodetype;
331        node.into()
332    }
333
334    fn canonical_order(&self, root: Node) -> impl Iterator<Item = Node> + '_ {
342        let mut queue = VecDeque::from([root]);
344        iter::from_fn(move || {
345            let node = queue.pop_front()?;
346            for child in self.children(node) {
347                queue.push_back(child);
348            }
349            Some(node)
350        })
351    }
352
353    pub fn canonicalize_nodes(&mut self, mut rekey: impl FnMut(Node, Node)) {
361        let ordered = {
363            let mut v = Vec::with_capacity(self.num_nodes());
364            v.extend(self.canonical_order(self.module_root()));
365            v
366        };
367        let mut new_entrypoint = None;
368
369        for position in 0..ordered.len() {
373            let pg_target = portgraph::NodeIndex::new(position);
374            let mut source: Node = ordered[position];
375
376            if source.into_portgraph() == self.entrypoint {
378                let old = new_entrypoint.replace(pg_target);
379                debug_assert!(old.is_none());
380            }
381
382            while position > source.index() {
385                source = ordered[source.index()];
386            }
387
388            let pg_source = source.into_portgraph();
389            if pg_target != pg_source {
390                self.graph.swap_nodes(pg_target, pg_source);
391                self.op_types.swap(pg_target, pg_source);
392                self.hierarchy.swap_nodes(pg_target, pg_source);
393                rekey(source, pg_target.into());
394            }
395        }
396        self.module_root = portgraph::NodeIndex::new(0);
397        self.entrypoint = new_entrypoint.unwrap();
398
399        self.graph.compact_nodes(|_, _| {});
403    }
404}
405
406#[derive(Debug, Clone, PartialEq, Error)]
407#[error(
408    "Parent node {parent} has extensions {parent_extensions} that are too restrictive for child node {child}, they must include child extensions {child_extensions}"
409)]
410pub struct ExtensionError {
412    parent: Node,
413    parent_extensions: ExtensionSet,
414    child: Node,
415    child_extensions: ExtensionSet,
416}
417
418#[derive(Debug, Clone, PartialEq, Eq, Error)]
420#[non_exhaustive]
421pub enum HugrError {
422    #[error("Invalid tag: required a tag in {required} but found {actual}")]
424    #[allow(missing_docs)]
425    InvalidTag { required: OpTag, actual: OpTag },
426    #[error("Invalid port direction {0:?}.")]
428    InvalidPortDirection(Direction),
429    #[error("Cannot initialize a HUGR with entrypoint type {op}")]
431    UnsupportedEntrypoint {
432        op: OpName,
434    },
435}
436
437fn make_module_hugr(root_op: OpType, nodes: usize, ports: usize) -> Option<Hugr> {
453    let mut graph = MultiPortGraph::with_capacity(nodes, ports);
454    let hierarchy = Hierarchy::new();
455    let mut op_types = UnmanagedDenseMap::with_capacity(nodes);
456    let extensions = root_op.used_extensions().unwrap_or_default();
457
458    let tag = root_op.tag();
460    let container_tags = [
461        OpTag::ModuleRoot,
462        OpTag::DataflowParent,
463        OpTag::Cfg,
464        OpTag::Conditional,
465    ];
466    if !container_tags.iter().any(|t| t.is_superset(tag)) {
467        return None;
468    }
469
470    let module = graph.add_node(0, 0);
471    op_types[module] = OpType::Module(ops::Module::new());
472
473    let mut hugr = Hugr {
474        graph,
475        hierarchy,
476        module_root: module,
477        entrypoint: module,
478        op_types,
479        metadata: UnmanagedDenseMap::with_capacity(nodes),
480        extensions,
481    };
482    let module: Node = module.into();
483
484    if root_op.is_module() {
486        }
488    else if OpTag::ModuleOp.is_superset(tag) {
490        let node = hugr.add_node_with_parent(module, root_op);
491        hugr.set_entrypoint(node);
492    }
493    else if OpTag::DataflowChild.is_superset(tag) && !root_op.is_input() && !root_op.is_output() {
496        let signature = root_op
497            .dataflow_signature()
498            .unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()))
499            .into_owned();
500        let dataflow_inputs = signature.input_count();
501        let dataflow_outputs = signature.output_count();
502
503        let func = hugr.add_node_with_parent(module, ops::FuncDefn::new("main", signature.clone()));
504        let inp = hugr.add_node_with_parent(
505            func,
506            ops::Input {
507                types: signature.input.clone(),
508            },
509        );
510        let out = hugr.add_node_with_parent(
511            func,
512            ops::Output {
513                types: signature.output.clone(),
514            },
515        );
516        let entrypoint = hugr.add_node_with_parent(func, root_op);
517
518        for port in 0..dataflow_inputs {
521            hugr.connect(inp, port, entrypoint, port);
522        }
523        for port in 0..dataflow_outputs {
524            hugr.connect(entrypoint, port, out, port);
525        }
526
527        hugr.set_entrypoint(entrypoint);
528    }
529    else {
531        debug_assert!(matches!(
532            root_op,
533            OpType::Input(_)
534                | OpType::Output(_)
535                | OpType::DataflowBlock(_)
536                | OpType::ExitBlock(_)
537                | OpType::Case(_)
538        ));
539        return None;
540    }
541
542    Some(hugr)
543}
544
545#[cfg(test)]
546pub(crate) mod test {
547    use std::{fs::File, io::BufReader};
548
549    use super::*;
550
551    use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
552    use crate::envelope::{EnvelopeError, PackageEncodingError};
553    use crate::extension::prelude::bool_t;
554    use crate::ops::OpaqueOp;
555    use crate::ops::handle::NodeHandle;
556    use crate::types::Signature;
557    use crate::{Visibility, test_file};
558    use cool_asserts::assert_matches;
559    use itertools::Either;
560    use portgraph::LinkView;
561    use rstest::rstest;
562
563    pub(crate) fn check_hugr_equality(lhs: &Hugr, rhs: &Hugr) {
565        let mut lhs = lhs.clone();
569        lhs.canonicalize_nodes(|_, _| {});
570        let mut rhs = rhs.clone();
571        rhs.canonicalize_nodes(|_, _| {});
572
573        assert_eq!(rhs.module_root(), lhs.module_root());
574        assert_eq!(rhs.entrypoint(), lhs.entrypoint());
575        assert_eq!(rhs.hierarchy, lhs.hierarchy);
576        assert_eq!(rhs.metadata, lhs.metadata);
577
578        for node in rhs.nodes() {
580            let new_op = rhs.get_optype(node);
581            let old_op = lhs.get_optype(node);
582            if !new_op.is_const() {
583                match (new_op, old_op) {
584                    (OpType::ExtensionOp(ext), OpType::OpaqueOp(opaque))
585                    | (OpType::OpaqueOp(opaque), OpType::ExtensionOp(ext)) => {
586                        let ext_opaque: OpaqueOp = ext.clone().into();
587                        assert_eq!(ext_opaque, opaque.clone());
588                    }
589                    _ => assert_eq!(new_op, old_op),
590                }
591            }
592        }
593
594        let new_graph = &rhs.graph;
596        let old_graph = &lhs.graph;
597        assert_eq!(new_graph.node_count(), old_graph.node_count());
598        assert_eq!(new_graph.port_count(), old_graph.port_count());
599        assert_eq!(new_graph.link_count(), old_graph.link_count());
600        for n in old_graph.nodes_iter() {
601            assert_eq!(new_graph.num_inputs(n), old_graph.num_inputs(n));
602            assert_eq!(new_graph.num_outputs(n), old_graph.num_outputs(n));
603            assert_eq!(
604                new_graph.output_neighbours(n).collect_vec(),
605                old_graph.output_neighbours(n).collect_vec()
606            );
607        }
608    }
609
610    #[test]
611    fn impls_send_and_sync() {
612        #[allow(dead_code)]
615        trait Test: Send + Sync {}
616        impl Test for Hugr {}
617    }
618
619    #[test]
620    fn io_node() {
621        use crate::builder::test::simple_dfg_hugr;
622
623        let hugr = simple_dfg_hugr();
624        assert_matches!(hugr.get_io(hugr.entrypoint()), Some(_));
625    }
626
627    #[test]
628    #[cfg_attr(miri, ignore)] fn hugr_validation_0() {
630        let hugr = Hugr::load(
632            BufReader::new(File::open(test_file!("hugr-0.hugr")).unwrap()),
633            None,
634        );
635        assert_matches!(
636            hugr,
637            Err(EnvelopeError::PackageEncoding {
638                source: PackageEncodingError::JsonEncoding(_)
639            })
640        );
641    }
642
643    #[test]
644    #[cfg_attr(miri, ignore)] fn hugr_validation_1() {
646        let hugr = Hugr::load(
648            BufReader::new(File::open(test_file!("hugr-1.hugr")).unwrap()),
649            None,
650        );
651        assert_matches!(&hugr, Ok(_));
652    }
653
654    #[test]
655    #[cfg_attr(miri, ignore)] fn hugr_validation_2() {
657        let hugr = Hugr::load(
659            BufReader::new(File::open(test_file!("hugr-2.hugr")).unwrap()),
660            None,
661        )
662        .unwrap();
663        assert_matches!(hugr.validate(), Err(_));
664    }
665
666    #[test]
667    #[cfg_attr(miri, ignore)] fn hugr_validation_3() {
669        let hugr = Hugr::load(
671            BufReader::new(File::open(test_file!("hugr-3.hugr")).unwrap()),
672            None,
673        );
674        assert_matches!(&hugr, Ok(_));
675    }
676
677    #[test]
678    #[cfg_attr(miri, ignore)] fn load_funcs_no_visibility() {
680        let hugr = Hugr::load(
681            BufReader::new(File::open(test_file!("hugr-no-visibility.hugr")).unwrap()),
682            None,
683        )
684        .unwrap();
685
686        let [_mod, decl, defn] = hugr.nodes().take(3).collect_array().unwrap();
687        assert_eq!(
688            hugr.get_optype(decl).as_func_decl().unwrap().visibility(),
689            &Visibility::Public
690        );
691        assert_eq!(
692            hugr.get_optype(defn).as_func_defn().unwrap().visibility(),
693            &Visibility::Private
694        );
695    }
696
697    fn hugr_failing_2262() -> Hugr {
698        let sig = Signature::new(vec![bool_t(); 2], bool_t());
699        let mut mb = ModuleBuilder::new();
700        let mut fa = mb.define_function("a", sig.clone()).unwrap();
701        let mut dfg = fa.dfg_builder(sig.clone(), fa.input_wires()).unwrap();
702        let call_op = ops::Call::try_new(sig.clone().into(), []).unwrap();
704        let call = dfg.add_dataflow_op(call_op, dfg.input_wires()).unwrap();
705        let dfg = dfg.finish_with_outputs(call.outputs()).unwrap();
706        fa.finish_with_outputs(dfg.outputs()).unwrap();
707        let fb = mb.define_function("b", sig).unwrap();
708        let [fst, _] = fb.input_wires_arr();
709        let fb = fb.finish_with_outputs([fst]).unwrap();
710        let mut h = mb.hugr().clone();
711
712        h.set_entrypoint(dfg.node()); let static_in = h.get_optype(call.node()).static_input_port().unwrap();
714        let static_out = h.get_optype(fb.node()).static_output_port().unwrap();
715        assert_eq!(h.single_linked_output(call.node(), static_in), None);
716        h.disconnect(call.node(), static_in);
717        h.connect(fb.node(), static_out, call.node(), static_in);
718        h
719    }
720
721    #[rstest]
722    #[cfg_attr(not(miri), case(Either::Left(test_file!("hugr-1.hugr"))))]
724    #[cfg_attr(not(miri), case(Either::Left(test_file!("hugr-3.hugr"))))]
725    #[case(Either::Right(hugr_failing_2262()))]
727    fn canonicalize_entrypoint(#[case] file_or_hugr: Either<&str, Hugr>) {
728        let hugr = match file_or_hugr {
729            Either::Left(file) => {
730                Hugr::load(BufReader::new(File::open(file).unwrap()), None).unwrap()
731            }
732            Either::Right(hugr) => hugr,
733        };
734        hugr.validate().unwrap();
735
736        for n in hugr.nodes() {
737            let mut h2 = hugr.clone();
738            h2.set_entrypoint(n);
739            if h2.validate().is_ok() {
740                h2.canonicalize_nodes(|_, _| {});
741                assert_eq!(hugr.get_optype(n), h2.entrypoint_optype());
742            }
743        }
744    }
745}