use std::io::{BufRead, Read};
use std::str::FromStr as _;
use hugr_model::v0::table;
use itertools::{Either, Itertools as _};
use crate::HugrView as _;
use crate::envelope::description::{ExtensionDesc, ModuleDesc, PackageDesc};
use crate::envelope::header::{EnvelopeFormat, HeaderError};
use crate::envelope::{EnvelopeHeader, ExtensionBreakingError, FormatUnsupportedError};
use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry};
use crate::extension::{Extension, ExtensionRegistry, ExtensionRegistryLoadError};
use crate::import::{ImportError, import_described_hugr};
use crate::package::Package;
use super::{check_breaking_extensions, check_model_version, package_json::PackageEncodingError};
use thiserror::Error;
use hugr_model::v0::bumpalo::Bump;
#[cfg(feature = "zstd")]
type RightType<R> = std::io::BufReader<zstd::Decoder<'static, std::io::BufReader<R>>>;
#[cfg(not(feature = "zstd"))]
type RightType<R> = std::io::BufReader<R>;
pub(crate) struct MaybeZstdRead<R>(Either<R, RightType<R>>);
impl<R> std::io::Read for MaybeZstdRead<R>
where
R: std::io::Read,
{
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
match &mut self.0 {
Either::Left(r) => r.read(buf),
Either::Right(r) => r.read(buf),
}
}
}
impl<R> std::io::BufRead for MaybeZstdRead<R>
where
R: std::io::BufRead,
{
fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
match &mut self.0 {
Either::Left(r) => r.fill_buf(),
Either::Right(r) => r.fill_buf(),
}
}
fn consume(&mut self, amt: usize) {
match &mut self.0 {
Either::Left(r) => r.consume(amt),
Either::Right(r) => r.consume(amt),
}
}
}
pub(super) struct EnvelopeReader<R> {
description: PackageDesc,
reader: MaybeZstdRead<R>,
registry: ExtensionRegistry,
}
impl<R: BufRead> EnvelopeReader<R> {
pub(super) fn new(mut reader: R, registry: &ExtensionRegistry) -> Result<Self, HeaderError> {
let header = EnvelopeHeader::read(&mut reader)?;
let reader = match header.zstd {
#[cfg(feature = "zstd")]
true => Either::Right(std::io::BufReader::new(zstd::Decoder::new(reader)?)),
#[cfg(not(feature = "zstd"))]
true => Err(super::header::HeaderErrorInner::ZstdUnsupported)?,
false => Either::Left(reader),
};
Ok(Self {
description: PackageDesc::new(header),
reader: MaybeZstdRead(reader),
registry: registry.clone(),
})
}
fn header(&self) -> &EnvelopeHeader {
&self.description.header
}
fn register_packaged(&mut self, extensions: &ExtensionRegistry) {
self.registry.extend(extensions);
}
fn handle_resolution_error(desc: &mut ModuleDesc, err: &ExtensionResolutionError) {
match err {
ExtensionResolutionError::MissingOpExtension {
missing_extension, ..
}
| ExtensionResolutionError::MissingTypeExtension {
missing_extension, ..
} => desc.extend_used_extensions_resolved([ExtensionDesc::new_unversioned(
missing_extension,
)]),
ExtensionResolutionError::InvalidConstTypes {
missing_extensions, ..
} => desc.extend_used_extensions_resolved(
missing_extensions
.iter()
.map(ExtensionDesc::new_unversioned),
),
_ => {}
}
}
fn read_impl(&mut self) -> Result<Package, PayloadError> {
let mut package = match self.header().format {
#[expect(deprecated)]
EnvelopeFormat::PackageJson => self.decode_json()?,
EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => self.decode_model()?,
EnvelopeFormat::SExpression | EnvelopeFormat::SExpressionWithExtensions => {
self.decode_model_ast()?
}
};
self.description.set_n_modules(package.modules.len());
for (index, module) in package.modules.iter_mut().enumerate() {
let desc = &mut self.description.modules[index];
let desc = desc.get_or_insert_default();
desc.load_used_extensions_generator(module)
.map_err(ExtensionBreakingError::from)?;
if let Some(used_exts) = &desc.used_extensions_generator {
check_breaking_extensions(module.extensions(), used_exts)?;
}
module
.resolve_extension_defs(&self.registry)
.inspect_err(|err| Self::handle_resolution_error(desc, err))?;
desc.load_from_hugr(&module);
}
for (index, ext) in package.extensions.iter().enumerate() {
self.description.set_packaged_extension(index, ext);
}
Ok(package)
}
pub(super) fn read(mut self) -> (PackageDesc, Result<Package, PayloadError>) {
let res = self.read_impl();
(self.description, res)
}
#[expect(deprecated)]
fn decode_json(&mut self) -> Result<Package, PackageEncodingError> {
let super::package_json::PackageDeser {
modules,
extensions: pkg_extensions,
} = serde_json::from_reader(&mut self.reader)
.map_err(PackageEncodingError::JsonEncoding)?;
let modules = modules.into_iter().map(|h| h.0).collect_vec();
let pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
pkg_extensions,
&WeakExtensionRegistry::from(&self.registry),
)
.map_err(PackageEncodingError::ExtensionResolution)?;
self.register_packaged(&pkg_extensions);
Ok(Package {
modules,
extensions: pkg_extensions,
})
}
fn decode_model(&mut self) -> Result<Package, ModelBinaryReadError> {
check_model_version(self.header().format)?;
let bump = Bump::default();
let model_package = hugr_model::v0::binary::read_from_reader(&mut self.reader, &bump)?;
let packaged_extensions = if self.header().format == EnvelopeFormat::ModelWithExtensions {
ExtensionRegistry::load_json(&mut self.reader, &self.registry)?
} else {
ExtensionRegistry::new([])
};
self.import_package(&model_package, packaged_extensions)
.map_err(Into::into)
}
fn decode_model_ast(&mut self) -> Result<Package, SExpressionReadError> {
let format = self.header().format;
check_model_version(format)?;
let packaged_extensions = if format == EnvelopeFormat::SExpressionWithExtensions {
let deserializer = serde_json::Deserializer::from_reader(&mut self.reader);
let extra_extensions = deserializer
.into_iter::<Vec<Extension>>()
.next()
.unwrap_or(Ok(vec![]))?;
let weak_registry: WeakExtensionRegistry = (&self.registry).into();
ExtensionRegistry::new_with_extension_resolution(extra_extensions, &weak_registry)
.map_err(ExtensionRegistryLoadError::from)?
} else {
ExtensionRegistry::new([])
};
let mut buffer = String::new();
self.reader.read_to_string(&mut buffer)?;
let ast_package = hugr_model::v0::ast::Package::from_str(&buffer)?;
let bump = Bump::default();
let model_package = ast_package.resolve(&bump)?;
self.import_package(&model_package, packaged_extensions)
.map_err(Into::into)
}
fn import_package(
&mut self,
package: &table::Package,
packaged_extensions: ExtensionRegistry,
) -> Result<Package, crate::import::ImportError> {
self.description.set_n_modules(package.modules.len());
self.register_packaged(&packaged_extensions);
let modules = package
.modules
.iter()
.enumerate()
.map(|(index, module)| {
let (desc, result) = import_described_hugr(module, &self.registry);
self.description.set_module(index, desc);
result
})
.collect::<Result<Vec<_>, _>>()?;
let mut package = Package::new(modules);
package.extensions = packaged_extensions;
Ok(package)
}
}
#[derive(Error, Debug)]
#[non_exhaustive]
#[error(transparent)]
pub struct PayloadError(PayloadErrorInner);
#[derive(Error, Debug)]
#[non_exhaustive]
#[error(transparent)]
pub(crate) enum PayloadErrorInner {
#[deprecated(since = "0.27.0")]
JsonRead(#[from] PackageEncodingError),
ModelBinary(#[from] ModelBinaryReadError),
SExpression(#[from] SExpressionReadError),
ExtensionsBreaking(#[from] ExtensionBreakingError),
ExtensionResolution(#[from] ExtensionResolutionError),
}
impl<T: Into<PayloadErrorInner>> From<T> for PayloadError {
fn from(value: T) -> Self {
Self(value.into())
}
}
#[derive(Debug, Error)]
#[error(transparent)]
pub(crate) enum SExpressionReadError {
ParseString(#[from] hugr_model::v0::ast::ParseError),
Import(#[from] ImportError),
ExtensionLoad(#[from] crate::extension::ExtensionRegistryLoadError),
FormatUnsupported(#[from] FormatUnsupportedError),
ExtensionDeserialize(#[from] serde_json::Error),
StringRead(#[from] std::io::Error),
ResolveError(#[from] hugr_model::v0::ast::ResolveError),
}
#[derive(Debug, Error)]
#[error(transparent)]
pub(crate) enum ModelBinaryReadError {
ParseString(#[from] hugr_model::v0::ast::ParseError),
ReadBinary(#[from] hugr_model::v0::binary::ReadError),
Import(#[from] ImportError),
Extensions(#[from] crate::extension::ExtensionRegistryLoadError),
FormatUnsupported(#[from] FormatUnsupportedError),
}
#[cfg(test)]
mod test {
use super::*;
use crate::Hugr;
use crate::builder::test::simple_module_hugr;
use crate::envelope::header::EnvelopeHeader;
use crate::envelope::{EnvelopeConfig, EnvelopeFormat, read_envelope};
use crate::extension::{ExtensionId, ExtensionRegistry};
use crate::hugr::HugrMut;
use cool_asserts::assert_matches;
use rstest::rstest;
use std::io::{BufReader, Cursor, Write as _};
#[test]
fn test_read_invalid_header() {
let cursor = Cursor::new(Vec::new()); let registry = ExtensionRegistry::new([]);
let result = EnvelopeReader::new(cursor, ®istry);
assert!(result.is_err());
}
#[test]
#[expect(deprecated)]
fn test_read_invalid_json_payload() {
let header = EnvelopeHeader {
format: EnvelopeFormat::PackageJson,
..Default::default()
};
let mut cursor = Cursor::new(Vec::new());
header.write(&mut cursor).unwrap();
cursor.write_all(b"invalid json").unwrap(); cursor.set_position(0);
let registry = ExtensionRegistry::new([]);
let reader = EnvelopeReader::new(cursor, ®istry).unwrap();
let (description, result) = reader.read();
assert_matches!(result, Err(PayloadError(PayloadErrorInner::JsonRead(_))));
assert_eq!(description.header, header);
}
#[test]
fn test_read_text_format() {
let header = EnvelopeHeader {
format: EnvelopeFormat::SExpressionWithExtensions,
..Default::default()
};
let mut cursor = Cursor::new(Vec::new());
header.write(&mut cursor).unwrap();
cursor.set_position(0);
let registry = ExtensionRegistry::new([]);
let reader = EnvelopeReader::new(cursor, ®istry).unwrap();
let (description, result) = reader.read();
assert_matches!(result, Err(PayloadError(PayloadErrorInner::SExpression(_))));
assert_eq!(description.header, header);
}
#[test]
#[expect(deprecated)]
fn test_partial_description_on_error() {
let header = EnvelopeHeader {
format: EnvelopeFormat::PackageJson,
..Default::default()
};
let mut cursor = Cursor::new(Vec::new());
header.write(&mut cursor).unwrap();
cursor.write_all(b"{\"modules\": [\"invalid\"]}").unwrap(); cursor.set_position(0);
let registry = ExtensionRegistry::new([]);
let reader = EnvelopeReader::new(cursor, ®istry).unwrap();
let (description, result) = reader.read();
assert_matches!(result, Err(PayloadError(PayloadErrorInner::JsonRead(_))));
assert_eq!(description.header, header);
assert_eq!(description.n_modules(), 0); }
#[test]
fn test_handle_resolution_error() {
use crate::extension::ExtensionId;
use crate::ops::{OpName, constant::ValueName};
use crate::types::TypeName;
let mut desc = ModuleDesc::default();
let handle_error = |d: &mut ModuleDesc, err: &ExtensionResolutionError| {
EnvelopeReader::<Cursor<Vec<u8>>>::handle_resolution_error(d, err)
};
let assert_extensions = |d: &ModuleDesc, expected_ids: &[&ExtensionId]| {
let resolved = d.used_extensions_resolved.as_ref().unwrap();
assert_eq!(resolved.len(), expected_ids.len());
let names: Vec<_> = resolved.iter().map(|e| &e.name).collect();
for ext_id in expected_ids {
assert!(names.contains(&&ext_id.to_string()));
}
assert!(
resolved
.iter()
.all(|e| e.version == crate::extension::Version::new(0, 0, 0))
);
};
let ext_id = ExtensionId::new("test.extension").unwrap();
let error = ExtensionResolutionError::MissingOpExtension {
node: None,
op: OpName::new("test.op"),
missing_extension: ext_id.clone(),
available_extensions: vec![],
};
handle_error(&mut desc, &error);
assert_extensions(&desc, &[&ext_id]);
desc.used_extensions_resolved = None;
let ext_id2 = ExtensionId::new("test.extension2").unwrap();
let error = ExtensionResolutionError::MissingTypeExtension {
node: None,
ty: TypeName::new("test.type"),
missing_extension: ext_id2.clone(),
available_extensions: vec![],
};
handle_error(&mut desc, &error);
assert_extensions(&desc, &[&ext_id2]);
desc.used_extensions_resolved = None;
let ext_id3 = ExtensionId::new("test.extension3").unwrap();
let ext_id4 = ExtensionId::new("test.extension4").unwrap();
let mut missing_exts = crate::extension::ExtensionSet::new();
missing_exts.insert(ext_id3.clone());
missing_exts.insert(ext_id4.clone());
let error = ExtensionResolutionError::InvalidConstTypes {
value: ValueName::new("test.value"),
missing_extensions: missing_exts,
};
handle_error(&mut desc, &error);
assert_extensions(&desc, &[&ext_id3, &ext_id4]);
desc.used_extensions_resolved = None;
let error = ExtensionResolutionError::WrongTypeDefExtension {
extension: ExtensionId::new("ext1").unwrap(),
def: TypeName::new("def"),
wrong_extension: ExtensionId::new("ext2").unwrap(),
};
handle_error(&mut desc, &error);
assert!(desc.used_extensions_resolved.is_none());
}
#[test]
fn test_decode_model_ast_with_packaged_extensions() {
let mut simple_package = crate::builder::test::simple_package();
let ext_id = ExtensionId::new("test.packaged.extension").unwrap();
let extension = Extension::new(ext_id.clone(), crate::extension::Version::new(1, 0, 0));
simple_package
.extensions
.register(std::sync::Arc::new(extension))
.unwrap();
let header = EnvelopeHeader {
format: EnvelopeFormat::SExpressionWithExtensions,
..Default::default()
};
let mut cursor = Cursor::new(Vec::new());
simple_package.store(&mut cursor, header.config()).unwrap();
cursor.set_position(0);
let registry = ExtensionRegistry::new([]);
let mut reader = EnvelopeReader::new(cursor, ®istry).unwrap();
assert!(!reader.registry.contains(&ext_id));
let result = reader.decode_model_ast();
assert!(result.is_ok());
assert!(reader.registry.contains(&ext_id));
let package = result.unwrap();
assert!(package.extensions.contains(&ext_id));
}
#[rstest::fixture]
fn big_hugr() -> Hugr {
const PAYLOAD_SIZE: usize = 64 * 1024 * 1024;
let big_payload: String = "a".repeat(PAYLOAD_SIZE);
let mut hugr = simple_module_hugr();
hugr.set_metadata_any(hugr.entrypoint(), "big", big_payload);
hugr
}
#[rstest]
#[case::model_with_extensions(EnvelopeFormat::ModelWithExtensions)]
#[case::model_text_with_extensions(EnvelopeFormat::SExpressionWithExtensions)]
#[case::package_json(EnvelopeFormat::PackageJson)]
#[ignore = "This test takes > 15s due to the large payload size."]
#[allow(deprecated)]
fn big_hugr_payload(#[case] format: EnvelopeFormat, big_hugr: Hugr) {
let config = EnvelopeConfig { format, zstd: None };
let mut buffer = Vec::with_capacity(64 * 1024 * 1024);
big_hugr.store(&mut buffer, config).unwrap();
let (desc, decoded) =
read_envelope(BufReader::new(buffer.as_slice()), big_hugr.extensions()).unwrap();
assert_eq!(desc.header.config().format, format);
decoded.validate().unwrap();
}
}