hugr_core/envelope/
header.rs

1//! Definitions for the header of an envelope.
2
3use std::io::{Read, Write};
4use std::num::NonZeroU8;
5
6use super::EnvelopeError;
7
8/// Magic number identifying the start of an envelope.
9///
10/// In ascii, this is "`HUGRiHJv`". The second half is a randomly generated string
11/// to avoid accidental collisions with other file formats.
12pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
13
14/// Header at the start of a binary envelope file.
15///
16/// See the [`crate::envelope`] module documentation for the binary format.
17#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
18#[display("EnvelopeHeader({format}{})",
19    if *zstd { ", zstd compressed" } else { "" },
20)]
21pub(super) struct EnvelopeHeader {
22    /// The format used for the payload.
23    pub format: EnvelopeFormat,
24    /// Whether the payload is compressed with zstd.
25    pub zstd: bool,
26}
27
28/// Encoded format of an envelope payload.
29#[derive(
30    Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
31)]
32#[non_exhaustive]
33pub enum EnvelopeFormat {
34    /// `hugr-model` v0 binary capnproto message.
35    Model = 1,
36    /// `hugr-model` v0 binary capnproto message followed by a json-encoded
37    /// [`crate::extension::ExtensionRegistry`].
38    ///
39    /// This is a temporary format required until the model adds support for
40    /// extensions.
41    ModelWithExtensions = 2,
42    /// Human-readable S-expression encoding using [`hugr_model::v0`].
43    ///
44    /// Uses a printable ascii value as the discriminant so the envelope can be
45    /// read as text.
46    ///
47    /// :caution: This format does not yet support extension encoding, so it should
48    /// be avoided.
49    //
50    // TODO: Update comment once extension encoding is supported.
51    ModelText = 40, // '(' in ascii
52    /// Human-readable S-expression encoding using [`hugr_model::v0`].
53    ///
54    /// Uses a printable ascii value as the discriminant so the envelope can be
55    /// read as text.
56    ///
57    /// This is a temporary format required until the model adds support for
58    /// extensions.
59    ModelTextWithExtensions = 41, // ')' in ascii
60    /// Json-encoded [`crate::package::Package`]
61    ///
62    /// Uses a printable ascii value as the discriminant so the envelope can be
63    /// read as text.
64    #[default]
65    PackageJson = 63, // '?' in ascii
66}
67
68// We use a u8 to represent EnvelopeFormat in the binary format, so we should not
69// add any non-unit variants or ones with discriminants > 255.
70static_assertions::assert_eq_size!(EnvelopeFormat, u8);
71
72impl EnvelopeFormat {
73    /// If the format is a model format, returns its version number.
74    #[must_use]
75    pub fn model_version(self) -> Option<u32> {
76        match self {
77            Self::Model
78            | Self::ModelWithExtensions
79            | Self::ModelText
80            | Self::ModelTextWithExtensions => Some(0),
81            _ => None,
82        }
83    }
84
85    /// Returns whether the encoding format is ASCII-printable.
86    ///
87    /// If true, the encoded envelope can be read as text.
88    #[must_use]
89    pub fn ascii_printable(self) -> bool {
90        matches!(
91            self,
92            Self::PackageJson | Self::ModelText | Self::ModelTextWithExtensions
93        )
94    }
95}
96
97/// Configuration for encoding an envelope.
98#[derive(Clone, Copy, Debug, Eq, PartialEq)]
99#[non_exhaustive]
100pub struct EnvelopeConfig {
101    /// The format to use for the payload.
102    pub format: EnvelopeFormat,
103    /// Whether to compress the payload with zstd, and the compression level to
104    /// use.
105    pub zstd: Option<ZstdConfig>,
106}
107
108impl Default for EnvelopeConfig {
109    fn default() -> Self {
110        let format = Default::default();
111        let zstd = if cfg!(feature = "zstd") {
112            Some(ZstdConfig::default())
113        } else {
114            None
115        };
116        Self { format, zstd }
117    }
118}
119
120impl EnvelopeConfig {
121    /// Create a new envelope header with the specified configuration.
122    pub(super) fn make_header(&self) -> EnvelopeHeader {
123        EnvelopeHeader {
124            format: self.format,
125            zstd: self.zstd.is_some(),
126        }
127    }
128
129    /// Default configuration for a plain-text envelope.
130    #[must_use]
131    pub const fn text() -> Self {
132        Self {
133            format: EnvelopeFormat::PackageJson,
134            zstd: None,
135        }
136    }
137
138    /// Default configuration for a binary envelope.
139    ///
140    /// If the `zstd` feature is enabled, this will use zstd compression.
141    #[must_use]
142    pub const fn binary() -> Self {
143        Self {
144            format: EnvelopeFormat::ModelWithExtensions,
145            zstd: Some(ZstdConfig::default_level()),
146        }
147    }
148}
149
150/// Configuration for zstd compression.
151#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
152#[non_exhaustive]
153pub struct ZstdConfig {
154    /// The compression level to use.
155    ///
156    /// The current range is 1-22, where 1 is fastest and 22 is best
157    /// compression. Values above 20 should be used with caution, as they
158    /// require additional memory.
159    ///
160    /// If `None`, zstd's default level is used.
161    pub level: Option<NonZeroU8>,
162}
163
164impl ZstdConfig {
165    /// Create a new zstd configuration with default compression level.
166    #[must_use]
167    pub const fn default_level() -> Self {
168        Self { level: None }
169    }
170
171    /// Returns the zstd compression level to pass to the zstd library.
172    ///
173    /// Uses [`zstd::DEFAULT_COMPRESSION_LEVEL`] if the level is not set.
174    #[must_use]
175    pub fn level(&self) -> i32 {
176        #[allow(unused_assignments, unused_mut)]
177        let mut default = 0;
178        #[cfg(feature = "zstd")]
179        {
180            default = zstd::DEFAULT_COMPRESSION_LEVEL;
181        }
182        self.level.map_or(default, |l| i32::from(l.get()))
183    }
184}
185
186impl EnvelopeHeader {
187    /// Returns the envelope configuration corresponding to this header.
188    ///
189    /// Note that zstd compression level is not stored in the header.
190    pub fn config(&self) -> EnvelopeConfig {
191        EnvelopeConfig {
192            format: self.format,
193            zstd: if self.zstd {
194                Some(ZstdConfig { level: None })
195            } else {
196                None
197            },
198        }
199    }
200
201    /// Write an envelope header to a writer.
202    ///
203    /// See the [`crate::envelope`] module documentation for the binary format.
204    pub fn write(&self, writer: &mut impl Write) -> Result<(), EnvelopeError> {
205        // The first 8 bytes are the magic number in little-endian.
206        writer.write_all(MAGIC_NUMBERS)?;
207        // Next is the format descriptor.
208        let format_bytes = [self.format as u8];
209        writer.write_all(&format_bytes)?;
210        // Next is the flags byte.
211        let mut flags = 0b01000000u8;
212        flags |= u8::from(self.zstd);
213        writer.write_all(&[flags])?;
214
215        Ok(())
216    }
217
218    /// Reads an envelope header from a reader.
219    ///
220    /// Consumes exactly 10 bytes from the reader.
221    /// See the [`crate::envelope`] module documentation for the binary format.
222    pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, EnvelopeError> {
223        // The first 8 bytes are the magic number in little-endian.
224        let mut magic = [0; 8];
225        reader.read_exact(&mut magic)?;
226        if magic != MAGIC_NUMBERS {
227            return Err(EnvelopeError::MagicNumber {
228                expected: MAGIC_NUMBERS.try_into().unwrap(),
229                found: magic,
230            });
231        }
232
233        // Next is the format descriptor.
234        let mut format_bytes = [0; 1];
235        reader.read_exact(&mut format_bytes)?;
236        let format_discriminant = format_bytes[0] as usize;
237        let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
238            return Err(EnvelopeError::InvalidFormatDescriptor {
239                descriptor: format_discriminant,
240            });
241        };
242
243        // Next is the flags byte.
244        let mut flags_bytes = [0; 1];
245        reader.read_exact(&mut flags_bytes)?;
246        let zstd = flags_bytes[0] & 0x1 != 0;
247
248        Ok(Self { format, zstd })
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255    use rstest::rstest;
256
257    #[rstest]
258    #[case(EnvelopeFormat::Model)]
259    #[case(EnvelopeFormat::ModelWithExtensions)]
260    #[case(EnvelopeFormat::ModelText)]
261    #[case(EnvelopeFormat::ModelTextWithExtensions)]
262    #[case(EnvelopeFormat::PackageJson)]
263    fn header_round_trip(#[case] format: EnvelopeFormat) {
264        // With zstd compression
265        let header = EnvelopeHeader { format, zstd: true };
266
267        let mut buffer = Vec::new();
268        header.write(&mut buffer).unwrap();
269        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
270        assert_eq!(header, read_header);
271
272        // Without zstd compression
273        let header = EnvelopeHeader {
274            format,
275            zstd: false,
276        };
277
278        let mut buffer = Vec::new();
279        header.write(&mut buffer).unwrap();
280        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
281        assert_eq!(header, read_header);
282    }
283}