use std::fs::{File, OpenOptions};
use std::io::{BufReader, Seek};
use std::marker::PhantomData;
use std::path::Path;
use memmap2::{Mmap, MmapMut, MmapOptions};
use ferray_core::Array;
use ferray_core::array::view::ArrayView;
use ferray_core::dimension::IxDyn;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::format::MemmapMode;
use crate::npy::NpyElement;
use crate::npy::checked_total_elements;
use crate::npy::header::{self, NpyHeader};
pub struct MemmapArray<T: Element> {
mmap: Mmap,
shape: Vec<usize>,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: Element> Send for MemmapArray<T> {}
unsafe impl<T: Element> Sync for MemmapArray<T> {}
impl<T: Element> MemmapArray<T> {
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.shape
}
fn data_ptr(&self) -> *const T {
self.mmap.as_ptr().cast::<T>()
}
#[must_use]
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.data_ptr(), self.len) }
}
pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
let data = self.as_slice().to_vec();
Array::from_vec(IxDyn::new(&self.shape), data)
}
#[must_use]
pub fn view(&self) -> ArrayView<'_, T, IxDyn> {
let ndim = self.shape.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * self.shape[i + 1];
}
unsafe { ArrayView::from_shape_ptr(self.data_ptr(), &self.shape, &strides) }
}
}
pub struct MemmapArrayMut<T: Element> {
mmap: MmapMut,
shape: Vec<usize>,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: Element> Send for MemmapArrayMut<T> {}
unsafe impl<T: Element> Sync for MemmapArrayMut<T> {}
impl<T: Element> MemmapArrayMut<T> {
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.shape
}
fn data_ptr(&self) -> *const T {
self.mmap.as_ptr().cast::<T>()
}
fn data_ptr_mut(&mut self) -> *mut T {
self.mmap.as_mut_ptr().cast::<T>()
}
#[must_use]
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.data_ptr(), self.len) }
}
pub fn as_slice_mut(&mut self) -> &mut [T] {
let len = self.len;
unsafe { std::slice::from_raw_parts_mut(self.data_ptr_mut(), len) }
}
pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
let data = self.as_slice().to_vec();
Array::from_vec(IxDyn::new(&self.shape), data)
}
#[must_use]
pub fn view(&self) -> ArrayView<'_, T, IxDyn> {
let ndim = self.shape.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * self.shape[i + 1];
}
unsafe { ArrayView::from_shape_ptr(self.data_ptr(), &self.shape, &strides) }
}
pub fn flush(&self) -> FerrayResult<()> {
self.mmap
.flush()
.map_err(|e| FerrayError::io_error(format!("failed to flush mmap: {e}")))
}
}
pub fn memmap_readonly<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
) -> FerrayResult<MemmapArray<T>> {
let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
validate_dtype::<T>(&header)?;
validate_native_endian(&header)?;
let len = checked_total_elements(&header.shape)?;
let file = File::open(path.as_ref())?;
let mmap = unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(len * std::mem::size_of::<T>())
.map(&file)
.map_err(|e| FerrayError::io_error(format!("mmap failed: {e}")))?
};
let probe_ptr = mmap.as_ptr().cast::<T>();
if (probe_ptr as usize) % std::mem::align_of::<T>() != 0 {
return Err(FerrayError::io_error(
"memory-mapped data is not properly aligned for the element type",
));
}
Ok(MemmapArray {
mmap,
shape: header.shape,
len,
_marker: PhantomData,
})
}
pub fn memmap_mut<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
mode: MemmapMode,
) -> FerrayResult<MemmapArrayMut<T>> {
if mode == MemmapMode::ReadOnly {
return Err(FerrayError::invalid_value(
"use memmap_readonly for read-only access",
));
}
let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
validate_dtype::<T>(&header)?;
validate_native_endian(&header)?;
let len = checked_total_elements(&header.shape)?;
let data_bytes = len * std::mem::size_of::<T>();
let mmap = match mode {
MemmapMode::ReadWrite => {
let file = OpenOptions::new()
.read(true)
.write(true)
.open(path.as_ref())?;
unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(data_bytes)
.map_mut(&file)
.map_err(|e| FerrayError::io_error(format!("mmap_mut failed: {e}")))?
}
}
MemmapMode::CopyOnWrite => {
let file = File::open(path.as_ref())?;
unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(data_bytes)
.map_copy(&file)
.map_err(|e| FerrayError::io_error(format!("mmap copy-on-write failed: {e}")))?
}
}
MemmapMode::ReadOnly => unreachable!(),
};
let probe_ptr = mmap.as_ptr().cast::<T>();
if (probe_ptr as usize) % std::mem::align_of::<T>() != 0 {
return Err(FerrayError::io_error(
"memory-mapped data is not properly aligned for the element type",
));
}
Ok(MemmapArrayMut {
mmap,
shape: header.shape,
len,
_marker: PhantomData,
})
}
pub fn open_memmap<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
mode: MemmapMode,
) -> FerrayResult<Array<T, IxDyn>> {
if mode == MemmapMode::ReadOnly {
let mapped = memmap_readonly::<T, _>(path)?;
mapped.to_array()
} else {
let mapped = memmap_mut::<T, _>(path, mode)?;
mapped.to_array()
}
}
fn read_npy_header_with_offset(path: &Path) -> FerrayResult<(NpyHeader, usize)> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let hdr = header::read_header(&mut reader)?;
let data_offset = reader
.stream_position()
.map_err(|e| FerrayError::io_error(format!("failed to get stream position: {e}")))?
as usize;
Ok((hdr, data_offset))
}
fn validate_dtype<T: Element>(header: &NpyHeader) -> FerrayResult<()> {
if header.dtype != T::dtype() {
return Err(FerrayError::invalid_dtype(format!(
"expected dtype {:?} for type {}, but file has {:?}",
T::dtype(),
std::any::type_name::<T>(),
header.dtype,
)));
}
Ok(())
}
fn validate_native_endian(header: &NpyHeader) -> FerrayResult<()> {
if header.endianness.needs_swap() {
return Err(FerrayError::io_error(
"memory-mapped arrays require native byte order; file has non-native endianness",
));
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
use crate::npy;
use ferray_core::dimension::Ix1;
fn temp_path(name: &str) -> (tempfile::TempDir, std::path::PathBuf) {
let dir = tempfile::TempDir::new().expect("failed to create test TempDir");
let path = dir.path().join(name);
(dir, path)
}
#[test]
fn memmap_readonly_f64() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
let (_dir, path) = temp_path("mm_ro_f64.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<f64, _>(&path).unwrap();
assert_eq!(mapped.shape(), &[5]);
assert_eq!(mapped.as_slice(), &data[..]);
}
#[test]
fn memmap_to_array() {
let data = vec![10i32, 20, 30];
let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let (_dir, path) = temp_path("mm_to_arr.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<i32, _>(&path).unwrap();
let owned = mapped.to_array().unwrap();
assert_eq!(owned.shape(), &[3]);
assert_eq!(owned.as_slice().unwrap(), &data[..]);
}
#[test]
fn memmap_readwrite_persist() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let (_dir, path) = temp_path("mm_rw.npy");
npy::save(&path, &arr).unwrap();
{
let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::ReadWrite).unwrap();
mapped.as_slice_mut()[0] = 999.0;
mapped.flush().unwrap();
}
let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap()[0], 999.0);
assert_eq!(loaded.as_slice().unwrap()[1], 2.0);
assert_eq!(loaded.as_slice().unwrap()[2], 3.0);
}
#[test]
fn memmap_copy_on_write() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let (_dir, path) = temp_path("mm_cow.npy");
npy::save(&path, &arr).unwrap();
{
let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::CopyOnWrite).unwrap();
mapped.as_slice_mut()[0] = 999.0;
assert_eq!(mapped.as_slice()[0], 999.0);
}
let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap()[0], 1.0);
}
#[test]
fn memmap_wrong_dtype_error() {
let data = vec![1.0_f64, 2.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([2]), data).unwrap();
let (_dir, path) = temp_path("mm_wrong_dt.npy");
npy::save(&path, &arr).unwrap();
let result = memmap_readonly::<f32, _>(&path);
assert!(result.is_err());
}
#[test]
fn open_memmap_readonly() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let (_dir, path) = temp_path("mm_open_ro.npy");
npy::save(&path, &arr).unwrap();
let loaded = open_memmap::<f64, _>(&path, MemmapMode::ReadOnly).unwrap();
assert_eq!(loaded.shape(), &[3]);
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
}
#[test]
fn memmap_view_borrows_underlying_data() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let arr = Array::<f64, ferray_core::dimension::Ix2>::from_vec(
ferray_core::dimension::Ix2::new([2, 3]),
data.clone(),
)
.unwrap();
let (_dir, path) = temp_path("mm_view.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<f64, _>(&path).unwrap();
let view = mapped.view();
assert_eq!(view.shape(), &[2, 3]);
let collected: Vec<f64> = view.iter().copied().collect();
assert_eq!(collected, data);
}
#[test]
fn memmap_readonly_2d_shape_and_data() {
let data: Vec<f64> = (0..12).map(|i| i as f64 + 0.5).collect();
let arr = Array::<f64, ferray_core::dimension::Ix2>::from_vec(
ferray_core::dimension::Ix2::new([3, 4]),
data.clone(),
)
.unwrap();
let (_dir, path) = temp_path("mm_ro_2d.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<f64, _>(&path).unwrap();
assert_eq!(mapped.shape(), &[3, 4]);
assert_eq!(mapped.as_slice(), &data[..]);
}
#[test]
fn memmap_readonly_3d_shape_and_data() {
let data: Vec<i32> = (0..24).collect();
let arr = Array::<i32, ferray_core::dimension::Ix3>::from_vec(
ferray_core::dimension::Ix3::new([2, 3, 4]),
data.clone(),
)
.unwrap();
let (_dir, path) = temp_path("mm_ro_3d.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<i32, _>(&path).unwrap();
assert_eq!(mapped.shape(), &[2, 3, 4]);
assert_eq!(mapped.as_slice(), &data[..]);
}
#[test]
fn memmap_readwrite_2d_persists() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let arr = Array::<f64, ferray_core::dimension::Ix2>::from_vec(
ferray_core::dimension::Ix2::new([2, 3]),
data,
)
.unwrap();
let (_dir, path) = temp_path("mm_rw_2d.npy");
npy::save(&path, &arr).unwrap();
{
let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::ReadWrite).unwrap();
assert_eq!(mapped.shape(), &[2, 3]);
mapped.as_slice_mut()[5] = -42.0;
mapped.flush().unwrap();
}
let loaded: Array<f64, ferray_core::dimension::Ix2> = npy::load(&path).unwrap();
assert_eq!(loaded.shape(), &[2, 3]);
let row1: Vec<f64> = loaded.iter().copied().collect();
assert_eq!(row1[5], -42.0);
assert_eq!(row1[..5], [1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn memmap_view_3d_strides_match_row_major() {
let data: Vec<f64> = (0..24).map(|i| i as f64).collect();
let arr = Array::<f64, ferray_core::dimension::Ix3>::from_vec(
ferray_core::dimension::Ix3::new([2, 3, 4]),
data.clone(),
)
.unwrap();
let (_dir, path) = temp_path("mm_view_3d.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<f64, _>(&path).unwrap();
let view = mapped.view();
assert_eq!(view.shape(), &[2, 3, 4]);
let collected: Vec<f64> = view.iter().copied().collect();
assert_eq!(collected, data);
}
#[test]
fn memmap_open_memmap_2d_readonly() {
let data: Vec<f32> = (0..15).map(|i| i as f32 * 0.1).collect();
let arr = Array::<f32, ferray_core::dimension::Ix2>::from_vec(
ferray_core::dimension::Ix2::new([3, 5]),
data.clone(),
)
.unwrap();
let (_dir, path) = temp_path("mm_open_ro_2d.npy");
npy::save(&path, &arr).unwrap();
let loaded = open_memmap::<f32, _>(&path, MemmapMode::ReadOnly).unwrap();
assert_eq!(loaded.shape(), &[3, 5]);
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
}
#[test]
fn memmap_copy_on_write_2d_isolates_changes() {
let data = vec![10.0_f64, 20.0, 30.0, 40.0];
let arr = Array::<f64, ferray_core::dimension::Ix2>::from_vec(
ferray_core::dimension::Ix2::new([2, 2]),
data.clone(),
)
.unwrap();
let (_dir, path) = temp_path("mm_cow_2d.npy");
npy::save(&path, &arr).unwrap();
{
let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::CopyOnWrite).unwrap();
assert_eq!(mapped.shape(), &[2, 2]);
mapped.as_slice_mut()[3] = 999.0;
assert_eq!(mapped.as_slice()[3], 999.0);
}
let loaded: Array<f64, ferray_core::dimension::Ix2> = npy::load(&path).unwrap();
let flat: Vec<f64> = loaded.iter().copied().collect();
assert_eq!(flat, data);
}
}