hugr_core/envelope/
header.rs

1//! Definitions for the header of an envelope.
2
3use std::io::{Read, Write};
4use std::num::NonZeroU8;
5
6use super::EnvelopeError;
7
8/// Magic number identifying the start of an envelope.
9///
10/// In ascii, this is "HUGRiHJv". The second half is a randomly generated string
11/// to avoid accidental collisions with other file formats.
12pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
13
14/// Header at the start of a binary envelope file.
15///
16/// See the [crate::envelope] module documentation for the binary format.
17#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
18#[display("EnvelopeHeader({format}{})",
19    if *zstd { ", zstd compressed" } else { "" },
20)]
21pub(super) struct EnvelopeHeader {
22    /// The format used for the payload.
23    pub format: EnvelopeFormat,
24    /// Whether the payload is compressed with zstd.
25    pub zstd: bool,
26}
27
28/// Encoded format of an envelope payload.
29#[derive(
30    Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
31)]
32#[non_exhaustive]
33pub enum EnvelopeFormat {
34    /// `hugr-model` v0 binary capnproto message.
35    Model = 1,
36    /// `hugr-model` v0 binary capnproto message followed by a json-encoded [crate::extension::ExtensionRegistry].
37    //
38    // This is a temporary format required until the model adds support for extensions.
39    ModelWithExtensions = 2,
40    /// Json-encoded [crate::package::Package]
41    ///
42    /// Uses a printable ascii value as the discriminant so the envelope can be
43    /// read as text.
44    #[default]
45    PackageJson = 63, // '?' in ascii
46}
47
48// We use a u8 to represent EnvelopeFormat in the binary format, so we should not
49// add any non-unit variants or ones with discriminants > 255.
50static_assertions::assert_eq_size!(EnvelopeFormat, u8);
51
52impl EnvelopeFormat {
53    /// Returns whether to encode the extensions as json after the hugr payload.
54    pub fn append_extensions(self) -> bool {
55        matches!(self, Self::ModelWithExtensions)
56    }
57
58    /// If the format is a model format, returns its version number.
59    pub fn model_version(self) -> Option<u32> {
60        match self {
61            Self::Model | Self::ModelWithExtensions => Some(0),
62            _ => None,
63        }
64    }
65
66    /// Returns whether the encoding format is ASCII-printable.
67    ///
68    /// If true, the encoded envelope can be read as text.
69    pub fn ascii_printable(self) -> bool {
70        matches!(self, Self::PackageJson)
71    }
72}
73
74/// Configuration for encoding an envelope.
75#[derive(Clone, Copy, Debug, Eq, PartialEq)]
76#[non_exhaustive]
77pub struct EnvelopeConfig {
78    /// The format to use for the payload.
79    pub format: EnvelopeFormat,
80    /// Whether to compress the payload with zstd, and the compression level to
81    /// use.
82    pub zstd: Option<ZstdConfig>,
83}
84
85impl Default for EnvelopeConfig {
86    fn default() -> Self {
87        let format = Default::default();
88        let zstd = if cfg!(feature = "zstd") {
89            Some(ZstdConfig::default())
90        } else {
91            None
92        };
93        Self { format, zstd }
94    }
95}
96
97impl EnvelopeConfig {
98    /// Create a new envelope header with the specified configuration.
99    pub(super) fn make_header(&self) -> EnvelopeHeader {
100        EnvelopeHeader {
101            format: self.format,
102            zstd: self.zstd.is_some(),
103        }
104    }
105
106    /// Default configuration for a plain-text envelope.
107    pub const fn text() -> Self {
108        Self {
109            format: EnvelopeFormat::PackageJson,
110            zstd: None,
111        }
112    }
113
114    /// Default configuration for a binary envelope.
115    ///
116    /// If the `zstd` feature is enabled, this will use zstd compression.
117    pub const fn binary() -> Self {
118        Self {
119            format: EnvelopeFormat::Model,
120            zstd: Some(ZstdConfig { level: None }),
121        }
122    }
123}
124
125/// Configuration for zstd compression.
126#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
127#[non_exhaustive]
128pub struct ZstdConfig {
129    /// The compression level to use.
130    ///
131    /// The current range is 1-22, where 1 is fastest and 22 is best
132    /// compression. Values above 20 should be used with caution, as they
133    /// require additional memory.
134    ///
135    /// If `None`, zstd's default level is used.
136    pub level: Option<NonZeroU8>,
137}
138
139impl ZstdConfig {
140    /// Returns the zstd compression level to pass to the zstd library.
141    ///
142    /// Uses [zstd::DEFAULT_COMPRESSION_LEVEL] if the level is not set.
143    pub fn level(&self) -> i32 {
144        #[allow(unused_assignments, unused_mut)]
145        let mut default = 0;
146        #[cfg(feature = "zstd")]
147        {
148            default = zstd::DEFAULT_COMPRESSION_LEVEL;
149        }
150        self.level.map_or(default, |l| l.get() as i32)
151    }
152}
153
154impl EnvelopeHeader {
155    /// Returns the envelope configuration corresponding to this header.
156    ///
157    /// Note that zstd compression level is not stored in the header.
158    pub fn config(&self) -> EnvelopeConfig {
159        EnvelopeConfig {
160            format: self.format,
161            zstd: match self.zstd {
162                true => Some(ZstdConfig { level: None }),
163                false => None,
164            },
165        }
166    }
167
168    /// Write an envelope header to a writer.
169    ///
170    /// See the [crate::envelope] module documentation for the binary format.
171    pub fn write(&self, writer: &mut impl Write) -> Result<(), EnvelopeError> {
172        // The first 8 bytes are the magic number in little-endian.
173        writer.write_all(MAGIC_NUMBERS)?;
174        // Next is the format descriptor.
175        let format_bytes = [self.format as u8];
176        writer.write_all(&format_bytes)?;
177        // Next is the flags byte.
178        let mut flags = 0b01000000u8;
179        flags |= self.zstd as u8;
180        writer.write_all(&[flags])?;
181
182        Ok(())
183    }
184
185    /// Reads an envelope header from a reader.
186    ///
187    /// Consumes exactly 10 bytes from the reader.
188    /// See the [crate::envelope] module documentation for the binary format.
189    pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, EnvelopeError> {
190        // The first 8 bytes are the magic number in little-endian.
191        let mut magic = [0; 8];
192        reader.read_exact(&mut magic)?;
193        if magic != MAGIC_NUMBERS {
194            return Err(EnvelopeError::MagicNumber {
195                expected: MAGIC_NUMBERS.try_into().unwrap(),
196                found: magic,
197            });
198        }
199
200        // Next is the format descriptor.
201        let mut format_bytes = [0; 1];
202        reader.read_exact(&mut format_bytes)?;
203        let format_discriminant = format_bytes[0] as usize;
204        let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
205            return Err(EnvelopeError::InvalidFormatDescriptor {
206                descriptor: format_discriminant,
207            });
208        };
209
210        // Next is the flags byte.
211        let mut flags_bytes = [0; 1];
212        reader.read_exact(&mut flags_bytes)?;
213        let zstd = flags_bytes[0] & 0x1 != 0;
214
215        Ok(Self { format, zstd })
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use rstest::rstest;
223
224    #[rstest]
225    #[case(EnvelopeFormat::Model)]
226    #[case(EnvelopeFormat::ModelWithExtensions)]
227    #[case(EnvelopeFormat::PackageJson)]
228    fn header_round_trip(#[case] format: EnvelopeFormat) {
229        // With zstd compression
230        let header = EnvelopeHeader { format, zstd: true };
231
232        let mut buffer = Vec::new();
233        header.write(&mut buffer).unwrap();
234        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
235        assert_eq!(header, read_header);
236
237        // Without zstd compression
238        let header = EnvelopeHeader {
239            format,
240            zstd: false,
241        };
242
243        let mut buffer = Vec::new();
244        header.write(&mut buffer).unwrap();
245        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
246        assert_eq!(header, read_header);
247    }
248}