1use std::{
10 convert::TryInto,
11 io::{BufReader, Read, Seek, SeekFrom, Write},
12 mem,
13};
14
15use bytemuck::Pod;
16use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
17use diskann::{ANNError, ANNErrorKind, ANNResult, utils::read_exact_into};
18use diskann_wide::{LoHi, SplitJoin};
19use thiserror::Error;
20use tracing::info;
21
22use crate::utils::DatasetDto;
23
24const DEFAULT_BUF_SIZE: usize = 1024 * 1024;
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct Metadata {
29 pub npoints: usize,
30 pub ndims: usize,
31}
32
33#[derive(Debug, Error)]
35pub enum MetadataError<T, U> {
36 #[error("num points conversion")]
37 NumPoints(#[source] T),
38 #[error("dim conversion")]
39 Dim(#[source] U),
40 #[error("writing binary results")]
41 Write(#[source] std::io::Error),
42}
43
44impl<T, U> From<MetadataError<T, U>> for ANNError
45where
46 T: std::error::Error + Send + Sync + 'static,
47 U: std::error::Error + Send + Sync + 'static,
48{
49 #[track_caller]
50 fn from(err: MetadataError<T, U>) -> Self {
51 ANNError::new(ANNErrorKind::IOError, err)
52 }
53}
54
55pub fn read_metadata<Reader: Read>(reader: &mut Reader) -> std::io::Result<Metadata> {
65 let raw = reader.read_u64::<LittleEndian>()?;
66 let bytes: [u8; 8] = bytemuck::cast(raw);
67 let LoHi {
68 lo: npts_bytes,
69 hi: ndims_bytes,
70 } = bytes.split();
71 let npoints = u32::from_le_bytes(npts_bytes) as usize;
72 let ndims = u32::from_le_bytes(ndims_bytes) as usize;
73 Ok(Metadata { npoints, ndims })
74}
75
76pub fn write_metadata<Writer: Write, N, D>(
90 writer: &mut Writer,
91 npts: N,
92 ndims: D,
93) -> Result<usize, MetadataError<N::Error, D::Error>>
94where
95 N: TryInto<u32>,
96 D: TryInto<u32>,
97 N::Error: std::error::Error + 'static,
98 D::Error: std::error::Error + 'static,
99{
100 let npts_u32 = npts.try_into().map_err(MetadataError::NumPoints)?;
101 let ndims_u32 = ndims.try_into().map_err(MetadataError::Dim)?;
102
103 let bytes: [u8; 8] = LoHi::new(npts_u32.to_le_bytes(), ndims_u32.to_le_bytes()).join();
104 writer.write_all(&bytes).map_err(MetadataError::Write)?;
105
106 Ok(2 * std::mem::size_of::<u32>())
107}
108
109pub fn load_vector_ids<Reader: Read>(reader: &mut Reader) -> std::io::Result<(usize, Vec<u32>)> {
111 let mut reader = BufReader::new(reader);
115 let num_ids = reader.read_u32::<LittleEndian>()? as usize;
116
117 let mut ids = Vec::with_capacity(num_ids);
118 for _ in 0..num_ids {
119 let id = reader.read_u32::<LittleEndian>()?;
120 ids.push(id);
121 }
122
123 Ok((num_ids, ids))
124}
125
126#[cfg(target_endian = "little")]
138pub fn copy_aligned_data<T: Default + bytemuck::Pod, Reader: Read>(
139 reader: &mut Reader,
140 dataset_dto: DatasetDto<T>,
141 pts_offset: usize,
142) -> std::io::Result<(usize, usize)> {
143 let mut reader = BufReader::with_capacity(DEFAULT_BUF_SIZE, reader);
144
145 let metadata = read_metadata(&mut reader)?;
146 let (npts, dim) = (metadata.npoints, metadata.ndims);
147 let rounded_dim = dataset_dto.rounded_dim;
148 let offset = pts_offset * rounded_dim;
149
150 for i in 0..npts {
151 let data_slice =
152 &mut dataset_dto.data[offset + i * rounded_dim..offset + i * rounded_dim + dim];
153
154 let byte_slice: &mut [u8] = bytemuck::must_cast_slice_mut(data_slice);
156 reader.read_exact(byte_slice)?;
157
158 let remaining = &mut dataset_dto.data
159 [offset + i * rounded_dim + dim..offset + i * rounded_dim + rounded_dim];
160 remaining.fill_with(Default::default);
161 }
162
163 Ok((npts, dim))
164}
165
166pub fn load_bin<T: Pod + Default, Reader: Read + Seek>(
171 reader: &mut Reader,
172 offset: usize,
173) -> std::io::Result<(Vec<T>, usize, usize)> {
174 let mut reader = BufReader::new(reader);
175 reader.seek(std::io::SeekFrom::Start(offset as u64))?;
176 let metadata = read_metadata(&mut reader)?;
177 let (npts, dim) = (metadata.npoints, metadata.ndims);
178
179 let size = npts * dim * std::mem::size_of::<T>();
180
181 let buf: Vec<T> = read_exact_into(&mut reader, npts * dim)?;
182 info!(
183 "bin: #pts = {}, #dims = {}, offset = {} size = {}B",
184 npts, dim, offset, size
185 );
186
187 Ok((buf, npts, dim))
188}
189
190pub fn save_bytes<Writer: Write + Seek>(
192 writer: &mut Writer,
193 data: &[u8],
194 npts: usize,
195 ndims: usize,
196 offset: usize,
197) -> ANNResult<usize> {
198 writer.seek(std::io::SeekFrom::Start(offset as u64))?;
199 write_metadata(writer, npts, ndims)?;
200 writer.write_all(data)?;
201 writer.flush()?;
202
203 Ok(data.len() + 2 * std::mem::size_of::<u32>())
204}
205
206pub fn save_data_in_base_dimensions<T: Default + Copy + bytemuck::Pod, Writer: Write + Seek>(
215 writer: &mut Writer,
216 data: &[T],
217 npts: usize,
218 ndims: usize,
219 aligned_dim: usize,
220 offset: usize,
221) -> ANNResult<usize> {
222 let bytes_written = 2 * std::mem::size_of::<u32>() + npts * ndims * (std::mem::size_of::<T>());
223
224 writer.seek(std::io::SeekFrom::Start(offset as u64))?;
225 write_metadata(writer, npts, ndims)?;
226
227 for i in 0..npts {
228 let start = i * aligned_dim;
229 let end = start + ndims;
230 let vector_slice = &data[start..end];
231 let bytes: &[u8] = bytemuck::must_cast_slice(vector_slice);
233 writer.write_all(bytes)?;
234 }
235 writer.flush()?;
236 Ok(bytes_written)
237}
238
239macro_rules! save_bin {
240 ($name:ident, $t:ty, $write_func:ident) => {
241 pub fn $name<W: Write + Seek>(
243 writer: &mut W,
244 data: &[$t],
245 num_pts: usize,
246 dims: usize,
247 offset: usize,
248 ) -> ANNResult<usize> {
249 writer.seek(SeekFrom::Start(offset as u64))?;
250 let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::<u32>();
251
252 write_metadata(writer, num_pts, dims)?;
253 info!(
254 "bin: #pts = {}, #dims = {}, size = {}B",
255 num_pts, dims, bytes_written
256 );
257
258 for item in data.iter() {
259 writer.$write_func::<LittleEndian>(*item)?;
260 }
261
262 writer.flush()?;
263
264 info!("Finished writing bin.");
265 Ok(bytes_written)
266 }
267 };
268}
269
270save_bin!(save_bin_f32, f32, write_f32);
271save_bin!(save_bin_u64, u64, write_u64);
272save_bin!(save_bin_u32, u32, write_u32);
273
274#[cfg(test)]
275mod storage_util_test {
276 use crate::storage::{StorageReadProvider, StorageWriteProvider, VirtualStorageProvider};
277 use tempfile::tempfile;
278
279 use super::*;
280 pub const DIM_8: usize = 8;
281
282 #[test]
283 fn read_metadata_test() {
284 let file_name = "/test_read_metadata_test.bin";
285 let data = [200, 0, 0, 0, 128, 0, 0, 0]; let storage_provider = VirtualStorageProvider::new_memory();
287 {
288 let mut file = storage_provider
289 .create_for_write(file_name)
290 .expect("Could not create file");
291 file.write_all(&data)
292 .expect("Should be able to write sample file");
293 }
294
295 let mut reader = storage_provider.open_reader(file_name).unwrap();
296 match read_metadata(&mut reader) {
297 Ok(metadata) => {
298 assert_eq!(metadata.npoints, 200);
299 assert_eq!(metadata.ndims, 128);
300 }
301 Err(_e) => {}
302 }
303 storage_provider
304 .delete(file_name)
305 .expect("Should be able to delete sample file");
306 }
307
308 #[test]
309 fn read_metadata_i32_compatibility_test() {
310 let file_name = "/test_read_metadata_i32_compat.bin";
312 let npts = 200i32;
313 let dims = 128i32;
314 let storage_provider = VirtualStorageProvider::new_memory();
315 {
316 let mut file = storage_provider
317 .create_for_write(file_name)
318 .expect("Could not create file");
319 file.write_i32::<LittleEndian>(npts).unwrap();
321 file.write_i32::<LittleEndian>(dims).unwrap();
322 }
323
324 let mut reader = storage_provider.open_reader(file_name).unwrap();
326 let metadata = read_metadata(&mut reader).unwrap();
327
328 assert_eq!(metadata.npoints, 200);
329 assert_eq!(metadata.ndims, 128);
330
331 storage_provider
332 .delete(file_name)
333 .expect("Should be able to delete sample file");
334 }
335
336 #[test]
337 fn load_vector_ids_test() {
338 let file_name = "/load_vector_ids_test";
339 let ids = vec![0u32, 1u32, 2u32];
340 let num_ids = ids.len();
341 let storage_provider = VirtualStorageProvider::new_memory();
342 {
343 let mut writer = storage_provider.create_for_write(file_name).unwrap();
344 writer.write_u32::<LittleEndian>(num_ids as u32).unwrap();
345 for item in ids.iter() {
346 writer.write_u32::<LittleEndian>(*item).unwrap();
347 }
348 }
349
350 let load_data =
351 load_vector_ids(&mut storage_provider.open_reader(file_name).unwrap()).unwrap();
352 assert_eq!(load_data, (num_ids, ids));
353 storage_provider
354 .delete(file_name)
355 .expect("Should be able to delete sample file");
356 }
357
358 #[test]
359 fn load_bin_test() {
360 let file_name = "/load_bin_test";
361 let data = vec![0u64, 1u64, 2u64];
362 let num_pts = data.len();
363 let dims = 1;
364 let storage_provider = VirtualStorageProvider::new_memory();
365 let bytes_written = save_bin_u64(
366 &mut storage_provider.create_for_write(file_name).unwrap(),
367 &data,
368 num_pts,
369 dims,
370 0,
371 )
372 .unwrap();
373 assert_eq!(bytes_written, 32);
374
375 let (load_data, load_num_pts, load_dims) =
376 load_bin::<u64, _>(&mut storage_provider.open_reader(file_name).unwrap(), 0).unwrap();
377 assert_eq!(load_num_pts, num_pts);
378 assert_eq!(load_dims, dims);
379 assert_eq!(load_data, data);
380 storage_provider.delete(file_name).unwrap();
381 }
382
383 #[test]
384 fn load_bin_offset_test() {
385 let offset: usize = 32;
386 let file_name = "/load_bin_offset_test";
387 let data = vec![0u64, 1u64, 2u64];
388 let num_pts = data.len();
389 let dims = 1;
390 let storage_provider = VirtualStorageProvider::new_memory();
391 let bytes_written = save_bin_u64(
392 &mut storage_provider.create_for_write(file_name).unwrap(),
393 &data,
394 num_pts,
395 dims,
396 offset,
397 )
398 .unwrap();
399 assert_eq!(bytes_written, 32);
400
401 let (load_data, load_num_pts, load_dims) = load_bin::<u64, _>(
402 &mut storage_provider.open_reader(file_name).unwrap(),
403 offset,
404 )
405 .unwrap();
406 assert_eq!(load_num_pts, num_pts);
407 assert_eq!(load_dims, dims);
408 assert_eq!(load_data, data);
409 storage_provider.delete(file_name).unwrap();
410 }
411
412 #[test]
413 fn save_data_in_base_dimensions_test() {
414 let data: [u8; 72] = [
416 2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
417 0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
418 0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
419 0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
420 0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
421 ];
422 let num_points = 2;
423 let dim = DIM_8;
424 let data_file = "/save_data_in_base_dimensions_test.data";
425 let storage_provider = VirtualStorageProvider::new_memory();
426 match save_data_in_base_dimensions(
427 &mut storage_provider.create_for_write(data_file).unwrap(),
428 &data,
429 num_points,
430 dim,
431 DIM_8,
432 0,
433 ) {
434 Ok(num) => {
435 assert!(storage_provider.exists(data_file));
436 assert_eq!(
437 num,
438 2 * std::mem::size_of::<u32>() + num_points * dim * std::mem::size_of::<u8>()
439 );
440 storage_provider
441 .delete(data_file)
442 .expect("Failed to delete file");
443 }
444 Err(e) => {
445 storage_provider
446 .delete(data_file)
447 .expect("Failed to delete file");
448 panic!("{}", e)
449 }
450 }
451 }
452
453 #[test]
454 fn save_bin_test() {
455 let data = vec![0u64, 1u64, 2u64];
456 let num_pts = data.len();
457 let dims = 1;
458 let mut file = tempfile().unwrap();
459 let bytes_written = save_bin_u64::<_>(&mut file, &data, num_pts, dims, 0).unwrap();
460 assert_eq!(bytes_written, 32);
461
462 let mut buffer = vec![];
463 file.seek(SeekFrom::Start(0)).unwrap();
464 let metadata = read_metadata(&mut file).unwrap();
465
466 file.read_to_end(&mut buffer).unwrap();
467 let data_read: Vec<u64> = buffer
468 .chunks_exact(8)
469 .map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]))
470 .collect();
471
472 assert_eq!(num_pts, metadata.npoints);
473 assert_eq!(dims, metadata.ndims);
474 assert_eq!(data, data_read);
475 }
476
477 #[test]
478 fn write_metadata_unified_test() {
479 let mut buffer = Vec::new();
480
481 let result = write_metadata(&mut buffer, 200u32, 128u32);
483 assert!(result.is_ok());
484 assert_eq!(result.unwrap(), 8);
485
486 buffer.clear();
488 let result = write_metadata(&mut buffer, 200usize, 128usize);
489 assert!(result.is_ok());
490 assert_eq!(result.unwrap(), 8);
491
492 buffer.clear();
494 let result = write_metadata(&mut buffer, 200usize, 128u32);
495 assert!(result.is_ok());
496
497 let mut cursor = std::io::Cursor::new(&buffer);
499 let metadata = read_metadata(&mut cursor).unwrap();
500 assert_eq!(metadata.npoints, 200);
501 assert_eq!(metadata.ndims, 128);
502 }
503
504 #[test]
505 fn metadata_error_types_test() {
506 let large_value = u32::MAX as usize + 1;
508 let result = write_metadata(&mut Vec::new(), large_value, 128usize);
509 assert!(matches!(result, Err(MetadataError::NumPoints(_))));
510
511 let result = write_metadata(&mut Vec::new(), 128usize, large_value);
513 assert!(matches!(result, Err(MetadataError::Dim(_))));
514
515 struct FailingWriter;
517 impl std::io::Write for FailingWriter {
518 fn write(&mut self, _: &[u8]) -> std::io::Result<usize> {
519 Err(std::io::Error::new(
520 std::io::ErrorKind::PermissionDenied,
521 "fail",
522 ))
523 }
524 fn flush(&mut self) -> std::io::Result<()> {
525 Ok(())
526 }
527 }
528
529 let result = write_metadata(&mut FailingWriter, 200u32, 128u32);
530 assert!(matches!(result, Err(MetadataError::Write(_))));
531 }
532
533 #[test]
534 fn metadata_error_to_ann_error_test() {
535 use diskann::{ANNError, ANNErrorKind};
536
537 let large_value = u32::MAX as usize + 1;
539 let result = write_metadata(&mut Vec::new(), large_value, 128usize);
540 let metadata_err = result.unwrap_err();
541 let ann_error: ANNError = metadata_err.into();
542
543 assert_eq!(ann_error.kind(), ANNErrorKind::IOError);
544
545 let error_str = ann_error.to_string();
547 assert!(error_str.contains("num points conversion"));
548 }
549}
550
551#[cfg(test)]
552mod tests {
553 use std::io::Cursor;
554
555 use super::*;
556
557 #[test]
558 fn test_copy_aligned_data() -> std::io::Result<()> {
559 let mut data = Vec::with_capacity(24);
560 data.extend_from_slice(&(2_i32.to_le_bytes()));
561 data.extend_from_slice(&(2_i32.to_le_bytes()));
562 data.extend_from_slice(&(1_f32.to_le_bytes()));
563 data.extend_from_slice(&(2_f32.to_le_bytes()));
564 data.extend_from_slice(&(3_f32.to_le_bytes()));
565 data.extend_from_slice(&(4_f32.to_le_bytes()));
566
567 let mut reader = Cursor::new(data);
568
569 let rounded_dim = 4;
570 let mut aligned_data = vec![0f32; 2 * rounded_dim];
571 let dataset_dto = DatasetDto::<f32> {
572 data: &mut aligned_data,
573 rounded_dim,
574 };
575
576 let (npts, dim) = copy_aligned_data(&mut reader, dataset_dto, 0)?;
577
578 assert_eq!(npts, 2);
579 assert_eq!(dim, 2);
580
581 assert_eq!(aligned_data, vec![1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
582
583 Ok(())
584 }
585}