hugr_core/
package.rs

1//! Bundles of hugr modules along with the extension required to load them.
2
3use derive_more::{Display, Error, From};
4use itertools::Itertools;
5use std::path::Path;
6use std::{fs, io, mem};
7
8use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
9use crate::envelope::{read_envelope, write_envelope, EnvelopeConfig, EnvelopeError};
10use crate::extension::resolution::ExtensionResolutionError;
11use crate::extension::{ExtensionId, ExtensionRegistry, PRELUDE_REGISTRY};
12use crate::hugr::internal::HugrMutInternals;
13use crate::hugr::{ExtensionError, HugrView, ValidationError};
14use crate::ops::{FuncDefn, Module, NamedOp, OpTag, OpTrait, OpType};
15use crate::{Extension, Hugr};
16
17#[derive(Debug, Default, Clone, PartialEq, serde::Serialize, serde::Deserialize)]
18/// Package of module HUGRs.
19pub struct Package {
20    /// Module HUGRs included in the package.
21    pub modules: Vec<Hugr>,
22    /// Extensions used in the modules.
23    ///
24    /// This is a superset of the extensions used in the modules.
25    pub extensions: ExtensionRegistry,
26}
27
28impl Package {
29    /// Create a new package from a list of hugrs.
30    ///
31    /// All the HUGRs must have a `Module` operation at the root.
32    ///
33    /// Collects the extensions used in the modules and stores them in top-level
34    /// `extensions` attribute.
35    ///
36    /// # Errors
37    ///
38    /// Returns an error if any of the HUGRs does not have a `Module` root.
39    pub fn new(modules: impl IntoIterator<Item = Hugr>) -> Result<Self, PackageError> {
40        let modules: Vec<Hugr> = modules.into_iter().collect();
41        let mut extensions = ExtensionRegistry::default();
42        for (idx, module) in modules.iter().enumerate() {
43            let root_op = module.get_optype(module.root());
44            if !root_op.is_module() {
45                return Err(PackageError::NonModuleHugr {
46                    module_index: idx,
47                    root_op: root_op.clone(),
48                });
49            }
50            extensions.extend(module.extensions());
51        }
52        Ok(Self {
53            modules,
54            extensions,
55        })
56    }
57
58    /// Create a new package from a list of hugrs.
59    ///
60    /// HUGRs that do not have a `Module` root will be wrapped in a new `Module` root,
61    /// depending on the root optype.
62    ///
63    /// - Currently all non-module roots will raise [PackageError::CannotWrapHugr].
64    ///
65    /// # Errors
66    ///
67    /// Returns an error if any of the HUGRs cannot be wrapped in a module.
68    pub fn from_hugrs(modules: impl IntoIterator<Item = Hugr>) -> Result<Self, PackageError> {
69        let modules: Vec<Hugr> = modules
70            .into_iter()
71            .map(to_module_hugr)
72            .collect::<Result<_, PackageError>>()?;
73
74        let mut extensions = ExtensionRegistry::default();
75        for module in &modules {
76            extensions.extend(module.extensions());
77        }
78
79        Ok(Self {
80            modules,
81            extensions,
82        })
83    }
84
85    /// Create a new package containing a single HUGR.
86    ///
87    /// If the Hugr is not a module, a new [OpType::Module] root will be added.
88    /// This behaviours depends on the root optype.
89    ///
90    /// - Currently all non-module roots will raise [PackageError::CannotWrapHugr].
91    ///
92    /// # Errors
93    ///
94    /// Returns an error if the hugr cannot be wrapped in a module.
95    pub fn from_hugr(hugr: Hugr) -> Result<Self, PackageError> {
96        let mut package = Self::default();
97        let module = to_module_hugr(hugr)?;
98        package.extensions = module.extensions().clone();
99        package.modules.push(module);
100        Ok(package)
101    }
102
103    /// Validate the modules of the package.
104    ///
105    /// Ensures that the top-level extension list is a superset of the extensions used in the modules.
106    pub fn validate(&self) -> Result<(), PackageValidationError> {
107        for hugr in self.modules.iter() {
108            hugr.validate()?;
109
110            let missing_exts = hugr
111                .extensions()
112                .ids()
113                .filter(|id| !self.extensions.contains(id))
114                .cloned()
115                .collect_vec();
116            if !missing_exts.is_empty() {
117                return Err(PackageValidationError::MissingExtension {
118                    missing: missing_exts,
119                    available: self.extensions.ids().cloned().collect(),
120                });
121            }
122        }
123        Ok(())
124    }
125
126    /// Read a Package from a HUGR envelope.
127    pub fn load(
128        reader: impl io::BufRead,
129        extensions: Option<&ExtensionRegistry>,
130    ) -> Result<Self, EnvelopeError> {
131        let extensions = extensions.unwrap_or(&PRELUDE_REGISTRY);
132        let (_, pkg) = read_envelope(reader, extensions)?;
133        Ok(pkg)
134    }
135
136    /// Read a Package from a HUGR envelope encoded in a string.
137    ///
138    /// Note that not all envelopes are valid strings. In the general case,
139    /// it is recommended to use `Package::load` with a bytearray instead.
140    pub fn load_str(
141        envelope: impl AsRef<str>,
142        extensions: Option<&ExtensionRegistry>,
143    ) -> Result<Self, EnvelopeError> {
144        Self::load(envelope.as_ref().as_bytes(), extensions)
145    }
146
147    /// Store the Package in a HUGR envelope.
148    pub fn store(
149        &self,
150        writer: impl io::Write,
151        config: EnvelopeConfig,
152    ) -> Result<(), EnvelopeError> {
153        write_envelope(writer, self, config)
154    }
155
156    /// Store the Package in a HUGR envelope encoded in a string.
157    ///
158    /// Note that not all envelopes are valid strings. In the general case,
159    /// it is recommended to use `Package::store` with a bytearray instead.
160    /// See [EnvelopeFormat::ascii_printable][crate::envelope::EnvelopeFormat::ascii_printable].
161    pub fn store_str(&self, config: EnvelopeConfig) -> Result<String, EnvelopeError> {
162        if !config.format.ascii_printable() {
163            return Err(EnvelopeError::NonASCIIFormat {
164                format: config.format,
165            });
166        }
167
168        let mut buf = Vec::new();
169        self.store(&mut buf, config)?;
170        Ok(String::from_utf8(buf).expect("Envelope is valid utf8"))
171    }
172
173    /// Read a Package in json format from an io reader.
174    ///
175    /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package].
176    //
177    // TODO: Make this a private method only used by the envelope reader, and remove the automatic HUGR fallback.
178    #[deprecated(
179        since = "0.14.5",
180        note = "Json encoding of packages is deprecated. Use `Package::load` instead"
181    )]
182    #[cfg_attr(coverage_nightly, coverage(off))]
183    pub fn from_json_reader(
184        reader: impl io::Read,
185        extension_registry: &ExtensionRegistry,
186    ) -> Result<Self, PackageEncodingError> {
187        let val: serde_json::Value = serde_json::from_reader(reader)?;
188
189        // Try to load a package json.
190        // Defers the extension registry loading so we can call [`ExtensionRegistry::load_json_value`] directly.
191        #[derive(Debug, serde::Deserialize)]
192        struct PackageDeser {
193            pub modules: Vec<Hugr>,
194            pub extensions: Vec<Extension>,
195        }
196        let loaded_pkg = serde_json::from_value::<PackageDeser>(val.clone());
197
198        if let Ok(PackageDeser {
199            mut modules,
200            extensions: pkg_extensions,
201        }) = loaded_pkg
202        {
203            let mut pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
204                pkg_extensions,
205                &extension_registry.into(),
206            )?;
207
208            // Resolve the operations in the modules using the defined registries.
209            let mut combined_registry = extension_registry.clone();
210            combined_registry.extend(&pkg_extensions);
211
212            for module in &mut modules {
213                module.resolve_extension_defs(&combined_registry)?;
214                pkg_extensions.extend(module.extensions());
215            }
216
217            return Ok(Package {
218                modules,
219                extensions: pkg_extensions,
220            });
221        };
222        let pkg_load_err = loaded_pkg.unwrap_err();
223
224        // As a fallback, try to load a hugr json.
225        if let Ok(mut hugr) = serde_json::from_value::<Hugr>(val) {
226            hugr.resolve_extension_defs(extension_registry)?;
227            if cfg!(feature = "extension_inference") {
228                hugr.infer_extensions(false)?;
229            }
230            return Ok(Package::from_hugr(hugr)?);
231        }
232
233        // Return the original error from parsing the package.
234        Err(PackageEncodingError::JsonEncoding(pkg_load_err))
235    }
236
237    /// Read a Package from a json string.
238    ///
239    /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package].
240    #[deprecated(
241        since = "0.14.5",
242        note = "Json encoding of packages is deprecated. Use `Package::load_str` instead"
243    )]
244    #[cfg_attr(coverage_nightly, coverage(off))]
245    pub fn from_json(
246        json: impl AsRef<str>,
247        extension_registry: &ExtensionRegistry,
248    ) -> Result<Self, PackageEncodingError> {
249        #[allow(deprecated)]
250        Self::from_json_reader(json.as_ref().as_bytes(), extension_registry)
251    }
252
253    /// Read a Package from a json file.
254    ///
255    /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package].
256    #[deprecated(
257        since = "0.14.5",
258        note = "Json encoding of packages is deprecated. Use `Package::load` instead"
259    )]
260    #[cfg_attr(coverage_nightly, coverage(off))]
261    pub fn from_json_file(
262        path: impl AsRef<Path>,
263        extension_registry: &ExtensionRegistry,
264    ) -> Result<Self, PackageEncodingError> {
265        let file = fs::File::open(path)?;
266        let reader = io::BufReader::new(file);
267        #[allow(deprecated)]
268        Self::from_json_reader(reader, extension_registry)
269    }
270
271    /// Write the Package in json format into an io writer.
272    #[deprecated(
273        since = "0.14.5",
274        note = "Json encoding of packages is deprecated. Use `Package::store` instead"
275    )]
276    #[cfg_attr(coverage_nightly, coverage(off))]
277    pub fn to_json_writer(&self, writer: impl io::Write) -> Result<(), PackageEncodingError> {
278        serde_json::to_writer(writer, self)?;
279        Ok(())
280    }
281
282    /// Write the Package into a json string.
283    ///
284    /// If the json encodes a single [Hugr] instead, it will be inserted in a new [Package].
285    #[deprecated(
286        since = "0.14.5",
287        note = "Json encoding of packages is deprecated. Use `Package::store_str` instead"
288    )]
289    #[cfg_attr(coverage_nightly, coverage(off))]
290    pub fn to_json(&self) -> Result<String, PackageEncodingError> {
291        let json = serde_json::to_string(self)?;
292        Ok(json)
293    }
294
295    /// Write the Package into a json file.
296    #[deprecated(
297        since = "0.14.5",
298        note = "Json encoding of packages is deprecated. Use `Package::store` instead"
299    )]
300    #[cfg_attr(coverage_nightly, coverage(off))]
301    pub fn to_json_file(&self, path: impl AsRef<Path>) -> Result<(), PackageEncodingError> {
302        let file = fs::OpenOptions::new()
303            .write(true)
304            .truncate(true)
305            .create(true)
306            .open(path)?;
307        let writer = io::BufWriter::new(file);
308        #[allow(deprecated)]
309        self.to_json_writer(writer)
310    }
311}
312
313impl AsRef<[Hugr]> for Package {
314    fn as_ref(&self) -> &[Hugr] {
315        &self.modules
316    }
317}
318
319/// Alter an arbitrary hugr to contain an [OpType::Module] root.
320///
321/// The behaviour depends on the root optype. See [Package::from_hugr] for details.
322///
323/// # Errors
324///
325/// Returns [PackageError::]
326fn to_module_hugr(mut hugr: Hugr) -> Result<Hugr, PackageError> {
327    let root = hugr.root();
328    let root_op = hugr.get_optype(root).clone();
329    let tag = root_op.tag();
330
331    // Modules can be returned as is.
332    if root_op.is_module() {
333        return Ok(hugr);
334    }
335    // If possible, wrap the hugr directly in a module.
336    if OpTag::ModuleOp.is_superset(tag) {
337        let new_root = hugr.add_node(Module::new().into());
338        hugr.set_root(new_root);
339        hugr.set_parent(root, new_root);
340        return Ok(hugr);
341    }
342    // If it is a DFG, make it into a "main" function definition and insert it into a module.
343    if OpTag::Dfg.is_superset(tag) {
344        let signature = root_op
345            .dataflow_signature()
346            .unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()));
347
348        // Convert the DFG into a `FuncDefn`
349        hugr.set_num_ports(root, 0, 1);
350        hugr.replace_op(
351            root,
352            FuncDefn {
353                name: "main".to_string(),
354                signature: signature.into_owned().into(),
355            },
356        )
357        .expect("Hugr accepts any root node");
358
359        // Wrap it in a module.
360        let new_root = hugr.add_node(Module::new().into());
361        hugr.set_root(new_root);
362        hugr.set_parent(root, new_root);
363        return Ok(hugr);
364    }
365    // Wrap it in a function definition named "main" inside the module otherwise.
366    if OpTag::DataflowChild.is_superset(tag) && !root_op.is_input() && !root_op.is_output() {
367        let signature = root_op
368            .dataflow_signature()
369            .unwrap_or_else(|| panic!("Dataflow child {} without signature", root_op.name()))
370            .into_owned();
371        let mut new_hugr = ModuleBuilder::new();
372        let mut func = new_hugr.define_function("main", signature).unwrap();
373        let dataflow_node = func.add_hugr_with_wires(hugr, func.input_wires()).unwrap();
374        func.finish_with_outputs(dataflow_node.outputs()).unwrap();
375        return Ok(mem::take(new_hugr.hugr_mut()));
376    }
377    // Reject all other hugrs.
378    Err(PackageError::CannotWrapHugr {
379        root_op: root_op.clone(),
380    })
381}
382
383/// Error raised while loading a package.
384#[derive(Debug, Display, Error, PartialEq)]
385#[non_exhaustive]
386pub enum PackageError {
387    /// A hugr in the package does not have an [OpType::Module] root.
388    #[display("Module {module_index} in the package does not have an OpType::Module root, but {}", root_op.name())]
389    NonModuleHugr {
390        /// The module index.
391        module_index: usize,
392        /// The invalid root operation.
393        root_op: OpType,
394    },
395    /// Tried to initialize a package with a hugr that cannot be wrapped in a module.
396    #[display("A hugr with optype {} cannot be wrapped in a module.", root_op.name())]
397    CannotWrapHugr {
398        /// The invalid root operation.
399        root_op: OpType,
400    },
401}
402
403/// Error raised while loading a package.
404#[derive(Debug, Display, Error, From)]
405#[non_exhaustive]
406pub enum PackageEncodingError {
407    /// Error raised while parsing the package json.
408    JsonEncoding(serde_json::Error),
409    /// Error raised while reading from a file.
410    IOError(io::Error),
411    /// Improper package definition.
412    Package(PackageError),
413    /// Could not resolve the extension needed to encode the hugr.
414    ExtensionResolution(ExtensionResolutionError),
415    /// Could not resolve the runtime extensions for the hugr.
416    RuntimeExtensionResolution(ExtensionError),
417}
418
419/// Error raised while validating a package.
420#[derive(Debug, Display, From, Error)]
421#[non_exhaustive]
422pub enum PackageValidationError {
423    /// Error raised while processing the package extensions.
424    #[display("The package modules use the extension{} {} not present in the defined set. The declared extensions are {}",
425            if missing.len() > 1 {"s"} else {""},
426            missing.iter().map(|id| id.to_string()).collect::<Vec<_>>().join(", "),
427            available.iter().map(|id| id.to_string()).collect::<Vec<_>>().join(", "),
428        )]
429    MissingExtension {
430        /// The missing extensions.
431        missing: Vec<ExtensionId>,
432        /// The available extensions.
433        available: Vec<ExtensionId>,
434    },
435    /// Error raised while validating the package hugrs.
436    Validation(ValidationError),
437}
438
439#[cfg(test)]
440mod test {
441    use cool_asserts::assert_matches;
442
443    use crate::builder::test::{
444        simple_cfg_hugr, simple_dfg_hugr, simple_funcdef_hugr, simple_module_hugr,
445    };
446    use crate::ops::dataflow::IOTrait;
447    use crate::ops::Input;
448
449    use super::*;
450    use rstest::{fixture, rstest};
451
452    #[fixture]
453    fn simple_input_node() -> Hugr {
454        Hugr::new(Input::new(vec![]))
455    }
456
457    #[rstest]
458    #[case::module("module", simple_module_hugr(), false)]
459    #[case::funcdef("funcdef", simple_funcdef_hugr(), false)]
460    #[case::dfg("dfg", simple_dfg_hugr(), false)]
461    #[case::cfg("cfg", simple_cfg_hugr(), false)]
462    #[case::unsupported_input("input", simple_input_node(), true)]
463    #[cfg_attr(miri, ignore)] // Opening files is not supported in (isolated) miri
464    fn hugr_to_package(#[case] test_name: &str, #[case] hugr: Hugr, #[case] errors: bool) {
465        match (&Package::from_hugr(hugr), errors) {
466            (Ok(package), false) => {
467                assert_eq!(package.modules.len(), 1);
468                let hugr = &package.modules[0];
469                let root_op = hugr.get_optype(hugr.root());
470                assert!(root_op.is_module());
471
472                insta::assert_snapshot!(test_name, hugr.mermaid_string());
473            }
474            (Err(_), true) => {}
475            (p, _) => panic!("Unexpected result {:?}", p),
476        }
477    }
478
479    #[rstest]
480    fn package_properties() {
481        let module = simple_module_hugr();
482        let dfg = simple_dfg_hugr();
483
484        assert_matches!(
485            Package::new([module.clone(), dfg.clone()]),
486            Err(PackageError::NonModuleHugr {
487                module_index: 1,
488                root_op: OpType::DFG(_),
489            })
490        );
491
492        let pkg = Package::from_hugrs([module, dfg]).unwrap();
493        pkg.validate().unwrap();
494
495        assert_eq!(pkg.modules.len(), 2);
496    }
497}