hugr_core/
envelope.rs

1//! Envelope format for HUGR packages.
2//!
3//! The format is designed to be extensible and backwards-compatible. It
4//! consists of a header declaring the format used to encode the HUGR, followed
5//! by the encoded HUGR itself.
6//!
7//! Use [`read_envelope`] and [`write_envelope`] for reading and writing
8//! envelopes from/to readers and writers, or call [`Package::load`] and
9//! [`Package::store`] directly.
10//!
11//! ## Payload formats
12//!
13//! The envelope may encode the HUGR in different formats, listed in
14//! [`EnvelopeFormat`]. The payload may also be compressed with zstd.
15//!
16//! Some formats can be represented as ASCII, as indicated by the
17//! [`EnvelopeFormat::ascii_printable`] method. When this is the case, the
18//! whole envelope can be stored in a string.
19//!
20//! ## Envelope header
21//!
22//! The binary header format is 10 bytes, with the following fields:
23//!
24//! | Field  | Size (bytes) | Description |
25//! |--------|--------------|-------------|
26//! | Magic  | 8            | [`MAGIC_NUMBERS`] constant identifying the envelope format. |
27//! | Format | 1            | [`EnvelopeFormat`] describing the payload format. |
28//! | Flags  | 1            | Additional configuration flags. |
29//!
30//! Flags:
31//!
32//! - Bit 0: Whether the payload is compressed with zstd.
33//! - Bits 1-5: Reserved for future use.
34//! - Bit 7,6: Constant "01" to make some headers ascii-printable.
35//!
36
37mod header;
38mod package_json;
39pub mod serde_with;
40
41pub use header::{EnvelopeConfig, EnvelopeFormat, MAGIC_NUMBERS, ZstdConfig};
42use hugr_model::v0::bumpalo::Bump;
43pub use package_json::PackageEncodingError;
44
45use crate::{Hugr, HugrView};
46use crate::{
47    extension::{ExtensionRegistry, Version},
48    package::Package,
49};
50use header::EnvelopeHeader;
51use std::io::BufRead;
52use std::io::Write;
53use std::str::FromStr;
54use thiserror::Error;
55
56#[allow(unused_imports)]
57use itertools::Itertools as _;
58
59use crate::import::ImportError;
60use crate::{Extension, import::import_package};
61
62/// Key used to store the name of the generator that produced the envelope.
63pub const GENERATOR_KEY: &str = "core.generator";
64/// Key used to store the list of used extensions in the metadata of a HUGR.
65pub const USED_EXTENSIONS_KEY: &str = "core.used_extensions";
66
67/// Get the name of the generator from the metadata of the HUGR modules.
68///
69/// If multiple modules have different generators, a comma-separated list is returned in
70/// module order.
71/// If no generator is found, `None` is returned.
72pub fn get_generator<H: HugrView>(modules: &[H]) -> Option<String> {
73    let generators: Vec<String> = modules
74        .iter()
75        .filter_map(|hugr| hugr.get_metadata(hugr.module_root(), GENERATOR_KEY))
76        .map(format_generator)
77        .collect();
78    if generators.is_empty() {
79        return None;
80    }
81
82    Some(generators.join(", "))
83}
84
85/// Format a generator value from the metadata.
86pub fn format_generator(json_val: &serde_json::Value) -> String {
87    match json_val {
88        serde_json::Value::String(s) => s.clone(),
89        serde_json::Value::Object(obj) => {
90            if let (Some(name), version) = (
91                obj.get("name").and_then(|v| v.as_str()),
92                obj.get("version").and_then(|v| v.as_str()),
93            ) {
94                if let Some(version) = version {
95                    // Expected format: {"name": "generator", "version": "1.0.0"}
96                    format!("{name}-v{version}")
97                } else {
98                    name.to_string()
99                }
100            } else {
101                // just print the whole object as a string
102                json_val.to_string()
103            }
104        }
105        // Raw JSON string fallback
106        _ => json_val.to_string(),
107    }
108}
109
110fn gen_str(generator: &Option<String>) -> String {
111    match generator {
112        Some(g) => format!("\ngenerated by {g}"),
113        None => String::new(),
114    }
115}
116
117/// Wrap an error with a generator string.
118#[derive(Error, Debug)]
119#[error("{inner}{}", gen_str(&self.generator))]
120pub struct WithGenerator<E: std::fmt::Display> {
121    inner: Box<E>,
122    /// The name of the generator that produced the envelope, if any.
123    generator: Option<String>,
124}
125
126impl<E: std::fmt::Display> WithGenerator<E> {
127    fn new(err: E, modules: &[impl HugrView]) -> Self {
128        Self {
129            inner: Box::new(err),
130            generator: get_generator(modules),
131        }
132    }
133
134    /// Get a reference to the inner error.
135    pub fn inner(&self) -> &E {
136        &self.inner
137    }
138
139    /// Get the name of the generator that produced the envelope, if any.
140    pub fn generator(&self) -> Option<&String> {
141        self.generator.as_ref()
142    }
143}
144
145/// Read a HUGR envelope from a reader.
146///
147/// Returns the deserialized package and the configuration used to encode it.
148///
149/// Parameters:
150/// - `reader`: The reader to read the envelope from.
151/// - `registry`: An extension registry with additional extensions to use when
152///   decoding the HUGR, if they are not already included in the package.
153pub fn read_envelope(
154    mut reader: impl BufRead,
155    registry: &ExtensionRegistry,
156) -> Result<(EnvelopeConfig, Package), EnvelopeError> {
157    let header = EnvelopeHeader::read(&mut reader)?;
158
159    let package = match header.zstd {
160        #[cfg(feature = "zstd")]
161        true => read_impl(
162            std::io::BufReader::new(zstd::Decoder::new(reader)?),
163            header,
164            registry,
165        ),
166        #[cfg(not(feature = "zstd"))]
167        true => Err(EnvelopeError::ZstdUnsupported),
168        false => read_impl(reader, header, registry),
169    }?;
170    Ok((header.config(), package))
171}
172
173/// Write a HUGR package into an envelope, using the specified configuration.
174///
175/// It is recommended to use a buffered writer for better performance.
176/// See [`std::io::BufWriter`] for more information.
177pub fn write_envelope(
178    writer: impl Write,
179    package: &Package,
180    config: EnvelopeConfig,
181) -> Result<(), EnvelopeError> {
182    write_envelope_impl(writer, &package.modules, &package.extensions, config)
183}
184
185/// Write a deconstructed HUGR package into an envelope, using the specified configuration.
186///
187/// It is recommended to use a buffered writer for better performance.
188/// See [`std::io::BufWriter`] for more information.
189pub(crate) fn write_envelope_impl<'h>(
190    mut writer: impl Write,
191    hugrs: impl IntoIterator<Item = &'h Hugr>,
192    extensions: &ExtensionRegistry,
193    config: EnvelopeConfig,
194) -> Result<(), EnvelopeError> {
195    let header = config.make_header();
196    header.write(&mut writer)?;
197
198    match config.zstd {
199        #[cfg(feature = "zstd")]
200        Some(zstd) => {
201            let writer = zstd::Encoder::new(writer, zstd.level())?.auto_finish();
202            write_impl(writer, hugrs, extensions, config)?;
203        }
204        #[cfg(not(feature = "zstd"))]
205        Some(_) => return Err(EnvelopeError::ZstdUnsupported),
206        None => write_impl(writer, hugrs, extensions, config)?,
207    }
208
209    Ok(())
210}
211
212/// Error type for envelope operations.
213#[derive(Debug, Error)]
214#[non_exhaustive]
215pub enum EnvelopeError {
216    /// Bad magic number.
217    #[error(
218        "Bad magic number. expected 0x{:X} found 0x{:X}",
219        u64::from_be_bytes(*expected),
220        u64::from_be_bytes(*found)
221    )]
222    MagicNumber {
223        /// The expected magic number.
224        ///
225        /// See [`MAGIC_NUMBERS`].
226        expected: [u8; 8],
227        /// The magic number in the envelope.
228        found: [u8; 8],
229    },
230    /// The specified payload format is invalid.
231    #[error("Format descriptor {descriptor} is invalid.")]
232    InvalidFormatDescriptor {
233        /// The unsupported format.
234        descriptor: usize,
235    },
236    /// The specified payload format is not supported.
237    #[error("Payload format {format} is not supported.{}",
238        match feature {
239            Some(f) => format!(" This requires the '{f}' feature for `hugr`."),
240            None => String::new()
241        },
242    )]
243    FormatUnsupported {
244        /// The unsupported format.
245        format: EnvelopeFormat,
246        /// Optionally, the feature required to support this format.
247        feature: Option<&'static str>,
248    },
249    /// Not all envelope formats can be represented as ASCII.
250    ///
251    /// This error is used when trying to store the envelope into a string.
252    #[error("Envelope format {format} cannot be represented as ASCII.")]
253    NonASCIIFormat {
254        /// The unsupported format.
255        format: EnvelopeFormat,
256    },
257    /// Envelope encoding required zstd compression, but the feature is not enabled.
258    #[error("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")]
259    ZstdUnsupported,
260    /// Expected the envelope to contain a single HUGR.
261    #[error("Expected an envelope containing a single hugr, but it contained {}.", if *count == 0 {
262        "none".to_string()
263    } else {
264        count.to_string()
265    })]
266    ExpectedSingleHugr {
267        /// The number of HUGRs in the package.
268        count: usize,
269    },
270    /// JSON serialization error.
271    #[error(transparent)]
272    SerdeError {
273        /// The source error.
274        #[from]
275        source: serde_json::Error,
276    },
277    /// IO read/write error.
278    #[error(transparent)]
279    IO {
280        /// The source error.
281        #[from]
282        source: std::io::Error,
283    },
284    /// Error writing a json package to the payload.
285    #[error(transparent)]
286    PackageEncoding {
287        /// The source error.
288        #[from]
289        source: PackageEncodingError,
290    },
291    /// Error importing a HUGR from a hugr-model payload.
292    #[error(transparent)]
293    ModelImport {
294        /// The source error.
295        #[from]
296        source: ImportError,
297        // TODO add generator to model import errors
298    },
299    /// Error reading a HUGR model payload.
300    #[error(transparent)]
301    ModelRead {
302        /// The source error.
303        #[from]
304        source: hugr_model::v0::binary::ReadError,
305    },
306    /// Error writing a HUGR model payload.
307    #[error(transparent)]
308    ModelWrite {
309        /// The source error.
310        #[from]
311        source: hugr_model::v0::binary::WriteError,
312    },
313    /// Error reading a HUGR model payload.
314    #[error("Model text parsing error")]
315    ModelTextRead {
316        /// The source error.
317        #[from]
318        source: hugr_model::v0::ast::ParseError,
319    },
320    /// Error reading a HUGR model payload.
321    #[error(transparent)]
322    ModelTextResolve {
323        /// The source error.
324        #[from]
325        source: hugr_model::v0::ast::ResolveError,
326    },
327    /// Error reading a list of extensions from the envelope.
328    #[error(transparent)]
329    ExtensionLoad {
330        /// The source error.
331        #[from]
332        source: crate::extension::ExtensionRegistryLoadError,
333    },
334    /// The specified payload format is not supported.
335    #[error(
336        "The envelope configuration has unknown {}. Please update your HUGR version.",
337        if flag_ids.len() == 1 {format!("flag #{}", flag_ids[0])} else {format!("flags {}", flag_ids.iter().join(", "))}
338    )]
339    FlagUnsupported {
340        /// The unrecognized flag bits.
341        flag_ids: Vec<usize>,
342    },
343    /// Error raised while checking for breaking extension version mismatch.
344    #[error(transparent)]
345    ExtensionVersion {
346        /// The source error.
347        #[from]
348        source: WithGenerator<ExtensionBreakingError>,
349    },
350}
351
352/// Internal implementation of [`read_envelope`] to call with/without the zstd decompression wrapper.
353fn read_impl(
354    payload: impl BufRead,
355    header: EnvelopeHeader,
356    registry: &ExtensionRegistry,
357) -> Result<Package, EnvelopeError> {
358    let package = match header.format {
359        EnvelopeFormat::PackageJson => Ok(package_json::from_json_reader(payload, registry)?),
360        EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
361            decode_model(payload, registry, header.format)
362        }
363        EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
364            decode_model_ast(payload, registry, header.format)
365        }
366    }?;
367
368    package.modules.iter().try_for_each(|module| {
369        check_breaking_extensions(module).map_err(|err| WithGenerator::new(err, &package.modules))
370    })?;
371    Ok(package)
372}
373
374/// Read a HUGR model payload from a reader.
375///
376/// Parameters:
377/// - `stream`: The reader to read the envelope from.
378/// - `extension_registry`: An extension registry with additional extensions to use when
379///   decoding the HUGR, if they are not already included in the package.
380/// - `format`: The format of the payload.
381///
382/// Returns package and the combined extension registry
383/// of the provided registry and the package extensions.
384fn decode_model(
385    mut stream: impl BufRead,
386    extension_registry: &ExtensionRegistry,
387    format: EnvelopeFormat,
388) -> Result<Package, EnvelopeError> {
389    check_model_version(format)?;
390    let bump = Bump::default();
391    let model_package = hugr_model::v0::binary::read_from_reader(&mut stream, &bump)?;
392
393    let packaged_extensions = if format == EnvelopeFormat::ModelWithExtensions {
394        ExtensionRegistry::load_json(stream, extension_registry)?
395    } else {
396        ExtensionRegistry::new([])
397    };
398
399    let package = import_package(&model_package, packaged_extensions, extension_registry)?;
400
401    Ok(package)
402}
403
404fn check_model_version(format: EnvelopeFormat) -> Result<(), EnvelopeError> {
405    if format.model_version() != Some(0) {
406        return Err(EnvelopeError::FormatUnsupported {
407            format,
408            feature: None,
409        });
410    }
411    Ok(())
412}
413
414/// Read a HUGR model text payload from a reader.
415///
416/// Parameters:
417/// - `stream`: The reader to read the envelope from.
418/// - `extension_registry`: An extension registry with additional extensions to use when
419///   decoding the HUGR, if they are not already included in the package.
420/// - `format`: The format of the payload.
421fn decode_model_ast(
422    mut stream: impl BufRead,
423    extension_registry: &ExtensionRegistry,
424    format: EnvelopeFormat,
425) -> Result<Package, EnvelopeError> {
426    check_model_version(format)?;
427
428    let packaged_extensions = if format == EnvelopeFormat::ModelTextWithExtensions {
429        let deserializer = serde_json::Deserializer::from_reader(&mut stream);
430        // Deserialize the first json object, leaving the rest of the reader unconsumed.
431        let extra_extensions = deserializer
432            .into_iter::<Vec<Extension>>()
433            .next()
434            .unwrap_or(Ok(vec![]))?;
435        ExtensionRegistry::new(extra_extensions.into_iter().map(std::sync::Arc::new))
436    } else {
437        ExtensionRegistry::new([])
438    };
439
440    // Read the package into a string, then parse it.
441    //
442    // Due to how `to_string` works, we cannot append extensions after the package.
443    let mut buffer = String::new();
444    stream.read_to_string(&mut buffer)?;
445    let ast_package = hugr_model::v0::ast::Package::from_str(&buffer)?;
446
447    let bump = Bump::default();
448    let model_package = ast_package.resolve(&bump)?;
449
450    let package = import_package(&model_package, packaged_extensions, extension_registry)?;
451
452    Ok(package)
453}
454
455/// Internal implementation of [`write_envelope`] to call with/without the zstd compression wrapper.
456fn write_impl<'h>(
457    writer: impl Write,
458    hugrs: impl IntoIterator<Item = &'h Hugr>,
459    extensions: &ExtensionRegistry,
460    config: EnvelopeConfig,
461) -> Result<(), EnvelopeError> {
462    match config.format {
463        EnvelopeFormat::PackageJson => package_json::to_json_writer(hugrs, extensions, writer)?,
464        EnvelopeFormat::Model
465        | EnvelopeFormat::ModelWithExtensions
466        | EnvelopeFormat::ModelText
467        | EnvelopeFormat::ModelTextWithExtensions => {
468            encode_model(writer, hugrs, extensions, config.format)?;
469        }
470    }
471    Ok(())
472}
473
474fn encode_model<'h>(
475    mut writer: impl Write,
476    hugrs: impl IntoIterator<Item = &'h Hugr>,
477    extensions: &ExtensionRegistry,
478    format: EnvelopeFormat,
479) -> Result<(), EnvelopeError> {
480    use hugr_model::v0::{binary::write_to_writer, bumpalo::Bump};
481
482    use crate::export::export_package;
483
484    if format.model_version() != Some(0) {
485        return Err(EnvelopeError::FormatUnsupported {
486            format,
487            feature: None,
488        });
489    }
490
491    // Prepend extensions for binary model.
492    if format == EnvelopeFormat::ModelTextWithExtensions {
493        serde_json::to_writer(&mut writer, &extensions.iter().collect_vec())?;
494    }
495
496    let bump = Bump::default();
497    let model_package = export_package(hugrs, extensions, &bump);
498
499    match format {
500        EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
501            write_to_writer(&model_package, &mut writer)?;
502        }
503        EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
504            let model_package = model_package.as_ast().unwrap();
505            writeln!(writer, "{model_package}")?;
506        }
507        _ => unreachable!(),
508    }
509
510    // Append extensions for binary model.
511    if format == EnvelopeFormat::ModelWithExtensions {
512        serde_json::to_writer(writer, &extensions.iter().collect_vec())?;
513    }
514
515    Ok(())
516}
517
518#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize)]
519struct UsedExtension {
520    name: String,
521    version: Version,
522}
523
524#[derive(Debug, Error)]
525#[error(
526    "Extension '{name}' version mismatch: registered version is {registered}, but used version is {used}"
527)]
528/// Error raised when the reported used version of an extension
529/// does not match the registered version in the extension registry.
530pub struct ExtensionVersionMismatch {
531    /// The name of the extension.
532    pub name: String,
533    /// The registered version of the extension in the loaded registry.
534    pub registered: Version,
535    /// The version of the extension reported as used in the HUGR metadata.
536    pub used: Version,
537}
538
539#[derive(Debug, Error)]
540#[non_exhaustive]
541/// Error raised when checking for breaking changes in used extensions.
542pub enum ExtensionBreakingError {
543    /// The extension version in the metadata does not match the registered version.
544    #[error("{0}")]
545    ExtensionVersionMismatch(ExtensionVersionMismatch),
546
547    /// Error deserializing the used extensions metadata.
548    #[error("Failed to deserialize used extensions metadata")]
549    Deserialization(#[from] serde_json::Error),
550}
551/// If HUGR metadata contains a list of used extensions, under the key [`USED_EXTENSIONS_KEY`],
552/// and extension is resolved in the HUGR, check that the
553/// version of the extension in the metadata matches the resolved version.
554/// Version compatibility is defined by [`compatible_versions`].
555fn check_breaking_extensions(hugr: impl crate::HugrView) -> Result<(), ExtensionBreakingError> {
556    check_breaking_extensions_against_registry(&hugr, hugr.extensions())
557}
558
559/// If HUGR metadata contains a list of used extensions, under the key [`USED_EXTENSIONS_KEY`],
560/// and extension is registered in the given registry, check that the
561/// version of the extension in the metadata matches the registered version.
562/// Version compatibility is defined by [`compatible_versions`].
563fn check_breaking_extensions_against_registry(
564    hugr: &impl crate::HugrView,
565    registry: &ExtensionRegistry,
566) -> Result<(), ExtensionBreakingError> {
567    let Some(exts) = hugr.get_metadata(hugr.module_root(), USED_EXTENSIONS_KEY) else {
568        return Ok(()); // No used extensions metadata, nothing to check
569    };
570    let used_exts: Vec<UsedExtension> = serde_json::from_value(exts.clone())?; // TODO handle errors properly
571
572    for ext in used_exts {
573        let Some(registered) = registry.get(ext.name.as_str()) else {
574            continue; // Extension not registered, ignore
575        };
576        if !compatible_versions(registered.version(), &ext.version) {
577            // This is a breaking change, raise an error.
578
579            return Err(ExtensionBreakingError::ExtensionVersionMismatch(
580                ExtensionVersionMismatch {
581                    name: ext.name,
582                    registered: registered.version().clone(),
583                    used: ext.version,
584                },
585            ));
586        }
587    }
588
589    Ok(())
590}
591
592/// Check if two versions are compatible according to:
593/// - Major version must match.
594/// - If major version is 0, minor version must match.
595/// - The registered version must be greater than or equal to the used version.
596fn compatible_versions(registered: &Version, used: &Version) -> bool {
597    if used.major != registered.major {
598        return false;
599    }
600    if used.major == 0 && used.minor != registered.minor {
601        return false;
602    }
603
604    registered >= used
605}
606
607#[cfg(test)]
608pub(crate) mod test {
609    use super::*;
610    use cool_asserts::assert_matches;
611    use rstest::rstest;
612    use std::borrow::Cow;
613    use std::io::BufReader;
614
615    use crate::HugrView;
616    use crate::builder::test::{multi_module_package, simple_package};
617    use crate::extension::{Extension, ExtensionRegistry, Version};
618    use crate::extension::{ExtensionId, PRELUDE_REGISTRY};
619    use crate::hugr::HugrMut;
620    use crate::hugr::test::check_hugr_equality;
621    use crate::std_extensions::STD_REG;
622    use serde_json::json;
623    use std::sync::Arc;
624
625    /// Returns an `ExtensionRegistry` with the extensions from both
626    /// sets. Avoids cloning if the first one already contains all
627    /// extensions from the second one.
628    fn join_extensions<'a>(
629        extensions: &'a ExtensionRegistry,
630        other: &ExtensionRegistry,
631    ) -> Cow<'a, ExtensionRegistry> {
632        if other.iter().all(|e| extensions.contains(e.name())) {
633            Cow::Borrowed(extensions)
634        } else {
635            let mut extensions = extensions.clone();
636            extensions.extend(other);
637            Cow::Owned(extensions)
638        }
639    }
640
641    /// Serialize and deserialize a HUGR into an envelope with the given config,
642    /// and check that the result is the same as the original.
643    ///
644    /// We do not compare the before and after `Hugr`s for equality directly,
645    /// because impls of `CustomConst` are not required to implement equality
646    /// checking.
647    ///
648    /// Returns the deserialized HUGR.
649    pub(crate) fn check_hugr_roundtrip(hugr: &Hugr, config: EnvelopeConfig) -> Hugr {
650        let mut buffer = Vec::new();
651        hugr.store(&mut buffer, config).unwrap();
652
653        let extensions = join_extensions(&STD_REG, hugr.extensions());
654
655        let reader = BufReader::new(buffer.as_slice());
656        let extracted = Hugr::load(reader, Some(&extensions)).unwrap();
657
658        check_hugr_equality(&extracted, hugr);
659        extracted
660    }
661
662    #[rstest]
663    fn errors() {
664        let package = simple_package();
665        assert_matches!(
666            package.store_str(EnvelopeConfig::binary()),
667            Err(EnvelopeError::NonASCIIFormat { .. })
668        );
669    }
670
671    #[rstest]
672    #[case::empty(Package::default())]
673    #[case::simple(simple_package())]
674    #[case::multi(multi_module_package())]
675    fn text_roundtrip(#[case] package: Package) {
676        let envelope = package.store_str(EnvelopeConfig::text()).unwrap();
677        let new_package = Package::load_str(&envelope, None).unwrap();
678        assert_eq!(package, new_package);
679    }
680
681    #[rstest]
682    #[case::empty(Package::default())]
683    #[case::simple(simple_package())]
684    #[case::multi(multi_module_package())]
685    #[cfg_attr(all(miri, feature = "zstd"), ignore)] // FFI calls (required to compress with zstd) are not supported in miri
686    fn compressed_roundtrip(#[case] package: Package) {
687        let mut buffer = Vec::new();
688        let config = EnvelopeConfig {
689            format: EnvelopeFormat::PackageJson,
690            zstd: Some(ZstdConfig::default()),
691        };
692        let res = package.store(&mut buffer, config);
693
694        match cfg!(feature = "zstd") {
695            true => res.unwrap(),
696            false => {
697                assert_matches!(res, Err(EnvelopeError::ZstdUnsupported));
698                return;
699            }
700        }
701
702        let (decoded_config, new_package) =
703            read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();
704
705        assert_eq!(config.format, decoded_config.format);
706        assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
707        assert_eq!(package, new_package);
708    }
709
710    #[rstest]
711    // Empty packages
712    #[case::empty_model(Package::default(), EnvelopeFormat::Model)]
713    #[case::empty_model_exts(Package::default(), EnvelopeFormat::ModelWithExtensions)]
714    #[case::empty_text(Package::default(), EnvelopeFormat::ModelText)]
715    #[case::empty_text_exts(Package::default(), EnvelopeFormat::ModelTextWithExtensions)]
716    // Single hugrs
717    #[case::simple_bin(simple_package(), EnvelopeFormat::Model)]
718    #[case::simple_bin_exts(simple_package(), EnvelopeFormat::ModelWithExtensions)]
719    #[case::simple_text(simple_package(), EnvelopeFormat::ModelText)]
720    #[case::simple_text_exts(simple_package(), EnvelopeFormat::ModelTextWithExtensions)]
721    // Multiple hugrs
722    #[case::multi_bin(multi_module_package(), EnvelopeFormat::Model)]
723    #[case::multi_bin_exts(multi_module_package(), EnvelopeFormat::ModelWithExtensions)]
724    #[case::multi_text(multi_module_package(), EnvelopeFormat::ModelText)]
725    #[case::multi_text_exts(multi_module_package(), EnvelopeFormat::ModelTextWithExtensions)]
726    fn model_roundtrip(#[case] package: Package, #[case] format: EnvelopeFormat) {
727        let mut buffer = Vec::new();
728        let config = EnvelopeConfig { format, zstd: None };
729        package.store(&mut buffer, config).unwrap();
730
731        let (decoded_config, new_package) =
732            read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();
733
734        assert_eq!(config.format, decoded_config.format);
735        assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
736
737        assert_eq!(package, new_package);
738    }
739
740    /// Test helper to call `check_breaking_extensions_against_registry`
741    fn check(hugr: &Hugr, registry: &ExtensionRegistry) -> Result<(), ExtensionBreakingError> {
742        check_breaking_extensions_against_registry(hugr, registry)
743    }
744
745    #[rstest]
746    #[case::simple(simple_package())]
747    fn test_check_breaking_extensions(#[case] mut package: Package) {
748        // extension with major version 0
749        let test_ext_v0 =
750            Extension::new(ExtensionId::new_unchecked("test-v0"), Version::new(0, 2, 3));
751        //  extension with major version > 0
752        let test_ext_v1 =
753            Extension::new(ExtensionId::new_unchecked("test-v1"), Version::new(1, 2, 3));
754
755        // Create a registry with the test extensions
756        let registry =
757            ExtensionRegistry::new([Arc::new(test_ext_v0.clone()), Arc::new(test_ext_v1.clone())]);
758        let mut hugr = package.modules.remove(0);
759
760        // No metadata - should pass
761        assert_matches!(check(&hugr, &registry), Ok(()));
762
763        // Matching version for v0 - should pass
764        let used_exts = json!([{ "name": "test-v0", "version": "0.2.3" }]);
765        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
766        assert_matches!(check(&hugr, &registry), Ok(()));
767
768        // Matching major/minor but lower patch for v0 - should pass
769        let used_exts = json!([{ "name": "test-v0", "version": "0.2.2" }]);
770        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
771        assert_matches!(check(&hugr, &registry), Ok(()));
772
773        //Different minor version for v0 - should fail
774        let used_exts = json!([{ "name": "test-v0", "version": "0.3.3" }]);
775        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
776        assert_matches!(
777            check(&hugr, &registry),
778            Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
779                name,
780                registered,
781                used
782            })) if name == "test-v0" && registered == Version::new(0, 2, 3) && used == Version::new(0, 3, 3)
783        );
784
785        assert!(
786            check_breaking_extensions(&hugr).is_ok(),
787            "Extension is not actually used in the HUGR, should be ignored by full check"
788        );
789
790        // Different major version for v0 - should fail
791        let used_exts = json!([{ "name": "test-v0", "version": "1.2.3" }]);
792        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
793        assert_matches!(
794            check(&hugr, &registry),
795            Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
796                name,
797                registered,
798                used
799            })) if name == "test-v0" && registered == Version::new(0, 2, 3) && used == Version::new(1, 2, 3)
800        );
801
802        // Higher patch version for v0 - should fail
803        let used_exts = json!([{ "name": "test-v0", "version": "0.2.4" }]);
804        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
805        assert_matches!(
806            check(&hugr, &registry),
807            Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
808                name,
809                registered,
810                used
811            })) if name == "test-v0" && registered == Version::new(0, 2, 3) && used == Version::new(0, 2, 4)
812        );
813
814        // Matching version for v1 - should pass
815        let used_exts = json!([{ "name": "test-v1", "version": "1.2.3" }]);
816        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
817        assert_matches!(check(&hugr, &registry), Ok(()));
818
819        // Lower minor version for v1 - should pass
820        let used_exts = json!([{ "name": "test-v1", "version": "1.1.0" }]);
821        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
822        assert_matches!(check(&hugr, &registry), Ok(()));
823
824        // Lower patch for v1 - should pass
825        let used_exts = json!([{ "name": "test-v1", "version": "1.2.2" }]);
826        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
827        assert_matches!(check(&hugr, &registry), Ok(()));
828
829        // Different major version for v1 - should fail
830        let used_exts = json!([{ "name": "test-v1", "version": "2.2.3" }]);
831        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
832        assert_matches!(
833            check(&hugr, &registry),
834            Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
835                name,
836                registered,
837                used
838            })) if name == "test-v1" && registered == Version::new(1, 2, 3) && used == Version::new(2, 2, 3)
839        );
840
841        // Higher minor version for v1 - should fail
842        let used_exts = json!([{ "name": "test-v1", "version": "1.3.0" }]);
843        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
844        assert_matches!(
845            check(&hugr, &registry),
846            Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
847                name,
848                registered,
849                used
850            })) if name == "test-v1" && registered == Version::new(1, 2, 3) && used == Version::new(1, 3, 0)
851        );
852
853        // Higher patch version for v1 - should fail
854        let used_exts = json!([{ "name": "test-v1", "version": "1.2.4" }]);
855        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
856        assert_matches!(
857            check(&hugr, &registry),
858            Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
859                name,
860                registered,
861                used
862            })) if name == "test-v1" && registered == Version::new(1, 2, 3) && used == Version::new(1, 2, 4)
863        );
864
865        // Non-registered extension - should pass
866        let used_exts = json!([{ "name": "unknown", "version": "1.0.0" }]);
867        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
868        assert_matches!(check(&hugr, &registry), Ok(()));
869
870        // Multiple extensions - one mismatch should fail
871        let used_exts = json!([
872            { "name": "unknown", "version": "1.0.0" },
873            { "name": "test-v1", "version": "2.0.0" }
874        ]);
875        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
876        assert_matches!(
877            check(&hugr, &registry),
878            Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
879                name,
880                registered,
881                used
882            })) if name == "test-v1" && registered == Version::new(1, 2, 3) && used == Version::new(2, 0, 0)
883        );
884
885        // Invalid metadata format - should fail with deserialization error
886        hugr.set_metadata(
887            hugr.module_root(),
888            USED_EXTENSIONS_KEY,
889            json!("not an array"),
890        );
891        assert_matches!(
892            check(&hugr, &registry),
893            Err(ExtensionBreakingError::Deserialization(_))
894        );
895
896        //  Multiple extensions with all compatible versions - should pass
897        let used_exts = json!([
898            { "name": "test-v0", "version": "0.2.2" },
899            { "name": "test-v1", "version": "1.1.9" }
900        ]);
901        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
902        assert_matches!(check(&hugr, &registry), Ok(()));
903    }
904
905    #[test]
906    fn test_with_generator_error_message() {
907        let test_ext = Extension::new(ExtensionId::new_unchecked("test"), Version::new(1, 0, 0));
908        let registry = ExtensionRegistry::new([Arc::new(test_ext)]);
909
910        let mut hugr = simple_package().modules.remove(0);
911
912        // Set a generator name in the metadata
913        let generator_name = json!({ "name": "TestGenerator", "version": "1.2.3" });
914        hugr.set_metadata(hugr.module_root(), GENERATOR_KEY, generator_name.clone());
915
916        // Set incompatible extension version in metadata
917        let used_exts = json!([{ "name": "test", "version": "2.0.0" }]);
918        hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
919
920        // Create the error and wrap it with WithGenerator
921        let err = check_breaking_extensions_against_registry(&hugr, &registry).unwrap_err();
922        let with_gen = WithGenerator::new(err, &[&hugr]);
923
924        let err_msg = with_gen.to_string();
925        assert!(err_msg.contains("Extension 'test' version mismatch"));
926        assert!(err_msg.contains("TestGenerator-v1.2.3"));
927    }
928}