Skip to main content

diskann_providers/utils/
storage_utils.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Utilities for reading and writing data from the storage layer with generic reader/writer.
7//! This is a replacement for the functions file_util.rs with generic reader/writer.
8
9use std::{
10    convert::TryInto,
11    io::{BufReader, Read, Seek, SeekFrom, Write},
12    mem,
13};
14
15use bytemuck::Pod;
16use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
17use diskann::{ANNError, ANNErrorKind, ANNResult, utils::read_exact_into};
18use diskann_wide::{LoHi, SplitJoin};
19use thiserror::Error;
20use tracing::info;
21
22use crate::utils::DatasetDto;
23
24const DEFAULT_BUF_SIZE: usize = 1024 * 1024;
25
26/// Metadata containing number of points and dimensions
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct Metadata {
29    pub npoints: usize,
30    pub ndims: usize,
31}
32
33/// Error type for metadata I/O operations
34#[derive(Debug, Error)]
35pub enum MetadataError<T, U> {
36    #[error("num points conversion")]
37    NumPoints(#[source] T),
38    #[error("dim conversion")]
39    Dim(#[source] U),
40    #[error("writing binary results")]
41    Write(#[source] std::io::Error),
42}
43
44impl<T, U> From<MetadataError<T, U>> for ANNError
45where
46    T: std::error::Error + Send + Sync + 'static,
47    U: std::error::Error + Send + Sync + 'static,
48{
49    #[track_caller]
50    fn from(err: MetadataError<T, U>) -> Self {
51        ANNError::new(ANNErrorKind::IOError, err)
52    }
53}
54
55/// Read binary metadata header (number of points and dimension) from a reader.
56///
57/// Reads 8 bytes total:
58/// - First 4 bytes: number of points (u32, little-endian)
59/// - Next 4 bytes: number of dimensions (u32, little-endian)
60///
61/// # Returns
62/// * `Ok(Metadata)` - Metadata containing number of points and dimensions
63/// * `Err(io::Error)` - If reading fails
64pub fn read_metadata<Reader: Read>(reader: &mut Reader) -> std::io::Result<Metadata> {
65    let raw = reader.read_u64::<LittleEndian>()?;
66    let bytes: [u8; 8] = bytemuck::cast(raw);
67    let LoHi {
68        lo: npts_bytes,
69        hi: ndims_bytes,
70    } = bytes.split();
71    let npoints = u32::from_le_bytes(npts_bytes) as usize;
72    let ndims = u32::from_le_bytes(ndims_bytes) as usize;
73    Ok(Metadata { npoints, ndims })
74}
75
76/// Write binary metadata header (number of points and dimension) to a writer.
77///
78/// Writes 8 bytes total:
79/// - First 4 bytes: number of points (u32, little-endian)
80/// - Next 4 bytes: number of dimensions (u32, little-endian)
81///
82/// This unified function accepts both `u32` and `usize` values, handling conversion appropriately:
83/// - `u32` values are written directly (no conversion overhead)
84/// - `usize` values are safely converted using `TryInto<u32>` (returns error on overflow)
85///
86/// # Returns
87/// * `Ok(usize)` - Number of bytes written (always 8)
88/// * `Err(MetadataError)` - If writing fails or conversion fails (usize > u32::MAX)
89pub fn write_metadata<Writer: Write, N, D>(
90    writer: &mut Writer,
91    npts: N,
92    ndims: D,
93) -> Result<usize, MetadataError<N::Error, D::Error>>
94where
95    N: TryInto<u32>,
96    D: TryInto<u32>,
97    N::Error: std::error::Error + 'static,
98    D::Error: std::error::Error + 'static,
99{
100    let npts_u32 = npts.try_into().map_err(MetadataError::NumPoints)?;
101    let ndims_u32 = ndims.try_into().map_err(MetadataError::Dim)?;
102
103    let bytes: [u8; 8] = LoHi::new(npts_u32.to_le_bytes(), ndims_u32.to_le_bytes()).join();
104    writer.write_all(&bytes).map_err(MetadataError::Write)?;
105
106    Ok(2 * std::mem::size_of::<u32>())
107}
108
109/// Load a list of vector ids from the stream.
110pub fn load_vector_ids<Reader: Read>(reader: &mut Reader) -> std::io::Result<(usize, Vec<u32>)> {
111    // The first 4 bytes are the number of vector ids.
112    // The rest of the file are the vector ids in the format of usize.
113    // The vector ids are sorted in ascending order.
114    let mut reader = BufReader::new(reader);
115    let num_ids = reader.read_u32::<LittleEndian>()? as usize;
116
117    let mut ids = Vec::with_capacity(num_ids);
118    for _ in 0..num_ids {
119        let id = reader.read_u32::<LittleEndian>()?;
120        ids.push(id);
121    }
122
123    Ok((num_ids, ids))
124}
125
126/// Copies data from a reader into a dataset with alignment.
127/// This function reads vector data and aligns it within the given dataset.
128///
129/// # Arguments
130/// * `reader` - A mutable reference to a type implementing the `Read` trait, where the data is read from.
131/// * `dataset_dto` - Destination dataset DTO to which the data is copied. It must have the correct rounded dimension.
132/// * `pts_offset` - Offset of points. Data will be loaded after this point in the dataset.
133///
134/// # Returns
135/// * `npts` - Number of points read from the reader.
136/// * `dim` - Point dimension read from the reader.
137#[cfg(target_endian = "little")]
138pub fn copy_aligned_data<T: Default + bytemuck::Pod, Reader: Read>(
139    reader: &mut Reader,
140    dataset_dto: DatasetDto<T>,
141    pts_offset: usize,
142) -> std::io::Result<(usize, usize)> {
143    let mut reader = BufReader::with_capacity(DEFAULT_BUF_SIZE, reader);
144
145    let metadata = read_metadata(&mut reader)?;
146    let (npts, dim) = (metadata.npoints, metadata.ndims);
147    let rounded_dim = dataset_dto.rounded_dim;
148    let offset = pts_offset * rounded_dim;
149
150    for i in 0..npts {
151        let data_slice =
152            &mut dataset_dto.data[offset + i * rounded_dim..offset + i * rounded_dim + dim];
153
154        // Casting Pod type to bytes always succeeds (u8 has alignment of 1)
155        let byte_slice: &mut [u8] = bytemuck::must_cast_slice_mut(data_slice);
156        reader.read_exact(byte_slice)?;
157
158        let remaining = &mut dataset_dto.data
159            [offset + i * rounded_dim + dim..offset + i * rounded_dim + rounded_dim];
160        remaining.fill_with(Default::default);
161    }
162
163    Ok((npts, dim))
164}
165
166/// Load a list of type T data from a stream.
167/// # Arguments
168/// * `reader` - a stream reader.
169/// * `offset` - start offset of the data.
170pub fn load_bin<T: Pod + Default, Reader: Read + Seek>(
171    reader: &mut Reader,
172    offset: usize,
173) -> std::io::Result<(Vec<T>, usize, usize)> {
174    let mut reader = BufReader::new(reader);
175    reader.seek(std::io::SeekFrom::Start(offset as u64))?;
176    let metadata = read_metadata(&mut reader)?;
177    let (npts, dim) = (metadata.npoints, metadata.ndims);
178
179    let size = npts * dim * std::mem::size_of::<T>();
180
181    let buf: Vec<T> = read_exact_into(&mut reader, npts * dim)?;
182    info!(
183        "bin: #pts = {}, #dims = {}, offset = {} size = {}B",
184        npts, dim, offset, size
185    );
186
187    Ok((buf, npts, dim))
188}
189
190/// Save the byte array to storage.
191pub fn save_bytes<Writer: Write + Seek>(
192    writer: &mut Writer,
193    data: &[u8],
194    npts: usize,
195    ndims: usize,
196    offset: usize,
197) -> ANNResult<usize> {
198    writer.seek(std::io::SeekFrom::Start(offset as u64))?;
199    write_metadata(writer, npts, ndims)?;
200    writer.write_all(data)?;
201    writer.flush()?;
202
203    Ok(data.len() + 2 * std::mem::size_of::<u32>())
204}
205
206/// Save vector data to stream with aligned dimension.
207/// # Arguments
208/// * `writer` - A writer to write the data to storage system.
209/// * `data` - information data
210/// * `npts` - number of points
211/// * `ndims` - point dimension
212/// * `aligned_dim` - aligned dimension
213/// * `offset` - data offset in file
214pub fn save_data_in_base_dimensions<T: Default + Copy + bytemuck::Pod, Writer: Write + Seek>(
215    writer: &mut Writer,
216    data: &[T],
217    npts: usize,
218    ndims: usize,
219    aligned_dim: usize,
220    offset: usize,
221) -> ANNResult<usize> {
222    let bytes_written = 2 * std::mem::size_of::<u32>() + npts * ndims * (std::mem::size_of::<T>());
223
224    writer.seek(std::io::SeekFrom::Start(offset as u64))?;
225    write_metadata(writer, npts, ndims)?;
226
227    for i in 0..npts {
228        let start = i * aligned_dim;
229        let end = start + ndims;
230        let vector_slice = &data[start..end];
231        // Casting Pod type to bytes always succeeds (u8 has alignment of 1)
232        let bytes: &[u8] = bytemuck::must_cast_slice(vector_slice);
233        writer.write_all(bytes)?;
234    }
235    writer.flush()?;
236    Ok(bytes_written)
237}
238
239macro_rules! save_bin {
240    ($name:ident, $t:ty, $write_func:ident) => {
241        /// Write data into the storage system.
242        pub fn $name<W: Write + Seek>(
243            writer: &mut W,
244            data: &[$t],
245            num_pts: usize,
246            dims: usize,
247            offset: usize,
248        ) -> ANNResult<usize> {
249            writer.seek(SeekFrom::Start(offset as u64))?;
250            let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::<u32>();
251
252            write_metadata(writer, num_pts, dims)?;
253            info!(
254                "bin: #pts = {}, #dims = {}, size = {}B",
255                num_pts, dims, bytes_written
256            );
257
258            for item in data.iter() {
259                writer.$write_func::<LittleEndian>(*item)?;
260            }
261
262            writer.flush()?;
263
264            info!("Finished writing bin.");
265            Ok(bytes_written)
266        }
267    };
268}
269
270save_bin!(save_bin_f32, f32, write_f32);
271save_bin!(save_bin_u64, u64, write_u64);
272save_bin!(save_bin_u32, u32, write_u32);
273
274#[cfg(test)]
275mod storage_util_test {
276    use crate::storage::{StorageReadProvider, StorageWriteProvider, VirtualStorageProvider};
277    use tempfile::tempfile;
278    use vfs::MemoryFS;
279
280    use super::*;
281    pub const DIM_8: usize = 8;
282
283    #[test]
284    fn read_metadata_test() {
285        let file_name = "/test_read_metadata_test.bin";
286        let data = [200, 0, 0, 0, 128, 0, 0, 0]; // 200 and 128 in little endian bytes (u32)
287        let vfs = MemoryFS::default();
288        let storage_provider = VirtualStorageProvider::new(vfs);
289        {
290            let mut file = storage_provider
291                .create_for_write(file_name)
292                .expect("Could not create file");
293            file.write_all(&data)
294                .expect("Should be able to write sample file");
295        }
296
297        let mut reader = storage_provider.open_reader(file_name).unwrap();
298        match read_metadata(&mut reader) {
299            Ok(metadata) => {
300                assert_eq!(metadata.npoints, 200);
301                assert_eq!(metadata.ndims, 128);
302            }
303            Err(_e) => {}
304        }
305        storage_provider
306            .delete(file_name)
307            .expect("Should be able to delete sample file");
308    }
309
310    #[test]
311    fn read_metadata_i32_compatibility_test() {
312        // Test that read_metadata (u32) can read data written as i32
313        let file_name = "/test_read_metadata_i32_compat.bin";
314        let npts = 200i32;
315        let dims = 128i32;
316        let vfs = MemoryFS::default();
317        let storage_provider = VirtualStorageProvider::new(vfs);
318        {
319            let mut file = storage_provider
320                .create_for_write(file_name)
321                .expect("Could not create file");
322            // Write as i32 (old format)
323            file.write_i32::<LittleEndian>(npts).unwrap();
324            file.write_i32::<LittleEndian>(dims).unwrap();
325        }
326
327        // Read as u32 (new format)
328        let mut reader = storage_provider.open_reader(file_name).unwrap();
329        let metadata = read_metadata(&mut reader).unwrap();
330
331        assert_eq!(metadata.npoints, 200);
332        assert_eq!(metadata.ndims, 128);
333
334        storage_provider
335            .delete(file_name)
336            .expect("Should be able to delete sample file");
337    }
338
339    #[test]
340    fn load_vector_ids_test() {
341        let file_name = "/load_vector_ids_test";
342        let ids = vec![0u32, 1u32, 2u32];
343        let num_ids = ids.len();
344        let vfs = MemoryFS::new();
345        let storage_provider = VirtualStorageProvider::new(vfs);
346        {
347            let mut writer = storage_provider.create_for_write(file_name).unwrap();
348            writer.write_u32::<LittleEndian>(num_ids as u32).unwrap();
349            for item in ids.iter() {
350                writer.write_u32::<LittleEndian>(*item).unwrap();
351            }
352        }
353
354        let load_data =
355            load_vector_ids(&mut storage_provider.open_reader(file_name).unwrap()).unwrap();
356        assert_eq!(load_data, (num_ids, ids));
357        storage_provider
358            .delete(file_name)
359            .expect("Should be able to delete sample file");
360    }
361
362    #[test]
363    fn load_bin_test() {
364        let file_name = "/load_bin_test";
365        let data = vec![0u64, 1u64, 2u64];
366        let num_pts = data.len();
367        let dims = 1;
368        let vfs = MemoryFS::new();
369        let storage_provider = VirtualStorageProvider::new(vfs);
370        let bytes_written = save_bin_u64(
371            &mut storage_provider.create_for_write(file_name).unwrap(),
372            &data,
373            num_pts,
374            dims,
375            0,
376        )
377        .unwrap();
378        assert_eq!(bytes_written, 32);
379
380        let (load_data, load_num_pts, load_dims) =
381            load_bin::<u64, _>(&mut storage_provider.open_reader(file_name).unwrap(), 0).unwrap();
382        assert_eq!(load_num_pts, num_pts);
383        assert_eq!(load_dims, dims);
384        assert_eq!(load_data, data);
385        storage_provider.delete(file_name).unwrap();
386    }
387
388    #[test]
389    fn load_bin_offset_test() {
390        let offset: usize = 32;
391        let file_name = "/load_bin_offset_test";
392        let data = vec![0u64, 1u64, 2u64];
393        let num_pts = data.len();
394        let dims = 1;
395        let vfs = MemoryFS::new();
396        let storage_provider = VirtualStorageProvider::new(vfs);
397        let bytes_written = save_bin_u64(
398            &mut storage_provider.create_for_write(file_name).unwrap(),
399            &data,
400            num_pts,
401            dims,
402            offset,
403        )
404        .unwrap();
405        assert_eq!(bytes_written, 32);
406
407        let (load_data, load_num_pts, load_dims) = load_bin::<u64, _>(
408            &mut storage_provider.open_reader(file_name).unwrap(),
409            offset,
410        )
411        .unwrap();
412        assert_eq!(load_num_pts, num_pts);
413        assert_eq!(load_dims, dims);
414        assert_eq!(load_data, data);
415        storage_provider.delete(file_name).unwrap();
416    }
417
418    #[test]
419    fn save_data_in_base_dimensions_test() {
420        //npoints=2, dim=8
421        let data: [u8; 72] = [
422            2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
423            0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
424            0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
425            0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
426            0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
427        ];
428        let num_points = 2;
429        let dim = DIM_8;
430        let data_file = "/save_data_in_base_dimensions_test.data";
431        let vfs = MemoryFS::new();
432        let storage_provider = VirtualStorageProvider::new(vfs);
433        match save_data_in_base_dimensions(
434            &mut storage_provider.create_for_write(data_file).unwrap(),
435            &data,
436            num_points,
437            dim,
438            DIM_8,
439            0,
440        ) {
441            Ok(num) => {
442                assert!(storage_provider.exists(data_file));
443                assert_eq!(
444                    num,
445                    2 * std::mem::size_of::<u32>() + num_points * dim * std::mem::size_of::<u8>()
446                );
447                storage_provider
448                    .delete(data_file)
449                    .expect("Failed to delete file");
450            }
451            Err(e) => {
452                storage_provider
453                    .delete(data_file)
454                    .expect("Failed to delete file");
455                panic!("{}", e)
456            }
457        }
458    }
459
460    #[test]
461    fn save_bin_test() {
462        let data = vec![0u64, 1u64, 2u64];
463        let num_pts = data.len();
464        let dims = 1;
465        let mut file = tempfile().unwrap();
466        let bytes_written = save_bin_u64::<_>(&mut file, &data, num_pts, dims, 0).unwrap();
467        assert_eq!(bytes_written, 32);
468
469        let mut buffer = vec![];
470        file.seek(SeekFrom::Start(0)).unwrap();
471        let metadata = read_metadata(&mut file).unwrap();
472
473        file.read_to_end(&mut buffer).unwrap();
474        let data_read: Vec<u64> = buffer
475            .chunks_exact(8)
476            .map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]))
477            .collect();
478
479        assert_eq!(num_pts, metadata.npoints);
480        assert_eq!(dims, metadata.ndims);
481        assert_eq!(data, data_read);
482    }
483
484    #[test]
485    fn write_metadata_unified_test() {
486        let mut buffer = Vec::new();
487
488        // Test with u32 values (no conversion)
489        let result = write_metadata(&mut buffer, 200u32, 128u32);
490        assert!(result.is_ok());
491        assert_eq!(result.unwrap(), 8);
492
493        // Test with usize values (safe conversion)
494        buffer.clear();
495        let result = write_metadata(&mut buffer, 200usize, 128usize);
496        assert!(result.is_ok());
497        assert_eq!(result.unwrap(), 8);
498
499        // Test mixed types
500        buffer.clear();
501        let result = write_metadata(&mut buffer, 200usize, 128u32);
502        assert!(result.is_ok());
503
504        // Verify the written data
505        let mut cursor = std::io::Cursor::new(&buffer);
506        let metadata = read_metadata(&mut cursor).unwrap();
507        assert_eq!(metadata.npoints, 200);
508        assert_eq!(metadata.ndims, 128);
509    }
510
511    #[test]
512    fn metadata_error_types_test() {
513        // Test NumPoints error
514        let large_value = u32::MAX as usize + 1;
515        let result = write_metadata(&mut Vec::new(), large_value, 128usize);
516        assert!(matches!(result, Err(MetadataError::NumPoints(_))));
517
518        // Test Dim error
519        let result = write_metadata(&mut Vec::new(), 128usize, large_value);
520        assert!(matches!(result, Err(MetadataError::Dim(_))));
521
522        // Test Write error
523        struct FailingWriter;
524        impl std::io::Write for FailingWriter {
525            fn write(&mut self, _: &[u8]) -> std::io::Result<usize> {
526                Err(std::io::Error::new(
527                    std::io::ErrorKind::PermissionDenied,
528                    "fail",
529                ))
530            }
531            fn flush(&mut self) -> std::io::Result<()> {
532                Ok(())
533            }
534        }
535
536        let result = write_metadata(&mut FailingWriter, 200u32, 128u32);
537        assert!(matches!(result, Err(MetadataError::Write(_))));
538    }
539
540    #[test]
541    fn metadata_error_to_ann_error_test() {
542        use diskann::{ANNError, ANNErrorKind};
543
544        // Test MetadataError -> ANNError conversion
545        let large_value = u32::MAX as usize + 1;
546        let result = write_metadata(&mut Vec::new(), large_value, 128usize);
547        let metadata_err = result.unwrap_err();
548        let ann_error: ANNError = metadata_err.into();
549
550        assert_eq!(ann_error.kind(), ANNErrorKind::IOError);
551
552        // Check that the error message contains information about the conversion
553        let error_str = ann_error.to_string();
554        assert!(error_str.contains("num points conversion"));
555    }
556}
557
558#[cfg(test)]
559mod tests {
560    use std::io::Cursor;
561
562    use super::*;
563
564    #[test]
565    fn test_copy_aligned_data() -> std::io::Result<()> {
566        let mut data = Vec::with_capacity(24);
567        data.extend_from_slice(&(2_i32.to_le_bytes()));
568        data.extend_from_slice(&(2_i32.to_le_bytes()));
569        data.extend_from_slice(&(1_f32.to_le_bytes()));
570        data.extend_from_slice(&(2_f32.to_le_bytes()));
571        data.extend_from_slice(&(3_f32.to_le_bytes()));
572        data.extend_from_slice(&(4_f32.to_le_bytes()));
573
574        let mut reader = Cursor::new(data);
575
576        let rounded_dim = 4;
577        let mut aligned_data = vec![0f32; 2 * rounded_dim];
578        let dataset_dto = DatasetDto::<f32> {
579            data: &mut aligned_data,
580            rounded_dim,
581        };
582
583        let (npts, dim) = copy_aligned_data(&mut reader, dataset_dto, 0)?;
584
585        assert_eq!(npts, 2);
586        assert_eq!(dim, 2);
587
588        assert_eq!(aligned_data, vec![1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
589
590        Ok(())
591    }
592}