1use blake3::hash;
13use serde::{Deserialize, Serialize};
14
15use crate::vec::VecSearchHit;
16use crate::{MemvidError, Result, types::FrameId};
17
18fn vec_config() -> impl bincode::config::Config {
19 bincode::config::standard()
20 .with_fixed_int_encoding()
21 .with_little_endian()
22}
23
24const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
25
26const NUM_SUBSPACES: usize = 96; const SUBSPACE_DIM: usize = 4; const NUM_CENTROIDS: usize = 256; const TOTAL_DIM: usize = NUM_SUBSPACES * SUBSPACE_DIM; #[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct SubspaceCodebook {
35 centroids: Vec<f32>,
37}
38
39impl SubspaceCodebook {
40 fn new() -> Self {
41 Self {
42 centroids: vec![0.0; NUM_CENTROIDS * SUBSPACE_DIM],
43 }
44 }
45
46 fn get_centroid(&self, index: u8) -> &[f32] {
47 let start = (index as usize) * SUBSPACE_DIM;
48 &self.centroids[start..start + SUBSPACE_DIM]
49 }
50
51 fn set_centroid(&mut self, index: u8, values: &[f32]) {
52 assert_eq!(values.len(), SUBSPACE_DIM);
53 let start = (index as usize) * SUBSPACE_DIM;
54 self.centroids[start..start + SUBSPACE_DIM].copy_from_slice(values);
55 }
56
57 fn quantize(&self, subspace: &[f32]) -> u8 {
59 assert_eq!(subspace.len(), SUBSPACE_DIM);
60
61 let mut best_idx = 0u8;
62 let mut best_dist = f32::INFINITY;
63
64 for i in 0..NUM_CENTROIDS {
65 let centroid = self.get_centroid(i as u8);
66 let dist = l2_distance_squared(subspace, centroid);
67 if dist < best_dist {
68 best_dist = dist;
69 best_idx = i as u8;
70 }
71 }
72
73 best_idx
74 }
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ProductQuantizer {
80 codebooks: Vec<SubspaceCodebook>,
82 dimension: u32,
83}
84
85impl ProductQuantizer {
86 pub fn new(dimension: u32) -> Result<Self> {
88 if dimension as usize != TOTAL_DIM {
89 return Err(MemvidError::InvalidQuery {
90 reason: format!(
91 "PQ only supports {}-dim vectors, got {}",
92 TOTAL_DIM, dimension
93 ),
94 });
95 }
96
97 Ok(Self {
98 codebooks: vec![SubspaceCodebook::new(); NUM_SUBSPACES],
99 dimension,
100 })
101 }
102
103 pub fn train(&mut self, training_vectors: &[Vec<f32>], max_iterations: usize) -> Result<()> {
105 if training_vectors.is_empty() {
106 return Err(MemvidError::InvalidQuery {
107 reason: "Cannot train PQ with empty training set".to_string(),
108 });
109 }
110
111 for vec in training_vectors {
113 if vec.len() != TOTAL_DIM {
114 return Err(MemvidError::InvalidQuery {
115 reason: format!(
116 "Training vector has wrong dimension: expected {}, got {}",
117 TOTAL_DIM,
118 vec.len()
119 ),
120 });
121 }
122 }
123
124 for subspace_idx in 0..NUM_SUBSPACES {
126 let start_dim = subspace_idx * SUBSPACE_DIM;
127 let end_dim = start_dim + SUBSPACE_DIM;
128
129 let subspace_vecs: Vec<Vec<f32>> = training_vectors
131 .iter()
132 .map(|v| v[start_dim..end_dim].to_vec())
133 .collect();
134
135 let centroids = kmeans(&subspace_vecs, NUM_CENTROIDS, max_iterations)?;
137
138 for (i, centroid) in centroids.iter().enumerate() {
140 self.codebooks[subspace_idx].set_centroid(i as u8, centroid);
141 }
142 }
143
144 Ok(())
145 }
146
147 pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
149 if vector.len() != TOTAL_DIM {
150 return Err(MemvidError::InvalidQuery {
151 reason: format!(
152 "Vector dimension mismatch: expected {}, got {}",
153 TOTAL_DIM,
154 vector.len()
155 ),
156 });
157 }
158
159 let mut codes = Vec::with_capacity(NUM_SUBSPACES);
160
161 for subspace_idx in 0..NUM_SUBSPACES {
162 let start_dim = subspace_idx * SUBSPACE_DIM;
163 let end_dim = start_dim + SUBSPACE_DIM;
164 let subspace = &vector[start_dim..end_dim];
165
166 let code = self.codebooks[subspace_idx].quantize(subspace);
167 codes.push(code);
168 }
169
170 Ok(codes)
171 }
172
173 pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
175 if codes.len() != NUM_SUBSPACES {
176 return Err(MemvidError::InvalidQuery {
177 reason: format!(
178 "Invalid PQ codes length: expected {}, got {}",
179 NUM_SUBSPACES,
180 codes.len()
181 ),
182 });
183 }
184
185 let mut vector = Vec::with_capacity(TOTAL_DIM);
186
187 for (subspace_idx, &code) in codes.iter().enumerate() {
188 let centroid = self.codebooks[subspace_idx].get_centroid(code);
189 vector.extend_from_slice(centroid);
190 }
191
192 Ok(vector)
193 }
194
195 pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
198 if query.len() != TOTAL_DIM || codes.len() != NUM_SUBSPACES {
199 return f32::INFINITY;
200 }
201
202 let mut total_dist_sq = 0.0f32;
203
204 for subspace_idx in 0..NUM_SUBSPACES {
205 let start_dim = subspace_idx * SUBSPACE_DIM;
206 let end_dim = start_dim + SUBSPACE_DIM;
207 let query_subspace = &query[start_dim..end_dim];
208
209 let code = codes[subspace_idx];
210 let centroid = self.codebooks[subspace_idx].get_centroid(code);
211
212 total_dist_sq += l2_distance_squared(query_subspace, centroid);
213 }
214
215 total_dist_sq.sqrt()
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct QuantizedVecDocument {
222 pub frame_id: FrameId,
223 pub codes: Vec<u8>,
225}
226
227#[derive(Default)]
229pub struct QuantizedVecIndexBuilder {
230 documents: Vec<QuantizedVecDocument>,
231 quantizer: Option<ProductQuantizer>,
232}
233
234impl QuantizedVecIndexBuilder {
235 pub fn new() -> Self {
236 Self::default()
237 }
238
239 pub fn train_quantizer(&mut self, training_vectors: &[Vec<f32>], dimension: u32) -> Result<()> {
241 let mut pq = ProductQuantizer::new(dimension)?;
242 pq.train(training_vectors, 25)?; self.quantizer = Some(pq);
244 Ok(())
245 }
246
247 pub fn add_document(&mut self, frame_id: FrameId, embedding: Vec<f32>) -> Result<()> {
249 let quantizer = self
250 .quantizer
251 .as_ref()
252 .ok_or_else(|| MemvidError::InvalidQuery {
253 reason: "Quantizer not trained. Call train_quantizer first".to_string(),
254 })?;
255
256 let codes = quantizer.encode(&embedding)?;
257
258 self.documents
259 .push(QuantizedVecDocument { frame_id, codes });
260
261 Ok(())
262 }
263
264 pub fn finish(self) -> Result<QuantizedVecIndexArtifact> {
265 let quantizer = self.quantizer.ok_or_else(|| MemvidError::InvalidQuery {
266 reason: "Quantizer not trained".to_string(),
267 })?;
268
269 let vector_count = self.documents.len() as u64;
270 let bytes =
271 bincode::serde::encode_to_vec(&(quantizer.clone(), self.documents), vec_config())?;
272 let checksum = *hash(&bytes).as_bytes();
273
274 Ok(QuantizedVecIndexArtifact {
275 bytes,
276 vector_count,
277 dimension: quantizer.dimension,
278 checksum,
279 compression_ratio: 16.0, })
281 }
282}
283
284#[derive(Debug, Clone)]
285pub struct QuantizedVecIndexArtifact {
286 pub bytes: Vec<u8>,
287 pub vector_count: u64,
288 pub dimension: u32,
289 pub checksum: [u8; 32],
290 pub compression_ratio: f64,
291}
292
293#[derive(Debug, Clone)]
294pub struct QuantizedVecIndex {
295 quantizer: ProductQuantizer,
296 documents: Vec<QuantizedVecDocument>,
297}
298
299impl QuantizedVecIndex {
300 pub fn decode(bytes: &[u8]) -> Result<Self> {
301 let config = bincode::config::standard()
303 .with_fixed_int_encoding()
304 .with_little_endian()
305 .with_limit::<VEC_DECODE_LIMIT>();
306
307 if let Ok(((quantizer, documents), read)) = bincode::serde::decode_from_slice::<
308 (ProductQuantizer, Vec<QuantizedVecDocument>),
309 _,
310 >(bytes, config.clone())
311 {
312 if read == bytes.len() {
313 return Ok(Self {
314 quantizer,
315 documents,
316 });
317 }
318 }
319
320 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
323 struct OldProductQuantizer {
324 codebooks: Vec<SubspaceCodebook>,
325 }
326
327 let ((old_quantizer, documents), read): (
328 (OldProductQuantizer, Vec<QuantizedVecDocument>),
329 usize,
330 ) = bincode::serde::decode_from_slice(bytes, config)?;
331
332 if read != bytes.len() {
333 return Err(MemvidError::InvalidToc {
334 reason: "unsupported quantized vector index encoding".into(),
335 });
336 }
337
338 let quantizer = ProductQuantizer {
340 codebooks: old_quantizer.codebooks,
341 dimension: (NUM_SUBSPACES * SUBSPACE_DIM) as u32,
342 };
343
344 Ok(Self {
345 quantizer,
346 documents,
347 })
348 }
349
350 pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
352 if query.is_empty() {
353 return Vec::new();
354 }
355
356 let mut hits: Vec<VecSearchHit> = self
357 .documents
358 .iter()
359 .map(|doc| {
360 let distance = self.quantizer.asymmetric_distance(query, &doc.codes);
361 VecSearchHit {
362 frame_id: doc.frame_id,
363 distance,
364 }
365 })
366 .collect();
367
368 hits.sort_by(|a, b| {
369 a.distance
370 .partial_cmp(&b.distance)
371 .unwrap_or(std::cmp::Ordering::Equal)
372 });
373
374 hits.truncate(limit);
375 hits
376 }
377
378 pub fn remove(&mut self, frame_id: FrameId) {
379 self.documents.retain(|doc| doc.frame_id != frame_id);
380 }
381
382 pub fn compression_stats(&self) -> CompressionStats {
384 let original_bytes = self.documents.len() * TOTAL_DIM * std::mem::size_of::<f32>();
385 let compressed_bytes = self.documents.len() * NUM_SUBSPACES; let codebook_bytes =
387 NUM_SUBSPACES * NUM_CENTROIDS * SUBSPACE_DIM * std::mem::size_of::<f32>();
388
389 CompressionStats {
390 vector_count: self.documents.len() as u64,
391 original_bytes: original_bytes as u64,
392 compressed_bytes: compressed_bytes as u64,
393 codebook_bytes: codebook_bytes as u64,
394 total_bytes: (compressed_bytes + codebook_bytes) as u64,
395 compression_ratio: original_bytes as f64 / (compressed_bytes + codebook_bytes) as f64,
396 }
397 }
398}
399
400#[derive(Debug, Clone)]
401pub struct CompressionStats {
402 pub vector_count: u64,
403 pub original_bytes: u64,
404 pub compressed_bytes: u64,
405 pub codebook_bytes: u64,
406 pub total_bytes: u64,
407 pub compression_ratio: f64,
408}
409
410fn kmeans(vectors: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
412 if vectors.is_empty() {
413 return Err(MemvidError::InvalidQuery {
414 reason: "Cannot run k-means on empty vector set".to_string(),
415 });
416 }
417
418 let dim = vectors[0].len();
419
420 let mut centroids = kmeans_plus_plus_init(vectors, k)?;
422
423 for _iteration in 0..max_iterations {
424 let mut assignments = vec![Vec::new(); k];
426
427 for vec in vectors {
428 let mut best_cluster = 0;
429 let mut best_dist = f32::INFINITY;
430
431 for (cluster_idx, centroid) in centroids.iter().enumerate() {
432 let dist = l2_distance_squared(vec, centroid);
433 if dist < best_dist {
434 best_dist = dist;
435 best_cluster = cluster_idx;
436 }
437 }
438
439 assignments[best_cluster].push(vec.clone());
440 }
441
442 let mut changed = false;
444 for (cluster_idx, cluster_vecs) in assignments.iter().enumerate() {
445 if cluster_vecs.is_empty() {
446 centroids[cluster_idx] = vectors[cluster_idx % vectors.len()].clone();
448 changed = true;
449 continue;
450 }
451
452 let mut new_centroid = vec![0.0f32; dim];
453 for vec in cluster_vecs {
454 for (i, &val) in vec.iter().enumerate() {
455 new_centroid[i] += val;
456 }
457 }
458 for val in &mut new_centroid {
459 *val /= cluster_vecs.len() as f32;
460 }
461
462 if l2_distance_squared(¢roids[cluster_idx], &new_centroid) > 1e-6 {
464 changed = true;
465 }
466
467 centroids[cluster_idx] = new_centroid;
468 }
469
470 if !changed {
471 break; }
473 }
474
475 Ok(centroids)
476}
477
478fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
480 if vectors.is_empty() || k == 0 {
481 return Err(MemvidError::InvalidQuery {
482 reason: "Invalid k-means++ initialization".to_string(),
483 });
484 }
485
486 let mut centroids = Vec::new();
487
488 centroids.push(vectors[0].clone());
490
491 for _ in 1..k {
493 let mut distances = Vec::new();
494
495 for vec in vectors {
497 let mut min_dist = f32::INFINITY;
498 for centroid in ¢roids {
499 let dist = l2_distance_squared(vec, centroid);
500 min_dist = min_dist.min(dist);
501 }
502 distances.push(min_dist);
503 }
504
505 let max_idx = distances
507 .iter()
508 .enumerate()
509 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
510 .map(|(idx, _)| idx)
511 .unwrap_or(0);
512
513 centroids.push(vectors[max_idx].clone());
514 }
515
516 Ok(centroids)
517}
518
519fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
521 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum()
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527
528 #[test]
529 fn test_subspace_codebook() {
530 let mut codebook = SubspaceCodebook::new();
531
532 codebook.set_centroid(0, &[1.0, 2.0, 3.0, 4.0]);
534
535 let centroid = codebook.get_centroid(0);
537 assert_eq!(centroid, &[1.0, 2.0, 3.0, 4.0]);
538
539 let code = codebook.quantize(&[1.1, 2.1, 3.1, 4.1]);
541 assert_eq!(code, 0);
542 }
543
544 #[test]
545 fn test_product_quantizer_roundtrip() {
546 let mut training_vecs = Vec::new();
548 for i in 0..100 {
549 let mut vec = vec![0.0f32; TOTAL_DIM];
550 for j in 0..TOTAL_DIM {
551 vec[j] = ((i * TOTAL_DIM + j) % 100) as f32 / 100.0;
552 }
553 training_vecs.push(vec);
554 }
555
556 let mut pq = ProductQuantizer::new(TOTAL_DIM as u32).unwrap();
558 pq.train(&training_vecs, 10).unwrap();
559
560 let test_vec = &training_vecs[0];
562 let codes = pq.encode(test_vec).unwrap();
563 assert_eq!(codes.len(), NUM_SUBSPACES);
564
565 let decoded = pq.decode(&codes).unwrap();
567 assert_eq!(decoded.len(), TOTAL_DIM);
568
569 let dist = l2_distance_squared(test_vec, &decoded).sqrt();
571 assert!(dist < 10.0, "Reconstruction error too large: {}", dist);
572 }
573
574 #[test]
575 fn test_quantized_index_builder() {
576 let mut training_vecs = Vec::new();
578 for i in 0..50 {
579 let mut vec = vec![0.0f32; TOTAL_DIM];
580 for j in 0..TOTAL_DIM {
581 vec[j] = ((i + j) % 10) as f32;
582 }
583 training_vecs.push(vec);
584 }
585
586 let mut builder = QuantizedVecIndexBuilder::new();
588 builder
589 .train_quantizer(&training_vecs, TOTAL_DIM as u32)
590 .unwrap();
591
592 for (i, vec) in training_vecs.iter().take(10).enumerate() {
593 builder
594 .add_document((i + 1) as FrameId, vec.clone())
595 .unwrap();
596 }
597
598 let artifact = builder.finish().unwrap();
599 assert_eq!(artifact.vector_count, 10);
600 assert_eq!(artifact.dimension, TOTAL_DIM as u32);
601 assert!(artifact.compression_ratio > 10.0);
602
603 let index = QuantizedVecIndex::decode(&artifact.bytes).unwrap();
605 let query = &training_vecs[0];
606 let hits = index.search(query, 5);
607
608 assert!(!hits.is_empty());
609 assert_eq!(hits[0].frame_id, 1); }
611
612 #[test]
613 fn test_kmeans_simple() {
614 let vectors = vec![
615 vec![0.0, 0.0],
616 vec![0.1, 0.1],
617 vec![10.0, 10.0],
618 vec![10.1, 10.1],
619 ];
620
621 let centroids = kmeans(&vectors, 2, 100).unwrap();
622 assert_eq!(centroids.len(), 2);
623
624 let near_zero = centroids.iter().any(|c| c[0] < 5.0 && c[1] < 5.0);
626 let near_ten = centroids.iter().any(|c| c[0] > 5.0 && c[1] > 5.0);
627 assert!(near_zero && near_ten);
628 }
629}