hugr_core/envelope/
header.rs1use std::io::{Read, Write};
4use std::num::NonZeroU8;
5
6use super::EnvelopeError;
7
8pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
13
14#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
18#[display("EnvelopeHeader({format}{})",
19 if *zstd { ", zstd compressed" } else { "" },
20)]
21pub(super) struct EnvelopeHeader {
22 pub format: EnvelopeFormat,
24 pub zstd: bool,
26}
27
28#[derive(
30 Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
31)]
32#[non_exhaustive]
33pub enum EnvelopeFormat {
34 Model = 1,
36 ModelWithExtensions = 2,
40 #[default]
45 PackageJson = 63, }
47
48static_assertions::assert_eq_size!(EnvelopeFormat, u8);
51
52impl EnvelopeFormat {
53 pub fn append_extensions(self) -> bool {
55 matches!(self, Self::ModelWithExtensions)
56 }
57
58 pub fn model_version(self) -> Option<u32> {
60 match self {
61 Self::Model | Self::ModelWithExtensions => Some(0),
62 _ => None,
63 }
64 }
65
66 pub fn ascii_printable(self) -> bool {
70 matches!(self, Self::PackageJson)
71 }
72}
73
74#[derive(Clone, Copy, Debug, Eq, PartialEq)]
76#[non_exhaustive]
77pub struct EnvelopeConfig {
78 pub format: EnvelopeFormat,
80 pub zstd: Option<ZstdConfig>,
83}
84
85impl Default for EnvelopeConfig {
86 fn default() -> Self {
87 let format = Default::default();
88 let zstd = if cfg!(feature = "zstd") {
89 Some(ZstdConfig::default())
90 } else {
91 None
92 };
93 Self { format, zstd }
94 }
95}
96
97impl EnvelopeConfig {
98 pub(super) fn make_header(&self) -> EnvelopeHeader {
100 EnvelopeHeader {
101 format: self.format,
102 zstd: self.zstd.is_some(),
103 }
104 }
105
106 pub const fn text() -> Self {
108 Self {
109 format: EnvelopeFormat::PackageJson,
110 zstd: None,
111 }
112 }
113
114 pub const fn binary() -> Self {
118 Self {
119 format: EnvelopeFormat::Model,
120 zstd: Some(ZstdConfig { level: None }),
121 }
122 }
123}
124
125#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
127#[non_exhaustive]
128pub struct ZstdConfig {
129 pub level: Option<NonZeroU8>,
137}
138
139impl ZstdConfig {
140 pub fn level(&self) -> i32 {
144 #[allow(unused_assignments, unused_mut)]
145 let mut default = 0;
146 #[cfg(feature = "zstd")]
147 {
148 default = zstd::DEFAULT_COMPRESSION_LEVEL;
149 }
150 self.level.map_or(default, |l| l.get() as i32)
151 }
152}
153
154impl EnvelopeHeader {
155 pub fn config(&self) -> EnvelopeConfig {
159 EnvelopeConfig {
160 format: self.format,
161 zstd: match self.zstd {
162 true => Some(ZstdConfig { level: None }),
163 false => None,
164 },
165 }
166 }
167
168 pub fn write(&self, writer: &mut impl Write) -> Result<(), EnvelopeError> {
172 writer.write_all(MAGIC_NUMBERS)?;
174 let format_bytes = [self.format as u8];
176 writer.write_all(&format_bytes)?;
177 let mut flags = 0b01000000u8;
179 flags |= self.zstd as u8;
180 writer.write_all(&[flags])?;
181
182 Ok(())
183 }
184
185 pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, EnvelopeError> {
190 let mut magic = [0; 8];
192 reader.read_exact(&mut magic)?;
193 if magic != MAGIC_NUMBERS {
194 return Err(EnvelopeError::MagicNumber {
195 expected: MAGIC_NUMBERS.try_into().unwrap(),
196 found: magic,
197 });
198 }
199
200 let mut format_bytes = [0; 1];
202 reader.read_exact(&mut format_bytes)?;
203 let format_discriminant = format_bytes[0] as usize;
204 let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
205 return Err(EnvelopeError::InvalidFormatDescriptor {
206 descriptor: format_discriminant,
207 });
208 };
209
210 let mut flags_bytes = [0; 1];
212 reader.read_exact(&mut flags_bytes)?;
213 let zstd = flags_bytes[0] & 0x1 != 0;
214
215 Ok(Self { format, zstd })
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use rstest::rstest;
223
224 #[rstest]
225 #[case(EnvelopeFormat::Model)]
226 #[case(EnvelopeFormat::ModelWithExtensions)]
227 #[case(EnvelopeFormat::PackageJson)]
228 fn header_round_trip(#[case] format: EnvelopeFormat) {
229 let header = EnvelopeHeader { format, zstd: true };
231
232 let mut buffer = Vec::new();
233 header.write(&mut buffer).unwrap();
234 let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
235 assert_eq!(header, read_header);
236
237 let header = EnvelopeHeader {
239 format,
240 zstd: false,
241 };
242
243 let mut buffer = Vec::new();
244 header.write(&mut buffer).unwrap();
245 let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
246 assert_eq!(header, read_header);
247 }
248}