next_plaid/
mmap.rs

1//! Memory-mapped file support for efficient large index loading.
2//!
3//! This module provides utilities for loading large arrays from disk using
4//! memory-mapped files, avoiding the need to load entire arrays into RAM.
5//!
6//! Two formats are supported:
7//! - Custom raw binary format (legacy): 8-byte header with shape, then raw data
8//! - NPY format: Standard NumPy format with header, used for index files
9
10use std::collections::HashMap;
11use std::fs;
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Write};
14use std::path::Path;
15
16use byteorder::{LittleEndian, ReadBytesExt};
17use memmap2::Mmap;
18use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
19
20use crate::error::{Error, Result};
21
22/// A memory-mapped array of f32 values.
23///
24/// This struct provides zero-copy access to large arrays stored on disk.
25pub struct MmapArray2F32 {
26    _mmap: Mmap,
27    shape: (usize, usize),
28    data_offset: usize,
29}
30
31impl MmapArray2F32 {
32    /// Load a 2D f32 array from a raw binary file.
33    ///
34    /// The file format is:
35    /// - 8 bytes: nrows (i64 little-endian)
36    /// - 8 bytes: ncols (i64 little-endian)
37    /// - nrows * ncols * 4 bytes: f32 data (little-endian)
38    pub fn from_raw_file(path: &Path) -> Result<Self> {
39        let file = File::open(path)
40            .map_err(|e| Error::IndexLoad(format!("Failed to open file {:?}: {}", path, e)))?;
41
42        let mmap = unsafe {
43            Mmap::map(&file)
44                .map_err(|e| Error::IndexLoad(format!("Failed to mmap file {:?}: {}", path, e)))?
45        };
46
47        if mmap.len() < 16 {
48            return Err(Error::IndexLoad("File too small for header".into()));
49        }
50
51        // Read shape from header
52        let mut cursor = std::io::Cursor::new(&mmap[..16]);
53        let nrows = cursor
54            .read_i64::<LittleEndian>()
55            .map_err(|e| Error::IndexLoad(format!("Failed to read nrows: {}", e)))?
56            as usize;
57        let ncols = cursor
58            .read_i64::<LittleEndian>()
59            .map_err(|e| Error::IndexLoad(format!("Failed to read ncols: {}", e)))?
60            as usize;
61
62        let expected_size = 16 + nrows * ncols * 4;
63        if mmap.len() < expected_size {
64            return Err(Error::IndexLoad(format!(
65                "File size {} too small for shape ({}, {})",
66                mmap.len(),
67                nrows,
68                ncols
69            )));
70        }
71
72        Ok(Self {
73            _mmap: mmap,
74            shape: (nrows, ncols),
75            data_offset: 16,
76        })
77    }
78
79    /// Get the shape of the array.
80    pub fn shape(&self) -> (usize, usize) {
81        self.shape
82    }
83
84    /// Get the number of rows.
85    pub fn nrows(&self) -> usize {
86        self.shape.0
87    }
88
89    /// Get the number of columns.
90    pub fn ncols(&self) -> usize {
91        self.shape.1
92    }
93
94    /// Get a view of a row.
95    pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
96        let start = self.data_offset + idx * self.shape.1 * 4;
97        let bytes = &self._mmap[start..start + self.shape.1 * 4];
98
99        // Safety: We've verified the bounds and alignment
100        let data =
101            unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, self.shape.1) };
102
103        ArrayView1::from_shape(self.shape.1, data).unwrap()
104    }
105
106    /// Load a range of rows into an owned Array2.
107    pub fn load_rows(&self, start: usize, end: usize) -> Array2<f32> {
108        let nrows = end - start;
109        let byte_start = self.data_offset + start * self.shape.1 * 4;
110        let byte_end = self.data_offset + end * self.shape.1 * 4;
111        let bytes = &self._mmap[byte_start..byte_end];
112
113        // Safety: We've verified the bounds
114        let data = unsafe {
115            std::slice::from_raw_parts(bytes.as_ptr() as *const f32, nrows * self.shape.1)
116        };
117
118        Array2::from_shape_vec((nrows, self.shape.1), data.to_vec()).unwrap()
119    }
120
121    /// Convert to an owned Array2 (loads all data into memory).
122    pub fn to_owned(&self) -> Array2<f32> {
123        self.load_rows(0, self.shape.0)
124    }
125}
126
127/// A memory-mapped array of u8 values.
128pub struct MmapArray2U8 {
129    _mmap: Mmap,
130    shape: (usize, usize),
131    data_offset: usize,
132}
133
134impl MmapArray2U8 {
135    /// Load a 2D u8 array from a raw binary file.
136    pub fn from_raw_file(path: &Path) -> Result<Self> {
137        let file = File::open(path)
138            .map_err(|e| Error::IndexLoad(format!("Failed to open file {:?}: {}", path, e)))?;
139
140        let mmap = unsafe {
141            Mmap::map(&file)
142                .map_err(|e| Error::IndexLoad(format!("Failed to mmap file {:?}: {}", path, e)))?
143        };
144
145        if mmap.len() < 16 {
146            return Err(Error::IndexLoad("File too small for header".into()));
147        }
148
149        let mut cursor = std::io::Cursor::new(&mmap[..16]);
150        let nrows = cursor
151            .read_i64::<LittleEndian>()
152            .map_err(|e| Error::IndexLoad(format!("Failed to read nrows: {}", e)))?
153            as usize;
154        let ncols = cursor
155            .read_i64::<LittleEndian>()
156            .map_err(|e| Error::IndexLoad(format!("Failed to read ncols: {}", e)))?
157            as usize;
158
159        let expected_size = 16 + nrows * ncols;
160        if mmap.len() < expected_size {
161            return Err(Error::IndexLoad(format!(
162                "File size {} too small for shape ({}, {})",
163                mmap.len(),
164                nrows,
165                ncols
166            )));
167        }
168
169        Ok(Self {
170            _mmap: mmap,
171            shape: (nrows, ncols),
172            data_offset: 16,
173        })
174    }
175
176    /// Get the shape of the array.
177    pub fn shape(&self) -> (usize, usize) {
178        self.shape
179    }
180
181    /// Get a view of the data as ArrayView2.
182    pub fn view(&self) -> ArrayView2<'_, u8> {
183        let bytes = &self._mmap[self.data_offset..self.data_offset + self.shape.0 * self.shape.1];
184        ArrayView2::from_shape(self.shape, bytes).unwrap()
185    }
186
187    /// Load a range of rows into an owned Array2.
188    pub fn load_rows(&self, start: usize, end: usize) -> Array2<u8> {
189        let nrows = end - start;
190        let byte_start = self.data_offset + start * self.shape.1;
191        let byte_end = self.data_offset + end * self.shape.1;
192        let bytes = &self._mmap[byte_start..byte_end];
193
194        Array2::from_shape_vec((nrows, self.shape.1), bytes.to_vec()).unwrap()
195    }
196
197    /// Convert to an owned Array2.
198    pub fn to_owned(&self) -> Array2<u8> {
199        self.load_rows(0, self.shape.0)
200    }
201}
202
203/// A memory-mapped array of i64 values.
204pub struct MmapArray1I64 {
205    _mmap: Mmap,
206    len: usize,
207    data_offset: usize,
208}
209
210impl MmapArray1I64 {
211    /// Load a 1D i64 array from a raw binary file.
212    pub fn from_raw_file(path: &Path) -> Result<Self> {
213        let file = File::open(path)
214            .map_err(|e| Error::IndexLoad(format!("Failed to open file {:?}: {}", path, e)))?;
215
216        let mmap = unsafe {
217            Mmap::map(&file)
218                .map_err(|e| Error::IndexLoad(format!("Failed to mmap file {:?}: {}", path, e)))?
219        };
220
221        if mmap.len() < 8 {
222            return Err(Error::IndexLoad("File too small for header".into()));
223        }
224
225        let mut cursor = std::io::Cursor::new(&mmap[..8]);
226        let len = cursor
227            .read_i64::<LittleEndian>()
228            .map_err(|e| Error::IndexLoad(format!("Failed to read length: {}", e)))?
229            as usize;
230
231        let expected_size = 8 + len * 8;
232        if mmap.len() < expected_size {
233            return Err(Error::IndexLoad(format!(
234                "File size {} too small for length {}",
235                mmap.len(),
236                len
237            )));
238        }
239
240        Ok(Self {
241            _mmap: mmap,
242            len,
243            data_offset: 8,
244        })
245    }
246
247    /// Get the length of the array.
248    pub fn len(&self) -> usize {
249        self.len
250    }
251
252    /// Returns true if the array is empty.
253    pub fn is_empty(&self) -> bool {
254        self.len == 0
255    }
256
257    /// Get a value at an index.
258    pub fn get(&self, idx: usize) -> i64 {
259        let start = self.data_offset + idx * 8;
260        let bytes = &self._mmap[start..start + 8];
261        i64::from_le_bytes(bytes.try_into().unwrap())
262    }
263
264    /// Convert to an owned Array1.
265    pub fn to_owned(&self) -> Array1<i64> {
266        let bytes = &self._mmap[self.data_offset..self.data_offset + self.len * 8];
267
268        // Safety: We've verified the bounds
269        let data = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const i64, self.len) };
270
271        Array1::from_vec(data.to_vec())
272    }
273}
274
275/// Write an `Array2<f32>` to a raw binary file format.
276pub fn write_array2_f32(array: &Array2<f32>, path: &Path) -> Result<()> {
277    use std::io::Write;
278
279    let file = File::create(path)
280        .map_err(|e| Error::IndexLoad(format!("Failed to create file {:?}: {}", path, e)))?;
281    let mut writer = std::io::BufWriter::new(file);
282
283    let nrows = array.nrows() as i64;
284    let ncols = array.ncols() as i64;
285
286    writer
287        .write_all(&nrows.to_le_bytes())
288        .map_err(|e| Error::IndexLoad(format!("Failed to write nrows: {}", e)))?;
289    writer
290        .write_all(&ncols.to_le_bytes())
291        .map_err(|e| Error::IndexLoad(format!("Failed to write ncols: {}", e)))?;
292
293    for val in array.iter() {
294        writer
295            .write_all(&val.to_le_bytes())
296            .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
297    }
298
299    writer
300        .flush()
301        .map_err(|e| Error::IndexLoad(format!("Failed to flush: {}", e)))?;
302
303    Ok(())
304}
305
306/// Write an `Array2<u8>` to a raw binary file format.
307pub fn write_array2_u8(array: &Array2<u8>, path: &Path) -> Result<()> {
308    use std::io::Write;
309
310    let file = File::create(path)
311        .map_err(|e| Error::IndexLoad(format!("Failed to create file {:?}: {}", path, e)))?;
312    let mut writer = std::io::BufWriter::new(file);
313
314    let nrows = array.nrows() as i64;
315    let ncols = array.ncols() as i64;
316
317    writer
318        .write_all(&nrows.to_le_bytes())
319        .map_err(|e| Error::IndexLoad(format!("Failed to write nrows: {}", e)))?;
320    writer
321        .write_all(&ncols.to_le_bytes())
322        .map_err(|e| Error::IndexLoad(format!("Failed to write ncols: {}", e)))?;
323
324    for row in array.rows() {
325        writer
326            .write_all(row.as_slice().unwrap())
327            .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
328    }
329
330    writer
331        .flush()
332        .map_err(|e| Error::IndexLoad(format!("Failed to flush: {}", e)))?;
333
334    Ok(())
335}
336
337/// Write an `Array1<i64>` to a raw binary file format.
338pub fn write_array1_i64(array: &Array1<i64>, path: &Path) -> Result<()> {
339    use std::io::Write;
340
341    let file = File::create(path)
342        .map_err(|e| Error::IndexLoad(format!("Failed to create file {:?}: {}", path, e)))?;
343    let mut writer = std::io::BufWriter::new(file);
344
345    let len = array.len() as i64;
346
347    writer
348        .write_all(&len.to_le_bytes())
349        .map_err(|e| Error::IndexLoad(format!("Failed to write length: {}", e)))?;
350
351    for val in array.iter() {
352        writer
353            .write_all(&val.to_le_bytes())
354            .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
355    }
356
357    writer
358        .flush()
359        .map_err(|e| Error::IndexLoad(format!("Failed to flush: {}", e)))?;
360
361    Ok(())
362}
363
364// ============================================================================
365// NPY Format Memory-Mapped Arrays
366// ============================================================================
367
368/// NPY file magic bytes
369const NPY_MAGIC: &[u8] = b"\x93NUMPY";
370
371/// Parse dtype from NPY header string (e.g., "<f2" for float16, "<f4" for float32)
372fn parse_dtype_from_header(header: &str) -> Result<String> {
373    // Find 'descr': '...'
374    let descr_start = header
375        .find("'descr':")
376        .ok_or_else(|| Error::IndexLoad("No descr in NPY header".into()))?;
377
378    let after_descr = &header[descr_start + 8..];
379    let quote_start = after_descr
380        .find('\'')
381        .ok_or_else(|| Error::IndexLoad("No dtype quote in NPY header".into()))?;
382    let rest = &after_descr[quote_start + 1..];
383    let quote_end = rest
384        .find('\'')
385        .ok_or_else(|| Error::IndexLoad("Unclosed dtype quote in NPY header".into()))?;
386
387    Ok(rest[..quote_end].to_string())
388}
389
390/// Detect NPY file dtype without loading the entire file
391pub fn detect_npy_dtype(path: &Path) -> Result<String> {
392    let file = File::open(path)
393        .map_err(|e| Error::IndexLoad(format!("Failed to open NPY file {:?}: {}", path, e)))?;
394
395    let mmap = unsafe {
396        Mmap::map(&file)
397            .map_err(|e| Error::IndexLoad(format!("Failed to mmap NPY file {:?}: {}", path, e)))?
398    };
399
400    if mmap.len() < 10 {
401        return Err(Error::IndexLoad("NPY file too small".into()));
402    }
403
404    // Check magic
405    if &mmap[..6] != NPY_MAGIC {
406        return Err(Error::IndexLoad("Invalid NPY magic".into()));
407    }
408
409    let major_version = mmap[6];
410
411    // Read header length
412    let header_len = if major_version == 1 {
413        u16::from_le_bytes([mmap[8], mmap[9]]) as usize
414    } else if major_version == 2 {
415        if mmap.len() < 12 {
416            return Err(Error::IndexLoad("NPY v2 file too small".into()));
417        }
418        u32::from_le_bytes([mmap[8], mmap[9], mmap[10], mmap[11]]) as usize
419    } else {
420        return Err(Error::IndexLoad(format!(
421            "Unsupported NPY version: {}",
422            major_version
423        )));
424    };
425
426    let header_start = if major_version == 1 { 10 } else { 12 };
427    let header_end = header_start + header_len;
428
429    if mmap.len() < header_end {
430        return Err(Error::IndexLoad("NPY header exceeds file size".into()));
431    }
432
433    let header_str = std::str::from_utf8(&mmap[header_start..header_end])
434        .map_err(|e| Error::IndexLoad(format!("Invalid NPY header encoding: {}", e)))?;
435
436    parse_dtype_from_header(header_str)
437}
438
439/// Convert a float16 NPY file to float32 in place
440pub fn convert_f16_to_f32_npy(path: &Path) -> Result<()> {
441    use half::f16;
442    use std::io::Read;
443
444    // Read the entire file
445    let mut file = File::open(path)
446        .map_err(|e| Error::IndexLoad(format!("Failed to open {:?}: {}", path, e)))?;
447    let mut data = Vec::new();
448    file.read_to_end(&mut data)
449        .map_err(|e| Error::IndexLoad(format!("Failed to read {:?}: {}", path, e)))?;
450
451    if data.len() < 10 || &data[..6] != NPY_MAGIC {
452        return Err(Error::IndexLoad("Invalid NPY file".into()));
453    }
454
455    let major_version = data[6];
456    let header_start = if major_version == 1 { 10 } else { 12 };
457    let header_len = if major_version == 1 {
458        u16::from_le_bytes([data[8], data[9]]) as usize
459    } else {
460        u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
461    };
462    let header_end = header_start + header_len;
463
464    // Parse header to get shape
465    let header_str = std::str::from_utf8(&data[header_start..header_end])
466        .map_err(|e| Error::IndexLoad(format!("Invalid header: {}", e)))?;
467    let shape = parse_shape_from_header(header_str)?;
468
469    // Calculate total elements
470    let total_elements: usize = shape.iter().product();
471    let f16_data = &data[header_end..header_end + total_elements * 2];
472
473    // Convert f16 to f32
474    let mut f32_data = Vec::with_capacity(total_elements * 4);
475    for chunk in f16_data.chunks(2) {
476        let f16_val = f16::from_le_bytes([chunk[0], chunk[1]]);
477        let f32_val: f32 = f16_val.to_f32();
478        f32_data.extend_from_slice(&f32_val.to_le_bytes());
479    }
480
481    // Write new file with f32 dtype
482    let file = File::create(path)
483        .map_err(|e| Error::IndexLoad(format!("Failed to create {:?}: {}", path, e)))?;
484    let mut writer = BufWriter::new(file);
485
486    if shape.len() == 1 {
487        write_npy_header_1d(&mut writer, shape[0], "<f4")?;
488    } else if shape.len() == 2 {
489        write_npy_header_2d(&mut writer, shape[0], shape[1], "<f4")?;
490    } else {
491        return Err(Error::IndexLoad("Unsupported shape dimensions".into()));
492    }
493
494    writer
495        .write_all(&f32_data)
496        .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
497    writer.flush()?;
498
499    Ok(())
500}
501
502/// Convert an int64 NPY file to int32 in place
503pub fn convert_i64_to_i32_npy(path: &Path) -> Result<()> {
504    use std::io::Read;
505
506    // Read the entire file
507    let mut file = File::open(path)
508        .map_err(|e| Error::IndexLoad(format!("Failed to open {:?}: {}", path, e)))?;
509    let mut data = Vec::new();
510    file.read_to_end(&mut data)
511        .map_err(|e| Error::IndexLoad(format!("Failed to read {:?}: {}", path, e)))?;
512
513    if data.len() < 10 || &data[..6] != NPY_MAGIC {
514        return Err(Error::IndexLoad("Invalid NPY file".into()));
515    }
516
517    let major_version = data[6];
518    let header_start = if major_version == 1 { 10 } else { 12 };
519    let header_len = if major_version == 1 {
520        u16::from_le_bytes([data[8], data[9]]) as usize
521    } else {
522        u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
523    };
524    let header_end = header_start + header_len;
525
526    // Parse header to get shape
527    let header_str = std::str::from_utf8(&data[header_start..header_end])
528        .map_err(|e| Error::IndexLoad(format!("Invalid header: {}", e)))?;
529    let shape = parse_shape_from_header(header_str)?;
530
531    if shape.len() != 1 {
532        return Err(Error::IndexLoad("Expected 1D array for i64->i32".into()));
533    }
534
535    let len = shape[0];
536    let i64_data = &data[header_end..header_end + len * 8];
537
538    // Convert i64 to i32
539    let mut i32_data = Vec::with_capacity(len * 4);
540    for chunk in i64_data.chunks(8) {
541        let i64_val = i64::from_le_bytes(chunk.try_into().unwrap());
542        let i32_val = i64_val as i32;
543        i32_data.extend_from_slice(&i32_val.to_le_bytes());
544    }
545
546    // Write new file with i32 dtype
547    let file = File::create(path)
548        .map_err(|e| Error::IndexLoad(format!("Failed to create {:?}: {}", path, e)))?;
549    let mut writer = BufWriter::new(file);
550
551    write_npy_header_1d(&mut writer, len, "<i4")?;
552
553    writer
554        .write_all(&i32_data)
555        .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
556    writer.flush()?;
557
558    Ok(())
559}
560
561/// Re-save a u8 NPY file to ensure dtype descriptor is "|u1" (platform-independent)
562///
563/// Note: We can't use ndarray_npy::ReadNpyExt here because it doesn't accept "<u1"
564/// descriptor, so we manually read the raw data and resave with "|u1".
565pub fn normalize_u8_npy(path: &Path) -> Result<()> {
566    use std::io::Read;
567
568    // Read the entire file
569    let mut file = File::open(path)
570        .map_err(|e| Error::IndexLoad(format!("Failed to open {:?}: {}", path, e)))?;
571    let mut data = Vec::new();
572    file.read_to_end(&mut data)
573        .map_err(|e| Error::IndexLoad(format!("Failed to read {:?}: {}", path, e)))?;
574
575    if data.len() < 10 || &data[..6] != NPY_MAGIC {
576        return Err(Error::IndexLoad("Invalid NPY file".into()));
577    }
578
579    let major_version = data[6];
580    let header_start = if major_version == 1 { 10 } else { 12 };
581    let header_len = if major_version == 1 {
582        u16::from_le_bytes([data[8], data[9]]) as usize
583    } else {
584        u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
585    };
586    let header_end = header_start + header_len;
587
588    // Parse header to get shape
589    let header_str = std::str::from_utf8(&data[header_start..header_end])
590        .map_err(|e| Error::IndexLoad(format!("Invalid header: {}", e)))?;
591    let shape = parse_shape_from_header(header_str)?;
592
593    if shape.len() != 2 {
594        return Err(Error::IndexLoad(
595            "Expected 2D array for u8 normalization".into(),
596        ));
597    }
598
599    let nrows = shape[0];
600    let ncols = shape[1];
601    let u8_data = &data[header_end..header_end + nrows * ncols];
602
603    // Re-write with explicit "|u1" dtype
604    let new_file = File::create(path)
605        .map_err(|e| Error::IndexLoad(format!("Failed to create {:?}: {}", path, e)))?;
606    let mut writer = BufWriter::new(new_file);
607
608    write_npy_header_2d(&mut writer, nrows, ncols, "|u1")?;
609
610    writer
611        .write_all(u8_data)
612        .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
613    writer.flush()?;
614
615    Ok(())
616}
617
618/// Parse NPY header and return (shape, data_offset, is_fortran_order)
619fn parse_npy_header(mmap: &Mmap) -> Result<(Vec<usize>, usize, bool)> {
620    if mmap.len() < 10 {
621        return Err(Error::IndexLoad("NPY file too small".into()));
622    }
623
624    // Check magic
625    if &mmap[..6] != NPY_MAGIC {
626        return Err(Error::IndexLoad("Invalid NPY magic".into()));
627    }
628
629    let major_version = mmap[6];
630    let _minor_version = mmap[7];
631
632    // Read header length
633    let header_len = if major_version == 1 {
634        u16::from_le_bytes([mmap[8], mmap[9]]) as usize
635    } else if major_version == 2 {
636        if mmap.len() < 12 {
637            return Err(Error::IndexLoad("NPY v2 file too small".into()));
638        }
639        u32::from_le_bytes([mmap[8], mmap[9], mmap[10], mmap[11]]) as usize
640    } else {
641        return Err(Error::IndexLoad(format!(
642            "Unsupported NPY version: {}",
643            major_version
644        )));
645    };
646
647    let header_start = if major_version == 1 { 10 } else { 12 };
648    let header_end = header_start + header_len;
649
650    if mmap.len() < header_end {
651        return Err(Error::IndexLoad("NPY header exceeds file size".into()));
652    }
653
654    // Parse header dict (simplified Python dict parsing)
655    let header_str = std::str::from_utf8(&mmap[header_start..header_end])
656        .map_err(|e| Error::IndexLoad(format!("Invalid NPY header encoding: {}", e)))?;
657
658    // Extract shape from header like: {'descr': '<i8', 'fortran_order': False, 'shape': (12345,), }
659    let shape = parse_shape_from_header(header_str)?;
660    let fortran_order = header_str.contains("'fortran_order': True");
661
662    Ok((shape, header_end, fortran_order))
663}
664
665/// Parse shape tuple from NPY header string
666fn parse_shape_from_header(header: &str) -> Result<Vec<usize>> {
667    // Find 'shape': (...)
668    let shape_start = header
669        .find("'shape':")
670        .ok_or_else(|| Error::IndexLoad("No shape in NPY header".into()))?;
671
672    let after_shape = &header[shape_start + 8..];
673    let paren_start = after_shape
674        .find('(')
675        .ok_or_else(|| Error::IndexLoad("No shape tuple in NPY header".into()))?;
676    let paren_end = after_shape
677        .find(')')
678        .ok_or_else(|| Error::IndexLoad("Unclosed shape tuple in NPY header".into()))?;
679
680    let shape_content = &after_shape[paren_start + 1..paren_end];
681
682    // Parse comma-separated numbers
683    let mut shape = Vec::new();
684    for part in shape_content.split(',') {
685        let trimmed = part.trim();
686        if !trimmed.is_empty() {
687            let dim: usize = trimmed.parse().map_err(|e| {
688                Error::IndexLoad(format!("Invalid shape dimension '{}': {}", trimmed, e))
689            })?;
690            shape.push(dim);
691        }
692    }
693
694    Ok(shape)
695}
696
697/// Memory-mapped NPY array for i64 values (used for codes).
698///
699/// This struct provides zero-copy access to 1D i64 arrays stored in NPY format.
700pub struct MmapNpyArray1I64 {
701    _mmap: Mmap,
702    len: usize,
703    data_offset: usize,
704}
705
706impl MmapNpyArray1I64 {
707    /// Load a 1D i64 array from an NPY file.
708    pub fn from_npy_file(path: &Path) -> Result<Self> {
709        let file = File::open(path)
710            .map_err(|e| Error::IndexLoad(format!("Failed to open NPY file {:?}: {}", path, e)))?;
711
712        let mmap = unsafe {
713            Mmap::map(&file).map_err(|e| {
714                Error::IndexLoad(format!("Failed to mmap NPY file {:?}: {}", path, e))
715            })?
716        };
717
718        let (shape, data_offset, _fortran_order) = parse_npy_header(&mmap)?;
719
720        if shape.is_empty() {
721            return Err(Error::IndexLoad("Empty shape in NPY file".into()));
722        }
723
724        let len = shape[0];
725
726        // Verify file size
727        let expected_size = data_offset + len * 8;
728        if mmap.len() < expected_size {
729            return Err(Error::IndexLoad(format!(
730                "NPY file size {} too small for {} elements",
731                mmap.len(),
732                len
733            )));
734        }
735
736        Ok(Self {
737            _mmap: mmap,
738            len,
739            data_offset,
740        })
741    }
742
743    /// Get the length of the array.
744    pub fn len(&self) -> usize {
745        self.len
746    }
747
748    /// Returns true if the array is empty.
749    pub fn is_empty(&self) -> bool {
750        self.len == 0
751    }
752
753    /// Get a slice of the data as &[i64].
754    ///
755    /// Returns a `Vec<i64>` instead of &[i64] to handle unaligned data safely.
756    ///
757    /// # Safety
758    /// The caller must ensure start <= end <= len.
759    pub fn slice(&self, start: usize, end: usize) -> Vec<i64> {
760        let count = end - start;
761        let mut result = Vec::with_capacity(count);
762
763        for i in start..end {
764            result.push(self.get(i));
765        }
766
767        result
768    }
769
770    /// Get a value at an index.
771    pub fn get(&self, idx: usize) -> i64 {
772        let start = self.data_offset + idx * 8;
773        let bytes = &self._mmap[start..start + 8];
774        i64::from_le_bytes(bytes.try_into().unwrap())
775    }
776}
777
778/// Memory-mapped NPY array for f32 values (used for centroids).
779///
780/// This struct provides zero-copy access to 2D f32 arrays stored in NPY format.
781/// Unlike loading into an owned `Array2<f32>`, this approach lets the OS manage
782/// paging, reducing resident memory usage for large centroid matrices.
783pub struct MmapNpyArray2F32 {
784    _mmap: Mmap,
785    shape: (usize, usize),
786    data_offset: usize,
787}
788
789impl MmapNpyArray2F32 {
790    /// Load a 2D f32 array from an NPY file.
791    pub fn from_npy_file(path: &Path) -> Result<Self> {
792        let file = File::open(path)
793            .map_err(|e| Error::IndexLoad(format!("Failed to open NPY file {:?}: {}", path, e)))?;
794
795        let mmap = unsafe {
796            Mmap::map(&file).map_err(|e| {
797                Error::IndexLoad(format!("Failed to mmap NPY file {:?}: {}", path, e))
798            })?
799        };
800
801        let (shape_vec, data_offset, _fortran_order) = parse_npy_header(&mmap)?;
802
803        if shape_vec.len() != 2 {
804            return Err(Error::IndexLoad(format!(
805                "Expected 2D array, got {}D",
806                shape_vec.len()
807            )));
808        }
809
810        let shape = (shape_vec[0], shape_vec[1]);
811
812        // Verify file size (f32 = 4 bytes)
813        let expected_size = data_offset + shape.0 * shape.1 * 4;
814        if mmap.len() < expected_size {
815            return Err(Error::IndexLoad(format!(
816                "NPY file size {} too small for shape {:?}",
817                mmap.len(),
818                shape
819            )));
820        }
821
822        Ok(Self {
823            _mmap: mmap,
824            shape,
825            data_offset,
826        })
827    }
828
829    /// Get the shape of the array.
830    pub fn shape(&self) -> (usize, usize) {
831        self.shape
832    }
833
834    /// Get the number of rows.
835    pub fn nrows(&self) -> usize {
836        self.shape.0
837    }
838
839    /// Get the number of columns.
840    pub fn ncols(&self) -> usize {
841        self.shape.1
842    }
843
844    /// Get a view of the entire array as ArrayView2.
845    ///
846    /// This provides zero-copy access to the memory-mapped data.
847    pub fn view(&self) -> ArrayView2<'_, f32> {
848        let byte_start = self.data_offset;
849        let byte_end = self.data_offset + self.shape.0 * self.shape.1 * 4;
850        let bytes = &self._mmap[byte_start..byte_end];
851
852        // Safety: We've verified bounds and f32 is 4-byte aligned in NPY format
853        let data = unsafe {
854            std::slice::from_raw_parts(bytes.as_ptr() as *const f32, self.shape.0 * self.shape.1)
855        };
856
857        ArrayView2::from_shape(self.shape, data).unwrap()
858    }
859
860    /// Get a view of a single row.
861    pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
862        let byte_start = self.data_offset + idx * self.shape.1 * 4;
863        let bytes = &self._mmap[byte_start..byte_start + self.shape.1 * 4];
864
865        // Safety: We've verified bounds and alignment
866        let data =
867            unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, self.shape.1) };
868
869        ArrayView1::from_shape(self.shape.1, data).unwrap()
870    }
871
872    /// Get a view of rows [start..end] as ArrayView2.
873    pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
874        let nrows = end - start;
875        let byte_start = self.data_offset + start * self.shape.1 * 4;
876        let byte_end = self.data_offset + end * self.shape.1 * 4;
877        let bytes = &self._mmap[byte_start..byte_end];
878
879        // Safety: We've verified bounds
880        let data = unsafe {
881            std::slice::from_raw_parts(bytes.as_ptr() as *const f32, nrows * self.shape.1)
882        };
883
884        ArrayView2::from_shape((nrows, self.shape.1), data).unwrap()
885    }
886
887    /// Convert to an owned Array2 (loads all data into memory).
888    ///
889    /// Use this only when you need an owned copy; prefer `view()` for read-only access.
890    pub fn to_owned(&self) -> Array2<f32> {
891        self.view().to_owned()
892    }
893}
894
895/// Memory-mapped NPY array for u8 values (used for residuals).
896///
897/// This struct provides zero-copy access to 2D u8 arrays stored in NPY format.
898pub struct MmapNpyArray2U8 {
899    _mmap: Mmap,
900    shape: (usize, usize),
901    data_offset: usize,
902}
903
904impl MmapNpyArray2U8 {
905    /// Load a 2D u8 array from an NPY file.
906    pub fn from_npy_file(path: &Path) -> Result<Self> {
907        let file = File::open(path)
908            .map_err(|e| Error::IndexLoad(format!("Failed to open NPY file {:?}: {}", path, e)))?;
909
910        let mmap = unsafe {
911            Mmap::map(&file).map_err(|e| {
912                Error::IndexLoad(format!("Failed to mmap NPY file {:?}: {}", path, e))
913            })?
914        };
915
916        let (shape_vec, data_offset, _fortran_order) = parse_npy_header(&mmap)?;
917
918        if shape_vec.len() != 2 {
919            return Err(Error::IndexLoad(format!(
920                "Expected 2D array, got {}D",
921                shape_vec.len()
922            )));
923        }
924
925        let shape = (shape_vec[0], shape_vec[1]);
926
927        // Verify file size
928        let expected_size = data_offset + shape.0 * shape.1;
929        if mmap.len() < expected_size {
930            return Err(Error::IndexLoad(format!(
931                "NPY file size {} too small for shape {:?}",
932                mmap.len(),
933                shape
934            )));
935        }
936
937        Ok(Self {
938            _mmap: mmap,
939            shape,
940            data_offset,
941        })
942    }
943
944    /// Get the shape of the array.
945    pub fn shape(&self) -> (usize, usize) {
946        self.shape
947    }
948
949    /// Get the number of rows.
950    pub fn nrows(&self) -> usize {
951        self.shape.0
952    }
953
954    /// Get the number of columns.
955    pub fn ncols(&self) -> usize {
956        self.shape.1
957    }
958
959    /// Get a view of rows [start..end] as ArrayView2.
960    pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, u8> {
961        let nrows = end - start;
962        let byte_start = self.data_offset + start * self.shape.1;
963        let byte_end = self.data_offset + end * self.shape.1;
964        let bytes = &self._mmap[byte_start..byte_end];
965
966        ArrayView2::from_shape((nrows, self.shape.1), bytes).unwrap()
967    }
968
969    /// Get a view of the entire array.
970    pub fn view(&self) -> ArrayView2<'_, u8> {
971        self.slice_rows(0, self.shape.0)
972    }
973
974    /// Get a single row as a slice.
975    pub fn row(&self, idx: usize) -> &[u8] {
976        let byte_start = self.data_offset + idx * self.shape.1;
977        let byte_end = byte_start + self.shape.1;
978        &self._mmap[byte_start..byte_end]
979    }
980}
981
982// ============================================================================
983// Merged File Creation
984// ============================================================================
985
986/// Manifest entry for tracking chunk files
987#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
988pub struct ChunkManifestEntry {
989    pub rows: usize,
990    pub mtime: f64,
991}
992
993/// Manifest for merged files
994pub type ChunkManifest = HashMap<String, ChunkManifestEntry>;
995
996/// Load manifest from disk if it exists
997fn load_manifest(manifest_path: &Path) -> Option<ChunkManifest> {
998    if manifest_path.exists() {
999        if let Ok(file) = File::open(manifest_path) {
1000            if let Ok(manifest) = serde_json::from_reader(BufReader::new(file)) {
1001                return Some(manifest);
1002            }
1003        }
1004    }
1005    None
1006}
1007
1008/// Save manifest to disk
1009fn save_manifest(manifest_path: &Path, manifest: &ChunkManifest) -> Result<()> {
1010    let file = File::create(manifest_path)
1011        .map_err(|e| Error::IndexLoad(format!("Failed to create manifest: {}", e)))?;
1012    serde_json::to_writer(BufWriter::new(file), manifest)
1013        .map_err(|e| Error::IndexLoad(format!("Failed to write manifest: {}", e)))?;
1014    Ok(())
1015}
1016
1017/// Get file modification time as f64 seconds since epoch
1018fn get_mtime(path: &Path) -> Result<f64> {
1019    let metadata = fs::metadata(path)
1020        .map_err(|e| Error::IndexLoad(format!("Failed to get metadata for {:?}: {}", path, e)))?;
1021    let mtime = metadata
1022        .modified()
1023        .map_err(|e| Error::IndexLoad(format!("Failed to get mtime: {}", e)))?;
1024    let duration = mtime
1025        .duration_since(std::time::UNIX_EPOCH)
1026        .map_err(|e| Error::IndexLoad(format!("Invalid mtime: {}", e)))?;
1027    Ok(duration.as_secs_f64())
1028}
1029
1030/// Write NPY header for a 1D array
1031fn write_npy_header_1d(writer: &mut impl Write, len: usize, dtype: &str) -> Result<usize> {
1032    // Build header dict
1033    let header_dict = format!(
1034        "{{'descr': '{}', 'fortran_order': False, 'shape': ({},), }}",
1035        dtype, len
1036    );
1037
1038    // Pad to 64-byte alignment (NPY requirement)
1039    let header_len = header_dict.len();
1040    let padding = (64 - ((10 + header_len) % 64)) % 64;
1041    let padded_header = format!("{}{}\n", header_dict, " ".repeat(padding));
1042
1043    // Write magic + version
1044    writer
1045        .write_all(NPY_MAGIC)
1046        .map_err(|e| Error::IndexLoad(format!("Failed to write NPY magic: {}", e)))?;
1047    writer
1048        .write_all(&[1, 0])
1049        .map_err(|e| Error::IndexLoad(format!("Failed to write version: {}", e)))?; // v1.0
1050
1051    // Write header length (2 bytes for v1.0)
1052    let header_len_bytes = (padded_header.len() as u16).to_le_bytes();
1053    writer
1054        .write_all(&header_len_bytes)
1055        .map_err(|e| Error::IndexLoad(format!("Failed to write header len: {}", e)))?;
1056
1057    // Write header
1058    writer
1059        .write_all(padded_header.as_bytes())
1060        .map_err(|e| Error::IndexLoad(format!("Failed to write header: {}", e)))?;
1061
1062    Ok(10 + padded_header.len())
1063}
1064
1065/// Write NPY header for a 2D array
1066fn write_npy_header_2d(
1067    writer: &mut impl Write,
1068    nrows: usize,
1069    ncols: usize,
1070    dtype: &str,
1071) -> Result<usize> {
1072    // Build header dict
1073    let header_dict = format!(
1074        "{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}",
1075        dtype, nrows, ncols
1076    );
1077
1078    // Pad to 64-byte alignment
1079    let header_len = header_dict.len();
1080    let padding = (64 - ((10 + header_len) % 64)) % 64;
1081    let padded_header = format!("{}{}\n", header_dict, " ".repeat(padding));
1082
1083    // Write magic + version
1084    writer
1085        .write_all(NPY_MAGIC)
1086        .map_err(|e| Error::IndexLoad(format!("Failed to write NPY magic: {}", e)))?;
1087    writer
1088        .write_all(&[1, 0])
1089        .map_err(|e| Error::IndexLoad(format!("Failed to write version: {}", e)))?;
1090
1091    // Write header length
1092    let header_len_bytes = (padded_header.len() as u16).to_le_bytes();
1093    writer
1094        .write_all(&header_len_bytes)
1095        .map_err(|e| Error::IndexLoad(format!("Failed to write header len: {}", e)))?;
1096
1097    // Write header
1098    writer
1099        .write_all(padded_header.as_bytes())
1100        .map_err(|e| Error::IndexLoad(format!("Failed to write header: {}", e)))?;
1101
1102    Ok(10 + padded_header.len())
1103}
1104
1105/// Information about a chunk file for merging
1106struct ChunkInfo {
1107    path: std::path::PathBuf,
1108    filename: String,
1109    rows: usize,
1110    mtime: f64,
1111}
1112
1113/// Merge chunked codes NPY files into a single merged file.
1114///
1115/// Uses incremental persistence with manifest tracking to skip unchanged chunks.
1116/// Returns the path to the merged file.
1117pub fn merge_codes_chunks(
1118    index_path: &Path,
1119    num_chunks: usize,
1120    padding_rows: usize,
1121) -> Result<std::path::PathBuf> {
1122    use ndarray_npy::ReadNpyExt;
1123
1124    let merged_path = index_path.join("merged_codes.npy");
1125    let manifest_path = index_path.join("merged_codes.manifest.json");
1126
1127    // Load previous manifest
1128    let old_manifest = load_manifest(&manifest_path);
1129
1130    // Scan chunks and detect changes
1131    let mut chunks: Vec<ChunkInfo> = Vec::new();
1132    let mut total_rows = 0usize;
1133    let mut chain_broken = false;
1134
1135    for i in 0..num_chunks {
1136        let filename = format!("{}.codes.npy", i);
1137        let path = index_path.join(&filename);
1138
1139        if path.exists() {
1140            let mtime = get_mtime(&path)?;
1141
1142            // Get shape by reading header only
1143            let file = File::open(&path)?;
1144            let arr: Array1<i64> = Array1::read_npy(file)?;
1145            let rows = arr.len();
1146
1147            if rows > 0 {
1148                total_rows += rows;
1149
1150                // Check if this chunk changed
1151                let is_clean = if let Some(ref manifest) = old_manifest {
1152                    manifest
1153                        .get(&filename)
1154                        .is_some_and(|entry| entry.mtime == mtime && entry.rows == rows)
1155                } else {
1156                    false
1157                };
1158
1159                if !is_clean {
1160                    chain_broken = true;
1161                }
1162
1163                chunks.push(ChunkInfo {
1164                    path,
1165                    filename,
1166                    rows,
1167                    mtime,
1168                });
1169            }
1170        }
1171    }
1172
1173    if total_rows == 0 {
1174        return Err(Error::IndexLoad("No data to merge".into()));
1175    }
1176
1177    let final_rows = total_rows + padding_rows;
1178
1179    // Check if we need to rewrite
1180    let needs_full_rewrite = !merged_path.exists() || chain_broken;
1181
1182    if needs_full_rewrite {
1183        // Create new merged file
1184        let file = File::create(&merged_path)?;
1185        let mut writer = BufWriter::new(file);
1186
1187        // Write header
1188        write_npy_header_1d(&mut writer, final_rows, "<i8")?;
1189
1190        // Write chunk data
1191        for chunk in &chunks {
1192            let file = File::open(&chunk.path)?;
1193            let arr: Array1<i64> = Array1::read_npy(file)?;
1194            for &val in arr.iter() {
1195                writer.write_all(&val.to_le_bytes())?;
1196            }
1197        }
1198
1199        // Write padding zeros
1200        for _ in 0..padding_rows {
1201            writer.write_all(&0i64.to_le_bytes())?;
1202        }
1203
1204        writer.flush()?;
1205    }
1206
1207    // Save manifest
1208    let mut new_manifest = ChunkManifest::new();
1209    for chunk in &chunks {
1210        new_manifest.insert(
1211            chunk.filename.clone(),
1212            ChunkManifestEntry {
1213                rows: chunk.rows,
1214                mtime: chunk.mtime,
1215            },
1216        );
1217    }
1218    save_manifest(&manifest_path, &new_manifest)?;
1219
1220    Ok(merged_path)
1221}
1222
1223/// Merge chunked residuals NPY files into a single merged file.
1224pub fn merge_residuals_chunks(
1225    index_path: &Path,
1226    num_chunks: usize,
1227    padding_rows: usize,
1228) -> Result<std::path::PathBuf> {
1229    use ndarray_npy::ReadNpyExt;
1230
1231    let merged_path = index_path.join("merged_residuals.npy");
1232    let manifest_path = index_path.join("merged_residuals.manifest.json");
1233
1234    // Load previous manifest
1235    let old_manifest = load_manifest(&manifest_path);
1236
1237    // Scan chunks and detect changes
1238    let mut chunks: Vec<ChunkInfo> = Vec::new();
1239    let mut total_rows = 0usize;
1240    let mut ncols = 0usize;
1241    let mut chain_broken = false;
1242
1243    for i in 0..num_chunks {
1244        let filename = format!("{}.residuals.npy", i);
1245        let path = index_path.join(&filename);
1246
1247        if path.exists() {
1248            let mtime = get_mtime(&path)?;
1249
1250            // Get shape by reading header
1251            let file = File::open(&path)?;
1252            let arr: Array2<u8> = Array2::read_npy(file)?;
1253            let rows = arr.nrows();
1254            ncols = arr.ncols();
1255
1256            if rows > 0 {
1257                total_rows += rows;
1258
1259                let is_clean = if let Some(ref manifest) = old_manifest {
1260                    manifest
1261                        .get(&filename)
1262                        .is_some_and(|entry| entry.mtime == mtime && entry.rows == rows)
1263                } else {
1264                    false
1265                };
1266
1267                if !is_clean {
1268                    chain_broken = true;
1269                }
1270
1271                chunks.push(ChunkInfo {
1272                    path,
1273                    filename,
1274                    rows,
1275                    mtime,
1276                });
1277            }
1278        }
1279    }
1280
1281    if total_rows == 0 || ncols == 0 {
1282        return Err(Error::IndexLoad("No residual data to merge".into()));
1283    }
1284
1285    let final_rows = total_rows + padding_rows;
1286
1287    let needs_full_rewrite = !merged_path.exists() || chain_broken;
1288
1289    if needs_full_rewrite {
1290        let file = File::create(&merged_path)?;
1291        let mut writer = BufWriter::new(file);
1292
1293        // Write header
1294        write_npy_header_2d(&mut writer, final_rows, ncols, "|u1")?;
1295
1296        // Write chunk data
1297        for chunk in &chunks {
1298            let file = File::open(&chunk.path)?;
1299            let arr: Array2<u8> = Array2::read_npy(file)?;
1300            for row in arr.rows() {
1301                writer.write_all(row.as_slice().unwrap())?;
1302            }
1303        }
1304
1305        // Write padding zeros
1306        let zero_row = vec![0u8; ncols];
1307        for _ in 0..padding_rows {
1308            writer.write_all(&zero_row)?;
1309        }
1310
1311        writer.flush()?;
1312    }
1313
1314    // Save manifest
1315    let mut new_manifest = ChunkManifest::new();
1316    for chunk in &chunks {
1317        new_manifest.insert(
1318            chunk.filename.clone(),
1319            ChunkManifestEntry {
1320                rows: chunk.rows,
1321                mtime: chunk.mtime,
1322            },
1323        );
1324    }
1325    save_manifest(&manifest_path, &new_manifest)?;
1326
1327    Ok(merged_path)
1328}
1329
1330// ============================================================================
1331// Fast-PLAID Compatibility Conversion
1332// ============================================================================
1333
1334/// Convert a fast-plaid index to next-plaid compatible format.
1335///
1336/// This function detects and converts:
1337/// - float16 → float32 for centroids, avg_residual, bucket_cutoffs, bucket_weights
1338/// - int64 → int32 for ivf_lengths
1339/// - `<u1` → `|u1` for residuals
1340///
1341/// Returns true if any conversion was performed, false if already compatible.
1342pub fn convert_fastplaid_to_nextplaid(index_path: &Path) -> Result<bool> {
1343    let mut converted = false;
1344
1345    // Float files to convert from f16 to f32
1346    let float_files = [
1347        "centroids.npy",
1348        "avg_residual.npy",
1349        "bucket_cutoffs.npy",
1350        "bucket_weights.npy",
1351    ];
1352
1353    for filename in float_files {
1354        let path = index_path.join(filename);
1355        if path.exists() {
1356            let dtype = detect_npy_dtype(&path)?;
1357            if dtype == "<f2" {
1358                eprintln!("  Converting {} from float16 to float32", filename);
1359                convert_f16_to_f32_npy(&path)?;
1360                converted = true;
1361            }
1362        }
1363    }
1364
1365    // Convert ivf_lengths from i64 to i32
1366    let ivf_lengths_path = index_path.join("ivf_lengths.npy");
1367    if ivf_lengths_path.exists() {
1368        let dtype = detect_npy_dtype(&ivf_lengths_path)?;
1369        if dtype == "<i8" {
1370            eprintln!("  Converting ivf_lengths.npy from int64 to int32");
1371            convert_i64_to_i32_npy(&ivf_lengths_path)?;
1372            converted = true;
1373        }
1374    }
1375
1376    // Normalize residual files to use "|u1" descriptor
1377    // fast-plaid uses "<u1" which ndarray_npy doesn't accept
1378    for entry in fs::read_dir(index_path)? {
1379        let entry = entry?;
1380        let filename = entry.file_name().to_string_lossy().to_string();
1381        if filename.ends_with(".residuals.npy") {
1382            let path = entry.path();
1383            let dtype = detect_npy_dtype(&path)?;
1384            if dtype == "<u1" {
1385                eprintln!(
1386                    "  Normalizing {} dtype descriptor from <u1 to |u1",
1387                    filename
1388                );
1389                normalize_u8_npy(&path)?;
1390                converted = true;
1391            }
1392        }
1393    }
1394
1395    Ok(converted)
1396}
1397
1398#[cfg(test)]
1399mod tests {
1400    use super::*;
1401    use std::io::Write;
1402    use tempfile::NamedTempFile;
1403
1404    #[test]
1405    fn test_mmap_array2_f32() {
1406        // Create a test file
1407        let mut file = NamedTempFile::new().unwrap();
1408
1409        // Write header (3 rows, 2 cols)
1410        file.write_all(&3i64.to_le_bytes()).unwrap();
1411        file.write_all(&2i64.to_le_bytes()).unwrap();
1412
1413        // Write data
1414        for val in [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] {
1415            file.write_all(&val.to_le_bytes()).unwrap();
1416        }
1417
1418        file.flush().unwrap();
1419
1420        // Load and verify
1421        let mmap = MmapArray2F32::from_raw_file(file.path()).unwrap();
1422        assert_eq!(mmap.shape(), (3, 2));
1423
1424        let row0 = mmap.row(0);
1425        assert_eq!(row0[0], 1.0);
1426        assert_eq!(row0[1], 2.0);
1427
1428        let owned = mmap.to_owned();
1429        assert_eq!(owned[[2, 0]], 5.0);
1430        assert_eq!(owned[[2, 1]], 6.0);
1431    }
1432
1433    #[test]
1434    fn test_mmap_array1_i64() {
1435        let mut file = NamedTempFile::new().unwrap();
1436
1437        // Write header (4 elements)
1438        file.write_all(&4i64.to_le_bytes()).unwrap();
1439
1440        // Write data
1441        for val in [10i64, 20, 30, 40] {
1442            file.write_all(&val.to_le_bytes()).unwrap();
1443        }
1444
1445        file.flush().unwrap();
1446
1447        let mmap = MmapArray1I64::from_raw_file(file.path()).unwrap();
1448        assert_eq!(mmap.len(), 4);
1449        assert_eq!(mmap.get(0), 10);
1450        assert_eq!(mmap.get(3), 40);
1451
1452        let owned = mmap.to_owned();
1453        assert_eq!(owned[1], 20);
1454        assert_eq!(owned[2], 30);
1455    }
1456
1457    #[test]
1458    fn test_write_read_roundtrip() {
1459        let file = NamedTempFile::new().unwrap();
1460        let path = file.path();
1461
1462        // Create test array
1463        let array = Array2::from_shape_vec((2, 3), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1464
1465        // Write
1466        write_array2_f32(&array, path).unwrap();
1467
1468        // Read back
1469        let mmap = MmapArray2F32::from_raw_file(path).unwrap();
1470        let loaded = mmap.to_owned();
1471
1472        assert_eq!(array, loaded);
1473    }
1474}