hugr_core/envelope/
header.rs1use std::io::{Read, Write};
4use std::num::NonZeroU8;
5
6use itertools::Itertools;
7
8use super::EnvelopeError;
9
10pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
15
16const DEFAULT_FLAGS: u8 = 0b0100_0000u8;
19const ZSTD_FLAG: u8 = 0b0000_0001;
21
22#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
26#[display("EnvelopeHeader({format}{})",
27 if *zstd { ", zstd compressed" } else { "" },
28)]
29pub(super) struct EnvelopeHeader {
30 pub format: EnvelopeFormat,
32 pub zstd: bool,
34}
35
36#[derive(
38 Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
39)]
40#[non_exhaustive]
41pub enum EnvelopeFormat {
42 Model = 1,
44 ModelWithExtensions = 2,
50 ModelText = 40, ModelTextWithExtensions = 41, #[default]
73 PackageJson = 63, }
75
76static_assertions::assert_eq_size!(EnvelopeFormat, u8);
79
80impl EnvelopeFormat {
81 #[must_use]
83 pub fn model_version(self) -> Option<u32> {
84 match self {
85 Self::Model
86 | Self::ModelWithExtensions
87 | Self::ModelText
88 | Self::ModelTextWithExtensions => Some(0),
89 _ => None,
90 }
91 }
92
93 #[must_use]
97 pub fn ascii_printable(self) -> bool {
98 matches!(
99 self,
100 Self::PackageJson | Self::ModelText | Self::ModelTextWithExtensions
101 )
102 }
103}
104
105#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
107#[non_exhaustive]
108pub struct EnvelopeConfig {
109 pub format: EnvelopeFormat,
111 pub zstd: Option<ZstdConfig>,
114}
115
116impl EnvelopeConfig {
117 pub fn new(format: EnvelopeFormat) -> Self {
120 Self {
121 format,
122 ..Default::default()
123 }
124 }
125
126 pub fn with_zstd(self, zstd: ZstdConfig) -> Self {
128 Self {
129 zstd: Some(zstd),
130 ..self
131 }
132 }
133
134 pub fn disable_compression(self) -> Self {
136 Self { zstd: None, ..self }
137 }
138
139 pub(super) fn make_header(&self) -> EnvelopeHeader {
141 EnvelopeHeader {
142 format: self.format,
143 zstd: self.zstd.is_some(),
144 }
145 }
146
147 #[must_use]
149 pub const fn text() -> Self {
150 Self {
151 format: EnvelopeFormat::PackageJson,
152 zstd: None,
153 }
154 }
155
156 #[must_use]
160 pub const fn binary() -> Self {
161 Self {
162 format: EnvelopeFormat::ModelWithExtensions,
163 zstd: Some(ZstdConfig::default_level()),
164 }
165 }
166}
167
168#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
170#[non_exhaustive]
171pub struct ZstdConfig {
172 pub level: Option<NonZeroU8>,
180}
181
182impl ZstdConfig {
183 pub fn new(level: u8) -> Self {
185 Self {
186 level: NonZeroU8::new(level),
187 }
188 }
189 #[must_use]
191 pub const fn default_level() -> Self {
192 Self { level: None }
193 }
194
195 #[must_use]
199 pub fn level(&self) -> i32 {
200 #[allow(unused_assignments, unused_mut)]
201 let mut default = 0;
202 #[cfg(feature = "zstd")]
203 {
204 default = zstd::DEFAULT_COMPRESSION_LEVEL;
205 }
206 self.level.map_or(default, |l| i32::from(l.get()))
207 }
208}
209
210impl EnvelopeHeader {
211 pub fn config(&self) -> EnvelopeConfig {
215 EnvelopeConfig {
216 format: self.format,
217 zstd: if self.zstd {
218 Some(ZstdConfig { level: None })
219 } else {
220 None
221 },
222 }
223 }
224
225 pub fn write(&self, writer: &mut impl Write) -> Result<(), EnvelopeError> {
229 writer.write_all(MAGIC_NUMBERS)?;
231 let format_bytes = [self.format as u8];
233 writer.write_all(&format_bytes)?;
234 let mut flags = DEFAULT_FLAGS;
236 if self.zstd {
237 flags |= ZSTD_FLAG;
238 }
239 writer.write_all(&[flags])?;
240
241 Ok(())
242 }
243
244 pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, EnvelopeError> {
249 let mut magic = [0; 8];
251 reader.read_exact(&mut magic)?;
252 if magic != MAGIC_NUMBERS {
253 return Err(EnvelopeError::MagicNumber {
254 expected: MAGIC_NUMBERS.try_into().unwrap(),
255 found: magic,
256 });
257 }
258
259 let mut format_bytes = [0; 1];
261 reader.read_exact(&mut format_bytes)?;
262 let format_discriminant = format_bytes[0] as usize;
263 let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
264 return Err(EnvelopeError::InvalidFormatDescriptor {
265 descriptor: format_discriminant,
266 });
267 };
268
269 let mut flags_bytes = [0; 1];
271 reader.read_exact(&mut flags_bytes)?;
272 let flags: u8 = flags_bytes[0];
273
274 let zstd = flags & ZSTD_FLAG != 0;
275
276 let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG;
278 if other_flags != 0 {
279 let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec();
280 return Err(EnvelopeError::FlagUnsupported { flag_ids });
281 }
282
283 Ok(Self { format, zstd })
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use cool_asserts::assert_matches;
291 use rstest::rstest;
292
293 #[rstest]
294 #[case(EnvelopeFormat::Model)]
295 #[case(EnvelopeFormat::ModelWithExtensions)]
296 #[case(EnvelopeFormat::ModelText)]
297 #[case(EnvelopeFormat::ModelTextWithExtensions)]
298 #[case(EnvelopeFormat::PackageJson)]
299 fn header_round_trip(#[case] format: EnvelopeFormat) {
300 let header = EnvelopeHeader { format, zstd: true };
302
303 let mut buffer = Vec::new();
304 header.write(&mut buffer).unwrap();
305 let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
306 assert_eq!(header, read_header);
307
308 let header = EnvelopeHeader {
310 format,
311 zstd: false,
312 };
313
314 let mut buffer = Vec::new();
315 header.write(&mut buffer).unwrap();
316 let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
317 assert_eq!(header, read_header);
318 }
319
320 #[rstest]
321 fn header_errors() {
322 let header = EnvelopeHeader {
323 format: EnvelopeFormat::Model,
324 zstd: false,
325 };
326 let mut buffer = Vec::new();
327 header.write(&mut buffer).unwrap();
328
329 assert_eq!(buffer.len(), 10);
330 let flags = buffer[9];
331 assert_eq!(flags, DEFAULT_FLAGS);
332
333 let mut invalid_magic = buffer.clone();
335 invalid_magic[7] = 0xFF;
336 assert_matches!(
337 EnvelopeHeader::read(&mut invalid_magic.as_slice()),
338 Err(EnvelopeError::MagicNumber { .. })
339 );
340
341 let mut unrecognised_flags = buffer.clone();
343 unrecognised_flags[9] |= 0b0001_0010;
344 assert_matches!(
345 EnvelopeHeader::read(&mut unrecognised_flags.as_slice()),
346 Err(EnvelopeError::FlagUnsupported { flag_ids })
347 => assert_eq!(flag_ids, vec![1, 4])
348 );
349 }
350}