1#![allow(deprecated)]
38mod header;
44mod package_json;
45pub mod serde_with;
46
47pub use header::{EnvelopeConfig, EnvelopeFormat, MAGIC_NUMBERS, ZstdConfig};
48pub use package_json::PackageEncodingError;
49
50use crate::{Hugr, HugrView};
51use crate::{
52 extension::{ExtensionRegistry, Version},
53 package::Package,
54};
55use header::EnvelopeHeader;
56use std::io::BufRead;
57use std::io::Write;
58use std::str::FromStr;
59use thiserror::Error;
60
61#[allow(unused_imports)]
62use itertools::Itertools as _;
63
64use crate::import::ImportError;
65use crate::{Extension, import::import_package};
66
67pub const GENERATOR_KEY: &str = "core.generator";
69pub const USED_EXTENSIONS_KEY: &str = "core.used_extensions";
71
72fn get_generator<H: HugrView>(modules: &[H]) -> Option<String> {
77 let generators: Vec<String> = modules
78 .iter()
79 .filter_map(|hugr| hugr.get_metadata(hugr.module_root(), GENERATOR_KEY))
80 .map(|v| v.to_string())
81 .collect();
82 if generators.is_empty() {
83 return None;
84 }
85
86 Some(generators.join(", "))
87}
88
89fn gen_str(generator: &Option<String>) -> String {
90 match generator {
91 Some(g) => format!("\ngenerated by {g}"),
92 None => String::new(),
93 }
94}
95
96#[derive(Error, Debug)]
98#[error("{inner}{}", gen_str(&self.generator))]
99pub struct WithGenerator<E: std::fmt::Display> {
100 inner: Box<E>,
101 generator: Option<String>,
103}
104
105impl<E: std::fmt::Display> WithGenerator<E> {
106 fn new(err: E, modules: &[impl HugrView]) -> Self {
107 Self {
108 inner: Box::new(err),
109 generator: get_generator(modules),
110 }
111 }
112}
113
114pub fn read_envelope(
123 mut reader: impl BufRead,
124 registry: &ExtensionRegistry,
125) -> Result<(EnvelopeConfig, Package), EnvelopeError> {
126 let header = EnvelopeHeader::read(&mut reader)?;
127
128 let package = match header.zstd {
129 #[cfg(feature = "zstd")]
130 true => read_impl(
131 std::io::BufReader::new(zstd::Decoder::new(reader)?),
132 header,
133 registry,
134 ),
135 #[cfg(not(feature = "zstd"))]
136 true => Err(EnvelopeError::ZstdUnsupported),
137 false => read_impl(reader, header, registry),
138 }?;
139 Ok((header.config(), package))
140}
141
142pub fn write_envelope(
147 writer: impl Write,
148 package: &Package,
149 config: EnvelopeConfig,
150) -> Result<(), EnvelopeError> {
151 write_envelope_impl(writer, &package.modules, &package.extensions, config)
152}
153
154pub(crate) fn write_envelope_impl<'h>(
159 mut writer: impl Write,
160 hugrs: impl IntoIterator<Item = &'h Hugr>,
161 extensions: &ExtensionRegistry,
162 config: EnvelopeConfig,
163) -> Result<(), EnvelopeError> {
164 let header = config.make_header();
165 header.write(&mut writer)?;
166
167 match config.zstd {
168 #[cfg(feature = "zstd")]
169 Some(zstd) => {
170 let writer = zstd::Encoder::new(writer, zstd.level())?.auto_finish();
171 write_impl(writer, hugrs, extensions, config)?;
172 }
173 #[cfg(not(feature = "zstd"))]
174 Some(_) => return Err(EnvelopeError::ZstdUnsupported),
175 None => write_impl(writer, hugrs, extensions, config)?,
176 }
177
178 Ok(())
179}
180
181#[derive(Debug, Error)]
183#[non_exhaustive]
184pub enum EnvelopeError {
185 #[error(
187 "Bad magic number. expected 0x{:X} found 0x{:X}",
188 u64::from_be_bytes(*expected),
189 u64::from_be_bytes(*found)
190 )]
191 MagicNumber {
192 expected: [u8; 8],
196 found: [u8; 8],
198 },
199 #[error("Format descriptor {descriptor} is invalid.")]
201 InvalidFormatDescriptor {
202 descriptor: usize,
204 },
205 #[error("Payload format {format} is not supported.{}",
207 match feature {
208 Some(f) => format!(" This requires the '{f}' feature for `hugr`."),
209 None => String::new()
210 },
211 )]
212 FormatUnsupported {
213 format: EnvelopeFormat,
215 feature: Option<&'static str>,
217 },
218 #[error("Envelope format {format} cannot be represented as ASCII.")]
222 NonASCIIFormat {
223 format: EnvelopeFormat,
225 },
226 #[error("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")]
228 ZstdUnsupported,
229 #[error("Expected an envelope containing a single hugr, but it contained {}.", if *count == 0 {
231 "none".to_string()
232 } else {
233 count.to_string()
234 })]
235 ExpectedSingleHugr {
236 count: usize,
238 },
239 #[error(transparent)]
241 SerdeError {
242 #[from]
244 source: serde_json::Error,
245 },
246 #[error(transparent)]
248 IO {
249 #[from]
251 source: std::io::Error,
252 },
253 #[error(transparent)]
255 PackageEncoding {
256 #[from]
258 source: PackageEncodingError,
259 },
260 #[error(transparent)]
262 ModelImport {
263 #[from]
265 source: ImportError,
266 },
268 #[error(transparent)]
270 ModelRead {
271 #[from]
273 source: hugr_model::v0::binary::ReadError,
274 },
275 #[error(transparent)]
277 ModelWrite {
278 #[from]
280 source: hugr_model::v0::binary::WriteError,
281 },
282 #[error(transparent)]
284 ModelTextRead {
285 #[from]
287 source: hugr_model::v0::ast::ParseError,
288 },
289 #[error(transparent)]
291 ModelTextResolve {
292 #[from]
294 source: hugr_model::v0::ast::ResolveError,
295 },
296 #[error(transparent)]
298 ExtensionLoad {
299 #[from]
301 source: crate::extension::ExtensionRegistryLoadError,
302 },
303}
304
305fn read_impl(
307 payload: impl BufRead,
308 header: EnvelopeHeader,
309 registry: &ExtensionRegistry,
310) -> Result<Package, EnvelopeError> {
311 match header.format {
312 #[allow(deprecated)]
313 EnvelopeFormat::PackageJson => Ok(package_json::from_json_reader(payload, registry)?),
314 EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
315 decode_model(payload, registry, header.format)
316 }
317 EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
318 decode_model_ast(payload, registry, header.format)
319 }
320 }
321}
322
323fn decode_model(
331 mut stream: impl BufRead,
332 extension_registry: &ExtensionRegistry,
333 format: EnvelopeFormat,
334) -> Result<Package, EnvelopeError> {
335 use hugr_model::v0::bumpalo::Bump;
336
337 if format.model_version() != Some(0) {
338 return Err(EnvelopeError::FormatUnsupported {
339 format,
340 feature: None,
341 });
342 }
343
344 let bump = Bump::default();
345 let model_package = hugr_model::v0::binary::read_from_reader(&mut stream, &bump)?;
346
347 let mut extension_registry = extension_registry.clone();
348 if format == EnvelopeFormat::ModelWithExtensions {
349 let extra_extensions = ExtensionRegistry::load_json(stream, &extension_registry)?;
350 extension_registry.extend(extra_extensions);
351 }
352
353 Ok(import_package(&model_package, &extension_registry)?)
354}
355
356fn decode_model_ast(
364 mut stream: impl BufRead,
365 extension_registry: &ExtensionRegistry,
366 format: EnvelopeFormat,
367) -> Result<Package, EnvelopeError> {
368 use crate::import::import_package;
369 use hugr_model::v0::bumpalo::Bump;
370
371 if format.model_version() != Some(0) {
372 return Err(EnvelopeError::FormatUnsupported {
373 format,
374 feature: None,
375 });
376 }
377
378 let mut extension_registry = extension_registry.clone();
379 if format == EnvelopeFormat::ModelTextWithExtensions {
380 let deserializer = serde_json::Deserializer::from_reader(&mut stream);
381 let extra_extensions = deserializer
383 .into_iter::<Vec<Extension>>()
384 .next()
385 .unwrap_or(Ok(vec![]))?;
386 for ext in extra_extensions {
387 extension_registry.register_updated(ext);
388 }
389 }
390
391 let mut buffer = String::new();
395 stream.read_to_string(&mut buffer)?;
396 let ast_package = hugr_model::v0::ast::Package::from_str(&buffer)?;
397
398 let bump = Bump::default();
399 let model_package = ast_package.resolve(&bump)?;
400
401 Ok(import_package(&model_package, &extension_registry)?)
402}
403
404fn write_impl<'h>(
406 writer: impl Write,
407 hugrs: impl IntoIterator<Item = &'h Hugr>,
408 extensions: &ExtensionRegistry,
409 config: EnvelopeConfig,
410) -> Result<(), EnvelopeError> {
411 match config.format {
412 #[allow(deprecated)]
413 EnvelopeFormat::PackageJson => package_json::to_json_writer(hugrs, extensions, writer)?,
414 EnvelopeFormat::Model
415 | EnvelopeFormat::ModelWithExtensions
416 | EnvelopeFormat::ModelText
417 | EnvelopeFormat::ModelTextWithExtensions => {
418 encode_model(writer, hugrs, extensions, config.format)?;
419 }
420 }
421 Ok(())
422}
423
424fn encode_model<'h>(
425 mut writer: impl Write,
426 hugrs: impl IntoIterator<Item = &'h Hugr>,
427 extensions: &ExtensionRegistry,
428 format: EnvelopeFormat,
429) -> Result<(), EnvelopeError> {
430 use hugr_model::v0::{binary::write_to_writer, bumpalo::Bump};
431
432 use crate::export::export_package;
433
434 if format.model_version() != Some(0) {
435 return Err(EnvelopeError::FormatUnsupported {
436 format,
437 feature: None,
438 });
439 }
440
441 if format == EnvelopeFormat::ModelTextWithExtensions {
443 serde_json::to_writer(&mut writer, &extensions.iter().collect_vec())?;
444 }
445
446 let bump = Bump::default();
447 let model_package = export_package(hugrs, extensions, &bump);
448
449 match format {
450 EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
451 write_to_writer(&model_package, &mut writer)?;
452 }
453 EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
454 let model_package = model_package.as_ast().unwrap();
455 writeln!(writer, "{model_package}")?;
456 }
457 _ => unreachable!(),
458 }
459
460 if format == EnvelopeFormat::ModelWithExtensions {
462 serde_json::to_writer(writer, &extensions.iter().collect_vec())?;
463 }
464
465 Ok(())
466}
467
468#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize)]
469struct UsedExtension {
470 name: String,
471 version: Version,
472}
473
474#[derive(Debug, Error)]
475#[error(
476 "Extension '{name}' version mismatch: registered version is {registered}, but used version is {used}"
477)]
478pub struct ExtensionVersionMismatch {
481 name: String,
482 registered: Version,
483 used: Version,
484}
485
486#[derive(Debug, Error)]
487#[non_exhaustive]
488pub enum ExtensionBreakingError {
490 #[error("{0}")]
492 ExtensionVersionMismatch(ExtensionVersionMismatch),
493
494 #[error("Failed to deserialize used extensions metadata")]
496 Deserialization(#[from] serde_json::Error),
497}
498fn check_breaking_extensions(
503 hugr: impl crate::HugrView,
504 registry: &ExtensionRegistry,
505) -> Result<(), ExtensionBreakingError> {
506 let Some(exts) = hugr.get_metadata(hugr.module_root(), USED_EXTENSIONS_KEY) else {
507 return Ok(()); };
509 let used_exts: Vec<UsedExtension> = serde_json::from_value(exts.clone())?; for ext in used_exts {
512 let Some(registered) = registry.get(ext.name.as_str()) else {
513 continue; };
515 if !compatible_versions(registered.version(), &ext.version) {
516 return Err(ExtensionBreakingError::ExtensionVersionMismatch(
519 ExtensionVersionMismatch {
520 name: ext.name,
521 registered: registered.version().clone(),
522 used: ext.version,
523 },
524 ));
525 }
526 }
527
528 Ok(())
529}
530
531fn compatible_versions(v1: &Version, v2: &Version) -> bool {
535 if v1.major != v2.major {
536 return false; }
538
539 if v1.major == 0 {
540 return v1.minor == v2.minor;
542 }
543
544 true
545}
546
547#[cfg(test)]
548pub(crate) mod test {
549 use super::*;
550 use cool_asserts::assert_matches;
551 use rstest::rstest;
552 use std::borrow::Cow;
553 use std::io::BufReader;
554
555 use crate::HugrView;
556 use crate::builder::test::{multi_module_package, simple_package};
557 use crate::extension::{Extension, ExtensionRegistry, Version};
558 use crate::extension::{ExtensionId, PRELUDE_REGISTRY};
559 use crate::hugr::HugrMut;
560 use crate::hugr::test::check_hugr_equality;
561 use crate::std_extensions::STD_REG;
562 use serde_json::json;
563 use std::sync::Arc;
564
565 fn join_extensions<'a>(
569 extensions: &'a ExtensionRegistry,
570 other: &ExtensionRegistry,
571 ) -> Cow<'a, ExtensionRegistry> {
572 if other.iter().all(|e| extensions.contains(e.name())) {
573 Cow::Borrowed(extensions)
574 } else {
575 let mut extensions = extensions.clone();
576 extensions.extend(other);
577 Cow::Owned(extensions)
578 }
579 }
580
581 pub(crate) fn check_hugr_roundtrip(hugr: &Hugr, config: EnvelopeConfig) -> Hugr {
590 let mut buffer = Vec::new();
591 hugr.store(&mut buffer, config).unwrap();
592
593 let extensions = join_extensions(&STD_REG, hugr.extensions());
594
595 let reader = BufReader::new(buffer.as_slice());
596 let extracted = Hugr::load(reader, Some(&extensions)).unwrap();
597
598 check_hugr_equality(&extracted, hugr);
599 extracted
600 }
601
602 #[rstest]
603 fn errors() {
604 let package = simple_package();
605 assert_matches!(
606 package.store_str(EnvelopeConfig::binary()),
607 Err(EnvelopeError::NonASCIIFormat { .. })
608 );
609 }
610
611 #[rstest]
612 #[case::empty(Package::default())]
613 #[case::simple(simple_package())]
614 #[case::multi(multi_module_package())]
615 fn text_roundtrip(#[case] package: Package) {
616 let envelope = package.store_str(EnvelopeConfig::text()).unwrap();
617 let new_package = Package::load_str(&envelope, None).unwrap();
618 assert_eq!(package, new_package);
619 }
620
621 #[rstest]
622 #[case::empty(Package::default())]
623 #[case::simple(simple_package())]
624 #[case::multi(multi_module_package())]
625 #[cfg_attr(all(miri, feature = "zstd"), ignore)] fn compressed_roundtrip(#[case] package: Package) {
627 let mut buffer = Vec::new();
628 let config = EnvelopeConfig {
629 format: EnvelopeFormat::PackageJson,
630 zstd: Some(ZstdConfig::default()),
631 };
632 let res = package.store(&mut buffer, config);
633
634 match cfg!(feature = "zstd") {
635 true => res.unwrap(),
636 false => {
637 assert_matches!(res, Err(EnvelopeError::ZstdUnsupported));
638 return;
639 }
640 }
641
642 let (decoded_config, new_package) =
643 read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();
644
645 assert_eq!(config.format, decoded_config.format);
646 assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
647 assert_eq!(package, new_package);
648 }
649
650 #[rstest]
651 #[case::empty_model(Package::default(), EnvelopeFormat::Model)]
653 #[case::empty_model_exts(Package::default(), EnvelopeFormat::ModelWithExtensions)]
654 #[case::empty_text(Package::default(), EnvelopeFormat::ModelText)]
655 #[case::empty_text_exts(Package::default(), EnvelopeFormat::ModelTextWithExtensions)]
656 #[case::simple_bin(simple_package(), EnvelopeFormat::Model)]
658 #[case::simple_bin_exts(simple_package(), EnvelopeFormat::ModelWithExtensions)]
659 #[case::simple_text(simple_package(), EnvelopeFormat::ModelText)]
660 #[case::simple_text_exts(simple_package(), EnvelopeFormat::ModelTextWithExtensions)]
661 #[case::multi_bin(multi_module_package(), EnvelopeFormat::Model)]
663 #[case::multi_bin_exts(multi_module_package(), EnvelopeFormat::ModelWithExtensions)]
664 #[case::multi_text(multi_module_package(), EnvelopeFormat::ModelText)]
665 #[case::multi_text_exts(multi_module_package(), EnvelopeFormat::ModelTextWithExtensions)]
666 fn model_roundtrip(#[case] package: Package, #[case] format: EnvelopeFormat) {
667 let mut buffer = Vec::new();
668 let config = EnvelopeConfig { format, zstd: None };
669 package.store(&mut buffer, config).unwrap();
670
671 let (decoded_config, new_package) =
672 read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();
673
674 assert_eq!(config.format, decoded_config.format);
675 assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
676
677 assert_eq!(package, new_package);
678 }
679
680 #[rstest]
681 #[case::simple(simple_package())]
682 fn test_check_breaking_extensions(#[case] mut package: Package) {
683 let test_ext_v0 =
685 Extension::new(ExtensionId::new_unchecked("test-v0"), Version::new(0, 2, 3));
686 let test_ext_v1 =
688 Extension::new(ExtensionId::new_unchecked("test-v1"), Version::new(1, 2, 3));
689
690 let registry =
692 ExtensionRegistry::new([Arc::new(test_ext_v0.clone()), Arc::new(test_ext_v1.clone())]);
693 let mut hugr = package.modules.remove(0);
694
695 assert_matches!(check_breaking_extensions(&hugr, ®istry), Ok(()));
697
698 let used_exts = json!([{ "name": "test-v0", "version": "0.2.3" }]);
700 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
701 assert_matches!(check_breaking_extensions(&hugr, ®istry), Ok(()));
702
703 let used_exts = json!([{ "name": "test-v0", "version": "0.2.4" }]);
705 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
706 assert_matches!(check_breaking_extensions(&hugr, ®istry), Ok(()));
707
708 let used_exts = json!([{ "name": "test-v0", "version": "0.3.3" }]);
710 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
711 assert_matches!(
712 check_breaking_extensions(&hugr, ®istry),
713 Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
714 name,
715 registered,
716 used
717 })) if name == "test-v0" && registered == Version::new(0, 2, 3) && used == Version::new(0, 3, 3)
718 );
719
720 let used_exts = json!([{ "name": "test-v0", "version": "1.2.3" }]);
722 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
723 assert_matches!(
724 check_breaking_extensions(&hugr, ®istry),
725 Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
726 name,
727 registered,
728 used
729 })) if name == "test-v0" && registered == Version::new(0, 2, 3) && used == Version::new(1, 2, 3)
730 );
731
732 let used_exts = json!([{ "name": "test-v1", "version": "1.2.3" }]);
734 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
735 assert_matches!(check_breaking_extensions(&hugr, ®istry), Ok(()));
736
737 let used_exts = json!([{ "name": "test-v1", "version": "1.3.0" }]);
739 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
740 assert_matches!(check_breaking_extensions(&hugr, ®istry), Ok(()));
741
742 let used_exts = json!([{ "name": "test-v1", "version": "1.2.4" }]);
744 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
745 assert_matches!(check_breaking_extensions(&hugr, ®istry), Ok(()));
746
747 let used_exts = json!([{ "name": "test-v1", "version": "2.2.3" }]);
749 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
750 assert_matches!(
751 check_breaking_extensions(&hugr, ®istry),
752 Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
753 name,
754 registered,
755 used
756 })) if name == "test-v1" && registered == Version::new(1, 2, 3) && used == Version::new(2, 2, 3)
757 );
758
759 let used_exts = json!([{ "name": "unknown", "version": "1.0.0" }]);
761 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
762 assert_matches!(check_breaking_extensions(&hugr, ®istry), Ok(()));
763
764 let used_exts = json!([
766 { "name": "unknown", "version": "1.0.0" },
767 { "name": "test-v1", "version": "2.0.0" }
768 ]);
769 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
770 assert_matches!(
771 check_breaking_extensions(&hugr, ®istry),
772 Err(ExtensionBreakingError::ExtensionVersionMismatch(ExtensionVersionMismatch {
773 name,
774 registered,
775 used
776 })) if name == "test-v1" && registered == Version::new(1, 2, 3) && used == Version::new(2, 0, 0)
777 );
778
779 hugr.set_metadata(
781 hugr.module_root(),
782 USED_EXTENSIONS_KEY,
783 json!("not an array"),
784 );
785 assert_matches!(
786 check_breaking_extensions(&hugr, ®istry),
787 Err(ExtensionBreakingError::Deserialization(_))
788 );
789
790 let used_exts = json!([
792 { "name": "test-v0", "version": "0.2.5" },
793 { "name": "test-v1", "version": "1.9.9" }
794 ]);
795 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
796 assert_matches!(check_breaking_extensions(&hugr, ®istry), Ok(()));
797 }
798
799 #[test]
800 fn test_with_generator_error_message() {
801 let test_ext = Extension::new(ExtensionId::new_unchecked("test"), Version::new(1, 0, 0));
802 let registry = ExtensionRegistry::new([Arc::new(test_ext)]);
803
804 let mut hugr = simple_package().modules.remove(0);
805
806 let generator_name = json!({ "name": "TestGenerator", "version": "1.2.3" });
808 hugr.set_metadata(hugr.module_root(), GENERATOR_KEY, generator_name.clone());
809
810 let used_exts = json!([{ "name": "test", "version": "2.0.0" }]);
812 hugr.set_metadata(hugr.module_root(), USED_EXTENSIONS_KEY, used_exts);
813
814 let err = check_breaking_extensions(&hugr, ®istry).unwrap_err();
816 let with_gen = WithGenerator::new(err, &[&hugr]);
817
818 let err_msg = with_gen.to_string();
819 assert!(err_msg.contains("Extension 'test' version mismatch"));
820 assert!(err_msg.contains(generator_name.to_string().as_str()));
821 }
822}