1mod header;
38
39pub use header::{EnvelopeConfig, EnvelopeFormat, ZstdConfig, MAGIC_NUMBERS};
40
41use crate::{
42 extension::ExtensionRegistry,
43 package::{Package, PackageEncodingError, PackageError},
44};
45use header::EnvelopeHeader;
46use std::io::BufRead;
47use std::io::Write;
48
49#[allow(unused_imports)]
50use itertools::Itertools as _;
51
52#[cfg(feature = "model_unstable")]
53use crate::import::ImportError;
54
55pub fn read_envelope(
64 mut reader: impl BufRead,
65 registry: &ExtensionRegistry,
66) -> Result<(EnvelopeConfig, Package), EnvelopeError> {
67 let header = EnvelopeHeader::read(&mut reader)?;
68
69 let package = match header.zstd {
70 #[cfg(feature = "zstd")]
71 true => read_impl(
72 std::io::BufReader::new(zstd::Decoder::new(reader)?),
73 header,
74 registry,
75 ),
76 #[cfg(not(feature = "zstd"))]
77 true => Err(EnvelopeError::ZstdUnsupported),
78 false => read_impl(reader, header, registry),
79 }?;
80 Ok((header.config(), package))
81}
82
83pub fn write_envelope(
88 mut writer: impl Write,
89 package: &Package,
90 config: EnvelopeConfig,
91) -> Result<(), EnvelopeError> {
92 let header = config.make_header();
93 header.write(&mut writer)?;
94
95 match config.zstd {
96 #[cfg(feature = "zstd")]
97 Some(zstd) => {
98 let writer = zstd::Encoder::new(writer, zstd.level())?.auto_finish();
99 write_impl(writer, package, config)?;
100 }
101 #[cfg(not(feature = "zstd"))]
102 Some(_) => return Err(EnvelopeError::ZstdUnsupported),
103 None => write_impl(writer, package, config)?,
104 }
105
106 Ok(())
107}
108
109#[derive(derive_more::Display, derive_more::Error, Debug, derive_more::From)]
111#[non_exhaustive]
112pub enum EnvelopeError {
113 #[display(
115 "Bad magic number. expected 0x{:X} found 0x{:X}",
116 u64::from_be_bytes(*expected),
117 u64::from_be_bytes(*found)
118 )]
119 #[from(ignore)]
120 MagicNumber {
121 expected: [u8; 8],
125 found: [u8; 8],
127 },
128 #[display("Format descriptor {descriptor} is invalid.")]
130 #[from(ignore)]
131 InvalidFormatDescriptor {
132 descriptor: usize,
134 },
135 #[display("Payload format {format} is not supported.{}",
137 match feature {
138 Some(f) => format!(" This requires the '{f}' feature for `hugr`."),
139 None => "".to_string()
140 },
141 )]
142 #[from(ignore)]
143 FormatUnsupported {
144 format: EnvelopeFormat,
146 feature: Option<&'static str>,
148 },
149 #[display("Envelope format {format} cannot be represented as ASCII.")]
153 #[from(ignore)]
154 NonASCIIFormat {
155 format: EnvelopeFormat,
157 },
158 #[display("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")]
160 #[from(ignore)]
161 ZstdUnsupported,
162 #[display(
164 "Packages with multiple HUGRs are currently unsupported. Tried to encode {count} HUGRs, when 1 was expected."
165 )]
166 #[from(ignore)]
167 MultipleHugrs {
168 count: usize,
170 },
171 SerdeError {
173 source: serde_json::Error,
175 },
176 IO {
178 source: std::io::Error,
180 },
181 Package {
183 source: PackageError,
185 },
186 PackageEncoding {
188 source: PackageEncodingError,
190 },
191 #[cfg(feature = "model_unstable")]
193 ModelImport {
194 source: ImportError,
196 },
197 #[cfg(feature = "model_unstable")]
199 ModelRead {
200 source: hugr_model::v0::binary::ReadError,
202 },
203 #[cfg(feature = "model_unstable")]
205 ModelWrite {
206 source: hugr_model::v0::binary::WriteError,
208 },
209}
210
211fn read_impl(
213 payload: impl BufRead,
214 header: EnvelopeHeader,
215 registry: &ExtensionRegistry,
216) -> Result<Package, EnvelopeError> {
217 match header.format {
218 #[allow(deprecated)]
219 EnvelopeFormat::PackageJson => Ok(Package::from_json_reader(payload, registry)?),
220 #[cfg(feature = "model_unstable")]
221 EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
222 decode_model(payload, registry, header.format)
223 }
224 #[cfg(not(feature = "model_unstable"))]
225 EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
226 Err(EnvelopeError::FormatUnsupported {
227 format: header.format,
228 feature: Some("model_unstable"),
229 })
230 }
231 }
232}
233
234#[cfg(feature = "model_unstable")]
242fn decode_model(
243 mut stream: impl BufRead,
244 extension_registry: &ExtensionRegistry,
245 format: EnvelopeFormat,
246) -> Result<Package, EnvelopeError> {
247 use crate::{import::import_hugr, Extension};
248 use hugr_model::v0::bumpalo::Bump;
249
250 if format.model_version() != Some(0) {
251 return Err(EnvelopeError::FormatUnsupported {
252 format,
253 feature: None,
254 });
255 }
256
257 let bump = Bump::default();
258 let module_list = hugr_model::v0::binary::read_from_reader(&mut stream, &bump)?;
259
260 let mut extension_registry = extension_registry.clone();
261 if format.append_extensions() {
262 let extra_extensions: Vec<Extension> =
263 serde_json::from_reader::<_, Vec<Extension>>(stream)?;
264 for ext in extra_extensions {
265 extension_registry.register_updated(ext);
266 }
267 }
268
269 let hugr = import_hugr(&module_list, &extension_registry)?;
271 Ok(Package::new([hugr])?)
272}
273
274fn write_impl(
276 writer: impl Write,
277 package: &Package,
278 config: EnvelopeConfig,
279) -> Result<(), EnvelopeError> {
280 match config.format {
281 #[allow(deprecated)]
282 EnvelopeFormat::PackageJson => package.to_json_writer(writer)?,
283 #[cfg(feature = "model_unstable")]
284 EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
285 encode_model(writer, package, config.format)?
286 }
287 #[cfg(not(feature = "model_unstable"))]
288 EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => {
289 return Err(EnvelopeError::FormatUnsupported {
290 format: config.format,
291 feature: Some("model_unstable"),
292 })
293 }
294 }
295 Ok(())
296}
297
298#[cfg(feature = "model_unstable")]
299fn encode_model(
300 mut writer: impl Write,
301 package: &Package,
302 format: EnvelopeFormat,
303) -> Result<(), EnvelopeError> {
304 use crate::export::export_hugr;
305 use hugr_model::v0::{binary::write_to_writer, bumpalo::Bump};
306
307 if format.model_version() != Some(0) {
308 return Err(EnvelopeError::FormatUnsupported {
309 format,
310 feature: None,
311 });
312 }
313
314 if package.modules.len() != 1 {
316 return Err(EnvelopeError::MultipleHugrs {
317 count: package.modules.len(),
318 });
319 }
320 let bump = Bump::default();
321 let module = export_hugr(&package.modules[0], &bump);
322 write_to_writer(&module, &mut writer)?;
323
324 if format.append_extensions() {
325 serde_json::to_writer(writer, &package.extensions.iter().collect_vec())?;
326 }
327
328 Ok(())
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use cool_asserts::assert_matches;
335 use rstest::rstest;
336 use std::io::BufReader;
337
338 use crate::builder::test::{multi_module_package, simple_package};
339 use crate::extension::PRELUDE_REGISTRY;
340
341 #[rstest]
342 fn errors() {
343 let package = simple_package();
344 assert_matches!(
345 package.store_str(EnvelopeConfig::binary()),
346 Err(EnvelopeError::NonASCIIFormat { .. })
347 );
348 }
349
350 #[rstest]
351 #[case::empty(Package::default())]
352 #[case::simple(simple_package())]
353 #[case::multi(multi_module_package())]
354 fn text_roundtrip(#[case] package: Package) {
355 let envelope = package.store_str(EnvelopeConfig::text()).unwrap();
356 let new_package = Package::load_str(&envelope, None).unwrap();
357 assert_eq!(package, new_package);
358 }
359
360 #[rstest]
361 #[case::empty(Package::default())]
362 #[case::simple(simple_package())]
363 #[case::multi(multi_module_package())]
364 fn compressed_roundtrip(#[case] package: Package) {
365 let mut buffer = Vec::new();
366 let config = EnvelopeConfig {
367 format: EnvelopeFormat::PackageJson,
368 zstd: Some(ZstdConfig::default()),
369 };
370 let res = package.store(&mut buffer, config);
371
372 match cfg!(feature = "zstd") {
373 true => res.unwrap(),
374 false => {
375 assert_matches!(res, Err(EnvelopeError::ZstdUnsupported));
376 return;
377 }
378 }
379
380 let (decoded_config, new_package) =
381 read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();
382
383 assert_eq!(config.format, decoded_config.format);
384 assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
385 assert_eq!(package, new_package);
386 }
387
388 #[rstest]
389 #[case::simple(simple_package())]
391 #[cfg(feature = "model_unstable")]
393 fn module_exts_roundtrip(#[case] package: Package) {
394 let mut buffer = Vec::new();
395 let config = EnvelopeConfig {
396 format: EnvelopeFormat::ModelWithExtensions,
397 ..Default::default()
398 };
399 package.store(&mut buffer, config).unwrap();
400 let (decoded_config, new_package) =
401 read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();
402
403 assert_eq!(config.format, decoded_config.format);
404 assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
405 assert_eq!(package, new_package);
406 }
407
408 #[rstest]
409 #[case::simple(simple_package())]
411 fn module_roundtrip(#[case] package: Package) {
413 let mut buffer = Vec::new();
414 let config = EnvelopeConfig {
415 format: EnvelopeFormat::Model,
416 ..Default::default()
417 };
418 let res = package.store(&mut buffer, config);
419
420 match cfg!(feature = "model_unstable") {
421 true => res.unwrap(),
422 false => {
423 assert_matches!(res, Err(EnvelopeError::FormatUnsupported { .. }));
424 return;
425 }
426 }
427
428 let (decoded_config, new_package) =
429 read_envelope(BufReader::new(buffer.as_slice()), &PRELUDE_REGISTRY).unwrap();
430
431 assert_eq!(config.format, decoded_config.format);
432 assert_eq!(config.zstd.is_some(), decoded_config.zstd.is_some());
433
434 assert_eq!(package, new_package);
435 }
436}