use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use crate::memory_optimize::cache_layout::{optimize_layout, LayoutStrategy};
use memmap2::{MmapMut, MmapOptions};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt;
use std::fs::{File, OpenOptions};
use std::io::{Read, Write};
use std::marker::PhantomData;
use std::mem;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use std::time::SystemTime;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MmapConfig {
pub layout_strategy: LayoutStrategy,
pub write_back: bool,
pub prefetch: PrefetchStrategy,
pub alignment: usize,
pub page_size_hint: Option<usize>,
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub enum PrefetchStrategy {
None,
Sequential,
Random,
Adaptive,
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct AccessPattern {
recent_accesses: Vec<Vec<usize>>,
access_count: HashMap<Vec<usize>, u64>,
last_access: SystemTime,
pattern_type: AccessPatternType,
}
#[derive(Debug, Clone, Copy, PartialEq)]
#[allow(dead_code)]
pub enum AccessPatternType {
Unknown,
Sequential,
Strided,
Random,
Blocked,
}
lazy_static::lazy_static! {
static ref GLOBAL_MMAP_CACHE: Mutex<HashMap<PathBuf, Arc<Mutex<AccessPattern>>>> =
Mutex::new(HashMap::new());
}
#[derive(Debug)]
pub struct MmapArray<T: Copy> {
mmap: MmapMut,
shape: Vec<usize>,
size: usize,
path: PathBuf,
config: MmapConfig,
access_pattern: Arc<Mutex<AccessPattern>>,
data_offset: usize,
#[allow(dead_code)]
page_size: usize,
_phantom: PhantomData<T>,
}
#[derive(Serialize, Deserialize, Debug)]
pub struct MmapArrayMeta {
pub type_name: String,
pub type_size: usize,
pub shape: Vec<usize>,
pub size: usize,
pub version: u8,
pub config: Option<MmapConfig>,
pub checksum: Option<u64>,
pub created_at: u64,
pub modified_at: u64,
}
impl Default for MmapConfig {
fn default() -> Self {
Self {
layout_strategy: LayoutStrategy::RowMajor,
write_back: true,
prefetch: PrefetchStrategy::Adaptive,
alignment: 0, page_size_hint: None,
}
}
}
impl Default for AccessPattern {
fn default() -> Self {
Self {
recent_accesses: Vec::with_capacity(100),
access_count: HashMap::new(),
last_access: SystemTime::now(),
pattern_type: AccessPatternType::Unknown,
}
}
}
#[allow(dead_code)]
impl<T: Copy> MmapArray<T> {
pub fn new<P: AsRef<Path>>(path: &P, shape: &[usize], create: bool) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let size: usize = shape.iter().product();
let data_size = size * mem::size_of::<T>();
let meta_size = calculate_meta_size(shape);
let total_size = meta_size + data_size;
let file = if create {
let file = OpenOptions::new()
.read(true)
.write(true)
.create(true)
.truncate(true)
.open(&path)?;
file.set_len(total_size as u64)?;
let meta = MmapArrayMeta {
type_name: std::any::type_name::<T>().to_string(),
type_size: mem::size_of::<T>(),
shape: shape.to_vec(),
size,
version: 1,
config: None,
checksum: None,
created_at: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
modified_at: SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
let config = oxicode::config::standard();
let meta_bytes = oxicode::serde::encode_to_vec(&meta, config)
.map_err(|e| NumRs2Error::SerializationError(e.to_string()))?;
let mut file = file;
file.write_all(&meta_bytes)?;
file
} else {
let mut file = File::open(&path)?;
let mut meta_bytes = vec![0u8; meta_size];
file.read_exact(&mut meta_bytes)?;
let config = oxicode::config::standard();
let (meta, _): (MmapArrayMeta, usize) =
oxicode::serde::decode_from_slice(&meta_bytes, config)
.map_err(|e| NumRs2Error::DeserializationError(e.to_string()))?;
if meta.type_name != std::any::type_name::<T>() {
return Err(NumRs2Error::InvalidOperation(format!(
"Type mismatch: file contains '{}', but requested '{}'",
meta.type_name,
std::any::type_name::<T>()
)));
}
if meta.shape != shape {
return Err(NumRs2Error::ShapeMismatch {
expected: shape.to_vec(),
actual: meta.shape,
});
}
file
};
let mmap = unsafe { MmapOptions::new().map_mut(&file)? };
if mmap.len() != total_size {
return Err(NumRs2Error::InvalidOperation(format!(
"File size mismatch: expected {} bytes, got {} bytes",
total_size,
mmap.len()
)));
}
let access_pattern = get_or_create_access_pattern(&path);
let config = MmapConfig::default();
let page_size = get_page_size();
let data_offset = meta_size;
Ok(Self {
mmap,
shape: shape.to_vec(),
size,
path,
config,
access_pattern,
data_offset,
page_size,
_phantom: PhantomData,
})
}
pub fn get(&self, indices: &[usize]) -> Result<T> {
if indices.len() != self.shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} indices, got {}",
self.shape.len(),
indices.len()
)));
}
let offset = self.calculate_offset(indices)?;
let byte_offset = self.data_offset + offset * mem::size_of::<T>();
if byte_offset + mem::size_of::<T>() > self.mmap.len() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index out of bounds: offset {} exceeds mmap size {}",
byte_offset,
self.mmap.len()
)));
}
let bytes = &self.mmap[byte_offset..byte_offset + mem::size_of::<T>()];
let value = unsafe { *(bytes.as_ptr() as *const T) };
Ok(value)
}
pub fn set(&mut self, indices: &[usize], value: T) -> Result<()> {
if indices.len() != self.shape.len() {
return Err(NumRs2Error::DimensionMismatch(format!(
"Expected {} indices, got {}",
self.shape.len(),
indices.len()
)));
}
let offset = self.calculate_offset(indices)?;
let byte_offset = self.data_offset + offset * mem::size_of::<T>();
if byte_offset + mem::size_of::<T>() > self.mmap.len() {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index out of bounds: offset {} exceeds mmap size {}",
byte_offset,
self.mmap.len()
)));
}
let bytes = unsafe {
std::slice::from_raw_parts(&value as *const T as *const u8, mem::size_of::<T>())
};
self.mmap[byte_offset..byte_offset + mem::size_of::<T>()].copy_from_slice(bytes);
Ok(())
}
fn calculate_offset(&self, indices: &[usize]) -> Result<usize> {
for (i, &idx) in indices.iter().enumerate() {
if idx >= self.shape[i] {
return Err(NumRs2Error::IndexOutOfBounds(format!(
"Index {} out of bounds for dimension {}: {}",
idx, i, self.shape[i]
)));
}
}
let mut offset = 0;
let mut stride = 1;
for i in (0..indices.len()).rev() {
offset += indices[i] * stride;
stride *= self.shape[i];
}
Ok(offset)
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn size(&self) -> usize {
self.size
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn flush(&mut self) -> Result<()> {
self.mmap.flush()?;
Ok(())
}
pub fn to_array(&self) -> Result<Array<T>> {
let mut data = Vec::with_capacity(self.size);
let meta_size = calculate_meta_size(&self.shape);
let data_start = meta_size;
let _data_end = meta_size + self.size * mem::size_of::<T>();
for i in 0..self.size {
let byte_offset = data_start + i * mem::size_of::<T>();
let bytes = &self.mmap[byte_offset..byte_offset + mem::size_of::<T>()];
let value = unsafe { *(bytes.as_ptr() as *const T) };
data.push(value);
}
let array = Array::from_vec(data).reshape(&self.shape);
Ok(array)
}
pub fn from_array<P: AsRef<Path>>(array: &Array<T>, path: &P) -> Result<Self> {
let shape = array.shape();
let mut mmap_array = Self::new(path, &shape, true)?;
let data = array.to_vec();
let meta_size = calculate_meta_size(&shape);
let data_start = meta_size;
for (i, &value) in data.iter().enumerate() {
let byte_offset = data_start + i * mem::size_of::<T>();
let bytes = unsafe {
std::slice::from_raw_parts(&value as *const T as *const u8, mem::size_of::<T>())
};
mmap_array.mmap[byte_offset..byte_offset + mem::size_of::<T>()].copy_from_slice(bytes);
}
mmap_array.flush()?;
Ok(mmap_array)
}
fn track_access(&self, indices: &[usize]) {
if let Ok(mut pattern) = self.access_pattern.lock() {
pattern.last_access = SystemTime::now();
if pattern.recent_accesses.len() >= 100 {
pattern.recent_accesses.remove(0);
}
pattern.recent_accesses.push(indices.to_vec());
*pattern.access_count.entry(indices.to_vec()).or_insert(0) += 1;
pattern.pattern_type = self.detect_access_pattern(&pattern.recent_accesses);
}
}
fn detect_access_pattern(&self, accesses: &[Vec<usize>]) -> AccessPatternType {
if accesses.len() < 3 {
return AccessPatternType::Unknown;
}
let mut sequential = true;
let mut stride = None;
for i in 1..accesses.len() {
if accesses[i].len() != accesses[i - 1].len() {
sequential = false;
break;
}
let last_dim = accesses[i].len() - 1;
let current_stride = accesses[i][last_dim] as i64 - accesses[i - 1][last_dim] as i64;
if let Some(expected_stride) = stride {
if current_stride != expected_stride {
sequential = false;
break;
}
} else {
stride = Some(current_stride);
}
}
if sequential {
if stride == Some(1) {
return AccessPatternType::Sequential;
} else if stride.is_some() {
return AccessPatternType::Strided;
}
}
AccessPatternType::Random
}
fn prefetch_if_needed(&self, indices: &[usize]) -> Result<()> {
if self.config.prefetch == PrefetchStrategy::None {
return Ok(());
}
let pattern = self
.access_pattern
.lock()
.expect("Access pattern mutex poisoned");
match (self.config.prefetch, pattern.pattern_type) {
(PrefetchStrategy::Sequential, AccessPatternType::Sequential)
| (PrefetchStrategy::Adaptive, AccessPatternType::Sequential) => {
self.prefetch_sequential(indices)?;
}
(PrefetchStrategy::Adaptive, AccessPatternType::Strided) => {
self.prefetch_strided(indices)?;
}
_ => {}
}
Ok(())
}
fn prefetch_sequential(&self, indices: &[usize]) -> Result<()> {
const PREFETCH_SIZE: usize = 8;
let mut next_indices = indices.to_vec();
let last_dim = next_indices.len() - 1;
for i in 1..=PREFETCH_SIZE {
if next_indices[last_dim] + i < self.shape[last_dim] {
next_indices[last_dim] += 1;
let offset = self.calculate_offset(&next_indices)?;
let byte_offset = self.data_offset + offset * mem::size_of::<T>();
if byte_offset + mem::size_of::<T>() <= self.mmap.len() {
let _ = self.mmap[byte_offset];
}
}
}
Ok(())
}
fn prefetch_strided(&self, _indices: &[usize]) -> Result<()> {
Ok(())
}
fn optimize_layout(&mut self) -> Result<()> {
if self.config.layout_strategy == LayoutStrategy::RowMajor {
return Ok(()); }
let data = self.get_all_data()?;
let mut optimized_data = data;
optimize_layout(&mut optimized_data, self.config.layout_strategy);
self.set_all_data(&optimized_data)?;
Ok(())
}
fn get_all_data(&self) -> Result<Vec<T>> {
let mut data = Vec::with_capacity(self.size);
let data_start = self.data_offset;
let element_size = mem::size_of::<T>();
for i in 0..self.size {
let byte_offset = data_start + i * element_size;
if byte_offset + element_size <= self.mmap.len() {
let bytes = &self.mmap[byte_offset..byte_offset + element_size];
let value = unsafe { *(bytes.as_ptr() as *const T) };
data.push(value);
}
}
Ok(data)
}
fn set_all_data(&mut self, data: &[T]) -> Result<()> {
if data.len() != self.size {
return Err(NumRs2Error::InvalidOperation(format!(
"Data size mismatch: expected {}, got {}",
self.size,
data.len()
)));
}
let data_start = self.data_offset;
let element_size = mem::size_of::<T>();
for (i, &value) in data.iter().enumerate() {
let byte_offset = data_start + i * element_size;
if byte_offset + element_size <= self.mmap.len() {
let bytes = unsafe {
std::slice::from_raw_parts(&value as *const T as *const u8, element_size)
};
self.mmap[byte_offset..byte_offset + element_size].copy_from_slice(bytes);
}
}
Ok(())
}
pub fn config(&self) -> &MmapConfig {
&self.config
}
pub fn update_config(&mut self, config: MmapConfig) {
self.config = config;
}
pub fn access_stats(&self) -> Option<(AccessPatternType, usize)> {
if let Ok(pattern) = self.access_pattern.lock() {
Some((pattern.pattern_type, pattern.recent_accesses.len()))
} else {
None
}
}
}
impl<T: Copy + fmt::Debug> fmt::Display for MmapArray<T> {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
writeln!(
f,
"MmapArray(shape={:?}, file={})",
self.shape,
self.path.display()
)?;
const MAX_ELEMENTS: usize = 10;
let preview_size = std::cmp::min(self.size, MAX_ELEMENTS);
write!(f, "Data preview: [")?;
for i in 0..preview_size {
let mut indices = Vec::with_capacity(self.ndim());
let mut remaining = i;
for &dim in self.shape.iter().rev() {
indices.insert(0, remaining % dim);
remaining /= dim;
}
if i > 0 {
write!(f, ", ")?;
}
match self.get(&indices) {
Ok(value) => write!(f, "{:?}", value)?,
Err(_) => write!(f, "<?>")?,
}
}
if self.size > MAX_ELEMENTS {
write!(f, ", ...]")?;
} else {
write!(f, "]")?;
}
Ok(())
}
}
fn calculate_meta_size(_shape: &[usize]) -> usize {
1024 }
fn get_page_size() -> usize {
4096
}
#[allow(dead_code)]
fn align_to_page(size: usize, page_size: usize) -> usize {
(size + page_size - 1) & !(page_size - 1)
}
#[allow(dead_code)]
fn apply_memory_advice(_mmap: &mut MmapMut, _config: &MmapConfig) {
#[cfg(unix)]
{
}
}
fn get_or_create_access_pattern(path: &Path) -> Arc<Mutex<AccessPattern>> {
let mut cache = GLOBAL_MMAP_CACHE
.lock()
.expect("Global mmap cache mutex poisoned");
cache
.entry(path.to_path_buf())
.or_insert_with(|| Arc::new(Mutex::new(AccessPattern::default())))
.clone()
}
pub fn open_mmap_info<P: AsRef<Path>>(path: &P) -> Result<MmapArrayMeta> {
let mut file = File::open(path)?;
let meta_size = 1024; let mut meta_bytes = vec![0u8; meta_size];
file.read_exact(&mut meta_bytes)?;
let config = oxicode::config::standard();
let (meta, _): (MmapArrayMeta, usize) = oxicode::serde::decode_from_slice(&meta_bytes, config)
.map_err(|e| NumRs2Error::DeserializationError(e.to_string()))?;
Ok(meta)
}