1use std::fs::{File, OpenOptions};
7use std::io::{BufReader, Seek};
8use std::marker::PhantomData;
9use std::path::Path;
10
11use memmap2::{Mmap, MmapMut, MmapOptions};
12
13use ferray_core::Array;
14use ferray_core::dimension::IxDyn;
15use ferray_core::dtype::Element;
16use ferray_core::error::{FerrayError, FerrayResult};
17
18use crate::format::MemmapMode;
19use crate::npy::NpyElement;
20use crate::npy::checked_total_elements;
21use crate::npy::header::{self, NpyHeader};
22
23pub struct MemmapArray<T: Element> {
28 _mmap: Mmap,
30 data_ptr: *const T,
32 shape: Vec<usize>,
34 len: usize,
36 _marker: PhantomData<T>,
38}
39
40unsafe impl<T: Element> Send for MemmapArray<T> {}
43unsafe impl<T: Element> Sync for MemmapArray<T> {}
44
45impl<T: Element> MemmapArray<T> {
46 pub fn shape(&self) -> &[usize] {
48 &self.shape
49 }
50
51 pub fn as_slice(&self) -> &[T] {
53 unsafe { std::slice::from_raw_parts(self.data_ptr, self.len) }
56 }
57
58 pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
60 let data = self.as_slice().to_vec();
61 Array::from_vec(IxDyn::new(&self.shape), data)
62 }
63}
64
65pub struct MemmapArrayMut<T: Element> {
69 _mmap: MmapMut,
71 data_ptr: *mut T,
73 shape: Vec<usize>,
75 len: usize,
77 _marker: PhantomData<T>,
79}
80
81unsafe impl<T: Element> Send for MemmapArrayMut<T> {}
82unsafe impl<T: Element> Sync for MemmapArrayMut<T> {}
83
84impl<T: Element> MemmapArrayMut<T> {
85 pub fn shape(&self) -> &[usize] {
87 &self.shape
88 }
89
90 pub fn as_slice(&self) -> &[T] {
92 unsafe { std::slice::from_raw_parts(self.data_ptr, self.len) }
93 }
94
95 pub fn as_slice_mut(&mut self) -> &mut [T] {
100 unsafe { std::slice::from_raw_parts_mut(self.data_ptr, self.len) }
101 }
102
103 pub fn to_array(&self) -> FerrayResult<Array<T, IxDyn>> {
105 let data = self.as_slice().to_vec();
106 Array::from_vec(IxDyn::new(&self.shape), data)
107 }
108
109 pub fn flush(&self) -> FerrayResult<()> {
111 self._mmap
112 .flush()
113 .map_err(|e| FerrayError::io_error(format!("failed to flush mmap: {e}")))
114 }
115}
116
117pub fn memmap_readonly<T: Element + NpyElement, P: AsRef<Path>>(
125 path: P,
126) -> FerrayResult<MemmapArray<T>> {
127 let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
128 validate_dtype::<T>(&header)?;
129 validate_native_endian(&header)?;
130
131 let len = checked_total_elements(&header.shape)?;
132 let file = File::open(path.as_ref())?;
133 let mmap = unsafe {
134 MmapOptions::new()
135 .offset(data_offset as u64)
136 .len(len * std::mem::size_of::<T>())
137 .map(&file)
138 .map_err(|e| FerrayError::io_error(format!("mmap failed: {e}")))?
139 };
140 let data_ptr = mmap.as_ptr() as *const T;
141
142 if (data_ptr as usize) % std::mem::align_of::<T>() != 0 {
144 return Err(FerrayError::io_error(
145 "memory-mapped data is not properly aligned for the element type",
146 ));
147 }
148
149 Ok(MemmapArray {
150 _mmap: mmap,
151 data_ptr,
152 shape: header.shape,
153 len,
154 _marker: PhantomData,
155 })
156}
157
158pub fn memmap_mut<T: Element + NpyElement, P: AsRef<Path>>(
169 path: P,
170 mode: MemmapMode,
171) -> FerrayResult<MemmapArrayMut<T>> {
172 if mode == MemmapMode::ReadOnly {
173 return Err(FerrayError::invalid_value(
174 "use memmap_readonly for read-only access",
175 ));
176 }
177
178 let (header, data_offset) = read_npy_header_with_offset(path.as_ref())?;
179 validate_dtype::<T>(&header)?;
180 validate_native_endian(&header)?;
181
182 let len = checked_total_elements(&header.shape)?;
183 let data_bytes = len * std::mem::size_of::<T>();
184
185 let mmap = match mode {
186 MemmapMode::ReadWrite => {
187 let file = OpenOptions::new()
188 .read(true)
189 .write(true)
190 .open(path.as_ref())?;
191 unsafe {
192 MmapOptions::new()
193 .offset(data_offset as u64)
194 .len(data_bytes)
195 .map_mut(&file)
196 .map_err(|e| FerrayError::io_error(format!("mmap_mut failed: {e}")))?
197 }
198 }
199 MemmapMode::CopyOnWrite => {
200 let file = File::open(path.as_ref())?;
201 unsafe {
202 MmapOptions::new()
203 .offset(data_offset as u64)
204 .len(data_bytes)
205 .map_copy(&file)
206 .map_err(|e| FerrayError::io_error(format!("mmap copy-on-write failed: {e}")))?
207 }
208 }
209 MemmapMode::ReadOnly => unreachable!(),
210 };
211
212 let data_ptr = mmap.as_ptr() as *mut T;
213
214 if (data_ptr as usize) % std::mem::align_of::<T>() != 0 {
215 return Err(FerrayError::io_error(
216 "memory-mapped data is not properly aligned for the element type",
217 ));
218 }
219
220 Ok(MemmapArrayMut {
221 _mmap: mmap,
222 data_ptr,
223 shape: header.shape,
224 len,
225 _marker: PhantomData,
226 })
227}
228
229pub fn open_memmap<T: Element + NpyElement, P: AsRef<Path>>(
238 path: P,
239 mode: MemmapMode,
240) -> FerrayResult<Array<T, IxDyn>> {
241 match mode {
242 MemmapMode::ReadOnly => {
243 let mapped = memmap_readonly::<T, _>(path)?;
244 mapped.to_array()
245 }
246 _ => {
247 let mapped = memmap_mut::<T, _>(path, mode)?;
248 mapped.to_array()
249 }
250 }
251}
252
253fn read_npy_header_with_offset(path: &Path) -> FerrayResult<(NpyHeader, usize)> {
258 let file = File::open(path)?;
259 let mut reader = BufReader::new(file);
260 let hdr = header::read_header(&mut reader)?;
261
262 let data_offset = reader
265 .stream_position()
266 .map_err(|e| FerrayError::io_error(format!("failed to get stream position: {e}")))?
267 as usize;
268
269 Ok((hdr, data_offset))
270}
271
272fn validate_dtype<T: Element>(header: &NpyHeader) -> FerrayResult<()> {
273 if header.dtype != T::dtype() {
274 return Err(FerrayError::invalid_dtype(format!(
275 "expected dtype {:?} for type {}, but file has {:?}",
276 T::dtype(),
277 std::any::type_name::<T>(),
278 header.dtype,
279 )));
280 }
281 Ok(())
282}
283
284fn validate_native_endian(header: &NpyHeader) -> FerrayResult<()> {
285 if header.endianness.needs_swap() {
286 return Err(FerrayError::io_error(
287 "memory-mapped arrays require native byte order; file has non-native endianness",
288 ));
289 }
290 Ok(())
291}
292
293#[cfg(test)]
294mod tests {
295 use super::*;
296 use crate::npy;
297 use ferray_core::dimension::Ix1;
298
299 fn test_dir() -> std::path::PathBuf {
300 let dir = std::env::temp_dir().join(format!("ferray_io_mmap_{}", std::process::id()));
301 let _ = std::fs::create_dir_all(&dir);
302 dir
303 }
304
305 fn test_file(name: &str) -> std::path::PathBuf {
306 test_dir().join(name)
307 }
308
309 #[test]
310 fn memmap_readonly_f64() {
311 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
312 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
313
314 let path = test_file("mm_ro_f64.npy");
315 npy::save(&path, &arr).unwrap();
316
317 let mapped = memmap_readonly::<f64, _>(&path).unwrap();
318 assert_eq!(mapped.shape(), &[5]);
319 assert_eq!(mapped.as_slice(), &data[..]);
320 let _ = std::fs::remove_file(&path);
321 }
322
323 #[test]
324 fn memmap_to_array() {
325 let data = vec![10i32, 20, 30];
326 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
327
328 let path = test_file("mm_to_arr.npy");
329 npy::save(&path, &arr).unwrap();
330
331 let mapped = memmap_readonly::<i32, _>(&path).unwrap();
332 let owned = mapped.to_array().unwrap();
333 assert_eq!(owned.shape(), &[3]);
334 assert_eq!(owned.as_slice().unwrap(), &data[..]);
335 let _ = std::fs::remove_file(&path);
336 }
337
338 #[test]
339 fn memmap_readwrite_persist() {
340 let data = vec![1.0_f64, 2.0, 3.0];
341 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
342
343 let path = test_file("mm_rw.npy");
344 npy::save(&path, &arr).unwrap();
345
346 {
348 let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::ReadWrite).unwrap();
349 mapped.as_slice_mut()[0] = 999.0;
350 mapped.flush().unwrap();
351 }
352
353 let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
355 assert_eq!(loaded.as_slice().unwrap()[0], 999.0);
356 assert_eq!(loaded.as_slice().unwrap()[1], 2.0);
357 assert_eq!(loaded.as_slice().unwrap()[2], 3.0);
358 let _ = std::fs::remove_file(&path);
359 }
360
361 #[test]
362 fn memmap_copy_on_write() {
363 let data = vec![1.0_f64, 2.0, 3.0];
364 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
365
366 let path = test_file("mm_cow.npy");
367 npy::save(&path, &arr).unwrap();
368
369 {
371 let mut mapped = memmap_mut::<f64, _>(&path, MemmapMode::CopyOnWrite).unwrap();
372 mapped.as_slice_mut()[0] = 999.0;
373 assert_eq!(mapped.as_slice()[0], 999.0);
374 }
375
376 let loaded: Array<f64, Ix1> = npy::load(&path).unwrap();
378 assert_eq!(loaded.as_slice().unwrap()[0], 1.0);
379 let _ = std::fs::remove_file(&path);
380 }
381
382 #[test]
383 fn memmap_wrong_dtype_error() {
384 let data = vec![1.0_f64, 2.0];
385 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([2]), data).unwrap();
386
387 let path = test_file("mm_wrong_dt.npy");
388 npy::save(&path, &arr).unwrap();
389
390 let result = memmap_readonly::<f32, _>(&path);
391 assert!(result.is_err());
392 let _ = std::fs::remove_file(&path);
393 }
394
395 #[test]
396 fn open_memmap_readonly() {
397 let data = vec![1.0_f64, 2.0, 3.0];
398 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
399
400 let path = test_file("mm_open_ro.npy");
401 npy::save(&path, &arr).unwrap();
402
403 let loaded = open_memmap::<f64, _>(&path, MemmapMode::ReadOnly).unwrap();
404 assert_eq!(loaded.shape(), &[3]);
405 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
406 let _ = std::fs::remove_file(&path);
407 }
408}