#![forbid(unsafe_code)]
#![deny(
missing_docs,
rustdoc::broken_intra_doc_links,
rustdoc::missing_crate_level_docs
)]
#![allow(clippy::tabs_in_doc_comments)]
#![cfg_attr(docsrs, feature(doc_auto_cfg))]
pub use ndarray;
pub use ndarray_npy;
use ndarray::{
prelude::*,
{Data, DataOwned},
};
use ndarray_npy::{
ReadNpyError, ReadNpyExt, ReadableElement, ViewElement, ViewMutElement, ViewMutNpyExt,
ViewNpyError, ViewNpyExt, WritableElement, WriteNpyError, WriteNpyExt,
};
use std::{
collections::{BTreeMap, HashMap, HashSet},
error::Error,
fmt,
io::{self, BufWriter, Cursor, Read, Seek, Write},
ops::Range,
};
use zip::{
result::ZipError,
write::SimpleFileOptions,
{CompressionMethod, ZipArchive, ZipWriter},
};
#[derive(Debug)]
pub enum WriteNpzError {
Zip(ZipError),
Npy(WriteNpyError),
}
impl Error for WriteNpzError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
WriteNpzError::Zip(err) => Some(err),
WriteNpzError::Npy(err) => Some(err),
}
}
}
impl fmt::Display for WriteNpzError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
WriteNpzError::Zip(err) => write!(f, "zip file error: {err}"),
WriteNpzError::Npy(err) => write!(f, "error writing npy file to npz archive: {err}"),
}
}
}
impl From<ZipError> for WriteNpzError {
fn from(err: ZipError) -> WriteNpzError {
WriteNpzError::Zip(err)
}
}
impl From<WriteNpyError> for WriteNpzError {
fn from(err: WriteNpyError) -> WriteNpzError {
WriteNpzError::Npy(err)
}
}
pub struct NpzWriter<W: Write + Seek> {
zip: ZipWriter<W>,
options: SimpleFileOptions,
}
impl<W: Write + Seek> NpzWriter<W> {
#[must_use]
pub fn new(writer: W) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options: SimpleFileOptions::default()
.with_alignment(64)
.compression_method(CompressionMethod::Stored),
}
}
#[cfg(feature = "compressed")]
#[must_use]
pub fn new_compressed(writer: W) -> NpzWriter<W> {
NpzWriter {
zip: ZipWriter::new(writer),
options: SimpleFileOptions::default().compression_method(CompressionMethod::Deflated),
}
}
pub fn add_array<N, S, D>(
&mut self,
name: N,
array: &ArrayBase<S, D>,
) -> Result<(), WriteNpzError>
where
N: Into<String>,
S::Elem: WritableElement,
S: Data,
D: Dimension,
{
self.zip.start_file(name.into(), self.options)?;
array.write_npy(BufWriter::new(&mut self.zip))?;
Ok(())
}
pub fn finish(self) -> Result<W, WriteNpzError> {
let mut writer = self.zip.finish()?;
writer.flush().map_err(ZipError::from)?;
Ok(writer)
}
}
#[derive(Debug)]
pub enum ReadNpzError {
Zip(ZipError),
Npy(ReadNpyError),
}
impl Error for ReadNpzError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ReadNpzError::Zip(err) => Some(err),
ReadNpzError::Npy(err) => Some(err),
}
}
}
impl fmt::Display for ReadNpzError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ReadNpzError::Zip(err) => write!(f, "zip file error: {err}"),
ReadNpzError::Npy(err) => write!(f, "error reading npy file in npz archive: {err}"),
}
}
}
impl From<ZipError> for ReadNpzError {
fn from(err: ZipError) -> ReadNpzError {
ReadNpzError::Zip(err)
}
}
impl From<ReadNpyError> for ReadNpzError {
fn from(err: ReadNpyError) -> ReadNpzError {
ReadNpzError::Npy(err)
}
}
pub struct NpzReader<R: Read + Seek> {
zip: ZipArchive<R>,
}
impl<R: Read + Seek> NpzReader<R> {
pub fn new(reader: R) -> Result<NpzReader<R>, ReadNpzError> {
Ok(NpzReader {
zip: ZipArchive::new(reader)?,
})
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.zip.len() == 0
}
#[must_use]
pub fn len(&self) -> usize {
self.zip.len()
}
pub fn names(&mut self) -> Result<Vec<String>, ReadNpzError> {
Ok((0..self.zip.len())
.map(|i| Ok(self.zip.by_index(i)?.name().to_owned()))
.collect::<Result<_, ZipError>>()?)
}
pub fn by_name<S, D>(&mut self, name: &str) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_name(name)?)?)
}
pub fn by_index<S, D>(&mut self, index: usize) -> Result<ArrayBase<S, D>, ReadNpzError>
where
S::Elem: ReadableElement,
S: DataOwned,
D: Dimension,
{
Ok(ArrayBase::<S, D>::read_npy(self.zip.by_index(index)?)?)
}
}
#[derive(Debug)]
#[non_exhaustive]
pub enum ViewNpzError {
Zip(ZipError),
Npy(ViewNpyError),
MovedNpyViewMut,
Directory,
CompressedFile,
EncryptedFile,
}
impl Error for ViewNpzError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
ViewNpzError::Zip(err) => Some(err),
ViewNpzError::Npy(err) => Some(err),
ViewNpzError::MovedNpyViewMut
| ViewNpzError::Directory
| ViewNpzError::CompressedFile
| ViewNpzError::EncryptedFile => None,
}
}
}
impl fmt::Display for ViewNpzError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match self {
ViewNpzError::Zip(err) => write!(f, "zip file error: {err}"),
ViewNpzError::Npy(err) => write!(f, "error viewing npy file in npz archive: {err}"),
ViewNpzError::MovedNpyViewMut => write!(
f,
"mutable npy file view already moved out of npz file view"
),
ViewNpzError::Directory => write!(f, "directories cannot be viewed"),
ViewNpzError::CompressedFile => write!(f, "compressed files cannot be viewed"),
ViewNpzError::EncryptedFile => write!(f, "encrypted files cannot be viewed"),
}
}
}
impl From<ZipError> for ViewNpzError {
fn from(err: ZipError) -> ViewNpzError {
ViewNpzError::Zip(err)
}
}
impl From<ViewNpyError> for ViewNpzError {
fn from(err: ViewNpyError) -> ViewNpzError {
ViewNpzError::Npy(err)
}
}
#[derive(Debug, Clone)]
pub struct NpzView<'a> {
files: HashMap<usize, NpyView<'a>>,
names: HashMap<String, usize>,
directory_names: HashSet<String>,
compressed_names: HashSet<String>,
encrypted_names: HashSet<String>,
}
impl<'a> NpzView<'a> {
pub fn new(bytes: &'a [u8]) -> Result<Self, ViewNpzError> {
let mut zip = ZipArchive::new(Cursor::new(bytes))?;
let mut archive = Self {
files: HashMap::new(),
names: HashMap::new(),
directory_names: HashSet::new(),
compressed_names: HashSet::new(),
encrypted_names: zip.file_names().map(From::from).collect(),
};
let mut index = 0;
for zip_index in 0..zip.len() {
let file = match zip.by_index(zip_index) {
Err(ZipError::UnsupportedArchive(ZipError::PASSWORD_REQUIRED)) => continue,
Err(err) => return Err(err.into()),
Ok(file) => file,
};
let name = file.name().to_string();
archive.encrypted_names.remove(&name);
if file.is_dir() {
archive.directory_names.insert(name);
continue;
}
if file.compression() != CompressionMethod::Stored {
archive.compressed_names.insert(name);
continue;
}
archive.names.insert(name, index);
let file = NpyView {
data: slice_at(bytes, file.data_start(), 0..file.size())?,
central_crc32: slice_at(bytes, file.central_header_start(), 16..20)
.map(as_array_ref)?,
status: ChecksumStatus::default(),
};
archive.files.insert(index, file);
index += 1;
}
Ok(archive)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.names.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.names.len()
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.names.keys().map(String::as_str)
}
pub fn directory_names(&self) -> impl Iterator<Item = &str> {
self.directory_names.iter().map(String::as_str)
}
pub fn compressed_names(&self) -> impl Iterator<Item = &str> {
self.compressed_names.iter().map(String::as_str)
}
pub fn encrypted_names(&self) -> impl Iterator<Item = &str> {
self.encrypted_names.iter().map(String::as_str)
}
pub fn by_name(&self, name: &str) -> Result<NpyView<'a>, ViewNpzError> {
self.by_index(self.names.get(name).copied().ok_or_else(|| {
if self.directory_names.contains(name) {
ViewNpzError::Directory
} else if self.compressed_names.contains(name) {
ViewNpzError::CompressedFile
} else if self.encrypted_names.contains(name) {
ViewNpzError::EncryptedFile
} else {
ZipError::FileNotFound.into()
}
})?)
}
pub fn by_index(&self, index: usize) -> Result<NpyView<'a>, ViewNpzError> {
self.files
.get(&index)
.copied()
.ok_or_else(|| ZipError::FileNotFound.into())
}
}
#[derive(Debug, Clone, Copy)]
pub struct NpyView<'a> {
data: &'a [u8],
central_crc32: &'a [u8; 4],
status: ChecksumStatus,
}
impl NpyView<'_> {
#[must_use]
pub fn status(&self) -> ChecksumStatus {
self.status
}
pub fn verify(&mut self) -> Result<u32, ViewNpzError> {
self.status = ChecksumStatus::Outdated;
let crc32 = crc32_verify(self.data, *self.central_crc32)?;
self.status = ChecksumStatus::Correct;
Ok(crc32)
}
pub fn view<A, D>(&self) -> Result<ArrayView<A, D>, ViewNpzError>
where
A: ViewElement,
D: Dimension,
{
Ok(ArrayView::view_npy(self.data)?)
}
}
#[derive(Debug)]
pub struct NpzViewMut<'a> {
files: HashMap<usize, NpyViewMut<'a>>,
names: HashMap<String, usize>,
directory_names: HashSet<String>,
compressed_names: HashSet<String>,
encrypted_names: HashSet<String>,
}
impl<'a> NpzViewMut<'a> {
pub fn new(mut bytes: &'a mut [u8]) -> Result<Self, ViewNpzError> {
let mut zip = ZipArchive::new(Cursor::new(&bytes))?;
let mut archive = Self {
files: HashMap::new(),
names: HashMap::new(),
directory_names: HashSet::new(),
compressed_names: HashSet::new(),
encrypted_names: zip.file_names().map(From::from).collect(),
};
let mut ranges = HashMap::new();
let mut splits = BTreeMap::new();
let mut index = 0;
for zip_index in 0..zip.len() {
let file = match zip.by_index(zip_index) {
Err(ZipError::UnsupportedArchive(ZipError::PASSWORD_REQUIRED)) => continue,
Err(err) => return Err(err.into()),
Ok(file) => file,
};
let name = file.name().to_string();
archive.encrypted_names.remove(&name);
if file.is_dir() {
archive.directory_names.insert(name);
continue;
}
if file.compression() != CompressionMethod::Stored {
archive.compressed_names.insert(name);
continue;
}
if file.is_dir() || file.compression() != CompressionMethod::Stored {
continue;
}
archive.names.insert(name, index);
let data_range = range_at(file.data_start(), 0..file.size())?;
let central_flag_range = range_at(file.central_header_start(), 8..10)?;
let central_flag = u16_at(bytes, central_flag_range);
let central_crc32_range = range_at(file.central_header_start(), 16..20)?;
let use_data_descriptor = central_flag & (1 << 3) != 0;
let crc32_range = if use_data_descriptor {
let crc32_range = range_at(data_range.end, 0..4)?;
let crc32 = u32_at(bytes, crc32_range.clone());
if crc32 == 0x0807_4b50 {
let central_crc32 = u32_at(bytes, central_crc32_range.clone());
if crc32 == central_crc32 {
return Err(ZipError::InvalidArchive(
"Ambiguous CRC-32 location in data descriptor".into(),
)
.into());
}
range_at(data_range.end, 4..8)?
} else {
crc32_range
}
} else {
range_at(file.header_start(), 14..18)?
};
splits.insert(crc32_range.start, crc32_range.end);
splits.insert(data_range.start, data_range.end);
splits.insert(central_crc32_range.start, central_crc32_range.end);
ranges.insert(index, (data_range, crc32_range, central_crc32_range));
index += 1;
}
let mut offset = 0;
let mut slices = HashMap::new();
for (&start, &end) in &splits {
let mid = start
.checked_sub(offset)
.ok_or(ZipError::InvalidArchive("Overlapping ranges".into()))?;
if mid > bytes.len() {
return Err(ZipError::InvalidArchive("Offset exceeds EOF".into()).into());
}
let (slice, remaining_bytes) = bytes.split_at_mut(mid);
offset += slice.len();
let mid = end - offset;
if mid > remaining_bytes.len() {
return Err(ZipError::InvalidArchive("Length exceeds EOF".into()).into());
}
let (slice, remaining_bytes) = remaining_bytes.split_at_mut(mid);
offset += slice.len();
slices.insert(start, slice);
bytes = remaining_bytes;
}
for (&index, (data_range, crc32_range, central_crc32_range)) in &ranges {
let ambiguous_offset = || ZipError::InvalidArchive("Ambiguous offsets".into());
let file = NpyViewMut {
data: slices
.remove(&data_range.start)
.ok_or_else(ambiguous_offset)?,
crc32: slices
.remove(&crc32_range.start)
.map(as_array_mut)
.ok_or_else(ambiguous_offset)?,
central_crc32: slices
.remove(¢ral_crc32_range.start)
.map(as_array_mut)
.ok_or_else(ambiguous_offset)?,
status: ChecksumStatus::default(),
};
archive.files.insert(index, file);
}
Ok(archive)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.names.is_empty()
}
#[must_use]
pub fn len(&self) -> usize {
self.names.len()
}
pub fn names(&self) -> impl Iterator<Item = &str> {
self.names.keys().map(String::as_str)
}
pub fn directory_names(&self) -> impl Iterator<Item = &str> {
self.directory_names.iter().map(String::as_str)
}
pub fn compressed_names(&self) -> impl Iterator<Item = &str> {
self.compressed_names.iter().map(String::as_str)
}
pub fn encrypted_names(&self) -> impl Iterator<Item = &str> {
self.encrypted_names.iter().map(String::as_str)
}
pub fn by_name(&mut self, name: &str) -> Result<NpyViewMut<'a>, ViewNpzError> {
self.by_index(self.names.get(name).copied().ok_or_else(|| {
if self.directory_names.contains(name) {
ViewNpzError::Directory
} else if self.compressed_names.contains(name) {
ViewNpzError::CompressedFile
} else if self.encrypted_names.contains(name) {
ViewNpzError::EncryptedFile
} else {
ZipError::FileNotFound.into()
}
})?)
}
pub fn by_index(&mut self, index: usize) -> Result<NpyViewMut<'a>, ViewNpzError> {
if index > self.names.len() {
Err(ZipError::FileNotFound.into())
} else {
self.files
.remove(&index)
.ok_or(ViewNpzError::MovedNpyViewMut)
}
}
}
#[derive(Debug)]
pub struct NpyViewMut<'a> {
data: &'a mut [u8],
crc32: &'a mut [u8; 4],
central_crc32: &'a mut [u8; 4],
status: ChecksumStatus,
}
impl NpyViewMut<'_> {
#[must_use]
pub fn status(&self) -> ChecksumStatus {
self.status
}
pub fn verify(&mut self) -> Result<u32, ViewNpzError> {
self.status = ChecksumStatus::Outdated;
let crc32 = crc32_verify(self.data, *self.central_crc32)?;
self.status = ChecksumStatus::Correct;
Ok(crc32)
}
pub fn update(&mut self) -> u32 {
self.status = ChecksumStatus::Correct;
let crc32 = crc32_update(self.data);
*self.central_crc32 = crc32.to_le_bytes();
*self.crc32 = *self.central_crc32;
crc32
}
pub fn view<A, D>(&self) -> Result<ArrayView<A, D>, ViewNpzError>
where
A: ViewElement,
D: Dimension,
{
Ok(ArrayView::<A, D>::view_npy(self.data)?)
}
pub fn view_mut<A, D>(&mut self) -> Result<ArrayViewMut<A, D>, ViewNpzError>
where
A: ViewMutElement,
D: Dimension,
{
self.status = ChecksumStatus::Outdated;
Ok(ArrayViewMut::<A, D>::view_mut_npy(self.data)?)
}
}
impl Drop for NpyViewMut<'_> {
fn drop(&mut self) {
if self.status == ChecksumStatus::Outdated {
self.update();
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChecksumStatus {
Unverified,
Correct,
Outdated,
}
impl Default for ChecksumStatus {
fn default() -> Self {
Self::Unverified
}
}
fn crc32_verify(bytes: &[u8], crc32: [u8; 4]) -> Result<u32, ZipError> {
let crc32 = u32::from_le_bytes(crc32);
if crc32_update(bytes) == crc32 {
Ok(crc32)
} else {
Err(ZipError::Io(io::Error::other("Invalid checksum")))
}
}
#[must_use]
fn crc32_update(bytes: &[u8]) -> u32 {
let mut hasher = crc32fast::Hasher::new();
hasher.update(bytes);
hasher.finalize()
}
fn range_at<T>(index: T, range: Range<T>) -> Result<Range<usize>, ZipError>
where
T: TryInto<usize> + Copy,
{
index
.try_into()
.ok()
.and_then(|index| {
let start = range.start.try_into().ok()?.checked_add(index)?;
let end = range.end.try_into().ok()?.checked_add(index)?;
Some(start..end)
})
.ok_or(ZipError::InvalidArchive("Range overflow".into()))
}
fn slice_at<T>(bytes: &[u8], index: T, range: Range<T>) -> Result<&[u8], ZipError>
where
T: TryInto<usize> + Copy,
{
let range = range_at(index, range)?;
bytes
.get(range)
.ok_or(ZipError::InvalidArchive("Range exceeds EOF".into()))
}
#[must_use]
fn u16_at(bytes: &[u8], range: Range<usize>) -> u16 {
u16::from_le_bytes(bytes.get(range).unwrap().try_into().unwrap())
}
#[must_use]
fn u32_at(bytes: &[u8], range: Range<usize>) -> u32 {
u32::from_le_bytes(bytes.get(range).unwrap().try_into().unwrap())
}
#[must_use]
fn as_array_ref(slice: &[u8]) -> &[u8; 4] {
slice.try_into().unwrap()
}
#[must_use]
fn as_array_mut(slice: &mut [u8]) -> &mut [u8; 4] {
slice.try_into().unwrap()
}