use memmap2::{Mmap, MmapOptions};
use std::fs::File;
use std::path::Path;
use crate::error::{Result, SynaError};
pub struct MmapReader {
mmap: Mmap,
}
impl MmapReader {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path)?;
let mmap = unsafe { MmapOptions::new().map(&file)? };
Ok(Self { mmap })
}
#[inline]
pub fn len(&self) -> usize {
self.mmap.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.mmap.is_empty()
}
#[inline]
pub fn slice(&self, offset: usize, len: usize) -> &[u8] {
&self.mmap[offset..offset + len]
}
#[inline]
pub fn try_slice(&self, offset: usize, len: usize) -> Option<&[u8]> {
let end = offset.checked_add(len)?;
if end <= self.mmap.len() {
Some(&self.mmap[offset..end])
} else {
None
}
}
#[inline]
pub fn as_f32_slice(&self, offset: usize, count: usize) -> &[f32] {
let byte_len = count * std::mem::size_of::<f32>();
let bytes = &self.mmap[offset..offset + byte_len];
let (prefix, floats, _) = unsafe { bytes.align_to::<f32>() };
debug_assert!(
prefix.is_empty(),
"mmap data misaligned for f32 at offset {}",
offset
);
&floats[..count]
}
pub fn try_as_f32_slice(&self, offset: usize, count: usize) -> Result<&[f32]> {
let byte_len =
count
.checked_mul(std::mem::size_of::<f32>())
.ok_or(SynaError::ShapeMismatch {
data_size: usize::MAX,
expected_size: 0,
})?;
let end = offset
.checked_add(byte_len)
.ok_or_else(|| SynaError::ShapeMismatch {
data_size: usize::MAX,
expected_size: self.mmap.len(),
})?;
if end > self.mmap.len() {
return Err(SynaError::ShapeMismatch {
data_size: end,
expected_size: self.mmap.len(),
});
}
let bytes = &self.mmap[offset..end];
let (prefix, floats, _) = unsafe { bytes.align_to::<f32>() };
if !prefix.is_empty() {
return Err(SynaError::ShapeMismatch {
data_size: offset,
expected_size: 0, });
}
Ok(&floats[..count])
}
#[inline]
pub fn as_f64_slice(&self, offset: usize, count: usize) -> &[f64] {
let byte_len = count * std::mem::size_of::<f64>();
let bytes = &self.mmap[offset..offset + byte_len];
let (prefix, doubles, _) = unsafe { bytes.align_to::<f64>() };
debug_assert!(
prefix.is_empty(),
"mmap data misaligned for f64 at offset {}",
offset
);
&doubles[..count]
}
pub fn try_as_f64_slice(&self, offset: usize, count: usize) -> Result<&[f64]> {
let byte_len =
count
.checked_mul(std::mem::size_of::<f64>())
.ok_or(SynaError::ShapeMismatch {
data_size: usize::MAX,
expected_size: 0,
})?;
let end = offset
.checked_add(byte_len)
.ok_or_else(|| SynaError::ShapeMismatch {
data_size: usize::MAX,
expected_size: self.mmap.len(),
})?;
if end > self.mmap.len() {
return Err(SynaError::ShapeMismatch {
data_size: end,
expected_size: self.mmap.len(),
});
}
let bytes = &self.mmap[offset..end];
let (prefix, doubles, _) = unsafe { bytes.align_to::<f64>() };
if !prefix.is_empty() {
return Err(SynaError::ShapeMismatch {
data_size: offset,
expected_size: 0,
});
}
Ok(&doubles[..count])
}
#[inline]
pub fn as_i32_slice(&self, offset: usize, count: usize) -> &[i32] {
let byte_len = count * std::mem::size_of::<i32>();
let bytes = &self.mmap[offset..offset + byte_len];
let (prefix, ints, _) = unsafe { bytes.align_to::<i32>() };
debug_assert!(
prefix.is_empty(),
"mmap data misaligned for i32 at offset {}",
offset
);
&ints[..count]
}
#[inline]
pub fn as_i64_slice(&self, offset: usize, count: usize) -> &[i64] {
let byte_len = count * std::mem::size_of::<i64>();
let bytes = &self.mmap[offset..offset + byte_len];
let (prefix, longs, _) = unsafe { bytes.align_to::<i64>() };
debug_assert!(
prefix.is_empty(),
"mmap data misaligned for i64 at offset {}",
offset
);
&longs[..count]
}
#[inline]
pub fn as_ptr(&self) -> *const u8 {
self.mmap.as_ptr()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
#[test]
fn test_mmap_reader_open() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"Hello, World!").unwrap();
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
assert_eq!(reader.len(), 13);
assert!(!reader.is_empty());
}
#[test]
fn test_mmap_reader_slice() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"Hello, World!").unwrap();
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
let slice = reader.slice(0, 5);
assert_eq!(slice, b"Hello");
let slice = reader.slice(7, 5);
assert_eq!(slice, b"World");
}
#[test]
fn test_mmap_reader_try_slice() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"Hello").unwrap();
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
assert!(reader.try_slice(0, 5).is_some());
assert!(reader.try_slice(0, 100).is_none());
assert!(reader.try_slice(10, 1).is_none());
}
#[test]
fn test_mmap_reader_f32_slice() {
let mut file = NamedTempFile::new().unwrap();
let values: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0];
for v in &values {
file.write_all(&v.to_le_bytes()).unwrap();
}
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
let slice = reader.as_f32_slice(0, 4);
assert_eq!(slice.len(), 4);
assert_eq!(slice[0], 1.0);
assert_eq!(slice[1], 2.0);
assert_eq!(slice[2], 3.0);
assert_eq!(slice[3], 4.0);
}
#[test]
fn test_mmap_reader_f64_slice() {
let mut file = NamedTempFile::new().unwrap();
let values: Vec<f64> = vec![1.5, 2.5, 3.5];
for v in &values {
file.write_all(&v.to_le_bytes()).unwrap();
}
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
let slice = reader.as_f64_slice(0, 3);
assert_eq!(slice.len(), 3);
assert_eq!(slice[0], 1.5);
assert_eq!(slice[1], 2.5);
assert_eq!(slice[2], 3.5);
}
#[test]
fn test_mmap_reader_try_f32_slice_bounds() {
let mut file = NamedTempFile::new().unwrap();
let values: Vec<f32> = vec![1.0, 2.0];
for v in &values {
file.write_all(&v.to_le_bytes()).unwrap();
}
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
assert!(reader.try_as_f32_slice(0, 2).is_ok());
assert!(reader.try_as_f32_slice(0, 100).is_err());
}
#[test]
fn test_mmap_reader_try_f64_slice_bounds() {
let mut file = NamedTempFile::new().unwrap();
let values: Vec<f64> = vec![1.0, 2.0];
for v in &values {
file.write_all(&v.to_le_bytes()).unwrap();
}
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
assert!(reader.try_as_f64_slice(0, 2).is_ok());
assert!(reader.try_as_f64_slice(0, 100).is_err());
}
#[test]
fn test_mmap_reader_i32_slice() {
let mut file = NamedTempFile::new().unwrap();
let values: Vec<i32> = vec![10, 20, 30];
for v in &values {
file.write_all(&v.to_le_bytes()).unwrap();
}
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
let slice = reader.as_i32_slice(0, 3);
assert_eq!(slice, &[10, 20, 30]);
}
#[test]
fn test_mmap_reader_i64_slice() {
let mut file = NamedTempFile::new().unwrap();
let values: Vec<i64> = vec![100, 200, 300];
for v in &values {
file.write_all(&v.to_le_bytes()).unwrap();
}
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
let slice = reader.as_i64_slice(0, 3);
assert_eq!(slice, &[100, 200, 300]);
}
#[test]
fn test_mmap_reader_offset_access() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(&[0u8; 8]).unwrap(); let values: Vec<f32> = vec![1.0, 2.0, 3.0];
for v in &values {
file.write_all(&v.to_le_bytes()).unwrap();
}
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
let slice = reader.as_f32_slice(8, 3);
assert_eq!(slice, &[1.0, 2.0, 3.0]);
}
#[test]
fn test_mmap_reader_empty_file() {
let file = NamedTempFile::new().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
assert_eq!(reader.len(), 0);
assert!(reader.is_empty());
}
#[test]
fn test_mmap_reader_as_ptr() {
let mut file = NamedTempFile::new().unwrap();
file.write_all(b"test").unwrap();
file.flush().unwrap();
let reader = MmapReader::open(file.path()).unwrap();
let ptr = reader.as_ptr();
unsafe {
assert_eq!(*ptr, b't');
}
}
}