1use crate::omen::section::{SectionEntry, SectionType};
4use std::io::{self, Read};
5
6pub const MAGIC: [u8; 4] = *b"OMEN";
8
9pub const VERSION_MAJOR: u16 = 1;
11pub const VERSION_MINOR: u16 = 0;
12
13pub const HEADER_SIZE: usize = 4096;
15
16pub const MAX_SECTIONS: usize = 8;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u8)]
25pub enum QuantizationCode {
26 F32 = 0,
27 Sq8 = 1,
28 RabitQ4 = 2,
29 RabitQ2 = 3,
30 RabitQ8 = 4,
31 Binary = 5,
32}
33
34impl From<u8> for QuantizationCode {
35 fn from(v: u8) -> Self {
36 match v {
37 1 => Self::Sq8,
38 2 => Self::RabitQ4,
39 3 => Self::RabitQ2,
40 4 => Self::RabitQ8,
41 5 => Self::Binary,
42 _ => Self::F32,
43 }
44 }
45}
46
47impl From<&crate::vector::QuantizationMode> for QuantizationCode {
48 fn from(mode: &crate::vector::QuantizationMode) -> Self {
49 use crate::compression::QuantizationBits;
50 match mode {
51 crate::vector::QuantizationMode::Binary => Self::Binary,
52 crate::vector::QuantizationMode::SQ8 => Self::Sq8,
53 crate::vector::QuantizationMode::RaBitQ(params) => match params.bits_per_dim {
54 QuantizationBits::Bits1 => Self::Binary,
55 QuantizationBits::Bits2 => Self::RabitQ2,
56 QuantizationBits::Bits3 | QuantizationBits::Bits4 => Self::RabitQ4,
57 QuantizationBits::Bits5 | QuantizationBits::Bits7 | QuantizationBits::Bits8 => {
58 Self::RabitQ8
59 }
60 },
61 }
62 }
63}
64
65impl From<crate::vector::QuantizationMode> for QuantizationCode {
66 fn from(mode: crate::vector::QuantizationMode) -> Self {
67 Self::from(&mode)
68 }
69}
70
71impl QuantizationCode {
72 #[must_use]
76 pub fn to_runtime(self) -> Option<crate::vector::QuantizationMode> {
77 use crate::compression::RaBitQParams;
78 match self {
79 Self::F32 => None,
80 Self::Sq8 => Some(crate::vector::QuantizationMode::SQ8),
81 Self::Binary => Some(crate::vector::QuantizationMode::Binary),
82 Self::RabitQ2 => Some(crate::vector::QuantizationMode::RaBitQ(
83 RaBitQParams::bits2(),
84 )),
85 Self::RabitQ4 => Some(crate::vector::QuantizationMode::RaBitQ(
86 RaBitQParams::bits4(),
87 )),
88 Self::RabitQ8 => Some(crate::vector::QuantizationMode::RaBitQ(
89 RaBitQParams::bits8(),
90 )),
91 }
92 }
93}
94
95#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100#[repr(u8)]
101pub enum Metric {
102 L2 = 0,
104 Cosine = 1,
106 Dot = 2,
108}
109
110impl From<u8> for Metric {
111 fn from(v: u8) -> Self {
112 match v {
113 1 => Self::Cosine,
114 2 => Self::Dot,
115 _ => Self::L2,
116 }
117 }
118}
119
120impl From<Metric> for crate::vector::hnsw::DistanceFunction {
121 fn from(m: Metric) -> Self {
122 match m {
123 Metric::L2 => Self::L2,
124 Metric::Cosine => Self::Cosine,
125 Metric::Dot => Self::NegativeDotProduct,
126 }
127 }
128}
129
130impl From<crate::vector::hnsw::DistanceFunction> for Metric {
131 fn from(d: crate::vector::hnsw::DistanceFunction) -> Self {
132 match d {
133 crate::vector::hnsw::DistanceFunction::L2 => Self::L2,
134 crate::vector::hnsw::DistanceFunction::Cosine => Self::Cosine,
135 crate::vector::hnsw::DistanceFunction::NegativeDotProduct => Self::Dot,
136 }
137 }
138}
139
140impl Metric {
141 pub fn parse(s: &str) -> Result<Self, String> {
148 match s.to_lowercase().as_str() {
149 "l2" | "euclidean" => Ok(Self::L2),
150 "cosine" => Ok(Self::Cosine),
151 "dot" | "ip" => Ok(Self::Dot),
152 _ => Err(format!(
153 "Unknown metric: '{s}'. Valid: l2, euclidean, cosine, dot, ip"
154 )),
155 }
156 }
157
158 #[must_use]
160 pub fn as_str(&self) -> &'static str {
161 match self {
162 Self::L2 => "l2",
163 Self::Cosine => "cosine",
164 Self::Dot => "dot",
165 }
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct OmenHeader {
172 pub version_major: u16,
174 pub version_minor: u16,
175 pub flags: u64,
176
177 pub dimensions: u32,
179 pub count: u64,
180 pub quantization: QuantizationCode,
181 pub distance_fn: Metric,
182
183 pub m: u16,
185 pub ef_construction: u16,
186 pub ef_search: u16,
187 pub max_level: u8,
188 pub entry_point: u32,
189
190 pub sections: [SectionEntry; MAX_SECTIONS],
192
193 pub header_checksum: u32,
195 pub data_checksum: u32,
196}
197
198impl Default for OmenHeader {
199 fn default() -> Self {
200 Self {
201 version_major: VERSION_MAJOR,
202 version_minor: VERSION_MINOR,
203 flags: 0,
204 dimensions: 0,
205 count: 0,
206 quantization: QuantizationCode::F32,
207 distance_fn: Metric::L2,
208 m: 16,
209 ef_construction: 100,
210 ef_search: 100,
211 max_level: 0,
212 entry_point: 0,
213 sections: [SectionEntry::default(); MAX_SECTIONS],
214 header_checksum: 0,
215 data_checksum: 0,
216 }
217 }
218}
219
220impl OmenHeader {
221 #[must_use]
223 pub fn new(dimensions: u32) -> Self {
224 Self {
225 dimensions,
226 ..Default::default()
227 }
228 }
229
230 #[must_use]
232 pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
233 let mut buf = [0u8; HEADER_SIZE];
234 let mut offset = 0;
235
236 buf[offset..offset + 4].copy_from_slice(&MAGIC);
238 offset += 4;
239
240 buf[offset..offset + 2].copy_from_slice(&self.version_major.to_le_bytes());
242 offset += 2;
243 buf[offset..offset + 2].copy_from_slice(&self.version_minor.to_le_bytes());
244 offset += 2;
245
246 buf[offset..offset + 8].copy_from_slice(&self.flags.to_le_bytes());
248 offset += 8;
249
250 buf[offset..offset + 4].copy_from_slice(&self.dimensions.to_le_bytes());
252 offset += 4;
253 buf[offset..offset + 8].copy_from_slice(&self.count.to_le_bytes());
254 offset += 8;
255 buf[offset] = self.quantization as u8;
256 offset += 1;
257 buf[offset] = self.distance_fn as u8;
258 offset += 1;
259 offset += 14;
261
262 buf[offset..offset + 2].copy_from_slice(&self.m.to_le_bytes());
264 offset += 2;
265 buf[offset..offset + 2].copy_from_slice(&self.ef_construction.to_le_bytes());
266 offset += 2;
267 buf[offset..offset + 2].copy_from_slice(&self.ef_search.to_le_bytes());
268 offset += 2;
269 buf[offset] = self.max_level;
270 offset += 1;
271 buf[offset..offset + 4].copy_from_slice(&self.entry_point.to_le_bytes());
272 offset += 4;
273 offset += 3;
275
276 for section in &self.sections {
278 buf[offset..offset + 24].copy_from_slice(§ion.to_bytes());
279 offset += 24;
280 }
281
282 buf[offset..offset + 4].copy_from_slice(&self.header_checksum.to_le_bytes());
284 offset += 4;
285 buf[offset..offset + 4].copy_from_slice(&self.data_checksum.to_le_bytes());
286
287 let checksum = crc32fast::hash(&buf[..HEADER_SIZE - 8]);
289 buf[HEADER_SIZE - 8..HEADER_SIZE - 4].copy_from_slice(&checksum.to_le_bytes());
290
291 buf
292 }
293
294 pub fn from_bytes(buf: &[u8; HEADER_SIZE]) -> io::Result<Self> {
296 if buf[0..4] != MAGIC {
298 return Err(io::Error::new(
299 io::ErrorKind::InvalidData,
300 "Invalid magic bytes",
301 ));
302 }
303
304 let stored_checksum = u32::from_le_bytes([
306 buf[HEADER_SIZE - 8],
307 buf[HEADER_SIZE - 7],
308 buf[HEADER_SIZE - 6],
309 buf[HEADER_SIZE - 5],
310 ]);
311 let computed_checksum = crc32fast::hash(&buf[..HEADER_SIZE - 8]);
312 if stored_checksum != computed_checksum {
313 return Err(io::Error::new(
314 io::ErrorKind::InvalidData,
315 "Header checksum mismatch",
316 ));
317 }
318
319 let mut cursor = io::Cursor::new(&buf[4..]); let mut u16_buf = [0u8; 2];
322 let mut u32_buf = [0u8; 4];
323 let mut u64_buf = [0u8; 8];
324 let mut u8_buf = [0u8; 1];
325
326 cursor.read_exact(&mut u16_buf)?;
328 let version_major = u16::from_le_bytes(u16_buf);
329 cursor.read_exact(&mut u16_buf)?;
330 let version_minor = u16::from_le_bytes(u16_buf);
331
332 if version_major > VERSION_MAJOR {
334 return Err(io::Error::new(
335 io::ErrorKind::InvalidData,
336 format!("Unsupported version: {version_major}.{version_minor}"),
337 ));
338 }
339
340 cursor.read_exact(&mut u64_buf)?;
342 let flags = u64::from_le_bytes(u64_buf);
343
344 cursor.read_exact(&mut u32_buf)?;
346 let dimensions = u32::from_le_bytes(u32_buf);
347 cursor.read_exact(&mut u64_buf)?;
348 let count = u64::from_le_bytes(u64_buf);
349 cursor.read_exact(&mut u8_buf)?;
350 let quantization = QuantizationCode::from(u8_buf[0]);
351 cursor.read_exact(&mut u8_buf)?;
352 let distance_fn = Metric::from(u8_buf[0]);
353
354 let mut reserved = [0u8; 14];
356 cursor.read_exact(&mut reserved)?;
357
358 cursor.read_exact(&mut u16_buf)?;
360 let m = u16::from_le_bytes(u16_buf);
361 cursor.read_exact(&mut u16_buf)?;
362 let ef_construction = u16::from_le_bytes(u16_buf);
363 cursor.read_exact(&mut u16_buf)?;
364 let ef_search = u16::from_le_bytes(u16_buf);
365 cursor.read_exact(&mut u8_buf)?;
366 let max_level = u8_buf[0];
367 cursor.read_exact(&mut u32_buf)?;
368 let entry_point = u32::from_le_bytes(u32_buf);
369
370 let mut reserved2 = [0u8; 3];
372 cursor.read_exact(&mut reserved2)?;
373
374 let mut sections = [SectionEntry::default(); MAX_SECTIONS];
376 for section in &mut sections {
377 let mut section_buf = [0u8; 24];
378 cursor.read_exact(&mut section_buf)?;
379 *section = SectionEntry::from_bytes(§ion_buf);
380 }
381
382 cursor.read_exact(&mut u32_buf)?;
384 let header_checksum = u32::from_le_bytes(u32_buf);
385 cursor.read_exact(&mut u32_buf)?;
386 let data_checksum = u32::from_le_bytes(u32_buf);
387
388 Ok(Self {
389 version_major,
390 version_minor,
391 flags,
392 dimensions,
393 count,
394 quantization,
395 distance_fn,
396 m,
397 ef_construction,
398 ef_search,
399 max_level,
400 entry_point,
401 sections,
402 header_checksum,
403 data_checksum,
404 })
405 }
406
407 #[must_use]
409 pub fn get_section(&self, section_type: SectionType) -> Option<&SectionEntry> {
410 self.sections
411 .iter()
412 .find(|s| s.section_type == section_type && s.length > 0)
413 }
414
415 pub fn set_section(&mut self, entry: SectionEntry) {
417 for section in &mut self.sections {
418 if section.section_type == entry.section_type || section.length == 0 {
419 *section = entry;
420 return;
421 }
422 }
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 #[test]
431 fn test_header_roundtrip() {
432 let mut header = OmenHeader::new(768);
433 header.count = 1000;
434 header.m = 32;
435 header.ef_construction = 200;
436 header.entry_point = 42;
437
438 let bytes = header.to_bytes();
439 let parsed = OmenHeader::from_bytes(&bytes).unwrap();
440
441 assert_eq!(parsed.dimensions, 768);
442 assert_eq!(parsed.count, 1000);
443 assert_eq!(parsed.m, 32);
444 assert_eq!(parsed.ef_construction, 200);
445 assert_eq!(parsed.entry_point, 42);
446 }
447
448 #[test]
449 fn test_invalid_magic() {
450 let mut buf = [0u8; HEADER_SIZE];
451 buf[0..4].copy_from_slice(b"NOPE");
452
453 let result = OmenHeader::from_bytes(&buf);
454 assert!(result.is_err());
455 }
456
457 #[test]
458 fn test_corrupted_header_detected() {
459 let header = OmenHeader::new(768);
460 let mut bytes = header.to_bytes();
461
462 bytes[20] ^= 0xFF;
464
465 let result = OmenHeader::from_bytes(&bytes);
466 assert!(result.is_err());
467 assert!(result
468 .unwrap_err()
469 .to_string()
470 .contains("checksum mismatch"));
471 }
472
473 #[test]
474 fn test_checksum_calculated_correctly() {
475 let mut header = OmenHeader::new(768);
476 header.count = 12345;
477 header.m = 32;
478 header.ef_construction = 200;
479
480 let bytes = header.to_bytes();
481
482 let stored_checksum =
484 u32::from_le_bytes(bytes[HEADER_SIZE - 8..HEADER_SIZE - 4].try_into().unwrap());
485
486 assert_ne!(stored_checksum, 0);
488
489 let parsed = OmenHeader::from_bytes(&bytes).unwrap();
491 assert_eq!(parsed.dimensions, 768);
492 assert_eq!(parsed.count, 12345);
493 }
494}