1use std::io::Cursor;
7
8use byteorder::{LittleEndian, ReadBytesExt};
9use diskann::{ANNError, ANNResult};
10use thiserror::Error;
11
12use super::{GraphLayoutVersion, GraphMetadata};
13
14pub struct GraphHeader {
16 metadata: GraphMetadata,
18
19 block_size: u64,
21
22 layout_version: GraphLayoutVersion,
24}
25
26#[derive(Error, Debug, PartialEq)]
27pub enum GraphHeaderError {
28 #[error("Overflow occurred during max_degree calculation.")]
29 MaxDegreeOverflow,
30 #[error("Unsupported graph layout version {0} for max_degree calculation.")]
31 MaxDegreeUnsupportedLayoutVersion(GraphLayoutVersion),
32}
33
34impl From<GraphHeaderError> for ANNError {
35 #[track_caller]
36 fn from(value: GraphHeaderError) -> Self {
37 ANNError::log_index_error(value)
38 }
39}
40
41impl GraphHeader {
42 pub const CURRENT_LAYOUT_VERSION: GraphLayoutVersion = GraphLayoutVersion::new(1, 0);
44
45 pub fn new(
46 metadata: GraphMetadata,
47 block_size: u64,
48 layout_version: GraphLayoutVersion,
49 ) -> Self {
50 Self {
51 metadata,
52 block_size,
53 layout_version,
54 }
55 }
56 pub fn to_bytes(&self) -> ANNResult<Vec<u8>> {
60 let mut buffer = vec![];
61 buffer.extend_from_slice(self.metadata.to_bytes()?.as_ref());
62 buffer.extend_from_slice(self.block_size.to_le_bytes().as_ref());
63 buffer.extend_from_slice(self.layout_version.to_bytes().as_ref());
64
65 Ok(buffer)
66 }
67
68 #[inline]
70 pub fn get_size() -> usize {
71 GraphMetadata::get_size()
72 + std::mem::size_of::<u64>()
73 + std::mem::size_of::<GraphLayoutVersion>()
74 }
75
76 pub fn metadata(&self) -> &GraphMetadata {
77 &self.metadata
78 }
79
80 pub fn block_size(&self) -> u64 {
81 self.block_size
82 }
83
84 pub fn layout_version(&self) -> &GraphLayoutVersion {
85 &self.layout_version
86 }
87
88 pub fn max_degree<DataType>(&self) -> Result<usize, GraphHeaderError> {
93 let supported_versions = [GraphLayoutVersion::new(0, 0), GraphLayoutVersion::new(1, 0)];
94
95 if supported_versions.contains(&self.layout_version) {
96 let vector_len = std::mem::size_of::<DataType>() * self.metadata.dims;
101 let max_degree = (self.metadata.node_len as usize)
102 .checked_sub(vector_len)
103 .and_then(|len| len.checked_sub(self.metadata.associated_data_length))
104 .and_then(|len| len.checked_div(std::mem::size_of::<u32>()))
105 .and_then(|len| len.checked_sub(1));
106
107 match max_degree {
108 Some(degree) => Ok(degree),
109 None => Err(GraphHeaderError::MaxDegreeOverflow),
110 }
111 } else {
112 Err(GraphHeaderError::MaxDegreeUnsupportedLayoutVersion(
113 self.layout_version.clone(),
114 ))
115 }
116 }
117}
118
119impl<'a> TryFrom<&'a [u8]> for GraphHeader {
120 type Error = ANNError;
121 fn try_from(value: &'a [u8]) -> ANNResult<Self> {
126 if value.len() < Self::get_size() {
127 Err(ANNError::log_parse_slice_error(
128 "&[u8]".to_string(),
129 "GraphHeader".to_string(),
130 "The given bytes are not long enough to create a valid graph header.".to_string(),
131 ))
132 } else {
133 let metadata_len = GraphMetadata::get_size();
135 let metadata = GraphMetadata::try_from(&value[0..metadata_len])?;
136
137 let block_size = Cursor::new(&value[metadata_len..]).read_u64::<LittleEndian>()?;
139
140 let layout_version = GraphLayoutVersion::try_from(&value[metadata_len + 8..])?;
142
143 Ok(Self::new(metadata, block_size, layout_version))
144 }
145 }
146}
147
148#[cfg(test)]
149mod tests {
150 use diskann::ANNErrorKind;
151 use rstest::rstest;
152
153 use super::*;
154 use crate::data_model::{GraphHeader, GraphLayoutVersion, GraphMetadata};
155
156 #[test]
157 fn test_graph_header_to_bytes_and_try_from() {
158 let layout_version = GraphLayoutVersion::new(1, 0);
159 let block_size = 128;
160 let num_pts = 1000;
161 let dims = 32;
162 let medoid = 500;
163 let node_len = 64;
164 let num_nodes_per_sector = 4;
165 let vamana_frozen_num = 20;
166 let vamana_frozen_loc = 50;
167 let disk_index_file_size = 1024;
168 let data_size = 256;
169
170 let metadata = GraphMetadata::new(
171 num_pts,
172 dims,
173 medoid,
174 node_len,
175 num_nodes_per_sector,
176 vamana_frozen_num,
177 vamana_frozen_loc,
178 disk_index_file_size,
179 data_size,
180 );
181
182 let header = GraphHeader::new(metadata.clone(), block_size, layout_version.clone());
183 let bytes = header.to_bytes().unwrap();
184 assert_eq!(bytes.len(), GraphHeader::get_size());
185
186 let deserialized_header = GraphHeader::try_from(bytes.as_slice()).unwrap();
187 assert_eq!(metadata.num_pts, deserialized_header.metadata.num_pts);
188 assert_eq!(metadata.dims, deserialized_header.metadata.dims);
189 assert_eq!(metadata.medoid, deserialized_header.metadata.medoid);
190 assert_eq!(metadata.node_len, deserialized_header.metadata.node_len);
191 assert_eq!(
192 metadata.num_nodes_per_block,
193 deserialized_header.metadata.num_nodes_per_block
194 );
195 assert_eq!(
196 metadata.vamana_frozen_num,
197 deserialized_header.metadata.vamana_frozen_num
198 );
199 assert_eq!(
200 metadata.vamana_frozen_loc,
201 deserialized_header.metadata.vamana_frozen_loc
202 );
203 assert_eq!(
204 metadata.disk_index_file_size,
205 deserialized_header.metadata.disk_index_file_size
206 );
207 assert_eq!(
208 metadata.associated_data_length,
209 deserialized_header.metadata.associated_data_length
210 );
211
212 assert_eq!(block_size, deserialized_header.block_size);
213 assert_eq!(layout_version, deserialized_header.layout_version);
214 }
215
216 #[test]
217 fn test_graph_header_try_from_invalid_bytes() {
218 let invalid_bytes = vec![1; GraphHeader::get_size() - 1];
219 let result = GraphHeader::try_from(&invalid_bytes[..]);
220 assert!(result.is_err());
221 }
222
223 #[rstest]
224 #[case(384, 1008, 0, 59, GraphLayoutVersion::new(0, 0))]
225 #[case(384, 1008, 0, 59, GraphLayoutVersion::new(1, 0))]
226 #[case(3072, 6384, 0, 59, GraphLayoutVersion::new(0, 0))]
227 #[case(3072, 6384, 0, 59, GraphLayoutVersion::new(1, 0))]
228 #[case(384, 1008, 0, 59, GraphHeader::CURRENT_LAYOUT_VERSION)]
230 fn test_graph_header_max_degree(
231 #[case] dims: usize,
232 #[case] node_len: u64,
233 #[case] data_size: usize,
234 #[case] expected_max_degree: usize,
235 #[case] layout_version: GraphLayoutVersion,
236 ) {
237 let num_pts = 1000;
238 let medoid = 500;
239 let num_nodes_per_sector = 4;
240 let vamana_frozen_num = 20;
241 let vamana_frozen_loc = 50;
242 let disk_index_file_size = 1024;
243
244 let metadata = GraphMetadata::new(
245 num_pts,
246 dims,
247 medoid,
248 node_len,
249 num_nodes_per_sector,
250 vamana_frozen_num,
251 vamana_frozen_loc,
252 disk_index_file_size,
253 data_size,
254 );
255 let block_size = 128;
256
257 let header = GraphHeader::new(metadata, block_size, layout_version);
258
259 let max_degree = header.max_degree::<diskann_vector::Half>();
260 assert!(max_degree.is_ok());
261 assert_eq!(max_degree.unwrap(), expected_max_degree);
262 }
263
264 #[rstest]
265 #[case(1, 1)]
266 #[case(2, 0)]
267 fn test_graph_header_max_degree_unsupported_layout_version(
268 #[case] major_version: u32,
269 #[case] minor_version: u32,
270 ) {
271 let num_pts = 1000;
272 let dims = 32;
273 let medoid = 500;
274 let node_len = 64;
275 let num_nodes_per_sector = 4;
276 let vamana_frozen_num = 20;
277 let vamana_frozen_loc = 50;
278 let disk_index_file_size = 1024;
279 let data_size = 256;
280
281 let metadata = GraphMetadata::new(
282 num_pts,
283 dims,
284 medoid,
285 node_len,
286 num_nodes_per_sector,
287 vamana_frozen_num,
288 vamana_frozen_loc,
289 disk_index_file_size,
290 data_size,
291 );
292 let layout_version = GraphLayoutVersion::new(major_version, minor_version);
293 let block_size = 128;
294
295 let header = GraphHeader::new(metadata, block_size, layout_version.clone());
296
297 let max_degree = header.max_degree::<diskann_vector::Half>();
298 assert!(max_degree.is_err());
299 assert_eq!(
300 max_degree.unwrap_err(),
301 GraphHeaderError::MaxDegreeUnsupportedLayoutVersion(layout_version.clone())
302 );
303 }
304
305 #[test]
306 fn test_graph_header_max_degree_overflow() {
307 let dims = 384;
308 let node_len = 384;
309 let data_size = 0;
310
311 let num_pts = 1000;
312 let medoid = 500;
313 let num_nodes_per_sector = 4;
314 let vamana_frozen_num = 20;
315 let vamana_frozen_loc = 50;
316 let disk_index_file_size = 1024;
317
318 let metadata = GraphMetadata::new(
319 num_pts,
320 dims,
321 medoid,
322 node_len,
323 num_nodes_per_sector,
324 vamana_frozen_num,
325 vamana_frozen_loc,
326 disk_index_file_size,
327 data_size,
328 );
329 let layout_version = GraphLayoutVersion::new(1, 0);
330 let block_size = 128;
331
332 let header = GraphHeader::new(metadata, block_size, layout_version);
333
334 let max_degree = header.max_degree::<diskann_vector::Half>();
335 assert!(max_degree.is_err());
336 assert_eq!(max_degree.unwrap_err(), GraphHeaderError::MaxDegreeOverflow);
337 }
338
339 #[test]
341 fn test_graph_header_error_conversion() {
342 let error = GraphHeaderError::MaxDegreeOverflow;
343 let ann_error: ANNError = error.into();
344 assert_eq!(ann_error.kind(), ANNErrorKind::IndexError);
345
346 let layout_version = GraphLayoutVersion::new(1, 0);
347 let error = GraphHeaderError::MaxDegreeUnsupportedLayoutVersion(layout_version);
348 let ann_error: ANNError = error.into();
349 assert_eq!(ann_error.kind(), ANNErrorKind::IndexError);
350 }
351}