Skip to main content

hermes_core/segment/
vector_data.rs

1//! Vector index data structures shared between builder and reader
2
3use std::mem::size_of;
4
5use serde::{Deserialize, Serialize};
6
7/// Magic number for binary flat vector format ("FVD2" in little-endian)
8const FLAT_BINARY_MAGIC: u32 = 0x46564432;
9
10/// Binary header: magic(u32) + dim(u32) + num_vectors(u32)
11const FLAT_BINARY_HEADER_SIZE: usize = 3 * size_of::<u32>();
12/// Per-vector element size
13const FLOAT_SIZE: usize = size_of::<f32>();
14/// Per-doc_id entry: doc_id(u32) + ordinal(u16)
15const DOC_ID_ENTRY_SIZE: usize = size_of::<u32>() + size_of::<u16>();
16
17/// Flat vector data for brute-force search.
18///
19/// Uses a single contiguous `Vec<f32>` instead of `Vec<Vec<f32>>`.
20/// Loading is a single bulk memcpy (1 allocation) instead of N separate
21/// allocations with float-by-float parsing. For 3.3M vectors at 768 dims
22/// this reduces load time from ~36s to ~1s.
23#[derive(Debug, Clone)]
24pub struct FlatVectorData {
25    pub dim: usize,
26    /// Flat contiguous vector storage: num_vectors * dim f32 values.
27    /// Access vector i via `vectors[i*dim .. (i+1)*dim]`.
28    vectors: Vec<f32>,
29    /// Document IDs with ordinals: (doc_id, ordinal) pairs
30    /// Ordinal tracks which vector in a multi-valued field
31    pub doc_ids: Vec<(u32, u16)>,
32}
33
34impl FlatVectorData {
35    /// Number of vectors
36    #[inline]
37    pub fn num_vectors(&self) -> usize {
38        self.vectors.len().checked_div(self.dim).unwrap_or(0)
39    }
40
41    /// Get vector at index as a &[f32] slice.
42    ///
43    /// # Panics
44    /// Panics if `idx >= num_vectors`.
45    #[inline]
46    pub fn get_vector(&self, idx: usize) -> &[f32] {
47        let start = idx * self.dim;
48        &self.vectors[start..start + self.dim]
49    }
50
51    /// Get doc_id and ordinal at index.
52    #[inline]
53    pub fn get_doc_id(&self, idx: usize) -> (u32, u16) {
54        self.doc_ids[idx]
55    }
56
57    /// Raw flat vector storage as byte slice for bulk streaming.
58    /// Returns the contiguous f32 data reinterpreted as bytes.
59    #[inline]
60    pub fn vectors_as_bytes(&self) -> &[u8] {
61        // SAFETY: reinterpret &[f32] as &[u8] — same layout, just wider view
62        unsafe {
63            std::slice::from_raw_parts(
64                self.vectors.as_ptr() as *const u8,
65                self.vectors.len() * FLOAT_SIZE,
66            )
67        }
68    }
69
70    /// Stream vectors and doc_ids from this FlatVectorData to a writer.
71    /// `doc_id_offset` is added to each doc_id for multi-segment merges.
72    pub fn stream_to_writer(
73        &self,
74        writer: &mut dyn std::io::Write,
75        doc_id_offset: u32,
76    ) -> std::io::Result<()> {
77        // Bulk write all vector bytes
78        writer.write_all(self.vectors_as_bytes())?;
79        // Write doc_ids with offset
80        for &(doc_id, ordinal) in &self.doc_ids {
81            writer.write_all(&(doc_id_offset + doc_id).to_le_bytes())?;
82            writer.write_all(&ordinal.to_le_bytes())?;
83        }
84        Ok(())
85    }
86
87    /// Write the binary header (magic + dim + num_vectors) to a writer.
88    pub fn write_binary_header(
89        dim: usize,
90        num_vectors: usize,
91        writer: &mut dyn std::io::Write,
92    ) -> std::io::Result<()> {
93        writer.write_all(&FLAT_BINARY_MAGIC.to_le_bytes())?;
94        writer.write_all(&(dim as u32).to_le_bytes())?;
95        writer.write_all(&(num_vectors as u32).to_le_bytes())?;
96        Ok(())
97    }
98
99    /// Estimate memory usage
100    pub fn estimated_memory_bytes(&self) -> usize {
101        let vectors_bytes = self.vectors.capacity() * FLOAT_SIZE;
102        let doc_ids_bytes = self.doc_ids.capacity() * size_of::<(u32, u16)>();
103        vectors_bytes + doc_ids_bytes + size_of::<Self>()
104    }
105
106    /// Deserialize from binary format. Single bulk memcpy for vectors.
107    ///
108    /// Parses the 12-byte header, bulk-copies vectors as one contiguous Vec<f32>,
109    /// and parses doc_ids. For 3.3M vectors at 768 dims this is ~1s vs ~36s
110    /// with the old Vec<Vec<f32>> approach (1 allocation vs 3.3M allocations).
111    pub fn from_binary_bytes(data: &[u8]) -> std::io::Result<Self> {
112        if data.len() < FLAT_BINARY_HEADER_SIZE {
113            return Err(std::io::Error::new(
114                std::io::ErrorKind::InvalidData,
115                "FlatVectorData binary too short",
116            ));
117        }
118
119        let magic = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
120        if magic != FLAT_BINARY_MAGIC {
121            return Err(std::io::Error::new(
122                std::io::ErrorKind::InvalidData,
123                "Invalid FlatVectorData binary magic",
124            ));
125        }
126
127        let dim = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
128        let num_vectors = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
129
130        let vectors_byte_len = num_vectors * dim * FLOAT_SIZE;
131        let doc_ids_start = FLAT_BINARY_HEADER_SIZE + vectors_byte_len;
132        let doc_ids_byte_len = num_vectors * DOC_ID_ENTRY_SIZE;
133
134        if data.len() < doc_ids_start + doc_ids_byte_len {
135            return Err(std::io::Error::new(
136                std::io::ErrorKind::InvalidData,
137                "FlatVectorData binary truncated",
138            ));
139        }
140
141        // Bulk memcpy: one allocation of num_vectors*dim floats
142        let total_floats = num_vectors * dim;
143        let mut vectors = vec![0f32; total_floats];
144        // SAFETY: copy raw bytes into properly-aligned Vec<f32>
145        unsafe {
146            std::ptr::copy_nonoverlapping(
147                data[FLAT_BINARY_HEADER_SIZE..].as_ptr(),
148                vectors.as_mut_ptr() as *mut u8,
149                vectors_byte_len,
150            );
151        }
152
153        // Parse doc_ids (small: ~6 bytes per vector vs ~3072 bytes per 768-dim vector)
154        let mut doc_ids = Vec::with_capacity(num_vectors);
155        for i in 0..num_vectors {
156            let off = doc_ids_start + i * DOC_ID_ENTRY_SIZE;
157            let doc_id =
158                u32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]);
159            let ordinal = u16::from_le_bytes([data[off + 4], data[off + 5]]);
160            doc_ids.push((doc_id, ordinal));
161        }
162
163        Ok(FlatVectorData {
164            dim,
165            vectors,
166            doc_ids,
167        })
168    }
169
170    /// Compute the serialized size without actually serializing.
171    pub fn serialized_binary_size(index_dim: usize, num_vectors: usize) -> usize {
172        FLAT_BINARY_HEADER_SIZE
173            + num_vectors * index_dim * FLOAT_SIZE
174            + num_vectors * DOC_ID_ENTRY_SIZE
175    }
176
177    /// Stream directly from flat f32 storage to a writer (zero-buffer serialization).
178    ///
179    /// `flat_vectors` is contiguous storage of dim*n floats.
180    /// `original_dim` is the dimension in flat_vectors (may differ from index_dim for MRL).
181    pub fn serialize_binary_from_flat_streaming(
182        index_dim: usize,
183        flat_vectors: &[f32],
184        original_dim: usize,
185        doc_ids: &[(u32, u16)],
186        writer: &mut dyn std::io::Write,
187    ) -> std::io::Result<()> {
188        let num_vectors = doc_ids.len();
189
190        writer.write_all(&FLAT_BINARY_MAGIC.to_le_bytes())?;
191        writer.write_all(&(index_dim as u32).to_le_bytes())?;
192        writer.write_all(&(num_vectors as u32).to_le_bytes())?;
193
194        if index_dim == original_dim {
195            // No trimming — write all floats directly
196            // SAFETY: reinterpret f32 slice as bytes for efficient bulk write
197            let bytes: &[u8] = unsafe {
198                std::slice::from_raw_parts(
199                    flat_vectors.as_ptr() as *const u8,
200                    flat_vectors.len() * FLOAT_SIZE,
201                )
202            };
203            writer.write_all(bytes)?;
204        } else {
205            // Trim each vector to index_dim (matryoshka/MRL)
206            for i in 0..num_vectors {
207                let start = i * original_dim;
208                let slice = &flat_vectors[start..start + index_dim];
209                let bytes: &[u8] = unsafe {
210                    std::slice::from_raw_parts(slice.as_ptr() as *const u8, index_dim * FLOAT_SIZE)
211                };
212                writer.write_all(bytes)?;
213            }
214        }
215
216        for &(doc_id, ordinal) in doc_ids {
217            writer.write_all(&doc_id.to_le_bytes())?;
218            writer.write_all(&ordinal.to_le_bytes())?;
219        }
220
221        Ok(())
222    }
223}
224
225/// IVF-RaBitQ index data with embedded centroids and codebook
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct IVFRaBitQIndexData {
228    pub index: crate::structures::IVFRaBitQIndex,
229    pub centroids: crate::structures::CoarseCentroids,
230    pub codebook: crate::structures::RaBitQCodebook,
231}
232
233impl IVFRaBitQIndexData {
234    pub fn to_bytes(&self) -> std::io::Result<Vec<u8>> {
235        serde_json::to_vec(self)
236            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
237    }
238
239    pub fn from_bytes(data: &[u8]) -> std::io::Result<Self> {
240        serde_json::from_slice(data)
241            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
242    }
243}
244
245/// ScaNN index data with embedded centroids and codebook
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct ScaNNIndexData {
248    pub index: crate::structures::IVFPQIndex,
249    pub centroids: crate::structures::CoarseCentroids,
250    pub codebook: crate::structures::PQCodebook,
251}
252
253impl ScaNNIndexData {
254    pub fn to_bytes(&self) -> std::io::Result<Vec<u8>> {
255        serde_json::to_vec(self)
256            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
257    }
258
259    pub fn from_bytes(data: &[u8]) -> std::io::Result<Self> {
260        serde_json::from_slice(data)
261            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
262    }
263}