1use std::{
10 convert::TryInto,
11 io::{BufReader, Read, Seek, SeekFrom, Write},
12 mem,
13};
14
15use bytemuck::Pod;
16use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
17use diskann::{ANNError, ANNErrorKind, ANNResult, utils::read_exact_into};
18use diskann_wide::{LoHi, SplitJoin};
19use thiserror::Error;
20use tracing::info;
21
22use crate::utils::DatasetDto;
23
24const DEFAULT_BUF_SIZE: usize = 1024 * 1024;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct Metadata {
29 pub npoints: usize,
30 pub ndims: usize,
31}
32
33#[derive(Debug, Error)]
35pub enum MetadataError<T, U> {
36 #[error("num points conversion")]
37 NumPoints(#[source] T),
38 #[error("dim conversion")]
39 Dim(#[source] U),
40 #[error("writing binary results")]
41 Write(#[source] std::io::Error),
42}
43
44impl<T, U> From<MetadataError<T, U>> for ANNError
45where
46 T: std::error::Error + Send + Sync + 'static,
47 U: std::error::Error + Send + Sync + 'static,
48{
49 #[track_caller]
50 fn from(err: MetadataError<T, U>) -> Self {
51 ANNError::new(ANNErrorKind::IOError, err)
52 }
53}
54
55pub fn read_metadata<Reader: Read>(reader: &mut Reader) -> std::io::Result<Metadata> {
65 let raw = reader.read_u64::<LittleEndian>()?;
66 let bytes: [u8; 8] = bytemuck::cast(raw);
67 let LoHi {
68 lo: npts_bytes,
69 hi: ndims_bytes,
70 } = bytes.split();
71 let npoints = u32::from_le_bytes(npts_bytes) as usize;
72 let ndims = u32::from_le_bytes(ndims_bytes) as usize;
73 Ok(Metadata { npoints, ndims })
74}
75
76pub fn write_metadata<Writer: Write, N, D>(
90 writer: &mut Writer,
91 npts: N,
92 ndims: D,
93) -> Result<usize, MetadataError<N::Error, D::Error>>
94where
95 N: TryInto<u32>,
96 D: TryInto<u32>,
97 N::Error: std::error::Error + 'static,
98 D::Error: std::error::Error + 'static,
99{
100 let npts_u32 = npts.try_into().map_err(MetadataError::NumPoints)?;
101 let ndims_u32 = ndims.try_into().map_err(MetadataError::Dim)?;
102
103 let bytes: [u8; 8] = LoHi::new(npts_u32.to_le_bytes(), ndims_u32.to_le_bytes()).join();
104 writer.write_all(&bytes).map_err(MetadataError::Write)?;
105
106 Ok(2 * std::mem::size_of::<u32>())
107}
108
109pub fn load_vector_ids<Reader: Read>(reader: &mut Reader) -> std::io::Result<(usize, Vec<u32>)> {
111 let mut reader = BufReader::new(reader);
115 let num_ids = reader.read_u32::<LittleEndian>()? as usize;
116
117 let mut ids = Vec::with_capacity(num_ids);
118 for _ in 0..num_ids {
119 let id = reader.read_u32::<LittleEndian>()?;
120 ids.push(id);
121 }
122
123 Ok((num_ids, ids))
124}
125
126#[cfg(target_endian = "little")]
138pub fn copy_aligned_data<T: Default + bytemuck::Pod, Reader: Read>(
139 reader: &mut Reader,
140 dataset_dto: DatasetDto<T>,
141 pts_offset: usize,
142) -> std::io::Result<(usize, usize)> {
143 let mut reader = BufReader::with_capacity(DEFAULT_BUF_SIZE, reader);
144
145 let metadata = read_metadata(&mut reader)?;
146 let (npts, dim) = (metadata.npoints, metadata.ndims);
147 let rounded_dim = dataset_dto.rounded_dim;
148 let offset = pts_offset * rounded_dim;
149
150 for i in 0..npts {
151 let data_slice =
152 &mut dataset_dto.data[offset + i * rounded_dim..offset + i * rounded_dim + dim];
153
154 let byte_slice: &mut [u8] = bytemuck::must_cast_slice_mut(data_slice);
156 reader.read_exact(byte_slice)?;
157
158 let remaining = &mut dataset_dto.data
159 [offset + i * rounded_dim + dim..offset + i * rounded_dim + rounded_dim];
160 remaining.fill_with(Default::default);
161 }
162
163 Ok((npts, dim))
164}
165
166pub fn load_bin<T: Pod + Default, Reader: Read + Seek>(
171 reader: &mut Reader,
172 offset: usize,
173) -> std::io::Result<(Vec<T>, usize, usize)> {
174 let mut reader = BufReader::new(reader);
175 reader.seek(std::io::SeekFrom::Start(offset as u64))?;
176 let metadata = read_metadata(&mut reader)?;
177 let (npts, dim) = (metadata.npoints, metadata.ndims);
178
179 let size = npts * dim * std::mem::size_of::<T>();
180
181 let buf: Vec<T> = read_exact_into(&mut reader, npts * dim)?;
182 info!(
183 "bin: #pts = {}, #dims = {}, offset = {} size = {}B",
184 npts, dim, offset, size
185 );
186
187 Ok((buf, npts, dim))
188}
189
190pub fn save_bytes<Writer: Write + Seek>(
192 writer: &mut Writer,
193 data: &[u8],
194 npts: usize,
195 ndims: usize,
196 offset: usize,
197) -> ANNResult<usize> {
198 writer.seek(std::io::SeekFrom::Start(offset as u64))?;
199 write_metadata(writer, npts, ndims)?;
200 writer.write_all(data)?;
201 writer.flush()?;
202
203 Ok(data.len() + 2 * std::mem::size_of::<u32>())
204}
205
206pub fn save_data_in_base_dimensions<T: Default + Copy + bytemuck::Pod, Writer: Write + Seek>(
215 writer: &mut Writer,
216 data: &[T],
217 npts: usize,
218 ndims: usize,
219 aligned_dim: usize,
220 offset: usize,
221) -> ANNResult<usize> {
222 let bytes_written = 2 * std::mem::size_of::<u32>() + npts * ndims * (std::mem::size_of::<T>());
223
224 writer.seek(std::io::SeekFrom::Start(offset as u64))?;
225 write_metadata(writer, npts, ndims)?;
226
227 for i in 0..npts {
228 let start = i * aligned_dim;
229 let end = start + ndims;
230 let vector_slice = &data[start..end];
231 let bytes: &[u8] = bytemuck::must_cast_slice(vector_slice);
233 writer.write_all(bytes)?;
234 }
235 writer.flush()?;
236 Ok(bytes_written)
237}
238
239macro_rules! save_bin {
240 ($name:ident, $t:ty, $write_func:ident) => {
241 pub fn $name<W: Write + Seek>(
243 writer: &mut W,
244 data: &[$t],
245 num_pts: usize,
246 dims: usize,
247 offset: usize,
248 ) -> ANNResult<usize> {
249 writer.seek(SeekFrom::Start(offset as u64))?;
250 let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::<u32>();
251
252 write_metadata(writer, num_pts, dims)?;
253 info!(
254 "bin: #pts = {}, #dims = {}, size = {}B",
255 num_pts, dims, bytes_written
256 );
257
258 for item in data.iter() {
259 writer.$write_func::<LittleEndian>(*item)?;
260 }
261
262 writer.flush()?;
263
264 info!("Finished writing bin.");
265 Ok(bytes_written)
266 }
267 };
268}
269
270save_bin!(save_bin_f32, f32, write_f32);
271save_bin!(save_bin_u64, u64, write_u64);
272save_bin!(save_bin_u32, u32, write_u32);
273
274#[cfg(test)]
275mod storage_util_test {
276 use crate::storage::{StorageReadProvider, StorageWriteProvider, VirtualStorageProvider};
277 use tempfile::tempfile;
278 use vfs::MemoryFS;
279
280 use super::*;
281 pub const DIM_8: usize = 8;
282
283 #[test]
284 fn read_metadata_test() {
285 let file_name = "/test_read_metadata_test.bin";
286 let data = [200, 0, 0, 0, 128, 0, 0, 0]; let vfs = MemoryFS::default();
288 let storage_provider = VirtualStorageProvider::new(vfs);
289 {
290 let mut file = storage_provider
291 .create_for_write(file_name)
292 .expect("Could not create file");
293 file.write_all(&data)
294 .expect("Should be able to write sample file");
295 }
296
297 let mut reader = storage_provider.open_reader(file_name).unwrap();
298 match read_metadata(&mut reader) {
299 Ok(metadata) => {
300 assert_eq!(metadata.npoints, 200);
301 assert_eq!(metadata.ndims, 128);
302 }
303 Err(_e) => {}
304 }
305 storage_provider
306 .delete(file_name)
307 .expect("Should be able to delete sample file");
308 }
309
310 #[test]
311 fn read_metadata_i32_compatibility_test() {
312 let file_name = "/test_read_metadata_i32_compat.bin";
314 let npts = 200i32;
315 let dims = 128i32;
316 let vfs = MemoryFS::default();
317 let storage_provider = VirtualStorageProvider::new(vfs);
318 {
319 let mut file = storage_provider
320 .create_for_write(file_name)
321 .expect("Could not create file");
322 file.write_i32::<LittleEndian>(npts).unwrap();
324 file.write_i32::<LittleEndian>(dims).unwrap();
325 }
326
327 let mut reader = storage_provider.open_reader(file_name).unwrap();
329 let metadata = read_metadata(&mut reader).unwrap();
330
331 assert_eq!(metadata.npoints, 200);
332 assert_eq!(metadata.ndims, 128);
333
334 storage_provider
335 .delete(file_name)
336 .expect("Should be able to delete sample file");
337 }
338
339 #[test]
340 fn load_vector_ids_test() {
341 let file_name = "/load_vector_ids_test";
342 let ids = vec![0u32, 1u32, 2u32];
343 let num_ids = ids.len();
344 let vfs = MemoryFS::new();
345 let storage_provider = VirtualStorageProvider::new(vfs);
346 {
347 let mut writer = storage_provider.create_for_write(file_name).unwrap();
348 writer.write_u32::<LittleEndian>(num_ids as u32).unwrap();
349 for item in ids.iter() {
350 writer.write_u32::<LittleEndian>(*item).unwrap();
351 }
352 }
353
354 let load_data =
355 load_vector_ids(&mut storage_provider.open_reader(file_name).unwrap()).unwrap();
356 assert_eq!(load_data, (num_ids, ids));
357 storage_provider
358 .delete(file_name)
359 .expect("Should be able to delete sample file");
360 }
361
362 #[test]
363 fn load_bin_test() {
364 let file_name = "/load_bin_test";
365 let data = vec![0u64, 1u64, 2u64];
366 let num_pts = data.len();
367 let dims = 1;
368 let vfs = MemoryFS::new();
369 let storage_provider = VirtualStorageProvider::new(vfs);
370 let bytes_written = save_bin_u64(
371 &mut storage_provider.create_for_write(file_name).unwrap(),
372 &data,
373 num_pts,
374 dims,
375 0,
376 )
377 .unwrap();
378 assert_eq!(bytes_written, 32);
379
380 let (load_data, load_num_pts, load_dims) =
381 load_bin::<u64, _>(&mut storage_provider.open_reader(file_name).unwrap(), 0).unwrap();
382 assert_eq!(load_num_pts, num_pts);
383 assert_eq!(load_dims, dims);
384 assert_eq!(load_data, data);
385 storage_provider.delete(file_name).unwrap();
386 }
387
388 #[test]
389 fn load_bin_offset_test() {
390 let offset: usize = 32;
391 let file_name = "/load_bin_offset_test";
392 let data = vec![0u64, 1u64, 2u64];
393 let num_pts = data.len();
394 let dims = 1;
395 let vfs = MemoryFS::new();
396 let storage_provider = VirtualStorageProvider::new(vfs);
397 let bytes_written = save_bin_u64(
398 &mut storage_provider.create_for_write(file_name).unwrap(),
399 &data,
400 num_pts,
401 dims,
402 offset,
403 )
404 .unwrap();
405 assert_eq!(bytes_written, 32);
406
407 let (load_data, load_num_pts, load_dims) = load_bin::<u64, _>(
408 &mut storage_provider.open_reader(file_name).unwrap(),
409 offset,
410 )
411 .unwrap();
412 assert_eq!(load_num_pts, num_pts);
413 assert_eq!(load_dims, dims);
414 assert_eq!(load_data, data);
415 storage_provider.delete(file_name).unwrap();
416 }
417
418 #[test]
419 fn save_data_in_base_dimensions_test() {
420 let data: [u8; 72] = [
422 2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
423 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
424 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
425 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
426 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
427 ];
428 let num_points = 2;
429 let dim = DIM_8;
430 let data_file = "/save_data_in_base_dimensions_test.data";
431 let vfs = MemoryFS::new();
432 let storage_provider = VirtualStorageProvider::new(vfs);
433 match save_data_in_base_dimensions(
434 &mut storage_provider.create_for_write(data_file).unwrap(),
435 &data,
436 num_points,
437 dim,
438 DIM_8,
439 0,
440 ) {
441 Ok(num) => {
442 assert!(storage_provider.exists(data_file));
443 assert_eq!(
444 num,
445 2 * std::mem::size_of::<u32>() + num_points * dim * std::mem::size_of::<u8>()
446 );
447 storage_provider
448 .delete(data_file)
449 .expect("Failed to delete file");
450 }
451 Err(e) => {
452 storage_provider
453 .delete(data_file)
454 .expect("Failed to delete file");
455 panic!("{}", e)
456 }
457 }
458 }
459
460 #[test]
461 fn save_bin_test() {
462 let data = vec![0u64, 1u64, 2u64];
463 let num_pts = data.len();
464 let dims = 1;
465 let mut file = tempfile().unwrap();
466 let bytes_written = save_bin_u64::<_>(&mut file, &data, num_pts, dims, 0).unwrap();
467 assert_eq!(bytes_written, 32);
468
469 let mut buffer = vec![];
470 file.seek(SeekFrom::Start(0)).unwrap();
471 let metadata = read_metadata(&mut file).unwrap();
472
473 file.read_to_end(&mut buffer).unwrap();
474 let data_read: Vec<u64> = buffer
475 .chunks_exact(8)
476 .map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]))
477 .collect();
478
479 assert_eq!(num_pts, metadata.npoints);
480 assert_eq!(dims, metadata.ndims);
481 assert_eq!(data, data_read);
482 }
483
484 #[test]
485 fn write_metadata_unified_test() {
486 let mut buffer = Vec::new();
487
488 let result = write_metadata(&mut buffer, 200u32, 128u32);
490 assert!(result.is_ok());
491 assert_eq!(result.unwrap(), 8);
492
493 buffer.clear();
495 let result = write_metadata(&mut buffer, 200usize, 128usize);
496 assert!(result.is_ok());
497 assert_eq!(result.unwrap(), 8);
498
499 buffer.clear();
501 let result = write_metadata(&mut buffer, 200usize, 128u32);
502 assert!(result.is_ok());
503
504 let mut cursor = std::io::Cursor::new(&buffer);
506 let metadata = read_metadata(&mut cursor).unwrap();
507 assert_eq!(metadata.npoints, 200);
508 assert_eq!(metadata.ndims, 128);
509 }
510
511 #[test]
512 fn metadata_error_types_test() {
513 let large_value = u32::MAX as usize + 1;
515 let result = write_metadata(&mut Vec::new(), large_value, 128usize);
516 assert!(matches!(result, Err(MetadataError::NumPoints(_))));
517
518 let result = write_metadata(&mut Vec::new(), 128usize, large_value);
520 assert!(matches!(result, Err(MetadataError::Dim(_))));
521
522 struct FailingWriter;
524 impl std::io::Write for FailingWriter {
525 fn write(&mut self, _: &[u8]) -> std::io::Result<usize> {
526 Err(std::io::Error::new(
527 std::io::ErrorKind::PermissionDenied,
528 "fail",
529 ))
530 }
531 fn flush(&mut self) -> std::io::Result<()> {
532 Ok(())
533 }
534 }
535
536 let result = write_metadata(&mut FailingWriter, 200u32, 128u32);
537 assert!(matches!(result, Err(MetadataError::Write(_))));
538 }
539
540 #[test]
541 fn metadata_error_to_ann_error_test() {
542 use diskann::{ANNError, ANNErrorKind};
543
544 let large_value = u32::MAX as usize + 1;
546 let result = write_metadata(&mut Vec::new(), large_value, 128usize);
547 let metadata_err = result.unwrap_err();
548 let ann_error: ANNError = metadata_err.into();
549
550 assert_eq!(ann_error.kind(), ANNErrorKind::IOError);
551
552 let error_str = ann_error.to_string();
554 assert!(error_str.contains("num points conversion"));
555 }
556}
557
558#[cfg(test)]
559mod tests {
560 use std::io::Cursor;
561
562 use super::*;
563
564 #[test]
565 fn test_copy_aligned_data() -> std::io::Result<()> {
566 let mut data = Vec::with_capacity(24);
567 data.extend_from_slice(&(2_i32.to_le_bytes()));
568 data.extend_from_slice(&(2_i32.to_le_bytes()));
569 data.extend_from_slice(&(1_f32.to_le_bytes()));
570 data.extend_from_slice(&(2_f32.to_le_bytes()));
571 data.extend_from_slice(&(3_f32.to_le_bytes()));
572 data.extend_from_slice(&(4_f32.to_le_bytes()));
573
574 let mut reader = Cursor::new(data);
575
576 let rounded_dim = 4;
577 let mut aligned_data = vec![0f32; 2 * rounded_dim];
578 let dataset_dto = DatasetDto::<f32> {
579 data: &mut aligned_data,
580 rounded_dim,
581 };
582
583 let (npts, dim) = copy_aligned_data(&mut reader, dataset_dto, 0)?;
584
585 assert_eq!(npts, 2);
586 assert_eq!(dim, 2);
587
588 assert_eq!(aligned_data, vec![1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
589
590 Ok(())
591 }
592}