1use std::fs::{File, OpenOptions};
7use std::io::{BufReader, Seek};
8use std::marker::PhantomData;
9use std::path::Path;
10
11use memmap2::{Mmap, MmapMut, MmapOptions};
12
13use ferray_core::Array;
14use ferray_core::array::view::ArrayView;
15use ferray_core::dimension::IxDyn;
16use ferray_core::dtype::Element;
17use ferray_core::error::{FerrayError, FerrayResult};
18
19use crate::format::MemmapMode;
20use crate::npy::NpyElement;
21use crate::npy::checked_total_elements;
22use crate::npy::header::{self, NpyHeader};
23
24pub struct MemmapArray<T: Element> {
29 _mmap: Mmap,
31 data_ptr: *const T,
33 shape: Vec<usize>,
35 len: usize,
37 _marker: PhantomData<T>,
39}
40
41unsafe impl<T: Element> Send for MemmapArray<T> {}
44unsafe impl<T: Element> Sync for MemmapArray<T> {}
45
46impl<T: Element> MemmapArray<T> {
47 pub fn shape(&self) -> &[usize] {
49 &self.shape
50 }
51
52 pub fn as_slice(&self) -> &[T] {
54 unsafe { std::slice::from_raw_parts(self.data_ptr, self.len) }
57 }
58
59 pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
61 let data = self.as_slice().to_vec();
62 Array::from_vec(IxDyn::new(&self.shape), data)
63 }
64
65 pub fn view(&self) -> ArrayView<'_, T, IxDyn> {
70 let ndim = self.shape.len();
72 let mut strides = vec![1usize; ndim];
73 for i in (0..ndim.saturating_sub(1)).rev() {
74 strides[i] = strides[i + 1] * self.shape[i + 1];
75 }
76 unsafe { ArrayView::from_shape_ptr(self.data_ptr, &self.shape, &strides) }
81 }
82}
83
84pub struct MemmapArrayMut<T: Element> {
88 _mmap: MmapMut,
90 data_ptr: *mut T,
92 shape: Vec<usize>,
94 len: usize,
96 _marker: PhantomData<T>,
98}
99
100unsafe impl<T: Element> Send for MemmapArrayMut<T> {}
101unsafe impl<T: Element> Sync for MemmapArrayMut<T> {}
102
103impl<T: Element> MemmapArrayMut<T> {
104 pub fn shape(&self) -> &[usize] {
106 &self.shape
107 }
108
109 pub fn as_slice(&self) -> &[T] {
111 unsafe { std::slice::from_raw_parts(self.data_ptr, self.len) }
112 }
113
114 pub fn as_slice_mut(&mut self) -> &mut [T] {
119 unsafe { std::slice::from_raw_parts_mut(self.data_ptr, self.len) }
120 }
121
122 pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
124 let data = self.as_slice().to_vec();
125 Array::from_vec(IxDyn::new(&self.shape), data)
126 }
127
128 pub fn view(&self) -> ArrayView<'_, T, IxDyn> {
132 let ndim = self.shape.len();
133 let mut strides = vec![1usize; ndim];
134 for i in (0..ndim.saturating_sub(1)).rev() {
135 strides[i] = strides[i + 1] * self.shape[i + 1];
136 }
137 unsafe { ArrayView::from_shape_ptr(self.data_ptr as *const T, &self.shape, &strides) }
140 }
141
142 pub fn flush(&self) -> FerrayResult<()> {
144 self._mmap
145 .flush()
146 .map_err(|e| FerrayError::io_error(format!("failed to flush mmap: {e}")))
147 }
148}
149
150pub fn memmap_readonly<T: Element + NpyElement, P: AsRef<Path>>(
158 path: P,
159) -> FerrayResult<MemmapArray<T>> {
160 let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
161 validate_dtype::<T>(&header)?;
162 validate_native_endian(&header)?;
163
164 let len = checked_total_elements(&header.shape)?;
165 let file = File::open(path.as_ref())?;
166 let mmap = unsafe {
167 MmapOptions::new()
168 .offset(data_offset as u64)
169 .len(len * std::mem::size_of::<T>())
170 .map(&file)
171 .map_err(|e| FerrayError::io_error(format!("mmap failed: {e}")))?
172 };
173 let data_ptr = mmap.as_ptr() as *const T;
174
175 if (data_ptr as usize) % std::mem::align_of::<T>() != 0 {
177 return Err(FerrayError::io_error(
178 "memory-mapped data is not properly aligned for the element type",
179 ));
180 }
181
182 Ok(MemmapArray {
183 _mmap: mmap,
184 data_ptr,
185 shape: header.shape,
186 len,
187 _marker: PhantomData,
188 })
189}
190
191pub fn memmap_mut<T: Element + NpyElement, P: AsRef<Path>>(
202 path: P,
203 mode: MemmapMode,
204) -> FerrayResult<MemmapArrayMut<T>> {
205 if mode == MemmapMode::ReadOnly {
206 return Err(FerrayError::invalid_value(
207 "use memmap_readonly for read-only access",
208 ));
209 }
210
211 let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
212 validate_dtype::<T>(&header)?;
213 validate_native_endian(&header)?;
214
215 let len = checked_total_elements(&header.shape)?;
216 let data_bytes = len * std::mem::size_of::<T>();
217
218 let mmap = match mode {
219 MemmapMode::ReadWrite => {
220 let file = OpenOptions::new()
221 .read(true)
222 .write(true)
223 .open(path.as_ref())?;
224 unsafe {
225 MmapOptions::new()
226 .offset(data_offset as u64)
227 .len(data_bytes)
228 .map_mut(&file)
229 .map_err(|e| FerrayError::io_error(format!("mmap_mut failed: {e}")))?
230 }
231 }
232 MemmapMode::CopyOnWrite => {
233 let file = File::open(path.as_ref())?;
234 unsafe {
235 MmapOptions::new()
236 .offset(data_offset as u64)
237 .len(data_bytes)
238 .map_copy(&file)
239 .map_err(|e| FerrayError::io_error(format!("mmap copy-on-write failed: {e}")))?
240 }
241 }
242 MemmapMode::ReadOnly => unreachable!(),
243 };
244
245 let data_ptr = mmap.as_ptr() as *mut T;
246
247 if (data_ptr as usize) % std::mem::align_of::<T>() != 0 {
248 return Err(FerrayError::io_error(
249 "memory-mapped data is not properly aligned for the element type",
250 ));
251 }
252
253 Ok(MemmapArrayMut {
254 _mmap: mmap,
255 data_ptr,
256 shape: header.shape,
257 len,
258 _marker: PhantomData,
259 })
260}
261
262pub fn open_memmap<T: Element + NpyElement, P: AsRef<Path>>(
272 path: P,
273 mode: MemmapMode,
274) -> FerrayResult<Array<T, IxDyn>> {
275 match mode {
276 MemmapMode::ReadOnly => {
277 let mapped = memmap_readonly::<T, _>(path)?;
278 mapped.to_array()
279 }
280 _ => {
281 let mapped = memmap_mut::<T, _>(path, mode)?;
282 mapped.to_array()
283 }
284 }
285}
286
287fn read_npy_header_with_offset(path: &Path) -> FerrayResult<(NpyHeader, usize)> {
292 let file = File::open(path)?;
293 let mut reader = BufReader::new(file);
294 let hdr = header::read_header(&mut reader)?;
295
296 let data_offset = reader
299 .stream_position()
300 .map_err(|e| FerrayError::io_error(format!("failed to get stream position: {e}")))?
301 as usize;
302
303 Ok((hdr, data_offset))
304}
305
306fn validate_dtype<T: Element>(header: &NpyHeader) -> FerrayResult<()> {
307 if header.dtype != T::dtype() {
308 return Err(FerrayError::invalid_dtype(format!(
309 "expected dtype {:?} for type {}, but file has {:?}",
310 T::dtype(),
311 std::any::type_name::<T>(),
312 header.dtype,
313 )));
314 }
315 Ok(())
316}
317
318fn validate_native_endian(header: &NpyHeader) -> FerrayResult<()> {
319 if header.endianness.needs_swap() {
320 return Err(FerrayError::io_error(
321 "memory-mapped arrays require native byte order; file has non-native endianness",
322 ));
323 }
324 Ok(())
325}
326
327#[cfg(test)]
328mod tests {
329 use super::*;
330 use crate::npy;
331 use ferray_core::dimension::Ix1;
332
333 fn test_dir() -> std::path::PathBuf {
334 let dir = std::env::temp_dir().join(format!("ferray_io_mmap_{}", std::process::id()));
335 let _ = std::fs::create_dir_all(&dir);
336 dir
337 }
338
339 fn test_file(name: &str) -> std::path::PathBuf {
340 test_dir().join(name)
341 }
342
343 #[test]
344 fn memmap_readonly_f64() {
345 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
346 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
347
348 let path = test_file("mm_ro_f64.npy");
349 npy::save(&path, &arr).unwrap();
350
351 let mapped = memmap_readonly::<f64, _>(&path).unwrap();
352 assert_eq!(mapped.shape(), &[5]);
353 assert_eq!(mapped.as_slice(), &data[..]);
354 let _ = std::fs::remove_file(&path);
355 }
356
357 #[test]
358 fn memmap_to_array() {
359 let data = vec![10i32, 20, 30];
360 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
361
362 let path = test_file("mm_to_arr.npy");
363 npy::save(&path, &arr).unwrap();
364
365 let mapped = memmap_readonly::<i32, _>(&path).unwrap();
366 let owned = mapped.to_array().unwrap();
367 assert_eq!(owned.shape(), &[3]);
368 assert_eq!(owned.as_slice().unwrap(), &data[..]);
369 let _ = std::fs::remove_file(&path);
370 }
371
372 #[test]
373 fn memmap_readwrite_persist() {
374 let data = vec![1.0_f64, 2.0, 3.0];
375 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
376
377 let path = test_file("mm_rw.npy");
378 npy::save(&path, &arr).unwrap();
379
380 {
382 let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::ReadWrite).unwrap();
383 mapped.as_slice_mut()[0] = 999.0;
384 mapped.flush().unwrap();
385 }
386
387 let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
389 assert_eq!(loaded.as_slice().unwrap()[0], 999.0);
390 assert_eq!(loaded.as_slice().unwrap()[1], 2.0);
391 assert_eq!(loaded.as_slice().unwrap()[2], 3.0);
392 let _ = std::fs::remove_file(&path);
393 }
394
395 #[test]
396 fn memmap_copy_on_write() {
397 let data = vec![1.0_f64, 2.0, 3.0];
398 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
399
400 let path = test_file("mm_cow.npy");
401 npy::save(&path, &arr).unwrap();
402
403 {
405 let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::CopyOnWrite).unwrap();
406 mapped.as_slice_mut()[0] = 999.0;
407 assert_eq!(mapped.as_slice()[0], 999.0);
408 }
409
410 let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
412 assert_eq!(loaded.as_slice().unwrap()[0], 1.0);
413 let _ = std::fs::remove_file(&path);
414 }
415
416 #[test]
417 fn memmap_wrong_dtype_error() {
418 let data = vec![1.0_f64, 2.0];
419 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([2]), data).unwrap();
420
421 let path = test_file("mm_wrong_dt.npy");
422 npy::save(&path, &arr).unwrap();
423
424 let result = memmap_readonly::<f32, _>(&path);
425 assert!(result.is_err());
426 let _ = std::fs::remove_file(&path);
427 }
428
429 #[test]
430 fn open_memmap_readonly() {
431 let data = vec![1.0_f64, 2.0, 3.0];
432 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
433
434 let path = test_file("mm_open_ro.npy");
435 npy::save(&path, &arr).unwrap();
436
437 let loaded = open_memmap::<f64, _>(&path, MemmapMode::ReadOnly).unwrap();
438 assert_eq!(loaded.shape(), &[3]);
439 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
440 let _ = std::fs::remove_file(&path);
441 }
442
443 #[test]
444 fn memmap_view_borrows_underlying_data() {
445 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
449 let arr = Array::<f64, ferray_core::dimension::Ix2>::from_vec(
450 ferray_core::dimension::Ix2::new([2, 3]),
451 data.clone(),
452 )
453 .unwrap();
454
455 let path = test_file("mm_view.npy");
456 npy::save(&path, &arr).unwrap();
457
458 let mapped = memmap_readonly::<f64, _>(&path).unwrap();
459 let view = mapped.view();
460 assert_eq!(view.shape(), &[2, 3]);
461 let collected: Vec<f64> = view.iter().copied().collect();
462 assert_eq!(collected, data);
463 let _ = std::fs::remove_file(&path);
464 }
465}