use std::fs::{File, OpenOptions};
use std::io::{BufReader, Seek};
use std::marker::PhantomData;
use std::path::Path;
use memmap2::{Mmap, MmapMut, MmapOptions};
use ferray_core::Array;
use ferray_core::array::view::ArrayView;
use ferray_core::dimension::IxDyn;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::format::MemmapMode;
use crate::npy::NpyElement;
use crate::npy::checked_total_elements;
use crate::npy::header::{self, NpyHeader};
pub struct MemmapArray<T: Element> {
_mmap: Mmap,
data_ptr: *const T,
shape: Vec<usize>,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: Element> Send for MemmapArray<T> {}
unsafe impl<T: Element> Sync for MemmapArray<T> {}
impl<T: Element> MemmapArray<T> {
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[must_use]
pub const fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.data_ptr, self.len) }
}
pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
let data = self.as_slice().to_vec();
Array::from_vec(IxDyn::new(&self.shape), data)
}
#[must_use]
pub fn view(&self) -> ArrayView<'_, T, IxDyn> {
let ndim = self.shape.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * self.shape[i + 1];
}
unsafe { ArrayView::from_shape_ptr(self.data_ptr, &self.shape, &strides) }
}
}
pub struct MemmapArrayMut<T: Element> {
mmap: MmapMut,
data_ptr: *mut T,
shape: Vec<usize>,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: Element> Send for MemmapArrayMut<T> {}
unsafe impl<T: Element> Sync for MemmapArrayMut<T> {}
impl<T: Element> MemmapArrayMut<T> {
#[must_use]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[must_use]
pub const fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.data_ptr, self.len) }
}
pub const fn as_slice_mut(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.data_ptr, self.len) }
}
pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
let data = self.as_slice().to_vec();
Array::from_vec(IxDyn::new(&self.shape), data)
}
#[must_use]
pub fn view(&self) -> ArrayView<'_, T, IxDyn> {
let ndim = self.shape.len();
let mut strides = vec![1usize; ndim];
for i in (0..ndim.saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * self.shape[i + 1];
}
unsafe { ArrayView::from_shape_ptr(self.data_ptr.cast_const(), &self.shape, &strides) }
}
pub fn flush(&self) -> FerrayResult<()> {
self.mmap
.flush()
.map_err(|e| FerrayError::io_error(format!("failed to flush mmap: {e}")))
}
}
pub fn memmap_readonly<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
) -> FerrayResult<MemmapArray<T>> {
let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
validate_dtype::<T>(&header)?;
validate_native_endian(&header)?;
let len = checked_total_elements(&header.shape)?;
let file = File::open(path.as_ref())?;
let mmap = unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(len * std::mem::size_of::<T>())
.map(&file)
.map_err(|e| FerrayError::io_error(format!("mmap failed: {e}")))?
};
let data_ptr = mmap.as_ptr().cast::<T>();
if (data_ptr as usize) % std::mem::align_of::<T>() != 0 {
return Err(FerrayError::io_error(
"memory-mapped data is not properly aligned for the element type",
));
}
Ok(MemmapArray {
_mmap: mmap,
data_ptr,
shape: header.shape,
len,
_marker: PhantomData,
})
}
pub fn memmap_mut<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
mode: MemmapMode,
) -> FerrayResult<MemmapArrayMut<T>> {
if mode == MemmapMode::ReadOnly {
return Err(FerrayError::invalid_value(
"use memmap_readonly for read-only access",
));
}
let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
validate_dtype::<T>(&header)?;
validate_native_endian(&header)?;
let len = checked_total_elements(&header.shape)?;
let data_bytes = len * std::mem::size_of::<T>();
let mmap = match mode {
MemmapMode::ReadWrite => {
let file = OpenOptions::new()
.read(true)
.write(true)
.open(path.as_ref())?;
unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(data_bytes)
.map_mut(&file)
.map_err(|e| FerrayError::io_error(format!("mmap_mut failed: {e}")))?
}
}
MemmapMode::CopyOnWrite => {
let file = File::open(path.as_ref())?;
unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(data_bytes)
.map_copy(&file)
.map_err(|e| FerrayError::io_error(format!("mmap copy-on-write failed: {e}")))?
}
}
MemmapMode::ReadOnly => unreachable!(),
};
let data_ptr = mmap.as_ptr().cast::<T>().cast_mut();
if (data_ptr as usize) % std::mem::align_of::<T>() != 0 {
return Err(FerrayError::io_error(
"memory-mapped data is not properly aligned for the element type",
));
}
Ok(MemmapArrayMut {
mmap,
data_ptr,
shape: header.shape,
len,
_marker: PhantomData,
})
}
pub fn open_memmap<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
mode: MemmapMode,
) -> FerrayResult<Array<T, IxDyn>> {
if mode == MemmapMode::ReadOnly {
let mapped = memmap_readonly::<T, _>(path)?;
mapped.to_array()
} else {
let mapped = memmap_mut::<T, _>(path, mode)?;
mapped.to_array()
}
}
fn read_npy_header_with_offset(path: &Path) -> FerrayResult<(NpyHeader, usize)> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let hdr = header::read_header(&mut reader)?;
let data_offset = reader
.stream_position()
.map_err(|e| FerrayError::io_error(format!("failed to get stream position: {e}")))?
as usize;
Ok((hdr, data_offset))
}
fn validate_dtype<T: Element>(header: &NpyHeader) -> FerrayResult<()> {
if header.dtype != T::dtype() {
return Err(FerrayError::invalid_dtype(format!(
"expected dtype {:?} for type {}, but file has {:?}",
T::dtype(),
std::any::type_name::<T>(),
header.dtype,
)));
}
Ok(())
}
fn validate_native_endian(header: &NpyHeader) -> FerrayResult<()> {
if header.endianness.needs_swap() {
return Err(FerrayError::io_error(
"memory-mapped arrays require native byte order; file has non-native endianness",
));
}
Ok(())
}
#[cfg(test)]
#[allow(clippy::float_cmp)] mod tests {
use super::*;
use crate::npy;
use ferray_core::dimension::Ix1;
fn test_dir() -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!("ferray_io_mmap_{}", std::process::id()));
let _ = std::fs::create_dir_all(&dir);
dir
}
fn test_file(name: &str) -> std::path::PathBuf {
test_dir().join(name)
}
#[test]
fn memmap_readonly_f64() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
let path = test_file("mm_ro_f64.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<f64, _>(&path).unwrap();
assert_eq!(mapped.shape(), &[5]);
assert_eq!(mapped.as_slice(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_to_array() {
let data = vec![10i32, 20, 30];
let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let path = test_file("mm_to_arr.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<i32, _>(&path).unwrap();
let owned = mapped.to_array().unwrap();
assert_eq!(owned.shape(), &[3]);
assert_eq!(owned.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_readwrite_persist() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let path = test_file("mm_rw.npy");
npy::save(&path, &arr).unwrap();
{
let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::ReadWrite).unwrap();
mapped.as_slice_mut()[0] = 999.0;
mapped.flush().unwrap();
}
let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap()[0], 999.0);
assert_eq!(loaded.as_slice().unwrap()[1], 2.0);
assert_eq!(loaded.as_slice().unwrap()[2], 3.0);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_copy_on_write() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let path = test_file("mm_cow.npy");
npy::save(&path, &arr).unwrap();
{
let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::CopyOnWrite).unwrap();
mapped.as_slice_mut()[0] = 999.0;
assert_eq!(mapped.as_slice()[0], 999.0);
}
let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap()[0], 1.0);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_wrong_dtype_error() {
let data = vec![1.0_f64, 2.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([2]), data).unwrap();
let path = test_file("mm_wrong_dt.npy");
npy::save(&path, &arr).unwrap();
let result = memmap_readonly::<f32, _>(&path);
assert!(result.is_err());
let _ = std::fs::remove_file(&path);
}
#[test]
fn open_memmap_readonly() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let path = test_file("mm_open_ro.npy");
npy::save(&path, &arr).unwrap();
let loaded = open_memmap::<f64, _>(&path, MemmapMode::ReadOnly).unwrap();
assert_eq!(loaded.shape(), &[3]);
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_view_borrows_underlying_data() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let arr = Array::<f64, ferray_core::dimension::Ix2>::from_vec(
ferray_core::dimension::Ix2::new([2, 3]),
data.clone(),
)
.unwrap();
let path = test_file("mm_view.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<f64, _>(&path).unwrap();
let view = mapped.view();
assert_eq!(view.shape(), &[2, 3]);
let collected: Vec<f64> = view.iter().copied().collect();
assert_eq!(collected, data);
let _ = std::fs::remove_file(&path);
}
}