1use blake3::hash;
2use serde::{Deserialize, Serialize};
3
4use crate::{MemvidError, Result, types::FrameId};
5
6fn vec_config() -> impl bincode::config::Config {
7 bincode::config::standard()
8 .with_fixed_int_encoding()
9 .with_little_endian()
10}
11
12const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct VecDocument {
16 pub frame_id: FrameId,
17 pub embedding: Vec<f32>,
18}
19
20#[derive(Default)]
21pub struct VecIndexBuilder {
22 documents: Vec<VecDocument>,
23}
24
25impl VecIndexBuilder {
26 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn add_document<I>(&mut self, frame_id: FrameId, embedding: I)
31 where
32 I: Into<Vec<f32>>,
33 {
34 self.documents.push(VecDocument {
35 frame_id,
36 embedding: embedding.into(),
37 });
38 }
39
40 pub fn finish(self) -> Result<VecIndexArtifact> {
41 let bytes = bincode::serde::encode_to_vec(&self.documents, vec_config())?;
42
43 let checksum = *hash(&bytes).as_bytes();
44 let dimension = self
45 .documents
46 .first()
47 .map(|doc| doc.embedding.len() as u32)
48 .unwrap_or(0);
49 #[cfg(feature = "parallel_segments")]
50 let bytes_uncompressed = self
51 .documents
52 .iter()
53 .map(|doc| doc.embedding.len() * std::mem::size_of::<f32>())
54 .sum::<usize>() as u64;
55 Ok(VecIndexArtifact {
56 bytes,
57 vector_count: self.documents.len() as u64,
58 dimension,
59 checksum,
60 #[cfg(feature = "parallel_segments")]
61 bytes_uncompressed,
62 })
63 }
64}
65
66#[derive(Debug, Clone)]
67pub struct VecIndexArtifact {
68 pub bytes: Vec<u8>,
69 pub vector_count: u64,
70 pub dimension: u32,
71 pub checksum: [u8; 32],
72 #[cfg(feature = "parallel_segments")]
73 pub bytes_uncompressed: u64,
74}
75
76#[derive(Debug, Clone)]
77pub enum VecIndex {
78 Uncompressed { documents: Vec<VecDocument> },
79 Compressed(crate::vec_pq::QuantizedVecIndex),
80}
81
82impl VecIndex {
83 pub fn decode(bytes: &[u8]) -> Result<Self> {
86 Self::decode_with_compression(bytes, crate::VectorCompression::None)
87 }
88
89 pub fn decode_with_compression(
96 bytes: &[u8],
97 _compression: crate::VectorCompression,
98 ) -> Result<Self> {
99 match bincode::serde::decode_from_slice::<Vec<VecDocument>, _>(
103 bytes,
104 bincode::config::standard()
105 .with_fixed_int_encoding()
106 .with_little_endian()
107 .with_limit::<VEC_DECODE_LIMIT>(),
108 ) {
109 Ok((documents, read)) if read == bytes.len() => {
110 tracing::debug!(
111 bytes_len = bytes.len(),
112 docs_count = documents.len(),
113 "decoded as uncompressed"
114 );
115 return Ok(Self::Uncompressed { documents });
116 }
117 Ok((_, read)) => {
118 tracing::debug!(
119 bytes_len = bytes.len(),
120 read = read,
121 "uncompressed decode partial read, trying PQ"
122 );
123 }
124 Err(err) => {
125 tracing::debug!(
126 error = %err,
127 bytes_len = bytes.len(),
128 "uncompressed decode failed, trying PQ"
129 );
130 }
131 }
132
133 match crate::vec_pq::QuantizedVecIndex::decode(bytes) {
135 Ok(quantized_index) => {
136 tracing::debug!(bytes_len = bytes.len(), "decoded as PQ");
137 Ok(Self::Compressed(quantized_index))
138 }
139 Err(err) => {
140 tracing::debug!(
141 error = %err,
142 bytes_len = bytes.len(),
143 "PQ decode also failed"
144 );
145 Err(MemvidError::InvalidToc {
146 reason: "unsupported vector index encoding".into(),
147 })
148 }
149 }
150 }
151
152 pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
153 if query.is_empty() {
154 return Vec::new();
155 }
156 match self {
157 VecIndex::Uncompressed { documents } => {
158 let mut hits: Vec<VecSearchHit> = documents
159 .iter()
160 .map(|doc| {
161 let distance = l2_distance(query, &doc.embedding);
162 VecSearchHit {
163 frame_id: doc.frame_id,
164 distance,
165 }
166 })
167 .collect();
168 hits.sort_by(|a, b| {
169 a.distance
170 .partial_cmp(&b.distance)
171 .unwrap_or(std::cmp::Ordering::Equal)
172 });
173 hits.truncate(limit);
174 hits
175 }
176 VecIndex::Compressed(quantized) => quantized.search(query, limit),
177 }
178 }
179
180 pub fn entries(&self) -> Box<dyn Iterator<Item = (FrameId, &[f32])> + '_> {
181 match self {
182 VecIndex::Uncompressed { documents } => Box::new(
183 documents
184 .iter()
185 .map(|doc| (doc.frame_id, doc.embedding.as_slice())),
186 ),
187 VecIndex::Compressed(_) => {
188 Box::new(std::iter::empty())
190 }
191 }
192 }
193
194 pub fn embedding_for(&self, frame_id: FrameId) -> Option<&[f32]> {
195 match self {
196 VecIndex::Uncompressed { documents } => documents
197 .iter()
198 .find(|doc| doc.frame_id == frame_id)
199 .map(|doc| doc.embedding.as_slice()),
200 VecIndex::Compressed(_) => {
201 None
203 }
204 }
205 }
206
207 pub fn remove(&mut self, frame_id: FrameId) {
208 match self {
209 VecIndex::Uncompressed { documents } => {
210 documents.retain(|doc| doc.frame_id != frame_id);
211 }
212 VecIndex::Compressed(_quantized) => {
213 }
216 }
217 }
218}
219
220#[derive(Debug, Clone, PartialEq)]
221pub struct VecSearchHit {
222 pub frame_id: FrameId,
223 pub distance: f32,
224}
225
226fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
227 a.iter()
228 .zip(b.iter())
229 .map(|(x, y)| (x - y).powi(2))
230 .sum::<f32>()
231 .sqrt()
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn builder_roundtrip() {
240 let mut builder = VecIndexBuilder::new();
241 builder.add_document(1, vec![0.0, 1.0, 2.0]);
242 builder.add_document(2, vec![1.0, 2.0, 3.0]);
243 let artifact = builder.finish().expect("finish");
244 assert_eq!(artifact.vector_count, 2);
245 assert_eq!(artifact.dimension, 3);
246
247 let index = VecIndex::decode(&artifact.bytes).expect("decode");
248 let hits = index.search(&[0.0, 1.0, 2.0], 10);
249 assert_eq!(hits[0].frame_id, 1);
250 }
251
252 #[test]
253 fn l2_distance_behaves() {
254 let d = l2_distance(&[0.0, 0.0], &[3.0, 4.0]);
255 assert!((d - 5.0).abs() < 1e-6);
256 }
257}