use std::{
fmt, io,
str::{self, FromStr},
};
use super::MAGIC;
mod parse;
const ALIGN: usize = 64;
#[derive(Clone, Debug, Eq, PartialEq)]
pub(super) struct Header {
pub version: Version,
pub dict: HeaderDict,
}
impl Header {
pub fn new(version: Version, dict: HeaderDict) -> Self {
Self { version, dict }
}
pub fn read<R>(reader: &mut R) -> io::Result<Self>
where
R: io::BufRead,
{
let mut magic_buf = [0; MAGIC.len()];
reader.read_exact(&mut magic_buf)?;
if magic_buf != MAGIC {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"unexpected npy magic number",
));
}
let mut version_buf = [0; 2];
reader.read_exact(&mut version_buf)?;
let version = Version::from_header_bytes(version_buf).map_err(|_| {
io::Error::new(
io::ErrorKind::InvalidData,
"invalid npy version specification",
)
})?;
let header_len = version.read_header_len(reader)?;
let mut dict_buf = vec![0; header_len];
reader.read_exact(&mut dict_buf)?;
let dict_str =
str::from_utf8(&dict_buf).map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
let dict = HeaderDict::from_str(dict_str)
.map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
Ok(Self::new(version, dict))
}
pub fn write<W>(&self, writer: &mut W) -> io::Result<()>
where
W: io::Write,
{
writer.write_all(&MAGIC)?;
let version_bytes = self.version.to_header_bytes();
writer.write_all(&version_bytes)?;
let fmt_dict = self.dict.to_string();
let len = MAGIC.len()
+ version_bytes.len()
+ self.version.header_len_bytes_len()
+ fmt_dict.len();
let rem = len % ALIGN;
let pad_len = if rem == 0 { 0 } else { ALIGN - rem };
assert_eq!((len + pad_len) % ALIGN, 0);
let header_len = fmt_dict.len() + pad_len;
self.version.write_header_len(header_len, writer)?;
writer.write_all(&fmt_dict.into_bytes())?;
let mut pad = vec![b' '; pad_len];
pad[pad_len - 1] = b'\n';
writer.write_all(&pad[..])
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(super) struct HeaderDict {
pub type_descriptor: TypeDescriptor,
pub fortran_order: bool,
pub shape: Vec<usize>,
}
impl HeaderDict {
pub fn new(type_descriptor: TypeDescriptor, fortran_order: bool, shape: Vec<usize>) -> Self {
Self {
type_descriptor,
fortran_order,
shape,
}
}
}
impl fmt::Display for HeaderDict {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let descr = self.type_descriptor.to_string();
let fortran_order = if self.fortran_order { "True" } else { "False" };
let shape_fmt = self
.shape
.iter()
.map(|x| x.to_string())
.collect::<Vec<_>>()
.join(", ");
let shape = format!("({shape_fmt},)");
write!(
f,
"{{'descr': '{descr}', 'fortran_order': {fortran_order}, 'shape': {shape}, }}"
)
}
}
impl FromStr for HeaderDict {
type Err = ParseHeaderError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let mut type_descriptor: Option<TypeDescriptor> = None;
let mut fortran_order: Option<bool> = None;
let mut shape: Option<Vec<usize>> = None;
for entry in parse::parse_header_dict(s)? {
match entry {
parse::Entry::Descr(val) => {
type_descriptor = Some(val);
}
parse::Entry::FortranOrder(val) => {
fortran_order = Some(val);
}
parse::Entry::Shape(val) => {
shape = Some(val);
}
}
}
match (type_descriptor, fortran_order, shape) {
(Some(type_descriptor), Some(fortran_order), Some(shape)) => {
Ok(Self::new(type_descriptor, fortran_order, shape))
}
_ => Err(ParseHeaderError(s.to_string())),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(super) enum Version {
V1,
V2,
V3,
}
impl Version {
fn from_header_bytes(bytes: [u8; 2]) -> Result<Self, [u8; 2]> {
match bytes {
[1, _] => Ok(Self::V1),
[2, _] => Ok(Self::V2),
[3, _] => Ok(Self::V3),
_ => Err(bytes),
}
}
fn read_header_len<R>(&self, reader: &mut R) -> io::Result<usize>
where
R: io::BufRead,
{
match self {
Version::V1 => {
let mut header_len_buf = [0; 2];
reader.read_exact(&mut header_len_buf)?;
Ok(u16::from_le_bytes(header_len_buf).into())
}
Version::V2 | Version::V3 => {
let mut header_len_buf = [0; 4];
reader.read_exact(&mut header_len_buf)?;
Ok(usize::try_from(u32::from_le_bytes(header_len_buf))
.expect("cannot convert npy u32 header_len to usize"))
}
}
}
fn to_header_bytes(&self) -> [u8; 2] {
match self {
Version::V1 => [1, 0],
Version::V2 => [2, 0],
Version::V3 => [3, 0],
}
}
fn write_header_len<W>(&self, header_len: usize, writer: &mut W) -> io::Result<()>
where
W: io::Write,
{
match self {
Version::V1 => writer.write_all(
&u16::try_from(header_len)
.expect("cannot convert npy header_len to u16")
.to_le_bytes(),
),
Version::V2 | Version::V3 => writer.write_all(
&u32::try_from(header_len)
.expect("cannot convert npy header_len to u16")
.to_le_bytes(),
),
}
}
fn header_len_bytes_len(&self) -> usize {
match self {
Version::V1 => 2,
Version::V2 | Version::V3 => 4,
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(super) struct TypeDescriptor {
endian: Endian,
ty: Type,
}
macro_rules! impl_get_read_fn {
($ty:ty, $fn:ident) => {{
|reader: &mut R| {
let mut buf = [0; std::mem::size_of::<$ty>()];
reader.read_exact(&mut buf)?;
Ok(<$ty>::$fn(buf) as f64)
}
}};
}
impl TypeDescriptor {
pub fn new(endian: Endian, ty: Type) -> Self {
Self { endian, ty }
}
fn get_read_fn<R>(&self) -> impl Fn(&mut R) -> io::Result<f64>
where
R: io::BufRead,
{
match (&self.endian, &self.ty) {
(Endian::Little, Type::F4) => impl_get_read_fn!(f32, from_le_bytes),
(Endian::Little, Type::F8) => impl_get_read_fn!(f64, from_le_bytes),
(Endian::Little, Type::I1) => impl_get_read_fn!(i8, from_le_bytes),
(Endian::Little, Type::I2) => impl_get_read_fn!(i16, from_le_bytes),
(Endian::Little, Type::I4) => impl_get_read_fn!(i32, from_le_bytes),
(Endian::Little, Type::I8) => impl_get_read_fn!(i64, from_le_bytes),
(Endian::Little, Type::U1) => impl_get_read_fn!(u8, from_le_bytes),
(Endian::Little, Type::U2) => impl_get_read_fn!(u16, from_le_bytes),
(Endian::Little, Type::U4) => impl_get_read_fn!(u32, from_le_bytes),
(Endian::Little, Type::U8) => impl_get_read_fn!(u64, from_le_bytes),
(Endian::Big, Type::F4) => impl_get_read_fn!(f32, from_be_bytes),
(Endian::Big, Type::F8) => impl_get_read_fn!(f64, from_be_bytes),
(Endian::Big, Type::I1) => impl_get_read_fn!(i8, from_be_bytes),
(Endian::Big, Type::I2) => impl_get_read_fn!(i16, from_be_bytes),
(Endian::Big, Type::I4) => impl_get_read_fn!(i32, from_be_bytes),
(Endian::Big, Type::I8) => impl_get_read_fn!(i64, from_be_bytes),
(Endian::Big, Type::U1) => impl_get_read_fn!(u8, from_be_bytes),
(Endian::Big, Type::U2) => impl_get_read_fn!(u16, from_be_bytes),
(Endian::Big, Type::U4) => impl_get_read_fn!(u32, from_be_bytes),
(Endian::Big, Type::U8) => impl_get_read_fn!(u64, from_be_bytes),
}
}
pub(super) fn read<R>(&self, reader: &mut R) -> io::Result<Vec<f64>>
where
R: io::BufRead,
{
let read_fn = self.get_read_fn();
let mut values = Vec::new();
while !reader.fill_buf()?.is_empty() {
values.push(read_fn(reader)?)
}
Ok(values)
}
}
impl fmt::Display for TypeDescriptor {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let endian_str = match self.endian {
Endian::Little => "<",
Endian::Big => ">",
};
let type_str = match self.ty {
Type::F4 => "f4",
Type::F8 => "f8",
Type::I1 => "i1",
Type::I2 => "i2",
Type::I4 => "i4",
Type::I8 => "i8",
Type::U1 => "u1",
Type::U2 => "u2",
Type::U4 => "u4",
Type::U8 => "u8",
};
write!(f, "{endian_str}{type_str}")
}
}
impl FromStr for TypeDescriptor {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let create_err = || Err(format!("invalid type descriptor '{s}'"));
if s.len() != 3 {
return create_err();
}
let (endian_str, type_str) = s.split_at(1);
let endian = match endian_str {
"<" | "|" => Endian::Little,
">" => Endian::Big,
_ => return create_err(),
};
let ty = match type_str {
"f4" => Type::F4,
"f8" => Type::F8,
"i1" => Type::I1,
"i2" => Type::I2,
"i4" => Type::I4,
"i8" => Type::I8,
"u1" => Type::U1,
"u2" => Type::U2,
"u4" => Type::U4,
"u8" => Type::U8,
_ => return create_err(),
};
Ok(Self::new(endian, ty))
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(super) enum Endian {
Little,
Big,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub(super) enum Type {
F4,
F8,
I1,
I2,
I4,
I8,
U1,
U2,
U4,
U8,
}
#[derive(Debug, Eq, PartialEq)]
pub struct ParseHeaderError(String);
impl fmt::Display for ParseHeaderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "failed to parse '{}' as npy format header", self.0)
}
}
impl std::error::Error for ParseHeaderError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_type_descriptor_read() -> io::Result<()> {
let src: Vec<u8> = (0i16..10).flat_map(|x| x.to_be_bytes()).collect();
let expected: Vec<f64> = (0..10).map(|x| x as f64).collect();
assert_eq!(
TypeDescriptor::new(Endian::Big, Type::I2).read(&mut &src[..])?,
expected
);
Ok(())
}
#[test]
fn test_parse_header_dict() {
assert_eq!(
"{ 'descr': '<f8', 'shape': (15, 3), 'fortran_order': False }".parse(),
Ok(HeaderDict::new(
TypeDescriptor::new(Endian::Little, Type::F8),
false,
vec![15, 3]
))
)
}
#[test]
fn test_display_header_dict() {
assert_eq!(
HeaderDict::new(
TypeDescriptor::new(Endian::Big, Type::U4),
true,
vec![3, 1, 2]
)
.to_string(),
String::from("{'descr': '>u4', 'fortran_order': True, 'shape': (3, 1, 2,), }"),
)
}
#[test]
fn test_read_header() -> io::Result<()> {
let header_dict = HeaderDict::new(
TypeDescriptor::new(Endian::Little, Type::F8),
false,
vec![2, 3],
);
let mut src = vec![
147, 78, 85, 77, 80, 89, 1, 0, 118, 0, ];
src.extend(header_dict.to_string().as_bytes());
src.extend([32; 58]); src.extend([10]);
assert_eq!(
Header::read(&mut &src[..])?,
Header::new(Version::V1, header_dict)
);
Ok(())
}
#[test]
fn test_write_header() -> io::Result<()> {
let header_dict =
HeaderDict::new(TypeDescriptor::new(Endian::Big, Type::F4), false, vec![2]);
let fmt_dict = header_dict.to_string();
let header = Header::new(Version::V2, header_dict);
let mut dest = Vec::new();
header.write(&mut dest)?;
let mut expected = vec![
147, 78, 85, 77, 80, 89, 2, 0, 116, 0, 0, 0, ];
expected.extend(fmt_dict.as_bytes());
expected.extend([32; 58]); expected.extend([10]);
assert_eq!(dest, expected);
Ok(())
}
}