use ndarray::prelude::*;
#[test]
fn npz_view_mut() {
use aligned_vec::AVec;
use ndarray_npz::{NpzView, NpzViewMut, NpzWriter};
use std::{fs::read, io::Cursor};
let header = [0x50, 0x4b, 0x03, 0x04];
let data_descriptor = [0x50, 0x4b, 0x07, 0x08];
let header_crc32 = |buffer: &[u8], offset: usize| {
u32::from_le_bytes((&buffer[offset + 14..offset + 18]).try_into().unwrap())
};
let data_descriptor_crc32 = |buffer: &[u8], offset: usize| {
u32::from_le_bytes((&buffer[offset + 4..offset + 8]).try_into().unwrap())
};
let mut buffer = Vec::<u8>::new();
{
let mut npz = NpzWriter::new(Cursor::new(&mut buffer));
npz.add_array("x.npy", &Array1::<f64>::zeros(5)).unwrap();
npz.add_array("y.npy", &Array1::<f64>::zeros(7)).unwrap();
npz.add_array("z.npy", &Array1::<f64>::zeros(9)).unwrap();
}
let mut buffer = AVec::<u8>::from_slice(64, &buffer);
let offsets = find_subsequence(&buffer, &header);
assert_eq!(&offsets, &[0, 232, 504]);
assert_eq!(&buffer[offsets[0]..offsets[0] + 4], &header);
assert_eq!(&buffer[offsets[1]..offsets[1] + 4], &header);
assert_eq!(&buffer[offsets[2]..offsets[2] + 4], &header);
let x_crc32 = header_crc32(&buffer, offsets[0]);
let y_crc32 = header_crc32(&buffer, offsets[1]);
let z_crc32 = header_crc32(&buffer, offsets[2]);
let (x_central_crc32, y_central_crc32, z_central_crc32) = {
let mut npz = NpzViewMut::new(&mut buffer).unwrap();
let mut x_npy_view_mut = npz.by_name("x.npy").unwrap();
let x_central_crc32 = x_npy_view_mut.verify().unwrap();
assert_eq!(x_crc32, x_central_crc32);
let mut x_array_view_mut = x_npy_view_mut.view_mut::<f64, Ix1>().unwrap();
x_array_view_mut[0] = 1.0;
x_array_view_mut[3] = 8.0;
x_array_view_mut[4] = 7.0;
x_npy_view_mut.verify().unwrap_err();
let x_central_crc32 = x_npy_view_mut.update();
let mut y_npy_view_mut = npz.by_name("y.npy").unwrap();
let y_central_crc32 = y_npy_view_mut.verify().unwrap();
assert_eq!(y_crc32, y_central_crc32);
let mut z_npy_view_mut = npz.by_name("z.npy").unwrap();
let z_central_crc32 = z_npy_view_mut.verify().unwrap();
assert_eq!(z_crc32, z_central_crc32);
let mut z_array_view_mut = z_npy_view_mut.view_mut::<f64, Ix1>().unwrap();
z_array_view_mut[0] = 3.0;
z_array_view_mut[2] = 3.5;
z_array_view_mut[6] = 9.0;
z_npy_view_mut.verify().unwrap_err();
let z_central_crc32 = z_npy_view_mut.update();
(x_central_crc32, y_central_crc32, z_central_crc32)
};
let x_crc32 = header_crc32(&buffer, offsets[0]);
let y_crc32 = header_crc32(&buffer, offsets[1]);
let z_crc32 = header_crc32(&buffer, offsets[2]);
assert_eq!(x_crc32, x_central_crc32);
assert_eq!(y_crc32, y_central_crc32);
assert_eq!(z_crc32, z_central_crc32);
{
let npz = NpzView::new(&buffer).unwrap();
let mut x_npy_view = npz.by_name("x.npy").unwrap();
x_npy_view.verify().unwrap();
let x_array_view = x_npy_view.view::<f64, Ix1>().unwrap();
assert_eq!(x_array_view, ArrayView1::from(&[1.0, 0.0, 0.0, 8.0, 7.0]));
let mut y_npy_view = npz.by_name("y.npy").unwrap();
y_npy_view.verify().unwrap();
let y_array_view = y_npy_view.view::<f64, Ix1>().unwrap();
assert_eq!(y_array_view, ArrayView1::from(&[0.0; 7]));
let mut z_npy_view = npz.by_name("z.npy").unwrap();
z_npy_view.verify().unwrap();
let z_array_view = z_npy_view.view::<f64, Ix1>().unwrap();
assert_eq!(
z_array_view,
ArrayView1::from(&[3.0, 0.0, 3.5, 0.0, 0.0, 0.0, 9.0, 0.0, 0.0])
);
}
{
let mut buffer = read("tests/examples_data_descriptor.npz").unwrap();
let offsets = find_subsequence(&buffer, &data_descriptor);
assert_eq!(&offsets, &[194, 404, 614]);
assert_eq!(&buffer[offsets[0]..offsets[0] + 4], &data_descriptor);
assert_eq!(&buffer[offsets[1]..offsets[1] + 4], &data_descriptor);
assert_eq!(&buffer[offsets[2]..offsets[2] + 4], &data_descriptor);
let crc32 = data_descriptor_crc32(&buffer, offsets[0]);
let central_crc32 = {
let mut npz = NpzViewMut::new(&mut buffer).unwrap();
let mut x_npy_view_mut = npz.by_name("b8.npy").unwrap();
let central_crc32 = x_npy_view_mut.verify().unwrap();
assert_eq!(crc32, central_crc32);
let mut x_array_view_mut = x_npy_view_mut.view_mut::<bool, Ix1>().unwrap();
x_array_view_mut[0] = false;
x_array_view_mut[1] = true;
x_npy_view_mut.verify().unwrap_err();
x_npy_view_mut.update()
};
let crc32 = data_descriptor_crc32(&buffer, offsets[0]);
assert_eq!(crc32, central_crc32);
}
}
fn find_subsequence<T>(haystack: &[T], needle: &[T]) -> Vec<usize>
where
for<'a> &'a [T]: PartialEq,
{
let mut positions = Vec::new();
loop {
let skip = positions
.last()
.map(|&skip| skip + needle.len())
.unwrap_or_default();
if let Some(position) = haystack[skip..]
.windows(needle.len())
.position(|window| window == needle)
{
positions.push(skip + position);
} else {
break;
}
}
positions
}