Skip to main content

diskann_providers/utils/
storage_utils.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Utilities for reading and writing data from the storage layer with generic reader/writer.
7//! This is a replacement for the functions file_util.rs with generic reader/writer.
8
9use std::{
10    convert::TryInto,
11    io::{BufReader, Read, Seek, SeekFrom, Write},
12    mem,
13};
14
15use bytemuck::Pod;
16use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt};
17use diskann::{ANNError, ANNErrorKind, ANNResult, utils::read_exact_into};
18use diskann_wide::{LoHi, SplitJoin};
19use thiserror::Error;
20use tracing::info;
21
22use crate::utils::DatasetDto;
23
24const DEFAULT_BUF_SIZE: usize = 1024 * 1024;
25
26/// Metadata containing number of points and dimensions
27#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub struct Metadata {
29    pub npoints: usize,
30    pub ndims: usize,
31}
32
33/// Error type for metadata I/O operations
34#[derive(Debug, Error)]
35pub enum MetadataError<T, U> {
36    #[error("num points conversion")]
37    NumPoints(#[source] T),
38    #[error("dim conversion")]
39    Dim(#[source] U),
40    #[error("writing binary results")]
41    Write(#[source] std::io::Error),
42}
43
44impl<T, U> From<MetadataError<T, U>> for ANNError
45where
46    T: std::error::Error + Send + Sync + 'static,
47    U: std::error::Error + Send + Sync + 'static,
48{
49    #[track_caller]
50    fn from(err: MetadataError<T, U>) -> Self {
51        ANNError::new(ANNErrorKind::IOError, err)
52    }
53}
54
55/// Read binary metadata header (number of points and dimension) from a reader.
56///
57/// Reads 8 bytes total:
58/// - First 4 bytes: number of points (u32, little-endian)
59/// - Next 4 bytes: number of dimensions (u32, little-endian)
60///
61/// # Returns
62/// * `Ok(Metadata)` - Metadata containing number of points and dimensions
63/// * `Err(io::Error)` - If reading fails
64pub fn read_metadata<Reader: Read>(reader: &mut Reader) -> std::io::Result<Metadata> {
65    let raw = reader.read_u64::<LittleEndian>()?;
66    let bytes: [u8; 8] = bytemuck::cast(raw);
67    let LoHi {
68        lo: npts_bytes,
69        hi: ndims_bytes,
70    } = bytes.split();
71    let npoints = u32::from_le_bytes(npts_bytes) as usize;
72    let ndims = u32::from_le_bytes(ndims_bytes) as usize;
73    Ok(Metadata { npoints, ndims })
74}
75
76/// Write binary metadata header (number of points and dimension) to a writer.
77///
78/// Writes 8 bytes total:
79/// - First 4 bytes: number of points (u32, little-endian)
80/// - Next 4 bytes: number of dimensions (u32, little-endian)
81///
82/// This unified function accepts both `u32` and `usize` values, handling conversion appropriately:
83/// - `u32` values are written directly (no conversion overhead)
84/// - `usize` values are safely converted using `TryInto<u32>` (returns error on overflow)
85///
86/// # Returns
87/// * `Ok(usize)` - Number of bytes written (always 8)
88/// * `Err(MetadataError)` - If writing fails or conversion fails (usize > u32::MAX)
89pub fn write_metadata<Writer: Write, N, D>(
90    writer: &mut Writer,
91    npts: N,
92    ndims: D,
93) -> Result<usize, MetadataError<N::Error, D::Error>>
94where
95    N: TryInto<u32>,
96    D: TryInto<u32>,
97    N::Error: std::error::Error + 'static,
98    D::Error: std::error::Error + 'static,
99{
100    let npts_u32 = npts.try_into().map_err(MetadataError::NumPoints)?;
101    let ndims_u32 = ndims.try_into().map_err(MetadataError::Dim)?;
102
103    let bytes: [u8; 8] = LoHi::new(npts_u32.to_le_bytes(), ndims_u32.to_le_bytes()).join();
104    writer.write_all(&bytes).map_err(MetadataError::Write)?;
105
106    Ok(2 * std::mem::size_of::<u32>())
107}
108
109/// Load a list of vector ids from the stream.
110pub fn load_vector_ids<Reader: Read>(reader: &mut Reader) -> std::io::Result<(usize, Vec<u32>)> {
111    // The first 4 bytes are the number of vector ids.
112    // The rest of the file are the vector ids in the format of usize.
113    // The vector ids are sorted in ascending order.
114    let mut reader = BufReader::new(reader);
115    let num_ids = reader.read_u32::<LittleEndian>()? as usize;
116
117    let mut ids = Vec::with_capacity(num_ids);
118    for _ in 0..num_ids {
119        let id = reader.read_u32::<LittleEndian>()?;
120        ids.push(id);
121    }
122
123    Ok((num_ids, ids))
124}
125
126/// Copies data from a reader into a dataset with alignment.
127/// This function reads vector data and aligns it within the given dataset.
128///
129/// # Arguments
130/// * `reader` - A mutable reference to a type implementing the `Read` trait, where the data is read from.
131/// * `dataset_dto` - Destination dataset DTO to which the data is copied. It must have the correct rounded dimension.
132/// * `pts_offset` - Offset of points. Data will be loaded after this point in the dataset.
133///
134/// # Returns
135/// * `npts` - Number of points read from the reader.
136/// * `dim` - Point dimension read from the reader.
137#[cfg(target_endian = "little")]
138pub fn copy_aligned_data<T: Default + bytemuck::Pod, Reader: Read>(
139    reader: &mut Reader,
140    dataset_dto: DatasetDto<T>,
141    pts_offset: usize,
142) -> std::io::Result<(usize, usize)> {
143    let mut reader = BufReader::with_capacity(DEFAULT_BUF_SIZE, reader);
144
145    let metadata = read_metadata(&mut reader)?;
146    let (npts, dim) = (metadata.npoints, metadata.ndims);
147    let rounded_dim = dataset_dto.rounded_dim;
148    let offset = pts_offset * rounded_dim;
149
150    for i in 0..npts {
151        let data_slice =
152            &mut dataset_dto.data[offset + i * rounded_dim..offset + i * rounded_dim + dim];
153
154        // Casting Pod type to bytes always succeeds (u8 has alignment of 1)
155        let byte_slice: &mut [u8] = bytemuck::must_cast_slice_mut(data_slice);
156        reader.read_exact(byte_slice)?;
157
158        let remaining = &mut dataset_dto.data
159            [offset + i * rounded_dim + dim..offset + i * rounded_dim + rounded_dim];
160        remaining.fill_with(Default::default);
161    }
162
163    Ok((npts, dim))
164}
165
166/// Load a list of type T data from a stream.
167/// # Arguments
168/// * `reader` - a stream reader.
169/// * `offset` - start offset of the data.
170pub fn load_bin<T: Pod + Default, Reader: Read + Seek>(
171    reader: &mut Reader,
172    offset: usize,
173) -> std::io::Result<(Vec<T>, usize, usize)> {
174    let mut reader = BufReader::new(reader);
175    reader.seek(std::io::SeekFrom::Start(offset as u64))?;
176    let metadata = read_metadata(&mut reader)?;
177    let (npts, dim) = (metadata.npoints, metadata.ndims);
178
179    let size = npts * dim * std::mem::size_of::<T>();
180
181    let buf: Vec<T> = read_exact_into(&mut reader, npts * dim)?;
182    info!(
183        "bin: #pts = {}, #dims = {}, offset = {} size = {}B",
184        npts, dim, offset, size
185    );
186
187    Ok((buf, npts, dim))
188}
189
190/// Save the byte array to storage.
191pub fn save_bytes<Writer: Write + Seek>(
192    writer: &mut Writer,
193    data: &[u8],
194    npts: usize,
195    ndims: usize,
196    offset: usize,
197) -> ANNResult<usize> {
198    writer.seek(std::io::SeekFrom::Start(offset as u64))?;
199    write_metadata(writer, npts, ndims)?;
200    writer.write_all(data)?;
201    writer.flush()?;
202
203    Ok(data.len() + 2 * std::mem::size_of::<u32>())
204}
205
206/// Save vector data to stream with aligned dimension.
207/// # Arguments
208/// * `writer` - A writer to write the data to storage system.
209/// * `data` - information data
210/// * `npts` - number of points
211/// * `ndims` - point dimension
212/// * `aligned_dim` - aligned dimension
213/// * `offset` - data offset in file
214pub fn save_data_in_base_dimensions<T: Default + Copy + bytemuck::Pod, Writer: Write + Seek>(
215    writer: &mut Writer,
216    data: &[T],
217    npts: usize,
218    ndims: usize,
219    aligned_dim: usize,
220    offset: usize,
221) -> ANNResult<usize> {
222    let bytes_written = 2 * std::mem::size_of::<u32>() + npts * ndims * (std::mem::size_of::<T>());
223
224    writer.seek(std::io::SeekFrom::Start(offset as u64))?;
225    write_metadata(writer, npts, ndims)?;
226
227    for i in 0..npts {
228        let start = i * aligned_dim;
229        let end = start + ndims;
230        let vector_slice = &data[start..end];
231        // Casting Pod type to bytes always succeeds (u8 has alignment of 1)
232        let bytes: &[u8] = bytemuck::must_cast_slice(vector_slice);
233        writer.write_all(bytes)?;
234    }
235    writer.flush()?;
236    Ok(bytes_written)
237}
238
239macro_rules! save_bin {
240    ($name:ident, $t:ty, $write_func:ident) => {
241        /// Write data into the storage system.
242        pub fn $name<W: Write + Seek>(
243            writer: &mut W,
244            data: &[$t],
245            num_pts: usize,
246            dims: usize,
247            offset: usize,
248        ) -> ANNResult<usize> {
249            writer.seek(SeekFrom::Start(offset as u64))?;
250            let bytes_written = num_pts * dims * mem::size_of::<$t>() + 2 * mem::size_of::<u32>();
251
252            write_metadata(writer, num_pts, dims)?;
253            info!(
254                "bin: #pts = {}, #dims = {}, size = {}B",
255                num_pts, dims, bytes_written
256            );
257
258            for item in data.iter() {
259                writer.$write_func::<LittleEndian>(*item)?;
260            }
261
262            writer.flush()?;
263
264            info!("Finished writing bin.");
265            Ok(bytes_written)
266        }
267    };
268}
269
270save_bin!(save_bin_f32, f32, write_f32);
271save_bin!(save_bin_u64, u64, write_u64);
272save_bin!(save_bin_u32, u32, write_u32);
273
274#[cfg(test)]
275mod storage_util_test {
276    use crate::storage::{StorageReadProvider, StorageWriteProvider, VirtualStorageProvider};
277    use tempfile::tempfile;
278
279    use super::*;
280    pub const DIM_8: usize = 8;
281
282    #[test]
283    fn read_metadata_test() {
284        let file_name = "/test_read_metadata_test.bin";
285        let data = [200, 0, 0, 0, 128, 0, 0, 0]; // 200 and 128 in little endian bytes (u32)
286        let storage_provider = VirtualStorageProvider::new_memory();
287        {
288            let mut file = storage_provider
289                .create_for_write(file_name)
290                .expect("Could not create file");
291            file.write_all(&data)
292                .expect("Should be able to write sample file");
293        }
294
295        let mut reader = storage_provider.open_reader(file_name).unwrap();
296        match read_metadata(&mut reader) {
297            Ok(metadata) => {
298                assert_eq!(metadata.npoints, 200);
299                assert_eq!(metadata.ndims, 128);
300            }
301            Err(_e) => {}
302        }
303        storage_provider
304            .delete(file_name)
305            .expect("Should be able to delete sample file");
306    }
307
308    #[test]
309    fn read_metadata_i32_compatibility_test() {
310        // Test that read_metadata (u32) can read data written as i32
311        let file_name = "/test_read_metadata_i32_compat.bin";
312        let npts = 200i32;
313        let dims = 128i32;
314        let storage_provider = VirtualStorageProvider::new_memory();
315        {
316            let mut file = storage_provider
317                .create_for_write(file_name)
318                .expect("Could not create file");
319            // Write as i32 (old format)
320            file.write_i32::<LittleEndian>(npts).unwrap();
321            file.write_i32::<LittleEndian>(dims).unwrap();
322        }
323
324        // Read as u32 (new format)
325        let mut reader = storage_provider.open_reader(file_name).unwrap();
326        let metadata = read_metadata(&mut reader).unwrap();
327
328        assert_eq!(metadata.npoints, 200);
329        assert_eq!(metadata.ndims, 128);
330
331        storage_provider
332            .delete(file_name)
333            .expect("Should be able to delete sample file");
334    }
335
336    #[test]
337    fn load_vector_ids_test() {
338        let file_name = "/load_vector_ids_test";
339        let ids = vec![0u32, 1u32, 2u32];
340        let num_ids = ids.len();
341        let storage_provider = VirtualStorageProvider::new_memory();
342        {
343            let mut writer = storage_provider.create_for_write(file_name).unwrap();
344            writer.write_u32::<LittleEndian>(num_ids as u32).unwrap();
345            for item in ids.iter() {
346                writer.write_u32::<LittleEndian>(*item).unwrap();
347            }
348        }
349
350        let load_data =
351            load_vector_ids(&mut storage_provider.open_reader(file_name).unwrap()).unwrap();
352        assert_eq!(load_data, (num_ids, ids));
353        storage_provider
354            .delete(file_name)
355            .expect("Should be able to delete sample file");
356    }
357
358    #[test]
359    fn load_bin_test() {
360        let file_name = "/load_bin_test";
361        let data = vec![0u64, 1u64, 2u64];
362        let num_pts = data.len();
363        let dims = 1;
364        let storage_provider = VirtualStorageProvider::new_memory();
365        let bytes_written = save_bin_u64(
366            &mut storage_provider.create_for_write(file_name).unwrap(),
367            &data,
368            num_pts,
369            dims,
370            0,
371        )
372        .unwrap();
373        assert_eq!(bytes_written, 32);
374
375        let (load_data, load_num_pts, load_dims) =
376            load_bin::<u64, _>(&mut storage_provider.open_reader(file_name).unwrap(), 0).unwrap();
377        assert_eq!(load_num_pts, num_pts);
378        assert_eq!(load_dims, dims);
379        assert_eq!(load_data, data);
380        storage_provider.delete(file_name).unwrap();
381    }
382
383    #[test]
384    fn load_bin_offset_test() {
385        let offset: usize = 32;
386        let file_name = "/load_bin_offset_test";
387        let data = vec![0u64, 1u64, 2u64];
388        let num_pts = data.len();
389        let dims = 1;
390        let storage_provider = VirtualStorageProvider::new_memory();
391        let bytes_written = save_bin_u64(
392            &mut storage_provider.create_for_write(file_name).unwrap(),
393            &data,
394            num_pts,
395            dims,
396            offset,
397        )
398        .unwrap();
399        assert_eq!(bytes_written, 32);
400
401        let (load_data, load_num_pts, load_dims) = load_bin::<u64, _>(
402            &mut storage_provider.open_reader(file_name).unwrap(),
403            offset,
404        )
405        .unwrap();
406        assert_eq!(load_num_pts, num_pts);
407        assert_eq!(load_dims, dims);
408        assert_eq!(load_data, data);
409        storage_provider.delete(file_name).unwrap();
410    }
411
412    #[test]
413    fn save_data_in_base_dimensions_test() {
414        //npoints=2, dim=8
415        let data: [u8; 72] = [
416            2, 0, 0, 0, 8, 0, 0, 0, 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x40, 0x00, 0x00,
417            0x40, 0x40, 0x00, 0x00, 0x80, 0x40, 0x00, 0x00, 0xa0, 0x40, 0x00, 0x00, 0xc0, 0x40,
418            0x00, 0x00, 0xe0, 0x40, 0x00, 0x00, 0x00, 0x41, 0x00, 0x00, 0x10, 0x41, 0x00, 0x00,
419            0x20, 0x41, 0x00, 0x00, 0x30, 0x41, 0x00, 0x00, 0x40, 0x41, 0x00, 0x00, 0x50, 0x41,
420            0x00, 0x00, 0x60, 0x41, 0x00, 0x00, 0x70, 0x41, 0x00, 0x00, 0x80, 0x41,
421        ];
422        let num_points = 2;
423        let dim = DIM_8;
424        let data_file = "/save_data_in_base_dimensions_test.data";
425        let storage_provider = VirtualStorageProvider::new_memory();
426        match save_data_in_base_dimensions(
427            &mut storage_provider.create_for_write(data_file).unwrap(),
428            &data,
429            num_points,
430            dim,
431            DIM_8,
432            0,
433        ) {
434            Ok(num) => {
435                assert!(storage_provider.exists(data_file));
436                assert_eq!(
437                    num,
438                    2 * std::mem::size_of::<u32>() + num_points * dim * std::mem::size_of::<u8>()
439                );
440                storage_provider
441                    .delete(data_file)
442                    .expect("Failed to delete file");
443            }
444            Err(e) => {
445                storage_provider
446                    .delete(data_file)
447                    .expect("Failed to delete file");
448                panic!("{}", e)
449            }
450        }
451    }
452
453    #[test]
454    fn save_bin_test() {
455        let data = vec![0u64, 1u64, 2u64];
456        let num_pts = data.len();
457        let dims = 1;
458        let mut file = tempfile().unwrap();
459        let bytes_written = save_bin_u64::<_>(&mut file, &data, num_pts, dims, 0).unwrap();
460        assert_eq!(bytes_written, 32);
461
462        let mut buffer = vec![];
463        file.seek(SeekFrom::Start(0)).unwrap();
464        let metadata = read_metadata(&mut file).unwrap();
465
466        file.read_to_end(&mut buffer).unwrap();
467        let data_read: Vec<u64> = buffer
468            .chunks_exact(8)
469            .map(|b| u64::from_le_bytes([b[0], b[1], b[2], b[3], b[4], b[5], b[6], b[7]]))
470            .collect();
471
472        assert_eq!(num_pts, metadata.npoints);
473        assert_eq!(dims, metadata.ndims);
474        assert_eq!(data, data_read);
475    }
476
477    #[test]
478    fn write_metadata_unified_test() {
479        let mut buffer = Vec::new();
480
481        // Test with u32 values (no conversion)
482        let result = write_metadata(&mut buffer, 200u32, 128u32);
483        assert!(result.is_ok());
484        assert_eq!(result.unwrap(), 8);
485
486        // Test with usize values (safe conversion)
487        buffer.clear();
488        let result = write_metadata(&mut buffer, 200usize, 128usize);
489        assert!(result.is_ok());
490        assert_eq!(result.unwrap(), 8);
491
492        // Test mixed types
493        buffer.clear();
494        let result = write_metadata(&mut buffer, 200usize, 128u32);
495        assert!(result.is_ok());
496
497        // Verify the written data
498        let mut cursor = std::io::Cursor::new(&buffer);
499        let metadata = read_metadata(&mut cursor).unwrap();
500        assert_eq!(metadata.npoints, 200);
501        assert_eq!(metadata.ndims, 128);
502    }
503
504    #[test]
505    fn metadata_error_types_test() {
506        // Test NumPoints error
507        let large_value = u32::MAX as usize + 1;
508        let result = write_metadata(&mut Vec::new(), large_value, 128usize);
509        assert!(matches!(result, Err(MetadataError::NumPoints(_))));
510
511        // Test Dim error
512        let result = write_metadata(&mut Vec::new(), 128usize, large_value);
513        assert!(matches!(result, Err(MetadataError::Dim(_))));
514
515        // Test Write error
516        struct FailingWriter;
517        impl std::io::Write for FailingWriter {
518            fn write(&mut self, _: &[u8]) -> std::io::Result<usize> {
519                Err(std::io::Error::new(
520                    std::io::ErrorKind::PermissionDenied,
521                    "fail",
522                ))
523            }
524            fn flush(&mut self) -> std::io::Result<()> {
525                Ok(())
526            }
527        }
528
529        let result = write_metadata(&mut FailingWriter, 200u32, 128u32);
530        assert!(matches!(result, Err(MetadataError::Write(_))));
531    }
532
533    #[test]
534    fn metadata_error_to_ann_error_test() {
535        use diskann::{ANNError, ANNErrorKind};
536
537        // Test MetadataError -> ANNError conversion
538        let large_value = u32::MAX as usize + 1;
539        let result = write_metadata(&mut Vec::new(), large_value, 128usize);
540        let metadata_err = result.unwrap_err();
541        let ann_error: ANNError = metadata_err.into();
542
543        assert_eq!(ann_error.kind(), ANNErrorKind::IOError);
544
545        // Check that the error message contains information about the conversion
546        let error_str = ann_error.to_string();
547        assert!(error_str.contains("num points conversion"));
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use std::io::Cursor;
554
555    use super::*;
556
557    #[test]
558    fn test_copy_aligned_data() -> std::io::Result<()> {
559        let mut data = Vec::with_capacity(24);
560        data.extend_from_slice(&(2_i32.to_le_bytes()));
561        data.extend_from_slice(&(2_i32.to_le_bytes()));
562        data.extend_from_slice(&(1_f32.to_le_bytes()));
563        data.extend_from_slice(&(2_f32.to_le_bytes()));
564        data.extend_from_slice(&(3_f32.to_le_bytes()));
565        data.extend_from_slice(&(4_f32.to_le_bytes()));
566
567        let mut reader = Cursor::new(data);
568
569        let rounded_dim = 4;
570        let mut aligned_data = vec![0f32; 2 * rounded_dim];
571        let dataset_dto = DatasetDto::<f32> {
572            data: &mut aligned_data,
573            rounded_dim,
574        };
575
576        let (npts, dim) = copy_aligned_data(&mut reader, dataset_dto, 0)?;
577
578        assert_eq!(npts, 2);
579        assert_eq!(dim, 2);
580
581        assert_eq!(aligned_data, vec![1.0, 2.0, 0.0, 0.0, 3.0, 4.0, 0.0, 0.0]);
582
583        Ok(())
584    }
585}