hugr_core/envelope/
header.rs

1//! Definitions for the header of an envelope.
2
3use std::io::{Read, Write};
4use std::num::NonZeroU8;
5
6use itertools::Itertools;
7
8use super::EnvelopeError;
9
10/// Magic number identifying the start of an envelope.
11///
12/// In ascii, this is "`HUGRiHJv`". The second half is a randomly generated string
13/// to avoid accidental collisions with other file formats.
14pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
15
16/// The all-unset header flags configuration.
17/// Bit 7 is always set to ensure we have a printable ASCII character.
18const DEFAULT_FLAGS: u8 = 0b0100_0000u8;
19/// The ZSTD flag bit in the header's flags.
20const ZSTD_FLAG: u8 = 0b0000_0001;
21
22/// Header at the start of a binary envelope file.
23///
24/// See the [`crate::envelope`] module documentation for the binary format.
25#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
26#[display("EnvelopeHeader({format}{})",
27    if *zstd { ", zstd compressed" } else { "" },
28)]
29pub(super) struct EnvelopeHeader {
30    /// The format used for the payload.
31    pub format: EnvelopeFormat,
32    /// Whether the payload is compressed with zstd.
33    pub zstd: bool,
34}
35
36/// Encoded format of an envelope payload.
37#[derive(
38    Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
39)]
40#[non_exhaustive]
41pub enum EnvelopeFormat {
42    /// `hugr-model` v0 binary capnproto message.
43    Model = 1,
44    /// `hugr-model` v0 binary capnproto message followed by a json-encoded
45    /// [`crate::extension::ExtensionRegistry`].
46    ///
47    /// This is a temporary format required until the model adds support for
48    /// extensions.
49    ModelWithExtensions = 2,
50    /// Human-readable S-expression encoding using [`hugr_model::v0`].
51    ///
52    /// Uses a printable ascii value as the discriminant so the envelope can be
53    /// read as text.
54    ///
55    /// :caution: This format does not yet support extension encoding, so it should
56    /// be avoided.
57    //
58    // TODO: Update comment once extension encoding is supported.
59    ModelText = 40, // '(' in ascii
60    /// Human-readable S-expression encoding using [`hugr_model::v0`].
61    ///
62    /// Uses a printable ascii value as the discriminant so the envelope can be
63    /// read as text.
64    ///
65    /// This is a temporary format required until the model adds support for
66    /// extensions.
67    ModelTextWithExtensions = 41, // ')' in ascii
68    /// Json-encoded [`crate::package::Package`]
69    ///
70    /// Uses a printable ascii value as the discriminant so the envelope can be
71    /// read as text.
72    #[default]
73    PackageJson = 63, // '?' in ascii
74}
75
76// We use a u8 to represent EnvelopeFormat in the binary format, so we should not
77// add any non-unit variants or ones with discriminants > 255.
78static_assertions::assert_eq_size!(EnvelopeFormat, u8);
79
80impl EnvelopeFormat {
81    /// If the format is a model format, returns its version number.
82    #[must_use]
83    pub fn model_version(self) -> Option<u32> {
84        match self {
85            Self::Model
86            | Self::ModelWithExtensions
87            | Self::ModelText
88            | Self::ModelTextWithExtensions => Some(0),
89            _ => None,
90        }
91    }
92
93    /// Returns whether the encoding format is ASCII-printable.
94    ///
95    /// If true, the encoded envelope can be read as text.
96    #[must_use]
97    pub fn ascii_printable(self) -> bool {
98        matches!(
99            self,
100            Self::PackageJson | Self::ModelText | Self::ModelTextWithExtensions
101        )
102    }
103}
104
105/// Configuration for encoding an envelope.
106#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
107#[non_exhaustive]
108pub struct EnvelopeConfig {
109    /// The format to use for the payload.
110    pub format: EnvelopeFormat,
111    /// Whether to compress the payload with zstd, and the compression level to
112    /// use.
113    pub zstd: Option<ZstdConfig>,
114}
115
116impl EnvelopeConfig {
117    /// Create a new envelope configuration with the specified format.
118    /// `zstd` compression is disabled by default.
119    pub fn new(format: EnvelopeFormat) -> Self {
120        Self {
121            format,
122            ..Default::default()
123        }
124    }
125
126    /// Set the zstd compression configuration for the envelope.
127    pub fn with_zstd(self, zstd: ZstdConfig) -> Self {
128        Self {
129            zstd: Some(zstd),
130            ..self
131        }
132    }
133
134    /// Disable zstd compression in the envelope configuration.
135    pub fn disable_compression(self) -> Self {
136        Self { zstd: None, ..self }
137    }
138
139    /// Create a new envelope header with the specified configuration.
140    pub(super) fn make_header(&self) -> EnvelopeHeader {
141        EnvelopeHeader {
142            format: self.format,
143            zstd: self.zstd.is_some(),
144        }
145    }
146
147    /// Default configuration for a plain-text envelope.
148    #[must_use]
149    pub const fn text() -> Self {
150        Self {
151            format: EnvelopeFormat::PackageJson,
152            zstd: None,
153        }
154    }
155
156    /// Default configuration for a binary envelope.
157    ///
158    /// If the `zstd` feature is enabled, this will use zstd compression.
159    #[must_use]
160    pub const fn binary() -> Self {
161        Self {
162            format: EnvelopeFormat::ModelWithExtensions,
163            zstd: Some(ZstdConfig::default_level()),
164        }
165    }
166}
167
168/// Configuration for zstd compression.
169#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
170#[non_exhaustive]
171pub struct ZstdConfig {
172    /// The compression level to use.
173    ///
174    /// The current range is 1-22, where 1 is fastest and 22 is best
175    /// compression. Values above 20 should be used with caution, as they
176    /// require additional memory.
177    ///
178    /// If `None`, zstd's default level is used.
179    pub level: Option<NonZeroU8>,
180}
181
182impl ZstdConfig {
183    /// Create a new zstd configuration with the specified compression level.
184    pub fn new(level: u8) -> Self {
185        Self {
186            level: NonZeroU8::new(level),
187        }
188    }
189    /// Create a new zstd configuration with default compression level.
190    #[must_use]
191    pub const fn default_level() -> Self {
192        Self { level: None }
193    }
194
195    /// Returns the zstd compression level to pass to the zstd library.
196    ///
197    /// Uses [`zstd::DEFAULT_COMPRESSION_LEVEL`] if the level is not set.
198    #[must_use]
199    pub fn level(&self) -> i32 {
200        #[allow(unused_assignments, unused_mut)]
201        let mut default = 0;
202        #[cfg(feature = "zstd")]
203        {
204            default = zstd::DEFAULT_COMPRESSION_LEVEL;
205        }
206        self.level.map_or(default, |l| i32::from(l.get()))
207    }
208}
209
210impl EnvelopeHeader {
211    /// Returns the envelope configuration corresponding to this header.
212    ///
213    /// Note that zstd compression level is not stored in the header.
214    pub fn config(&self) -> EnvelopeConfig {
215        EnvelopeConfig {
216            format: self.format,
217            zstd: if self.zstd {
218                Some(ZstdConfig { level: None })
219            } else {
220                None
221            },
222        }
223    }
224
225    /// Write an envelope header to a writer.
226    ///
227    /// See the [`crate::envelope`] module documentation for the binary format.
228    pub fn write(&self, writer: &mut impl Write) -> Result<(), EnvelopeError> {
229        // The first 8 bytes are the magic number in little-endian.
230        writer.write_all(MAGIC_NUMBERS)?;
231        // Next is the format descriptor.
232        let format_bytes = [self.format as u8];
233        writer.write_all(&format_bytes)?;
234        // Next is the flags byte.
235        let mut flags = DEFAULT_FLAGS;
236        if self.zstd {
237            flags |= ZSTD_FLAG;
238        }
239        writer.write_all(&[flags])?;
240
241        Ok(())
242    }
243
244    /// Reads an envelope header from a reader.
245    ///
246    /// Consumes exactly 10 bytes from the reader.
247    /// See the [`crate::envelope`] module documentation for the binary format.
248    pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, EnvelopeError> {
249        // The first 8 bytes are the magic number in little-endian.
250        let mut magic = [0; 8];
251        reader.read_exact(&mut magic)?;
252        if magic != MAGIC_NUMBERS {
253            return Err(EnvelopeError::MagicNumber {
254                expected: MAGIC_NUMBERS.try_into().unwrap(),
255                found: magic,
256            });
257        }
258
259        // Next is the format descriptor.
260        let mut format_bytes = [0; 1];
261        reader.read_exact(&mut format_bytes)?;
262        let format_discriminant = format_bytes[0] as usize;
263        let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
264            return Err(EnvelopeError::InvalidFormatDescriptor {
265                descriptor: format_discriminant,
266            });
267        };
268
269        // Next is the flags byte.
270        let mut flags_bytes = [0; 1];
271        reader.read_exact(&mut flags_bytes)?;
272        let flags: u8 = flags_bytes[0];
273
274        let zstd = flags & ZSTD_FLAG != 0;
275
276        // Check if there's any unrecognized flags.
277        let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG;
278        if other_flags != 0 {
279            let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec();
280            return Err(EnvelopeError::FlagUnsupported { flag_ids });
281        }
282
283        Ok(Self { format, zstd })
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use cool_asserts::assert_matches;
291    use rstest::rstest;
292
293    #[rstest]
294    #[case(EnvelopeFormat::Model)]
295    #[case(EnvelopeFormat::ModelWithExtensions)]
296    #[case(EnvelopeFormat::ModelText)]
297    #[case(EnvelopeFormat::ModelTextWithExtensions)]
298    #[case(EnvelopeFormat::PackageJson)]
299    fn header_round_trip(#[case] format: EnvelopeFormat) {
300        // With zstd compression
301        let header = EnvelopeHeader { format, zstd: true };
302
303        let mut buffer = Vec::new();
304        header.write(&mut buffer).unwrap();
305        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
306        assert_eq!(header, read_header);
307
308        // Without zstd compression
309        let header = EnvelopeHeader {
310            format,
311            zstd: false,
312        };
313
314        let mut buffer = Vec::new();
315        header.write(&mut buffer).unwrap();
316        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
317        assert_eq!(header, read_header);
318    }
319
320    #[rstest]
321    fn header_errors() {
322        let header = EnvelopeHeader {
323            format: EnvelopeFormat::Model,
324            zstd: false,
325        };
326        let mut buffer = Vec::new();
327        header.write(&mut buffer).unwrap();
328
329        assert_eq!(buffer.len(), 10);
330        let flags = buffer[9];
331        assert_eq!(flags, DEFAULT_FLAGS);
332
333        // Invalid magic
334        let mut invalid_magic = buffer.clone();
335        invalid_magic[7] = 0xFF;
336        assert_matches!(
337            EnvelopeHeader::read(&mut invalid_magic.as_slice()),
338            Err(EnvelopeError::MagicNumber { .. })
339        );
340
341        // Unrecognised flags
342        let mut unrecognised_flags = buffer.clone();
343        unrecognised_flags[9] |= 0b0001_0010;
344        assert_matches!(
345            EnvelopeHeader::read(&mut unrecognised_flags.as_slice()),
346            Err(EnvelopeError::FlagUnsupported { flag_ids })
347            => assert_eq!(flag_ids, vec![1, 4])
348        );
349    }
350}