use std::fs::{File, OpenOptions};
use std::io::{BufReader, Seek};
use std::marker::PhantomData;
use std::path::Path;
use memmap2::{Mmap, MmapMut, MmapOptions};
use ferray_core::Array;
use ferray_core::dimension::IxDyn;
use ferray_core::dtype::Element;
use ferray_core::error::{FerrayError, FerrayResult};
use crate::format::MemmapMode;
use crate::npy::NpyElement;
use crate::npy::checked_total_elements;
use crate::npy::header::{self, NpyHeader};
pub struct MemmapArray<T: Element> {
_mmap: Mmap,
data_ptr: *const T,
shape: Vec<usize>,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: Element> Send for MemmapArray<T> {}
unsafe impl<T: Element> Sync for MemmapArray<T> {}
impl<T: Element> MemmapArray<T> {
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.data_ptr, self.len) }
}
pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
let data = self.as_slice().to_vec();
Array::from_vec(IxDyn::new(&self.shape), data)
}
}
pub struct MemmapArrayMut<T: Element> {
_mmap: MmapMut,
data_ptr: *mut T,
shape: Vec<usize>,
len: usize,
_marker: PhantomData<T>,
}
unsafe impl<T: Element> Send for MemmapArrayMut<T> {}
unsafe impl<T: Element> Sync for MemmapArrayMut<T> {}
impl<T: Element> MemmapArrayMut<T> {
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn as_slice(&self) -> &[T] {
unsafe { std::slice::from_raw_parts(self.data_ptr, self.len) }
}
pub fn as_slice_mut(&mut self) -> &mut [T] {
unsafe { std::slice::from_raw_parts_mut(self.data_ptr, self.len) }
}
pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
let data = self.as_slice().to_vec();
Array::from_vec(IxDyn::new(&self.shape), data)
}
pub fn flush(&self) -> FerrayResult<()> {
self._mmap
.flush()
.map_err(|e| FerrayError::io_error(format!("failed to flush mmap: {e}")))
}
}
pub fn memmap_readonly<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
) -> FerrayResult<MemmapArray<T>> {
let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
validate_dtype::<T>(&header)?;
validate_native_endian(&header)?;
let len = checked_total_elements(&header.shape)?;
let file = File::open(path.as_ref())?;
let mmap = unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(len * std::mem::size_of::<T>())
.map(&file)
.map_err(|e| FerrayError::io_error(format!("mmap failed: {e}")))?
};
let data_ptr = mmap.as_ptr() as *const T;
if (data_ptr as usize) % std::mem::align_of::<T>() != 0 {
return Err(FerrayError::io_error(
"memory-mapped data is not properly aligned for the element type",
));
}
Ok(MemmapArray {
_mmap: mmap,
data_ptr,
shape: header.shape,
len,
_marker: PhantomData,
})
}
pub fn memmap_mut<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
mode: MemmapMode,
) -> FerrayResult<MemmapArrayMut<T>> {
if mode == MemmapMode::ReadOnly {
return Err(FerrayError::invalid_value(
"use memmap_readonly for read-only access",
));
}
let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
validate_dtype::<T>(&header)?;
validate_native_endian(&header)?;
let len = checked_total_elements(&header.shape)?;
let data_bytes = len * std::mem::size_of::<T>();
let mmap = match mode {
MemmapMode::ReadWrite => {
let file = OpenOptions::new()
.read(true)
.write(true)
.open(path.as_ref())?;
unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(data_bytes)
.map_mut(&file)
.map_err(|e| FerrayError::io_error(format!("mmap_mut failed: {e}")))?
}
}
MemmapMode::CopyOnWrite => {
let file = File::open(path.as_ref())?;
unsafe {
MmapOptions::new()
.offset(data_offset as u64)
.len(data_bytes)
.map_copy(&file)
.map_err(|e| FerrayError::io_error(format!("mmap copy-on-write failed: {e}")))?
}
}
MemmapMode::ReadOnly => unreachable!(),
};
let data_ptr = mmap.as_ptr() as *mut T;
if (data_ptr as usize) % std::mem::align_of::<T>() != 0 {
return Err(FerrayError::io_error(
"memory-mapped data is not properly aligned for the element type",
));
}
Ok(MemmapArrayMut {
_mmap: mmap,
data_ptr,
shape: header.shape,
len,
_marker: PhantomData,
})
}
pub fn open_memmap<T: Element + NpyElement, P: AsRef<Path>>(
path: P,
mode: MemmapMode,
) -> FerrayResult<Array<T, IxDyn>> {
match mode {
MemmapMode::ReadOnly => {
let mapped = memmap_readonly::<T, _>(path)?;
mapped.to_array()
}
_ => {
let mapped = memmap_mut::<T, _>(path, mode)?;
mapped.to_array()
}
}
}
fn read_npy_header_with_offset(path: &Path) -> FerrayResult<(NpyHeader, usize)> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let hdr = header::read_header(&mut reader)?;
let data_offset = reader
.stream_position()
.map_err(|e| FerrayError::io_error(format!("failed to get stream position: {e}")))?
as usize;
Ok((hdr, data_offset))
}
fn validate_dtype<T: Element>(header: &NpyHeader) -> FerrayResult<()> {
if header.dtype != T::dtype() {
return Err(FerrayError::invalid_dtype(format!(
"expected dtype {:?} for type {}, but file has {:?}",
T::dtype(),
std::any::type_name::<T>(),
header.dtype,
)));
}
Ok(())
}
fn validate_native_endian(header: &NpyHeader) -> FerrayResult<()> {
if header.endianness.needs_swap() {
return Err(FerrayError::io_error(
"memory-mapped arrays require native byte order; file has non-native endianness",
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::npy;
use ferray_core::dimension::Ix1;
fn test_dir() -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!("ferray_io_mmap_{}", std::process::id()));
let _ = std::fs::create_dir_all(&dir);
dir
}
fn test_file(name: &str) -> std::path::PathBuf {
test_dir().join(name)
}
#[test]
fn memmap_readonly_f64() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
let path = test_file("mm_ro_f64.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<f64, _>(&path).unwrap();
assert_eq!(mapped.shape(), &[5]);
assert_eq!(mapped.as_slice(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_to_array() {
let data = vec![10i32, 20, 30];
let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let path = test_file("mm_to_arr.npy");
npy::save(&path, &arr).unwrap();
let mapped = memmap_readonly::<i32, _>(&path).unwrap();
let owned = mapped.to_array().unwrap();
assert_eq!(owned.shape(), &[3]);
assert_eq!(owned.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_readwrite_persist() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let path = test_file("mm_rw.npy");
npy::save(&path, &arr).unwrap();
{
let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::ReadWrite).unwrap();
mapped.as_slice_mut()[0] = 999.0;
mapped.flush().unwrap();
}
let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap()[0], 999.0);
assert_eq!(loaded.as_slice().unwrap()[1], 2.0);
assert_eq!(loaded.as_slice().unwrap()[2], 3.0);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_copy_on_write() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let path = test_file("mm_cow.npy");
npy::save(&path, &arr).unwrap();
{
let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::CopyOnWrite).unwrap();
mapped.as_slice_mut()[0] = 999.0;
assert_eq!(mapped.as_slice()[0], 999.0);
}
let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap()[0], 1.0);
let _ = std::fs::remove_file(&path);
}
#[test]
fn memmap_wrong_dtype_error() {
let data = vec![1.0_f64, 2.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([2]), data).unwrap();
let path = test_file("mm_wrong_dt.npy");
npy::save(&path, &arr).unwrap();
let result = memmap_readonly::<f32, _>(&path);
assert!(result.is_err());
let _ = std::fs::remove_file(&path);
}
#[test]
fn open_memmap_readonly() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let path = test_file("mm_open_ro.npy");
npy::save(&path, &arr).unwrap();
let loaded = open_memmap::<f64, _>(&path, MemmapMode::ReadOnly).unwrap();
assert_eq!(loaded.shape(), &[3]);
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
}