omendb_core/omen/
header.rs

1//! .omen file header (4KB)
2
3use crate::omen::section::{SectionEntry, SectionType};
4use std::io::{self, Read};
5
6/// Magic bytes: "OMEN"
7pub const MAGIC: [u8; 4] = *b"OMEN";
8
9/// Current format version
10pub const VERSION_MAJOR: u16 = 1;
11pub const VERSION_MINOR: u16 = 0;
12
13/// Header size (4KB, one page)
14pub const HEADER_SIZE: usize = 4096;
15
16/// Maximum number of sections
17pub const MAX_SECTIONS: usize = 8;
18
19/// Quantization code for file format serialization.
20///
21/// This is a compact `repr(u8)` representation for storing in the .omen header.
22/// For runtime API, use `crate::vector::QuantizationMode` instead.
23#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24#[repr(u8)]
25pub enum QuantizationCode {
26    F32 = 0,
27    Sq8 = 1,
28    RabitQ4 = 2,
29    RabitQ2 = 3,
30    RabitQ8 = 4,
31    Binary = 5,
32}
33
34impl From<u8> for QuantizationCode {
35    fn from(v: u8) -> Self {
36        match v {
37            1 => Self::Sq8,
38            2 => Self::RabitQ4,
39            3 => Self::RabitQ2,
40            4 => Self::RabitQ8,
41            5 => Self::Binary,
42            _ => Self::F32,
43        }
44    }
45}
46
47impl From<&crate::vector::QuantizationMode> for QuantizationCode {
48    fn from(mode: &crate::vector::QuantizationMode) -> Self {
49        use crate::compression::QuantizationBits;
50        match mode {
51            crate::vector::QuantizationMode::Binary => Self::Binary,
52            crate::vector::QuantizationMode::SQ8 => Self::Sq8,
53            crate::vector::QuantizationMode::RaBitQ(params) => match params.bits_per_dim {
54                QuantizationBits::Bits1 => Self::Binary,
55                QuantizationBits::Bits2 => Self::RabitQ2,
56                QuantizationBits::Bits3 | QuantizationBits::Bits4 => Self::RabitQ4,
57                QuantizationBits::Bits5 | QuantizationBits::Bits7 | QuantizationBits::Bits8 => {
58                    Self::RabitQ8
59                }
60            },
61        }
62    }
63}
64
65impl From<crate::vector::QuantizationMode> for QuantizationCode {
66    fn from(mode: crate::vector::QuantizationMode) -> Self {
67        Self::from(&mode)
68    }
69}
70
71impl QuantizationCode {
72    /// Convert to runtime `QuantizationMode`.
73    ///
74    /// Returns `None` for `F32` (no quantization).
75    #[must_use]
76    pub fn to_runtime(self) -> Option<crate::vector::QuantizationMode> {
77        use crate::compression::RaBitQParams;
78        match self {
79            Self::F32 => None,
80            Self::Sq8 => Some(crate::vector::QuantizationMode::SQ8),
81            Self::Binary => Some(crate::vector::QuantizationMode::Binary),
82            Self::RabitQ2 => Some(crate::vector::QuantizationMode::RaBitQ(
83                RaBitQParams::bits2(),
84            )),
85            Self::RabitQ4 => Some(crate::vector::QuantizationMode::RaBitQ(
86                RaBitQParams::bits4(),
87            )),
88            Self::RabitQ8 => Some(crate::vector::QuantizationMode::RaBitQ(
89                RaBitQParams::bits8(),
90            )),
91        }
92    }
93}
94
95/// Distance metric for similarity search (user-facing API type).
96///
97/// This is the serialization/API type stored in .omen file headers.
98/// For runtime distance computation, see `crate::vector::hnsw::DistanceFunction`.
99#[derive(Debug, Clone, Copy, PartialEq, Eq)]
100#[repr(u8)]
101pub enum Metric {
102    /// L2 / Euclidean distance
103    L2 = 0,
104    /// Cosine distance (1 - cosine similarity)
105    Cosine = 1,
106    /// Dot product / inner product (for MIPS)
107    Dot = 2,
108}
109
110impl From<u8> for Metric {
111    fn from(v: u8) -> Self {
112        match v {
113            1 => Self::Cosine,
114            2 => Self::Dot,
115            _ => Self::L2,
116        }
117    }
118}
119
120impl From<Metric> for crate::vector::hnsw::DistanceFunction {
121    fn from(m: Metric) -> Self {
122        match m {
123            Metric::L2 => Self::L2,
124            Metric::Cosine => Self::Cosine,
125            Metric::Dot => Self::NegativeDotProduct,
126        }
127    }
128}
129
130impl From<crate::vector::hnsw::DistanceFunction> for Metric {
131    fn from(d: crate::vector::hnsw::DistanceFunction) -> Self {
132        match d {
133            crate::vector::hnsw::DistanceFunction::L2 => Self::L2,
134            crate::vector::hnsw::DistanceFunction::Cosine => Self::Cosine,
135            crate::vector::hnsw::DistanceFunction::NegativeDotProduct => Self::Dot,
136        }
137    }
138}
139
140impl Metric {
141    /// Parse from string (case-insensitive, with aliases).
142    ///
143    /// # Supported values
144    /// - `"l2"` or `"euclidean"`: Euclidean distance (default)
145    /// - `"cosine"`: Cosine distance (1 - cosine similarity)
146    /// - `"dot"` or `"ip"`: Inner product (for MIPS)
147    pub fn parse(s: &str) -> Result<Self, String> {
148        match s.to_lowercase().as_str() {
149            "l2" | "euclidean" => Ok(Self::L2),
150            "cosine" => Ok(Self::Cosine),
151            "dot" | "ip" => Ok(Self::Dot),
152            _ => Err(format!(
153                "Unknown metric: '{s}'. Valid: l2, euclidean, cosine, dot, ip"
154            )),
155        }
156    }
157
158    /// Get the string representation.
159    #[must_use]
160    pub fn as_str(&self) -> &'static str {
161        match self {
162            Self::L2 => "l2",
163            Self::Cosine => "cosine",
164            Self::Dot => "dot",
165        }
166    }
167}
168
169/// .omen file header
170#[derive(Debug, Clone)]
171pub struct OmenHeader {
172    // Magic and version (16 bytes)
173    pub version_major: u16,
174    pub version_minor: u16,
175    pub flags: u64,
176
177    // Database info (32 bytes)
178    pub dimensions: u32,
179    pub count: u64,
180    pub quantization: QuantizationCode,
181    pub distance_fn: Metric,
182
183    // HNSW params (16 bytes)
184    pub m: u16,
185    pub ef_construction: u16,
186    pub ef_search: u16,
187    pub max_level: u8,
188    pub entry_point: u32,
189
190    // Section directory
191    pub sections: [SectionEntry; MAX_SECTIONS],
192
193    // Checksums
194    pub header_checksum: u32,
195    pub data_checksum: u32,
196}
197
198impl Default for OmenHeader {
199    fn default() -> Self {
200        Self {
201            version_major: VERSION_MAJOR,
202            version_minor: VERSION_MINOR,
203            flags: 0,
204            dimensions: 0,
205            count: 0,
206            quantization: QuantizationCode::F32,
207            distance_fn: Metric::L2,
208            m: 16,
209            ef_construction: 100,
210            ef_search: 100,
211            max_level: 0,
212            entry_point: 0,
213            sections: [SectionEntry::default(); MAX_SECTIONS],
214            header_checksum: 0,
215            data_checksum: 0,
216        }
217    }
218}
219
220impl OmenHeader {
221    /// Create a new header with the given dimensions
222    #[must_use]
223    pub fn new(dimensions: u32) -> Self {
224        Self {
225            dimensions,
226            ..Default::default()
227        }
228    }
229
230    /// Serialize header to bytes (4KB)
231    #[must_use]
232    pub fn to_bytes(&self) -> [u8; HEADER_SIZE] {
233        let mut buf = [0u8; HEADER_SIZE];
234        let mut offset = 0;
235
236        // Magic (4 bytes)
237        buf[offset..offset + 4].copy_from_slice(&MAGIC);
238        offset += 4;
239
240        // Version (4 bytes)
241        buf[offset..offset + 2].copy_from_slice(&self.version_major.to_le_bytes());
242        offset += 2;
243        buf[offset..offset + 2].copy_from_slice(&self.version_minor.to_le_bytes());
244        offset += 2;
245
246        // Flags (8 bytes)
247        buf[offset..offset + 8].copy_from_slice(&self.flags.to_le_bytes());
248        offset += 8;
249
250        // Database info (32 bytes)
251        buf[offset..offset + 4].copy_from_slice(&self.dimensions.to_le_bytes());
252        offset += 4;
253        buf[offset..offset + 8].copy_from_slice(&self.count.to_le_bytes());
254        offset += 8;
255        buf[offset] = self.quantization as u8;
256        offset += 1;
257        buf[offset] = self.distance_fn as u8;
258        offset += 1;
259        // 14 bytes reserved (already zeroed)
260        offset += 14;
261
262        // HNSW params (16 bytes)
263        buf[offset..offset + 2].copy_from_slice(&self.m.to_le_bytes());
264        offset += 2;
265        buf[offset..offset + 2].copy_from_slice(&self.ef_construction.to_le_bytes());
266        offset += 2;
267        buf[offset..offset + 2].copy_from_slice(&self.ef_search.to_le_bytes());
268        offset += 2;
269        buf[offset] = self.max_level;
270        offset += 1;
271        buf[offset..offset + 4].copy_from_slice(&self.entry_point.to_le_bytes());
272        offset += 4;
273        // 3 bytes reserved (already zeroed)
274        offset += 3;
275
276        // Sections (8 * 24 bytes = 192 bytes)
277        for section in &self.sections {
278            buf[offset..offset + 24].copy_from_slice(&section.to_bytes());
279            offset += 24;
280        }
281
282        // Checksums (8 bytes)
283        buf[offset..offset + 4].copy_from_slice(&self.header_checksum.to_le_bytes());
284        offset += 4;
285        buf[offset..offset + 4].copy_from_slice(&self.data_checksum.to_le_bytes());
286
287        // Calculate and write header checksum
288        let checksum = crc32fast::hash(&buf[..HEADER_SIZE - 8]);
289        buf[HEADER_SIZE - 8..HEADER_SIZE - 4].copy_from_slice(&checksum.to_le_bytes());
290
291        buf
292    }
293
294    /// Parse header from bytes
295    pub fn from_bytes(buf: &[u8; HEADER_SIZE]) -> io::Result<Self> {
296        // Verify magic
297        if buf[0..4] != MAGIC {
298            return Err(io::Error::new(
299                io::ErrorKind::InvalidData,
300                "Invalid magic bytes",
301            ));
302        }
303
304        // Verify checksum - direct array indexing for fixed-size buffer
305        let stored_checksum = u32::from_le_bytes([
306            buf[HEADER_SIZE - 8],
307            buf[HEADER_SIZE - 7],
308            buf[HEADER_SIZE - 6],
309            buf[HEADER_SIZE - 5],
310        ]);
311        let computed_checksum = crc32fast::hash(&buf[..HEADER_SIZE - 8]);
312        if stored_checksum != computed_checksum {
313            return Err(io::Error::new(
314                io::ErrorKind::InvalidData,
315                "Header checksum mismatch",
316            ));
317        }
318
319        let mut cursor = io::Cursor::new(&buf[4..]); // Skip magic
320
321        let mut u16_buf = [0u8; 2];
322        let mut u32_buf = [0u8; 4];
323        let mut u64_buf = [0u8; 8];
324        let mut u8_buf = [0u8; 1];
325
326        // Version
327        cursor.read_exact(&mut u16_buf)?;
328        let version_major = u16::from_le_bytes(u16_buf);
329        cursor.read_exact(&mut u16_buf)?;
330        let version_minor = u16::from_le_bytes(u16_buf);
331
332        // Check version compatibility
333        if version_major > VERSION_MAJOR {
334            return Err(io::Error::new(
335                io::ErrorKind::InvalidData,
336                format!("Unsupported version: {version_major}.{version_minor}"),
337            ));
338        }
339
340        // Flags
341        cursor.read_exact(&mut u64_buf)?;
342        let flags = u64::from_le_bytes(u64_buf);
343
344        // Database info
345        cursor.read_exact(&mut u32_buf)?;
346        let dimensions = u32::from_le_bytes(u32_buf);
347        cursor.read_exact(&mut u64_buf)?;
348        let count = u64::from_le_bytes(u64_buf);
349        cursor.read_exact(&mut u8_buf)?;
350        let quantization = QuantizationCode::from(u8_buf[0]);
351        cursor.read_exact(&mut u8_buf)?;
352        let distance_fn = Metric::from(u8_buf[0]);
353
354        // Skip reserved
355        let mut reserved = [0u8; 14];
356        cursor.read_exact(&mut reserved)?;
357
358        // HNSW params
359        cursor.read_exact(&mut u16_buf)?;
360        let m = u16::from_le_bytes(u16_buf);
361        cursor.read_exact(&mut u16_buf)?;
362        let ef_construction = u16::from_le_bytes(u16_buf);
363        cursor.read_exact(&mut u16_buf)?;
364        let ef_search = u16::from_le_bytes(u16_buf);
365        cursor.read_exact(&mut u8_buf)?;
366        let max_level = u8_buf[0];
367        cursor.read_exact(&mut u32_buf)?;
368        let entry_point = u32::from_le_bytes(u32_buf);
369
370        // Skip reserved
371        let mut reserved2 = [0u8; 3];
372        cursor.read_exact(&mut reserved2)?;
373
374        // Sections
375        let mut sections = [SectionEntry::default(); MAX_SECTIONS];
376        for section in &mut sections {
377            let mut section_buf = [0u8; 24];
378            cursor.read_exact(&mut section_buf)?;
379            *section = SectionEntry::from_bytes(&section_buf);
380        }
381
382        // Checksums
383        cursor.read_exact(&mut u32_buf)?;
384        let header_checksum = u32::from_le_bytes(u32_buf);
385        cursor.read_exact(&mut u32_buf)?;
386        let data_checksum = u32::from_le_bytes(u32_buf);
387
388        Ok(Self {
389            version_major,
390            version_minor,
391            flags,
392            dimensions,
393            count,
394            quantization,
395            distance_fn,
396            m,
397            ef_construction,
398            ef_search,
399            max_level,
400            entry_point,
401            sections,
402            header_checksum,
403            data_checksum,
404        })
405    }
406
407    /// Get section by type
408    #[must_use]
409    pub fn get_section(&self, section_type: SectionType) -> Option<&SectionEntry> {
410        self.sections
411            .iter()
412            .find(|s| s.section_type == section_type && s.length > 0)
413    }
414
415    /// Set section entry
416    pub fn set_section(&mut self, entry: SectionEntry) {
417        for section in &mut self.sections {
418            if section.section_type == entry.section_type || section.length == 0 {
419                *section = entry;
420                return;
421            }
422        }
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    #[test]
431    fn test_header_roundtrip() {
432        let mut header = OmenHeader::new(768);
433        header.count = 1000;
434        header.m = 32;
435        header.ef_construction = 200;
436        header.entry_point = 42;
437
438        let bytes = header.to_bytes();
439        let parsed = OmenHeader::from_bytes(&bytes).unwrap();
440
441        assert_eq!(parsed.dimensions, 768);
442        assert_eq!(parsed.count, 1000);
443        assert_eq!(parsed.m, 32);
444        assert_eq!(parsed.ef_construction, 200);
445        assert_eq!(parsed.entry_point, 42);
446    }
447
448    #[test]
449    fn test_invalid_magic() {
450        let mut buf = [0u8; HEADER_SIZE];
451        buf[0..4].copy_from_slice(b"NOPE");
452
453        let result = OmenHeader::from_bytes(&buf);
454        assert!(result.is_err());
455    }
456
457    #[test]
458    fn test_corrupted_header_detected() {
459        let header = OmenHeader::new(768);
460        let mut bytes = header.to_bytes();
461
462        // Corrupt a byte in the middle of the header (dimensions field)
463        bytes[20] ^= 0xFF;
464
465        let result = OmenHeader::from_bytes(&bytes);
466        assert!(result.is_err());
467        assert!(result
468            .unwrap_err()
469            .to_string()
470            .contains("checksum mismatch"));
471    }
472
473    #[test]
474    fn test_checksum_calculated_correctly() {
475        let mut header = OmenHeader::new(768);
476        header.count = 12345;
477        header.m = 32;
478        header.ef_construction = 200;
479
480        let bytes = header.to_bytes();
481
482        // Extract the stored checksum
483        let stored_checksum =
484            u32::from_le_bytes(bytes[HEADER_SIZE - 8..HEADER_SIZE - 4].try_into().unwrap());
485
486        // Verify it's not zero (would indicate checksum wasn't calculated)
487        assert_ne!(stored_checksum, 0);
488
489        // Verify we can read it back (proves checksum is correct)
490        let parsed = OmenHeader::from_bytes(&bytes).unwrap();
491        assert_eq!(parsed.dimensions, 768);
492        assert_eq!(parsed.count, 12345);
493    }
494}