oxirs_embed/embed_compression/
mod.rs1#[derive(Debug, Clone)]
15pub struct QuantizedEmbedding {
16 pub original_dim: usize,
18 pub quantized_data: Vec<u8>,
20 pub scale: f32,
22 pub zero_point: f32,
24}
25
26impl QuantizedEmbedding {
27 pub fn quantize(embedding: &[f32]) -> Self {
31 let dim = embedding.len();
32 if dim == 0 {
33 return Self {
34 original_dim: 0,
35 quantized_data: vec![],
36 scale: 0.0,
37 zero_point: 0.0,
38 };
39 }
40
41 let min_val = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
42 let max_val = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
43 let range = max_val - min_val;
44
45 let (scale, zero_point) = if range < 1e-10 {
46 (0.0_f32, min_val)
48 } else {
49 (range / 255.0, min_val)
50 };
51
52 let quantized_data: Vec<u8> = embedding
53 .iter()
54 .map(|&v| {
55 if range < 1e-10 {
56 0_u8
57 } else {
58 ((v - min_val) / range * 255.0).round().clamp(0.0, 255.0) as u8
59 }
60 })
61 .collect();
62
63 Self {
64 original_dim: dim,
65 quantized_data,
66 scale,
67 zero_point,
68 }
69 }
70
71 pub fn dequantize(&self) -> Vec<f32> {
75 self.quantized_data
76 .iter()
77 .map(|&q| q as f32 * self.scale + self.zero_point)
78 .collect()
79 }
80
81 pub fn approx_size_bytes(&self) -> usize {
83 self.quantized_data.len() + std::mem::size_of::<f32>() * 2 + std::mem::size_of::<usize>() }
87}
88
89#[derive(Debug, Clone)]
95pub struct EmbeddingQuantizer {
96 pub bits: u8,
98}
99
100impl EmbeddingQuantizer {
101 pub fn new(bits: u8) -> Self {
105 Self { bits }
106 }
107
108 pub fn quantize_batch(&self, embeddings: &[Vec<f32>]) -> Vec<QuantizedEmbedding> {
110 embeddings.iter().map(|e| self.quantize_single(e)).collect()
111 }
112
113 pub fn dequantize_batch(&self, quantized: &[QuantizedEmbedding]) -> Vec<Vec<f32>> {
115 quantized.iter().map(|q| q.dequantize()).collect()
116 }
117
118 pub fn compression_ratio(&self, original: &[Vec<f32>]) -> f64 {
120 if original.is_empty() {
121 return 1.0;
122 }
123 let original_bytes: usize = original.iter().map(|v| v.len() * 4).sum(); let quantized = self.quantize_batch(original);
125 let quantized_bytes: usize = quantized.iter().map(|q| q.approx_size_bytes()).sum();
126 if quantized_bytes == 0 {
127 return 1.0;
128 }
129 original_bytes as f64 / quantized_bytes as f64
130 }
131
132 fn quantize_single(&self, embedding: &[f32]) -> QuantizedEmbedding {
135 if self.bits <= 4 {
136 self.quantize_4bit(embedding)
137 } else {
138 QuantizedEmbedding::quantize(embedding)
139 }
140 }
141
142 fn quantize_4bit(&self, embedding: &[f32]) -> QuantizedEmbedding {
145 let dim = embedding.len();
146 if dim == 0 {
147 return QuantizedEmbedding {
148 original_dim: 0,
149 quantized_data: vec![],
150 scale: 0.0,
151 zero_point: 0.0,
152 };
153 }
154
155 let min_val = embedding.iter().cloned().fold(f32::INFINITY, f32::min);
156 let max_val = embedding.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
157 let range = max_val - min_val;
158
159 let (scale, zero_point) = if range < 1e-10 {
161 (0.0_f32, min_val)
162 } else {
163 (range / 15.0, min_val)
164 };
165
166 let quantized_data: Vec<u8> = embedding
168 .iter()
169 .map(|&v| {
170 if range < 1e-10 {
171 0_u8
172 } else {
173 ((v - min_val) / range * 15.0).round().clamp(0.0, 15.0) as u8
174 }
175 })
176 .collect();
177
178 QuantizedEmbedding {
179 original_dim: dim,
180 quantized_data,
181 scale,
182 zero_point,
183 }
184 }
185}
186
187#[derive(Debug, Clone)]
194pub struct ProductQuantizer {
195 pub subspace_count: usize,
197 pub codebook_size: usize,
199 pub codebooks: Vec<Vec<Vec<f32>>>,
201 pub subvec_dim: usize,
203}
204
205impl ProductQuantizer {
206 pub fn new(subspace_count: usize, codebook_size: usize) -> Self {
208 Self {
209 subspace_count,
210 codebook_size,
211 codebooks: Vec::new(),
212 subvec_dim: 0,
213 }
214 }
215
216 pub fn train(&mut self, embeddings: &[Vec<f32>]) {
221 if embeddings.is_empty() || self.subspace_count == 0 {
222 return;
223 }
224 let dim = embeddings[0].len();
225 self.subvec_dim = dim / self.subspace_count;
226 if self.subvec_dim == 0 {
227 self.subvec_dim = 1;
228 }
229
230 self.codebooks = (0..self.subspace_count)
231 .map(|s| {
232 let start = s * self.subvec_dim;
233 let end = ((s + 1) * self.subvec_dim).min(dim);
234
235 let subvecs: Vec<Vec<f32>> =
237 embeddings.iter().map(|e| e[start..end].to_vec()).collect();
238
239 let n_codes = self.codebook_size.min(subvecs.len());
241 let mut centroids: Vec<Vec<f32>> = Vec::with_capacity(n_codes);
242
243 let mut lcg_state: u64 = (s as u64 + 1).wrapping_mul(6_364_136_223_846_793_005);
245 let mut used = std::collections::HashSet::new();
246 while centroids.len() < n_codes {
247 lcg_state = lcg_state
248 .wrapping_mul(6_364_136_223_846_793_005)
249 .wrapping_add(1_442_695_040_888_963_407);
250 let idx = (lcg_state >> 33) as usize % subvecs.len();
251 if used.insert(idx) {
252 centroids.push(subvecs[idx].clone());
253 }
254 }
255
256 for _ in 0..5 {
258 let assignments: Vec<usize> = subvecs
259 .iter()
260 .map(|sv| nearest_centroid(sv, ¢roids))
261 .collect();
262
263 let sub_dim = end - start;
264 let mut new_centroids = vec![vec![0.0_f32; sub_dim]; n_codes];
265 let mut counts = vec![0usize; n_codes];
266
267 for (sv, &c) in subvecs.iter().zip(assignments.iter()) {
268 for (i, &v) in sv.iter().enumerate() {
269 if i < new_centroids[c].len() {
270 new_centroids[c][i] += v;
271 }
272 }
273 counts[c] += 1;
274 }
275
276 for (c, count) in counts.iter().enumerate() {
277 if *count > 0 {
278 let n = *count as f32;
279 new_centroids[c].iter_mut().for_each(|x| *x /= n);
280 centroids[c] = new_centroids[c].clone();
281 }
282 }
283 }
284
285 centroids
286 })
287 .collect();
288 }
289
290 pub fn encode(&self, embedding: &[f32]) -> Vec<u8> {
292 if self.codebooks.is_empty() || self.subvec_dim == 0 {
293 return vec![0; self.subspace_count];
294 }
295 let dim = embedding.len();
296 (0..self.subspace_count)
297 .map(|s| {
298 let start = s * self.subvec_dim;
299 let end = ((s + 1) * self.subvec_dim).min(dim);
300 let subvec = &embedding[start..end];
301 let code = nearest_centroid(subvec, &self.codebooks[s]);
302 code.min(255) as u8
303 })
304 .collect()
305 }
306
307 pub fn decode(&self, codes: &[u8]) -> Vec<f32> {
309 if self.codebooks.is_empty() {
310 return vec![];
311 }
312 let mut result = Vec::new();
313 for (s, &code) in codes.iter().enumerate().take(self.subspace_count) {
314 if s >= self.codebooks.len() {
315 break;
316 }
317 let c_idx = (code as usize).min(self.codebooks[s].len().saturating_sub(1));
318 result.extend_from_slice(&self.codebooks[s][c_idx]);
319 }
320 result
321 }
322
323 pub fn approx_distance(&self, codes1: &[u8], codes2: &[u8]) -> f32 {
325 if self.codebooks.is_empty() {
326 return 0.0;
327 }
328 let mut total = 0.0_f32;
329 for s in 0..self.subspace_count.min(codes1.len()).min(codes2.len()) {
330 if s >= self.codebooks.len() {
331 break;
332 }
333 let c1 = (codes1[s] as usize).min(self.codebooks[s].len().saturating_sub(1));
334 let c2 = (codes2[s] as usize).min(self.codebooks[s].len().saturating_sub(1));
335 let v1 = &self.codebooks[s][c1];
336 let v2 = &self.codebooks[s][c2];
337 let sq_dist: f32 = v1
338 .iter()
339 .zip(v2.iter())
340 .map(|(a, b)| (a - b) * (a - b))
341 .sum();
342 total += sq_dist;
343 }
344 total
345 }
346
347 pub fn is_trained(&self) -> bool {
349 !self.codebooks.is_empty()
350 }
351}
352
353fn nearest_centroid(query: &[f32], centroids: &[Vec<f32>]) -> usize {
359 let mut best_idx = 0;
360 let mut best_dist = f32::INFINITY;
361 for (i, c) in centroids.iter().enumerate() {
362 let d: f32 = query
363 .iter()
364 .zip(c.iter())
365 .map(|(a, b)| (a - b) * (a - b))
366 .sum();
367 if d < best_dist {
368 best_dist = d;
369 best_idx = i;
370 }
371 }
372 best_idx
373}
374
375#[cfg(test)]
380mod tests {
381 use super::*;
382
383 fn sample_embedding(seed: u32, dim: usize) -> Vec<f32> {
384 let mut v = Vec::with_capacity(dim);
385 let mut s = seed;
386 for _ in 0..dim {
387 s = s.wrapping_mul(1664525).wrapping_add(1013904223);
388 v.push((s as f32 / u32::MAX as f32) * 2.0 - 1.0);
389 }
390 v
391 }
392
393 fn sample_batch(n: usize, dim: usize, base_seed: u32) -> Vec<Vec<f32>> {
394 (0..n)
395 .map(|i| sample_embedding(base_seed + i as u32, dim))
396 .collect()
397 }
398
399 #[test]
402 fn test_quantize_dequantize_roundtrip() {
403 let emb = sample_embedding(1, 16);
404 let q = QuantizedEmbedding::quantize(&emb);
405 let deq = q.dequantize();
406 assert_eq!(deq.len(), emb.len());
407 let range = emb.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
409 - emb.iter().cloned().fold(f32::INFINITY, f32::min);
410 let max_err = range / 255.0 + 1e-5;
411 for (orig, rec) in emb.iter().zip(deq.iter()) {
412 assert!(
413 (orig - rec).abs() <= max_err + 1e-4,
414 "reconstruction error too large: {} vs {} (max_err={})",
415 orig,
416 rec,
417 max_err
418 );
419 }
420 }
421
422 #[test]
423 fn test_quantize_output_in_range() {
424 let emb = sample_embedding(2, 8);
425 let q = QuantizedEmbedding::quantize(&emb);
426 assert_eq!(q.quantized_data.len(), 8);
427 assert!(!q.quantized_data.is_empty());
429 }
430
431 #[test]
432 fn test_quantize_empty_embedding() {
433 let q = QuantizedEmbedding::quantize(&[]);
434 assert_eq!(q.original_dim, 0);
435 assert!(q.quantized_data.is_empty());
436 }
437
438 #[test]
439 fn test_quantize_constant_embedding() {
440 let val = 3.125_f32;
441 let emb = vec![val; 8];
442 let q = QuantizedEmbedding::quantize(&emb);
443 let deq = q.dequantize();
444 for &v in &deq {
445 assert!(
446 (v - val).abs() < 0.5,
447 "constant embedding should dequantize close to {val}, got {v}"
448 );
449 }
450 }
451
452 #[test]
453 fn test_approx_size_bytes() {
454 let emb = sample_embedding(3, 64);
455 let q = QuantizedEmbedding::quantize(&emb);
456 let sz = q.approx_size_bytes();
457 assert!(sz > 0);
458 assert!(sz < 64 * 4, "quantized size should be smaller than f32");
460 }
461
462 #[test]
465 fn test_quantizer_8bit_creation() {
466 let q = EmbeddingQuantizer::new(8);
467 assert_eq!(q.bits, 8);
468 }
469
470 #[test]
471 fn test_quantizer_4bit_creation() {
472 let q = EmbeddingQuantizer::new(4);
473 assert_eq!(q.bits, 4);
474 }
475
476 #[test]
477 fn test_quantize_batch_count() {
478 let q = EmbeddingQuantizer::new(8);
479 let batch = sample_batch(10, 16, 100);
480 let out = q.quantize_batch(&batch);
481 assert_eq!(out.len(), 10);
482 }
483
484 #[test]
485 fn test_dequantize_batch_count() {
486 let q = EmbeddingQuantizer::new(8);
487 let batch = sample_batch(5, 16, 200);
488 let quantized = q.quantize_batch(&batch);
489 let deq = q.dequantize_batch(&quantized);
490 assert_eq!(deq.len(), 5);
491 assert_eq!(deq[0].len(), 16);
492 }
493
494 #[test]
495 fn test_compression_ratio_8bit() {
496 let q = EmbeddingQuantizer::new(8);
497 let batch = sample_batch(10, 64, 300);
498 let ratio = q.compression_ratio(&batch);
499 assert!(
501 ratio > 1.0,
502 "8-bit quantization should compress: ratio={ratio}"
503 );
504 }
505
506 #[test]
507 fn test_compression_ratio_4bit() {
508 let q = EmbeddingQuantizer::new(4);
509 let batch = sample_batch(10, 64, 400);
510 let ratio = q.compression_ratio(&batch);
511 assert!(
512 ratio > 1.0,
513 "4-bit quantization should compress: ratio={ratio}"
514 );
515 }
516
517 #[test]
518 fn test_compression_ratio_empty() {
519 let q = EmbeddingQuantizer::new(8);
520 let ratio = q.compression_ratio(&[]);
521 assert_eq!(ratio, 1.0);
522 }
523
524 #[test]
525 fn test_4bit_quantize_dequantize() {
526 let q = EmbeddingQuantizer::new(4);
527 let batch = sample_batch(3, 16, 500);
528 let quantized = q.quantize_batch(&batch);
529 let deq = q.dequantize_batch(&quantized);
530 for (orig, rec) in batch.iter().zip(deq.iter()) {
532 let range = orig.iter().cloned().fold(f32::NEG_INFINITY, f32::max)
533 - orig.iter().cloned().fold(f32::INFINITY, f32::min);
534 let max_err = range / 15.0 + 1e-3;
535 for (o, r) in orig.iter().zip(rec.iter()) {
536 assert!(
537 (o - r).abs() <= max_err + 0.1,
538 "4-bit error too large: {o} vs {r}"
539 );
540 }
541 }
542 }
543
544 #[test]
547 fn test_pq_creation() {
548 let pq = ProductQuantizer::new(4, 16);
549 assert_eq!(pq.subspace_count, 4);
550 assert_eq!(pq.codebook_size, 16);
551 assert!(!pq.is_trained());
552 }
553
554 #[test]
555 fn test_pq_train() {
556 let mut pq = ProductQuantizer::new(4, 8);
557 let batch = sample_batch(50, 16, 1000);
558 pq.train(&batch);
559 assert!(pq.is_trained());
560 assert_eq!(pq.codebooks.len(), 4);
561 }
562
563 #[test]
564 fn test_pq_encode_length() {
565 let mut pq = ProductQuantizer::new(4, 8);
566 let batch = sample_batch(30, 16, 1100);
567 pq.train(&batch);
568 let codes = pq.encode(&batch[0]);
569 assert_eq!(codes.len(), 4);
570 }
571
572 #[test]
573 fn test_pq_decode_length() {
574 let mut pq = ProductQuantizer::new(4, 8);
575 let batch = sample_batch(30, 16, 1200);
576 pq.train(&batch);
577 let codes = pq.encode(&batch[0]);
578 let decoded = pq.decode(&codes);
579 assert!(!decoded.is_empty());
580 }
581
582 #[test]
583 fn test_pq_approx_distance_same_code() {
584 let mut pq = ProductQuantizer::new(4, 8);
585 let batch = sample_batch(30, 16, 1300);
586 pq.train(&batch);
587 let codes = pq.encode(&batch[0]);
588 let dist = pq.approx_distance(&codes, &codes);
589 assert!(
590 dist.abs() < 1e-6,
591 "distance to self should be ~0, got {dist}"
592 );
593 }
594
595 #[test]
596 fn test_pq_approx_distance_different_codes() {
597 let mut pq = ProductQuantizer::new(4, 8);
598 let batch = sample_batch(40, 16, 1400);
599 pq.train(&batch);
600 let c0 = pq.encode(&batch[0]);
601 let c1 = pq.encode(&batch[20]);
602 let dist = pq.approx_distance(&c0, &c1);
603 assert!(dist >= 0.0, "distance should be non-negative");
604 assert!(dist.is_finite(), "distance should be finite");
605 }
606
607 #[test]
608 fn test_pq_encode_before_train_returns_zeros() {
609 let pq = ProductQuantizer::new(4, 8);
610 let emb = sample_embedding(1, 16);
611 let codes = pq.encode(&emb);
612 assert!(codes.iter().all(|&c| c == 0));
613 }
614
615 #[test]
616 fn test_pq_codebook_size_capped_by_data() {
617 let mut pq = ProductQuantizer::new(2, 256); let batch = sample_batch(10, 8, 2000); pq.train(&batch);
620 for cb in &pq.codebooks {
622 assert!(cb.len() <= 256);
623 }
624 }
625
626 #[test]
627 fn test_pq_reconstruction_quality() {
628 let mut pq = ProductQuantizer::new(2, 8);
629 let batch = sample_batch(50, 8, 3000);
630 pq.train(&batch);
631 let orig = &batch[0];
633 let codes = pq.encode(orig);
634 let decoded = pq.decode(&codes);
635 assert!(!decoded.is_empty());
637 assert!(decoded.iter().all(|v| v.is_finite()));
638 }
639
640 #[test]
641 fn test_pq_train_empty_no_panic() {
642 let mut pq = ProductQuantizer::new(4, 8);
643 pq.train(&[]); assert!(!pq.is_trained());
645 }
646}