1use std::io::{BufRead, Read};
2use std::str::FromStr as _;
3
4use hugr_model::v0::table;
5use itertools::{Either, Itertools as _};
6
7use crate::HugrView as _;
8use crate::envelope::description::PackageDesc;
9use crate::envelope::header::{EnvelopeFormat, HeaderError};
10use crate::envelope::{
11 EnvelopeError, EnvelopeHeader, ExtensionBreakingError, FormatUnsupportedError,
12};
13use crate::extension::resolution::{ExtensionResolutionError, WeakExtensionRegistry};
14use crate::extension::{Extension, ExtensionRegistry};
15use crate::import::{ImportError, import_described_hugr};
16use crate::package::Package;
17
18use super::{check_breaking_extensions, check_model_version, package_json::PackageEncodingError};
19use thiserror::Error;
20
21use hugr_model::v0::bumpalo::Bump;
22#[cfg(feature = "zstd")]
23type RightType<R> = std::io::BufReader<zstd::Decoder<'static, std::io::BufReader<R>>>;
24#[cfg(not(feature = "zstd"))]
25type RightType<R> = std::io::BufReader<R>;
26
27pub(crate) struct MaybeZstdRead<R>(Either<R, RightType<R>>);
28
29impl<R> std::io::Read for MaybeZstdRead<R>
30where
31 R: std::io::Read,
32{
33 fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
34 match &mut self.0 {
35 Either::Left(r) => r.read(buf),
36 Either::Right(r) => r.read(buf),
37 }
38 }
39}
40
41impl<R> std::io::BufRead for MaybeZstdRead<R>
42where
43 R: std::io::BufRead,
44{
45 fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
46 match &mut self.0 {
47 Either::Left(r) => r.fill_buf(),
48 Either::Right(r) => r.fill_buf(),
49 }
50 }
51
52 fn consume(&mut self, amt: usize) {
53 match &mut self.0 {
54 Either::Left(r) => r.consume(amt),
55 Either::Right(r) => r.consume(amt),
56 }
57 }
58}
59
60pub(super) struct EnvelopeReader<R> {
65 description: PackageDesc,
66 reader: MaybeZstdRead<R>,
67 registry: ExtensionRegistry,
68}
69
70impl<R: BufRead> EnvelopeReader<R> {
71 pub(super) fn new(mut reader: R, registry: &ExtensionRegistry) -> Result<Self, HeaderError> {
79 let header = EnvelopeHeader::read(&mut reader)?;
80 let reader = match header.zstd {
81 #[cfg(feature = "zstd")]
82 true => Either::Right(std::io::BufReader::new(zstd::Decoder::new(reader)?)),
83 #[cfg(not(feature = "zstd"))]
84 true => Err(super::header::HeaderErrorInner::ZstdUnsupported)?,
85 false => Either::Left(reader),
86 };
87 Ok(Self {
88 description: PackageDesc::new(header),
89 reader: MaybeZstdRead(reader),
90 registry: registry.clone(),
91 })
92 }
93
94 pub(crate) fn description(&self) -> &PackageDesc {
95 &self.description
96 }
97
98 fn header(&self) -> &EnvelopeHeader {
99 &self.description.header
100 }
101
102 fn register_packaged(&mut self, extensions: &ExtensionRegistry) {
103 self.registry.extend(extensions);
104 }
105
106 fn read_impl(&mut self) -> Result<Package, PayloadError> {
107 let mut package = match self.header().format {
108 EnvelopeFormat::PackageJson => self.decode_json()?,
109 EnvelopeFormat::Model | EnvelopeFormat::ModelWithExtensions => self.decode_model()?,
110 EnvelopeFormat::ModelText | EnvelopeFormat::ModelTextWithExtensions => {
111 self.decode_model_ast()?
112 }
113 };
114 self.description.set_n_modules(package.modules.len());
115 for (index, module) in package.modules.iter_mut().enumerate() {
116 let desc = &mut self.description.modules[index];
117 let desc = desc.get_or_insert_default();
118 desc.load_used_extensions_generator(module)
119 .map_err(ExtensionBreakingError::from)?;
120 if let Some(used_exts) = &mut desc.used_extensions_generator {
121 check_breaking_extensions(module.extensions(), used_exts.drain(..))?;
122 }
123
124 module.resolve_extension_defs(&self.registry)?;
125 desc.load_from_hugr(&module);
128 }
129
130 for (index, ext) in package.extensions.iter().enumerate() {
131 self.description.set_packaged_extension(index, ext);
132 }
133 Ok(package)
134 }
135
136 pub(super) fn read(mut self) -> (PackageDesc, Result<Package, PayloadError>) {
146 let res = self.read_impl();
147
148 (self.description, res)
149 }
150
151 fn decode_json(&mut self) -> Result<Package, PackageEncodingError> {
155 let super::package_json::PackageDeser {
156 modules,
157 extensions: pkg_extensions,
158 } = serde_json::from_reader(&mut self.reader)?;
159 let modules = modules.into_iter().map(|h| h.0).collect_vec();
160 let pkg_extensions = ExtensionRegistry::new_with_extension_resolution(
161 pkg_extensions,
162 &WeakExtensionRegistry::from(&self.registry),
163 )?;
164
165 self.register_packaged(&pkg_extensions);
167 Ok(Package {
168 modules,
169 extensions: pkg_extensions,
170 })
171 }
172 fn decode_model(&mut self) -> Result<Package, ModelBinaryReadError> {
174 check_model_version(self.header().format)?;
175 let bump = Bump::default();
176 let model_package = hugr_model::v0::binary::read_from_reader(&mut self.reader, &bump)?;
177
178 let packaged_extensions = if self.header().format == EnvelopeFormat::ModelWithExtensions {
179 ExtensionRegistry::load_json(&mut self.reader, &self.registry)?
180 } else {
181 ExtensionRegistry::new([])
182 };
183 self.register_packaged(&packaged_extensions);
184
185 self.import_package(&model_package, packaged_extensions)
186 .map_err(Into::into)
187 }
188
189 fn decode_model_ast(&mut self) -> Result<Package, ModelTextReadError> {
191 let format = self.header().format;
192 check_model_version(format)?;
193
194 let packaged_extensions = if format == EnvelopeFormat::ModelTextWithExtensions {
195 let deserializer = serde_json::Deserializer::from_reader(&mut self.reader);
196 let extra_extensions = deserializer
198 .into_iter::<Vec<Extension>>()
199 .next()
200 .unwrap_or(Ok(vec![]))?;
201 ExtensionRegistry::new(extra_extensions.into_iter().map(std::sync::Arc::new))
202 } else {
203 ExtensionRegistry::new([])
204 };
205
206 let mut buffer = String::new();
210 self.reader.read_to_string(&mut buffer)?;
211 let ast_package = hugr_model::v0::ast::Package::from_str(&buffer)?;
212
213 let bump = Bump::default();
214 let model_package = ast_package.resolve(&bump)?;
215
216 self.import_package(&model_package, packaged_extensions)
217 .map_err(Into::into)
218 }
219
220 fn import_package(
221 &mut self,
222 package: &table::Package,
223 packaged_extensions: ExtensionRegistry,
224 ) -> Result<Package, crate::import::ImportError> {
225 self.description.set_n_modules(package.modules.len());
226
227 let modules = package
228 .modules
229 .iter()
230 .enumerate()
231 .map(|(index, module)| {
232 let (desc, result) = import_described_hugr(module, &self.registry);
233 self.description.set_module(index, desc);
234 result
235 })
236 .collect::<Result<Vec<_>, _>>()?;
237
238 let mut package = Package::new(modules);
240 package.extensions = packaged_extensions;
241 Ok(package)
242 }
243}
244
245#[derive(Error, Debug)]
246#[non_exhaustive]
247#[error(transparent)]
249pub struct PayloadError(PayloadErrorInner);
250
251#[derive(Error, Debug)]
252#[non_exhaustive]
253#[error(transparent)]
254enum PayloadErrorInner {
256 JsonRead(#[from] PackageEncodingError),
258 ModelBinary(#[from] ModelBinaryReadError),
260 ModelText(#[from] ModelTextReadError),
262 ExtensionsBreaking(#[from] ExtensionBreakingError),
264 ExtensionResolution(#[from] ExtensionResolutionError),
266}
267impl From<PayloadError> for EnvelopeError {
268 fn from(value: PayloadError) -> Self {
269 match value.0 {
270 PayloadErrorInner::JsonRead(e) => e.into(),
271 PayloadErrorInner::ModelBinary(e) => e.into(),
272 PayloadErrorInner::ModelText(e) => e.into(),
273 #[expect(deprecated)]
274 PayloadErrorInner::ExtensionsBreaking(e) => super::WithGenerator {
275 inner: Box::new(e),
276 generator: None,
277 }
278 .into(),
279 PayloadErrorInner::ExtensionResolution(e) => e.into(),
280 }
281 }
282}
283
284impl<T: Into<PayloadErrorInner>> From<T> for PayloadError {
285 fn from(value: T) -> Self {
286 Self(value.into())
287 }
288}
289
290#[derive(Debug, Error)]
291#[error(transparent)]
292enum ModelTextReadError {
293 ParseString(#[from] hugr_model::v0::ast::ParseError),
294 Import(#[from] ImportError),
295 ExtensionLoad(#[from] crate::extension::ExtensionRegistryLoadError),
296 FormatUnsupported(#[from] FormatUnsupportedError),
297 ExtensionDeserialize(#[from] serde_json::Error),
298 StringRead(#[from] std::io::Error),
299 ResolveError(#[from] hugr_model::v0::ast::ResolveError),
300}
301impl From<ModelTextReadError> for EnvelopeError {
302 fn from(value: ModelTextReadError) -> Self {
303 match value {
304 ModelTextReadError::FormatUnsupported(e) => EnvelopeError::FormatUnsupported {
305 format: e.format,
306 feature: e.feature,
307 },
308 ModelTextReadError::ParseString(e) => e.into(),
309 ModelTextReadError::Import(e) => e.into(),
310 ModelTextReadError::ExtensionLoad(e) => e.into(),
311 ModelTextReadError::ExtensionDeserialize(e) => e.into(),
312 ModelTextReadError::StringRead(e) => e.into(),
313 ModelTextReadError::ResolveError(e) => e.into(),
314 }
315 }
316}
317
318#[derive(Debug, Error)]
319#[error(transparent)]
320enum ModelBinaryReadError {
321 ParseString(#[from] hugr_model::v0::ast::ParseError),
322 ReadBinary(#[from] hugr_model::v0::binary::ReadError),
323 Import(#[from] ImportError),
324 Extensions(#[from] crate::extension::ExtensionRegistryLoadError),
325 FormatUnsupported(#[from] FormatUnsupportedError),
326}
327
328impl From<ModelBinaryReadError> for EnvelopeError {
329 fn from(value: ModelBinaryReadError) -> Self {
330 match value {
331 ModelBinaryReadError::FormatUnsupported(e) => EnvelopeError::FormatUnsupported {
332 format: e.format,
333 feature: e.feature,
334 },
335 ModelBinaryReadError::ParseString(e) => e.into(),
336 ModelBinaryReadError::ReadBinary(e) => e.into(),
337 ModelBinaryReadError::Import(e) => e.into(),
338 ModelBinaryReadError::Extensions(e) => e.into(),
339 }
340 }
341}
342
343#[cfg(test)]
344mod test {
345 use super::*;
346
347 use crate::extension::ExtensionRegistry;
348
349 use crate::envelope::header::EnvelopeHeader;
350 use cool_asserts::assert_matches;
351
352 use std::io::{Cursor, Write as _};
353
354 #[test]
355 fn test_read_invalid_header() {
356 let cursor = Cursor::new(Vec::new()); let registry = ExtensionRegistry::new([]);
358 let result = EnvelopeReader::new(cursor, ®istry);
359 assert!(result.is_err());
360 }
361
362 #[test]
363 fn test_read_invalid_json_payload() {
364 let header = EnvelopeHeader {
365 format: EnvelopeFormat::PackageJson,
366 ..Default::default()
367 };
368 let mut cursor = Cursor::new(Vec::new());
369 header.write(&mut cursor).unwrap();
370 cursor.write_all(b"invalid json").unwrap(); cursor.set_position(0);
372
373 let registry = ExtensionRegistry::new([]);
374 let reader = EnvelopeReader::new(cursor, ®istry).unwrap();
375 let (description, result) = reader.read();
376
377 assert_matches!(result, Err(PayloadError(PayloadErrorInner::JsonRead(_))));
378 assert_eq!(description.header, header);
379 }
380
381 #[test]
382 fn test_read_text_format() {
383 let header = EnvelopeHeader {
384 format: EnvelopeFormat::ModelTextWithExtensions,
385 ..Default::default()
386 };
387 let mut cursor = Cursor::new(Vec::new());
388 header.write(&mut cursor).unwrap();
389 cursor.set_position(0);
390
391 let registry = ExtensionRegistry::new([]);
392 let reader = EnvelopeReader::new(cursor, ®istry).unwrap();
393 let (description, result) = reader.read();
394
395 assert_matches!(result, Err(PayloadError(PayloadErrorInner::ModelText(_))));
396 assert_eq!(description.header, header);
397 }
398
399 #[test]
400 fn test_partial_description_on_error() {
401 let header = EnvelopeHeader {
402 format: EnvelopeFormat::PackageJson,
403 ..Default::default()
404 };
405 let mut cursor = Cursor::new(Vec::new());
406 header.write(&mut cursor).unwrap();
407 cursor.write_all(b"{\"modules\": [\"invalid\"]}").unwrap(); cursor.set_position(0);
409
410 let registry = ExtensionRegistry::new([]);
411 let reader = EnvelopeReader::new(cursor, ®istry).unwrap();
412 let (description, result) = reader.read();
413
414 assert_matches!(result, Err(PayloadError(PayloadErrorInner::JsonRead(_))));
415 assert_eq!(description.header, header);
416 assert_eq!(description.n_modules(), 0); }
418}