hugr_core/envelope/
header.rs

1//! Definitions for the header of an envelope.
2
3use std::io::{Read, Write};
4use std::num::NonZeroU8;
5
6use itertools::Itertools;
7use thiserror::Error;
8
9use super::EnvelopeError;
10
11/// Magic number identifying the start of an envelope.
12///
13/// In ascii, this is "`HUGRiHJv`". The second half is a randomly generated string
14/// to avoid accidental collisions with other file formats.
15pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
16
17/// The all-unset header flags configuration.
18/// Bit 7 is always set to ensure we have a printable ASCII character.
19const DEFAULT_FLAGS: u8 = 0b0100_0000u8;
20/// The ZSTD flag bit in the header's flags.
21const ZSTD_FLAG: u8 = 0b0000_0001;
22
23/// Header at the start of a binary envelope file.
24///
25/// See the [`crate::envelope`] module documentation for the binary format.
26#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
27#[display("EnvelopeHeader({format}{})",
28    if *zstd { ", zstd compressed" } else { "" },
29)]
30pub struct EnvelopeHeader {
31    /// The format used for the payload.
32    pub format: EnvelopeFormat,
33    /// Whether the payload is compressed with zstd.
34    pub zstd: bool,
35}
36
37/// Encoded format of an envelope payload.
38#[derive(
39    Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
40)]
41#[non_exhaustive]
42pub enum EnvelopeFormat {
43    /// `hugr-model` v0 binary capnproto message.
44    Model = 1,
45    /// `hugr-model` v0 binary capnproto message followed by a json-encoded
46    /// [`crate::extension::ExtensionRegistry`].
47    ///
48    /// This is a temporary format required until the model adds support for
49    /// extensions.
50    ModelWithExtensions = 2,
51    /// Human-readable S-expression encoding using [`hugr_model::v0`].
52    ///
53    /// Uses a printable ascii value as the discriminant so the envelope can be
54    /// read as text.
55    ///
56    /// :caution: This format does not yet support extension encoding, so it should
57    /// be avoided.
58    //
59    // TODO: Update comment once extension encoding is supported.
60    ModelText = 40, // '(' in ascii
61    /// Human-readable S-expression encoding using [`hugr_model::v0`].
62    ///
63    /// Uses a printable ascii value as the discriminant so the envelope can be
64    /// read as text.
65    ///
66    /// This is a temporary format required until the model adds support for
67    /// extensions.
68    ModelTextWithExtensions = 41, // ')' in ascii
69    /// Json-encoded [`crate::package::Package`]
70    ///
71    /// Uses a printable ascii value as the discriminant so the envelope can be
72    /// read as text.
73    #[default]
74    PackageJson = 63, // '?' in ascii
75}
76
77// We use a u8 to represent EnvelopeFormat in the binary format, so we should not
78// add any non-unit variants or ones with discriminants > 255.
79static_assertions::assert_eq_size!(EnvelopeFormat, u8);
80
81impl EnvelopeFormat {
82    /// If the format is a model format, returns its version number.
83    #[must_use]
84    pub fn model_version(self) -> Option<u32> {
85        match self {
86            Self::Model
87            | Self::ModelWithExtensions
88            | Self::ModelText
89            | Self::ModelTextWithExtensions => Some(0),
90            _ => None,
91        }
92    }
93
94    /// Returns whether the encoding format is ASCII-printable.
95    ///
96    /// If true, the encoded envelope can be read as text.
97    #[must_use]
98    pub fn ascii_printable(self) -> bool {
99        matches!(
100            self,
101            Self::PackageJson | Self::ModelText | Self::ModelTextWithExtensions
102        )
103    }
104}
105
106/// Configuration for encoding an envelope.
107#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
108#[non_exhaustive]
109pub struct EnvelopeConfig {
110    /// The format to use for the payload.
111    pub format: EnvelopeFormat,
112    /// Whether to compress the payload with zstd, and the compression level to
113    /// use.
114    pub zstd: Option<ZstdConfig>,
115}
116
117impl EnvelopeConfig {
118    /// Create a new envelope configuration with the specified format.
119    /// `zstd` compression is disabled by default.
120    pub fn new(format: EnvelopeFormat) -> Self {
121        Self {
122            format,
123            ..Default::default()
124        }
125    }
126
127    /// Set the zstd compression configuration for the envelope.
128    pub fn with_zstd(self, zstd: ZstdConfig) -> Self {
129        Self {
130            zstd: Some(zstd),
131            ..self
132        }
133    }
134
135    /// Disable zstd compression in the envelope configuration.
136    pub fn disable_compression(self) -> Self {
137        Self { zstd: None, ..self }
138    }
139
140    /// Create a new envelope header with the specified configuration.
141    pub(super) fn make_header(&self) -> EnvelopeHeader {
142        EnvelopeHeader {
143            format: self.format,
144            zstd: self.zstd.is_some(),
145        }
146    }
147
148    /// Default configuration for a plain-text envelope.
149    #[must_use]
150    pub const fn text() -> Self {
151        Self {
152            format: EnvelopeFormat::PackageJson,
153            zstd: None,
154        }
155    }
156
157    /// Default configuration for a binary envelope.
158    ///
159    /// If the `zstd` feature is enabled, this will use zstd compression.
160    #[must_use]
161    pub const fn binary() -> Self {
162        Self {
163            format: EnvelopeFormat::ModelWithExtensions,
164            zstd: Some(ZstdConfig::default_level()),
165        }
166    }
167}
168
169/// Configuration for zstd compression.
170#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
171#[non_exhaustive]
172pub struct ZstdConfig {
173    /// The compression level to use.
174    ///
175    /// The current range is 1-22, where 1 is fastest and 22 is best
176    /// compression. Values above 20 should be used with caution, as they
177    /// require additional memory.
178    ///
179    /// If `None`, zstd's default level is used.
180    pub level: Option<NonZeroU8>,
181}
182
183impl ZstdConfig {
184    /// Create a new zstd configuration with the specified compression level.
185    pub fn new(level: u8) -> Self {
186        Self {
187            level: NonZeroU8::new(level),
188        }
189    }
190    /// Create a new zstd configuration with default compression level.
191    #[must_use]
192    pub const fn default_level() -> Self {
193        Self { level: None }
194    }
195
196    /// Returns the zstd compression level to pass to the zstd library.
197    ///
198    /// Uses [`zstd::DEFAULT_COMPRESSION_LEVEL`] if the level is not set.
199    #[must_use]
200    pub fn level(&self) -> i32 {
201        #[allow(unused_assignments, unused_mut)]
202        let mut default = 0;
203        #[cfg(feature = "zstd")]
204        {
205            default = zstd::DEFAULT_COMPRESSION_LEVEL;
206        }
207        self.level.map_or(default, |l| i32::from(l.get()))
208    }
209}
210
211#[derive(Debug, Error, derive_more::Display)]
212#[display("Error reading the envelope header. {_0}")]
213pub struct HeaderError(HeaderErrorInner);
214
215#[derive(Debug, Error)]
216#[non_exhaustive]
217pub(super) enum HeaderErrorInner {
218    /// Bad magic number.
219    #[error(
220        "Bad magic number. expected 0x{:X} found 0x{:X}",
221        u64::from_be_bytes(*expected),
222        u64::from_be_bytes(*found)
223    )]
224    MagicNumber {
225        /// The expected magic number.
226        ///
227        /// See [`MAGIC_NUMBERS`].
228        expected: [u8; 8],
229        /// The magic number in the envelope.
230        found: [u8; 8],
231    },
232    /// The specified payload format is invalid.
233    #[error("Format descriptor {descriptor} is invalid.")]
234    InvalidFormatDescriptor {
235        /// The unsupported format.
236        descriptor: usize,
237    },
238    /// IO read/write error.
239    #[error(transparent)]
240    IO {
241        /// The source error.
242        #[from]
243        source: std::io::Error,
244    },
245    /// The specified payload format is not supported.
246    #[error(
247        "The envelope configuration has unknown {}. Please update your HUGR version.",
248        if flag_ids.len() == 1 {format!("flag #{}", flag_ids[0])} else {format!("flags {}", flag_ids.iter().join(", "))}
249    )]
250    FlagUnsupported {
251        /// The unrecognized flag bits.
252        flag_ids: Vec<usize>,
253    },
254    #[cfg(not(feature = "zstd"))]
255    /// Envelope encoding required zstd compression, but the feature is not enabled.
256    #[error("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")]
257    ZstdUnsupported,
258}
259impl From<HeaderError> for EnvelopeError {
260    fn from(err: HeaderError) -> Self {
261        match err.0 {
262            HeaderErrorInner::IO { source } => EnvelopeError::IO { source },
263            HeaderErrorInner::MagicNumber { expected, found } => {
264                EnvelopeError::MagicNumber { expected, found }
265            }
266            HeaderErrorInner::InvalidFormatDescriptor { descriptor } => {
267                EnvelopeError::InvalidFormatDescriptor { descriptor }
268            }
269            HeaderErrorInner::FlagUnsupported { flag_ids } => {
270                EnvelopeError::FlagUnsupported { flag_ids }
271            }
272            #[cfg(not(feature = "zstd"))]
273            HeaderErrorInner::ZstdUnsupported => EnvelopeError::ZstdUnsupported,
274        }
275    }
276}
277
278impl<T: Into<HeaderErrorInner>> From<T> for HeaderError {
279    fn from(value: T) -> Self {
280        Self(value.into())
281    }
282}
283impl EnvelopeHeader {
284    /// Returns the envelope configuration corresponding to this header.
285    ///
286    /// Note that zstd compression level is not stored in the header.
287    pub fn config(&self) -> EnvelopeConfig {
288        EnvelopeConfig {
289            format: self.format,
290            zstd: if self.zstd {
291                Some(ZstdConfig { level: None })
292            } else {
293                None
294            },
295        }
296    }
297
298    /// Write an envelope header to a writer.
299    ///
300    /// See the [`crate::envelope`] module documentation for the binary format.
301    pub fn write(&self, writer: &mut impl Write) -> Result<(), EnvelopeError> {
302        // The first 8 bytes are the magic number in little-endian.
303        writer.write_all(MAGIC_NUMBERS)?;
304        // Next is the format descriptor.
305        let format_bytes = [self.format as u8];
306        writer.write_all(&format_bytes)?;
307        // Next is the flags byte.
308        let mut flags = DEFAULT_FLAGS;
309        if self.zstd {
310            flags |= ZSTD_FLAG;
311        }
312        writer.write_all(&[flags])?;
313
314        Ok(())
315    }
316
317    /// Reads an envelope header from a reader.
318    ///
319    /// Consumes exactly 10 bytes from the reader.
320    /// See the [`crate::envelope`] module documentation for the binary format.
321    pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, HeaderError> {
322        // The first 8 bytes are the magic number in little-endian.
323        let mut magic = [0; 8];
324        reader.read_exact(&mut magic)?;
325        if magic != MAGIC_NUMBERS {
326            return Err(HeaderErrorInner::MagicNumber {
327                expected: MAGIC_NUMBERS.try_into().unwrap(),
328                found: magic,
329            }
330            .into());
331        }
332
333        // Next is the format descriptor.
334        let mut format_bytes = [0; 1];
335        reader.read_exact(&mut format_bytes)?;
336        let format_discriminant = format_bytes[0] as usize;
337        let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
338            return Err(HeaderErrorInner::InvalidFormatDescriptor {
339                descriptor: format_discriminant,
340            }
341            .into());
342        };
343
344        // Next is the flags byte.
345        let mut flags_bytes = [0; 1];
346        reader.read_exact(&mut flags_bytes)?;
347        let flags: u8 = flags_bytes[0];
348
349        let zstd = flags & ZSTD_FLAG != 0;
350
351        // Check if there's any unrecognized flags.
352        let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG;
353        if other_flags != 0 {
354            let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec();
355            return Err(HeaderErrorInner::FlagUnsupported { flag_ids }.into());
356        }
357
358        Ok(Self { format, zstd })
359    }
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use cool_asserts::assert_matches;
366    use rstest::rstest;
367
368    #[rstest]
369    #[case(EnvelopeFormat::Model)]
370    #[case(EnvelopeFormat::ModelWithExtensions)]
371    #[case(EnvelopeFormat::ModelText)]
372    #[case(EnvelopeFormat::ModelTextWithExtensions)]
373    #[case(EnvelopeFormat::PackageJson)]
374    fn header_round_trip(#[case] format: EnvelopeFormat) {
375        // With zstd compression
376        let header = EnvelopeHeader { format, zstd: true };
377
378        let mut buffer = Vec::new();
379        header.write(&mut buffer).unwrap();
380        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
381        assert_eq!(header, read_header);
382
383        // Without zstd compression
384        let header = EnvelopeHeader {
385            format,
386            zstd: false,
387        };
388
389        let mut buffer = Vec::new();
390        header.write(&mut buffer).unwrap();
391        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
392        assert_eq!(header, read_header);
393    }
394
395    #[rstest]
396    fn header_errors() {
397        let header = EnvelopeHeader {
398            format: EnvelopeFormat::Model,
399            zstd: false,
400        };
401        let mut buffer = Vec::new();
402        header.write(&mut buffer).unwrap();
403
404        assert_eq!(buffer.len(), 10);
405        let flags = buffer[9];
406        assert_eq!(flags, DEFAULT_FLAGS);
407
408        // Invalid magic
409        let mut invalid_magic = buffer.clone();
410        invalid_magic[7] = 0xFF;
411        assert_matches!(
412            EnvelopeHeader::read(&mut invalid_magic.as_slice()),
413            Err(HeaderError(HeaderErrorInner::MagicNumber { .. }))
414        );
415
416        // Unrecognised flags
417        let mut unrecognised_flags = buffer.clone();
418        unrecognised_flags[9] |= 0b0001_0010;
419        assert_matches!(
420            EnvelopeHeader::read(&mut unrecognised_flags.as_slice()),
421            Err(HeaderError(HeaderErrorInner::FlagUnsupported { flag_ids }))
422            => assert_eq!(flag_ids, vec![1, 4])
423        );
424    }
425}