hermes_core/segment/
vector_data.rs1use std::mem::size_of;
4
5use serde::{Deserialize, Serialize};
6
7const FLAT_BINARY_MAGIC: u32 = 0x46564432;
9
10const FLAT_BINARY_HEADER_SIZE: usize = 3 * size_of::<u32>();
12const FLOAT_SIZE: usize = size_of::<f32>();
14const DOC_ID_ENTRY_SIZE: usize = size_of::<u32>() + size_of::<u16>();
16
17#[derive(Debug, Clone)]
24pub struct FlatVectorData {
25 pub dim: usize,
26 vectors: Vec<f32>,
29 pub doc_ids: Vec<(u32, u16)>,
32}
33
34impl FlatVectorData {
35 #[inline]
37 pub fn num_vectors(&self) -> usize {
38 self.vectors.len().checked_div(self.dim).unwrap_or(0)
39 }
40
41 #[inline]
46 pub fn get_vector(&self, idx: usize) -> &[f32] {
47 let start = idx * self.dim;
48 &self.vectors[start..start + self.dim]
49 }
50
51 #[inline]
53 pub fn get_doc_id(&self, idx: usize) -> (u32, u16) {
54 self.doc_ids[idx]
55 }
56
57 #[inline]
60 pub fn vectors_as_bytes(&self) -> &[u8] {
61 unsafe {
63 std::slice::from_raw_parts(
64 self.vectors.as_ptr() as *const u8,
65 self.vectors.len() * FLOAT_SIZE,
66 )
67 }
68 }
69
70 pub fn stream_to_writer(
73 &self,
74 writer: &mut dyn std::io::Write,
75 doc_id_offset: u32,
76 ) -> std::io::Result<()> {
77 writer.write_all(self.vectors_as_bytes())?;
79 for &(doc_id, ordinal) in &self.doc_ids {
81 writer.write_all(&(doc_id_offset + doc_id).to_le_bytes())?;
82 writer.write_all(&ordinal.to_le_bytes())?;
83 }
84 Ok(())
85 }
86
87 pub fn write_binary_header(
89 dim: usize,
90 num_vectors: usize,
91 writer: &mut dyn std::io::Write,
92 ) -> std::io::Result<()> {
93 writer.write_all(&FLAT_BINARY_MAGIC.to_le_bytes())?;
94 writer.write_all(&(dim as u32).to_le_bytes())?;
95 writer.write_all(&(num_vectors as u32).to_le_bytes())?;
96 Ok(())
97 }
98
99 pub fn estimated_memory_bytes(&self) -> usize {
101 let vectors_bytes = self.vectors.capacity() * FLOAT_SIZE;
102 let doc_ids_bytes = self.doc_ids.capacity() * size_of::<(u32, u16)>();
103 vectors_bytes + doc_ids_bytes + size_of::<Self>()
104 }
105
106 pub fn from_binary_bytes(data: &[u8]) -> std::io::Result<Self> {
112 if data.len() < FLAT_BINARY_HEADER_SIZE {
113 return Err(std::io::Error::new(
114 std::io::ErrorKind::InvalidData,
115 "FlatVectorData binary too short",
116 ));
117 }
118
119 let magic = u32::from_le_bytes([data[0], data[1], data[2], data[3]]);
120 if magic != FLAT_BINARY_MAGIC {
121 return Err(std::io::Error::new(
122 std::io::ErrorKind::InvalidData,
123 "Invalid FlatVectorData binary magic",
124 ));
125 }
126
127 let dim = u32::from_le_bytes([data[4], data[5], data[6], data[7]]) as usize;
128 let num_vectors = u32::from_le_bytes([data[8], data[9], data[10], data[11]]) as usize;
129
130 let vectors_byte_len = num_vectors * dim * FLOAT_SIZE;
131 let doc_ids_start = FLAT_BINARY_HEADER_SIZE + vectors_byte_len;
132 let doc_ids_byte_len = num_vectors * DOC_ID_ENTRY_SIZE;
133
134 if data.len() < doc_ids_start + doc_ids_byte_len {
135 return Err(std::io::Error::new(
136 std::io::ErrorKind::InvalidData,
137 "FlatVectorData binary truncated",
138 ));
139 }
140
141 let total_floats = num_vectors * dim;
143 let mut vectors = vec![0f32; total_floats];
144 unsafe {
146 std::ptr::copy_nonoverlapping(
147 data[FLAT_BINARY_HEADER_SIZE..].as_ptr(),
148 vectors.as_mut_ptr() as *mut u8,
149 vectors_byte_len,
150 );
151 }
152
153 let mut doc_ids = Vec::with_capacity(num_vectors);
155 for i in 0..num_vectors {
156 let off = doc_ids_start + i * DOC_ID_ENTRY_SIZE;
157 let doc_id =
158 u32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]);
159 let ordinal = u16::from_le_bytes([data[off + 4], data[off + 5]]);
160 doc_ids.push((doc_id, ordinal));
161 }
162
163 Ok(FlatVectorData {
164 dim,
165 vectors,
166 doc_ids,
167 })
168 }
169
170 pub fn serialized_binary_size(index_dim: usize, num_vectors: usize) -> usize {
172 FLAT_BINARY_HEADER_SIZE
173 + num_vectors * index_dim * FLOAT_SIZE
174 + num_vectors * DOC_ID_ENTRY_SIZE
175 }
176
177 pub fn serialize_binary_from_flat_streaming(
182 index_dim: usize,
183 flat_vectors: &[f32],
184 original_dim: usize,
185 doc_ids: &[(u32, u16)],
186 writer: &mut dyn std::io::Write,
187 ) -> std::io::Result<()> {
188 let num_vectors = doc_ids.len();
189
190 writer.write_all(&FLAT_BINARY_MAGIC.to_le_bytes())?;
191 writer.write_all(&(index_dim as u32).to_le_bytes())?;
192 writer.write_all(&(num_vectors as u32).to_le_bytes())?;
193
194 if index_dim == original_dim {
195 let bytes: &[u8] = unsafe {
198 std::slice::from_raw_parts(
199 flat_vectors.as_ptr() as *const u8,
200 flat_vectors.len() * FLOAT_SIZE,
201 )
202 };
203 writer.write_all(bytes)?;
204 } else {
205 for i in 0..num_vectors {
207 let start = i * original_dim;
208 let slice = &flat_vectors[start..start + index_dim];
209 let bytes: &[u8] = unsafe {
210 std::slice::from_raw_parts(slice.as_ptr() as *const u8, index_dim * FLOAT_SIZE)
211 };
212 writer.write_all(bytes)?;
213 }
214 }
215
216 for &(doc_id, ordinal) in doc_ids {
217 writer.write_all(&doc_id.to_le_bytes())?;
218 writer.write_all(&ordinal.to_le_bytes())?;
219 }
220
221 Ok(())
222 }
223}
224
225#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct IVFRaBitQIndexData {
228 pub index: crate::structures::IVFRaBitQIndex,
229 pub centroids: crate::structures::CoarseCentroids,
230 pub codebook: crate::structures::RaBitQCodebook,
231}
232
233impl IVFRaBitQIndexData {
234 pub fn to_bytes(&self) -> std::io::Result<Vec<u8>> {
235 serde_json::to_vec(self)
236 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
237 }
238
239 pub fn from_bytes(data: &[u8]) -> std::io::Result<Self> {
240 serde_json::from_slice(data)
241 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
242 }
243}
244
245#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct ScaNNIndexData {
248 pub index: crate::structures::IVFPQIndex,
249 pub centroids: crate::structures::CoarseCentroids,
250 pub codebook: crate::structures::PQCodebook,
251}
252
253impl ScaNNIndexData {
254 pub fn to_bytes(&self) -> std::io::Result<Vec<u8>> {
255 serde_json::to_vec(self)
256 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
257 }
258
259 pub fn from_bytes(data: &[u8]) -> std::io::Result<Self> {
260 serde_json::from_slice(data)
261 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))
262 }
263}