1use std::collections::HashMap;
11use std::fs;
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Write};
14use std::path::Path;
15
16use byteorder::{LittleEndian, ReadBytesExt};
17use memmap2::Mmap;
18use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
19
20use crate::error::{Error, Result};
21
22pub struct MmapArray2F32 {
26 _mmap: Mmap,
27 shape: (usize, usize),
28 data_offset: usize,
29}
30
31impl MmapArray2F32 {
32 pub fn from_raw_file(path: &Path) -> Result<Self> {
39 let file = File::open(path)
40 .map_err(|e| Error::IndexLoad(format!("Failed to open file {:?}: {}", path, e)))?;
41
42 let mmap = unsafe {
43 Mmap::map(&file)
44 .map_err(|e| Error::IndexLoad(format!("Failed to mmap file {:?}: {}", path, e)))?
45 };
46
47 if mmap.len() < 16 {
48 return Err(Error::IndexLoad("File too small for header".into()));
49 }
50
51 let mut cursor = std::io::Cursor::new(&mmap[..16]);
53 let nrows = cursor
54 .read_i64::<LittleEndian>()
55 .map_err(|e| Error::IndexLoad(format!("Failed to read nrows: {}", e)))?
56 as usize;
57 let ncols = cursor
58 .read_i64::<LittleEndian>()
59 .map_err(|e| Error::IndexLoad(format!("Failed to read ncols: {}", e)))?
60 as usize;
61
62 let expected_size = 16 + nrows * ncols * 4;
63 if mmap.len() < expected_size {
64 return Err(Error::IndexLoad(format!(
65 "File size {} too small for shape ({}, {})",
66 mmap.len(),
67 nrows,
68 ncols
69 )));
70 }
71
72 Ok(Self {
73 _mmap: mmap,
74 shape: (nrows, ncols),
75 data_offset: 16,
76 })
77 }
78
79 pub fn shape(&self) -> (usize, usize) {
81 self.shape
82 }
83
84 pub fn nrows(&self) -> usize {
86 self.shape.0
87 }
88
89 pub fn ncols(&self) -> usize {
91 self.shape.1
92 }
93
94 pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
96 let start = self.data_offset + idx * self.shape.1 * 4;
97 let bytes = &self._mmap[start..start + self.shape.1 * 4];
98
99 let data =
101 unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, self.shape.1) };
102
103 ArrayView1::from_shape(self.shape.1, data).unwrap()
104 }
105
106 pub fn load_rows(&self, start: usize, end: usize) -> Array2<f32> {
108 let nrows = end - start;
109 let byte_start = self.data_offset + start * self.shape.1 * 4;
110 let byte_end = self.data_offset + end * self.shape.1 * 4;
111 let bytes = &self._mmap[byte_start..byte_end];
112
113 let data = unsafe {
115 std::slice::from_raw_parts(bytes.as_ptr() as *const f32, nrows * self.shape.1)
116 };
117
118 Array2::from_shape_vec((nrows, self.shape.1), data.to_vec()).unwrap()
119 }
120
121 pub fn to_owned(&self) -> Array2<f32> {
123 self.load_rows(0, self.shape.0)
124 }
125}
126
127pub struct MmapArray2U8 {
129 _mmap: Mmap,
130 shape: (usize, usize),
131 data_offset: usize,
132}
133
134impl MmapArray2U8 {
135 pub fn from_raw_file(path: &Path) -> Result<Self> {
137 let file = File::open(path)
138 .map_err(|e| Error::IndexLoad(format!("Failed to open file {:?}: {}", path, e)))?;
139
140 let mmap = unsafe {
141 Mmap::map(&file)
142 .map_err(|e| Error::IndexLoad(format!("Failed to mmap file {:?}: {}", path, e)))?
143 };
144
145 if mmap.len() < 16 {
146 return Err(Error::IndexLoad("File too small for header".into()));
147 }
148
149 let mut cursor = std::io::Cursor::new(&mmap[..16]);
150 let nrows = cursor
151 .read_i64::<LittleEndian>()
152 .map_err(|e| Error::IndexLoad(format!("Failed to read nrows: {}", e)))?
153 as usize;
154 let ncols = cursor
155 .read_i64::<LittleEndian>()
156 .map_err(|e| Error::IndexLoad(format!("Failed to read ncols: {}", e)))?
157 as usize;
158
159 let expected_size = 16 + nrows * ncols;
160 if mmap.len() < expected_size {
161 return Err(Error::IndexLoad(format!(
162 "File size {} too small for shape ({}, {})",
163 mmap.len(),
164 nrows,
165 ncols
166 )));
167 }
168
169 Ok(Self {
170 _mmap: mmap,
171 shape: (nrows, ncols),
172 data_offset: 16,
173 })
174 }
175
176 pub fn shape(&self) -> (usize, usize) {
178 self.shape
179 }
180
181 pub fn view(&self) -> ArrayView2<'_, u8> {
183 let bytes = &self._mmap[self.data_offset..self.data_offset + self.shape.0 * self.shape.1];
184 ArrayView2::from_shape(self.shape, bytes).unwrap()
185 }
186
187 pub fn load_rows(&self, start: usize, end: usize) -> Array2<u8> {
189 let nrows = end - start;
190 let byte_start = self.data_offset + start * self.shape.1;
191 let byte_end = self.data_offset + end * self.shape.1;
192 let bytes = &self._mmap[byte_start..byte_end];
193
194 Array2::from_shape_vec((nrows, self.shape.1), bytes.to_vec()).unwrap()
195 }
196
197 pub fn to_owned(&self) -> Array2<u8> {
199 self.load_rows(0, self.shape.0)
200 }
201}
202
203pub struct MmapArray1I64 {
205 _mmap: Mmap,
206 len: usize,
207 data_offset: usize,
208}
209
210impl MmapArray1I64 {
211 pub fn from_raw_file(path: &Path) -> Result<Self> {
213 let file = File::open(path)
214 .map_err(|e| Error::IndexLoad(format!("Failed to open file {:?}: {}", path, e)))?;
215
216 let mmap = unsafe {
217 Mmap::map(&file)
218 .map_err(|e| Error::IndexLoad(format!("Failed to mmap file {:?}: {}", path, e)))?
219 };
220
221 if mmap.len() < 8 {
222 return Err(Error::IndexLoad("File too small for header".into()));
223 }
224
225 let mut cursor = std::io::Cursor::new(&mmap[..8]);
226 let len = cursor
227 .read_i64::<LittleEndian>()
228 .map_err(|e| Error::IndexLoad(format!("Failed to read length: {}", e)))?
229 as usize;
230
231 let expected_size = 8 + len * 8;
232 if mmap.len() < expected_size {
233 return Err(Error::IndexLoad(format!(
234 "File size {} too small for length {}",
235 mmap.len(),
236 len
237 )));
238 }
239
240 Ok(Self {
241 _mmap: mmap,
242 len,
243 data_offset: 8,
244 })
245 }
246
247 pub fn len(&self) -> usize {
249 self.len
250 }
251
252 pub fn is_empty(&self) -> bool {
254 self.len == 0
255 }
256
257 pub fn get(&self, idx: usize) -> i64 {
259 let start = self.data_offset + idx * 8;
260 let bytes = &self._mmap[start..start + 8];
261 i64::from_le_bytes(bytes.try_into().unwrap())
262 }
263
264 pub fn to_owned(&self) -> Array1<i64> {
266 let bytes = &self._mmap[self.data_offset..self.data_offset + self.len * 8];
267
268 let data = unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const i64, self.len) };
270
271 Array1::from_vec(data.to_vec())
272 }
273}
274
275pub fn write_array2_f32(array: &Array2<f32>, path: &Path) -> Result<()> {
277 use std::io::Write;
278
279 let file = File::create(path)
280 .map_err(|e| Error::IndexLoad(format!("Failed to create file {:?}: {}", path, e)))?;
281 let mut writer = std::io::BufWriter::new(file);
282
283 let nrows = array.nrows() as i64;
284 let ncols = array.ncols() as i64;
285
286 writer
287 .write_all(&nrows.to_le_bytes())
288 .map_err(|e| Error::IndexLoad(format!("Failed to write nrows: {}", e)))?;
289 writer
290 .write_all(&ncols.to_le_bytes())
291 .map_err(|e| Error::IndexLoad(format!("Failed to write ncols: {}", e)))?;
292
293 for val in array.iter() {
294 writer
295 .write_all(&val.to_le_bytes())
296 .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
297 }
298
299 writer
300 .flush()
301 .map_err(|e| Error::IndexLoad(format!("Failed to flush: {}", e)))?;
302
303 Ok(())
304}
305
306pub fn write_array2_u8(array: &Array2<u8>, path: &Path) -> Result<()> {
308 use std::io::Write;
309
310 let file = File::create(path)
311 .map_err(|e| Error::IndexLoad(format!("Failed to create file {:?}: {}", path, e)))?;
312 let mut writer = std::io::BufWriter::new(file);
313
314 let nrows = array.nrows() as i64;
315 let ncols = array.ncols() as i64;
316
317 writer
318 .write_all(&nrows.to_le_bytes())
319 .map_err(|e| Error::IndexLoad(format!("Failed to write nrows: {}", e)))?;
320 writer
321 .write_all(&ncols.to_le_bytes())
322 .map_err(|e| Error::IndexLoad(format!("Failed to write ncols: {}", e)))?;
323
324 for row in array.rows() {
325 writer
326 .write_all(row.as_slice().unwrap())
327 .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
328 }
329
330 writer
331 .flush()
332 .map_err(|e| Error::IndexLoad(format!("Failed to flush: {}", e)))?;
333
334 Ok(())
335}
336
337pub fn write_array1_i64(array: &Array1<i64>, path: &Path) -> Result<()> {
339 use std::io::Write;
340
341 let file = File::create(path)
342 .map_err(|e| Error::IndexLoad(format!("Failed to create file {:?}: {}", path, e)))?;
343 let mut writer = std::io::BufWriter::new(file);
344
345 let len = array.len() as i64;
346
347 writer
348 .write_all(&len.to_le_bytes())
349 .map_err(|e| Error::IndexLoad(format!("Failed to write length: {}", e)))?;
350
351 for val in array.iter() {
352 writer
353 .write_all(&val.to_le_bytes())
354 .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
355 }
356
357 writer
358 .flush()
359 .map_err(|e| Error::IndexLoad(format!("Failed to flush: {}", e)))?;
360
361 Ok(())
362}
363
364const NPY_MAGIC: &[u8] = b"\x93NUMPY";
370
371fn parse_dtype_from_header(header: &str) -> Result<String> {
373 let descr_start = header
375 .find("'descr':")
376 .ok_or_else(|| Error::IndexLoad("No descr in NPY header".into()))?;
377
378 let after_descr = &header[descr_start + 8..];
379 let quote_start = after_descr
380 .find('\'')
381 .ok_or_else(|| Error::IndexLoad("No dtype quote in NPY header".into()))?;
382 let rest = &after_descr[quote_start + 1..];
383 let quote_end = rest
384 .find('\'')
385 .ok_or_else(|| Error::IndexLoad("Unclosed dtype quote in NPY header".into()))?;
386
387 Ok(rest[..quote_end].to_string())
388}
389
390pub fn detect_npy_dtype(path: &Path) -> Result<String> {
392 let file = File::open(path)
393 .map_err(|e| Error::IndexLoad(format!("Failed to open NPY file {:?}: {}", path, e)))?;
394
395 let mmap = unsafe {
396 Mmap::map(&file)
397 .map_err(|e| Error::IndexLoad(format!("Failed to mmap NPY file {:?}: {}", path, e)))?
398 };
399
400 if mmap.len() < 10 {
401 return Err(Error::IndexLoad("NPY file too small".into()));
402 }
403
404 if &mmap[..6] != NPY_MAGIC {
406 return Err(Error::IndexLoad("Invalid NPY magic".into()));
407 }
408
409 let major_version = mmap[6];
410
411 let header_len = if major_version == 1 {
413 u16::from_le_bytes([mmap[8], mmap[9]]) as usize
414 } else if major_version == 2 {
415 if mmap.len() < 12 {
416 return Err(Error::IndexLoad("NPY v2 file too small".into()));
417 }
418 u32::from_le_bytes([mmap[8], mmap[9], mmap[10], mmap[11]]) as usize
419 } else {
420 return Err(Error::IndexLoad(format!(
421 "Unsupported NPY version: {}",
422 major_version
423 )));
424 };
425
426 let header_start = if major_version == 1 { 10 } else { 12 };
427 let header_end = header_start + header_len;
428
429 if mmap.len() < header_end {
430 return Err(Error::IndexLoad("NPY header exceeds file size".into()));
431 }
432
433 let header_str = std::str::from_utf8(&mmap[header_start..header_end])
434 .map_err(|e| Error::IndexLoad(format!("Invalid NPY header encoding: {}", e)))?;
435
436 parse_dtype_from_header(header_str)
437}
438
439pub fn convert_f16_to_f32_npy(path: &Path) -> Result<()> {
441 use half::f16;
442 use std::io::Read;
443
444 let mut file = File::open(path)
446 .map_err(|e| Error::IndexLoad(format!("Failed to open {:?}: {}", path, e)))?;
447 let mut data = Vec::new();
448 file.read_to_end(&mut data)
449 .map_err(|e| Error::IndexLoad(format!("Failed to read {:?}: {}", path, e)))?;
450
451 if data.len() < 10 || &data[..6] != NPY_MAGIC {
452 return Err(Error::IndexLoad("Invalid NPY file".into()));
453 }
454
455 let major_version = data[6];
456 let header_start = if major_version == 1 { 10 } else { 12 };
457 let header_len = if major_version == 1 {
458 u16::from_le_bytes([data[8], data[9]]) as usize
459 } else {
460 u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
461 };
462 let header_end = header_start + header_len;
463
464 let header_str = std::str::from_utf8(&data[header_start..header_end])
466 .map_err(|e| Error::IndexLoad(format!("Invalid header: {}", e)))?;
467 let shape = parse_shape_from_header(header_str)?;
468
469 let total_elements: usize = shape.iter().product();
471 let f16_data = &data[header_end..header_end + total_elements * 2];
472
473 let mut f32_data = Vec::with_capacity(total_elements * 4);
475 for chunk in f16_data.chunks(2) {
476 let f16_val = f16::from_le_bytes([chunk[0], chunk[1]]);
477 let f32_val: f32 = f16_val.to_f32();
478 f32_data.extend_from_slice(&f32_val.to_le_bytes());
479 }
480
481 let file = File::create(path)
483 .map_err(|e| Error::IndexLoad(format!("Failed to create {:?}: {}", path, e)))?;
484 let mut writer = BufWriter::new(file);
485
486 if shape.len() == 1 {
487 write_npy_header_1d(&mut writer, shape[0], "<f4")?;
488 } else if shape.len() == 2 {
489 write_npy_header_2d(&mut writer, shape[0], shape[1], "<f4")?;
490 } else {
491 return Err(Error::IndexLoad("Unsupported shape dimensions".into()));
492 }
493
494 writer
495 .write_all(&f32_data)
496 .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
497 writer.flush()?;
498
499 Ok(())
500}
501
502pub fn convert_i64_to_i32_npy(path: &Path) -> Result<()> {
504 use std::io::Read;
505
506 let mut file = File::open(path)
508 .map_err(|e| Error::IndexLoad(format!("Failed to open {:?}: {}", path, e)))?;
509 let mut data = Vec::new();
510 file.read_to_end(&mut data)
511 .map_err(|e| Error::IndexLoad(format!("Failed to read {:?}: {}", path, e)))?;
512
513 if data.len() < 10 || &data[..6] != NPY_MAGIC {
514 return Err(Error::IndexLoad("Invalid NPY file".into()));
515 }
516
517 let major_version = data[6];
518 let header_start = if major_version == 1 { 10 } else { 12 };
519 let header_len = if major_version == 1 {
520 u16::from_le_bytes([data[8], data[9]]) as usize
521 } else {
522 u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
523 };
524 let header_end = header_start + header_len;
525
526 let header_str = std::str::from_utf8(&data[header_start..header_end])
528 .map_err(|e| Error::IndexLoad(format!("Invalid header: {}", e)))?;
529 let shape = parse_shape_from_header(header_str)?;
530
531 if shape.len() != 1 {
532 return Err(Error::IndexLoad("Expected 1D array for i64->i32".into()));
533 }
534
535 let len = shape[0];
536 let i64_data = &data[header_end..header_end + len * 8];
537
538 let mut i32_data = Vec::with_capacity(len * 4);
540 for chunk in i64_data.chunks(8) {
541 let i64_val = i64::from_le_bytes(chunk.try_into().unwrap());
542 let i32_val = i64_val as i32;
543 i32_data.extend_from_slice(&i32_val.to_le_bytes());
544 }
545
546 let file = File::create(path)
548 .map_err(|e| Error::IndexLoad(format!("Failed to create {:?}: {}", path, e)))?;
549 let mut writer = BufWriter::new(file);
550
551 write_npy_header_1d(&mut writer, len, "<i4")?;
552
553 writer
554 .write_all(&i32_data)
555 .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
556 writer.flush()?;
557
558 Ok(())
559}
560
561pub fn normalize_u8_npy(path: &Path) -> Result<()> {
566 use std::io::Read;
567
568 let mut file = File::open(path)
570 .map_err(|e| Error::IndexLoad(format!("Failed to open {:?}: {}", path, e)))?;
571 let mut data = Vec::new();
572 file.read_to_end(&mut data)
573 .map_err(|e| Error::IndexLoad(format!("Failed to read {:?}: {}", path, e)))?;
574
575 if data.len() < 10 || &data[..6] != NPY_MAGIC {
576 return Err(Error::IndexLoad("Invalid NPY file".into()));
577 }
578
579 let major_version = data[6];
580 let header_start = if major_version == 1 { 10 } else { 12 };
581 let header_len = if major_version == 1 {
582 u16::from_le_bytes([data[8], data[9]]) as usize
583 } else {
584 u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize
585 };
586 let header_end = header_start + header_len;
587
588 let header_str = std::str::from_utf8(&data[header_start..header_end])
590 .map_err(|e| Error::IndexLoad(format!("Invalid header: {}", e)))?;
591 let shape = parse_shape_from_header(header_str)?;
592
593 if shape.len() != 2 {
594 return Err(Error::IndexLoad(
595 "Expected 2D array for u8 normalization".into(),
596 ));
597 }
598
599 let nrows = shape[0];
600 let ncols = shape[1];
601 let u8_data = &data[header_end..header_end + nrows * ncols];
602
603 let new_file = File::create(path)
605 .map_err(|e| Error::IndexLoad(format!("Failed to create {:?}: {}", path, e)))?;
606 let mut writer = BufWriter::new(new_file);
607
608 write_npy_header_2d(&mut writer, nrows, ncols, "|u1")?;
609
610 writer
611 .write_all(u8_data)
612 .map_err(|e| Error::IndexLoad(format!("Failed to write data: {}", e)))?;
613 writer.flush()?;
614
615 Ok(())
616}
617
618fn parse_npy_header(mmap: &Mmap) -> Result<(Vec<usize>, usize, bool)> {
620 if mmap.len() < 10 {
621 return Err(Error::IndexLoad("NPY file too small".into()));
622 }
623
624 if &mmap[..6] != NPY_MAGIC {
626 return Err(Error::IndexLoad("Invalid NPY magic".into()));
627 }
628
629 let major_version = mmap[6];
630 let _minor_version = mmap[7];
631
632 let header_len = if major_version == 1 {
634 u16::from_le_bytes([mmap[8], mmap[9]]) as usize
635 } else if major_version == 2 {
636 if mmap.len() < 12 {
637 return Err(Error::IndexLoad("NPY v2 file too small".into()));
638 }
639 u32::from_le_bytes([mmap[8], mmap[9], mmap[10], mmap[11]]) as usize
640 } else {
641 return Err(Error::IndexLoad(format!(
642 "Unsupported NPY version: {}",
643 major_version
644 )));
645 };
646
647 let header_start = if major_version == 1 { 10 } else { 12 };
648 let header_end = header_start + header_len;
649
650 if mmap.len() < header_end {
651 return Err(Error::IndexLoad("NPY header exceeds file size".into()));
652 }
653
654 let header_str = std::str::from_utf8(&mmap[header_start..header_end])
656 .map_err(|e| Error::IndexLoad(format!("Invalid NPY header encoding: {}", e)))?;
657
658 let shape = parse_shape_from_header(header_str)?;
660 let fortran_order = header_str.contains("'fortran_order': True");
661
662 Ok((shape, header_end, fortran_order))
663}
664
665fn parse_shape_from_header(header: &str) -> Result<Vec<usize>> {
667 let shape_start = header
669 .find("'shape':")
670 .ok_or_else(|| Error::IndexLoad("No shape in NPY header".into()))?;
671
672 let after_shape = &header[shape_start + 8..];
673 let paren_start = after_shape
674 .find('(')
675 .ok_or_else(|| Error::IndexLoad("No shape tuple in NPY header".into()))?;
676 let paren_end = after_shape
677 .find(')')
678 .ok_or_else(|| Error::IndexLoad("Unclosed shape tuple in NPY header".into()))?;
679
680 let shape_content = &after_shape[paren_start + 1..paren_end];
681
682 let mut shape = Vec::new();
684 for part in shape_content.split(',') {
685 let trimmed = part.trim();
686 if !trimmed.is_empty() {
687 let dim: usize = trimmed.parse().map_err(|e| {
688 Error::IndexLoad(format!("Invalid shape dimension '{}': {}", trimmed, e))
689 })?;
690 shape.push(dim);
691 }
692 }
693
694 Ok(shape)
695}
696
697pub struct MmapNpyArray1I64 {
701 _mmap: Mmap,
702 len: usize,
703 data_offset: usize,
704}
705
706impl MmapNpyArray1I64 {
707 pub fn from_npy_file(path: &Path) -> Result<Self> {
709 let file = File::open(path)
710 .map_err(|e| Error::IndexLoad(format!("Failed to open NPY file {:?}: {}", path, e)))?;
711
712 let mmap = unsafe {
713 Mmap::map(&file).map_err(|e| {
714 Error::IndexLoad(format!("Failed to mmap NPY file {:?}: {}", path, e))
715 })?
716 };
717
718 let (shape, data_offset, _fortran_order) = parse_npy_header(&mmap)?;
719
720 if shape.is_empty() {
721 return Err(Error::IndexLoad("Empty shape in NPY file".into()));
722 }
723
724 let len = shape[0];
725
726 let expected_size = data_offset + len * 8;
728 if mmap.len() < expected_size {
729 return Err(Error::IndexLoad(format!(
730 "NPY file size {} too small for {} elements",
731 mmap.len(),
732 len
733 )));
734 }
735
736 Ok(Self {
737 _mmap: mmap,
738 len,
739 data_offset,
740 })
741 }
742
743 pub fn len(&self) -> usize {
745 self.len
746 }
747
748 pub fn is_empty(&self) -> bool {
750 self.len == 0
751 }
752
753 pub fn slice(&self, start: usize, end: usize) -> Vec<i64> {
760 let count = end - start;
761 let mut result = Vec::with_capacity(count);
762
763 for i in start..end {
764 result.push(self.get(i));
765 }
766
767 result
768 }
769
770 pub fn get(&self, idx: usize) -> i64 {
772 let start = self.data_offset + idx * 8;
773 let bytes = &self._mmap[start..start + 8];
774 i64::from_le_bytes(bytes.try_into().unwrap())
775 }
776}
777
778pub struct MmapNpyArray2F32 {
784 _mmap: Mmap,
785 shape: (usize, usize),
786 data_offset: usize,
787}
788
789impl MmapNpyArray2F32 {
790 pub fn from_npy_file(path: &Path) -> Result<Self> {
792 let file = File::open(path)
793 .map_err(|e| Error::IndexLoad(format!("Failed to open NPY file {:?}: {}", path, e)))?;
794
795 let mmap = unsafe {
796 Mmap::map(&file).map_err(|e| {
797 Error::IndexLoad(format!("Failed to mmap NPY file {:?}: {}", path, e))
798 })?
799 };
800
801 let (shape_vec, data_offset, _fortran_order) = parse_npy_header(&mmap)?;
802
803 if shape_vec.len() != 2 {
804 return Err(Error::IndexLoad(format!(
805 "Expected 2D array, got {}D",
806 shape_vec.len()
807 )));
808 }
809
810 let shape = (shape_vec[0], shape_vec[1]);
811
812 let expected_size = data_offset + shape.0 * shape.1 * 4;
814 if mmap.len() < expected_size {
815 return Err(Error::IndexLoad(format!(
816 "NPY file size {} too small for shape {:?}",
817 mmap.len(),
818 shape
819 )));
820 }
821
822 Ok(Self {
823 _mmap: mmap,
824 shape,
825 data_offset,
826 })
827 }
828
829 pub fn shape(&self) -> (usize, usize) {
831 self.shape
832 }
833
834 pub fn nrows(&self) -> usize {
836 self.shape.0
837 }
838
839 pub fn ncols(&self) -> usize {
841 self.shape.1
842 }
843
844 pub fn view(&self) -> ArrayView2<'_, f32> {
848 let byte_start = self.data_offset;
849 let byte_end = self.data_offset + self.shape.0 * self.shape.1 * 4;
850 let bytes = &self._mmap[byte_start..byte_end];
851
852 let data = unsafe {
854 std::slice::from_raw_parts(bytes.as_ptr() as *const f32, self.shape.0 * self.shape.1)
855 };
856
857 ArrayView2::from_shape(self.shape, data).unwrap()
858 }
859
860 pub fn row(&self, idx: usize) -> ArrayView1<'_, f32> {
862 let byte_start = self.data_offset + idx * self.shape.1 * 4;
863 let bytes = &self._mmap[byte_start..byte_start + self.shape.1 * 4];
864
865 let data =
867 unsafe { std::slice::from_raw_parts(bytes.as_ptr() as *const f32, self.shape.1) };
868
869 ArrayView1::from_shape(self.shape.1, data).unwrap()
870 }
871
872 pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, f32> {
874 let nrows = end - start;
875 let byte_start = self.data_offset + start * self.shape.1 * 4;
876 let byte_end = self.data_offset + end * self.shape.1 * 4;
877 let bytes = &self._mmap[byte_start..byte_end];
878
879 let data = unsafe {
881 std::slice::from_raw_parts(bytes.as_ptr() as *const f32, nrows * self.shape.1)
882 };
883
884 ArrayView2::from_shape((nrows, self.shape.1), data).unwrap()
885 }
886
887 pub fn to_owned(&self) -> Array2<f32> {
891 self.view().to_owned()
892 }
893}
894
895pub struct MmapNpyArray2U8 {
899 _mmap: Mmap,
900 shape: (usize, usize),
901 data_offset: usize,
902}
903
904impl MmapNpyArray2U8 {
905 pub fn from_npy_file(path: &Path) -> Result<Self> {
907 let file = File::open(path)
908 .map_err(|e| Error::IndexLoad(format!("Failed to open NPY file {:?}: {}", path, e)))?;
909
910 let mmap = unsafe {
911 Mmap::map(&file).map_err(|e| {
912 Error::IndexLoad(format!("Failed to mmap NPY file {:?}: {}", path, e))
913 })?
914 };
915
916 let (shape_vec, data_offset, _fortran_order) = parse_npy_header(&mmap)?;
917
918 if shape_vec.len() != 2 {
919 return Err(Error::IndexLoad(format!(
920 "Expected 2D array, got {}D",
921 shape_vec.len()
922 )));
923 }
924
925 let shape = (shape_vec[0], shape_vec[1]);
926
927 let expected_size = data_offset + shape.0 * shape.1;
929 if mmap.len() < expected_size {
930 return Err(Error::IndexLoad(format!(
931 "NPY file size {} too small for shape {:?}",
932 mmap.len(),
933 shape
934 )));
935 }
936
937 Ok(Self {
938 _mmap: mmap,
939 shape,
940 data_offset,
941 })
942 }
943
944 pub fn shape(&self) -> (usize, usize) {
946 self.shape
947 }
948
949 pub fn nrows(&self) -> usize {
951 self.shape.0
952 }
953
954 pub fn ncols(&self) -> usize {
956 self.shape.1
957 }
958
959 pub fn slice_rows(&self, start: usize, end: usize) -> ArrayView2<'_, u8> {
961 let nrows = end - start;
962 let byte_start = self.data_offset + start * self.shape.1;
963 let byte_end = self.data_offset + end * self.shape.1;
964 let bytes = &self._mmap[byte_start..byte_end];
965
966 ArrayView2::from_shape((nrows, self.shape.1), bytes).unwrap()
967 }
968
969 pub fn view(&self) -> ArrayView2<'_, u8> {
971 self.slice_rows(0, self.shape.0)
972 }
973
974 pub fn row(&self, idx: usize) -> &[u8] {
976 let byte_start = self.data_offset + idx * self.shape.1;
977 let byte_end = byte_start + self.shape.1;
978 &self._mmap[byte_start..byte_end]
979 }
980}
981
982#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
988pub struct ChunkManifestEntry {
989 pub rows: usize,
990 pub mtime: f64,
991}
992
993pub type ChunkManifest = HashMap<String, ChunkManifestEntry>;
995
996fn load_manifest(manifest_path: &Path) -> Option<ChunkManifest> {
998 if manifest_path.exists() {
999 if let Ok(file) = File::open(manifest_path) {
1000 if let Ok(manifest) = serde_json::from_reader(BufReader::new(file)) {
1001 return Some(manifest);
1002 }
1003 }
1004 }
1005 None
1006}
1007
1008fn save_manifest(manifest_path: &Path, manifest: &ChunkManifest) -> Result<()> {
1010 let file = File::create(manifest_path)
1011 .map_err(|e| Error::IndexLoad(format!("Failed to create manifest: {}", e)))?;
1012 serde_json::to_writer(BufWriter::new(file), manifest)
1013 .map_err(|e| Error::IndexLoad(format!("Failed to write manifest: {}", e)))?;
1014 Ok(())
1015}
1016
1017fn get_mtime(path: &Path) -> Result<f64> {
1019 let metadata = fs::metadata(path)
1020 .map_err(|e| Error::IndexLoad(format!("Failed to get metadata for {:?}: {}", path, e)))?;
1021 let mtime = metadata
1022 .modified()
1023 .map_err(|e| Error::IndexLoad(format!("Failed to get mtime: {}", e)))?;
1024 let duration = mtime
1025 .duration_since(std::time::UNIX_EPOCH)
1026 .map_err(|e| Error::IndexLoad(format!("Invalid mtime: {}", e)))?;
1027 Ok(duration.as_secs_f64())
1028}
1029
1030fn write_npy_header_1d(writer: &mut impl Write, len: usize, dtype: &str) -> Result<usize> {
1032 let header_dict = format!(
1034 "{{'descr': '{}', 'fortran_order': False, 'shape': ({},), }}",
1035 dtype, len
1036 );
1037
1038 let header_len = header_dict.len();
1040 let padding = (64 - ((10 + header_len) % 64)) % 64;
1041 let padded_header = format!("{}{}\n", header_dict, " ".repeat(padding));
1042
1043 writer
1045 .write_all(NPY_MAGIC)
1046 .map_err(|e| Error::IndexLoad(format!("Failed to write NPY magic: {}", e)))?;
1047 writer
1048 .write_all(&[1, 0])
1049 .map_err(|e| Error::IndexLoad(format!("Failed to write version: {}", e)))?; let header_len_bytes = (padded_header.len() as u16).to_le_bytes();
1053 writer
1054 .write_all(&header_len_bytes)
1055 .map_err(|e| Error::IndexLoad(format!("Failed to write header len: {}", e)))?;
1056
1057 writer
1059 .write_all(padded_header.as_bytes())
1060 .map_err(|e| Error::IndexLoad(format!("Failed to write header: {}", e)))?;
1061
1062 Ok(10 + padded_header.len())
1063}
1064
1065fn write_npy_header_2d(
1067 writer: &mut impl Write,
1068 nrows: usize,
1069 ncols: usize,
1070 dtype: &str,
1071) -> Result<usize> {
1072 let header_dict = format!(
1074 "{{'descr': '{}', 'fortran_order': False, 'shape': ({}, {}), }}",
1075 dtype, nrows, ncols
1076 );
1077
1078 let header_len = header_dict.len();
1080 let padding = (64 - ((10 + header_len) % 64)) % 64;
1081 let padded_header = format!("{}{}\n", header_dict, " ".repeat(padding));
1082
1083 writer
1085 .write_all(NPY_MAGIC)
1086 .map_err(|e| Error::IndexLoad(format!("Failed to write NPY magic: {}", e)))?;
1087 writer
1088 .write_all(&[1, 0])
1089 .map_err(|e| Error::IndexLoad(format!("Failed to write version: {}", e)))?;
1090
1091 let header_len_bytes = (padded_header.len() as u16).to_le_bytes();
1093 writer
1094 .write_all(&header_len_bytes)
1095 .map_err(|e| Error::IndexLoad(format!("Failed to write header len: {}", e)))?;
1096
1097 writer
1099 .write_all(padded_header.as_bytes())
1100 .map_err(|e| Error::IndexLoad(format!("Failed to write header: {}", e)))?;
1101
1102 Ok(10 + padded_header.len())
1103}
1104
1105struct ChunkInfo {
1107 path: std::path::PathBuf,
1108 filename: String,
1109 rows: usize,
1110 mtime: f64,
1111}
1112
1113pub fn merge_codes_chunks(
1118 index_path: &Path,
1119 num_chunks: usize,
1120 padding_rows: usize,
1121) -> Result<std::path::PathBuf> {
1122 use ndarray_npy::ReadNpyExt;
1123
1124 let merged_path = index_path.join("merged_codes.npy");
1125 let manifest_path = index_path.join("merged_codes.manifest.json");
1126
1127 let old_manifest = load_manifest(&manifest_path);
1129
1130 let mut chunks: Vec<ChunkInfo> = Vec::new();
1132 let mut total_rows = 0usize;
1133 let mut chain_broken = false;
1134
1135 for i in 0..num_chunks {
1136 let filename = format!("{}.codes.npy", i);
1137 let path = index_path.join(&filename);
1138
1139 if path.exists() {
1140 let mtime = get_mtime(&path)?;
1141
1142 let file = File::open(&path)?;
1144 let arr: Array1<i64> = Array1::read_npy(file)?;
1145 let rows = arr.len();
1146
1147 if rows > 0 {
1148 total_rows += rows;
1149
1150 let is_clean = if let Some(ref manifest) = old_manifest {
1152 manifest
1153 .get(&filename)
1154 .is_some_and(|entry| entry.mtime == mtime && entry.rows == rows)
1155 } else {
1156 false
1157 };
1158
1159 if !is_clean {
1160 chain_broken = true;
1161 }
1162
1163 chunks.push(ChunkInfo {
1164 path,
1165 filename,
1166 rows,
1167 mtime,
1168 });
1169 }
1170 }
1171 }
1172
1173 if total_rows == 0 {
1174 return Err(Error::IndexLoad("No data to merge".into()));
1175 }
1176
1177 let final_rows = total_rows + padding_rows;
1178
1179 let needs_full_rewrite = !merged_path.exists() || chain_broken;
1181
1182 if needs_full_rewrite {
1183 let file = File::create(&merged_path)?;
1185 let mut writer = BufWriter::new(file);
1186
1187 write_npy_header_1d(&mut writer, final_rows, "<i8")?;
1189
1190 for chunk in &chunks {
1192 let file = File::open(&chunk.path)?;
1193 let arr: Array1<i64> = Array1::read_npy(file)?;
1194 for &val in arr.iter() {
1195 writer.write_all(&val.to_le_bytes())?;
1196 }
1197 }
1198
1199 for _ in 0..padding_rows {
1201 writer.write_all(&0i64.to_le_bytes())?;
1202 }
1203
1204 writer.flush()?;
1205 }
1206
1207 let mut new_manifest = ChunkManifest::new();
1209 for chunk in &chunks {
1210 new_manifest.insert(
1211 chunk.filename.clone(),
1212 ChunkManifestEntry {
1213 rows: chunk.rows,
1214 mtime: chunk.mtime,
1215 },
1216 );
1217 }
1218 save_manifest(&manifest_path, &new_manifest)?;
1219
1220 Ok(merged_path)
1221}
1222
1223pub fn merge_residuals_chunks(
1225 index_path: &Path,
1226 num_chunks: usize,
1227 padding_rows: usize,
1228) -> Result<std::path::PathBuf> {
1229 use ndarray_npy::ReadNpyExt;
1230
1231 let merged_path = index_path.join("merged_residuals.npy");
1232 let manifest_path = index_path.join("merged_residuals.manifest.json");
1233
1234 let old_manifest = load_manifest(&manifest_path);
1236
1237 let mut chunks: Vec<ChunkInfo> = Vec::new();
1239 let mut total_rows = 0usize;
1240 let mut ncols = 0usize;
1241 let mut chain_broken = false;
1242
1243 for i in 0..num_chunks {
1244 let filename = format!("{}.residuals.npy", i);
1245 let path = index_path.join(&filename);
1246
1247 if path.exists() {
1248 let mtime = get_mtime(&path)?;
1249
1250 let file = File::open(&path)?;
1252 let arr: Array2<u8> = Array2::read_npy(file)?;
1253 let rows = arr.nrows();
1254 ncols = arr.ncols();
1255
1256 if rows > 0 {
1257 total_rows += rows;
1258
1259 let is_clean = if let Some(ref manifest) = old_manifest {
1260 manifest
1261 .get(&filename)
1262 .is_some_and(|entry| entry.mtime == mtime && entry.rows == rows)
1263 } else {
1264 false
1265 };
1266
1267 if !is_clean {
1268 chain_broken = true;
1269 }
1270
1271 chunks.push(ChunkInfo {
1272 path,
1273 filename,
1274 rows,
1275 mtime,
1276 });
1277 }
1278 }
1279 }
1280
1281 if total_rows == 0 || ncols == 0 {
1282 return Err(Error::IndexLoad("No residual data to merge".into()));
1283 }
1284
1285 let final_rows = total_rows + padding_rows;
1286
1287 let needs_full_rewrite = !merged_path.exists() || chain_broken;
1288
1289 if needs_full_rewrite {
1290 let file = File::create(&merged_path)?;
1291 let mut writer = BufWriter::new(file);
1292
1293 write_npy_header_2d(&mut writer, final_rows, ncols, "|u1")?;
1295
1296 for chunk in &chunks {
1298 let file = File::open(&chunk.path)?;
1299 let arr: Array2<u8> = Array2::read_npy(file)?;
1300 for row in arr.rows() {
1301 writer.write_all(row.as_slice().unwrap())?;
1302 }
1303 }
1304
1305 let zero_row = vec![0u8; ncols];
1307 for _ in 0..padding_rows {
1308 writer.write_all(&zero_row)?;
1309 }
1310
1311 writer.flush()?;
1312 }
1313
1314 let mut new_manifest = ChunkManifest::new();
1316 for chunk in &chunks {
1317 new_manifest.insert(
1318 chunk.filename.clone(),
1319 ChunkManifestEntry {
1320 rows: chunk.rows,
1321 mtime: chunk.mtime,
1322 },
1323 );
1324 }
1325 save_manifest(&manifest_path, &new_manifest)?;
1326
1327 Ok(merged_path)
1328}
1329
1330pub fn convert_fastplaid_to_nextplaid(index_path: &Path) -> Result<bool> {
1343 let mut converted = false;
1344
1345 let float_files = [
1347 "centroids.npy",
1348 "avg_residual.npy",
1349 "bucket_cutoffs.npy",
1350 "bucket_weights.npy",
1351 ];
1352
1353 for filename in float_files {
1354 let path = index_path.join(filename);
1355 if path.exists() {
1356 let dtype = detect_npy_dtype(&path)?;
1357 if dtype == "<f2" {
1358 eprintln!(" Converting {} from float16 to float32", filename);
1359 convert_f16_to_f32_npy(&path)?;
1360 converted = true;
1361 }
1362 }
1363 }
1364
1365 let ivf_lengths_path = index_path.join("ivf_lengths.npy");
1367 if ivf_lengths_path.exists() {
1368 let dtype = detect_npy_dtype(&ivf_lengths_path)?;
1369 if dtype == "<i8" {
1370 eprintln!(" Converting ivf_lengths.npy from int64 to int32");
1371 convert_i64_to_i32_npy(&ivf_lengths_path)?;
1372 converted = true;
1373 }
1374 }
1375
1376 for entry in fs::read_dir(index_path)? {
1379 let entry = entry?;
1380 let filename = entry.file_name().to_string_lossy().to_string();
1381 if filename.ends_with(".residuals.npy") {
1382 let path = entry.path();
1383 let dtype = detect_npy_dtype(&path)?;
1384 if dtype == "<u1" {
1385 eprintln!(
1386 " Normalizing {} dtype descriptor from <u1 to |u1",
1387 filename
1388 );
1389 normalize_u8_npy(&path)?;
1390 converted = true;
1391 }
1392 }
1393 }
1394
1395 Ok(converted)
1396}
1397
1398#[cfg(test)]
1399mod tests {
1400 use super::*;
1401 use std::io::Write;
1402 use tempfile::NamedTempFile;
1403
1404 #[test]
1405 fn test_mmap_array2_f32() {
1406 let mut file = NamedTempFile::new().unwrap();
1408
1409 file.write_all(&3i64.to_le_bytes()).unwrap();
1411 file.write_all(&2i64.to_le_bytes()).unwrap();
1412
1413 for val in [1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0] {
1415 file.write_all(&val.to_le_bytes()).unwrap();
1416 }
1417
1418 file.flush().unwrap();
1419
1420 let mmap = MmapArray2F32::from_raw_file(file.path()).unwrap();
1422 assert_eq!(mmap.shape(), (3, 2));
1423
1424 let row0 = mmap.row(0);
1425 assert_eq!(row0[0], 1.0);
1426 assert_eq!(row0[1], 2.0);
1427
1428 let owned = mmap.to_owned();
1429 assert_eq!(owned[[2, 0]], 5.0);
1430 assert_eq!(owned[[2, 1]], 6.0);
1431 }
1432
1433 #[test]
1434 fn test_mmap_array1_i64() {
1435 let mut file = NamedTempFile::new().unwrap();
1436
1437 file.write_all(&4i64.to_le_bytes()).unwrap();
1439
1440 for val in [10i64, 20, 30, 40] {
1442 file.write_all(&val.to_le_bytes()).unwrap();
1443 }
1444
1445 file.flush().unwrap();
1446
1447 let mmap = MmapArray1I64::from_raw_file(file.path()).unwrap();
1448 assert_eq!(mmap.len(), 4);
1449 assert_eq!(mmap.get(0), 10);
1450 assert_eq!(mmap.get(3), 40);
1451
1452 let owned = mmap.to_owned();
1453 assert_eq!(owned[1], 20);
1454 assert_eq!(owned[2], 30);
1455 }
1456
1457 #[test]
1458 fn test_write_read_roundtrip() {
1459 let file = NamedTempFile::new().unwrap();
1460 let path = file.path();
1461
1462 let array = Array2::from_shape_vec((2, 3), vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
1464
1465 write_array2_f32(&array, path).unwrap();
1467
1468 let mmap = MmapArray2F32::from_raw_file(path).unwrap();
1470 let loaded = mmap.to_owned();
1471
1472 assert_eq!(array, loaded);
1473 }
1474}