use crate::error::{IoError, Result};
use scirs2_core::ndarray::{ArrayBase, ArrayD, ArrayView, ArrayViewMut, Dimension, IxDyn};
use std::fs::{File, OpenOptions};
use std::io::Write;
use std::marker::PhantomData;
use std::path::Path;
pub struct MmapArray<T> {
mmap: memmap2::Mmap,
_file: File,
len: usize,
_phantom: PhantomData<T>,
}
pub struct MmapArrayMut<T> {
mmap: memmap2::MmapMut,
_file: File,
len: usize,
_phantom: PhantomData<T>,
}
pub struct MmapArrayBuilder<'a> {
path: &'a Path,
create: bool,
truncate: bool,
buffer_size: usize,
}
#[derive(Debug, Clone, Default)]
pub struct MmapConfig {
pub prefetch: bool,
pub page_size: Option<usize>,
pub sequential: bool,
pub random: bool,
}
impl<'a> MmapArrayBuilder<'a> {
pub fn new<P: AsRef<Path>>(path: &'a P) -> Self {
Self {
path: path.as_ref(),
create: true,
truncate: false,
buffer_size: 64 * 1024, }
}
pub fn create(mut self, create: bool) -> Self {
self.create = create;
self
}
pub fn truncate(mut self, truncate: bool) -> Self {
self.truncate = truncate;
self
}
pub fn buffer_size(mut self, size: usize) -> Self {
self.buffer_size = size;
self
}
pub fn create_from_array<S, D, T>(&self, array: &ArrayBase<S, D>) -> Result<()>
where
S: scirs2_core::ndarray::Data<Elem = T>,
D: Dimension,
T: Clone + bytemuck::Pod,
{
let mut file = OpenOptions::new()
.write(true)
.create(self.create)
.truncate(self.truncate)
.open(self.path)
.map_err(|e| IoError::FileError(format!("Failed to create file: {}", e)))?;
let shape = array.shape();
let ndim = shape.len() as u64;
file.write_all(&ndim.to_le_bytes())
.map_err(|e| IoError::FileError(format!("Failed to write metadata: {}", e)))?;
for &dim in shape {
let dim = dim as u64;
file.write_all(&dim.to_le_bytes())
.map_err(|e| IoError::FileError(format!("Failed to write shape: {}", e)))?;
}
let element_size = std::mem::size_of::<T>() as u64;
file.write_all(&element_size.to_le_bytes())
.map_err(|e| IoError::FileError(format!("Failed to write element size: {}", e)))?;
if array.is_standard_layout() {
let data_slice = bytemuck::cast_slice(array.as_slice().expect("Operation failed"));
let mut written = 0;
while written < data_slice.len() {
let chunk_size = (data_slice.len() - written).min(self.buffer_size);
let chunk = &data_slice[written..written + chunk_size];
file.write_all(chunk)
.map_err(|e| IoError::FileError(format!("Failed to write data: {}", e)))?;
written += chunk_size;
}
} else {
let owned_array = array.to_owned();
let data_slice =
bytemuck::cast_slice(owned_array.as_slice().expect("Operation failed"));
let mut written = 0;
while written < data_slice.len() {
let chunk_size = (data_slice.len() - written).min(self.buffer_size);
let chunk = &data_slice[written..written + chunk_size];
file.write_all(chunk)
.map_err(|e| IoError::FileError(format!("Failed to write data: {}", e)))?;
written += chunk_size;
}
}
file.sync_all()
.map_err(|e| IoError::FileError(format!("Failed to sync file: {}", e)))?;
Ok(())
}
pub fn create_empty<T>(&self, shape: &[usize]) -> Result<()>
where
T: bytemuck::Pod,
{
let mut file = OpenOptions::new()
.write(true)
.create(self.create)
.truncate(self.truncate)
.open(self.path)
.map_err(|e| IoError::FileError(format!("Failed to create file: {}", e)))?;
let ndim = shape.len() as u64;
file.write_all(&ndim.to_le_bytes())
.map_err(|e| IoError::FileError(format!("Failed to write metadata: {}", e)))?;
for &dim in shape {
let dim = dim as u64;
file.write_all(&dim.to_le_bytes())
.map_err(|e| IoError::FileError(format!("Failed to write shape: {}", e)))?;
}
let element_size = std::mem::size_of::<T>() as u64;
file.write_all(&element_size.to_le_bytes())
.map_err(|e| IoError::FileError(format!("Failed to write element size: {}", e)))?;
let total_elements: usize = shape.iter().product();
let total_bytes = total_elements * std::mem::size_of::<T>();
let zero_buffer = vec![0u8; self.buffer_size.min(total_bytes)];
let mut remaining = total_bytes;
while remaining > 0 {
let chunk_size = remaining.min(zero_buffer.len());
file.write_all(&zero_buffer[..chunk_size])
.map_err(|e| IoError::FileError(format!("Failed to write zeros: {}", e)))?;
remaining -= chunk_size;
}
file.sync_all()
.map_err(|e| IoError::FileError(format!("Failed to sync file: {}", e)))?;
Ok(())
}
}
impl<T> MmapArray<T>
where
T: bytemuck::Pod,
{
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path.as_ref())
.map_err(|e| IoError::FileError(format!("Failed to open file: {}", e)))?;
let file_size = file
.metadata()
.map_err(|e| IoError::FileError(format!("Failed to get file size: {}", e)))?
.len();
if file_size < 8 {
return Err(IoError::FormatError(
"File too small to contain valid array".to_string(),
));
}
let mmap = unsafe {
memmap2::Mmap::map(&file)
.map_err(|e| IoError::FileError(format!("Failed to create memory map: {}", e)))?
};
let (len_value, metadata_size) = Self::read_metadata(&mmap[..])?;
Ok(Self {
mmap,
_file: file,
len: len_value,
_phantom: PhantomData,
})
}
fn read_metadata(mmap: &[u8]) -> Result<(usize, usize)> {
if mmap.len() < 8 {
return Err(IoError::FormatError("Invalid file format".to_string()));
}
let mut offset = 0;
let ndim = u64::from_le_bytes(
mmap[offset..offset + 8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read ndim".to_string()))?,
) as usize;
offset += 8;
if ndim == 0 || ndim > 32 {
return Err(IoError::FormatError(
"Invalid number of dimensions".to_string(),
));
}
let mut total_elements = 1;
for _ in 0..ndim {
if offset + 8 > mmap.len() {
return Err(IoError::FormatError("Truncated shape data".to_string()));
}
let dim = u64::from_le_bytes(
mmap[offset..offset + 8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read dimension".to_string()))?,
) as usize;
total_elements *= dim;
offset += 8;
}
if offset + 8 > mmap.len() {
return Err(IoError::FormatError(
"Truncated element size data".to_string(),
));
}
let element_size = u64::from_le_bytes(
mmap[offset..offset + 8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read element size".to_string()))?,
) as usize;
offset += 8;
if element_size != std::mem::size_of::<T>() {
return Err(IoError::FormatError("Element size mismatch".to_string()));
}
Ok((total_elements, offset))
}
pub fn shape(&self) -> Result<Vec<usize>> {
let mut offset = 0;
let ndim = u64::from_le_bytes(
self.mmap[offset..offset + 8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read ndim".to_string()))?,
) as usize;
offset += 8;
let mut shape = Vec::with_capacity(ndim);
for _ in 0..ndim {
let dim = u64::from_le_bytes(
self.mmap[offset..offset + 8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read dimension".to_string()))?,
) as usize;
shape.push(dim);
offset += 8;
}
Ok(shape)
}
fn data_offset(&self) -> Result<usize> {
let ndim = u64::from_le_bytes(
self.mmap[0..8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read ndim".to_string()))?,
) as usize;
Ok(8 + ndim * 8 + 8)
}
pub fn as_slice(&self) -> Result<&[T]> {
let data_offset = self.data_offset()?;
let data_bytes = &self.mmap[data_offset..];
if data_bytes.len() < self.len * std::mem::size_of::<T>() {
return Err(IoError::FormatError(
"Insufficient data in file".to_string(),
));
}
Ok(bytemuck::cast_slice(
&data_bytes[..self.len * std::mem::size_of::<T>()],
))
}
pub fn as_array_view(&self, shape: &[usize]) -> Result<ArrayView<T, IxDyn>> {
let data_slice = self.as_slice()?;
let expected_len: usize = shape.iter().product();
if expected_len != self.len {
return Err(IoError::FormatError(format!(
"Shape mismatch: expected {} elements, got {}",
expected_len, self.len
)));
}
ArrayView::from_shape(IxDyn(shape), data_slice)
.map_err(|e| IoError::FormatError(format!("Failed to create array view: {}", e)))
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
impl<T> MmapArrayMut<T>
where
T: bytemuck::Pod,
{
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = OpenOptions::new()
.read(true)
.write(true)
.open(path.as_ref())
.map_err(|e| IoError::FileError(format!("Failed to open file: {}", e)))?;
let file_size = file
.metadata()
.map_err(|e| IoError::FileError(format!("Failed to get file size: {}", e)))?
.len();
if file_size < 8 {
return Err(IoError::FormatError(
"File too small to contain valid array".to_string(),
));
}
let mmap = unsafe {
memmap2::MmapMut::map_mut(&file)
.map_err(|e| IoError::FileError(format!("Failed to create memory map: {}", e)))?
};
let (len_value, metadata_size) = Self::read_metadata(&mmap)?;
Ok(Self {
mmap,
_file: file,
len: len_value,
_phantom: PhantomData,
})
}
fn read_metadata(mmap: &memmap2::MmapMut) -> Result<(usize, usize)> {
MmapArray::<T>::read_metadata(&mmap[..])
}
pub fn shape(&self) -> Result<Vec<usize>> {
let mut offset = 0;
let ndim = u64::from_le_bytes(
self.mmap[offset..offset + 8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read ndim".to_string()))?,
) as usize;
offset += 8;
let mut shape = Vec::with_capacity(ndim);
for _ in 0..ndim {
let dim = u64::from_le_bytes(
self.mmap[offset..offset + 8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read dimension".to_string()))?,
) as usize;
shape.push(dim);
offset += 8;
}
Ok(shape)
}
fn data_offset(&self) -> Result<usize> {
let ndim = u64::from_le_bytes(
self.mmap[0..8]
.try_into()
.map_err(|_| IoError::FormatError("Failed to read ndim".to_string()))?,
) as usize;
Ok(8 + ndim * 8 + 8)
}
pub fn as_slice_mut(&mut self) -> Result<&mut [T]> {
let data_offset = self.data_offset()?;
let data_len = self.len * std::mem::size_of::<T>();
if self.mmap.len() < data_offset + data_len {
return Err(IoError::FormatError(
"Insufficient data in file".to_string(),
));
}
let data_bytes = &mut self.mmap[data_offset..data_offset + data_len];
Ok(bytemuck::cast_slice_mut(data_bytes))
}
pub fn as_array_view_mut(&mut self, shape: &[usize]) -> Result<ArrayViewMut<T, IxDyn>> {
let expected_len: usize = shape.iter().product();
if expected_len != self.len {
return Err(IoError::FormatError(format!(
"Shape mismatch: expected {} elements, got {}",
expected_len, self.len
)));
}
let data_slice = self.as_slice_mut()?;
ArrayViewMut::from_shape(IxDyn(shape), data_slice)
.map_err(|e| IoError::FormatError(format!("Failed to create array view: {}", e)))
}
pub fn flush(&self) -> Result<()> {
self.mmap
.flush()
.map_err(|e| IoError::FileError(format!("Failed to flush memory map: {}", e)))
}
pub fn len(&self) -> usize {
self.len
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
#[allow(dead_code)]
pub fn create_mmap_array<P, S, D, T>(path: P, array: &ArrayBase<S, D>) -> Result<()>
where
P: AsRef<Path>,
S: scirs2_core::ndarray::Data<Elem = T>,
D: Dimension,
T: Clone + bytemuck::Pod,
{
MmapArrayBuilder::new(&path).create_from_array(array)
}
#[allow(dead_code)]
pub fn read_mmap_array<P, T>(path: P) -> Result<ArrayD<T>>
where
P: AsRef<Path>,
T: bytemuck::Pod + Clone,
{
let mmap_array = MmapArray::open(path)?;
let shape = mmap_array.shape()?;
let array_view = mmap_array.as_array_view(&shape)?;
Ok(array_view.to_owned())
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array1, Array2};
use tempfile::tempdir;
#[test]
fn test_mmap_array_1d() {
let temp_dir = tempdir().expect("Operation failed");
let file_path = temp_dir.path().join("test_1d.bin");
let data = Array1::from(vec![1.0f64, 2.0, 3.0, 4.0, 5.0]);
create_mmap_array(&file_path, &data).expect("Operation failed");
let mmap_array: MmapArray<f64> = MmapArray::open(&file_path).expect("Operation failed");
let shape = mmap_array.shape().expect("Operation failed");
assert_eq!(shape, vec![5]);
let array_view = mmap_array.as_array_view(&shape).expect("Operation failed");
assert_eq!(array_view.len(), 5);
for (i, &value) in array_view.iter().enumerate() {
assert_eq!(value, data[i]);
}
}
#[test]
fn test_mmap_array_2d() {
let temp_dir = tempdir().expect("Operation failed");
let file_path = temp_dir.path().join("test_2d.bin");
let data = array![[1.0f64, 2.0, 3.0], [4.0, 5.0, 6.0]];
create_mmap_array(&file_path, &data).expect("Operation failed");
let mmap_array: MmapArray<f64> = MmapArray::open(&file_path).expect("Operation failed");
let shape = mmap_array.shape().expect("Operation failed");
assert_eq!(shape, vec![2, 3]);
let array_view = mmap_array.as_array_view(&shape).expect("Operation failed");
assert_eq!(array_view.shape(), &[2, 3]);
for i in 0..2 {
for j in 0..3 {
let linear_index = i * 3 + j;
assert_eq!(
array_view.as_slice().expect("Operation failed")[linear_index],
data[[i, j]]
);
}
}
}
#[test]
fn test_mmap_array_mutable() {
let temp_dir = tempdir().expect("Operation failed");
let file_path = temp_dir.path().join("test_mut.bin");
let data: Array2<f64> = Array2::zeros((10, 10));
create_mmap_array(&file_path, &data).expect("Operation failed");
let mut mmap_array: MmapArrayMut<f64> =
MmapArrayMut::open(&file_path).expect("Operation failed");
let shape = mmap_array.shape().expect("Operation failed");
{
let mut array_view = mmap_array
.as_array_view_mut(&shape)
.expect("Operation failed");
let slice = array_view.as_slice_mut().expect("Operation failed");
slice[5 * 10 + 5] = 42.0; slice[10 + 2] = 13.7; }
mmap_array.flush().expect("Operation failed");
let read_array: ArrayD<f64> = read_mmap_array(&file_path).expect("Operation failed");
let read_slice = read_array.as_slice().expect("Operation failed");
assert_eq!(read_slice[5 * 10 + 5], 42.0);
assert_eq!(read_slice[10 + 2], 13.7);
assert_eq!(read_slice[0], 0.0);
}
#[test]
fn test_convenience_functions() {
let temp_dir = tempdir().expect("Operation failed");
let file_path = temp_dir.path().join("test_convenience.bin");
let original = Array2::from_shape_fn((100, 50), |(i, j)| (i + j) as f64);
create_mmap_array(&file_path, &original).expect("Operation failed");
let read_back: ArrayD<f64> = read_mmap_array(&file_path).expect("Operation failed");
assert_eq!(original.shape(), read_back.shape());
for (orig, read) in original.iter().zip(read_back.iter()) {
assert_eq!(orig, read);
}
}
#[test]
fn test_empty_array_creation() {
let temp_dir = tempdir().expect("Operation failed");
let file_path = temp_dir.path().join("test_empty.bin");
let shape = vec![100, 200];
MmapArrayBuilder::new(&file_path)
.create_empty::<f64>(&shape)
.expect("Operation failed");
let mmap_array = MmapArray::<f64>::open(&file_path).expect("Operation failed");
let readshape = mmap_array.shape().expect("Operation failed");
assert_eq!(readshape, shape);
assert_eq!(mmap_array.len(), 100 * 200);
let array_view = mmap_array.as_array_view(&shape).expect("Operation failed");
for &value in array_view.iter() {
assert_eq!(value, 0.0);
}
}
}