use std::io::{Read, Write};
use std::num::NonZeroU8;
use itertools::Itertools;
use thiserror::Error;
pub const MAGIC_NUMBERS: &[u8] = "HUGRiHJv".as_bytes();
const DEFAULT_FLAGS: u8 = 0b0100_0000u8;
const ZSTD_FLAG: u8 = 0b0000_0001;
#[derive(Clone, Copy, Eq, PartialEq, Debug, Default, derive_more::Display)]
#[display("EnvelopeHeader({format}{})",
if *zstd { ", zstd compressed" } else { "" },
)]
pub struct EnvelopeHeader {
pub format: EnvelopeFormat,
pub zstd: bool,
}
mod silenced {
#![expect(deprecated, reason = "https://github.com/Peternator7/strum/issues/404")]
#[derive(
Clone, Copy, Eq, PartialEq, Debug, Default, Hash, derive_more::Display, strum::FromRepr,
)]
#[non_exhaustive]
pub enum EnvelopeFormat {
Model = 1,
#[default]
ModelWithExtensions = 2,
SExpression = 40, SExpressionWithExtensions = 41, #[deprecated(since = "0.27.0")]
PackageJson = 63, }
}
pub use silenced::EnvelopeFormat;
static_assertions::assert_eq_size!(EnvelopeFormat, u8);
impl EnvelopeFormat {
#[must_use]
pub fn model_version(self) -> Option<u32> {
match self {
Self::Model
| Self::ModelWithExtensions
| Self::SExpression
| Self::SExpressionWithExtensions => Some(0),
_ => None,
}
}
#[must_use]
#[expect(deprecated)]
pub fn ascii_printable(self) -> bool {
matches!(
self,
Self::PackageJson | Self::SExpression | Self::SExpressionWithExtensions
)
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[non_exhaustive]
pub struct EnvelopeConfig {
pub format: EnvelopeFormat,
pub zstd: Option<ZstdConfig>,
}
impl EnvelopeConfig {
pub fn new(format: EnvelopeFormat) -> Self {
Self {
format,
..Default::default()
}
}
pub fn with_zstd(self, zstd: ZstdConfig) -> Self {
Self {
zstd: Some(zstd),
..self
}
}
pub fn disable_compression(self) -> Self {
Self { zstd: None, ..self }
}
pub(super) fn make_header(&self) -> EnvelopeHeader {
EnvelopeHeader {
format: self.format,
zstd: self.zstd.is_some(),
}
}
#[must_use]
pub const fn text() -> Self {
Self {
format: EnvelopeFormat::SExpressionWithExtensions,
zstd: None,
}
}
#[must_use]
pub const fn binary() -> Self {
Self {
format: EnvelopeFormat::ModelWithExtensions,
zstd: Some(ZstdConfig::default_level()),
}
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
#[non_exhaustive]
pub struct ZstdConfig {
pub level: Option<NonZeroU8>,
}
impl ZstdConfig {
pub fn new(level: u8) -> Self {
Self {
level: NonZeroU8::new(level),
}
}
#[must_use]
pub const fn default_level() -> Self {
Self { level: None }
}
#[must_use]
pub fn level(&self) -> i32 {
#[allow(unused_assignments, unused_mut)]
let mut default = 0;
#[cfg(feature = "zstd")]
{
default = zstd::DEFAULT_COMPRESSION_LEVEL;
}
self.level.map_or(default, |l| i32::from(l.get()))
}
}
#[derive(Debug, Error, derive_more::Display)]
#[display("Error reading the envelope header. {_0}")]
pub struct HeaderError(HeaderErrorInner);
#[derive(Debug, Error)]
#[non_exhaustive]
pub(super) enum HeaderErrorInner {
#[error(
"Bad magic number. expected 0x{:X} found 0x{:X}",
u64::from_be_bytes(*expected),
u64::from_be_bytes(*found)
)]
MagicNumber {
expected: [u8; 8],
found: [u8; 8],
},
#[error("Format descriptor {descriptor} is invalid.")]
InvalidFormatDescriptor {
descriptor: usize,
},
#[error(transparent)]
IO {
#[from]
source: std::io::Error,
},
#[error(
"The envelope configuration has unknown {}. Please update your HUGR version.",
if flag_ids.len() == 1 {format!("flag #{}", flag_ids[0])} else {format!("flags {}", flag_ids.iter().join(", "))}
)]
FlagUnsupported {
flag_ids: Vec<usize>,
},
#[cfg(not(feature = "zstd"))]
#[error("Zstd compression is not supported. This requires the 'zstd' feature for `hugr`.")]
ZstdUnsupported,
}
impl<T: Into<HeaderErrorInner>> From<T> for HeaderError {
fn from(value: T) -> Self {
Self(value.into())
}
}
impl EnvelopeHeader {
pub fn config(&self) -> EnvelopeConfig {
EnvelopeConfig {
format: self.format,
zstd: if self.zstd {
Some(ZstdConfig { level: None })
} else {
None
},
}
}
pub fn write(&self, writer: &mut impl Write) -> Result<(), HeaderError> {
writer.write_all(MAGIC_NUMBERS)?;
let format_bytes = [self.format as u8];
writer.write_all(&format_bytes)?;
let mut flags = DEFAULT_FLAGS;
if self.zstd {
flags |= ZSTD_FLAG;
}
writer.write_all(&[flags])?;
Ok(())
}
pub fn read(reader: &mut impl Read) -> Result<EnvelopeHeader, HeaderError> {
let mut magic = [0; 8];
reader.read_exact(&mut magic)?;
if magic != MAGIC_NUMBERS {
return Err(HeaderErrorInner::MagicNumber {
expected: MAGIC_NUMBERS.try_into().unwrap(),
found: magic,
}
.into());
}
let mut format_bytes = [0; 1];
reader.read_exact(&mut format_bytes)?;
let format_discriminant = format_bytes[0] as usize;
let Some(format) = EnvelopeFormat::from_repr(format_discriminant) else {
return Err(HeaderErrorInner::InvalidFormatDescriptor {
descriptor: format_discriminant,
}
.into());
};
let mut flags_bytes = [0; 1];
reader.read_exact(&mut flags_bytes)?;
let flags: u8 = flags_bytes[0];
let zstd = flags & ZSTD_FLAG != 0;
let other_flags = (flags ^ DEFAULT_FLAGS) & !ZSTD_FLAG;
if other_flags != 0 {
let flag_ids = (0..8).filter(|i| other_flags & (1 << i) != 0).collect_vec();
return Err(HeaderErrorInner::FlagUnsupported { flag_ids }.into());
}
Ok(Self { format, zstd })
}
}
#[cfg(test)]
mod tests {
use super::*;
use cool_asserts::assert_matches;
use rstest::rstest;
#[rstest]
#[case(EnvelopeFormat::Model)]
#[case(EnvelopeFormat::ModelWithExtensions)]
#[case(EnvelopeFormat::SExpression)]
#[case(EnvelopeFormat::SExpressionWithExtensions)]
#[case(EnvelopeFormat::PackageJson)]
#[allow(deprecated)]
fn header_round_trip(#[case] format: EnvelopeFormat) {
let header = EnvelopeHeader { format, zstd: true };
let mut buffer = Vec::new();
header.write(&mut buffer).unwrap();
let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
assert_eq!(header, read_header);
let header = EnvelopeHeader {
format,
zstd: false,
};
let mut buffer = Vec::new();
header.write(&mut buffer).unwrap();
let read_header = EnvelopeHeader::read(&mut buffer.as_slice()).unwrap();
assert_eq!(header, read_header);
}
#[rstest]
fn header_errors() {
let header = EnvelopeHeader {
format: EnvelopeFormat::Model,
zstd: false,
};
let mut buffer = Vec::new();
header.write(&mut buffer).unwrap();
assert_eq!(buffer.len(), 10);
let flags = buffer[9];
assert_eq!(flags, DEFAULT_FLAGS);
let mut invalid_magic = buffer.clone();
invalid_magic[7] = 0xFF;
assert_matches!(
EnvelopeHeader::read(&mut invalid_magic.as_slice()),
Err(HeaderError(HeaderErrorInner::MagicNumber { .. }))
);
let mut unrecognised_flags = buffer.clone();
unrecognised_flags[9] |= 0b0001_0010;
assert_matches!(
EnvelopeHeader::read(&mut unrecognised_flags.as_slice()),
Err(HeaderError(HeaderErrorInner::FlagUnsupported { flag_ids }))
=> assert_eq!(flag_ids, vec![1, 4])
);
}
}