hugr_core/envelope/
reader.rs

1use std::io::{BufRead, Read};
2use std::str::FromStr as _;
3
4use hugr_model::v0::table;
5use itertools::{Either, Itertools as _};
6
7use crate::HugrView as _;
8use crate::envelope::description::PackageDesc;
9use crate::envelope::header::{EnvelopeFormat, HeaderError};
10use crate::envelope::{
11    EnvelopeError, EnvelopeHeader, ExtensionBreakingError, FormatUnsupportedError,
12};
13use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry};
14use crate::extension::{Extension, ExtensionRegistry};
15use crate::import::{ImportError, import_described_hugr};
16use crate::package::Package;
17
18use super::{check_breaking_extensions, check_model_version, package_json::PackageEncodingError};
19use thiserror::Error;
20
21use hugr_model::v0::bumpalo::Bump;
22#[cfg(feature = "zstd")]
23type RightType<R> = std::io::BufReader<zstd::Decoder<'static, std::io::BufReader<R>>>;
24#[cfg(not(feature = "zstd"))]
25type RightType<R> = std::io::BufReader<R>;
26
27pub(crate) struct MaybeZstdRead<R>(Either<R, RightType<R>>);
28
29impl<R> std::io::Read for MaybeZstdRead<R>
30where
31    R: std::io::Read,
32{
33    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
34        match &mut self.0 {
35            Either::Left(r) => r.read(buf),
36            Either::Right(r) => r.read(buf),
37        }
38    }
39}
40
41impl<R> std::io::BufRead for MaybeZstdRead<R>
42where
43    R: std::io::BufRead,
44{
45    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
46        match &mut self.0 {
47            Either::Left(r) => r.fill_buf(),
48            Either::Right(r) => r.fill_buf(),
49        }
50    }
51
52    fn consume(&mut self, amt: usize) {
53        match &mut self.0 {
54            Either::Left(r) => r.consume(amt),
55            Either::Right(r) => r.consume(amt),
56        }
57    }
58}
59
60/// Reader for HUGR envelopes.
61///
62/// To read a package from an envelope, first create an `EnvelopeReader` using
63/// [`EnvelopeReader::new`], then call [`EnvelopeReader::read`].
64pub(super) struct EnvelopeReader<R> {
65    description: PackageDesc,
66    reader: MaybeZstdRead<R>,
67    registry: ExtensionRegistry,
68}
69
70impl<R: BufRead> EnvelopeReader<R> {
71    /// Create a new `EnvelopeReader` from a reader and an extension registry.
72    ///
73    /// # Errors
74    ///
75    /// - If the header is invalid.
76    /// - If zstd decompression is requested but the `zstd` feature is not
77    ///   enabled.
78    pub(super) fn new(mut reader: R, registry: &ExtensionRegistry) -> Result<Self, HeaderError> {
79        let header = EnvelopeHeader::read(&mut reader)?;
80        let reader = match header.zstd {
81            #[cfg(feature = "zstd")]
82            true => Either::Right(std::io::BufReader::new(zstd::Decoder::new(reader)?)),
83            #[cfg(not(feature = "zstd"))]
84            true => Err(super::header::HeaderErrorInner::ZstdUnsupported)?,
85            false => Either::Left(reader),
86        };
87        Ok(Self {
88            description: PackageDesc::new(header),
89            reader: MaybeZstdRead(reader),
90            registry: registry.clone(),
91        })
92    }
93
94    pub(crate) fn description(&self) -> &PackageDesc {
95        &self.description
96    }
97
98    fn header(&self) -> &EnvelopeHeader {
99        &self.description.header
100    }
101
102    fn register_packaged(&mut self, extensions: &ExtensionRegistry) {
103        self.registry.extend(extensions);
104    }
105
106    fn read_impl(&mut self) -> Result<Package, PayloadError> {
107        let mut package = match self.header().format {
108            EnvelopeFormat::PackageJson => self.decode_json()?,
109            EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => self.decode_model()?,
110            EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
111                self.decode_model_ast()?
112            }
113        };
114        self.description.set_n_modules(package.modules.len());
115        for (index, module) in package.modules.iter_mut().enumerate() {
116            let desc = &mut self.description.modules[index];
117            let desc = desc.get_or_insert_default();
118            desc.load_used_extensions_generator(module)
119                .map_err(ExtensionBreakingError::from)?;
120            if let Some(used_exts) = &mut desc.used_extensions_generator {
121                check_breaking_extensions(module.extensions(), used_exts.drain(..))?;
122            }
123
124            module.resolve_extension_defs(&self.registry)?;
125            // overwrite the description with the actual module read,
126            // cheap so ok to repeat.
127            desc.load_from_hugr(&module);
128        }
129
130        for (index, ext) in package.extensions.iter().enumerate() {
131            self.description.set_packaged_extension(index, ext);
132        }
133        Ok(package)
134    }
135
136    /// Read the package and return the description and the package or an error.
137    ///
138    /// The description is always returned, even if reading the package fails,
139    /// it may be incomplete. Minimally it contains the header, but may also
140    /// contain any information gathered prior to the error.
141    ///
142    /// # Errors
143    ///
144    /// - If reading the package payload fails.
145    pub(super) fn read(mut self) -> (PackageDesc, Result<Package, PayloadError>) {
146        let res = self.read_impl();
147
148        (self.description, res)
149    }
150
151    /// Read a Package in json format from an io reader.
152    /// Returns package and the combined extension registry
153    /// of the provided registry and the package extensions.
154    fn decode_json(&mut self) -> Result<Package, PackageEncodingError> {
155        let super::package_json::PackageDeser {
156            modules,
157            extensions: pkg_extensions,
158        } = serde_json::from_reader(&mut self.reader)?;
159        let modules = modules.into_iter().map(|h| h.0).collect_vec();
160        let pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
161            pkg_extensions,
162            &WeakExtensionRegistry::from(&self.registry),
163        )?;
164
165        // Resolve the operations in the modules using the defined registries.
166        self.register_packaged(&pkg_extensions);
167        Ok(Package {
168            modules,
169            extensions: pkg_extensions,
170        })
171    }
172    /// Read a HUGR model payload from a reader.
173    fn decode_model(&mut self) -> Result<Package, ModelBinaryReadError> {
174        check_model_version(self.header().format)?;
175        let bump = Bump::default();
176        let model_package = hugr_model::v0::binary::read_from_reader(&mut self.reader, &bump)?;
177
178        let packaged_extensions = if self.header().format == EnvelopeFormat::ModelWithExtensions {
179            ExtensionRegistry::load_json(&mut self.reader, &self.registry)?
180        } else {
181            ExtensionRegistry::new([])
182        };
183        self.register_packaged(&packaged_extensions);
184
185        self.import_package(&model_package, packaged_extensions)
186            .map_err(Into::into)
187    }
188
189    /// Read a HUGR model text payload from a reader.
190    fn decode_model_ast(&mut self) -> Result<Package, ModelTextReadError> {
191        let format = self.header().format;
192        check_model_version(format)?;
193
194        let packaged_extensions = if format == EnvelopeFormat::ModelTextWithExtensions {
195            let deserializer = serde_json::Deserializer::from_reader(&mut self.reader);
196            // Deserialize the first json object, leaving the rest of the reader unconsumed.
197            let extra_extensions = deserializer
198                .into_iter::<Vec<Extension>>()
199                .next()
200                .unwrap_or(Ok(vec![]))?;
201            ExtensionRegistry::new(extra_extensions.into_iter().map(std::sync::Arc::new))
202        } else {
203            ExtensionRegistry::new([])
204        };
205
206        // Read the package into a string, then parse it.
207        //
208        // Due to how `to_string` works, we cannot append extensions after the package.
209        let mut buffer = String::new();
210        self.reader.read_to_string(&mut buffer)?;
211        let ast_package = hugr_model::v0::ast::Package::from_str(&buffer)?;
212
213        let bump = Bump::default();
214        let model_package = ast_package.resolve(&bump)?;
215
216        self.import_package(&model_package, packaged_extensions)
217            .map_err(Into::into)
218    }
219
220    fn import_package(
221        &mut self,
222        package: &table::Package,
223        packaged_extensions: ExtensionRegistry,
224    ) -> Result<Package, crate::import::ImportError> {
225        self.description.set_n_modules(package.modules.len());
226
227        let modules = package
228            .modules
229            .iter()
230            .enumerate()
231            .map(|(index, module)| {
232                let (desc, result) = import_described_hugr(module, &self.registry);
233                self.description.set_module(index, desc);
234                result
235            })
236            .collect::<Result<Vec<_>, _>>()?;
237
238        // This does not panic since the import already requires a module root.
239        let mut package = Package::new(modules);
240        package.extensions = packaged_extensions;
241        Ok(package)
242    }
243}
244
245#[derive(Error, Debug)]
246#[non_exhaustive]
247/// Error decoding an envelope payload.
248#[error(transparent)]
249pub struct PayloadError(PayloadErrorInner);
250
251#[derive(Error, Debug)]
252#[non_exhaustive]
253#[error(transparent)]
254/// Error decoding an envelope payload with enumerated variants.
255enum PayloadErrorInner {
256    /// Error decoding a JSON format package.
257    JsonRead(#[from] PackageEncodingError),
258    /// Error decoding a binary model format package.
259    ModelBinary(#[from] ModelBinaryReadError),
260    /// Error decoding a text model format package.
261    ModelText(#[from] ModelTextReadError),
262    /// Error raised while checking for breaking extension version mismatch.
263    ExtensionsBreaking(#[from] ExtensionBreakingError),
264    /// Error resolving extensions while decoding the payload.
265    ExtensionResolution(#[from] ExtensionResolutionError),
266}
267impl From<PayloadError> for EnvelopeError {
268    fn from(value: PayloadError) -> Self {
269        match value.0 {
270            PayloadErrorInner::JsonRead(e) => e.into(),
271            PayloadErrorInner::ModelBinary(e) => e.into(),
272            PayloadErrorInner::ModelText(e) => e.into(),
273            #[expect(deprecated)]
274            PayloadErrorInner::ExtensionsBreaking(e) => super::WithGenerator {
275                inner: Box::new(e),
276                generator: None,
277            }
278            .into(),
279            PayloadErrorInner::ExtensionResolution(e) => e.into(),
280        }
281    }
282}
283
284impl<T: Into<PayloadErrorInner>> From<T> for PayloadError {
285    fn from(value: T) -> Self {
286        Self(value.into())
287    }
288}
289
290#[derive(Debug, Error)]
291#[error(transparent)]
292enum ModelTextReadError {
293    ParseString(#[from] hugr_model::v0::ast::ParseError),
294    Import(#[from] ImportError),
295    ExtensionLoad(#[from] crate::extension::ExtensionRegistryLoadError),
296    FormatUnsupported(#[from] FormatUnsupportedError),
297    ExtensionDeserialize(#[from] serde_json::Error),
298    StringRead(#[from] std::io::Error),
299    ResolveError(#[from] hugr_model::v0::ast::ResolveError),
300}
301impl From<ModelTextReadError> for EnvelopeError {
302    fn from(value: ModelTextReadError) -> Self {
303        match value {
304            ModelTextReadError::FormatUnsupported(e) => EnvelopeError::FormatUnsupported {
305                format: e.format,
306                feature: e.feature,
307            },
308            ModelTextReadError::ParseString(e) => e.into(),
309            ModelTextReadError::Import(e) => e.into(),
310            ModelTextReadError::ExtensionLoad(e) => e.into(),
311            ModelTextReadError::ExtensionDeserialize(e) => e.into(),
312            ModelTextReadError::StringRead(e) => e.into(),
313            ModelTextReadError::ResolveError(e) => e.into(),
314        }
315    }
316}
317
318#[derive(Debug, Error)]
319#[error(transparent)]
320enum ModelBinaryReadError {
321    ParseString(#[from] hugr_model::v0::ast::ParseError),
322    ReadBinary(#[from] hugr_model::v0::binary::ReadError),
323    Import(#[from] ImportError),
324    Extensions(#[from] crate::extension::ExtensionRegistryLoadError),
325    FormatUnsupported(#[from] FormatUnsupportedError),
326}
327
328impl From<ModelBinaryReadError> for EnvelopeError {
329    fn from(value: ModelBinaryReadError) -> Self {
330        match value {
331            ModelBinaryReadError::FormatUnsupported(e) => EnvelopeError::FormatUnsupported {
332                format: e.format,
333                feature: e.feature,
334            },
335            ModelBinaryReadError::ParseString(e) => e.into(),
336            ModelBinaryReadError::ReadBinary(e) => e.into(),
337            ModelBinaryReadError::Import(e) => e.into(),
338            ModelBinaryReadError::Extensions(e) => e.into(),
339        }
340    }
341}
342
343#[cfg(test)]
344mod test {
345    use super::*;
346
347    use crate::extension::ExtensionRegistry;
348
349    use crate::envelope::header::EnvelopeHeader;
350    use cool_asserts::assert_matches;
351
352    use std::io::{Cursor, Write as _};
353
354    #[test]
355    fn test_read_invalid_header() {
356        let cursor = Cursor::new(Vec::new()); // Empty cursor simulates invalid header
357        let registry = ExtensionRegistry::new([]);
358        let result = EnvelopeReader::new(cursor, &registry);
359        assert!(result.is_err());
360    }
361
362    #[test]
363    fn test_read_invalid_json_payload() {
364        let header = EnvelopeHeader {
365            format: EnvelopeFormat::PackageJson,
366            ..Default::default()
367        };
368        let mut cursor = Cursor::new(Vec::new());
369        header.write(&mut cursor).unwrap();
370        cursor.write_all(b"invalid json").unwrap(); // Write invalid JSON payload
371        cursor.set_position(0);
372
373        let registry = ExtensionRegistry::new([]);
374        let reader = EnvelopeReader::new(cursor, &registry).unwrap();
375        let (description, result) = reader.read();
376
377        assert_matches!(result, Err(PayloadError(PayloadErrorInner::JsonRead(_))));
378        assert_eq!(description.header, header);
379    }
380
381    #[test]
382    fn test_read_text_format() {
383        let header = EnvelopeHeader {
384            format: EnvelopeFormat::ModelTextWithExtensions,
385            ..Default::default()
386        };
387        let mut cursor = Cursor::new(Vec::new());
388        header.write(&mut cursor).unwrap();
389        cursor.set_position(0);
390
391        let registry = ExtensionRegistry::new([]);
392        let reader = EnvelopeReader::new(cursor, &registry).unwrap();
393        let (description, result) = reader.read();
394
395        assert_matches!(result, Err(PayloadError(PayloadErrorInner::ModelText(_))));
396        assert_eq!(description.header, header);
397    }
398
399    #[test]
400    fn test_partial_description_on_error() {
401        let header = EnvelopeHeader {
402            format: EnvelopeFormat::PackageJson,
403            ..Default::default()
404        };
405        let mut cursor = Cursor::new(Vec::new());
406        header.write(&mut cursor).unwrap();
407        cursor.write_all(b"{\"modules\": [\"invalid\"]}").unwrap(); // Invalid module structure
408        cursor.set_position(0);
409
410        let registry = ExtensionRegistry::new([]);
411        let reader = EnvelopeReader::new(cursor, &registry).unwrap();
412        let (description, result) = reader.read();
413
414        assert_matches!(result, Err(PayloadError(PayloadErrorInner::JsonRead(_))));
415        assert_eq!(description.header, header);
416        assert_eq!(description.n_modules(), 0); // No valid modules should be set
417    }
418}