Skip to main content

diskann_utils/
io.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6//! Read and write vectors in the DiskANN binary format.
7//!
8//! The binary format is:
9//! - 8-byte header
10//!   - `npoints` (u32 LE)
11//!   - `ndims` (u32 LE)
12//! - Payload: `npoints × ndims` elements of `T`, tightly packed in row-major order
13
14use std::io::{Read, Seek, Write};
15
16use diskann_wide::{LoHi, SplitJoin};
17use thiserror::Error;
18
19use crate::views::{Matrix, MatrixView};
20
21/// Read a matrix of `T` from the DiskANN binary format (see [module docs](self)).
22///
23/// Validates that the reader contains enough data before allocating.
24pub fn read_bin<T>(reader: &mut (impl Read + Seek)) -> Result<Matrix<T>, ReadBinError>
25where
26    T: bytemuck::Pod,
27{
28    let metadata = Metadata::read(reader)?;
29    let (npoints, ndims) = (metadata.npoints(), metadata.ndims());
30    let type_size = std::mem::size_of::<T>();
31
32    let expected_bytes = npoints
33        .checked_mul(ndims)
34        .and_then(|n| n.checked_mul(type_size))
35        .ok_or(ReadBinError::Overflow {
36            npoints: metadata.npoints_u32(),
37            ndims: metadata.ndims_u32(),
38            type_size,
39        })?;
40
41    let data_start = reader.stream_position()?;
42    let end = reader.seek(std::io::SeekFrom::End(0))?;
43    let available = end - data_start;
44    reader.seek(std::io::SeekFrom::Start(data_start))?;
45
46    if available < expected_bytes as u64 {
47        return Err(ReadBinError::SizeMismatch {
48            expected: expected_bytes as u64,
49            available,
50            npoints: metadata.npoints_u32(),
51            ndims: metadata.ndims_u32(),
52            type_size,
53        });
54    }
55
56    let mut data = Matrix::new(<T as bytemuck::Zeroable>::zeroed(), npoints, ndims);
57
58    reader.read_exact(bytemuck::must_cast_slice_mut::<T, u8>(data.as_mut_slice()))?;
59    Ok(data)
60}
61
62/// Write a matrix of `T` in the DiskANN binary format (see [module docs](self)).
63///
64/// Returns the total number of bytes written.
65pub fn write_bin<T>(data: MatrixView<'_, T>, writer: &mut impl Write) -> Result<usize, SaveBinError>
66where
67    T: bytemuck::Pod,
68{
69    let metadata =
70        Metadata::new(data.nrows(), data.ncols()).map_err(|_| SaveBinError::DimensionOverflow {
71            nrows: data.nrows(),
72            ncols: data.ncols(),
73        })?;
74    let bytes = metadata.write(writer)?;
75    writer.write_all(bytemuck::must_cast_slice::<T, u8>(data.as_slice()))?;
76    Ok(bytes + std::mem::size_of_val(data.as_slice()))
77}
78
79/// 8-byte header at the start of a DiskANN binary file: `npoints` and `ndims` as little-endian u32.
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81pub struct Metadata {
82    npoints: u32,
83    ndims: u32,
84}
85
86impl Metadata {
87    /// Construct from any integer types that fit in `u32`.
88    pub fn new<T, U>(npoints: T, ndims: U) -> Result<Self, MetadataError<T::Error, U::Error>>
89    where
90        T: TryInto<u32>,
91        U: TryInto<u32>,
92    {
93        Ok(Self {
94            npoints: npoints.try_into().map_err(MetadataError::NumPoints)?,
95            ndims: ndims.try_into().map_err(MetadataError::Dim)?,
96        })
97    }
98
99    /// Number of points as `usize`.
100    pub fn npoints(&self) -> usize {
101        self.npoints as usize
102    }
103
104    /// Number of points as `u32`.
105    pub fn npoints_u32(&self) -> u32 {
106        self.npoints
107    }
108
109    /// Number of dimensions as `usize`.
110    pub fn ndims(&self) -> usize {
111        self.ndims as usize
112    }
113
114    /// Number of dimensions as `u32`.
115    pub fn ndims_u32(&self) -> u32 {
116        self.ndims
117    }
118
119    /// Destructure into (`npoints`, `ndims`) as `usize`.
120    pub fn into_dims(&self) -> (usize, usize) {
121        (self.npoints(), self.ndims())
122    }
123
124    /// Deserialize the 8-byte header from a reader.
125    pub fn read<R>(reader: &mut R) -> std::io::Result<Self>
126    where
127        R: Read,
128    {
129        let mut bytes = [0u8; 8];
130        reader.read_exact(&mut bytes)?;
131
132        let LoHi {
133            lo: npts_bytes,
134            hi: ndims_bytes,
135        } = bytes.split();
136
137        let npoints = u32::from_le_bytes(npts_bytes);
138        let ndims = u32::from_le_bytes(ndims_bytes);
139        Ok(Metadata { npoints, ndims })
140    }
141
142    /// Serialize the 8-byte header to a writer. Returns the number of bytes written (always 8).
143    pub fn write<W>(&self, writer: &mut W) -> std::io::Result<usize>
144    where
145        W: Write,
146    {
147        let bytes: [u8; 8] = LoHi::new(self.npoints.to_le_bytes(), self.ndims.to_le_bytes()).join();
148        writer.write_all(&bytes)?;
149        Ok(2 * std::mem::size_of::<u32>())
150    }
151}
152
153#[derive(Debug, Error)]
154pub enum MetadataError<T, U> {
155    #[error("num points conversion")]
156    NumPoints(#[source] T),
157    #[error("dim conversion")]
158    Dim(#[source] U),
159}
160
161/// Error type for [`read_bin`].
162#[derive(Debug, Error)]
163pub enum ReadBinError {
164    /// The reader has fewer bytes remaining than the header declares.
165    #[error(
166        "binary data too short: header declares {npoints} points × {ndims} dims × {type_size} bytes = \
167         {expected} bytes, but only {available} bytes available"
168    )]
169    SizeMismatch {
170        expected: u64,
171        available: u64,
172        npoints: u32,
173        ndims: u32,
174        type_size: usize,
175    },
176
177    /// `npoints * ndims` overflows `usize` (corrupt or malicious header).
178    #[error(
179        "header dimensions overflow: {npoints} points × {ndims} dims × {type_size} bytes overflows"
180    )]
181    Overflow {
182        npoints: u32,
183        ndims: u32,
184        type_size: usize,
185    },
186
187    /// Underlying IO failure.
188    #[error(transparent)]
189    Io(#[from] std::io::Error),
190}
191
192/// Error type for [`write_bin`].
193#[derive(Debug, Error)]
194pub enum SaveBinError {
195    /// Matrix dimensions exceed `u32::MAX` and cannot be represented in the binary header.
196    #[error("dimensions overflow u32: {nrows} rows × {ncols} cols")]
197    DimensionOverflow { nrows: usize, ncols: usize },
198
199    /// Underlying IO failure.
200    #[error(transparent)]
201    Io(#[from] std::io::Error),
202}
203
204///////////
205// Tests //
206///////////
207
208#[cfg(test)]
209mod tests {
210    use std::io::Cursor;
211
212    use crate::views::Init;
213
214    use super::*;
215
216    #[test]
217    fn round_trip_f32() {
218        let mut counter = 1.0f32;
219        let matrix = Matrix::<f32>::new(
220            Init(|| {
221                let v = counter;
222                counter += 1.0;
223                v
224            }),
225            3,
226            4,
227        );
228
229        assert_eq!(
230            matrix.as_slice(),
231            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0]
232        );
233
234        let mut buf = Vec::new();
235        let written = write_bin(matrix.as_view(), &mut buf).unwrap();
236        assert_eq!(written, 8 + 3 * 4 * 4);
237
238        let mut cursor = Cursor::new(&buf);
239        let loaded = read_bin::<f32>(&mut cursor).unwrap();
240        assert_eq!(loaded.nrows(), 3);
241        assert_eq!(loaded.ncols(), 4);
242        assert_eq!(loaded.as_slice(), matrix.as_slice());
243    }
244
245    #[test]
246    fn read_bin_size_mismatch() {
247        // Header says 10 points × 4 dims of f32, but only provide 8 bytes of payload
248        let mut buf = Vec::new();
249        let metadata = Metadata::new(10u32, 4u32).unwrap();
250        metadata.write(&mut buf).unwrap();
251        buf.extend_from_slice(&[0u8; 8]);
252
253        let mut cursor = Cursor::new(&buf);
254        let err = read_bin::<f32>(&mut cursor).unwrap_err();
255
256        match err {
257            ReadBinError::SizeMismatch {
258                expected,
259                available,
260                npoints,
261                ndims,
262                type_size,
263            } => {
264                assert_eq!(expected, 10 * 4 * 4);
265                assert_eq!(available, 8);
266                assert_eq!(npoints, 10);
267                assert_eq!(ndims, 4);
268                assert_eq!(type_size, 4);
269            }
270            other => panic!("expected SizeMismatch, got: {other}"),
271        }
272    }
273
274    #[test]
275    fn read_bin_overflow() {
276        // Header with huge values that overflow usize multiplication
277        let mut buf = Vec::new();
278        buf.extend_from_slice(&u32::MAX.to_le_bytes());
279        buf.extend_from_slice(&u32::MAX.to_le_bytes());
280
281        let mut cursor = Cursor::new(&buf);
282        let err = read_bin::<f32>(&mut cursor).unwrap_err();
283
284        match err {
285            ReadBinError::Overflow {
286                npoints,
287                ndims,
288                type_size,
289            } => {
290                assert_eq!(npoints, u32::MAX);
291                assert_eq!(ndims, u32::MAX);
292                assert_eq!(type_size, 4);
293            }
294            other => panic!("expected Overflow, got: {other}"),
295        }
296    }
297
298    #[test]
299    fn read_bin_error_message_is_informative() {
300        let mut buf = Vec::new();
301        let metadata = Metadata::new(100u32, 32u32).unwrap();
302        metadata.write(&mut buf).unwrap();
303        // no payload
304
305        let mut cursor = Cursor::new(&buf);
306        let err = read_bin::<f32>(&mut cursor).unwrap_err();
307        let msg = err.to_string();
308
309        assert!(msg.contains("100 points"), "missing npoints: {msg}");
310        assert!(msg.contains("32 dims"), "missing ndims: {msg}");
311        assert!(msg.contains("12800 bytes"), "missing expected: {msg}");
312        assert!(
313            msg.contains("0 bytes available"),
314            "missing available: {msg}"
315        );
316    }
317
318    #[test]
319    fn metadata_read_write_round_trip() {
320        let mut buf = Vec::new();
321        let metadata = Metadata::new(200u32, 128u32).unwrap();
322        metadata.write(&mut buf).unwrap();
323
324        let mut cursor = Cursor::new(&buf);
325        let loaded = Metadata::read(&mut cursor).unwrap();
326        assert_eq!(loaded, metadata);
327    }
328}