1use std::io::{Read, Write};
4use std::num::NonZeroU8;
5
6use itertools::Itertools;
7use thiserror::Error;
8
9use super::EnvelopeError;
10
11pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
16
17const DEFAULT_FLAGS: u8 = 0b0100_0000u8;
20const ZSTD_FLAG: u8 = 0b0000_0001;
22
23#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
27#[display("EnvelopeHeader({format}{})",
28 if *zstd { ", zstd compressed" } else { "" },
29)]
30pub struct EnvelopeHeader {
31 pub format: EnvelopeFormat,
33 pub zstd: bool,
35}
36
37#[derive(
39 Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
40)]
41#[non_exhaustive]
42pub enum EnvelopeFormat {
43 Model = 1,
45 ModelWithExtensions = 2,
51 ModelText = 40, ModelTextWithExtensions = 41, #[default]
74 PackageJson = 63, }
76
77static_assertions::assert_eq_size!(EnvelopeFormat, u8);
80
81impl EnvelopeFormat {
82 #[must_use]
84 pub fn model_version(self) -> Option<u32> {
85 match self {
86 Self::Model
87 | Self::ModelWithExtensions
88 | Self::ModelText
89 | Self::ModelTextWithExtensions => Some(0),
90 _ => None,
91 }
92 }
93
94 #[must_use]
98 pub fn ascii_printable(self) -> bool {
99 matches!(
100 self,
101 Self::PackageJson | Self::ModelText | Self::ModelTextWithExtensions
102 )
103 }
104}
105
106#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
108#[non_exhaustive]
109pub struct EnvelopeConfig {
110 pub format: EnvelopeFormat,
112 pub zstd: Option<ZstdConfig>,
115}
116
117impl EnvelopeConfig {
118 pub fn new(format: EnvelopeFormat) -> Self {
121 Self {
122 format,
123 ..Default::default()
124 }
125 }
126
127 pub fn with_zstd(self, zstd: ZstdConfig) -> Self {
129 Self {
130 zstd: Some(zstd),
131 ..self
132 }
133 }
134
135 pub fn disable_compression(self) -> Self {
137 Self { zstd: None, ..self }
138 }
139
140 pub(super) fn make_header(&self) -> EnvelopeHeader {
142 EnvelopeHeader {
143 format: self.format,
144 zstd: self.zstd.is_some(),
145 }
146 }
147
148 #[must_use]
150 pub const fn text() -> Self {
151 Self {
152 format: EnvelopeFormat::PackageJson,
153 zstd: None,
154 }
155 }
156
157 #[must_use]
161 pub const fn binary() -> Self {
162 Self {
163 format: EnvelopeFormat::ModelWithExtensions,
164 zstd: Some(ZstdConfig::default_level()),
165 }
166 }
167}
168
169#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
171#[non_exhaustive]
172pub struct ZstdConfig {
173 pub level: Option<NonZeroU8>,
181}
182
183impl ZstdConfig {
184 pub fn new(level: u8) -> Self {
186 Self {
187 level: NonZeroU8::new(level),
188 }
189 }
190 #[must_use]
192 pub const fn default_level() -> Self {
193 Self { level: None }
194 }
195
196 #[must_use]
200 pub fn level(&self) -> i32 {
201 #[allow(unused_assignments, unused_mut)]
202 let mut default = 0;
203 #[cfg(feature = "zstd")]
204 {
205 default = zstd::DEFAULT_COMPRESSION_LEVEL;
206 }
207 self.level.map_or(default, |l| i32::from(l.get()))
208 }
209}
210
211#[derive(Debug, Error, derive_more::Display)]
212#[display("Error reading the envelope header. {_0}")]
213pub struct HeaderError(HeaderErrorInner);
214
215#[derive(Debug, Error)]
216#[non_exhaustive]
217pub(super) enum HeaderErrorInner {
218 #[error(
220 "Bad magic number. expected 0x{:X} found 0x{:X}",
221 u64::from_be_bytes(*expected),
222 u64::from_be_bytes(*found)
223 )]
224 MagicNumber {
225 expected: [u8; 8],
229 found: [u8; 8],
231 },
232 #[error("Format descriptor {descriptor} is invalid.")]
234 InvalidFormatDescriptor {
235 descriptor: usize,
237 },
238 #[error(transparent)]
240 IO {
241 #[from]
243 source: std::io::Error,
244 },
245 #[error(
247 "The envelope configuration has unknown {}. Please update your HUGR version.",
248 if flag_ids.len() == 1 {format!("flag #{}", flag_ids[0])} else {format!("flags {}", flag_ids.iter().join(", "))}
249 )]
250 FlagUnsupported {
251 flag_ids: Vec<usize>,
253 },
254 #[cfg(not(feature = "zstd"))]
255 #[error("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")]
257 ZstdUnsupported,
258}
259impl From<HeaderError> for EnvelopeError {
260 fn from(err: HeaderError) -> Self {
261 match err.0 {
262 HeaderErrorInner::IO { source } => EnvelopeError::IO { source },
263 HeaderErrorInner::MagicNumber { expected, found } => {
264 EnvelopeError::MagicNumber { expected, found }
265 }
266 HeaderErrorInner::InvalidFormatDescriptor { descriptor } => {
267 EnvelopeError::InvalidFormatDescriptor { descriptor }
268 }
269 HeaderErrorInner::FlagUnsupported { flag_ids } => {
270 EnvelopeError::FlagUnsupported { flag_ids }
271 }
272 #[cfg(not(feature = "zstd"))]
273 HeaderErrorInner::ZstdUnsupported => EnvelopeError::ZstdUnsupported,
274 }
275 }
276}
277
278impl<T: Into<HeaderErrorInner>> From<T> for HeaderError {
279 fn from(value: T) -> Self {
280 Self(value.into())
281 }
282}
283impl EnvelopeHeader {
284 pub fn config(&self) -> EnvelopeConfig {
288 EnvelopeConfig {
289 format: self.format,
290 zstd: if self.zstd {
291 Some(ZstdConfig { level: None })
292 } else {
293 None
294 },
295 }
296 }
297
298 pub fn write(&self, writer: &mut impl Write) -> Result<(), EnvelopeError> {
302 writer.write_all(MAGIC_NUMBERS)?;
304 let format_bytes = [self.format as u8];
306 writer.write_all(&format_bytes)?;
307 let mut flags = DEFAULT_FLAGS;
309 if self.zstd {
310 flags |= ZSTD_FLAG;
311 }
312 writer.write_all(&[flags])?;
313
314 Ok(())
315 }
316
317 pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, HeaderError> {
322 let mut magic = [0; 8];
324 reader.read_exact(&mut magic)?;
325 if magic != MAGIC_NUMBERS {
326 return Err(HeaderErrorInner::MagicNumber {
327 expected: MAGIC_NUMBERS.try_into().unwrap(),
328 found: magic,
329 }
330 .into());
331 }
332
333 let mut format_bytes = [0; 1];
335 reader.read_exact(&mut format_bytes)?;
336 let format_discriminant = format_bytes[0] as usize;
337 let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
338 return Err(HeaderErrorInner::InvalidFormatDescriptor {
339 descriptor: format_discriminant,
340 }
341 .into());
342 };
343
344 let mut flags_bytes = [0; 1];
346 reader.read_exact(&mut flags_bytes)?;
347 let flags: u8 = flags_bytes[0];
348
349 let zstd = flags & ZSTD_FLAG != 0;
350
351 let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG;
353 if other_flags != 0 {
354 let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec();
355 return Err(HeaderErrorInner::FlagUnsupported { flag_ids }.into());
356 }
357
358 Ok(Self { format, zstd })
359 }
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use cool_asserts::assert_matches;
366 use rstest::rstest;
367
368 #[rstest]
369 #[case(EnvelopeFormat::Model)]
370 #[case(EnvelopeFormat::ModelWithExtensions)]
371 #[case(EnvelopeFormat::ModelText)]
372 #[case(EnvelopeFormat::ModelTextWithExtensions)]
373 #[case(EnvelopeFormat::PackageJson)]
374 fn header_round_trip(#[case] format: EnvelopeFormat) {
375 let header = EnvelopeHeader { format, zstd: true };
377
378 let mut buffer = Vec::new();
379 header.write(&mut buffer).unwrap();
380 let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
381 assert_eq!(header, read_header);
382
383 let header = EnvelopeHeader {
385 format,
386 zstd: false,
387 };
388
389 let mut buffer = Vec::new();
390 header.write(&mut buffer).unwrap();
391 let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
392 assert_eq!(header, read_header);
393 }
394
395 #[rstest]
396 fn header_errors() {
397 let header = EnvelopeHeader {
398 format: EnvelopeFormat::Model,
399 zstd: false,
400 };
401 let mut buffer = Vec::new();
402 header.write(&mut buffer).unwrap();
403
404 assert_eq!(buffer.len(), 10);
405 let flags = buffer[9];
406 assert_eq!(flags, DEFAULT_FLAGS);
407
408 let mut invalid_magic = buffer.clone();
410 invalid_magic[7] = 0xFF;
411 assert_matches!(
412 EnvelopeHeader::read(&mut invalid_magic.as_slice()),
413 Err(HeaderError(HeaderErrorInner::MagicNumber { .. }))
414 );
415
416 let mut unrecognised_flags = buffer.clone();
418 unrecognised_flags[9] |= 0b0001_0010;
419 assert_matches!(
420 EnvelopeHeader::read(&mut unrecognised_flags.as_slice()),
421 Err(HeaderError(HeaderErrorInner::FlagUnsupported { flag_ids }))
422 => assert_eq!(flag_ids, vec![1, 4])
423 );
424 }
425}