1use blake3::hash;
13use serde::{Deserialize, Serialize};
14
15use crate::vec::VecSearchHit;
16use crate::{MemvidError, Result, types::FrameId};
17
18fn vec_config() -> impl bincode::config::Config {
19 bincode::config::standard()
20 .with_fixed_int_encoding()
21 .with_little_endian()
22}
23
24#[allow(clippy::cast_possible_truncation)]
25const VEC_DECODE_LIMIT: usize = crate::MAX_INDEX_BYTES as usize;
26
27const NUM_SUBSPACES: usize = 96; const SUBSPACE_DIM: usize = 4; const NUM_CENTROIDS: usize = 256; const TOTAL_DIM: usize = NUM_SUBSPACES * SUBSPACE_DIM; #[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct SubspaceCodebook {
36 centroids: Vec<f32>,
38}
39
40impl SubspaceCodebook {
41 fn new() -> Self {
42 Self {
43 centroids: vec![0.0; NUM_CENTROIDS * SUBSPACE_DIM],
44 }
45 }
46
47 fn get_centroid(&self, index: u8) -> &[f32] {
48 let start = (index as usize) * SUBSPACE_DIM;
49 &self.centroids[start..start + SUBSPACE_DIM]
50 }
51
52 fn set_centroid(&mut self, index: u8, values: &[f32]) {
53 assert_eq!(values.len(), SUBSPACE_DIM);
54 let start = (index as usize) * SUBSPACE_DIM;
55 self.centroids[start..start + SUBSPACE_DIM].copy_from_slice(values);
56 }
57
58 fn quantize(&self, subspace: &[f32]) -> u8 {
60 assert_eq!(subspace.len(), SUBSPACE_DIM);
61
62 let mut best_idx = 0u8;
63 let mut best_dist = f32::INFINITY;
64
65 for i in 0..NUM_CENTROIDS {
66 #[allow(clippy::cast_possible_truncation)]
67 let centroid = self.get_centroid(i as u8);
68 let dist = l2_distance_squared(subspace, centroid);
69 if dist < best_dist {
70 best_dist = dist;
71 #[allow(clippy::cast_possible_truncation)]
72 {
73 best_idx = i as u8;
74 }
75 }
76 }
77
78 best_idx
79 }
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct ProductQuantizer {
85 codebooks: Vec<SubspaceCodebook>,
87 dimension: u32,
88}
89
90impl ProductQuantizer {
91 pub fn new(dimension: u32) -> Result<Self> {
93 if dimension as usize != TOTAL_DIM {
94 return Err(MemvidError::InvalidQuery {
95 reason: format!("PQ only supports {TOTAL_DIM}-dim vectors, got {dimension}"),
96 });
97 }
98
99 Ok(Self {
100 codebooks: vec![SubspaceCodebook::new(); NUM_SUBSPACES],
101 dimension,
102 })
103 }
104
105 pub fn train(&mut self, training_vectors: &[Vec<f32>], max_iterations: usize) -> Result<()> {
107 if training_vectors.is_empty() {
108 return Err(MemvidError::InvalidQuery {
109 reason: "Cannot train PQ with empty training set".to_string(),
110 });
111 }
112
113 for vec in training_vectors {
115 if vec.len() != TOTAL_DIM {
116 return Err(MemvidError::InvalidQuery {
117 reason: format!(
118 "Training vector has wrong dimension: expected {}, got {}",
119 TOTAL_DIM,
120 vec.len()
121 ),
122 });
123 }
124 }
125
126 for subspace_idx in 0..NUM_SUBSPACES {
128 let start_dim = subspace_idx * SUBSPACE_DIM;
129 let end_dim = start_dim + SUBSPACE_DIM;
130
131 let subspace_vecs: Vec<Vec<f32>> = training_vectors
133 .iter()
134 .map(|v| v[start_dim..end_dim].to_vec())
135 .collect();
136
137 let centroids = kmeans(&subspace_vecs, NUM_CENTROIDS, max_iterations)?;
139
140 for (i, centroid) in centroids.iter().enumerate() {
142 #[allow(clippy::cast_possible_truncation)]
143 self.codebooks[subspace_idx].set_centroid(i as u8, centroid);
144 }
145 }
146
147 Ok(())
148 }
149
150 pub fn encode(&self, vector: &[f32]) -> Result<Vec<u8>> {
152 if vector.len() != TOTAL_DIM {
153 return Err(MemvidError::InvalidQuery {
154 reason: format!(
155 "Vector dimension mismatch: expected {}, got {}",
156 TOTAL_DIM,
157 vector.len()
158 ),
159 });
160 }
161
162 let mut codes = Vec::with_capacity(NUM_SUBSPACES);
163
164 for subspace_idx in 0..NUM_SUBSPACES {
165 let start_dim = subspace_idx * SUBSPACE_DIM;
166 let end_dim = start_dim + SUBSPACE_DIM;
167 let subspace = &vector[start_dim..end_dim];
168
169 let code = self.codebooks[subspace_idx].quantize(subspace);
170 codes.push(code);
171 }
172
173 Ok(codes)
174 }
175
176 pub fn decode(&self, codes: &[u8]) -> Result<Vec<f32>> {
178 if codes.len() != NUM_SUBSPACES {
179 return Err(MemvidError::InvalidQuery {
180 reason: format!(
181 "Invalid PQ codes length: expected {}, got {}",
182 NUM_SUBSPACES,
183 codes.len()
184 ),
185 });
186 }
187
188 let mut vector = Vec::with_capacity(TOTAL_DIM);
189
190 for (subspace_idx, &code) in codes.iter().enumerate() {
191 let centroid = self.codebooks[subspace_idx].get_centroid(code);
192 vector.extend_from_slice(centroid);
193 }
194
195 Ok(vector)
196 }
197
198 #[must_use]
201 pub fn asymmetric_distance(&self, query: &[f32], codes: &[u8]) -> f32 {
202 if query.len() != TOTAL_DIM || codes.len() != NUM_SUBSPACES {
203 return f32::INFINITY;
204 }
205
206 let mut total_dist_sq = 0.0f32;
207
208 for subspace_idx in 0..NUM_SUBSPACES {
209 let start_dim = subspace_idx * SUBSPACE_DIM;
210 let end_dim = start_dim + SUBSPACE_DIM;
211 let query_subspace = &query[start_dim..end_dim];
212
213 let code = codes[subspace_idx];
214 let centroid = self.codebooks[subspace_idx].get_centroid(code);
215
216 total_dist_sq += l2_distance_squared(query_subspace, centroid);
217 }
218
219 total_dist_sq.sqrt()
220 }
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
225pub struct QuantizedVecDocument {
226 pub frame_id: FrameId,
227 pub codes: Vec<u8>,
229}
230
231#[derive(Default)]
233pub struct QuantizedVecIndexBuilder {
234 documents: Vec<QuantizedVecDocument>,
235 quantizer: Option<ProductQuantizer>,
236}
237
238impl QuantizedVecIndexBuilder {
239 #[must_use]
240 pub fn new() -> Self {
241 Self::default()
242 }
243
244 pub fn train_quantizer(&mut self, training_vectors: &[Vec<f32>], dimension: u32) -> Result<()> {
246 let mut pq = ProductQuantizer::new(dimension)?;
247 pq.train(training_vectors, 25)?; self.quantizer = Some(pq);
249 Ok(())
250 }
251
252 pub fn add_document(&mut self, frame_id: FrameId, embedding: Vec<f32>) -> Result<()> {
254 let quantizer = self
255 .quantizer
256 .as_ref()
257 .ok_or_else(|| MemvidError::InvalidQuery {
258 reason: "Quantizer not trained. Call train_quantizer first".to_string(),
259 })?;
260
261 let codes = quantizer.encode(&embedding)?;
262
263 self.documents
264 .push(QuantizedVecDocument { frame_id, codes });
265
266 Ok(())
267 }
268
269 pub fn finish(self) -> Result<QuantizedVecIndexArtifact> {
270 let quantizer = self.quantizer.ok_or_else(|| MemvidError::InvalidQuery {
271 reason: "Quantizer not trained".to_string(),
272 })?;
273
274 let vector_count = self.documents.len() as u64;
275 let bytes =
276 bincode::serde::encode_to_vec(&(quantizer.clone(), self.documents), vec_config())?;
277 let checksum = *hash(&bytes).as_bytes();
278
279 Ok(QuantizedVecIndexArtifact {
280 bytes,
281 vector_count,
282 dimension: quantizer.dimension,
283 checksum,
284 compression_ratio: 16.0, })
286 }
287}
288
289#[derive(Debug, Clone)]
290pub struct QuantizedVecIndexArtifact {
291 pub bytes: Vec<u8>,
292 pub vector_count: u64,
293 pub dimension: u32,
294 pub checksum: [u8; 32],
295 pub compression_ratio: f64,
296}
297
298#[derive(Debug, Clone)]
299pub struct QuantizedVecIndex {
300 quantizer: ProductQuantizer,
301 documents: Vec<QuantizedVecDocument>,
302}
303
304impl QuantizedVecIndex {
305 pub fn decode(bytes: &[u8]) -> Result<Self> {
306 let config = bincode::config::standard()
308 .with_fixed_int_encoding()
309 .with_little_endian()
310 .with_limit::<VEC_DECODE_LIMIT>();
311
312 if let Ok(((quantizer, documents), read)) = bincode::serde::decode_from_slice::<
313 (ProductQuantizer, Vec<QuantizedVecDocument>),
314 _,
315 >(bytes, config)
316 {
317 if read == bytes.len() {
318 return Ok(Self {
319 quantizer,
320 documents,
321 });
322 }
323 }
324
325 #[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
328 struct OldProductQuantizer {
329 codebooks: Vec<SubspaceCodebook>,
330 }
331
332 let ((old_quantizer, documents), read): (
333 (OldProductQuantizer, Vec<QuantizedVecDocument>),
334 usize,
335 ) = bincode::serde::decode_from_slice(bytes, config)?;
336
337 if read != bytes.len() {
338 return Err(MemvidError::InvalidToc {
339 reason: "unsupported quantized vector index encoding".into(),
340 });
341 }
342
343 let quantizer = ProductQuantizer {
345 codebooks: old_quantizer.codebooks,
346 dimension: u32::try_from(NUM_SUBSPACES * SUBSPACE_DIM).unwrap_or(u32::MAX),
347 };
348
349 Ok(Self {
350 quantizer,
351 documents,
352 })
353 }
354
355 #[must_use]
357 pub fn search(&self, query: &[f32], limit: usize) -> Vec<VecSearchHit> {
358 if query.is_empty() {
359 return Vec::new();
360 }
361
362 let mut hits: Vec<VecSearchHit> = self
363 .documents
364 .iter()
365 .map(|doc| {
366 let distance = self.quantizer.asymmetric_distance(query, &doc.codes);
367 VecSearchHit {
368 frame_id: doc.frame_id,
369 distance,
370 }
371 })
372 .collect();
373
374 hits.sort_by(|a, b| {
375 a.distance
376 .partial_cmp(&b.distance)
377 .unwrap_or(std::cmp::Ordering::Equal)
378 });
379
380 hits.truncate(limit);
381 hits
382 }
383
384 pub fn remove(&mut self, frame_id: FrameId) {
385 self.documents.retain(|doc| doc.frame_id != frame_id);
386 }
387
388 #[must_use]
390 pub fn compression_stats(&self) -> CompressionStats {
391 let original_bytes = self.documents.len() * TOTAL_DIM * std::mem::size_of::<f32>();
392 let compressed_bytes = self.documents.len() * NUM_SUBSPACES; let codebook_bytes =
394 NUM_SUBSPACES * NUM_CENTROIDS * SUBSPACE_DIM * std::mem::size_of::<f32>();
395
396 CompressionStats {
397 vector_count: self.documents.len() as u64,
398 original_bytes: original_bytes as u64,
399 compressed_bytes: compressed_bytes as u64,
400 codebook_bytes: codebook_bytes as u64,
401 total_bytes: (compressed_bytes + codebook_bytes) as u64,
402 compression_ratio: original_bytes as f64 / (compressed_bytes + codebook_bytes) as f64,
403 }
404 }
405}
406
407#[derive(Debug, Clone)]
408pub struct CompressionStats {
409 pub vector_count: u64,
410 pub original_bytes: u64,
411 pub compressed_bytes: u64,
412 pub codebook_bytes: u64,
413 pub total_bytes: u64,
414 pub compression_ratio: f64,
415}
416
417fn kmeans(vectors: &[Vec<f32>], k: usize, max_iterations: usize) -> Result<Vec<Vec<f32>>> {
419 if vectors.is_empty() {
420 return Err(MemvidError::InvalidQuery {
421 reason: "Cannot run k-means on empty vector set".to_string(),
422 });
423 }
424
425 let dim = vectors[0].len();
426
427 let mut centroids = kmeans_plus_plus_init(vectors, k)?;
429
430 for _iteration in 0..max_iterations {
431 let mut assignments = vec![Vec::new(); k];
433
434 for vec in vectors {
435 let mut best_cluster = 0;
436 let mut best_dist = f32::INFINITY;
437
438 for (cluster_idx, centroid) in centroids.iter().enumerate() {
439 let dist = l2_distance_squared(vec, centroid);
440 if dist < best_dist {
441 best_dist = dist;
442 best_cluster = cluster_idx;
443 }
444 }
445
446 assignments[best_cluster].push(vec.clone());
447 }
448
449 let mut changed = false;
451 for (cluster_idx, cluster_vecs) in assignments.iter().enumerate() {
452 if cluster_vecs.is_empty() {
453 centroids[cluster_idx] = vectors[cluster_idx % vectors.len()].clone();
455 changed = true;
456 continue;
457 }
458
459 let mut new_centroid = vec![0.0f32; dim];
460 for vec in cluster_vecs {
461 for (i, &val) in vec.iter().enumerate() {
462 new_centroid[i] += val;
463 }
464 }
465 for val in &mut new_centroid {
466 *val /= cluster_vecs.len() as f32;
467 }
468
469 if l2_distance_squared(¢roids[cluster_idx], &new_centroid) > 1e-6 {
471 changed = true;
472 }
473
474 centroids[cluster_idx] = new_centroid;
475 }
476
477 if !changed {
478 break; }
480 }
481
482 Ok(centroids)
483}
484
485fn kmeans_plus_plus_init(vectors: &[Vec<f32>], k: usize) -> Result<Vec<Vec<f32>>> {
487 if vectors.is_empty() || k == 0 {
488 return Err(MemvidError::InvalidQuery {
489 reason: "Invalid k-means++ initialization".to_string(),
490 });
491 }
492
493 let mut centroids = Vec::new();
494
495 centroids.push(vectors[0].clone());
497
498 for _ in 1..k {
500 let mut distances = Vec::new();
501
502 for vec in vectors {
504 let mut min_dist = f32::INFINITY;
505 for centroid in ¢roids {
506 let dist = l2_distance_squared(vec, centroid);
507 min_dist = min_dist.min(dist);
508 }
509 distances.push(min_dist);
510 }
511
512 let max_idx = distances
514 .iter()
515 .enumerate()
516 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
517 .map_or(0, |(idx, _)| idx);
518
519 centroids.push(vectors[max_idx].clone());
520 }
521
522 Ok(centroids)
523}
524
525fn l2_distance_squared(a: &[f32], b: &[f32]) -> f32 {
527 crate::simd::l2_distance_squared_simd(a, b)
528}
529
530#[cfg(test)]
531mod tests {
532 use super::*;
533
534 #[test]
535 fn test_subspace_codebook() {
536 let mut codebook = SubspaceCodebook::new();
537
538 codebook.set_centroid(0, &[1.0, 2.0, 3.0, 4.0]);
540
541 let centroid = codebook.get_centroid(0);
543 assert_eq!(centroid, &[1.0, 2.0, 3.0, 4.0]);
544
545 let code = codebook.quantize(&[1.1, 2.1, 3.1, 4.1]);
547 assert_eq!(code, 0);
548 }
549
550 #[test]
551 fn test_product_quantizer_roundtrip() {
552 let mut training_vecs = Vec::new();
554 for i in 0..100 {
555 let mut vec = vec![0.0f32; TOTAL_DIM];
556 for j in 0..TOTAL_DIM {
557 vec[j] = ((i * TOTAL_DIM + j) % 100) as f32 / 100.0;
558 }
559 training_vecs.push(vec);
560 }
561
562 let mut pq = ProductQuantizer::new(u32::try_from(TOTAL_DIM).unwrap()).unwrap();
564 pq.train(&training_vecs, 10).unwrap();
565
566 let test_vec = &training_vecs[0];
568 let codes = pq.encode(test_vec).unwrap();
569 assert_eq!(codes.len(), NUM_SUBSPACES);
570
571 let decoded = pq.decode(&codes).unwrap();
573 assert_eq!(decoded.len(), TOTAL_DIM);
574
575 let dist = l2_distance_squared(test_vec, &decoded).sqrt();
577 assert!(dist < 10.0, "Reconstruction error too large: {}", dist);
578 }
579
580 #[test]
581 fn test_quantized_index_builder() {
582 let mut training_vecs = Vec::new();
584 for i in 0..50 {
585 let mut vec = vec![0.0f32; TOTAL_DIM];
586 for j in 0..TOTAL_DIM {
587 vec[j] = ((i + j) % 10) as f32;
588 }
589 training_vecs.push(vec);
590 }
591
592 let mut builder = QuantizedVecIndexBuilder::new();
594 builder
595 .train_quantizer(&training_vecs, u32::try_from(TOTAL_DIM).unwrap())
596 .unwrap();
597
598 for (i, vec) in training_vecs.iter().take(10).enumerate() {
599 builder
600 .add_document((i + 1) as FrameId, vec.clone())
601 .unwrap();
602 }
603
604 let artifact = builder.finish().unwrap();
605 assert_eq!(artifact.vector_count, 10);
606 assert_eq!(artifact.dimension, u32::try_from(TOTAL_DIM).unwrap());
607 assert!(artifact.compression_ratio > 10.0);
608
609 let index = QuantizedVecIndex::decode(&artifact.bytes).unwrap();
611 let query = &training_vecs[0];
612 let hits = index.search(query, 5);
613
614 assert!(!hits.is_empty());
615 assert_eq!(hits[0].frame_id, 1); }
617
618 #[test]
619 fn test_kmeans_simple() {
620 let vectors = vec![
621 vec![0.0, 0.0],
622 vec![0.1, 0.1],
623 vec![10.0, 10.0],
624 vec![10.1, 10.1],
625 ];
626
627 let centroids = kmeans(&vectors, 2, 100).unwrap();
628 assert_eq!(centroids.len(), 2);
629
630 let near_zero = centroids.iter().any(|c| c[0] < 5.0 && c[1] < 5.0);
632 let near_ten = centroids.iter().any(|c| c[0] > 5.0 && c[1] > 5.0);
633 assert!(near_zero && near_ten);
634 }
635}