hugr_core/envelope/
header.rs

1//! Definitions for the header of an envelope.
2
3use std::io::{Read, Write};
4use std::num::NonZeroU8;
5
6use super::EnvelopeError;
7
8/// Magic number identifying the start of an envelope.
9///
10/// In ascii, this is "`HUGRiHJv`". The second half is a randomly generated string
11/// to avoid accidental collisions with other file formats.
12pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
13
14/// Header at the start of a binary envelope file.
15///
16/// See the [`crate::envelope`] module documentation for the binary format.
17#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
18#[display("EnvelopeHeader({format}{})",
19    if *zstd { ", zstd compressed" } else { "" },
20)]
21pub(super) struct EnvelopeHeader {
22    /// The format used for the payload.
23    pub format: EnvelopeFormat,
24    /// Whether the payload is compressed with zstd.
25    pub zstd: bool,
26}
27
28/// Encoded format of an envelope payload.
29#[derive(
30    Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
31)]
32#[non_exhaustive]
33pub enum EnvelopeFormat {
34    /// `hugr-model` v0 binary capnproto message.
35    Model = 1,
36    /// `hugr-model` v0 binary capnproto message followed by a json-encoded
37    /// [`crate::extension::ExtensionRegistry`].
38    ///
39    /// This is a temporary format required until the model adds support for
40    /// extensions.
41    ModelWithExtensions = 2,
42    /// Human-readable S-expression encoding using [`hugr_model::v0`].
43    ///
44    /// Uses a printable ascii value as the discriminant so the envelope can be
45    /// read as text.
46    ///
47    /// :caution: This format does not yet support extension encoding, so it should
48    /// be avoided.
49    //
50    // TODO: Update comment once extension encoding is supported.
51    ModelText = 40, // '(' in ascii
52    /// Human-readable S-expression encoding using [`hugr_model::v0`].
53    ///
54    /// Uses a printable ascii value as the discriminant so the envelope can be
55    /// read as text.
56    ///
57    /// This is a temporary format required until the model adds support for
58    /// extensions.
59    ModelTextWithExtensions = 41, // ')' in ascii
60    /// Json-encoded [`crate::package::Package`]
61    ///
62    /// Uses a printable ascii value as the discriminant so the envelope can be
63    /// read as text.
64    #[default]
65    PackageJson = 63, // '?' in ascii
66}
67
68// We use a u8 to represent EnvelopeFormat in the binary format, so we should not
69// add any non-unit variants or ones with discriminants > 255.
70static_assertions::assert_eq_size!(EnvelopeFormat, u8);
71
72impl EnvelopeFormat {
73    /// If the format is a model format, returns its version number.
74    #[must_use]
75    pub fn model_version(self) -> Option<u32> {
76        match self {
77            Self::Model
78            | Self::ModelWithExtensions
79            | Self::ModelText
80            | Self::ModelTextWithExtensions => Some(0),
81            _ => None,
82        }
83    }
84
85    /// Returns whether the encoding format is ASCII-printable.
86    ///
87    /// If true, the encoded envelope can be read as text.
88    #[must_use]
89    pub fn ascii_printable(self) -> bool {
90        matches!(
91            self,
92            Self::PackageJson | Self::ModelText | Self::ModelTextWithExtensions
93        )
94    }
95}
96
97/// Configuration for encoding an envelope.
98#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
99#[non_exhaustive]
100pub struct EnvelopeConfig {
101    /// The format to use for the payload.
102    pub format: EnvelopeFormat,
103    /// Whether to compress the payload with zstd, and the compression level to
104    /// use.
105    pub zstd: Option<ZstdConfig>,
106}
107
108impl EnvelopeConfig {
109    /// Create a new envelope configuration with the specified format.
110    /// `zstd` compression is disabled by default.
111    pub fn new(format: EnvelopeFormat) -> Self {
112        Self {
113            format,
114            ..Default::default()
115        }
116    }
117
118    /// Set the zstd compression configuration for the envelope.
119    pub fn with_zstd(self, zstd: ZstdConfig) -> Self {
120        Self {
121            zstd: Some(zstd),
122            ..self
123        }
124    }
125
126    /// Disable zstd compression in the envelope configuration.
127    pub fn disable_compression(self) -> Self {
128        Self { zstd: None, ..self }
129    }
130
131    /// Create a new envelope header with the specified configuration.
132    pub(super) fn make_header(&self) -> EnvelopeHeader {
133        EnvelopeHeader {
134            format: self.format,
135            zstd: self.zstd.is_some(),
136        }
137    }
138
139    /// Default configuration for a plain-text envelope.
140    #[must_use]
141    pub const fn text() -> Self {
142        Self {
143            format: EnvelopeFormat::PackageJson,
144            zstd: None,
145        }
146    }
147
148    /// Default configuration for a binary envelope.
149    ///
150    /// If the `zstd` feature is enabled, this will use zstd compression.
151    #[must_use]
152    pub const fn binary() -> Self {
153        Self {
154            format: EnvelopeFormat::ModelWithExtensions,
155            zstd: Some(ZstdConfig::default_level()),
156        }
157    }
158}
159
160/// Configuration for zstd compression.
161#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
162#[non_exhaustive]
163pub struct ZstdConfig {
164    /// The compression level to use.
165    ///
166    /// The current range is 1-22, where 1 is fastest and 22 is best
167    /// compression. Values above 20 should be used with caution, as they
168    /// require additional memory.
169    ///
170    /// If `None`, zstd's default level is used.
171    pub level: Option<NonZeroU8>,
172}
173
174impl ZstdConfig {
175    /// Create a new zstd configuration with the specified compression level.
176    pub fn new(level: u8) -> Self {
177        Self {
178            level: NonZeroU8::new(level),
179        }
180    }
181    /// Create a new zstd configuration with default compression level.
182    #[must_use]
183    pub const fn default_level() -> Self {
184        Self { level: None }
185    }
186
187    /// Returns the zstd compression level to pass to the zstd library.
188    ///
189    /// Uses [`zstd::DEFAULT_COMPRESSION_LEVEL`] if the level is not set.
190    #[must_use]
191    pub fn level(&self) -> i32 {
192        #[allow(unused_assignments, unused_mut)]
193        let mut default = 0;
194        #[cfg(feature = "zstd")]
195        {
196            default = zstd::DEFAULT_COMPRESSION_LEVEL;
197        }
198        self.level.map_or(default, |l| i32::from(l.get()))
199    }
200}
201
202impl EnvelopeHeader {
203    /// Returns the envelope configuration corresponding to this header.
204    ///
205    /// Note that zstd compression level is not stored in the header.
206    pub fn config(&self) -> EnvelopeConfig {
207        EnvelopeConfig {
208            format: self.format,
209            zstd: if self.zstd {
210                Some(ZstdConfig { level: None })
211            } else {
212                None
213            },
214        }
215    }
216
217    /// Write an envelope header to a writer.
218    ///
219    /// See the [`crate::envelope`] module documentation for the binary format.
220    pub fn write(&self, writer: &mut impl Write) -> Result<(), EnvelopeError> {
221        // The first 8 bytes are the magic number in little-endian.
222        writer.write_all(MAGIC_NUMBERS)?;
223        // Next is the format descriptor.
224        let format_bytes = [self.format as u8];
225        writer.write_all(&format_bytes)?;
226        // Next is the flags byte.
227        let mut flags = 0b01000000u8;
228        flags |= u8::from(self.zstd);
229        writer.write_all(&[flags])?;
230
231        Ok(())
232    }
233
234    /// Reads an envelope header from a reader.
235    ///
236    /// Consumes exactly 10 bytes from the reader.
237    /// See the [`crate::envelope`] module documentation for the binary format.
238    pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, EnvelopeError> {
239        // The first 8 bytes are the magic number in little-endian.
240        let mut magic = [0; 8];
241        reader.read_exact(&mut magic)?;
242        if magic != MAGIC_NUMBERS {
243            return Err(EnvelopeError::MagicNumber {
244                expected: MAGIC_NUMBERS.try_into().unwrap(),
245                found: magic,
246            });
247        }
248
249        // Next is the format descriptor.
250        let mut format_bytes = [0; 1];
251        reader.read_exact(&mut format_bytes)?;
252        let format_discriminant = format_bytes[0] as usize;
253        let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
254            return Err(EnvelopeError::InvalidFormatDescriptor {
255                descriptor: format_discriminant,
256            });
257        };
258
259        // Next is the flags byte.
260        let mut flags_bytes = [0; 1];
261        reader.read_exact(&mut flags_bytes)?;
262        let zstd = flags_bytes[0] & 0x1 != 0;
263
264        Ok(Self { format, zstd })
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271    use rstest::rstest;
272
273    #[rstest]
274    #[case(EnvelopeFormat::Model)]
275    #[case(EnvelopeFormat::ModelWithExtensions)]
276    #[case(EnvelopeFormat::ModelText)]
277    #[case(EnvelopeFormat::ModelTextWithExtensions)]
278    #[case(EnvelopeFormat::PackageJson)]
279    fn header_round_trip(#[case] format: EnvelopeFormat) {
280        // With zstd compression
281        let header = EnvelopeHeader { format, zstd: true };
282
283        let mut buffer = Vec::new();
284        header.write(&mut buffer).unwrap();
285        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
286        assert_eq!(header, read_header);
287
288        // Without zstd compression
289        let header = EnvelopeHeader {
290            format,
291            zstd: false,
292        };
293
294        let mut buffer = Vec::new();
295        header.write(&mut buffer).unwrap();
296        let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
297        assert_eq!(header, read_header);
298    }
299}