Skip to main content

diskann_disk/data_model/
graph_header.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use std::io::Cursor;
7
8use byteorder::{LittleEndian, ReadBytesExt};
9use diskann::{ANNError, ANNResult};
10use thiserror::Error;
11
12use super::{GraphLayoutVersion, GraphMetadata};
13
14/// GraphHeader. The header is stored in the first sector of the disk index file, or the first segment of the JET stream.
15pub struct GraphHeader {
16    // Graph metadata.
17    metadata: GraphMetadata,
18
19    // Block size.
20    block_size: u64,
21
22    // Graph layout version.
23    layout_version: GraphLayoutVersion,
24}
25
26#[derive(Error, Debug, PartialEq)]
27pub enum GraphHeaderError {
28    #[error("Overflow occurred during max_degree calculation.")]
29    MaxDegreeOverflow,
30    #[error("Unsupported graph layout version {0} for max_degree calculation.")]
31    MaxDegreeUnsupportedLayoutVersion(GraphLayoutVersion),
32}
33
34impl From<GraphHeaderError> for ANNError {
35    #[track_caller]
36    fn from(value: GraphHeaderError) -> Self {
37        ANNError::log_index_error(value)
38    }
39}
40
41impl GraphHeader {
42    /// Update the layout version when the [GraphHeader] layout is modified.
43    pub const CURRENT_LAYOUT_VERSION: GraphLayoutVersion = GraphLayoutVersion::new(1, 0);
44
45    pub fn new(
46        metadata: GraphMetadata,
47        block_size: u64,
48        layout_version: GraphLayoutVersion,
49    ) -> Self {
50        Self {
51            metadata,
52            block_size,
53            layout_version,
54        }
55    }
56    /// Serialize the `GraphHeader` object to a byte vector.
57    /// Layout:
58    /// | GraphMetadata (80 bytes) | BlockSize (8 bytes) | GraphLayoutVersion (8 bytes) |
59    pub fn to_bytes(&self) -> ANNResult<Vec<u8>> {
60        let mut buffer = vec![];
61        buffer.extend_from_slice(self.metadata.to_bytes()?.as_ref());
62        buffer.extend_from_slice(self.block_size.to_le_bytes().as_ref());
63        buffer.extend_from_slice(self.layout_version.to_bytes().as_ref());
64
65        Ok(buffer)
66    }
67
68    /// Get the size of the header after serialization.
69    #[inline]
70    pub fn get_size() -> usize {
71        GraphMetadata::get_size()
72            + std::mem::size_of::<u64>()
73            + std::mem::size_of::<GraphLayoutVersion>()
74    }
75
76    pub fn metadata(&self) -> &GraphMetadata {
77        &self.metadata
78    }
79
80    pub fn block_size(&self) -> u64 {
81        self.block_size
82    }
83
84    pub fn layout_version(&self) -> &GraphLayoutVersion {
85        &self.layout_version
86    }
87
88    /// Returns the maximum degree of the graph
89    ///
90    /// # Type Parameters
91    /// * `DataType` - The type of vector data stored in the graph nodes
92    pub fn max_degree<DataType>(&self) -> Result<usize, GraphHeaderError> {
93        let supported_versions = [GraphLayoutVersion::new(0, 0), GraphLayoutVersion::new(1, 0)];
94
95        if supported_versions.contains(&self.layout_version) {
96            // Calculates max degree based on the node layout:
97            // - Each node contains: vector data + neighbor list + associated data
98            // - Neighbors are stored as u32 indices
99            // - The -1 accounts for the first u32 which stores number of neighbors
100            let vector_len = std::mem::size_of::<DataType>() * self.metadata.dims;
101            let max_degree = (self.metadata.node_len as usize)
102                .checked_sub(vector_len)
103                .and_then(|len| len.checked_sub(self.metadata.associated_data_length))
104                .and_then(|len| len.checked_div(std::mem::size_of::<u32>()))
105                .and_then(|len| len.checked_sub(1));
106
107            match max_degree {
108                Some(degree) => Ok(degree),
109                None => Err(GraphHeaderError::MaxDegreeOverflow),
110            }
111        } else {
112            Err(GraphHeaderError::MaxDegreeUnsupportedLayoutVersion(
113                self.layout_version.clone(),
114            ))
115        }
116    }
117}
118
119impl<'a> TryFrom<&'a [u8]> for GraphHeader {
120    type Error = ANNError;
121    /// Try creating a new `GraphHeader` object from a byte slice. The try_from syntax is used here instead of from because this operation can fail.
122    ///
123    /// Layout:
124    /// | GraphMetadata (80 bytes) | BlockSize (8 bytes) | GraphLayoutVersion (8 bytes) |
125    fn try_from(value: &'a [u8]) -> ANNResult<Self> {
126        if value.len() < Self::get_size() {
127            Err(ANNError::log_parse_slice_error(
128                "&[u8]".to_string(),
129                "GraphHeader".to_string(),
130                "The given bytes are not long enough to create a valid graph header.".to_string(),
131            ))
132        } else {
133            // Parse metadata.
134            let metadata_len = GraphMetadata::get_size();
135            let metadata = GraphMetadata::try_from(&value[0..metadata_len])?;
136
137            // Parse block size.
138            let block_size = Cursor::new(&value[metadata_len..]).read_u64::<LittleEndian>()?;
139
140            // Parse layout version.
141            let layout_version = GraphLayoutVersion::try_from(&value[metadata_len + 8..])?;
142
143            Ok(Self::new(metadata, block_size, layout_version))
144        }
145    }
146}
147
148#[cfg(test)]
149mod tests {
150    use diskann::ANNErrorKind;
151    use rstest::rstest;
152
153    use super::*;
154    use crate::data_model::{GraphHeader, GraphLayoutVersion, GraphMetadata};
155
156    #[test]
157    fn test_graph_header_to_bytes_and_try_from() {
158        let layout_version = GraphLayoutVersion::new(1, 0);
159        let block_size = 128;
160        let num_pts = 1000;
161        let dims = 32;
162        let medoid = 500;
163        let node_len = 64;
164        let num_nodes_per_sector = 4;
165        let vamana_frozen_num = 20;
166        let vamana_frozen_loc = 50;
167        let disk_index_file_size = 1024;
168        let data_size = 256;
169
170        let metadata = GraphMetadata::new(
171            num_pts,
172            dims,
173            medoid,
174            node_len,
175            num_nodes_per_sector,
176            vamana_frozen_num,
177            vamana_frozen_loc,
178            disk_index_file_size,
179            data_size,
180        );
181
182        let header = GraphHeader::new(metadata.clone(), block_size, layout_version.clone());
183        let bytes = header.to_bytes().unwrap();
184        assert_eq!(bytes.len(), GraphHeader::get_size());
185
186        let deserialized_header = GraphHeader::try_from(bytes.as_slice()).unwrap();
187        assert_eq!(metadata.num_pts, deserialized_header.metadata.num_pts);
188        assert_eq!(metadata.dims, deserialized_header.metadata.dims);
189        assert_eq!(metadata.medoid, deserialized_header.metadata.medoid);
190        assert_eq!(metadata.node_len, deserialized_header.metadata.node_len);
191        assert_eq!(
192            metadata.num_nodes_per_block,
193            deserialized_header.metadata.num_nodes_per_block
194        );
195        assert_eq!(
196            metadata.vamana_frozen_num,
197            deserialized_header.metadata.vamana_frozen_num
198        );
199        assert_eq!(
200            metadata.vamana_frozen_loc,
201            deserialized_header.metadata.vamana_frozen_loc
202        );
203        assert_eq!(
204            metadata.disk_index_file_size,
205            deserialized_header.metadata.disk_index_file_size
206        );
207        assert_eq!(
208            metadata.associated_data_length,
209            deserialized_header.metadata.associated_data_length
210        );
211
212        assert_eq!(block_size, deserialized_header.block_size);
213        assert_eq!(layout_version, deserialized_header.layout_version);
214    }
215
216    #[test]
217    fn test_graph_header_try_from_invalid_bytes() {
218        let invalid_bytes = vec![1; GraphHeader::get_size() - 1];
219        let result = GraphHeader::try_from(&invalid_bytes[..]);
220        assert!(result.is_err());
221    }
222
223    #[rstest]
224    #[case(384, 1008, 0, 59, GraphLayoutVersion::new(0, 0))]
225    #[case(384, 1008, 0, 59, GraphLayoutVersion::new(1, 0))]
226    #[case(3072, 6384, 0, 59, GraphLayoutVersion::new(0, 0))]
227    #[case(3072, 6384, 0, 59, GraphLayoutVersion::new(1, 0))]
228    // Current layout version should support max degree calculation
229    #[case(384, 1008, 0, 59, GraphHeader::CURRENT_LAYOUT_VERSION)]
230    fn test_graph_header_max_degree(
231        #[case] dims: usize,
232        #[case] node_len: u64,
233        #[case] data_size: usize,
234        #[case] expected_max_degree: usize,
235        #[case] layout_version: GraphLayoutVersion,
236    ) {
237        let num_pts = 1000;
238        let medoid = 500;
239        let num_nodes_per_sector = 4;
240        let vamana_frozen_num = 20;
241        let vamana_frozen_loc = 50;
242        let disk_index_file_size = 1024;
243
244        let metadata = GraphMetadata::new(
245            num_pts,
246            dims,
247            medoid,
248            node_len,
249            num_nodes_per_sector,
250            vamana_frozen_num,
251            vamana_frozen_loc,
252            disk_index_file_size,
253            data_size,
254        );
255        let block_size = 128;
256
257        let header = GraphHeader::new(metadata, block_size, layout_version);
258
259        let max_degree = header.max_degree::<diskann_vector::Half>();
260        assert!(max_degree.is_ok());
261        assert_eq!(max_degree.unwrap(), expected_max_degree);
262    }
263
264    #[rstest]
265    #[case(1, 1)]
266    #[case(2, 0)]
267    fn test_graph_header_max_degree_unsupported_layout_version(
268        #[case] major_version: u32,
269        #[case] minor_version: u32,
270    ) {
271        let num_pts = 1000;
272        let dims = 32;
273        let medoid = 500;
274        let node_len = 64;
275        let num_nodes_per_sector = 4;
276        let vamana_frozen_num = 20;
277        let vamana_frozen_loc = 50;
278        let disk_index_file_size = 1024;
279        let data_size = 256;
280
281        let metadata = GraphMetadata::new(
282            num_pts,
283            dims,
284            medoid,
285            node_len,
286            num_nodes_per_sector,
287            vamana_frozen_num,
288            vamana_frozen_loc,
289            disk_index_file_size,
290            data_size,
291        );
292        let layout_version = GraphLayoutVersion::new(major_version, minor_version);
293        let block_size = 128;
294
295        let header = GraphHeader::new(metadata, block_size, layout_version.clone());
296
297        let max_degree = header.max_degree::<diskann_vector::Half>();
298        assert!(max_degree.is_err());
299        assert_eq!(
300            max_degree.unwrap_err(),
301            GraphHeaderError::MaxDegreeUnsupportedLayoutVersion(layout_version.clone())
302        );
303    }
304
305    #[test]
306    fn test_graph_header_max_degree_overflow() {
307        let dims = 384;
308        let node_len = 384;
309        let data_size = 0;
310
311        let num_pts = 1000;
312        let medoid = 500;
313        let num_nodes_per_sector = 4;
314        let vamana_frozen_num = 20;
315        let vamana_frozen_loc = 50;
316        let disk_index_file_size = 1024;
317
318        let metadata = GraphMetadata::new(
319            num_pts,
320            dims,
321            medoid,
322            node_len,
323            num_nodes_per_sector,
324            vamana_frozen_num,
325            vamana_frozen_loc,
326            disk_index_file_size,
327            data_size,
328        );
329        let layout_version = GraphLayoutVersion::new(1, 0);
330        let block_size = 128;
331
332        let header = GraphHeader::new(metadata, block_size, layout_version);
333
334        let max_degree = header.max_degree::<diskann_vector::Half>();
335        assert!(max_degree.is_err());
336        assert_eq!(max_degree.unwrap_err(), GraphHeaderError::MaxDegreeOverflow);
337    }
338
339    // test cases for GraphHeaderError conversion to ANNError
340    #[test]
341    fn test_graph_header_error_conversion() {
342        let error = GraphHeaderError::MaxDegreeOverflow;
343        let ann_error: ANNError = error.into();
344        assert_eq!(ann_error.kind(), ANNErrorKind::IndexError);
345
346        let layout_version = GraphLayoutVersion::new(1, 0);
347        let error = GraphHeaderError::MaxDegreeUnsupportedLayoutVersion(layout_version);
348        let ann_error: ANNError = error.into();
349        assert_eq!(ann_error.kind(), ANNErrorKind::IndexError);
350    }
351}